diff --git a/examples/server.nim b/examples/server.nim index baa79af..aacf78d 100644 --- a/examples/server.nim +++ b/examples/server.nim @@ -16,6 +16,7 @@ proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = if ws.readyState != Open: error "Failed to open websocket connection." return + debug "Websocket handshake completed." while true: let recvData = await ws.recv() @@ -23,15 +24,11 @@ proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = debug "Websocket closed." break debug "Client Response: ", size = recvData.len - await ws.send(recvData) + await ws.send(recvData, Opcode.Text) except WebSocketError as exc: error "WebSocket error:", exception = exc.msg - let header = HttpTable.init([ - ("Server", "nim-ws example server") - ]) - discard await request.respond(Http200, "Hello World") else: return dumbResponse() diff --git a/examples/tlsserver.nim b/examples/tlsserver.nim index d50ea05..9dc7631 100644 --- a/examples/tlsserver.nim +++ b/examples/tlsserver.nim @@ -33,7 +33,7 @@ proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = if ws.readyState == ReadyState.Closed: return debug "Response: ", data = string.fromBytes(recvData) - await ws.send(recvData) + await ws.send(recvData, Opcode.Text) except WebSocketError: error "WebSocket error:", exception = getCurrentExceptionMsg() discard await request.respond(Http200, "Hello World") diff --git a/tests/testwebsockets.nim b/tests/testwebsockets.nim index c981a8d..d0763f0 100644 --- a/tests/testwebsockets.nim +++ b/tests/testwebsockets.nim @@ -6,9 +6,7 @@ import pkg/[asynctest, chronicles, stew/byteutils] -import ../ws/[ws, stream] - -include ../ws/ws +import ../ws/[ws, stream, utils] var server: HttpServerRef let address = initTAddress("127.0.0.1:8888") @@ -233,13 +231,14 @@ suite "Test ping-pong": proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = if r.isErr(): return dumbResponse() + let request = r.get() check request.uri.path == "/ws" let ws = await createServer( request, "proto", - onPing = proc() = - ping = true + onPing = proc(data: openArray[byte] = []) = + ping = true ) let respData = await ws.recv() @@ -257,34 +256,35 @@ suite "Test ping-pong": path = "/ws", protocols = @["proto"], frameSize = maxFrameSize, - onPong = proc() = - pong = true + onPong = proc(data: openArray[byte] = []) = + pong = true ) let maskKey = genMaskKey(newRng()) - let encframe = encodeFrame(Frame( - fin: false, - rsv1: false, - rsv2: false, - rsv3: false, - opcode: Opcode.Text, - mask: true, - data: msg[0..4], - maskKey: maskKey)) + await wsClient.stream.writer.write( + encodeFrame(Frame( + fin: false, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Text, + mask: true, + data: msg[0..4], + maskKey: maskKey))) - await wsClient.stream.writer.write(encframe) await wsClient.ping() - let encframe1 = encodeFrame(Frame( - fin: true, - rsv1: false, - rsv2: false, - rsv3: false, - opcode: Opcode.Cont, - mask: true, - data: msg[5..9], - maskKey: maskKey)) - await wsClient.stream.writer.write(encframe1) + await wsClient.stream.writer.write( + encodeFrame(Frame( + fin: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Cont, + mask: true, + data: msg[5..9], + maskKey: maskKey))) + await wsClient.close() check: ping @@ -306,7 +306,7 @@ suite "Test ping-pong": let ws = await createServer( request, "proto", - onPong = proc() = + onPong = proc(data: openArray[byte] = []) = pong = true ) @@ -322,8 +322,8 @@ suite "Test ping-pong": Port(8888), path = "/ws", protocols = @["proto"], - onPing = proc() = - ping = true + onPing = proc(data: openArray[byte] = []) = + ping = true ) await waitForClose(wsClient) @@ -342,8 +342,8 @@ suite "Test ping-pong": let ws = await createServer( request, "proto", - onPing = proc() = - ping = true + onPing = proc(data: openArray[byte] = []) = + ping = true ) await waitForClose(ws) check: @@ -359,8 +359,8 @@ suite "Test ping-pong": Port(8888), path = "/ws", protocols = @["proto"], - onPong = proc() = - pong = true + onPong = proc(data: openArray[byte] = []) = + pong = true ) await wsClient.ping() @@ -744,6 +744,7 @@ suite "Test Payload": expect WSPayloadTooLarge: discard await ws.recv() + await waitForClose(ws) let res = HttpServerRef.new( @@ -758,7 +759,7 @@ suite "Test Payload": path = "/ws", protocols = @["proto"]) - await wsClient.send(toBytes(str), Opcode.Ping) + await wsClient.ping(toBytes(str)) await wsClient.close() test "Test single empty payload": @@ -825,8 +826,8 @@ suite "Test Payload": let ws = await createServer( request, "proto", - onPing = proc() = - ping = true + onPing = proc(data: openArray[byte]) = + ping = data == testData ) await waitForClose(ws) @@ -841,11 +842,11 @@ suite "Test Payload": Port(8888), path = "/ws", protocols = @["proto"], - onPong = proc() = - pong = true + onPong = proc(data: openArray[byte] = []) = + pong = true ) - await wsClient.send(testData, Opcode.Ping) + await wsClient.ping(testData) await wsClient.close() check: ping diff --git a/ws/stream.nim b/ws/stream.nim index bb48c1b..551209d 100644 --- a/ws/stream.nim +++ b/ws/stream.nim @@ -46,10 +46,11 @@ proc readHeaders*(rstream: AsyncStreamReader): Future[seq[byte]] {.async.} = return buffer proc closeWait*(wsStream: AsyncStream) {.async.} = - + # TODO: this is most likelly wrongs await allFutures( wsStream.writer.closeWait(), wsStream.reader.closeWait()) + await allFutures( wsStream.writer.tsource.closeWait(), wsStream.reader.tsource.closeWait()) diff --git a/ws/ws.nim b/ws/ws.nim index bf672d8..7d08c1f 100644 --- a/ws/ws.nim +++ b/ws/ws.nim @@ -78,15 +78,12 @@ type WSPayloadLengthError* = object of WebSocketError WSInvalidOpcodeError* = object of WebSocketError - - Base16Error* = object of CatchableError - ## Base16 specific exception type - HeaderFlag* {.size: sizeof(uint8).} = enum rsv3 rsv2 rsv1 fin + HeaderFlags = set[HeaderFlag] HttpCode* = enum @@ -121,19 +118,19 @@ type # 3000-3999 reserved for libs # 4000-4999 reserved for applications - Frame = ref object - fin: bool ## Indicates that this is the final fragment in a message. - rsv1: bool ## MUST be 0 unless negotiated that defines meanings - rsv2: bool ## MUST be 0 - rsv3: bool ## MUST be 0 - opcode: Opcode ## Defines the interpretation of the "Payload data". - mask: bool ## Defines whether the "Payload data" is masked. - data: seq[byte] ## Payload data - maskKey: array[4, char] ## Masking key - length: uint64 ## Message size. - consumed: uint64 ## how much has been consumed from the frame + Frame* = ref object + fin*: bool ## Indicates that this is the final fragment in a message. + rsv1*: bool ## MUST be 0 unless negotiated that defines meanings + rsv2*: bool ## MUST be 0 + rsv3*: bool ## MUST be 0 + opcode*: Opcode ## Defines the interpretation of the "Payload data". + mask*: bool ## Defines whether the "Payload data" is masked. + data*: seq[byte] ## Payload data + maskKey*: array[4, char] ## Masking key + length*: uint64 ## Message size. + consumed*: uint64 ## how much has been consumed from the frame - ControlCb* = proc() {.gcsafe, raises: [Defect].} + ControlCb* = proc(data: openArray[byte] = []) {.gcsafe, raises: [Defect].} CloseResult* = tuple code: Status @@ -172,7 +169,7 @@ proc `$`(ht: HttpTables): string = res.add(CRLF) res -proc unmask*( +proc mask*( data: var openArray[byte], maskKey: array[4, char], offset = 0) = @@ -215,8 +212,9 @@ proc handshake*( wantProtocol & ")") let cKey = ws.key & WSGuid - let acceptKey = Base64Pad.encode(sha1.digest(cKey.toOpenArray(0, - cKey.high)).data) + let acceptKey = Base64Pad.encode( + sha1.digest(cKey.toOpenArray(0, cKey.high)).data) + var headerData = [ ("Connection", "Upgrade"), ("Upgrade", "webSocket"), @@ -231,6 +229,7 @@ proc handshake*( except CatchableError as exc: raise newException(WSHandshakeError, "Failed to sent handshake response. Error: " & exc.msg) + ws.readyState = ReadyState.Open proc createServer*( @@ -263,12 +262,12 @@ proc createServer*( await ws.handshake(request) return ws -proc encodeFrame*(f: Frame): seq[byte] = +proc encodeFrame*(f: Frame, offset = 0): seq[byte] = ## Encodes a frame into a string buffer. ## See https://tools.ietf.org/html/rfc6455#section-5.2 var ret: seq[byte] - var b0 = (f.opcode.uint8 and 0x0f) # 0th byte: opcodes and flags. + var b0 = (f.opcode.uint8 and 0x0F) # 0th byte: opcodes and flags. if f.fin: b0 = b0 or 128'u8 @@ -280,7 +279,7 @@ proc encodeFrame*(f: Frame): seq[byte] = if f.data.len <= 125: b1 = f.data.len.uint8 - elif f.data.len > 125 and f.data.len <= 0xffff: + elif f.data.len > 125 and f.data.len <= 0xFFFF: b1 = 126'u8 else: b1 = 127'u8 @@ -291,12 +290,12 @@ proc encodeFrame*(f: Frame): seq[byte] = ret.add(uint8 b1) # Only need more bytes if data len is 7+16 bits, or 7+64 bits. - if f.data.len > 125 and f.data.len <= 0xffff: + if f.data.len > 125 and f.data.len <= 0xFFFF: # Data len is 7+16 bits. var len = f.data.len.uint16 - ret.add ((len shr 8) and 255).uint8 - ret.add (len and 255).uint8 - elif f.data.len > 0xffff: + ret.add ((len shr 8) and 0xFF).uint8 + ret.add (len and 0xFF).uint8 + elif f.data.len > 0xFFFF: # Data len is 7+64 bits. var len = f.data.len.uint64 ret.add(len.toBytesBE()) @@ -305,8 +304,8 @@ proc encodeFrame*(f: Frame): seq[byte] = if f.mask: # If we need to mask it generate random mask key and mask the data. - for i in 0..= data.len): true else: false, - rsv1: false, - rsv2: false, - rsv3: false, - opcode: if i > 0: Opcode.Cont else: opcode, # fragments have to be `Continuation` frames - mask: ws.masked, - data: data[i ..< len], - maskKey: maskKey)) + await ws.stream.writer.write( + encodeFrame(Frame( + fin: if (i + len >= data.len): true else: false, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: if i > 0: Opcode.Cont else: opcode, # fragments have to be `Continuation` frames + mask: ws.masked, + data: data[i ..< len], + maskKey: maskKey))) - await ws.stream.writer.write(encFrame) i += len - if i >= data.len: break @@ -377,8 +378,6 @@ proc send*(ws: WebSocket, data: string): Future[void] = proc handleClose*(ws: WebSocket, frame: Frame, payLoad: seq[byte] = @[]) {.async.} = - if ws.readyState notin {ReadyState.Open}: - return logScope: fin = frame.fin masked = frame.mask @@ -386,22 +385,31 @@ proc handleClose*(ws: WebSocket, frame: Frame, payLoad: seq[byte] = @[]) {.async serverState = ws.readyState debug "Handling close sequence" + + if ws.readyState notin {ReadyState.Open}: + return + var code = Status.Fulfilled reason = "" if payLoad.len == 1: - raise newException(WSPayloadLengthError, "Invalid close frame with payload length 1!") + raise newException(WSPayloadLengthError, + "Invalid close frame with payload length 1!") - elif payLoad.len > 1: + if payLoad.len > 1: # first two bytes are the status let ccode = uint16.fromBytesBE(payLoad[0..<2]) if ccode <= 999 or ccode > 1015: - raise newException(WSInvalidCloseCodeError, "Invalid code in close message!") + raise newException(WSInvalidCloseCodeError, + "Invalid code in close message!") + try: code = Status(ccode) except RangeError: - code = Status.Fulfilled + raise newException(WSInvalidCloseCodeError, + "Status code out of range!") + # remining payload bytes are reason for closing reason = string.fromBytes(payLoad[2..payLoad.high]) @@ -423,17 +431,35 @@ proc handleClose*(ws: WebSocket, frame: Frame, payLoad: seq[byte] = @[]) {.async ws.readyState = ReadyState.Closed await ws.stream.closeWait() -proc handleControl*(ws: WebSocket, frame: Frame, payLoad: seq[byte] = @[]) {.async.} = +proc handleControl*(ws: WebSocket, frame: Frame) {.async.} = ## handle control frames ## + if not frame.fin: + raise newException(WSFragmentedControlFrameError, + "Control frame cannot be fragmented!") + + if frame.length > 125: + raise newException(WSPayloadTooLarge, + "Control message payload is greater than 125 bytes!") + try: + var payLoad = newSeq[byte](frame.length.int) + if frame.length > 0: + payLoad.setLen(frame.length.int) + # Read control frame payload. + await ws.stream.reader.readExactly(addr payLoad[0], frame.length.int) + if frame.mask: + mask( + payLoad.toOpenArray(0, payLoad.high), + frame.maskKey) + # Process control frame payload. case frame.opcode: of Opcode.Ping: if not isNil(ws.onPing): try: - ws.onPing() + ws.onPing(payLoad) except CatchableError as exc: debug "Exception in Ping callback, this is most likelly a bug", exc = exc.msg @@ -442,7 +468,7 @@ proc handleControl*(ws: WebSocket, frame: Frame, payLoad: seq[byte] = @[]) {.asy of Opcode.Pong: if not isNil(ws.onPong): try: - ws.onPong() + ws.onPong(payLoad) except CatchableError as exc: debug "Exception in Pong callback, this is most likelly a bug", exc = exc.msg of Opcode.Close: @@ -469,7 +495,8 @@ proc readFrame*(ws: WebSocket): Future[Frame] {.async.} = await ws.stream.reader.readExactly(addr header[0], 2) if header.len != 2: debug "Invalid websocket header length" - raise newException(WSMalformedHeaderError, "Invalid websocket header length") + raise newException(WSMalformedHeaderError, + "Invalid websocket header length") let b0 = header[0].uint8 let b1 = header[1].uint8 @@ -487,12 +514,8 @@ proc readFrame*(ws: WebSocket): Future[Frame] {.async.} = if opcode > ord(Opcode.high): raise newException(WSOpcodeMismatchError, "Wrong opcode!") - let frameOpcode = (opcode).Opcode - if frameOpcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary, - Opcode.Ping, Opcode.Pong, Opcode.Close}: - raise newException(WSReserverdOpcodeError, "Unknown opcode received!") + frame.opcode = (opcode).Opcode - frame.opcode = frameOpcode # If any of the rsv are set close the socket. if frame.rsv1 or frame.rsv2 or frame.rsv3: raise newException(WSRsvMismatchError, "WebSocket rsv mismatch") @@ -530,34 +553,16 @@ proc readFrame*(ws: WebSocket): Future[Frame] {.async.} = for i in 0.. 125: - raise newException(WSPayloadTooLarge, - "Control message payload is greater than 125 bytes!") - var payLoad = newSeq[byte](frame.length) - if frame.length > 0: - # Read control frame payload. - await ws.stream.reader.readExactly(addr payLoad[0], frame.length.int) - unmask(payLoad.toOpenArray(0, payLoad.high), frame.maskKey) - await ws.handleControl(frame, payLoad) # process control frames# process control frames + await ws.handleControl(frame) # process control frames# process control frames continue - debug "Decoded new frame", opcode = frame.opcode, len = frame.length, - mask = frame.mask - return frame - - except WSReserverdOpcodeError as exc: - trace "Handled websocket opcode exception", exc = exc.msg - raise exc - except WSPayloadTooLarge as exc: - debug "Handled payload too large exception", exc = exc.msg - raise exc except WebSocketError as exc: - debug "Handled websocket exception", exc = exc.msg + trace "Websocket error", exc = exc.msg raise exc except CatchableError as exc: debug "Exception reading frame, dropping socket", exc = exc.msg @@ -565,8 +570,8 @@ proc readFrame*(ws: WebSocket): Future[Frame] {.async.} = await ws.stream.closeWait() raise exc -proc ping*(ws: WebSocket): Future[void] = - ws.send(opcode = Opcode.Ping) +proc ping*(ws: WebSocket, data: seq[byte] = @[]): Future[void] = + ws.send(data, opcode = Opcode.Ping) proc recv*( ws: WebSocket, @@ -585,27 +590,32 @@ proc recv*( var consumed = 0 var pbuffer = cast[ptr UncheckedArray[byte]](data) try: + + # read the first frame + ws.frame = await ws.readFrame() + + # This could happen if the connection is closed. + if isNil(ws.frame): + return consumed + + if ws.frame.opcode == Opcode.Cont: + raise newException(WSOpcodeMismatchError, + "First frame cannot be continue frame") + while consumed < size: # we might have to read more than # one frame to fill the buffer - # all has been consumed from the frame - # read the next frame - if isNil(ws.frame): - ws.frame = await ws.readFrame() - # This could happen if the connection is closed. - if isNil(ws.frame): - return consumed.int - if ws.frame.opcode == Opcode.Cont: - raise newException(WSOpcodeMismatchError, "First frame cannot be continue frame") - elif (not ws.frame.fin and ws.frame.remainder() <= 0): + if (not ws.frame.fin and ws.frame.remainder() <= 0): ws.frame = await ws.readFrame() # This could happen if the connection is closed. if isNil(ws.frame): - return consumed.int + return consumed if ws.frame.opcode != Opcode.Cont: - raise newException(WSOpcodeMismatchError, "expected continue frame") + raise newException(WSOpcodeMismatchError, + "expected continue frame") + if ws.frame.fin and ws.frame.remainder().int <= 0: ws.frame = nil break @@ -613,13 +623,14 @@ proc recv*( let len = min(ws.frame.remainder().int, size - consumed) if len == 0: continue + let read = await ws.stream.reader.readOnce(addr pbuffer[consumed], len) if read <= 0: continue if ws.frame.mask: # unmask data using offset - unmask( + mask( pbuffer.toOpenArray(consumed, (consumed + read) - 1), ws.frame.maskKey, ws.frame.consumed.int)