diff --git a/libp2p/builders.nim b/libp2p/builders.nim index c48370c2a..90f217ead 100644 --- a/libp2p/builders.nim +++ b/libp2p/builders.nim @@ -28,6 +28,8 @@ type maxConnsPerPeer: int protoVersion: Option[string] agentVersion: Option[string] + externalAddress: Option[MultiAddress] + addressProvider: Option[CurrentAddressProvider] proc init*(_: type[SwitchBuilder]): SwitchBuilder = SwitchBuilder( @@ -104,6 +106,14 @@ proc withAgentVersion*(b: SwitchBuilder, agentVersion: string): SwitchBuilder = b.agentVersion = some(agentVersion) b +proc withExternalAddress*(b: SwitchBuilder, address: MultiAddress): SwitchBuilder = + b.externalAddress = some(address) + b + +proc withAddressProvider*(b: SwitchBuilder, provider: CurrentAddressProvider): SwitchBuilder = + b.addressProvider = some(provider) + b + proc build*(b: SwitchBuilder): Switch = let inTimeout = b.inTimeout @@ -134,8 +144,13 @@ proc build*(b: SwitchBuilder): Switch = transports &= Transport(TcpTransport.init(b.tcpTransportFlags.get())) transports muxers = {MplexCodec: mplexProvider}.toTable - identify = newIdentify(peerInfo) - + identify = block: + if b.addressProvider.isSome(): + newIdentify(peerInfo, b.addressProvider.get()) + elif b.externalAddress.isSome(): + newIdentify(peerInfo, proc(): MultiAddress = b.externalAddress.get()) + else: + newIdentify(peerInfo) if b.secureManagers.len == 0: b.secureManagers &= SecureProtocol.Noise diff --git a/libp2p/protocols/identify.nim b/libp2p/protocols/identify.nim index 017a48258..67cb354ea 100644 --- a/libp2p/protocols/identify.nim +++ b/libp2p/protocols/identify.nim @@ -42,8 +42,13 @@ type agentVersion*: Option[string] protos*: seq[string] + CurrentAddressProvider* = + proc(): MultiAddress {.raises: [Defect], gcsafe.} + Identify* = ref object of LPProtocol peerInfo*: PeerInfo + provider*: CurrentAddressProvider + proc encodeMsg*(peerInfo: PeerInfo, observedAddr: Multiaddress): ProtoBuffer = result = initProtoBuffer() @@ -101,16 +106,17 @@ proc decodeMsg*(buf: seq[byte]): Option[IdentifyInfo] = trace "decodeMsg: failed to decode received message" none[IdentifyInfo]() -proc newIdentify*(peerInfo: PeerInfo): Identify = +proc newIdentify*(peerInfo: PeerInfo, provider: CurrentAddressProvider = nil): Identify = new result result.peerInfo = peerInfo + result.provider = provider result.init() method init*(p: Identify) = proc handle(conn: Connection, proto: string) {.async, gcsafe, closure.} = try: trace "handling identify request", conn - var pb = encodeMsg(p.peerInfo, conn.observedAddr) + var pb = encodeMsg(p.peerInfo, (if not isNil(p.provider): p.provider() else: conn.observedAddr)) await conn.writeLp(pb.buffer) except CancelledError as exc: raise exc @@ -151,5 +157,5 @@ proc identify*(p: Identify, proc push*(p: Identify, conn: Connection) {.async.} = await conn.write(IdentifyPushCodec) - var pb = encodeMsg(p.peerInfo, conn.observedAddr) + var pb = encodeMsg(p.peerInfo, (if not isNil(p.provider): p.provider() else: conn.observedAddr)) await conn.writeLp(pb.buffer) diff --git a/tests/testidentify.nim b/tests/testidentify.nim index 5b87be552..3b27e6702 100644 --- a/tests/testidentify.nim +++ b/tests/testidentify.nim @@ -31,6 +31,7 @@ suite "Identify": msListen {.threadvar.}: MultistreamSelect msDial {.threadvar.}: MultistreamSelect conn {.threadvar.}: Connection + exposedAddr {.threadvar.}: MultiAddress asyncSetup: ma = Multiaddress.init("/ip4/0.0.0.0/tcp/0").tryGet() @@ -41,7 +42,9 @@ suite "Identify": transport1 = TcpTransport.init() transport2 = TcpTransport.init() - identifyProto1 = newIdentify(remotePeerInfo) + exposedAddr = Multiaddress.init("/ip4/192.168.1.1/tcp/1337").get() + + identifyProto1 = newIdentify(remotePeerInfo, proc(): MultiAddress = exposedAddr) identifyProto2 = newIdentify(remotePeerInfo) msListen = newMultistream() @@ -92,6 +95,7 @@ suite "Identify": check id.pubKey.get() == remoteSecKey.getKey().get() check id.addrs[0] == ma + check id.observedAddr.get() == exposedAddr check id.protoVersion.get() == ProtoVersion check id.agentVersion.get() == customAgentVersion check id.protos == @["/test/proto1/1.0.0", "/test/proto2/1.0.0"] @@ -117,3 +121,23 @@ suite "Identify": let pi2 = PeerInfo.init(PrivateKey.random(ECDSA, rng[]).get()) discard await msDial.select(conn, IdentifyCodec) discard await identifyProto2.identify(conn, pi2) + + asyncTest "external address provider": + msListen.addHandler(IdentifyCodec, identifyProto1) + serverFut = transport1.start(ma) + proc acceptHandler(): Future[void] {.async, gcsafe.} = + let c = await transport1.accept() + await msListen.handle(c) + + acceptFut = acceptHandler() + conn = await transport2.dial(transport1.ma) + + discard await msDial.select(conn, IdentifyCodec) + let id = await identifyProto2.identify(conn, remotePeerInfo) + + check id.pubKey.get() == remoteSecKey.getKey().get() + check id.addrs[0] == ma + check id.observedAddr.get() == exposedAddr + check id.protoVersion.get() == ProtoVersion + check id.agentVersion.get() == AgentVersion + check id.protos == @["/test/proto1/1.0.0", "/test/proto2/1.0.0"]