diff options
Diffstat (limited to 'ext/http/lib.rs')
-rw-r--r-- | ext/http/lib.rs | 100 |
1 files changed, 77 insertions, 23 deletions
diff --git a/ext/http/lib.rs b/ext/http/lib.rs index af117d3f9..812394d94 100644 --- a/ext/http/lib.rs +++ b/ext/http/lib.rs @@ -70,9 +70,12 @@ use tokio::io::AsyncRead; use tokio::io::AsyncWrite; use tokio::io::AsyncWriteExt; use tokio::task::spawn_local; -use tokio_util::io::ReaderStream; + +use crate::reader_stream::ExternallyAbortableReaderStream; +use crate::reader_stream::ShutdownHandle; pub mod compressible; +mod reader_stream; pub fn init() -> Extension { Extension::builder() @@ -414,8 +417,11 @@ impl Default for HttpRequestReader { /// The write half of an HTTP stream. enum HttpResponseWriter { Headers(oneshot::Sender<Response<Body>>), - Body(Pin<Box<dyn tokio::io::AsyncWrite>>), - BodyUncompressed(hyper::body::Sender), + Body { + writer: Pin<Box<dyn tokio::io::AsyncWrite>>, + shutdown_handle: ShutdownHandle, + }, + BodyUncompressed(BodyUncompressedSender), Closed, } @@ -425,6 +431,36 @@ impl Default for HttpResponseWriter { } } +struct BodyUncompressedSender(Option<hyper::body::Sender>); + +impl BodyUncompressedSender { + fn sender(&mut self) -> &mut hyper::body::Sender { + // This is safe because we only ever take the sender out of the option + // inside of the shutdown method. + self.0.as_mut().unwrap() + } + + fn shutdown(mut self) { + // take the sender out of self so that when self is dropped at the end of + // this block, it doesn't get aborted + self.0.take(); + } +} + +impl From<hyper::body::Sender> for BodyUncompressedSender { + fn from(sender: hyper::body::Sender) -> Self { + BodyUncompressedSender(Some(sender)) + } +} + +impl Drop for BodyUncompressedSender { + fn drop(&mut self) { + if let Some(sender) = self.0.take() { + sender.abort(); + } + } +} + // We use a tuple instead of struct to avoid serialization overhead of the keys. #[derive(Serialize)] #[serde(rename_all = "camelCase")] @@ -668,14 +704,22 @@ fn http_response( Encoding::Gzip => Box::pin(GzipEncoder::new(writer)), _ => unreachable!(), // forbidden by accepts_compression }; + let (stream, shutdown_handle) = + ExternallyAbortableReaderStream::new(reader); Ok(( - HttpResponseWriter::Body(writer), - Body::wrap_stream(ReaderStream::new(reader)), + HttpResponseWriter::Body { + writer, + shutdown_handle, + }, + Body::wrap_stream(stream), )) } None => { let (body_tx, body_rx) = Body::channel(); - Ok((HttpResponseWriter::BodyUncompressed(body_tx), body_rx)) + Ok(( + HttpResponseWriter::BodyUncompressed(body_tx.into()), + body_rx, + )) } } } @@ -768,10 +812,10 @@ async fn op_http_write_resource( } match &mut *wr { - HttpResponseWriter::Body(body) => { - let mut result = body.write_all(&view).await; + HttpResponseWriter::Body { writer, .. } => { + let mut result = writer.write_all(&view).await; if result.is_ok() { - result = body.flush().await; + result = writer.flush().await; } if let Err(err) = result { assert_eq!(err.kind(), std::io::ErrorKind::BrokenPipe); @@ -784,7 +828,7 @@ async fn op_http_write_resource( } HttpResponseWriter::BodyUncompressed(body) => { let bytes = Bytes::from(view); - if let Err(err) = body.send_data(bytes).await { + if let Err(err) = body.sender().send_data(bytes).await { assert!(err.is_closed()); // Pull up the failure associated with the transport connection instead. http_stream.conn.closed().await?; @@ -813,10 +857,10 @@ async fn op_http_write( match &mut *wr { HttpResponseWriter::Headers(_) => Err(http_error("no response headers")), HttpResponseWriter::Closed => Err(http_error("response already completed")), - HttpResponseWriter::Body(body) => { - let mut result = body.write_all(&buf).await; + HttpResponseWriter::Body { writer, .. } => { + let mut result = writer.write_all(&buf).await; if result.is_ok() { - result = body.flush().await; + result = writer.flush().await; } match result { Ok(_) => Ok(()), @@ -833,7 +877,7 @@ async fn op_http_write( } HttpResponseWriter::BodyUncompressed(body) => { let bytes = Bytes::from(buf); - match body.send_data(bytes).await { + match body.sender().send_data(bytes).await { Ok(_) => Ok(()), Err(err) => { assert!(err.is_closed()); @@ -862,17 +906,27 @@ async fn op_http_shutdown( .get::<HttpStreamResource>(rid)?; let mut wr = RcRef::map(&stream, |r| &r.wr).borrow_mut().await; let wr = take(&mut *wr); - if let HttpResponseWriter::Body(mut body_writer) = wr { - match body_writer.shutdown().await { - Ok(_) => {} - Err(err) => { - assert_eq!(err.kind(), std::io::ErrorKind::BrokenPipe); - // Don't return "broken pipe", that's an implementation detail. - // Pull up the failure associated with the transport connection instead. - stream.conn.closed().await?; + match wr { + HttpResponseWriter::Body { + mut writer, + shutdown_handle, + } => { + shutdown_handle.shutdown(); + match writer.shutdown().await { + Ok(_) => {} + Err(err) => { + assert_eq!(err.kind(), std::io::ErrorKind::BrokenPipe); + // Don't return "broken pipe", that's an implementation detail. + // Pull up the failure associated with the transport connection instead. + stream.conn.closed().await?; + } } } - } + HttpResponseWriter::BodyUncompressed(body) => { + body.shutdown(); + } + _ => {} + }; Ok(()) } |