use std::fmt;
use std::io;
use std::marker::PhantomData;
use bytes::{Buf, Bytes};
use http::header::{HeaderValue, CONNECTION};
use http::{HeaderMap, Method, Version};
use tokio::io::{AsyncRead, AsyncWrite};
use super::io::Buffered;
use super::{Decoder, Encode, EncodedBuf, Encoder, Http1Transaction, ParseContext, Wants};
use crate::body::DecodedLength;
use crate::common::{task, Pin, Poll, Unpin};
use crate::headers::connection_keep_alive;
use crate::proto::{BodyLength, MessageHead};
const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
pub(crate) struct Conn<I, B, T> {
io: Buffered<I, EncodedBuf<B>>,
state: State,
_marker: PhantomData<fn(T)>,
}
impl<I, B, T> Conn<I, B, T>
where
I: AsyncRead + AsyncWrite + Unpin,
B: Buf,
T: Http1Transaction,
{
pub(crate) fn new(io: I) -> Conn<I, B, T> {
Conn {
io: Buffered::new(io),
state: State {
allow_half_close: false,
cached_headers: None,
error: None,
keep_alive: KA::Busy,
method: None,
#[cfg(feature = "ffi")]
preserve_header_case: false,
title_case_headers: false,
notify_read: false,
reading: Reading::Init,
writing: Writing::Init,
upgrade: None,
version: Version::HTTP_11,
},
_marker: PhantomData,
}
}
#[cfg(feature = "server")]
pub(crate) fn set_flush_pipeline(&mut self, enabled: bool) {
self.io.set_flush_pipeline(enabled);
}
pub(crate) fn set_max_buf_size(&mut self, max: usize) {
self.io.set_max_buf_size(max);
}
#[cfg(feature = "client")]
pub(crate) fn set_read_buf_exact_size(&mut self, sz: usize) {
self.io.set_read_buf_exact_size(sz);
}
#[cfg(feature = "client")]
pub(crate) fn set_title_case_headers(&mut self) {
self.state.title_case_headers = true;
}
#[cfg(feature = "server")]
pub(crate) fn set_allow_half_close(&mut self) {
self.state.allow_half_close = true;
}
pub(crate) fn into_inner(self) -> (I, Bytes) {
self.io.into_inner()
}
pub(crate) fn pending_upgrade(&mut self) -> Option<crate::upgrade::Pending> {
self.state.upgrade.take()
}
pub(crate) fn is_read_closed(&self) -> bool {
self.state.is_read_closed()
}
pub(crate) fn is_write_closed(&self) -> bool {
self.state.is_write_closed()
}
pub(crate) fn can_read_head(&self) -> bool {
match self.state.reading {
Reading::Init => {
if T::should_read_first() {
true
} else {
match self.state.writing {
Writing::Init => false,
_ => true,
}
}
}
_ => false,
}
}
pub(crate) fn can_read_body(&self) -> bool {
match self.state.reading {
Reading::Body(..) | Reading::Continue(..) => true,
_ => false,
}
}
fn should_error_on_eof(&self) -> bool {
T::should_error_on_parse_eof() && !self.state.is_idle()
}
fn has_h2_prefix(&self) -> bool {
let read_buf = self.io.read_buf();
read_buf.len() >= 24 && read_buf[..24] == *H2_PREFACE
}
pub(super) fn poll_read_head(
&mut self,
cx: &mut task::Context<'_>,
) -> Poll<Option<crate::Result<(MessageHead<T::Incoming>, DecodedLength, Wants)>>> {
debug_assert!(self.can_read_head());
trace!("Conn::read_head");
let msg = match ready!(self.io.parse::<T>(
cx,
ParseContext {
cached_headers: &mut self.state.cached_headers,
req_method: &mut self.state.method,
#[cfg(feature = "ffi")]
preserve_header_case: self.state.preserve_header_case,
}
)) {
Ok(msg) => msg,
Err(e) => return self.on_read_head_error(e),
};
debug!("incoming body is {}", msg.decode);
self.state.busy();
self.state.keep_alive &= msg.keep_alive;
self.state.version = msg.head.version;
let mut wants = if msg.wants_upgrade {
Wants::UPGRADE
} else {
Wants::EMPTY
};
if msg.decode == DecodedLength::ZERO {
if msg.expect_continue {
debug!("ignoring expect-continue since body is empty");
}
self.state.reading = Reading::KeepAlive;
if !T::should_read_first() {
self.try_keep_alive(cx);
}
} else if msg.expect_continue {
self.state.reading = Reading::Continue(Decoder::new(msg.decode));
wants = wants.add(Wants::EXPECT);
} else {
self.state.reading = Reading::Body(Decoder::new(msg.decode));
}
Poll::Ready(Some(Ok((msg.head, msg.decode, wants))))
}
fn on_read_head_error<Z>(&mut self, e: crate::Error) -> Poll<Option<crate::Result<Z>>> {
let must_error = self.should_error_on_eof();
self.close_read();
self.io.consume_leading_lines();
let was_mid_parse = e.is_parse() || !self.io.read_buf().is_empty();
if was_mid_parse || must_error {
debug!(
"parse error ({}) with {} bytes",
e,
self.io.read_buf().len()
);
match self.on_parse_error(e) {
Ok(()) => Poll::Pending,
Err(e) => Poll::Ready(Some(Err(e))),
}
} else {
debug!("read eof");
self.close_write();
Poll::Ready(None)
}
}
pub(crate) fn poll_read_body(
&mut self,
cx: &mut task::Context<'_>,
) -> Poll<Option<io::Result<Bytes>>> {
debug_assert!(self.can_read_body());
let (reading, ret) = match self.state.reading {
Reading::Body(ref mut decoder) => {
match ready!(decoder.decode(cx, &mut self.io)) {
Ok(slice) => {
let (reading, chunk) = if decoder.is_eof() {
debug!("incoming body completed");
(
Reading::KeepAlive,
if !slice.is_empty() {
Some(Ok(slice))
} else {
None
},
)
} else if slice.is_empty() {
error!("incoming body unexpectedly ended");
(Reading::Closed, None)
} else {
return Poll::Ready(Some(Ok(slice)));
};
(reading, Poll::Ready(chunk))
}
Err(e) => {
debug!("incoming body decode error: {}", e);
(Reading::Closed, Poll::Ready(Some(Err(e))))
}
}
}
Reading::Continue(ref decoder) => {
if let Writing::Init = self.state.writing {
trace!("automatically sending 100 Continue");
let cont = b"HTTP/1.1 100 Continue\r\n\r\n";
self.io.headers_buf().extend_from_slice(cont);
}
self.state.reading = Reading::Body(decoder.clone());
return self.poll_read_body(cx);
}
_ => unreachable!("poll_read_body invalid state: {:?}", self.state.reading),
};
self.state.reading = reading;
self.try_keep_alive(cx);
ret
}
pub(crate) fn wants_read_again(&mut self) -> bool {
let ret = self.state.notify_read;
self.state.notify_read = false;
ret
}
pub(crate) fn poll_read_keep_alive(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
debug_assert!(!self.can_read_head() && !self.can_read_body());
if self.is_read_closed() {
Poll::Pending
} else if self.is_mid_message() {
self.mid_message_detect_eof(cx)
} else {
self.require_empty_read(cx)
}
}
fn is_mid_message(&self) -> bool {
match (&self.state.reading, &self.state.writing) {
(&Reading::Init, &Writing::Init) => false,
_ => true,
}
}
fn require_empty_read(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
debug_assert!(!self.can_read_head() && !self.can_read_body() && !self.is_read_closed());
debug_assert!(!self.is_mid_message());
debug_assert!(T::is_client());
if !self.io.read_buf().is_empty() {
debug!("received an unexpected {} bytes", self.io.read_buf().len());
return Poll::Ready(Err(crate::Error::new_unexpected_message()));
}
let num_read = ready!(self.force_io_read(cx)).map_err(crate::Error::new_io)?;
if num_read == 0 {
let ret = if self.should_error_on_eof() {
trace!("found unexpected EOF on busy connection: {:?}", self.state);
Poll::Ready(Err(crate::Error::new_incomplete()))
} else {
trace!("found EOF on idle connection, closing");
Poll::Ready(Ok(()))
};
self.state.close_read();
return ret;
}
debug!(
"received unexpected {} bytes on an idle connection",
num_read
);
Poll::Ready(Err(crate::Error::new_unexpected_message()))
}
fn mid_message_detect_eof(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
debug_assert!(!self.can_read_head() && !self.can_read_body() && !self.is_read_closed());
debug_assert!(self.is_mid_message());
if self.state.allow_half_close || !self.io.read_buf().is_empty() {
return Poll::Pending;
}
let num_read = ready!(self.force_io_read(cx)).map_err(crate::Error::new_io)?;
if num_read == 0 {
trace!("found unexpected EOF on busy connection: {:?}", self.state);
self.state.close_read();
Poll::Ready(Err(crate::Error::new_incomplete()))
} else {
Poll::Ready(Ok(()))
}
}
fn force_io_read(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<usize>> {
debug_assert!(!self.state.is_read_closed());
let result = ready!(self.io.poll_read_from_io(cx));
Poll::Ready(result.map_err(|e| {
trace!("force_io_read; io error = {:?}", e);
self.state.close();
e
}))
}
fn maybe_notify(&mut self, cx: &mut task::Context<'_>) {
match self.state.reading {
Reading::Continue(..) | Reading::Body(..) | Reading::KeepAlive | Reading::Closed => {
return
}
Reading::Init => (),
};
match self.state.writing {
Writing::Body(..) => return,
Writing::Init | Writing::KeepAlive | Writing::Closed => (),
}
if !self.io.is_read_blocked() {
if self.io.read_buf().is_empty() {
match self.io.poll_read_from_io(cx) {
Poll::Ready(Ok(n)) => {
if n == 0 {
trace!("maybe_notify; read eof");
if self.state.is_idle() {
self.state.close();
} else {
self.close_read()
}
return;
}
}
Poll::Pending => {
trace!("maybe_notify; read_from_io blocked");
return;
}
Poll::Ready(Err(e)) => {
trace!("maybe_notify; read_from_io error: {}", e);
self.state.close();
self.state.error = Some(crate::Error::new_io(e));
}
}
}
self.state.notify_read = true;
}
}
fn try_keep_alive(&mut self, cx: &mut task::Context<'_>) {
self.state.try_keep_alive::<T>();
self.maybe_notify(cx);
}
pub(crate) fn can_write_head(&self) -> bool {
if !T::should_read_first() {
if let Reading::Closed = self.state.reading {
return false;
}
}
match self.state.writing {
Writing::Init => true,
_ => false,
}
}
pub(crate) fn can_write_body(&self) -> bool {
match self.state.writing {
Writing::Body(..) => true,
Writing::Init | Writing::KeepAlive | Writing::Closed => false,
}
}
pub(crate) fn can_buffer_body(&self) -> bool {
self.io.can_buffer()
}
pub(crate) fn write_head(&mut self, head: MessageHead<T::Outgoing>, body: Option<BodyLength>) {
if let Some(encoder) = self.encode_head(head, body) {
self.state.writing = if !encoder.is_eof() {
Writing::Body(encoder)
} else if encoder.is_last() {
Writing::Closed
} else {
Writing::KeepAlive
};
}
}
pub(crate) fn write_full_msg(&mut self, head: MessageHead<T::Outgoing>, body: B) {
if let Some(encoder) =
self.encode_head(head, Some(BodyLength::Known(body.remaining() as u64)))
{
let is_last = encoder.is_last();
if !encoder.is_eof() {
encoder.danger_full_buf(body, self.io.write_buf());
}
self.state.writing = if is_last {
Writing::Closed
} else {
Writing::KeepAlive
}
}
}
fn encode_head(
&mut self,
mut head: MessageHead<T::Outgoing>,
body: Option<BodyLength>,
) -> Option<Encoder> {
debug_assert!(self.can_write_head());
if !T::should_read_first() {
self.state.busy();
}
self.enforce_version(&mut head);
#[cfg(feature = "ffi")]
{
if T::is_client() && !self.state.preserve_header_case {
self.state.preserve_header_case =
head.extensions.get::<crate::ffi::HeaderCaseMap>().is_some();
}
}
let buf = self.io.headers_buf();
match super::role::encode_headers::<T>(
Encode {
head: &mut head,
body,
#[cfg(feature = "server")]
keep_alive: self.state.wants_keep_alive(),
req_method: &mut self.state.method,
title_case_headers: self.state.title_case_headers,
},
buf,
) {
Ok(encoder) => {
debug_assert!(self.state.cached_headers.is_none());
debug_assert!(head.headers.is_empty());
self.state.cached_headers = Some(head.headers);
Some(encoder)
}
Err(err) => {
self.state.error = Some(err);
self.state.writing = Writing::Closed;
None
}
}
}
fn fix_keep_alive(&mut self, head: &mut MessageHead<T::Outgoing>) {
let outgoing_is_keep_alive = head
.headers
.get(CONNECTION)
.map(connection_keep_alive)
.unwrap_or(false);
if !outgoing_is_keep_alive {
match head.version {
Version::HTTP_10 => self.state.disable_keep_alive(),
Version::HTTP_11 => {
if self.state.wants_keep_alive() {
head.headers
.insert(CONNECTION, HeaderValue::from_static("keep-alive"));
}
}
_ => (),
}
}
}
fn enforce_version(&mut self, head: &mut MessageHead<T::Outgoing>) {
if let Version::HTTP_10 = self.state.version {
self.fix_keep_alive(head);
head.version = Version::HTTP_10;
}
}
pub(crate) fn write_body(&mut self, chunk: B) {
debug_assert!(self.can_write_body() && self.can_buffer_body());
debug_assert!(chunk.remaining() != 0);
let state = match self.state.writing {
Writing::Body(ref mut encoder) => {
self.io.buffer(encoder.encode(chunk));
if encoder.is_eof() {
if encoder.is_last() {
Writing::Closed
} else {
Writing::KeepAlive
}
} else {
return;
}
}
_ => unreachable!("write_body invalid state: {:?}", self.state.writing),
};
self.state.writing = state;
}
pub(crate) fn write_body_and_end(&mut self, chunk: B) {
debug_assert!(self.can_write_body() && self.can_buffer_body());
debug_assert!(chunk.remaining() != 0);
let state = match self.state.writing {
Writing::Body(ref encoder) => {
let can_keep_alive = encoder.encode_and_end(chunk, self.io.write_buf());
if can_keep_alive {
Writing::KeepAlive
} else {
Writing::Closed
}
}
_ => unreachable!("write_body invalid state: {:?}", self.state.writing),
};
self.state.writing = state;
}
pub(crate) fn end_body(&mut self) -> crate::Result<()> {
debug_assert!(self.can_write_body());
let mut res = Ok(());
let state = match self.state.writing {
Writing::Body(ref mut encoder) => {
match encoder.end() {
Ok(end) => {
if let Some(end) = end {
self.io.buffer(end);
}
if encoder.is_last() || encoder.is_close_delimited() {
Writing::Closed
} else {
Writing::KeepAlive
}
}
Err(_not_eof) => {
res = Err(crate::Error::new_user_body(
crate::Error::new_body_write_aborted(),
));
Writing::Closed
}
}
}
_ => return Ok(()),
};
self.state.writing = state;
res
}
fn on_parse_error(&mut self, err: crate::Error) -> crate::Result<()> {
if let Writing::Init = self.state.writing {
if self.has_h2_prefix() {
return Err(crate::Error::new_version_h2());
}
if let Some(msg) = T::on_error(&err) {
self.state.cached_headers.take();
self.write_head(msg, None);
self.state.error = Some(err);
return Ok(());
}
}
Err(err)
}
pub(crate) fn poll_flush(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
ready!(Pin::new(&mut self.io).poll_flush(cx))?;
self.try_keep_alive(cx);
trace!("flushed({}): {:?}", T::LOG, self.state);
Poll::Ready(Ok(()))
}
pub(crate) fn poll_shutdown(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
match ready!(Pin::new(self.io.io_mut()).poll_shutdown(cx)) {
Ok(()) => {
trace!("shut down IO complete");
Poll::Ready(Ok(()))
}
Err(e) => {
debug!("error shutting down IO: {}", e);
Poll::Ready(Err(e))
}
}
}
pub(super) fn poll_drain_or_close_read(&mut self, cx: &mut task::Context<'_>) {
let _ = self.poll_read_body(cx);
match self.state.reading {
Reading::Init | Reading::KeepAlive => {
trace!("body drained");
return;
}
_ => self.close_read(),
}
}
pub(crate) fn close_read(&mut self) {
self.state.close_read();
}
pub(crate) fn close_write(&mut self) {
self.state.close_write();
}
#[cfg(feature = "server")]
pub(crate) fn disable_keep_alive(&mut self) {
if self.state.is_idle() {
trace!("disable_keep_alive; closing idle connection");
self.state.close();
} else {
trace!("disable_keep_alive; in-progress connection");
self.state.disable_keep_alive();
}
}
pub(crate) fn take_error(&mut self) -> crate::Result<()> {
if let Some(err) = self.state.error.take() {
Err(err)
} else {
Ok(())
}
}
pub(super) fn on_upgrade(&mut self) -> crate::upgrade::OnUpgrade {
trace!("{}: prepare possible HTTP upgrade", T::LOG);
self.state.prepare_upgrade()
}
}
impl<I, B: Buf, T> fmt::Debug for Conn<I, B, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Conn")
.field("state", &self.state)
.field("io", &self.io)
.finish()
}
}
impl<I: Unpin, B, T> Unpin for Conn<I, B, T> {}
struct State {
allow_half_close: bool,
cached_headers: Option<HeaderMap>,
error: Option<crate::Error>,
keep_alive: KA,
method: Option<Method>,
#[cfg(feature = "ffi")]
preserve_header_case: bool,
title_case_headers: bool,
notify_read: bool,
reading: Reading,
writing: Writing,
upgrade: Option<crate::upgrade::Pending>,
version: Version,
}
#[derive(Debug)]
enum Reading {
Init,
Continue(Decoder),
Body(Decoder),
KeepAlive,
Closed,
}
enum Writing {
Init,
Body(Encoder),
KeepAlive,
Closed,
}
impl fmt::Debug for State {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut builder = f.debug_struct("State");
builder
.field("reading", &self.reading)
.field("writing", &self.writing)
.field("keep_alive", &self.keep_alive);
if let Some(ref error) = self.error {
builder.field("error", error);
}
if self.allow_half_close {
builder.field("allow_half_close", &true);
}
builder.finish()
}
}
impl fmt::Debug for Writing {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
Writing::Init => f.write_str("Init"),
Writing::Body(ref enc) => f.debug_tuple("Body").field(enc).finish(),
Writing::KeepAlive => f.write_str("KeepAlive"),
Writing::Closed => f.write_str("Closed"),
}
}
}
impl std::ops::BitAndAssign<bool> for KA {
fn bitand_assign(&mut self, enabled: bool) {
if !enabled {
trace!("remote disabling keep-alive");
*self = KA::Disabled;
}
}
}
#[derive(Clone, Copy, Debug)]
enum KA {
Idle,
Busy,
Disabled,
}
impl Default for KA {
fn default() -> KA {
KA::Busy
}
}
impl KA {
fn idle(&mut self) {
*self = KA::Idle;
}
fn busy(&mut self) {
*self = KA::Busy;
}
fn disable(&mut self) {
*self = KA::Disabled;
}
fn status(&self) -> KA {
*self
}
}
impl State {
fn close(&mut self) {
trace!("State::close()");
self.reading = Reading::Closed;
self.writing = Writing::Closed;
self.keep_alive.disable();
}
fn close_read(&mut self) {
trace!("State::close_read()");
self.reading = Reading::Closed;
self.keep_alive.disable();
}
fn close_write(&mut self) {
trace!("State::close_write()");
self.writing = Writing::Closed;
self.keep_alive.disable();
}
fn wants_keep_alive(&self) -> bool {
if let KA::Disabled = self.keep_alive.status() {
false
} else {
true
}
}
fn try_keep_alive<T: Http1Transaction>(&mut self) {
match (&self.reading, &self.writing) {
(&Reading::KeepAlive, &Writing::KeepAlive) => {
if let KA::Busy = self.keep_alive.status() {
self.idle::<T>();
} else {
trace!(
"try_keep_alive({}): could keep-alive, but status = {:?}",
T::LOG,
self.keep_alive
);
self.close();
}
}
(&Reading::Closed, &Writing::KeepAlive) | (&Reading::KeepAlive, &Writing::Closed) => {
self.close()
}
_ => (),
}
}
fn disable_keep_alive(&mut self) {
self.keep_alive.disable()
}
fn busy(&mut self) {
if let KA::Disabled = self.keep_alive.status() {
return;
}
self.keep_alive.busy();
}
fn idle<T: Http1Transaction>(&mut self) {
debug_assert!(!self.is_idle(), "State::idle() called while idle");
self.method = None;
self.keep_alive.idle();
if self.is_idle() {
self.reading = Reading::Init;
self.writing = Writing::Init;
if !T::should_read_first() {
self.notify_read = true;
}
} else {
self.close();
}
}
fn is_idle(&self) -> bool {
if let KA::Idle = self.keep_alive.status() {
true
} else {
false
}
}
fn is_read_closed(&self) -> bool {
match self.reading {
Reading::Closed => true,
_ => false,
}
}
fn is_write_closed(&self) -> bool {
match self.writing {
Writing::Closed => true,
_ => false,
}
}
fn prepare_upgrade(&mut self) -> crate::upgrade::OnUpgrade {
let (tx, rx) = crate::upgrade::pending();
self.upgrade = Some(tx);
rx
}
}
#[cfg(test)]
mod tests {
#[cfg(feature = "nightly")]
#[bench]
fn bench_read_head_short(b: &mut ::test::Bencher) {
use super::*;
let s = b"GET / HTTP/1.1\r\nHost: localhost:8080\r\n\r\n";
let len = s.len();
b.bytes = len as u64;
let io = tokio_test::io::Builder::new().build();
let mut conn = Conn::<_, bytes::Bytes, crate::proto::h1::ServerTransaction>::new(io);
*conn.io.read_buf_mut() = ::bytes::BytesMut::from(&s[..]);
conn.state.cached_headers = Some(HeaderMap::with_capacity(2));
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
b.iter(|| {
rt.block_on(futures_util::future::poll_fn(|cx| {
match conn.poll_read_head(cx) {
Poll::Ready(Some(Ok(x))) => {
::test::black_box(&x);
let mut headers = x.0.headers;
headers.clear();
conn.state.cached_headers = Some(headers);
}
f => panic!("expected Ready(Some(Ok(..))): {:?}", f),
}
conn.io.read_buf_mut().reserve(1);
unsafe {
conn.io.read_buf_mut().set_len(len);
}
conn.state.reading = Reading::Init;
Poll::Ready(())
}));
});
}
}