diff --git a/src/ws.nim b/src/ws.nim index 8baa4d84..169291bb 100644 --- a/src/ws.nim +++ b/src/ws.nim @@ -1,4 +1,5 @@ -import chronos, chronicles, httputils, strutils, base64, std/sha1, random +import chronos, chronicles, httputils, strutils, base64, std/sha1, random, + streams, nativesockets const MaxHttpHeadersSize = 8192 # maximum size of HTTP headers in octets @@ -78,7 +79,8 @@ proc handshake*(ws: WebSocket, header: HttpRequestHeader) {.async.} = let wantProtocol = header["Sec-WebSocket-Protocol"].strip() if ws.protocol != wantProtocol: raise newException(WebSocketError, - "Protocol mismatch (expected: " & ws.protocol & ", got: " & wantProtocol & ")") + "Protocol mismatch (expected: " & ws.protocol & ", got: " & + wantProtocol & ")") let sh = secureHash(ws.key & "258EAFA5-E914-47DA-95CA-C5AB0DC85B11") @@ -96,8 +98,8 @@ proc handshake*(ws: WebSocket, header: HttpRequestHeader) {.async.} = discard await ws.tcpSocket.write(response) ws.readyState = Open -proc newWebSocket*(header: HttpRequestHeader, transp: StreamTransport, protocol: string = ""): Future[ - WebSocket] {.async.} = +proc newWebSocket*(header: HttpRequestHeader, transp: StreamTransport, + protocol: string = ""): Future[WebSocket] {.async.} = ## Creates a new socket from a request. try: if not header.contains("Sec-WebSocket-Version"): @@ -115,6 +117,260 @@ proc newWebSocket*(header: HttpRequestHeader, transp: StreamTransport, protocol: "Failed to create WebSocket from request: " & getCurrentExceptionMsg() ) +type + Opcode* = enum + ## 4 bits. Defines the interpretation of the "Payload data". + Cont = 0x0 ## Denotes a continuation frame. + Text = 0x1 ## Denotes a text frame. + Binary = 0x2 ## Denotes a binary frame. + # 3-7 are reserved for further non-control frames. + Close = 0x8 ## Denotes a connection close. + Ping = 0x9 ## Denotes a ping. + Pong = 0xa ## Denotes a pong. + # B-F are reserved for further control frames. + + #[ + +---------------------------------------------------------------+ + |0 1 2 3 | + |0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1| + +-+-+-+-+-------+-+-------------+-------------------------------+ + |F|R|R|R| opcode|M| Payload len | Extended payload length | + |I|S|S|S| (4) |A| (7) | (16/64) | + |N|V|V|V| |S| | (if payload len==126/127) | + | |1|2|3| |K| | | + +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + | Extended payload length continued, if payload len == 127 | + + - - - - - - - - - - - - - - - +-------------------------------+ + | |Masking-key, if MASK set to 1 | + +-------------------------------+-------------------------------+ + | Masking-key (continued) | Payload Data | + +-------------------------------- - - - - - - - - - - - - - - - + + : Payload Data continued ... : + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + | Payload Data continued ... | + +---------------------------------------------------------------+ + ]# + Frame = tuple + fin: bool ## Indicates that this is the final fragment in a message. + rsv1: bool ## MUST be 0 unless negotiated that defines meanings + rsv2: bool ## MUST be 0 + rsv3: bool ## MUST be 0 + opcode: Opcode ## Defines the interpretation of the "Payload data". + mask: bool ## Defines whether the "Payload data" is masked. + data: string ## Payload data + +proc encodeFrame(f: Frame): string = + ## Encodes a frame into a string buffer. + ## See https://tools.ietf.org/html/rfc6455#section-5.2 + + var ret = newStringStream() + + var b0 = (f.opcode.uint8 and 0x0f) # 0th byte: opcodes and flags. + if f.fin: + b0 = b0 or 128u8 + + ret.write(b0) + + # Payload length can be 7 bits, 7+16 bits, or 7+64 bits. + # 1st byte: payload len start and mask bit. + var b1 = 0u8 + + if f.data.len <= 125: + b1 = f.data.len.uint8 + elif f.data.len > 125 and f.data.len <= 0xffff: + b1 = 126u8 + else: + b1 = 127u8 + + if f.mask: + b1 = b1 or (1 shl 7) + + ret.write(uint8 b1) + + # Only need more bytes if data len is 7+16 bits, or 7+64 bits. + if f.data.len > 125 and f.data.len <= 0xffff: + # Data len is 7+16 bits. + ret.write(htons(f.data.len.uint16)) + elif f.data.len > 0xffff: + # Data len is 7+64 bits. + var len = f.data.len + ret.write char((len shr 56) and 255) + ret.write char((len shr 48) and 255) + ret.write char((len shr 40) and 255) + ret.write char((len shr 32) and 255) + ret.write char((len shr 24) and 255) + ret.write char((len shr 16) and 255) + ret.write char((len shr 8) and 255) + ret.write char(len and 255) + + var data = f.data + + if f.mask: + # If we need to mask it generate random mask key and mask the data. + let maskKey = genMaskKey() + for i in 0..