use a Transport.serverFlags attribute

This commit is contained in:
Ștefan Talpalaru 2020-05-05 17:55:02 +02:00 committed by Dmitriy Ryajov
parent c480e65055
commit 313f9b0952
4 changed files with 13 additions and 15 deletions

View File

@ -6,6 +6,7 @@ const
import import
options, tables, options, tables,
chronos,
switch, peer, peerinfo, connection, multiaddress, switch, peer, peerinfo, connection, multiaddress,
crypto/crypto, transports/[transport, tcptransport], crypto/crypto, transports/[transport, tcptransport],
muxers/[muxer, mplex/mplex, mplex/types], muxers/[muxer, mplex/mplex, mplex/types],
@ -18,14 +19,15 @@ else:
import protocols/secure/secio import protocols/secure/secio
export export
switch, peer, peerinfo, connection, multiaddress, crypto switch, peer, peerinfo, connection, multiaddress, crypto, ServerFlags
proc newStandardSwitch*(privKey = none(PrivateKey), proc newStandardSwitch*(privKey = none(PrivateKey),
address = MultiAddress.init("/ip4/127.0.0.1/tcp/0"), address = MultiAddress.init("/ip4/127.0.0.1/tcp/0"),
triggerSelf = false, triggerSelf = false,
gossip = false, gossip = false,
verifySignature = libp2p_pubsub_verify, verifySignature = libp2p_pubsub_verify,
sign = libp2p_pubsub_sign): Switch = sign = libp2p_pubsub_sign,
serverFlags: set[ServerFlags] = {}): Switch =
proc createMplex(conn: Connection): Muxer = proc createMplex(conn: Connection): Muxer =
result = newMplex(conn) result = newMplex(conn)
@ -33,7 +35,7 @@ proc newStandardSwitch*(privKey = none(PrivateKey),
seckey = privKey.get(otherwise = PrivateKey.random(ECDSA)) seckey = privKey.get(otherwise = PrivateKey.random(ECDSA))
peerInfo = PeerInfo.init(seckey, [address]) peerInfo = PeerInfo.init(seckey, [address])
mplexProvider = newMuxerProvider(createMplex, MplexCodec) mplexProvider = newMuxerProvider(createMplex, MplexCodec)
transports = @[Transport(newTransport(TcpTransport))] transports = @[Transport(newTransport(TcpTransport, serverFlags))]
muxers = {MplexCodec: mplexProvider}.toTable muxers = {MplexCodec: mplexProvider}.toTable
identify = newIdentify(peerInfo) identify = newIdentify(peerInfo)
when libp2p_secure == "noise": when libp2p_secure == "noise":

View File

@ -22,8 +22,6 @@ import connection,
errors, errors,
peer peer
export ServerFlags
logScope: logScope:
topic = "Switch" topic = "Switch"
@ -284,7 +282,7 @@ proc mount*[T: LPProtocol](s: Switch, proto: T) {.gcsafe.} =
s.ms.addHandler(proto.codec, proto) s.ms.addHandler(proto.codec, proto)
proc start*(s: Switch, serverFlags: set[ServerFlags] = {}): Future[seq[Future[void]]] {.async, gcsafe.} = proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} =
trace "starting switch" trace "starting switch"
proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} = proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} =
@ -300,7 +298,7 @@ proc start*(s: Switch, serverFlags: set[ServerFlags] = {}): Future[seq[Future[vo
for t in s.transports: # for each transport for t in s.transports: # for each transport
for i, a in s.peerInfo.addrs: for i, a in s.peerInfo.addrs:
if t.handles(a): # check if it handles the multiaddr if t.handles(a): # check if it handles the multiaddr
var server = await t.listen(a, handle, serverFlags) var server = await t.listen(a, handle)
s.peerInfo.addrs[i] = t.ma # update peer's address s.peerInfo.addrs[i] = t.ma # update peer's address
startFuts.add(server) startFuts.add(server)

View File

@ -16,8 +16,6 @@ import transport,
../multicodec, ../multicodec,
../stream/chronosstream ../stream/chronosstream
export ServerFlags
logScope: logScope:
topic = "TcpTransport" topic = "TcpTransport"
@ -124,13 +122,12 @@ method close*(t: TcpTransport) {.async, gcsafe.} =
method listen*(t: TcpTransport, method listen*(t: TcpTransport,
ma: MultiAddress, ma: MultiAddress,
handler: ConnHandler, handler: ConnHandler):
serverFlags: set[ServerFlags] = {}):
Future[Future[void]] {.async, gcsafe.} = Future[Future[void]] {.async, gcsafe.} =
discard await procCall Transport(t).listen(ma, handler) # call base discard await procCall Transport(t).listen(ma, handler) # call base
## listen on the transport ## listen on the transport
t.server = createStreamServer(t.ma, connCb, serverFlags, t) t.server = createStreamServer(t.ma, connCb, t.serverFlags, t)
t.server.start() t.server.start()
# always get the resolved address in case we're bound to 0.0.0.0:0 # always get the resolved address in case we're bound to 0.0.0.0:0

View File

@ -24,13 +24,15 @@ type
connections*: seq[Connection] connections*: seq[Connection]
handler*: ConnHandler handler*: ConnHandler
multicodec*: MultiCodec multicodec*: MultiCodec
serverFlags*: set[ServerFlags]
method init*(t: Transport) {.base, gcsafe.} = method init*(t: Transport) {.base, gcsafe.} =
## perform protocol initialization ## perform protocol initialization
discard discard
proc newTransport*(t: typedesc[Transport]): t {.gcsafe.} = proc newTransport*(t: typedesc[Transport], serverFlags: set[ServerFlags] = {}): t {.gcsafe.} =
new result new result
result.serverFlags = serverFlags
result.init() result.init()
method close*(t: Transport) {.base, async, gcsafe.} = method close*(t: Transport) {.base, async, gcsafe.} =
@ -41,8 +43,7 @@ method close*(t: Transport) {.base, async, gcsafe.} =
method listen*(t: Transport, method listen*(t: Transport,
ma: MultiAddress, ma: MultiAddress,
handler: ConnHandler, handler: ConnHandler):
serverFlags: set[ServerFlags] = {}):
Future[Future[void]] {.base, async, gcsafe.} = Future[Future[void]] {.base, async, gcsafe.} =
## listen for incoming connections ## listen for incoming connections
t.ma = ma t.ma = ma