summaryrefslogtreecommitdiff
path: root/ext/http/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'ext/http/lib.rs')
-rw-r--r--ext/http/lib.rs100
1 files changed, 77 insertions, 23 deletions
diff --git a/ext/http/lib.rs b/ext/http/lib.rs
index af117d3f9..812394d94 100644
--- a/ext/http/lib.rs
+++ b/ext/http/lib.rs
@@ -70,9 +70,12 @@ use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tokio::io::AsyncWriteExt;
use tokio::task::spawn_local;
-use tokio_util::io::ReaderStream;
+
+use crate::reader_stream::ExternallyAbortableReaderStream;
+use crate::reader_stream::ShutdownHandle;
pub mod compressible;
+mod reader_stream;
pub fn init() -> Extension {
Extension::builder()
@@ -414,8 +417,11 @@ impl Default for HttpRequestReader {
/// The write half of an HTTP stream.
enum HttpResponseWriter {
Headers(oneshot::Sender<Response<Body>>),
- Body(Pin<Box<dyn tokio::io::AsyncWrite>>),
- BodyUncompressed(hyper::body::Sender),
+ Body {
+ writer: Pin<Box<dyn tokio::io::AsyncWrite>>,
+ shutdown_handle: ShutdownHandle,
+ },
+ BodyUncompressed(BodyUncompressedSender),
Closed,
}
@@ -425,6 +431,36 @@ impl Default for HttpResponseWriter {
}
}
+struct BodyUncompressedSender(Option<hyper::body::Sender>);
+
+impl BodyUncompressedSender {
+ fn sender(&mut self) -> &mut hyper::body::Sender {
+ // This is safe because we only ever take the sender out of the option
+ // inside of the shutdown method.
+ self.0.as_mut().unwrap()
+ }
+
+ fn shutdown(mut self) {
+ // take the sender out of self so that when self is dropped at the end of
+ // this block, it doesn't get aborted
+ self.0.take();
+ }
+}
+
+impl From<hyper::body::Sender> for BodyUncompressedSender {
+ fn from(sender: hyper::body::Sender) -> Self {
+ BodyUncompressedSender(Some(sender))
+ }
+}
+
+impl Drop for BodyUncompressedSender {
+ fn drop(&mut self) {
+ if let Some(sender) = self.0.take() {
+ sender.abort();
+ }
+ }
+}
+
// We use a tuple instead of struct to avoid serialization overhead of the keys.
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
@@ -668,14 +704,22 @@ fn http_response(
Encoding::Gzip => Box::pin(GzipEncoder::new(writer)),
_ => unreachable!(), // forbidden by accepts_compression
};
+ let (stream, shutdown_handle) =
+ ExternallyAbortableReaderStream::new(reader);
Ok((
- HttpResponseWriter::Body(writer),
- Body::wrap_stream(ReaderStream::new(reader)),
+ HttpResponseWriter::Body {
+ writer,
+ shutdown_handle,
+ },
+ Body::wrap_stream(stream),
))
}
None => {
let (body_tx, body_rx) = Body::channel();
- Ok((HttpResponseWriter::BodyUncompressed(body_tx), body_rx))
+ Ok((
+ HttpResponseWriter::BodyUncompressed(body_tx.into()),
+ body_rx,
+ ))
}
}
}
@@ -768,10 +812,10 @@ async fn op_http_write_resource(
}
match &mut *wr {
- HttpResponseWriter::Body(body) => {
- let mut result = body.write_all(&view).await;
+ HttpResponseWriter::Body { writer, .. } => {
+ let mut result = writer.write_all(&view).await;
if result.is_ok() {
- result = body.flush().await;
+ result = writer.flush().await;
}
if let Err(err) = result {
assert_eq!(err.kind(), std::io::ErrorKind::BrokenPipe);
@@ -784,7 +828,7 @@ async fn op_http_write_resource(
}
HttpResponseWriter::BodyUncompressed(body) => {
let bytes = Bytes::from(view);
- if let Err(err) = body.send_data(bytes).await {
+ if let Err(err) = body.sender().send_data(bytes).await {
assert!(err.is_closed());
// Pull up the failure associated with the transport connection instead.
http_stream.conn.closed().await?;
@@ -813,10 +857,10 @@ async fn op_http_write(
match &mut *wr {
HttpResponseWriter::Headers(_) => Err(http_error("no response headers")),
HttpResponseWriter::Closed => Err(http_error("response already completed")),
- HttpResponseWriter::Body(body) => {
- let mut result = body.write_all(&buf).await;
+ HttpResponseWriter::Body { writer, .. } => {
+ let mut result = writer.write_all(&buf).await;
if result.is_ok() {
- result = body.flush().await;
+ result = writer.flush().await;
}
match result {
Ok(_) => Ok(()),
@@ -833,7 +877,7 @@ async fn op_http_write(
}
HttpResponseWriter::BodyUncompressed(body) => {
let bytes = Bytes::from(buf);
- match body.send_data(bytes).await {
+ match body.sender().send_data(bytes).await {
Ok(_) => Ok(()),
Err(err) => {
assert!(err.is_closed());
@@ -862,17 +906,27 @@ async fn op_http_shutdown(
.get::<HttpStreamResource>(rid)?;
let mut wr = RcRef::map(&stream, |r| &r.wr).borrow_mut().await;
let wr = take(&mut *wr);
- if let HttpResponseWriter::Body(mut body_writer) = wr {
- match body_writer.shutdown().await {
- Ok(_) => {}
- Err(err) => {
- assert_eq!(err.kind(), std::io::ErrorKind::BrokenPipe);
- // Don't return "broken pipe", that's an implementation detail.
- // Pull up the failure associated with the transport connection instead.
- stream.conn.closed().await?;
+ match wr {
+ HttpResponseWriter::Body {
+ mut writer,
+ shutdown_handle,
+ } => {
+ shutdown_handle.shutdown();
+ match writer.shutdown().await {
+ Ok(_) => {}
+ Err(err) => {
+ assert_eq!(err.kind(), std::io::ErrorKind::BrokenPipe);
+ // Don't return "broken pipe", that's an implementation detail.
+ // Pull up the failure associated with the transport connection instead.
+ stream.conn.closed().await?;
+ }
}
}
- }
+ HttpResponseWriter::BodyUncompressed(body) => {
+ body.shutdown();
+ }
+ _ => {}
+ };
Ok(())
}