diff options
author | Yusuke Sakurai <kerokerokerop@gmail.com> | 2020-02-06 22:42:32 +0900 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-02-06 08:42:32 -0500 |
commit | 699d10bd9e5f19ad2f4ffb82225c86690a092c07 (patch) | |
tree | f62d22e4f945917ae2cad7f0f824405a0ab6719e | |
parent | ed680552a24b7d4b936b7c16a63b46e0f24c0e60 (diff) |
fix: make WebSocket.send() exclusive (#3885)
-rw-r--r-- | std/ws/mod.ts | 127 | ||||
-rw-r--r-- | std/ws/test.ts | 63 |
2 files changed, 138 insertions, 52 deletions
diff --git a/std/ws/mod.ts b/std/ws/mod.ts index 96ba4df62..217ebc8b5 100644 --- a/std/ws/mod.ts +++ b/std/ws/mod.ts @@ -10,6 +10,7 @@ import { readLong, readShort, sliceLongToBytes } from "../io/ioutil.ts"; import { Sha1 } from "./sha1.ts"; import { writeResponse } from "../http/server.ts"; import { TextProtoReader } from "../textproto/mod.ts"; +import { Deferred, deferred } from "../util/async.ts"; export enum OpCode { Continue = 0x0, @@ -193,21 +194,30 @@ function createMask(): Uint8Array { } class WebSocketImpl implements WebSocket { + readonly conn: Conn; private readonly mask?: Uint8Array; private readonly bufReader: BufReader; private readonly bufWriter: BufWriter; + private sendQueue: Array<{ + frame: WebSocketFrame; + d: Deferred<void>; + }> = []; - constructor( - readonly conn: Conn, - opts: { - bufReader?: BufReader; - bufWriter?: BufWriter; - mask?: Uint8Array; - } - ) { - this.mask = opts.mask; - this.bufReader = opts.bufReader || new BufReader(conn); - this.bufWriter = opts.bufWriter || new BufWriter(conn); + constructor({ + conn, + bufReader, + bufWriter, + mask + }: { + conn: Conn; + bufReader?: BufReader; + bufWriter?: BufWriter; + mask?: Uint8Array; + }) { + this.conn = conn; + this.mask = mask; + this.bufReader = bufReader || new BufReader(conn); + this.bufWriter = bufWriter || new BufWriter(conn); } async *receive(): AsyncIterableIterator<WebSocketEvent> { @@ -250,14 +260,11 @@ class WebSocketImpl implements WebSocket { yield { code, reason }; return; case OpCode.Ping: - await writeFrame( - { - opcode: OpCode.Pong, - payload: frame.payload, - isLastFrame: true - }, - this.bufWriter - ); + await this.enqueue({ + opcode: OpCode.Pong, + payload: frame.payload, + isLastFrame: true + }); yield ["ping", frame.payload] as WebSocketPingEvent; break; case OpCode.Pong: @@ -268,6 +275,27 @@ class WebSocketImpl implements WebSocket { } } + private dequeue(): void { + const [e] = this.sendQueue; + if (!e) return; + writeFrame(e.frame, this.bufWriter) + .then(() => e.d.resolve()) + .catch(e => e.d.reject(e)) + .finally(() => { + this.sendQueue.shift(); + this.dequeue(); + }); + } + + private enqueue(frame: WebSocketFrame): Promise<void> { + const d = deferred<void>(); + this.sendQueue.push({ d, frame }); + if (this.sendQueue.length === 1) { + this.dequeue(); + } + return d; + } + async send(data: WebSocketMessage): Promise<void> { if (this.isClosed) { throw new SocketClosedError("socket has been closed"); @@ -276,28 +304,24 @@ class WebSocketImpl implements WebSocket { typeof data === "string" ? OpCode.TextFrame : OpCode.BinaryFrame; const payload = typeof data === "string" ? encode(data) : data; const isLastFrame = true; - await writeFrame( - { - isLastFrame, - opcode, - payload, - mask: this.mask - }, - this.bufWriter - ); + const frame = { + isLastFrame, + opcode, + payload, + mask: this.mask + }; + return this.enqueue(frame); } async ping(data: WebSocketMessage = ""): Promise<void> { const payload = typeof data === "string" ? encode(data) : data; - await writeFrame( - { - isLastFrame: true, - opcode: OpCode.Ping, - mask: this.mask, - payload - }, - this.bufWriter - ); + const frame = { + isLastFrame: true, + opcode: OpCode.Ping, + mask: this.mask, + payload + }; + return this.enqueue(frame); } private _isClosed = false; @@ -317,15 +341,12 @@ class WebSocketImpl implements WebSocket { } else { payload = new Uint8Array(header); } - await writeFrame( - { - isLastFrame: true, - opcode: OpCode.Close, - mask: this.mask, - payload - }, - this.bufWriter - ); + await this.enqueue({ + isLastFrame: true, + opcode: OpCode.Close, + mask: this.mask, + payload + }); } catch (e) { throw e; } finally { @@ -380,7 +401,7 @@ export async function acceptWebSocket(req: { }): Promise<WebSocket> { const { conn, headers, bufReader, bufWriter } = req; if (acceptable(req)) { - const sock = new WebSocketImpl(conn, { bufReader, bufWriter }); + const sock = new WebSocketImpl({ conn, bufReader, bufWriter }); const secKey = headers.get("sec-websocket-key"); if (typeof secKey !== "string") { throw new Error("sec-websocket-key is not provided"); @@ -499,9 +520,19 @@ export async function connectWebSocket( conn.close(); throw err; } - return new WebSocketImpl(conn, { + return new WebSocketImpl({ + conn, bufWriter, bufReader, mask: createMask() }); } + +export function createWebSocket(params: { + conn: Conn; + bufWriter?: BufWriter; + bufReader?: BufReader; + mask?: Uint8Array; +}): WebSocket { + return new WebSocketImpl(params); +} diff --git a/std/ws/test.ts b/std/ws/test.ts index e9cdd1d40..d3148f7ad 100644 --- a/std/ws/test.ts +++ b/std/ws/test.ts @@ -3,6 +3,7 @@ import { BufReader, BufWriter } from "../io/bufio.ts"; import { assert, assertEquals, assertThrowsAsync } from "../testing/asserts.ts"; import { runIfMain, test } from "../testing/mod.ts"; import { TextProtoReader } from "../textproto/mod.ts"; +import * as bytes from "../bytes/mod.ts"; import { acceptable, connectWebSocket, @@ -11,10 +12,13 @@ import { OpCode, readFrame, unmask, - writeFrame + writeFrame, + createWebSocket } from "./mod.ts"; -import { encode } from "../strings/mod.ts"; - +import { encode, decode } from "../strings/mod.ts"; +type Writer = Deno.Writer; +type Reader = Deno.Reader; +type Conn = Deno.Conn; const { Buffer } = Deno; test(async function wsReadUnmaskedTextFrame(): Promise<void> { @@ -30,7 +34,7 @@ test(async function wsReadUnmaskedTextFrame(): Promise<void> { }); test(async function wsReadMaskedTextFrame(): Promise<void> { - //a masked single text frame with payload "Hello" + // a masked single text frame with payload "Hello" const buf = new BufReader( new Buffer( new Uint8Array([ @@ -272,4 +276,55 @@ test("handshake should send search correctly", async function wsHandshakeWithSea assertEquals(statusLine, "GET /?a=1 HTTP/1.1"); }); +function dummyConn(r: Reader, w: Writer): Conn { + return { + rid: -1, + closeRead: (): void => {}, + closeWrite: (): void => {}, + read: (x): Promise<number | Deno.EOF> => r.read(x), + write: (x): Promise<number> => w.write(x), + close: (): void => {}, + localAddr: { transport: "tcp", hostname: "0.0.0.0", port: 0 }, + remoteAddr: { transport: "tcp", hostname: "0.0.0.0", port: 0 } + }; +} + +function delayedWriter(ms: number, dest: Writer): Writer { + return { + write(p: Uint8Array): Promise<number> { + return new Promise<number>(resolve => { + setTimeout(async (): Promise<void> => { + resolve(await dest.write(p)); + }, ms); + }); + } + }; +} +test("WebSocket.send(), WebSocket.ping() should be exclusive", async (): Promise< + void +> => { + const buf = new Buffer(); + const conn = dummyConn(new Buffer(), delayedWriter(1, buf)); + const sock = createWebSocket({ conn }); + // Ensure send call + await Promise.all([ + sock.send("first"), + sock.send("second"), + sock.ping(), + sock.send(new Uint8Array([3])) + ]); + const bufr = new BufReader(buf); + const first = await readFrame(bufr); + const second = await readFrame(bufr); + const ping = await readFrame(bufr); + const third = await readFrame(bufr); + assertEquals(first.opcode, OpCode.TextFrame); + assertEquals(decode(first.payload), "first"); + assertEquals(first.opcode, OpCode.TextFrame); + assertEquals(decode(second.payload), "second"); + assertEquals(ping.opcode, OpCode.Ping); + assertEquals(third.opcode, OpCode.BinaryFrame); + assertEquals(bytes.equal(third.payload, new Uint8Array([3])), true); +}); + runIfMain(import.meta); |