DTLS support
This commit is contained in:
parent
d5db29a8b7
commit
b33af62270
|
@ -0,0 +1,211 @@
|
|||
# Nim-WebRTC
|
||||
# Copyright (c) 2023 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import std/[openssl, os]
|
||||
import posix
|
||||
import chronos, chronicles
|
||||
import stew/byteutils
|
||||
|
||||
export chronicles
|
||||
|
||||
logScope:
|
||||
topics = "webrtc dtls"
|
||||
|
||||
# Missing openssl procs things
|
||||
const
|
||||
BIO_NOCLOSE = 0x0
|
||||
#BIO_CLOSE = 0x1
|
||||
BIO_CTRL_DGRAM_SET_CONNECTED = 32
|
||||
BIO_C_SET_FD = 104
|
||||
|
||||
proc DTLS_client_method(): PSSL_METHOD {.cdecl, dynlib: DLLSSLName, importc.}
|
||||
proc DTLS_server_method(): PSSL_METHOD {.cdecl, dynlib: DLLSSLName, importc.}
|
||||
proc BIO_new_dgram(fd: SocketHandle, closeFlag: int): BIO {.cdecl, dynlib: DLLUtilName, importc.}
|
||||
proc DTLSv1_listen(ssl: SslPtr, peer: ptr): int {.cdecl, dynlib: DLLSSLName, importc.}
|
||||
proc SSL_CTX_set_cookie_generate_cb(ctx: SslCtx, cb: proc (ssl: SslPtr, cookie: ptr byte, cookieLen: ptr int): int {.cdecl.}) {.cdecl, dynlib: DLLSSLName, importc.}
|
||||
proc SSL_CTX_set_cookie_verify_cb(ctx: SslCtx, cb: proc (ssl: SslPtr, cookie: ptr byte, cookieLen: ptr int): int {.cdecl.}) {.cdecl, dynlib: DLLSSLName, importc.}
|
||||
# --- openssl
|
||||
|
||||
type
|
||||
DtlsSocket = ref object
|
||||
udp: DatagramTransport
|
||||
gotData: AsyncEvent
|
||||
sslCtx: SslCtx
|
||||
ctxIsView: bool
|
||||
ssl: SslPtr
|
||||
|
||||
template wrapSslCallRes(dtlsSocket, call: untyped): untyped =
|
||||
block:
|
||||
var err: type(call)
|
||||
while true:
|
||||
err = call
|
||||
if err <= 0:
|
||||
let openSslErr = SSL_get_error(dtlsSocket.ssl, cint(err))
|
||||
if openSslErr == SSL_ERROR_WANT_READ:
|
||||
dtlsSocket.gotData.clear()
|
||||
#TODO timeouts?
|
||||
await dtlsSocket.gotData.wait()#.withTimeout(1.seconds)
|
||||
continue
|
||||
elif openSslErr == SSL_ERROR_SYSCALL:
|
||||
let err = osLastError()
|
||||
if cint(err) == EAGAIN:
|
||||
dtlsSocket.gotData.clear()
|
||||
await dtlsSocket.gotData.wait()#.withTimeout(1.seconds)
|
||||
continue
|
||||
raiseTransportOsError(err)
|
||||
echo ERR_error_string(culong(ERR_peek_last_error()), nil)
|
||||
raise ValueError.newException("openssl error" & $openSslErr)
|
||||
break
|
||||
err
|
||||
|
||||
template wrapSslCall(dtlsSocket, call: untyped) =
|
||||
discard wrapSslCallRes(dtlsSocket, call)
|
||||
|
||||
proc generateSslCookie(ssl: SslPtr, cookie: ptr byte, cookieLen: ptr int): int {.cdecl.} =
|
||||
#TODO
|
||||
cookieLen[] = 30
|
||||
1
|
||||
|
||||
proc verifySslCookie(ssl: SslPtr, cookie: ptr byte, cookieLen: ptr int): int {.cdecl.} =
|
||||
#TODO
|
||||
1
|
||||
|
||||
proc createDtlsSocket(
|
||||
localAddress = AnyAddress,
|
||||
remoteAddress = AnyAddress,
|
||||
flags: set[ServerFlags] = {NoAutoRead}): DtlsSocket =
|
||||
|
||||
let gotData = newAsyncEvent()
|
||||
proc callback(transp: DatagramTransport, remote: TransportAddress) {.async.} = discard
|
||||
proc callback2(udata: pointer) =
|
||||
gotData.fire()
|
||||
let datagram = newDatagramTransport(
|
||||
callback,
|
||||
local = localAddress,
|
||||
remote = remoteAddress,
|
||||
flags = flags)
|
||||
addReader(datagram.fd, callback2)
|
||||
return DtlsSocket(udp: datagram, gotData: gotData)
|
||||
|
||||
|
||||
proc createDtlsServer*(host: TransportAddress): Future[DtlsSocket] {.async.} =
|
||||
result = createDtlsSocket(
|
||||
localAddress = host,
|
||||
flags = {NoAutoRead, ReuseAddr}
|
||||
)
|
||||
|
||||
result.sslCtx = SSL_CTX_new(DTLS_server_method())
|
||||
#TODO if we close the server with connections alive,
|
||||
#they have a ref to this ctx
|
||||
|
||||
#TODO handle certificates
|
||||
echo SSL_CTX_use_certificate_file(result.sslCtx, "certs/server-cert.pem", SSL_FILETYPE_PEM)
|
||||
echo SSL_CTX_use_PrivateKey_file(result.sslCtx, "certs/server-key.pem", SSL_FILETYPE_PEM)
|
||||
SSL_CTX_set_cookie_generate_cb(result.sslCtx, generateSslCookie)
|
||||
SSL_CTX_set_cookie_verify_cb(result.sslCtx, verifySslCookie)
|
||||
|
||||
proc accept*(sock: DtlsSocket): Future[DtlsSocket] {.async.} =
|
||||
let
|
||||
ctx = sock.sslCtx
|
||||
ssl = SSL_new(ctx)
|
||||
bio = BIO_new_dgram(SocketHandle(sock.udp.fd), BIO_NOCLOSE)
|
||||
|
||||
sslSetBio(ssl, bio, bio)
|
||||
|
||||
var
|
||||
clientSockAddr: Sockaddr_storage
|
||||
clientAddr: TransportAddress
|
||||
doAssert isNil(sock.ssl)
|
||||
sock.ssl = ssl
|
||||
wrapSslCall(sock, DTLSv1_listen(ssl, addr clientSockAddr))
|
||||
sock.ssl = nil
|
||||
let size =
|
||||
if int(clientSockAddr.ss_family) == ord(Domain.AF_INET):
|
||||
sizeof(Sockaddr_in)
|
||||
elif int(clientSockAddr.ss_family) == ord(Domain.AF_INET6):
|
||||
sizeof(Sockaddr_in6)
|
||||
elif int(clientSockAddr.ss_family) == ord(Domain.AF_UNIX):
|
||||
sizeof(Sockaddr_storage)
|
||||
else: doAssert(false); -1
|
||||
fromSAddr(addr clientSockAddr, SockLen(size), clientAddr)
|
||||
|
||||
# create new socket
|
||||
result = createDtlsSocket(
|
||||
localAddress = sock.udp.localAddress,
|
||||
remoteAddress = clientAddr,
|
||||
flags = {NoAutoRead, ReuseAddr}
|
||||
)
|
||||
|
||||
let sockHandle = SocketHandle(result.udp.fd)
|
||||
doAssert BIO_ctrl(bio, BIO_C_SET_FD, 0, cast[cstring](addr sockHandle)) > 0
|
||||
doAssert BIO_ctrl(bio, BIO_CTRL_DGRAM_SET_CONNECTED, 0, cast[cstring](addr clientSockAddr)) > 0
|
||||
|
||||
result.sslCtx = ctx
|
||||
result.ssl = ssl
|
||||
result.ctxIsView = true
|
||||
wrapSslCall(result, SSL_accept(ssl))
|
||||
|
||||
proc connect*(address: TransportAddress): Future[DtlsSocket] {.async.} =
|
||||
result = createDtlsSocket(
|
||||
remoteAddress = address
|
||||
)
|
||||
|
||||
let
|
||||
ctx = SSL_CTX_new(DTLS_client_method())
|
||||
ssl = SSL_new(ctx)
|
||||
bio = BIO_new_dgram(SocketHandle(result.udp.fd), BIO_NOCLOSE)
|
||||
|
||||
#TODO handle certs
|
||||
echo SSL_CTX_use_certificate_file(ctx, "certs/client-cert.pem", SSL_FILETYPE_PEM)
|
||||
echo SSL_CTX_use_PrivateKey_file(ctx, "certs/client-key.pem", SSL_FILETYPE_PEM)
|
||||
echo SSL_CTX_check_private_key(ctx)
|
||||
|
||||
result.sslCtx = ctx
|
||||
result.ssl = ssl
|
||||
var slen: SockLen
|
||||
var remoteSaddr: Sockaddr_storage
|
||||
toSAddr(address, remoteSaddr, slen)
|
||||
doAssert BIO_ctrl(bio, BIO_CTRL_DGRAM_SET_CONNECTED, 0, cast[cstring](addr remoteSaddr)) > 0
|
||||
sslSetBio(ssl, bio, bio)
|
||||
wrapSslCall(result, SSL_connect(ssl))
|
||||
|
||||
proc write*(sock: DtlsSocket, data: seq[byte]) {.async.} =
|
||||
wrapSslCall(sock, SSL_write(sock.ssl, cast[cstring](addr data[0]), data.len))
|
||||
|
||||
proc read*(sock: DtlsSocket): Future[seq[byte]] {.async.} =
|
||||
result = newSeq[byte](1000)
|
||||
let length = wrapSslCallRes(sock, SSL_read(sock.ssl, cast[cstring](addr result[0]), result.len))
|
||||
result.setLen(length)
|
||||
|
||||
proc close*(sock: DtlsSocket) {.async.} =
|
||||
if not isNil(sock.ssl):
|
||||
let shutdownRes = SSL_shutdown(sock.ssl)
|
||||
if shutdownRes == 0:
|
||||
wrapSslCall(sock, SSL_shutdown(sock.ssl))
|
||||
SSL_free(sock.ssl)
|
||||
if not isNil(sock.sslCtx) and not sock.ctxIsView:
|
||||
SSL_CTX_free(sock.sslCtx)
|
||||
sock.udp.close()
|
||||
|
||||
proc main {.async.} =
|
||||
let
|
||||
address = initTAddress("127.0.0.1:8090")
|
||||
server = await createDtlsServer(address)
|
||||
client = connect(address)
|
||||
|
||||
let
|
||||
servConn = await server.accept()
|
||||
clientConn = await client
|
||||
await clientConn.write("Hello world!".toBytes())
|
||||
echo string.fromBytes(await servConn.read())
|
||||
|
||||
await allFutures(servConn.close(), clientConn.close())
|
||||
await server.close()
|
||||
|
||||
waitFor(main())
|
Loading…
Reference in New Issue