diff --git a/libp2p/multiaddress.nim b/libp2p/multiaddress.nim index 75c118d19..14a32e7cc 100644 --- a/libp2p/multiaddress.nim +++ b/libp2p/multiaddress.nim @@ -416,6 +416,9 @@ const MAProtocol( mcodec: multiCodec("wss"), kind: Marker, size: 0 ), + MAProtocol( + mcodec: multiCodec("tls"), kind: Marker, size: 0 + ), MAProtocol( mcodec: multiCodec("ipfs"), kind: Length, size: 0, coder: TranscoderP2P @@ -468,7 +471,7 @@ const IP* = mapOr(IP4, IP6) DNS_OR_IP* = mapOr(DNS, IP) TCP_DNS* = mapAnd(DNS, mapEq("tcp")) - TCP_IP* =mapAnd(IP, mapEq("tcp")) + TCP_IP* = mapAnd(IP, mapEq("tcp")) TCP* = mapOr(TCP_DNS, TCP_IP) UDP_DNS* = mapAnd(DNS, mapEq("udp")) UDP_IP* = mapAnd(IP, mapEq("udp")) @@ -479,9 +482,10 @@ const WS_DNS* = mapAnd(TCP_DNS, mapEq("ws")) WS_IP* = mapAnd(TCP_IP, mapEq("ws")) WS* = mapAnd(TCP, mapEq("ws")) - WSS_DNS* = mapAnd(TCP_DNS, mapEq("wss")) - WSS_IP* = mapAnd(TCP_IP, mapEq("wss")) - WSS* = mapAnd(TCP, mapEq("wss")) + TLS_WS* = mapOr(mapEq("wss"), mapAnd(mapEq("tls"), mapEq("ws"))) + WSS_DNS* = mapAnd(TCP_DNS, TLS_WS) + WSS_IP* = mapAnd(TCP_IP, TLS_WS) + WSS* = mapAnd(TCP, TLS_WS) WebSockets_DNS* = mapOr(WS_DNS, WSS_DNS) WebSockets_IP* = mapOr(WS_IP, WSS_IP) WebSockets* = mapOr(WS, WSS) diff --git a/libp2p/multicodec.nim b/libp2p/multicodec.nim index 538e52ce8..184da5712 100644 --- a/libp2p/multicodec.nim +++ b/libp2p/multicodec.nim @@ -191,9 +191,10 @@ const MultiCodecList = [ ("p2p", 0x01A5), ("http", 0x01E0), ("https", 0x01BB), + ("tls", 0x01C0), ("quic", 0x01CC), ("ws", 0x01DD), - ("wss", 0x01DE), # not in multicodec list + ("wss", 0x01DE), ("p2p-websocket-star", 0x01DF), # not in multicodec list ("p2p-webrtc-star", 0x0113), # not in multicodec list ("p2p-webrtc-direct", 0x0114), # not in multicodec list diff --git a/libp2p/transports/wstransport.nim b/libp2p/transports/wstransport.nim index d77a532aa..f135d4977 100644 --- a/libp2p/transports/wstransport.nim +++ b/libp2p/transports/wstransport.nim @@ -156,8 +156,12 @@ method start*( self.httpservers &= httpserver - let codec = if isWss: - MultiAddress.init("/wss") + let codec = + if isWss: + if ma.contains(multiCodec("tls")) == MaResult[bool].ok(true): + MultiAddress.init("/tls/ws") + else: + MultiAddress.init("/wss") else: MultiAddress.init("/ws") diff --git a/tests/testmultiaddress.nim b/tests/testmultiaddress.nim index 1f98126c9..025909e48 100644 --- a/tests/testmultiaddress.nim +++ b/tests/testmultiaddress.nim @@ -64,6 +64,8 @@ const "/ip4/127.0.0.1/ipfs/QmcgpsyWgH8Y8ajJz1Cu72KnS5uo2Aa2LpzU7kinSupNKC/tcp/1234", "/ip4/127.0.0.1/p2p/QmcgpsyWgH8Y8ajJz1Cu72KnS5uo2Aa2LpzU7kinSupNKC", "/ip4/127.0.0.1/p2p/QmcgpsyWgH8Y8ajJz1Cu72KnS5uo2Aa2LpzU7kinSupNKC/tcp/1234", + "/ip4/127.0.0.1/tcp/8000/wss/p2p/QmcgpsyWgH8Y8ajJz1Cu72KnS5uo2Aa2LpzU7kinSupNKC", + "/ip4/127.0.0.1/tcp/8000/tls/ws/p2p/QmcgpsyWgH8Y8ajJz1Cu72KnS5uo2Aa2LpzU7kinSupNKC", "/unix/a/b/c/d/e", "/unix/stdio", "/ip4/1.2.3.4/tcp/80/unix/a/b/c/d/e/f", diff --git a/tests/testwstransport.nim b/tests/testwstransport.nim index b53f27308..d56fb36ba 100644 --- a/tests/testwstransport.nim +++ b/tests/testwstransport.nim @@ -86,7 +86,9 @@ suite "WebSocket transport": let ma = @[MultiAddress.init("/ip4/0.0.0.0/tcp/0/wss").tryGet()] let transport1 = WsTransport.new(Upgrade(), TLSPrivateKey.init(SecureKey), TLSCertificate.init(SecureCert), {TLSFlags.NoVerifyHost}) + const correctPattern = mapAnd(TCP, mapEq("wss")) await transport1.start(ma) + check correctPattern.match(transport1.addrs[0]) proc acceptHandler() {.async, gcsafe.} = while true: let conn = await transport1.accept() @@ -108,3 +110,21 @@ suite "WebSocket transport": await handlerWait.cancelAndWait() await transport1.stop() + + asyncTest "handles tls/ws": + let ma = @[MultiAddress.init("/ip4/0.0.0.0/tcp/0/tls/ws").tryGet()] + let transport1 = wsSecureTranspProvider() + const correctPattern = mapAnd(TCP, mapEq("tls"), mapEq("ws")) + await transport1.start(ma) + check transport1.handles(transport1.addrs[0]) + check correctPattern.match(transport1.addrs[0]) + + # Would raise somewhere if this wasn't handled: + let + inboundConn = transport1.accept() + outboundConn = await transport1.dial(transport1.addrs[0]) + closing = outboundConn.close() + await (await inboundConn).close() + await closing + + await transport1.stop()