summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cli/dts/lib.deno.unstable.d.ts1
-rw-r--r--cli/tests/testdata/websocketstream_test.ts57
-rw-r--r--ext/websocket/02_websocketstream.js5
-rw-r--r--ext/websocket/lib.rs29
4 files changed, 92 insertions, 0 deletions
diff --git a/cli/dts/lib.deno.unstable.d.ts b/cli/dts/lib.deno.unstable.d.ts
index 6b7755ee5..442f5d7d4 100644
--- a/cli/dts/lib.deno.unstable.d.ts
+++ b/cli/dts/lib.deno.unstable.d.ts
@@ -1172,6 +1172,7 @@ declare interface WorkerOptions {
declare interface WebSocketStreamOptions {
protocols?: string[];
signal?: AbortSignal;
+ headers?: HeadersInit;
}
declare interface WebSocketConnection {
diff --git a/cli/tests/testdata/websocketstream_test.ts b/cli/tests/testdata/websocketstream_test.ts
index 1198c4164..b43b90139 100644
--- a/cli/tests/testdata/websocketstream_test.ts
+++ b/cli/tests/testdata/websocketstream_test.ts
@@ -3,6 +3,7 @@
import {
assert,
assertEquals,
+ assertNotEquals,
assertRejects,
assertThrows,
unreachable,
@@ -137,3 +138,59 @@ Deno.test("aborting immediately with a primitive as reason throws that primitive
(e) => assertEquals(e, "Some string"),
);
});
+
+Deno.test("headers", async () => {
+ const listener = Deno.listen({ port: 4501 });
+ const promise = (async () => {
+ const httpConn = Deno.serveHttp(await listener.accept());
+ const { request, respondWith } = (await httpConn.nextRequest())!;
+ assertEquals(request.headers.get("x-some-header"), "foo");
+ const {
+ response,
+ socket,
+ } = Deno.upgradeWebSocket(request);
+ socket.onopen = () => socket.close();
+ await respondWith(response);
+ })();
+
+ const ws = new WebSocketStream("ws://localhost:4501", {
+ headers: [["x-some-header", "foo"]],
+ });
+ await promise;
+ await ws.closed;
+ listener.close();
+});
+
+Deno.test("forbidden headers", async () => {
+ const forbiddenHeaders = [
+ "sec-websocket-accept",
+ "sec-websocket-extensions",
+ "sec-websocket-key",
+ "sec-websocket-protocol",
+ "sec-websocket-version",
+ "upgrade",
+ "connection",
+ ];
+
+ const listener = Deno.listen({ port: 4501 });
+ const promise = (async () => {
+ const httpConn = Deno.serveHttp(await listener.accept());
+ const { request, respondWith } = (await httpConn.nextRequest())!;
+ for (const header of request.headers.keys()) {
+ assertNotEquals(header, "foo");
+ }
+ const {
+ response,
+ socket,
+ } = Deno.upgradeWebSocket(request);
+ socket.onopen = () => socket.close();
+ await respondWith(response);
+ })();
+
+ const ws = new WebSocketStream("ws://localhost:4501", {
+ headers: forbiddenHeaders.map((header) => [header, "foo"]),
+ });
+ await promise;
+ await ws.closed;
+ listener.close();
+});
diff --git a/ext/websocket/02_websocketstream.js b/ext/websocket/02_websocketstream.js
index 8b032d1c2..d0a4e055d 100644
--- a/ext/websocket/02_websocketstream.js
+++ b/ext/websocket/02_websocketstream.js
@@ -39,6 +39,10 @@
key: "signal",
converter: webidl.converters.AbortSignal,
},
+ {
+ key: "headers",
+ converter: webidl.converters.HeadersInit,
+ },
],
);
webidl.converters.WebSocketCloseInfo = webidl.createDictionaryConverter(
@@ -139,6 +143,7 @@
? ArrayPrototypeJoin(options.protocols, ", ")
: "",
cancelHandle: cancelRid,
+ headers: [...new Headers(options.headers).entries()],
}),
(create) => {
options.signal?.[remove](abort);
diff --git a/ext/websocket/lib.rs b/ext/websocket/lib.rs
index 4796eddc6..544423066 100644
--- a/ext/websocket/lib.rs
+++ b/ext/websocket/lib.rs
@@ -1,6 +1,7 @@
// Copyright 2018-2021 the Deno authors. All rights reserved. MIT license.
use deno_core::error::invalid_hostname;
+use deno_core::error::type_error;
use deno_core::error::AnyError;
use deno_core::futures::stream::SplitSink;
use deno_core::futures::stream::SplitStream;
@@ -11,6 +12,7 @@ use deno_core::op_async;
use deno_core::op_sync;
use deno_core::url;
use deno_core::AsyncRefCell;
+use deno_core::ByteString;
use deno_core::CancelFuture;
use deno_core::CancelHandle;
use deno_core::Extension;
@@ -20,6 +22,8 @@ use deno_core::Resource;
use deno_core::ResourceId;
use deno_core::ZeroCopyBuf;
use deno_tls::create_client_config;
+use http::header::HeaderName;
+use http::HeaderValue;
use http::Method;
use http::Request;
use http::Uri;
@@ -215,6 +219,7 @@ pub struct CreateArgs {
url: String,
protocols: String,
cancel_handle: Option<ResourceId>,
+ headers: Option<Vec<(ByteString, ByteString)>>,
}
#[derive(Serialize)]
@@ -267,6 +272,30 @@ where
request = request.header("Sec-WebSocket-Protocol", args.protocols);
}
+ if let Some(headers) = args.headers {
+ for (key, value) in headers {
+ let name = HeaderName::from_bytes(&key)
+ .map_err(|err| type_error(err.to_string()))?;
+ let v = HeaderValue::from_bytes(&value)
+ .map_err(|err| type_error(err.to_string()))?;
+
+ let is_disallowed_header = matches!(
+ name,
+ http::header::HOST
+ | http::header::SEC_WEBSOCKET_ACCEPT
+ | http::header::SEC_WEBSOCKET_EXTENSIONS
+ | http::header::SEC_WEBSOCKET_KEY
+ | http::header::SEC_WEBSOCKET_PROTOCOL
+ | http::header::SEC_WEBSOCKET_VERSION
+ | http::header::UPGRADE
+ | http::header::CONNECTION
+ );
+ if !is_disallowed_header {
+ request = request.header(name, v);
+ }
+ }
+ }
+
let request = request.body(())?;
let domain = &uri.host().unwrap().to_string();
let port = &uri.port_u16().unwrap_or(match uri.scheme_str() {