summaryrefslogtreecommitdiff
path: root/ext/http/response_body.rs
diff options
context:
space:
mode:
Diffstat (limited to 'ext/http/response_body.rs')
-rw-r--r--ext/http/response_body.rs291
1 files changed, 266 insertions, 25 deletions
diff --git a/ext/http/response_body.rs b/ext/http/response_body.rs
index ea6cc5ab8..7da2142d3 100644
--- a/ext/http/response_body.rs
+++ b/ext/http/response_body.rs
@@ -118,6 +118,7 @@ trait PollFrame: Unpin {
pub enum Compression {
None,
GZip,
+ Brotli,
}
pub enum ResponseStream {
@@ -140,6 +141,8 @@ pub enum ResponseBytesInner {
UncompressedStream(ResponseStream),
/// A GZip stream.
GZipStream(GZipResponseStream),
+ /// A Brotli stream.
+ BrotliStream(BrotliResponseStream),
}
impl std::fmt::Debug for ResponseBytesInner {
@@ -150,6 +153,7 @@ impl std::fmt::Debug for ResponseBytesInner {
Self::Bytes(..) => f.write_str("Bytes"),
Self::UncompressedStream(..) => f.write_str("Uncompressed"),
Self::GZipStream(..) => f.write_str("GZip"),
+ Self::BrotliStream(..) => f.write_str("Brotli"),
}
}
}
@@ -197,14 +201,17 @@ impl ResponseBytesInner {
Self::Bytes(bytes) => SizeHint::with_exact(bytes.len() as u64),
Self::UncompressedStream(res) => res.size_hint(),
Self::GZipStream(..) => SizeHint::default(),
+ Self::BrotliStream(..) => SizeHint::default(),
}
}
fn from_stream(compression: Compression, stream: ResponseStream) -> Self {
- if compression == Compression::GZip {
- Self::GZipStream(GZipResponseStream::new(stream))
- } else {
- Self::UncompressedStream(stream)
+ match compression {
+ Compression::GZip => Self::GZipStream(GZipResponseStream::new(stream)),
+ Compression::Brotli => {
+ Self::BrotliStream(BrotliResponseStream::new(stream))
+ }
+ _ => Self::UncompressedStream(stream),
}
}
@@ -227,22 +234,45 @@ impl ResponseBytesInner {
}
pub fn from_slice(compression: Compression, bytes: &[u8]) -> Self {
- if compression == Compression::GZip {
- let mut writer = GzEncoder::new(Vec::new(), flate2::Compression::fast());
- writer.write_all(bytes).unwrap();
- Self::Bytes(BufView::from(writer.finish().unwrap()))
- } else {
- Self::Bytes(BufView::from(bytes.to_vec()))
+ match compression {
+ Compression::GZip => {
+ let mut writer =
+ GzEncoder::new(Vec::new(), flate2::Compression::fast());
+ writer.write_all(bytes).unwrap();
+ Self::Bytes(BufView::from(writer.finish().unwrap()))
+ }
+ Compression::Brotli => {
+ // quality level 6 is based on google's nginx default value for
+ // on-the-fly compression
+ // https://github.com/google/ngx_brotli#brotli_comp_level
+ // lgwin 22 is equivalent to brotli window size of (2**22)-16 bytes
+ // (~4MB)
+ let mut writer =
+ brotli::CompressorWriter::new(Vec::new(), 65 * 1024, 6, 22);
+ writer.write_all(bytes).unwrap();
+ writer.flush().unwrap();
+ Self::Bytes(BufView::from(writer.into_inner()))
+ }
+ _ => Self::Bytes(BufView::from(bytes.to_vec())),
}
}
pub fn from_vec(compression: Compression, vec: Vec<u8>) -> Self {
- if compression == Compression::GZip {
- let mut writer = GzEncoder::new(Vec::new(), flate2::Compression::fast());
- writer.write_all(&vec).unwrap();
- Self::Bytes(BufView::from(writer.finish().unwrap()))
- } else {
- Self::Bytes(BufView::from(vec))
+ match compression {
+ Compression::GZip => {
+ let mut writer =
+ GzEncoder::new(Vec::new(), flate2::Compression::fast());
+ writer.write_all(&vec).unwrap();
+ Self::Bytes(BufView::from(writer.finish().unwrap()))
+ }
+ Compression::Brotli => {
+ let mut writer =
+ brotli::CompressorWriter::new(Vec::new(), 65 * 1024, 6, 22);
+ writer.write_all(&vec).unwrap();
+ writer.flush().unwrap();
+ Self::Bytes(BufView::from(writer.into_inner()))
+ }
+ _ => Self::Bytes(BufView::from(vec)),
}
}
}
@@ -273,6 +303,9 @@ impl Body for ResponseBytes {
ResponseBytesInner::GZipStream(stm) => {
ready!(Pin::new(stm).poll_frame(cx))
}
+ ResponseBytesInner::BrotliStream(stm) => {
+ ready!(Pin::new(stm).poll_frame(cx))
+ }
};
// This is where we retry the NoData response
if matches!(res, ResponseStreamResult::NoData) {
@@ -546,6 +579,157 @@ impl PollFrame for GZipResponseStream {
}
}
+#[derive(Copy, Clone, Debug)]
+enum BrotliState {
+ Streaming,
+ Flushing,
+ EndOfStream,
+}
+
+#[pin_project]
+pub struct BrotliResponseStream {
+ state: BrotliState,
+ stm: *mut brotli::ffi::compressor::BrotliEncoderState,
+ current_cursor: usize,
+ output_written_so_far: usize,
+ #[pin]
+ underlying: ResponseStream,
+}
+
+impl BrotliResponseStream {
+ pub fn new(underlying: ResponseStream) -> Self {
+ Self {
+ // SAFETY: creating an FFI instance should be OK with these args.
+ stm: unsafe {
+ brotli::ffi::compressor::BrotliEncoderCreateInstance(
+ None,
+ None,
+ std::ptr::null_mut(),
+ )
+ },
+ output_written_so_far: 0,
+ current_cursor: 0,
+ state: BrotliState::Streaming,
+ underlying,
+ }
+ }
+}
+
+fn max_compressed_size(input_size: usize) -> usize {
+ if input_size == 0 {
+ return 2;
+ }
+
+ // [window bits / empty metadata] + N * [uncompressed] + [last empty]
+ let num_large_blocks = input_size >> 14;
+ let overhead = 2 + (4 * num_large_blocks) + 3 + 1;
+ let result = input_size + overhead;
+
+ if result < input_size {
+ 0
+ } else {
+ result
+ }
+}
+
+impl PollFrame for BrotliResponseStream {
+ fn poll_frame(
+ self: Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<ResponseStreamResult> {
+ let this = self.get_mut();
+ let state = &mut this.state;
+ let frame = match *state {
+ BrotliState::Streaming => {
+ ready!(Pin::new(&mut this.underlying).poll_frame(cx))
+ }
+ BrotliState::Flushing => ResponseStreamResult::EndOfStream,
+ BrotliState::EndOfStream => {
+ return std::task::Poll::Ready(ResponseStreamResult::EndOfStream);
+ }
+ };
+
+ let res = match frame {
+ ResponseStreamResult::NonEmptyBuf(buf) => {
+ let mut output_written = 0;
+ let mut total_output_written = 0;
+ let mut input_size = buf.len();
+ let input_buffer = buf.as_ref();
+ let mut len = max_compressed_size(input_size);
+ let mut output_buffer = vec![0u8; len];
+ let mut ob_ptr = output_buffer.as_mut_ptr();
+
+ // SAFETY: these are okay arguments to these FFI calls.
+ unsafe {
+ brotli::ffi::compressor::BrotliEncoderCompressStream(
+ this.stm,
+ brotli::ffi::compressor::BrotliEncoderOperation::BROTLI_OPERATION_PROCESS,
+ &mut input_size,
+ &input_buffer.as_ptr() as *const *const u8 as *mut *const u8,
+ &mut len,
+ &mut ob_ptr,
+ &mut output_written,
+ );
+ total_output_written += output_written;
+ output_written = 0;
+
+ brotli::ffi::compressor::BrotliEncoderCompressStream(
+ this.stm,
+ brotli::ffi::compressor::BrotliEncoderOperation::BROTLI_OPERATION_FLUSH,
+ &mut input_size,
+ &input_buffer.as_ptr() as *const *const u8 as *mut *const u8,
+ &mut len,
+ &mut ob_ptr,
+ &mut output_written,
+ );
+ total_output_written += output_written;
+ };
+
+ output_buffer
+ .truncate(total_output_written - this.output_written_so_far);
+ this.output_written_so_far = total_output_written;
+ ResponseStreamResult::NonEmptyBuf(BufView::from(output_buffer))
+ }
+ ResponseStreamResult::EndOfStream => {
+ let mut len = 1024usize;
+ let mut output_buffer = vec![0u8; len];
+ let mut input_size = 0;
+ let mut output_written = 0;
+ let ob_ptr = output_buffer.as_mut_ptr();
+
+ // SAFETY: these are okay arguments to these FFI calls.
+ unsafe {
+ brotli::ffi::compressor::BrotliEncoderCompressStream(
+ this.stm,
+ brotli::ffi::compressor::BrotliEncoderOperation::BROTLI_OPERATION_FINISH,
+ &mut input_size,
+ std::ptr::null_mut(),
+ &mut len,
+ &ob_ptr as *const *mut u8 as *mut *mut u8,
+ &mut output_written,
+ );
+ };
+
+ if output_written == 0 {
+ this.state = BrotliState::EndOfStream;
+ ResponseStreamResult::EndOfStream
+ } else {
+ this.state = BrotliState::Flushing;
+ output_buffer.truncate(output_written - this.output_written_so_far);
+ ResponseStreamResult::NonEmptyBuf(BufView::from(output_buffer))
+ }
+ }
+ _ => frame,
+ };
+
+ std::task::Poll::Ready(res)
+ }
+
+ fn size_hint(&self) -> SizeHint {
+ SizeHint::default()
+ }
+}
+
/// A response body object that can be passed to V8. This body will feed byte buffers to a channel which
/// feed's hyper's HTTP response.
pub struct V8StreamHttpResponseBody(
@@ -670,7 +854,7 @@ mod tests {
vec![v, v2].into_iter()
}
- async fn test(i: impl Iterator<Item = Vec<u8>> + Send + 'static) {
+ async fn test_gzip(i: impl Iterator<Item = Vec<u8>> + Send + 'static) {
let v = i.collect::<Vec<_>>();
let mut expected: Vec<u8> = vec![];
for v in &v {
@@ -712,19 +896,66 @@ mod tests {
handle.await.unwrap();
}
+ async fn test_brotli(i: impl Iterator<Item = Vec<u8>> + Send + 'static) {
+ let v = i.collect::<Vec<_>>();
+ let mut expected: Vec<u8> = vec![];
+ for v in &v {
+ expected.extend(v);
+ }
+ let (tx, rx) = tokio::sync::mpsc::channel(1);
+ let underlying = ResponseStream::V8Stream(rx);
+ let mut resp = BrotliResponseStream::new(underlying);
+ let handle = tokio::task::spawn(async move {
+ for chunk in v {
+ tx.send(chunk.into()).await.ok().unwrap();
+ }
+ });
+ // Limit how many times we'll loop
+ const LIMIT: usize = 1000;
+ let mut v: Vec<u8> = vec![];
+ for i in 0..=LIMIT {
+ assert_ne!(i, LIMIT);
+ let frame = poll_fn(|cx| Pin::new(&mut resp).poll_frame(cx)).await;
+ if matches!(frame, ResponseStreamResult::EndOfStream) {
+ break;
+ }
+ if matches!(frame, ResponseStreamResult::NoData) {
+ continue;
+ }
+ let ResponseStreamResult::NonEmptyBuf(buf) = frame else {
+ panic!("Unexpected stream type");
+ };
+ assert_ne!(buf.len(), 0);
+ v.extend(&*buf);
+ }
+
+ let mut gz = brotli::Decompressor::new(&*v, v.len());
+ let mut v = vec![];
+ if !expected.is_empty() {
+ gz.read_to_end(&mut v).unwrap();
+ }
+
+ assert_eq!(v, expected);
+
+ handle.await.unwrap();
+ }
+
#[tokio::test]
async fn test_simple() {
- test(vec![b"hello world".to_vec()].into_iter()).await
+ test_brotli(vec![b"hello world".to_vec()].into_iter()).await;
+ test_gzip(vec![b"hello world".to_vec()].into_iter()).await;
}
#[tokio::test]
async fn test_empty() {
- test(vec![].into_iter()).await
+ test_brotli(vec![].into_iter()).await;
+ test_gzip(vec![].into_iter()).await;
}
#[tokio::test]
async fn test_simple_zeros() {
- test(vec![vec![0; 0x10000]].into_iter()).await
+ test_brotli(vec![vec![0; 0x10000]].into_iter()).await;
+ test_gzip(vec![vec![0; 0x10000]].into_iter()).await;
}
macro_rules! test {
@@ -733,31 +964,41 @@ mod tests {
#[tokio::test]
async fn chunk() {
let iter = super::chunk(super::$vec());
- super::test(iter).await;
+ super::test_gzip(iter).await;
+ let br_iter = super::chunk(super::$vec());
+ super::test_brotli(br_iter).await;
}
#[tokio::test]
async fn front_load() {
let iter = super::front_load(super::$vec());
- super::test(iter).await;
+ super::test_gzip(iter).await;
+ let br_iter = super::front_load(super::$vec());
+ super::test_brotli(br_iter).await;
}
#[tokio::test]
async fn front_load_but_one() {
let iter = super::front_load_but_one(super::$vec());
- super::test(iter).await;
+ super::test_gzip(iter).await;
+ let br_iter = super::front_load_but_one(super::$vec());
+ super::test_brotli(br_iter).await;
}
#[tokio::test]
async fn back_load() {
let iter = super::back_load(super::$vec());
- super::test(iter).await;
+ super::test_gzip(iter).await;
+ let br_iter = super::back_load(super::$vec());
+ super::test_brotli(br_iter).await;
}
#[tokio::test]
async fn random() {
let iter = super::random(super::$vec());
- super::test(iter).await;
+ super::test_gzip(iter).await;
+ let br_iter = super::random(super::$vec());
+ super::test_brotli(br_iter).await;
}
}
};