first draft
This commit is contained in:
parent
474f3d30ad
commit
0fa09ba6f7
|
@ -84,14 +84,16 @@ proc init*(self: DtlsConn, conn: StunConn, laddr: TransportAddress) {.async.} =
|
|||
proc write*(self: DtlsConn, msg: seq[byte]) {.async.} =
|
||||
trace "Dtls write", length = msg.len()
|
||||
var buf = msg
|
||||
discard mbedtls_ssl_write(addr self.ssl, cast[ptr byte](addr buf[0]), buf.len().uint)
|
||||
# TODO: exception catching
|
||||
discard mb_ssl_write(self.ssl, buf)
|
||||
|
||||
proc read*(self: DtlsConn): Future[seq[byte]] {.async.} =
|
||||
var res = newSeq[byte](8192)
|
||||
while true:
|
||||
let tmp = await self.dataRecv.popFirst()
|
||||
self.dataRecv.addFirstNoWait(tmp)
|
||||
let length = mbedtls_ssl_read(addr self.ssl, cast[ptr byte](addr res[0]), res.len().uint)
|
||||
# TODO: exception catching
|
||||
let length = mb_ssl_read(self.ssl, res)
|
||||
if length == MBEDTLS_ERR_SSL_WANT_READ:
|
||||
continue
|
||||
if length < 0:
|
||||
|
@ -191,10 +193,37 @@ proc serverHandshake(self: DtlsConn) {.async.} =
|
|||
continue
|
||||
elif res != 0:
|
||||
raise newException(DtlsError, $(res.mbedtls_high_level_strerr()))
|
||||
# var remoteCertPtr = mbedtls_ssl_get_peer_cert(addr self.ssl)
|
||||
# let remoteCert = remoteCertPtr[]
|
||||
# self.remoteCert = newSeq[byte](remoteCert.raw.len)
|
||||
# copyMem(addr self.remoteCert[0], remoteCert.raw.p, remoteCert.raw.len)
|
||||
|
||||
proc clientHandshake(self: DtlsConn) {.async.} =
|
||||
var shouldRead = false
|
||||
while self.ssl.private_state != MBEDTLS_SSL_HANDSHAKE_OVER:
|
||||
if shouldRead:
|
||||
case self.raddr.family
|
||||
of AddressFamily.IPv4:
|
||||
mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v4)
|
||||
of AddressFamily.IPv6:
|
||||
mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v6)
|
||||
else:
|
||||
raise newException(DtlsError, "Remote address isn't an IP address")
|
||||
let tmp = await self.dataRecv.popFirst()
|
||||
self.dataRecv.addFirstNoWait(tmp)
|
||||
self.sendFuture = nil
|
||||
let res = mb_ssl_handshake_step(self.ssl)
|
||||
if not self.sendFuture.isNil(): await self.sendFuture
|
||||
shouldRead = false
|
||||
if res == MBEDTLS_ERR_SSL_WANT_WRITE:
|
||||
continue
|
||||
elif res == MBEDTLS_ERR_SSL_WANT_READ:
|
||||
# or self.ssl.private_state == MBEDTLS_SSL_SERVER_KEY_EXCHANGE:
|
||||
# TODO: Might need to check directly on mbedtls C source
|
||||
shouldRead = true
|
||||
continue
|
||||
elif res == MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED:
|
||||
mb_ssl_session_reset(self.ssl)
|
||||
shouldRead = true
|
||||
continue
|
||||
elif res != 0:
|
||||
raise newException(DtlsError, $(res.mbedtls_high_level_strerr()))
|
||||
|
||||
proc remoteCertificate*(conn: DtlsConn): seq[byte] =
|
||||
conn.remoteCert
|
||||
|
@ -245,8 +274,8 @@ proc accept*(self: Dtls): Future[DtlsConn] {.async.} =
|
|||
mb_ssl_set_timer_cb(res.ssl, res.timer)
|
||||
mb_ssl_setup(res.ssl, res.config)
|
||||
mb_ssl_session_reset(res.ssl)
|
||||
mbedtls_ssl_set_verify(addr res.ssl, verify, cast[pointer](res))
|
||||
mbedtls_ssl_conf_authmode(addr res.config, MBEDTLS_SSL_VERIFY_OPTIONAL) # TODO: create template
|
||||
mb_ssl_set_verify(res.ssl, verify, res)
|
||||
mb_ssl_conf_authmode(res.config, MBEDTLS_SSL_VERIFY_OPTIONAL)
|
||||
mb_ssl_set_bio(res.ssl, cast[pointer](res),
|
||||
dtlsSend, dtlsRecv, nil)
|
||||
while true:
|
||||
|
@ -263,20 +292,46 @@ proc accept*(self: Dtls): Future[DtlsConn] {.async.} =
|
|||
continue
|
||||
return res
|
||||
|
||||
proc dial*(self: Dtls, raddr: TransportAddress): Future[DtlsConn] {.async.} =
|
||||
discard
|
||||
proc connect*(self: Dtls, raddr: TransportAddress): Future[DtlsConn] {.async.} =
|
||||
var
|
||||
selfvar = self
|
||||
res = DtlsConn()
|
||||
|
||||
#import ../udp_connection
|
||||
#import stew/byteutils
|
||||
#proc main() {.async.} =
|
||||
# let laddr = initTAddress("127.0.0.1:4433")
|
||||
# let udp = UdpConn()
|
||||
# await udp.init(laddr)
|
||||
# let stun = StunConn()
|
||||
# await stun.init(udp, laddr)
|
||||
# let dtls = Dtls()
|
||||
# dtls.start(stun, laddr)
|
||||
# let x = await dtls.accept()
|
||||
# echo "Recv: <", string.fromBytes(await x.read()), ">"
|
||||
#
|
||||
#waitFor(main())
|
||||
await res.init(self.conn, self.laddr)
|
||||
mb_ssl_init(res.ssl)
|
||||
mb_ssl_config_init(res.config)
|
||||
|
||||
var pkey = res.ctr_drbg.generateKey()
|
||||
var srvcert = res.ctr_drbg.generateCertificate(pkey)
|
||||
res.localCert = newSeq[byte](srvcert.raw.len)
|
||||
copyMem(addr res.localCert[0], srvcert.raw.p, srvcert.raw.len)
|
||||
|
||||
mb_ctr_drbg_init(res.ctr_drbg)
|
||||
mb_entropy_init(res.entropy)
|
||||
mb_ctr_drbg_seed(res.ctr_drbg, mbedtls_entropy_func, res.entropy, nil, 0)
|
||||
|
||||
mb_ssl_config_defaults(res.config,
|
||||
MBEDTLS_SSL_IS_CLIENT,
|
||||
MBEDTLS_SSL_TRANSPORT_DATAGRAM,
|
||||
MBEDTLS_SSL_PRESET_DEFAULT)
|
||||
mb_ssl_conf_rng(res.config, mbedtls_ctr_drbg_random, res.ctr_drbg)
|
||||
mb_ssl_conf_read_timeout(res.config, 10000) # in milliseconds
|
||||
mb_ssl_conf_ca_chain(res.config, srvcert.next, nil)
|
||||
mb_ssl_set_timer_cb(res.ssl, res.timer)
|
||||
mb_ssl_setup(res.ssl, res.config)
|
||||
mb_ssl_set_verify(res.ssl, verify, res)
|
||||
mb_ssl_conf_authmode(res.config, MBEDTLS_SSL_VERIFY_OPTIONAL)
|
||||
mb_ssl_set_bio(res.ssl, cast[pointer](res),
|
||||
dtlsSend, dtlsRecv, nil)
|
||||
|
||||
res.raddr = raddr
|
||||
self.connections[raddr] = res
|
||||
|
||||
try:
|
||||
await res.clientHandshake()
|
||||
except CatchableError as exc:
|
||||
trace "Handshake fail", remoteAddress = raddr, error = exc.msg
|
||||
self.connections.del(raddr)
|
||||
raise exc
|
||||
|
||||
return res
|
||||
|
|
|
@ -398,7 +398,7 @@ proc readLoopProc(res: SctpConn) {.async.} =
|
|||
proc accept*(self: Sctp): Future[SctpConn] {.async.} =
|
||||
if not self.isServer:
|
||||
raise newSctpError("Not a server")
|
||||
var res = SctpConn.new(await self.dtls.accept, self)
|
||||
var res = SctpConn.new(await self.dtls.accept(), self)
|
||||
usrsctp_register_address(cast[pointer](res))
|
||||
res.readLoop = res.readLoopProc()
|
||||
res.acceptEvent.clear()
|
||||
|
@ -429,20 +429,21 @@ proc listen*(self: Sctp, sctpPort: uint16 = 5000) =
|
|||
proc connect*(self: Sctp,
|
||||
address: TransportAddress,
|
||||
sctpPort: uint16 = 5000): Future[SctpConn] {.async.} =
|
||||
discard
|
||||
let
|
||||
sctpSocket = usrsctp_socket(AF_CONN, posix.SOCK_STREAM, IPPROTO_SCTP, nil, nil, 0, nil)
|
||||
res = SctpConn.new(await self.dtls.connect(address), self)
|
||||
|
||||
# proc connect*(self: Sctp,
|
||||
# address: TransportAddress,
|
||||
# sctpPort: uint16 = 5000): Future[SctpConn] {.async.} =
|
||||
# trace "Connect", address, sctpPort
|
||||
# let conn = await self.getOrCreateConnection(self.udp, address, sctpPort)
|
||||
# if conn.state == Connected:
|
||||
# return conn
|
||||
# try:
|
||||
# await conn.connectEvent.wait() # TODO: clear?
|
||||
# except CancelledError as exc:
|
||||
# conn.sctpSocket.usrsctp_close()
|
||||
# return nil
|
||||
# if conn.state != Connected:
|
||||
# raise newSctpError("Cannot connect to " & $address)
|
||||
# return conn
|
||||
#usrsctp_register_address(cast[pointer](res))
|
||||
|
||||
# trace "Connect", address, sctpPort
|
||||
# let conn = await self.getOrCreateConnection(self.udp, address, sctpPort)
|
||||
# if conn.state == Connected:
|
||||
# return conn
|
||||
# try:
|
||||
# await conn.connectEvent.wait() # TODO: clear?
|
||||
# except CancelledError as exc:
|
||||
# conn.sctpSocket.usrsctp_close()
|
||||
# return nil
|
||||
# if conn.state != Connected:
|
||||
# raise newSctpError("Cannot connect to " & $address)
|
||||
# return conn
|
||||
|
|
Loading…
Reference in New Issue