A lot of fixes
This commit is contained in:
parent
9a6657922a
commit
2591a158ba
|
@ -72,13 +72,17 @@ type
|
||||||
DataChannelConnection* = ref object
|
DataChannelConnection* = ref object
|
||||||
readLoopFut: Future[void]
|
readLoopFut: Future[void]
|
||||||
streams: Table[uint16, DataChannelStream]
|
streams: Table[uint16, DataChannelStream]
|
||||||
|
streamId: uint16
|
||||||
conn*: SctpConn
|
conn*: SctpConn
|
||||||
incomingStreams: AsyncQueue[DataChannelStream]
|
incomingStreams: AsyncQueue[DataChannelStream]
|
||||||
|
|
||||||
proc read*(stream: DataChannelStream): Future[seq[byte]] {.async.} =
|
proc read*(stream: DataChannelStream): Future[seq[byte]] {.async.} =
|
||||||
return await stream.receivedData.popLast()
|
let x = await stream.receivedData.popFirst()
|
||||||
|
trace "read", length=x.len(), id=stream.id
|
||||||
|
return x
|
||||||
|
|
||||||
proc write*(stream: DataChannelStream, buf: seq[byte]) {.async.} =
|
proc write*(stream: DataChannelStream, buf: seq[byte]) {.async.} =
|
||||||
|
trace "write", length=buf.len(), id=stream.id
|
||||||
var
|
var
|
||||||
sendInfo = SctpMessageParameters(
|
sendInfo = SctpMessageParameters(
|
||||||
streamId: stream.id,
|
streamId: stream.id,
|
||||||
|
@ -105,14 +109,23 @@ proc sendControlMessage(stream: DataChannelStream, msg: DataChannelMessage) {.as
|
||||||
endOfRecord: true,
|
endOfRecord: true,
|
||||||
protocolId: uint32(WebRtcDcep)
|
protocolId: uint32(WebRtcDcep)
|
||||||
)
|
)
|
||||||
|
trace "send control message", msg
|
||||||
|
|
||||||
await stream.conn.write(encoded, sendInfo)
|
await stream.conn.write(encoded, sendInfo)
|
||||||
|
|
||||||
proc openStream*(
|
proc openStream*(
|
||||||
conn: DataChannelConnection,
|
conn: DataChannelConnection,
|
||||||
streamId: uint16,
|
noiseHandshake: bool,
|
||||||
reliability = Reliable, reliabilityParameter: uint32 = 0): Future[DataChannelStream] {.async.} =
|
reliability = Reliable, reliabilityParameter: uint32 = 0): Future[DataChannelStream] {.async.} =
|
||||||
|
let streamId: uint16 =
|
||||||
|
if not noiseHandshake:
|
||||||
|
let res = conn.streamId
|
||||||
|
conn.streamId += 2
|
||||||
|
res
|
||||||
|
else:
|
||||||
|
0
|
||||||
|
|
||||||
|
trace "open stream", streamId
|
||||||
if reliability in [Reliable, ReliableUnordered] and reliabilityParameter != 0:
|
if reliability in [Reliable, ReliableUnordered] and reliabilityParameter != 0:
|
||||||
raise newException(ValueError, "reliabilityParameter should be 0")
|
raise newException(ValueError, "reliabilityParameter should be 0")
|
||||||
|
|
||||||
|
@ -144,6 +157,7 @@ proc openStream*(
|
||||||
|
|
||||||
proc handleData(conn: DataChannelConnection, msg: SctpMessage) =
|
proc handleData(conn: DataChannelConnection, msg: SctpMessage) =
|
||||||
let streamId = msg.params.streamId
|
let streamId = msg.params.streamId
|
||||||
|
trace "handle data message", streamId, ppid = msg.params.protocolId, data = msg.data
|
||||||
|
|
||||||
if streamId notin conn.streams:
|
if streamId notin conn.streams:
|
||||||
raise newException(ValueError, "got data for unknown streamid")
|
raise newException(ValueError, "got data for unknown streamid")
|
||||||
|
@ -162,6 +176,7 @@ proc handleControl(conn: DataChannelConnection, msg: SctpMessage) {.async.} =
|
||||||
decoded = Binary.decode(msg.data, DataChannelMessage)
|
decoded = Binary.decode(msg.data, DataChannelMessage)
|
||||||
streamId = msg.params.streamId
|
streamId = msg.params.streamId
|
||||||
|
|
||||||
|
trace "handle control message", decoded, streamId = msg.params.streamId
|
||||||
if decoded.messageType == Ack:
|
if decoded.messageType == Ack:
|
||||||
if streamId notin conn.streams:
|
if streamId notin conn.streams:
|
||||||
raise newException(ValueError, "got ack for unknown streamid")
|
raise newException(ValueError, "got ack for unknown streamid")
|
||||||
|
@ -178,6 +193,7 @@ proc handleControl(conn: DataChannelConnection, msg: SctpMessage) {.async.} =
|
||||||
)
|
)
|
||||||
|
|
||||||
conn.streams[streamId] = stream
|
conn.streams[streamId] = stream
|
||||||
|
conn.incomingStreams.addLastNoWait(stream)
|
||||||
|
|
||||||
await stream.sendControlMessage(DataChannelMessage(messageType: Ack))
|
await stream.sendControlMessage(DataChannelMessage(messageType: Ack))
|
||||||
|
|
||||||
|
@ -185,6 +201,7 @@ proc readLoop(conn: DataChannelConnection) {.async.} =
|
||||||
try:
|
try:
|
||||||
while true:
|
while true:
|
||||||
let message = await conn.conn.read()
|
let message = await conn.conn.read()
|
||||||
|
# TODO: might be necessary to check the others protocolId at some point
|
||||||
if message.params.protocolId == uint32(WebRtcDcep):
|
if message.params.protocolId == uint32(WebRtcDcep):
|
||||||
#TODO should we really await?
|
#TODO should we really await?
|
||||||
await conn.handleControl(message)
|
await conn.handleControl(message)
|
||||||
|
@ -195,12 +212,12 @@ proc readLoop(conn: DataChannelConnection) {.async.} =
|
||||||
discard
|
discard
|
||||||
|
|
||||||
proc accept*(conn: DataChannelConnection): Future[DataChannelStream] {.async.} =
|
proc accept*(conn: DataChannelConnection): Future[DataChannelStream] {.async.} =
|
||||||
if isNil(conn.readLoopFut):
|
|
||||||
conn.readLoopFut = conn.readLoop()
|
|
||||||
return await conn.incomingStreams.popFirst()
|
return await conn.incomingStreams.popFirst()
|
||||||
|
|
||||||
proc new*(_: type DataChannelConnection, conn: SctpConn): DataChannelConnection =
|
proc new*(_: type DataChannelConnection, conn: SctpConn): DataChannelConnection =
|
||||||
DataChannelConnection(
|
result = DataChannelConnection(
|
||||||
conn: conn,
|
conn: conn,
|
||||||
incomingStreams: newAsyncQueue[DataChannelStream]()
|
incomingStreams: newAsyncQueue[DataChannelStream](),
|
||||||
|
streamId: 1'u16 # TODO: Serveur == 1, client == 2
|
||||||
)
|
)
|
||||||
|
conn.readLoopFut = conn.readLoop()
|
||||||
|
|
|
@ -8,9 +8,10 @@
|
||||||
# those terms.
|
# those terms.
|
||||||
|
|
||||||
import tables, bitops, posix, strutils, sequtils
|
import tables, bitops, posix, strutils, sequtils
|
||||||
import chronos, chronicles, stew/[ranges/ptr_arith, byteutils]
|
import chronos, chronicles, stew/[ranges/ptr_arith, byteutils, endians2]
|
||||||
import usrsctp
|
import usrsctp
|
||||||
import dtls/dtls
|
import dtls/dtls
|
||||||
|
import binary_serialization
|
||||||
|
|
||||||
export chronicles
|
export chronicles
|
||||||
|
|
||||||
|
@ -37,7 +38,7 @@ type
|
||||||
|
|
||||||
SctpMessage* = ref object
|
SctpMessage* = ref object
|
||||||
data*: seq[byte]
|
data*: seq[byte]
|
||||||
info: sctp_rcvinfo
|
info: sctp_recvv_rn
|
||||||
params*: SctpMessageParameters
|
params*: SctpMessageParameters
|
||||||
|
|
||||||
SctpConn* = ref object
|
SctpConn* = ref object
|
||||||
|
@ -67,6 +68,19 @@ type
|
||||||
sentAddress: TransportAddress
|
sentAddress: TransportAddress
|
||||||
sentFuture: Future[void]
|
sentFuture: Future[void]
|
||||||
|
|
||||||
|
# Those two objects are only here for debugging purpose
|
||||||
|
SctpChunk = object
|
||||||
|
chunkType: uint8
|
||||||
|
flag: uint8
|
||||||
|
length {.bin_value: it.data.len() + 4.}: uint16
|
||||||
|
data {.bin_len: it.length - 4.}: seq[byte]
|
||||||
|
|
||||||
|
SctpPacketStructure = object
|
||||||
|
srcPort: uint16
|
||||||
|
dstPort: uint16
|
||||||
|
verifTag: uint32
|
||||||
|
checksum: uint32
|
||||||
|
|
||||||
const
|
const
|
||||||
IPPROTO_SCTP = 132
|
IPPROTO_SCTP = 132
|
||||||
|
|
||||||
|
@ -86,6 +100,19 @@ template usrsctpAwait(self: SctpConn|Sctp, body: untyped): untyped =
|
||||||
proc perror(error: cstring) {.importc, cdecl, header: "<errno.h>".}
|
proc perror(error: cstring) {.importc, cdecl, header: "<errno.h>".}
|
||||||
proc printf(format: cstring) {.cdecl, importc: "printf", varargs, header: "<stdio.h>", gcsafe.}
|
proc printf(format: cstring) {.cdecl, importc: "printf", varargs, header: "<stdio.h>", gcsafe.}
|
||||||
|
|
||||||
|
proc printSctpPacket(buffer: seq[byte]) =
|
||||||
|
let s = Binary.decode(buffer, SctpPacketStructure)
|
||||||
|
echo " => \e[31;1mStructure\e[0m: ", s
|
||||||
|
var size = sizeof(SctpPacketStructure)
|
||||||
|
var i = 1
|
||||||
|
while size < buffer.len:
|
||||||
|
let c = Binary.decode(buffer[size..^1], SctpChunk)
|
||||||
|
echo " ===> \e[32;1mChunk ", i, "\e[0m ", c
|
||||||
|
i.inc()
|
||||||
|
size.inc(c.length.int)
|
||||||
|
while size mod 4 != 0:
|
||||||
|
size.inc()
|
||||||
|
|
||||||
proc packetPretty(packet: cstring): string =
|
proc packetPretty(packet: cstring): string =
|
||||||
let data = $packet
|
let data = $packet
|
||||||
let ctn = data[23..^16]
|
let ctn = data[23..^16]
|
||||||
|
@ -137,22 +164,23 @@ proc write*(
|
||||||
self.sctp.sentConnection = self
|
self.sctp.sentConnection = self
|
||||||
self.sctp.sentAddress = self.address
|
self.sctp.sentAddress = self.address
|
||||||
|
|
||||||
let
|
var cpy = buf
|
||||||
|
var
|
||||||
(sendInfo, infoType) =
|
(sendInfo, infoType) =
|
||||||
if sendParams != default(SctpMessageParameters):
|
if sendParams != default(SctpMessageParameters):
|
||||||
(sctp_sndinfo(
|
(sctp_sndinfo(
|
||||||
snd_sid: sendParams.streamId,
|
snd_sid: sendParams.streamId,
|
||||||
snd_ppid: sendParams.protocolId,
|
snd_ppid: sendParams.protocolId.swapBytes(),
|
||||||
snd_flags: sendParams.toFlags
|
snd_flags: sendParams.toFlags
|
||||||
), cuint(SCTP_SENDV_SNDINFO))
|
), cuint(SCTP_SENDV_SNDINFO))
|
||||||
else:
|
else:
|
||||||
(default(sctp_sndinfo), cuint(SCTP_SENDV_NOINFO))
|
(default(sctp_sndinfo), cuint(SCTP_SENDV_NOINFO))
|
||||||
sendvErr = self.usrsctpAwait:
|
sendvErr = self.usrsctpAwait:
|
||||||
self.sctpSocket.usrsctp_sendv(unsafeAddr buf[0], buf.len.uint,
|
self.sctpSocket.usrsctp_sendv(cast[pointer](addr cpy[0]), cpy.len.uint, nil, 0,
|
||||||
nil, 0, unsafeAddr sendInfo, sizeof(sendInfo).SockLen,
|
cast[pointer](addr sendInfo), sizeof(sendInfo).SockLen,
|
||||||
infoType, 0)
|
infoType, 0)
|
||||||
if sendvErr < 0:
|
if sendvErr < 0:
|
||||||
perror("usrsctp_sendv")
|
perror("usrsctp_sendv") # TODO: throw an exception
|
||||||
trace "write sendv error?", sendvErr, sendParams
|
trace "write sendv error?", sendvErr, sendParams
|
||||||
|
|
||||||
proc write*(self: SctpConn, s: string) {.async.} =
|
proc write*(self: SctpConn, s: string) {.async.} =
|
||||||
|
@ -182,7 +210,7 @@ proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
|
||||||
address: Sockaddr_storage
|
address: Sockaddr_storage
|
||||||
rn: sctp_recvv_rn
|
rn: sctp_recvv_rn
|
||||||
addressLen = sizeof(Sockaddr_storage).SockLen
|
addressLen = sizeof(Sockaddr_storage).SockLen
|
||||||
rnLen = sizeof(message.info).SockLen
|
rnLen = sizeof(sctp_recvv_rn).SockLen
|
||||||
infotype: uint
|
infotype: uint
|
||||||
flags: int
|
flags: int
|
||||||
trace "recv from", sockuint64=cast[uint64](sock)
|
trace "recv from", sockuint64=cast[uint64](sock)
|
||||||
|
@ -197,11 +225,12 @@ proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
|
||||||
perror("usrsctp_recvv")
|
perror("usrsctp_recvv")
|
||||||
return
|
return
|
||||||
elif n > 0:
|
elif n > 0:
|
||||||
if infotype == SCTP_RECVV_RCVINFO:
|
# It might be necessary to check if infotype == SCTP_RECVV_RCVINFO
|
||||||
|
message.data.delete(n..<message.data.len())
|
||||||
|
trace "message info from handle upcall", msginfo = message.info
|
||||||
message.params = SctpMessageParameters(
|
message.params = SctpMessageParameters(
|
||||||
#TODO endianness?
|
protocolId: message.info.recvv_rcvinfo.rcv_ppid.swapBytes(),
|
||||||
protocolId: message.info.rcv_ppid,
|
streamId: message.info.recvv_rcvinfo.rcv_sid
|
||||||
streamId: message.info.rcv_sid
|
|
||||||
)
|
)
|
||||||
if bitand(flags, MSG_NOTIFICATION) != 0:
|
if bitand(flags, MSG_NOTIFICATION) != 0:
|
||||||
trace "Notification received", length = n
|
trace "Notification received", length = n
|
||||||
|
@ -229,11 +258,12 @@ proc handleAccept(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
|
||||||
conn.sctpSocket = sctpSocket
|
conn.sctpSocket = sctpSocket
|
||||||
conn.state = Connected
|
conn.state = Connected
|
||||||
var nodelay: uint32 = 1
|
var nodelay: uint32 = 1
|
||||||
|
var recvinfo: uint32 = 1
|
||||||
doAssert 0 == conn.sctpSocket.usrsctp_set_upcall(handleUpcall, cast[pointer](conn))
|
doAssert 0 == conn.sctpSocket.usrsctp_set_upcall(handleUpcall, cast[pointer](conn))
|
||||||
doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP,
|
doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_NODELAY,
|
||||||
SCTP_NODELAY,
|
addr nodelay, sizeof(nodelay).SockLen)
|
||||||
addr nodelay,
|
doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_RECVRCVINFO,
|
||||||
sizeof(nodelay).SockLen)
|
addr recvinfo, sizeof(recvinfo).SockLen)
|
||||||
conn.acceptEvent.fire()
|
conn.acceptEvent.fire()
|
||||||
|
|
||||||
# proc getOrCreateConnection(self: Sctp,
|
# proc getOrCreateConnection(self: Sctp,
|
||||||
|
@ -276,10 +306,11 @@ proc sendCallback(ctx: pointer,
|
||||||
trace "sendCallback", data = data.packetPretty(), length
|
trace "sendCallback", data = data.packetPretty(), length
|
||||||
usrsctp_freedumpbuffer(data)
|
usrsctp_freedumpbuffer(data)
|
||||||
let sctpConn = cast[SctpConn](ctx)
|
let sctpConn = cast[SctpConn](ctx)
|
||||||
|
let buf = @(buffer.makeOpenArray(byte, int(length)))
|
||||||
proc testSend() {.async.} =
|
proc testSend() {.async.} =
|
||||||
try:
|
try:
|
||||||
let buf = @(buffer.makeOpenArray(byte, int(length)))
|
|
||||||
trace "Send To", address = sctpConn.address
|
trace "Send To", address = sctpConn.address
|
||||||
|
# TODO: defined it printSctpPacket(buf)
|
||||||
await sctpConn.conn.write(buf)
|
await sctpConn.conn.write(buf)
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
trace "Send Failed", message = exc.msg
|
trace "Send Failed", message = exc.msg
|
||||||
|
@ -355,11 +386,9 @@ proc stop*(self: Sctp) {.async.} =
|
||||||
|
|
||||||
proc readLoopProc(res: SctpConn) {.async.} =
|
proc readLoopProc(res: SctpConn) {.async.} =
|
||||||
while true:
|
while true:
|
||||||
trace "Read Loop Proc Before"
|
|
||||||
let
|
let
|
||||||
msg = await res.conn.read()
|
msg = await res.conn.read()
|
||||||
data = usrsctp_dumppacket(unsafeAddr msg[0], uint(msg.len), SCTP_DUMP_INBOUND)
|
data = usrsctp_dumppacket(unsafeAddr msg[0], uint(msg.len), SCTP_DUMP_INBOUND)
|
||||||
trace "Read Loop Proc Before", isnil=data.isNil()
|
|
||||||
if not data.isNil():
|
if not data.isNil():
|
||||||
trace "Receive data", remoteAddress = res.conn.raddr, data = data.packetPretty()
|
trace "Receive data", remoteAddress = res.conn.raddr, data = data.packetPretty()
|
||||||
usrsctp_freedumpbuffer(data)
|
usrsctp_freedumpbuffer(data)
|
||||||
|
@ -384,6 +413,7 @@ proc listen*(self: Sctp, sctpPort: uint16 = 5000) =
|
||||||
trace "Listening", sctpPort
|
trace "Listening", sctpPort
|
||||||
doAssert 0 == usrsctp_sysctl_set_sctp_blackhole(2)
|
doAssert 0 == usrsctp_sysctl_set_sctp_blackhole(2)
|
||||||
doAssert 0 == usrsctp_sysctl_set_sctp_no_csum_on_loopback(0)
|
doAssert 0 == usrsctp_sysctl_set_sctp_no_csum_on_loopback(0)
|
||||||
|
doAssert 0 == usrsctp_sysctl_set_sctp_delayed_sack_time_default(0)
|
||||||
let sock = usrsctp_socket(AF_CONN, posix.SOCK_STREAM, IPPROTO_SCTP, nil, nil, 0, nil)
|
let sock = usrsctp_socket(AF_CONN, posix.SOCK_STREAM, IPPROTO_SCTP, nil, nil, 0, nil)
|
||||||
var on: int = 1
|
var on: int = 1
|
||||||
doAssert 0 == usrsctp_set_non_blocking(sock, 1)
|
doAssert 0 == usrsctp_set_non_blocking(sock, 1)
|
||||||
|
|
|
@ -21,6 +21,7 @@ proc handles(self: StunConn) {.async.} =
|
||||||
while true: # TODO: while not self.conn.atEof()
|
while true: # TODO: while not self.conn.atEof()
|
||||||
let (msg, raddr) = await self.conn.read()
|
let (msg, raddr) = await self.conn.read()
|
||||||
if Stun.isMessage(msg):
|
if Stun.isMessage(msg):
|
||||||
|
echo "\e[35;1m<STUN>\e[0m"
|
||||||
let res = Stun.getResponse(msg, self.laddr)
|
let res = Stun.getResponse(msg, self.laddr)
|
||||||
if res.isSome():
|
if res.isSome():
|
||||||
await self.conn.write(raddr, res.get())
|
await self.conn.write(raddr, res.get())
|
||||||
|
|
|
@ -24,7 +24,7 @@ proc init*(self: UdpConn, laddr: TransportAddress) =
|
||||||
|
|
||||||
proc onReceive(udp: DatagramTransport, address: TransportAddress) {.async, gcsafe.} =
|
proc onReceive(udp: DatagramTransport, address: TransportAddress) {.async, gcsafe.} =
|
||||||
let msg = udp.getMessage()
|
let msg = udp.getMessage()
|
||||||
echo "\e[33m<UDP>\e[0;1m onReceive\e[0m: ", msg.len()
|
echo "\e[33m<UDP>\e[0;1m onReceive\e[0m"
|
||||||
self.dataRecv.addLastNoWait((msg, address))
|
self.dataRecv.addLastNoWait((msg, address))
|
||||||
|
|
||||||
self.dataRecv = newAsyncQueue[(seq[byte], TransportAddress)]()
|
self.dataRecv = newAsyncQueue[(seq[byte], TransportAddress)]()
|
||||||
|
|
Loading…
Reference in New Issue