diff options
Diffstat (limited to 'ext/http/websocket_upgrade.rs')
-rw-r--r-- | ext/http/websocket_upgrade.rs | 333 |
1 files changed, 333 insertions, 0 deletions
diff --git a/ext/http/websocket_upgrade.rs b/ext/http/websocket_upgrade.rs new file mode 100644 index 000000000..042a46721 --- /dev/null +++ b/ext/http/websocket_upgrade.rs @@ -0,0 +1,333 @@ +// Copyright 2018-2023 the Deno authors. All rights reserved. MIT license. + +use bytes::Bytes; +use bytes::BytesMut; +use deno_core::error::AnyError; +use httparse::Status; +use hyper::http::HeaderName; +use hyper::http::HeaderValue; +use hyper::Body; +use hyper::Response; +use memmem::Searcher; +use memmem::TwoWaySearcher; +use once_cell::sync::OnceCell; + +use crate::http_error; + +/// Given a buffer that ends in `\n\n` or `\r\n\r\n`, returns a parsed [`Request<Body>`]. +fn parse_response( + header_bytes: &[u8], +) -> Result<(usize, Response<Body>), AnyError> { + let mut headers = [httparse::EMPTY_HEADER; 16]; + let status = httparse::parse_headers(header_bytes, &mut headers)?; + match status { + Status::Complete((index, parsed)) => { + let mut resp = Response::builder().status(101).body(Body::empty())?; + for header in parsed.iter() { + resp.headers_mut().append( + HeaderName::from_bytes(header.name.as_bytes())?, + HeaderValue::from_str(std::str::from_utf8(header.value)?)?, + ); + } + Ok((index, resp)) + } + _ => Err(http_error("invalid headers")), + } +} + +/// Find a newline in a slice. +fn find_newline(slice: &[u8]) -> Option<usize> { + for (i, byte) in slice.iter().enumerate() { + if *byte == b'\n' { + return Some(i); + } + } + None +} + +/// WebSocket upgrade state machine states. +#[derive(Default)] +enum WebSocketUpgradeState { + #[default] + Initial, + StatusLine, + Headers, + Complete, +} + +static HEADER_SEARCHER: OnceCell<TwoWaySearcher> = OnceCell::new(); +static HEADER_SEARCHER2: OnceCell<TwoWaySearcher> = OnceCell::new(); + +#[derive(Default)] +pub struct WebSocketUpgrade { + state: WebSocketUpgradeState, + buf: BytesMut, +} + +impl WebSocketUpgrade { + /// Ensures that the status line starts with "HTTP/1.1 101 " which matches all of the node.js + /// WebSocket libraries that are known. We don't care about the trailing status text. + fn validate_status(&self, status: &[u8]) -> Result<(), AnyError> { + if status.starts_with(b"HTTP/1.1 101 ") { + Ok(()) + } else { + Err(http_error("invalid HTTP status line")) + } + } + + /// Writes bytes to our upgrade buffer, returning [`Ok(None)`] if we need to keep feeding it data, + /// [`Ok(Some(Response))`] if we got a valid upgrade header, or [`Err`] if something went badly. + pub fn write( + &mut self, + bytes: &[u8], + ) -> Result<Option<(Response<Body>, Bytes)>, AnyError> { + use WebSocketUpgradeState::*; + + match self.state { + Initial => { + if let Some(index) = find_newline(bytes) { + let (status, rest) = bytes.split_at(index + 1); + self.validate_status(status)?; + + // Fast path for the most common node.js WebSocket libraries that use \r\n as the + // separator between header lines and send the whole response in one packet. + if rest.ends_with(b"\r\n\r\n") { + let (index, response) = parse_response(rest)?; + if index == rest.len() { + return Ok(Some((response, Bytes::default()))); + } else { + let bytes = Bytes::copy_from_slice(&rest[index..]); + return Ok(Some((response, bytes))); + } + } + + self.state = Headers; + self.write(rest) + } else { + self.state = StatusLine; + self.buf.extend_from_slice(bytes); + Ok(None) + } + } + StatusLine => { + if let Some(index) = find_newline(bytes) { + let (status, rest) = bytes.split_at(index + 1); + self.buf.extend_from_slice(status); + self.validate_status(&self.buf)?; + self.buf.clear(); + // Recursively process this write + self.state = Headers; + self.write(rest) + } else { + self.buf.extend_from_slice(bytes); + Ok(None) + } + } + Headers => { + self.buf.extend_from_slice(bytes); + let header_searcher = + HEADER_SEARCHER.get_or_init(|| TwoWaySearcher::new(b"\r\n\r\n")); + let header_searcher2 = + HEADER_SEARCHER2.get_or_init(|| TwoWaySearcher::new(b"\n\n")); + if let Some(..) = header_searcher.search_in(&self.buf) { + let (index, response) = parse_response(&self.buf)?; + let mut buf = std::mem::take(&mut self.buf); + self.state = Complete; + Ok(Some((response, buf.split_off(index).freeze()))) + } else if let Some(..) = header_searcher2.search_in(&self.buf) { + let (index, response) = parse_response(&self.buf)?; + let mut buf = std::mem::take(&mut self.buf); + self.state = Complete; + Ok(Some((response, buf.split_off(index).freeze()))) + } else { + Ok(None) + } + } + Complete => { + Err(http_error("attempted to write to completed upgrade buffer")) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + type ExpectedResponseAndHead = Option<(Response<Body>, &'static [u8])>; + + fn assert_response( + result: Result<Option<(Response<Body>, Bytes)>, AnyError>, + expected: Result<ExpectedResponseAndHead, &'static str>, + chunk_info: Option<(usize, usize)>, + ) { + let formatted = format!("{result:?}"); + match expected { + Ok(Some((resp1, remainder1))) => match result { + Ok(Some((resp2, remainder2))) => { + assert_eq!(format!("{resp1:?}"), format!("{resp2:?}")); + if let Some((byte_len, chunk_size)) = chunk_info { + // We need to compute how many bytes should be in the trailing data + + // We know how many bytes of header data we had + let last_packet_header_size = + (byte_len - remainder1.len() + chunk_size - 1) % chunk_size + 1; + + // Which means we can compute how much was in the remainder + let remaining = + (chunk_size - last_packet_header_size).min(remainder1.len()); + + assert_eq!(remainder1[..remaining], remainder2); + } else { + assert_eq!(remainder1, remainder2); + } + } + _ => panic!("Expected Ok(Some(...)), was {formatted}"), + }, + Ok(None) => assert!( + result.ok().unwrap().is_none(), + "Expected Ok(None), was {formatted}", + ), + Err(e) => assert_eq!( + e, + result.err().map(|e| format!("{e:?}")).unwrap_or_default(), + "Expected error, was {formatted}", + ), + } + } + + fn validate_upgrade_all_at_once( + s: &str, + expected: Result<ExpectedResponseAndHead, &'static str>, + ) { + let mut upgrade = WebSocketUpgrade::default(); + let res = upgrade.write(s.as_bytes()); + + assert_response(res, expected, None); + } + + fn validate_upgrade_chunks( + s: &str, + size: usize, + expected: Result<ExpectedResponseAndHead, &'static str>, + ) { + let chunk_info = Some((s.as_bytes().len(), size)); + let mut upgrade = WebSocketUpgrade::default(); + let mut result = Ok(None); + for chunk in s.as_bytes().chunks(size) { + result = upgrade.write(chunk); + if let Ok(Some(..)) = &result { + assert_response(result, expected, chunk_info); + return; + } + } + assert_response(result, expected, chunk_info); + } + + fn validate_upgrade( + s: &str, + expected: fn() -> Result<ExpectedResponseAndHead, &'static str>, + ) { + validate_upgrade_all_at_once(s, expected()); + validate_upgrade_chunks(s, 1, expected()); + validate_upgrade_chunks(s, 2, expected()); + validate_upgrade_chunks(s, 10, expected()); + + // Replace \n with \r\n, but only in headers + let (headers, trailing) = s.split_once("\n\n").unwrap(); + let s = headers.replace('\n', "\r\n") + "\r\n\r\n" + trailing; + let s = s.as_ref(); + + validate_upgrade_all_at_once(s, expected()); + validate_upgrade_chunks(s, 1, expected()); + validate_upgrade_chunks(s, 2, expected()); + validate_upgrade_chunks(s, 10, expected()); + } + + #[test] + fn upgrade1() { + validate_upgrade( + "HTTP/1.1 101 Switching Protocols\nConnection: Upgrade\n\n", + || { + let mut expected = + Response::builder().status(101).body(Body::empty()).unwrap(); + expected.headers_mut().append( + HeaderName::from_static("connection"), + HeaderValue::from_static("Upgrade"), + ); + Ok(Some((expected, b""))) + }, + ); + } + + #[test] + fn upgrade_trailing() { + validate_upgrade( + "HTTP/1.1 101 Switching Protocols\nConnection: Upgrade\n\ntrailing data", + || { + let mut expected = + Response::builder().status(101).body(Body::empty()).unwrap(); + expected.headers_mut().append( + HeaderName::from_static("connection"), + HeaderValue::from_static("Upgrade"), + ); + Ok(Some((expected, b"trailing data"))) + }, + ); + } + + #[test] + fn upgrade_trailing_with_newlines() { + validate_upgrade( + "HTTP/1.1 101 Switching Protocols\nConnection: Upgrade\n\ntrailing data\r\n\r\n", + || { + let mut expected = + Response::builder().status(101).body(Body::empty()).unwrap(); + expected.headers_mut().append( + HeaderName::from_static("connection"), + HeaderValue::from_static("Upgrade"), + ); + Ok(Some((expected, b"trailing data\r\n\r\n"))) + }, + ); + } + + #[test] + fn upgrade2() { + validate_upgrade( + "HTTP/1.1 101 Switching Protocols\nConnection: Upgrade\nOther: 123\n\n", + || { + let mut expected = + Response::builder().status(101).body(Body::empty()).unwrap(); + expected.headers_mut().append( + HeaderName::from_static("connection"), + HeaderValue::from_static("Upgrade"), + ); + expected.headers_mut().append( + HeaderName::from_static("other"), + HeaderValue::from_static("123"), + ); + Ok(Some((expected, b""))) + }, + ); + } + + #[test] + fn upgrade_invalid_status() { + validate_upgrade("HTTP/1.1 200 OK\nConnection: Upgrade\n\n", || { + Err("invalid HTTP status line") + }); + } + + #[test] + fn upgrade_too_many_headers() { + let headers = (0..20) + .map(|i| format!("h{i}: {i}")) + .collect::<Vec<_>>() + .join("\n"); + validate_upgrade( + &format!("HTTP/1.1 101 Switching Protocols\n{headers}\n\n"), + || Err("too many headers"), + ); + } +} |