# Nim-WebRTC # Copyright (c) 2022 Status Research & Development GmbH # Licensed under either of # * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) # * MIT license ([LICENSE-MIT](LICENSE-MIT)) # at your option. # This file may not be copied, modified, or distributed except according to # those terms. import tables, bitops, posix, strutils, sequtils import chronos, chronicles, stew/[ranges/ptr_arith, byteutils] import usrsctp import dtls/dtls export chronicles logScope: topics = "webrtc sctp" # TODO: # - Replace doAssert by a proper exception management # - Find a clean way to manage SCTP ports type SctpError* = object of CatchableError SctpState = enum Connecting Connected Closed SctpMessageParameters* = object protocolId*: uint32 streamId*: uint16 endOfRecord*: bool unordered*: bool SctpMessage* = ref object data*: seq[byte] info: sctp_rcvinfo params*: SctpMessageParameters SctpConn* = ref object conn: DtlsConn state: SctpState connectEvent: AsyncEvent acceptEvent: AsyncEvent readLoop: Future[void] sctp: Sctp udp: DatagramTransport address: TransportAddress sctpSocket: ptr socket dataRecv: AsyncQueue[SctpMessage] sentFuture: Future[void] Sctp* = ref object dtls: Dtls udp: DatagramTransport connections: Table[TransportAddress, SctpConn] gotConnection: AsyncEvent timersHandler: Future[void] isServer: bool sockServer: ptr socket pendingConnections: seq[SctpConn] pendingConnections2: Table[SockAddr, SctpConn] sentConnection: SctpConn sentAddress: TransportAddress sentFuture: Future[void] const IPPROTO_SCTP = 132 proc newSctpError(msg: string): ref SctpError = result = newException(SctpError, msg) template usrsctpAwait(self: SctpConn|Sctp, body: untyped): untyped = self.sentFuture = nil when type(body) is void: body if self.sentFuture != nil: await self.sentFuture else: let res = body if self.sentFuture != nil: await self.sentFuture res proc perror(error: cstring) {.importc, cdecl, header: "".} proc printf(format: cstring) {.cdecl, importc: "printf", varargs, header: "", gcsafe.} proc packetPretty(packet: cstring): string = let data = $packet let ctn = data[23..^16] result = data[1..14] if ctn.len > 30: result = result & ctn[0..14] & " ... " & ctn[^14..^1] else: result = result & ctn proc new(T: typedesc[SctpConn], sctp: Sctp, udp: DatagramTransport, address: TransportAddress, sctpSocket: ptr socket): T = T(sctp: sctp, state: Connecting, udp: udp, address: address, sctpSocket: sctpSocket, connectEvent: AsyncEvent(), #TODO add some limit for backpressure? dataRecv: newAsyncQueue[SctpMessage]() ) proc new(T: typedesc[SctpConn], conn: DtlsConn, sctp: Sctp): T = T(conn: conn, sctp: sctp, state: Connecting, connectEvent: AsyncEvent(), acceptEvent: AsyncEvent(), dataRecv: newAsyncQueue[SctpMessage]() #TODO add some limit for backpressure? ) proc read*(self: SctpConn): Future[SctpMessage] {.async.} = return await self.dataRecv.popFirst() proc toFlags(params: SctpMessageParameters): uint16 = if params.endOfRecord: result = result or SCTP_EOR if params.unordered: result = result or SCTP_UNORDERED proc write*( self: SctpConn, buf: seq[byte], sendParams = default(SctpMessageParameters), ) {.async.} = trace "Write", buf self.sctp.sentConnection = self self.sctp.sentAddress = self.address let (sendInfo, infoType) = if sendParams != default(SctpMessageParameters): (sctp_sndinfo( snd_sid: sendParams.streamId, #TODO endianness? snd_ppid: sendParams.protocolId, snd_flags: sendParams.toFlags ), cuint(SCTP_SENDV_SNDINFO)) else: (default(sctp_sndinfo), cuint(SCTP_SENDV_NOINFO)) sendvErr = self.usrsctpAwait: self.sctpSocket.usrsctp_sendv(unsafeAddr buf[0], buf.len.uint, nil, 0, unsafeAddr sendInfo, sizeof(sendInfo).SockLen, infoType, 0) proc write*(self: SctpConn, s: string) {.async.} = await self.write(s.toBytes()) proc close*(self: SctpConn) {.async.} = self.usrsctpAwait: self.sctpSocket.usrsctp_close() proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} = let events = usrsctp_get_events(sock) conn = cast[SctpConn](data) trace "Handle Upcall", events if conn.state == Connecting: if bitand(events, SCTP_EVENT_ERROR) != 0: warn "Cannot connect", address = conn.address conn.state = Closed elif bitand(events, SCTP_EVENT_WRITE) != 0: conn.state = Connected conn.connectEvent.fire() elif bitand(events, SCTP_EVENT_READ) != 0: var message = SctpMessage( data: newSeq[byte](4096) ) address: Sockaddr_storage rn: sctp_recvv_rn addressLen = sizeof(Sockaddr_storage).SockLen rnLen = sizeof(message.info).SockLen infotype: uint flags: int let n = sock.usrsctp_recvv(cast[pointer](addr message.data[0]), message.data.len.uint, cast[ptr SockAddr](addr address), cast[ptr SockLen](addr addressLen), cast[pointer](addr message.info), cast[ptr SockLen](addr rnLen), cast[ptr cuint](addr infotype), cast[ptr cint](addr flags)) if n < 0: perror("usrsctp_recvv") return elif n > 0: if infotype == SCTP_RECVV_RCVINFO: message.params = SctpMessageParameters( #TODO endianness? protocolId: message.info.rcv_ppid, streamId: message.info.rcv_sid ) if bitand(flags, MSG_NOTIFICATION) != 0: trace "Notification received", length = n else: try: conn.dataRecv.addLastNoWait(message) except AsyncQueueFullError: trace "Queue full, dropping packet" else: warn "Handle Upcall unexpected event", events proc handleAccept(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} = trace "Handle Accept" var sconn: Sockaddr_conn slen: Socklen = sizeof(Sockaddr_conn).uint32 let sctp = cast[Sctp](data) sctpSocket = usrsctp_accept(sctp.sockServer, cast[ptr SockAddr](addr sconn), addr slen) doAssert 0 == sctpSocket.usrsctp_set_non_blocking(1) let conn = cast[SctpConn](sconn.sconn_addr) conn.state = Connected conn.acceptEvent.fire() # proc getOrCreateConnection(self: Sctp, # udp: DatagramTransport, # address: TransportAddress, # sctpPort: uint16 = 5000): Future[SctpConn] {.async.} = # #TODO remove the = 5000 # if self.connections.hasKey(address): # return self.connections[address] # trace "Create Connection", address # let # sctpSocket = usrsctp_socket(AF_CONN, posix.SOCK_STREAM, IPPROTO_SCTP, nil, nil, 0, nil) # conn = SctpConn.new(self, udp, address, sctpSocket) # var on: int = 1 # doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, # SCTP_RECVRCVINFO, # addr on, # sizeof(on).SockLen) # doAssert 0 == usrsctp_set_non_blocking(conn.sctpSocket, 1) # doAssert 0 == usrsctp_set_upcall(conn.sctpSocket, handleUpcall, cast[pointer](conn)) # var sconn: Sockaddr_conn # sconn.sconn_family = AF_CONN # sconn.sconn_port = htons(sctpPort) # sconn.sconn_addr = cast[pointer](self) # self.sentConnection = conn # self.sentAddress = address # let connErr = self.usrsctpAwait: # conn.sctpSocket.usrsctp_connect(cast[ptr SockAddr](addr sconn), SockLen(sizeof(sconn))) # doAssert 0 == connErr or errno == posix.EINPROGRESS, ($errno) # self.connections[address] = conn # return conn proc sendCallback(ctx: pointer, buffer: pointer, length: uint, tos: uint8, set_df: uint8): cint {.cdecl.} = let data = usrsctp_dumppacket(buffer, length, SCTP_DUMP_OUTBOUND) if data != nil: trace "sendCallback", data = data.packetPretty(), length usrsctp_freedumpbuffer(data) let sctpConn = cast[SctpConn](ctx) proc testSend() {.async.} = try: let buf = @(buffer.makeOpenArray(byte, int(length))) trace "Send To", address = sctpConn.address await sctpConn.conn.write(buf) except CatchableError as exc: trace "Send Failed", message = exc.msg sctpConn.sentFuture = testSend() proc timersHandler() {.async.} = while true: await sleepAsync(500.milliseconds) usrsctp_handle_timers(500) proc stopServer*(self: Sctp) = if not self.isServer: trace "Try to close a client" return self.isServer = false let pcs = self.pendingConnections self.pendingConnections = @[] for pc in pcs: pc.sctpSocket.usrsctp_close() self.sockServer.usrsctp_close() proc new*(T: typedesc[Sctp], dtls: Dtls, laddr: TransportAddress): T = let sctp = T(gotConnection: newAsyncEvent(), timersHandler: timersHandler(), dtls: dtls) usrsctp_init_nothreads(laddr.port.uint16, sendCallback, printf) discard usrsctp_sysctl_set_sctp_debug_on(SCTP_DEBUG_NONE) discard usrsctp_sysctl_set_sctp_ecn_enable(1) usrsctp_register_address(cast[pointer](sctp)) return sctp #proc new*(T: typedesc[Sctp], port: uint16 = 9899): T = # logScope: topics = "webrtc sctp" # let sctp = T(gotConnection: newAsyncEvent()) # proc onReceive(udp: DatagramTransport, raddr: TransportAddress) {.async, gcsafe.} = # let # msg = udp.getMessage() # data = usrsctp_dumppacket(unsafeAddr msg[0], uint(msg.len), SCTP_DUMP_INBOUND) # if data != nil: # if sctp.isServer: # trace "onReceive (server)", data = data.packetPretty(), length = msg.len(), raddr # else: # trace "onReceive (client)", data = data.packetPretty(), length = msg.len(), raddr # usrsctp_freedumpbuffer(data) # # if sctp.isServer: # sctp.sentAddress = raddr # usrsctp_conninput(cast[pointer](sctp), unsafeAddr msg[0], uint(msg.len), 0) # else: # let conn = await sctp.getOrCreateConnection(udp, raddr) # sctp.sentConnection = conn # sctp.sentAddress = raddr # usrsctp_conninput(cast[pointer](sctp), unsafeAddr msg[0], uint(msg.len), 0) # let # localAddr = TransportAddress(family: AddressFamily.IPv4, port: Port(port)) # laddr = initTAddress("127.0.0.1:" & $port) # udp = newDatagramTransport(onReceive, local = laddr) # trace "local address", localAddr, laddr # sctp.udp = udp # sctp.timersHandler = timersHandler() # # usrsctp_init_nothreads(port, sendCallback, printf) # discard usrsctp_sysctl_set_sctp_debug_on(SCTP_DEBUG_NONE) # discard usrsctp_sysctl_set_sctp_ecn_enable(1) # usrsctp_register_address(cast[pointer](sctp)) # # return sctp proc stop*(self: Sctp) {.async.} = discard self.usrsctpAwait usrsctp_finish() self.udp.close() proc readLoopProc(res: SctpConn) {.async.} = while true: trace "Read Loop Proc Before" let msg = await res.conn.read() data = usrsctp_dumppacket(unsafeAddr msg[0], uint(msg.len), SCTP_DUMP_INBOUND) trace "Read Loop Proc Before", isnil=data.isNil() if data != nil: trace "Receive connection", remoteAddress = res.conn.raddr, data = data.packetPretty() usrsctp_freedumpbuffer(data) res.sctp.sentConnection = res usrsctp_conninput(cast[pointer](res), unsafeAddr msg[0], uint(msg.len), 0) proc accept*(self: Sctp): Future[SctpConn] {.async.} = if not self.isServer: raise newSctpError("Not a server") var res = SctpConn.new(await self.dtls.accept, self) usrsctp_register_address(cast[pointer](res)) res.readLoop = res.readLoopProc() res.acceptEvent.clear() await res.acceptEvent.wait() return res proc listen*(self: Sctp, sctpPort: uint16 = 5000) = if self.isServer: trace "Try to start the server twice" return self.isServer = true trace "Listening", sctpPort doAssert 0 == usrsctp_sysctl_set_sctp_blackhole(2) doAssert 0 == usrsctp_sysctl_set_sctp_no_csum_on_loopback(0) let sock = usrsctp_socket(AF_CONN, posix.SOCK_STREAM, IPPROTO_SCTP, nil, nil, 0, nil) var on: int = 1 doAssert 0 == usrsctp_set_non_blocking(sock, 1) var sin: Sockaddr_in sin.sin_family = posix.AF_INET.uint16 sin.sin_port = htons(sctpPort) sin.sin_addr.s_addr = htonl(INADDR_ANY) doAssert 0 == usrsctp_bind(sock, cast[ptr SockAddr](addr sin), SockLen(sizeof(Sockaddr_in))) doAssert 0 >= usrsctp_listen(sock, 1) doAssert 0 == sock.usrsctp_set_upcall(handleAccept, cast[pointer](self)) self.sockServer = sock proc connect*(self: Sctp, address: TransportAddress, sctpPort: uint16 = 5000): Future[SctpConn] {.async.} = discard # proc connect*(self: Sctp, # address: TransportAddress, # sctpPort: uint16 = 5000): Future[SctpConn] {.async.} = # trace "Connect", address, sctpPort # let conn = await self.getOrCreateConnection(self.udp, address, sctpPort) # if conn.state == Connected: # return conn # try: # await conn.connectEvent.wait() # TODO: clear? # except CancelledError as exc: # conn.sctpSocket.usrsctp_close() # return nil # if conn.state != Connected: # raise newSctpError("Cannot connect to " & $address) # return conn