Start working on #450

This commit is contained in:
Giovanni 2021-04-01 10:08:40 +09:00
parent cde444d490
commit 896cde52c6
3 changed files with 51 additions and 6 deletions

View File

@ -28,6 +28,8 @@ type
maxConnsPerPeer: int maxConnsPerPeer: int
protoVersion: Option[string] protoVersion: Option[string]
agentVersion: Option[string] agentVersion: Option[string]
externalAddress: Option[MultiAddress]
addressProvider: Option[CurrentAddressProvider]
proc init*(_: type[SwitchBuilder]): SwitchBuilder = proc init*(_: type[SwitchBuilder]): SwitchBuilder =
SwitchBuilder( SwitchBuilder(
@ -104,6 +106,14 @@ proc withAgentVersion*(b: SwitchBuilder, agentVersion: string): SwitchBuilder =
b.agentVersion = some(agentVersion) b.agentVersion = some(agentVersion)
b b
proc withExternalAddress*(b: SwitchBuilder, address: MultiAddress): SwitchBuilder =
b.externalAddress = some(address)
b
proc withAddressProvider*(b: SwitchBuilder, provider: CurrentAddressProvider): SwitchBuilder =
b.addressProvider = some(provider)
b
proc build*(b: SwitchBuilder): Switch = proc build*(b: SwitchBuilder): Switch =
let let
inTimeout = b.inTimeout inTimeout = b.inTimeout
@ -134,8 +144,13 @@ proc build*(b: SwitchBuilder): Switch =
transports &= Transport(TcpTransport.init(b.tcpTransportFlags.get())) transports &= Transport(TcpTransport.init(b.tcpTransportFlags.get()))
transports transports
muxers = {MplexCodec: mplexProvider}.toTable muxers = {MplexCodec: mplexProvider}.toTable
identify = newIdentify(peerInfo) identify = block:
if b.addressProvider.isSome():
newIdentify(peerInfo, b.addressProvider.get())
elif b.externalAddress.isSome():
newIdentify(peerInfo, proc(): MultiAddress = b.externalAddress.get())
else:
newIdentify(peerInfo)
if b.secureManagers.len == 0: if b.secureManagers.len == 0:
b.secureManagers &= SecureProtocol.Noise b.secureManagers &= SecureProtocol.Noise

View File

@ -42,8 +42,13 @@ type
agentVersion*: Option[string] agentVersion*: Option[string]
protos*: seq[string] protos*: seq[string]
CurrentAddressProvider* =
proc(): MultiAddress {.raises: [Defect], gcsafe.}
Identify* = ref object of LPProtocol Identify* = ref object of LPProtocol
peerInfo*: PeerInfo peerInfo*: PeerInfo
provider*: CurrentAddressProvider
proc encodeMsg*(peerInfo: PeerInfo, observedAddr: Multiaddress): ProtoBuffer = proc encodeMsg*(peerInfo: PeerInfo, observedAddr: Multiaddress): ProtoBuffer =
result = initProtoBuffer() result = initProtoBuffer()
@ -101,16 +106,17 @@ proc decodeMsg*(buf: seq[byte]): Option[IdentifyInfo] =
trace "decodeMsg: failed to decode received message" trace "decodeMsg: failed to decode received message"
none[IdentifyInfo]() none[IdentifyInfo]()
proc newIdentify*(peerInfo: PeerInfo): Identify = proc newIdentify*(peerInfo: PeerInfo, provider: CurrentAddressProvider = nil): Identify =
new result new result
result.peerInfo = peerInfo result.peerInfo = peerInfo
result.provider = provider
result.init() result.init()
method init*(p: Identify) = method init*(p: Identify) =
proc handle(conn: Connection, proto: string) {.async, gcsafe, closure.} = proc handle(conn: Connection, proto: string) {.async, gcsafe, closure.} =
try: try:
trace "handling identify request", conn trace "handling identify request", conn
var pb = encodeMsg(p.peerInfo, conn.observedAddr) var pb = encodeMsg(p.peerInfo, (if not isNil(p.provider): p.provider() else: conn.observedAddr))
await conn.writeLp(pb.buffer) await conn.writeLp(pb.buffer)
except CancelledError as exc: except CancelledError as exc:
raise exc raise exc
@ -151,5 +157,5 @@ proc identify*(p: Identify,
proc push*(p: Identify, conn: Connection) {.async.} = proc push*(p: Identify, conn: Connection) {.async.} =
await conn.write(IdentifyPushCodec) await conn.write(IdentifyPushCodec)
var pb = encodeMsg(p.peerInfo, conn.observedAddr) var pb = encodeMsg(p.peerInfo, (if not isNil(p.provider): p.provider() else: conn.observedAddr))
await conn.writeLp(pb.buffer) await conn.writeLp(pb.buffer)

View File

@ -31,6 +31,7 @@ suite "Identify":
msListen {.threadvar.}: MultistreamSelect msListen {.threadvar.}: MultistreamSelect
msDial {.threadvar.}: MultistreamSelect msDial {.threadvar.}: MultistreamSelect
conn {.threadvar.}: Connection conn {.threadvar.}: Connection
exposedAddr {.threadvar.}: MultiAddress
asyncSetup: asyncSetup:
ma = Multiaddress.init("/ip4/0.0.0.0/tcp/0").tryGet() ma = Multiaddress.init("/ip4/0.0.0.0/tcp/0").tryGet()
@ -41,7 +42,9 @@ suite "Identify":
transport1 = TcpTransport.init() transport1 = TcpTransport.init()
transport2 = TcpTransport.init() transport2 = TcpTransport.init()
identifyProto1 = newIdentify(remotePeerInfo) exposedAddr = Multiaddress.init("/ip4/192.168.1.1/tcp/1337").get()
identifyProto1 = newIdentify(remotePeerInfo, proc(): MultiAddress = exposedAddr)
identifyProto2 = newIdentify(remotePeerInfo) identifyProto2 = newIdentify(remotePeerInfo)
msListen = newMultistream() msListen = newMultistream()
@ -92,6 +95,7 @@ suite "Identify":
check id.pubKey.get() == remoteSecKey.getKey().get() check id.pubKey.get() == remoteSecKey.getKey().get()
check id.addrs[0] == ma check id.addrs[0] == ma
check id.observedAddr.get() == exposedAddr
check id.protoVersion.get() == ProtoVersion check id.protoVersion.get() == ProtoVersion
check id.agentVersion.get() == customAgentVersion check id.agentVersion.get() == customAgentVersion
check id.protos == @["/test/proto1/1.0.0", "/test/proto2/1.0.0"] check id.protos == @["/test/proto1/1.0.0", "/test/proto2/1.0.0"]
@ -117,3 +121,23 @@ suite "Identify":
let pi2 = PeerInfo.init(PrivateKey.random(ECDSA, rng[]).get()) let pi2 = PeerInfo.init(PrivateKey.random(ECDSA, rng[]).get())
discard await msDial.select(conn, IdentifyCodec) discard await msDial.select(conn, IdentifyCodec)
discard await identifyProto2.identify(conn, pi2) discard await identifyProto2.identify(conn, pi2)
asyncTest "external address provider":
msListen.addHandler(IdentifyCodec, identifyProto1)
serverFut = transport1.start(ma)
proc acceptHandler(): Future[void] {.async, gcsafe.} =
let c = await transport1.accept()
await msListen.handle(c)
acceptFut = acceptHandler()
conn = await transport2.dial(transport1.ma)
discard await msDial.select(conn, IdentifyCodec)
let id = await identifyProto2.identify(conn, remotePeerInfo)
check id.pubKey.get() == remoteSecKey.getKey().get()
check id.addrs[0] == ma
check id.observedAddr.get() == exposedAddr
check id.protoVersion.get() == ProtoVersion
check id.agentVersion.get() == AgentVersion
check id.protos == @["/test/proto1/1.0.0", "/test/proto2/1.0.0"]