Fixes and tweaks related to the beacon node integration

* Bugfix: Dialing an already connected peer may lead to crash

* Introduced a standard_setup module allowing to instantiate
  the `Switch` object in an easier manner.

* Added `Switch.disconnect(peer)`

* Trailing space removed (sorry about polluting the diff)
This commit is contained in:
Zahary Karadjov 2019-12-08 23:06:58 +02:00
parent 31aaa2c8ec
commit 454f658ba8
No known key found for this signature in database
GPG Key ID: C8936F8A3073D609
15 changed files with 128 additions and 131 deletions

View File

@ -126,8 +126,8 @@ proc createSwitch(ma: MultiAddress): (Switch, PeerInfo) =
let mplexProvider = newMuxerProvider(createMplex, MplexCodec) # create multiplexer let mplexProvider = newMuxerProvider(createMplex, MplexCodec) # create multiplexer
let transports = @[Transport(newTransport(TcpTransport))] # add all transports (tcp only for now, but can be anything in the future) let transports = @[Transport(newTransport(TcpTransport))] # add all transports (tcp only for now, but can be anything in the future)
let muxers = [(MplexCodec, mplexProvider)].toTable() # add all muxers let muxers = {MplexCodec: mplexProvider}.toTable() # add all muxers
let secureManagers = [(SecioCodec, Secure(newSecio(seckey)))].toTable() # setup the secio and any other secure provider let secureManagers = {SecioCodec: Secure(newSecio(seckey))}.toTable() # setup the secio and any other secure provider
# create the switch # create the switch
let switch = newSwitch(peerInfo, let switch = newSwitch(peerInfo,

View File

@ -139,7 +139,7 @@ method init(p: ChatProto) {.gcsafe.} =
if p.connected and not p.conn.closed: if p.connected and not p.conn.closed:
echo "a chat session is already in progress - disconnecting!" echo "a chat session is already in progress - disconnecting!"
await stream.close() await stream.close()
else:
p.conn = stream p.conn = stream
p.connected = true p.connected = true

View File

@ -6,7 +6,12 @@
## at your option. ## at your option.
## This file may not be copied, modified, or distributed except according to ## This file may not be copied, modified, or distributed except according to
## those terms. ## those terms.
import libp2p/daemon/[daemonapi, transpool]
import libp2p/protobuf/minprotobuf import
import libp2p/varint libp2p/daemon/[daemonapi, transpool],
export daemonapi, minprotobuf, varint, transpool libp2p/protobuf/minprotobuf,
libp2p/varint
export
daemonapi, transpool, minprotobuf, varint

View File

@ -427,7 +427,7 @@ proc `$`*(sig: Signature): string =
## Get string representation of signature ``sig``. ## Get string representation of signature ``sig``.
result = toHex(sig.data) result = toHex(sig.data)
proc sign*(key: PrivateKey, data: openarray[byte]): Signature = proc sign*(key: PrivateKey, data: openarray[byte]): Signature {.gcsafe.} =
## Sign message ``data`` using private key ``key`` and return generated ## Sign message ``data`` using private key ``key`` and return generated
## signature in raw binary form. ## signature in raw binary form.
if key.scheme == RSA: if key.scheme == RSA:

View File

@ -863,7 +863,7 @@ proc getSecret*(pubkey: EcPublicKey, seckey: EcPrivateKey): seq[byte] =
copyMem(addr result[0], addr data[0], res) copyMem(addr result[0], addr data[0], res)
proc sign*[T: byte|char](seckey: EcPrivateKey, proc sign*[T: byte|char](seckey: EcPrivateKey,
message: openarray[T]): EcSignature = message: openarray[T]): EcSignature {.gcsafe.} =
## Get ECDSA signature of data ``message`` using private key ``seckey``. ## Get ECDSA signature of data ``message`` using private key ``seckey``.
doAssert(not isNil(seckey)) doAssert(not isNil(seckey))
var hc: BrHashCompatContext var hc: BrHashCompatContext

View File

@ -1836,7 +1836,7 @@ proc clear*(pair: var EdKeyPair) =
burnMem(pair.pubkey.data) burnMem(pair.pubkey.data)
proc sign*[T: byte|char](key: EdPrivateKey, proc sign*[T: byte|char](key: EdPrivateKey,
message: openarray[T]): EdSignature {.noinit.} = message: openarray[T]): EdSignature {.gcsafe, noinit.} =
## Create ED25519 signature of data ``message`` using private key ``key``. ## Create ED25519 signature of data ``message`` using private key ``key``.
var ctx: sha512 var ctx: sha512
var r: GeP3 var r: GeP3

View File

@ -723,13 +723,13 @@ proc `==`*(a, b: RsaPublicKey): bool =
result = r1 and r2 result = r1 and r2
proc sign*[T: byte|char](key: RsaPrivateKey, proc sign*[T: byte|char](key: RsaPrivateKey,
message: openarray[T]): RsaSignature = message: openarray[T]): RsaSignature {.gcsafe.} =
## Get RSA PKCS1.5 signature of data ``message`` using SHA256 and private ## Get RSA PKCS1.5 signature of data ``message`` using SHA256 and private
## key ``key``. ## key ``key``.
doAssert(not isNil(key)) doAssert(not isNil(key))
var hc: BrHashCompatContext var hc: BrHashCompatContext
var hash: array[32, byte] var hash: array[32, byte]
var impl = BrRsaPkcs1SignGetDefault() let impl = BrRsaPkcs1SignGetDefault()
result = new RsaSignature result = new RsaSignature
result.buffer = newSeq[byte]((key.seck.nBitlen + 7) shr 3) result.buffer = newSeq[byte]((key.seck.nBitlen + 7) shr 3)
var kv = addr sha256Vtable var kv = addr sha256Vtable

View File

@ -339,7 +339,7 @@ proc `$`*(sig: SkSignature): string =
discard sig.toBytes(ssig) discard sig.toBytes(ssig)
result = toHex(ssig) result = toHex(ssig)
proc sign*[T: byte|char](key: SkPrivateKey, msg: openarray[T]): SkSignature = proc sign*[T: byte|char](key: SkPrivateKey, msg: openarray[T]): SkSignature {.gcsafe.} =
## Sign message `msg` using private key `key` and return signature object. ## Sign message `msg` using private key `key` and return signature object.
let ctx = getContext() let ctx = getContext()
var hash = sha256.digest(msg) var hash = sha256.digest(msg)

35
libp2p/standard_setup.nim Normal file
View File

@ -0,0 +1,35 @@
import
options, tables,
switch, peer, peerinfo, connection, multiaddress,
crypto/crypto, transports/[transport, tcptransport],
muxers/[muxer, mplex/mplex, mplex/types],
protocols/[identify, secure/secure, secure/secio],
protocols/pubsub/[pubsub, gossipsub, floodsub]
export
switch, peer, peerinfo, connection, multiaddress, crypto
proc newStandardSwitch*(privKey = none(PrivateKey),
address = MultiAddress.init("/ip4/127.0.0.1/tcp/0"),
triggerSelf = false, gossip = false): Switch =
proc createMplex(conn: Connection): Muxer =
result = newMplex(conn)
let
seckey = privKey.get(otherwise = PrivateKey.random(RSA))
peerInfo = PeerInfo.init(seckey, @[address])
mplexProvider = newMuxerProvider(createMplex, MplexCodec)
transports = @[Transport(newTransport(TcpTransport))]
muxers = {MplexCodec: mplexProvider}.toTable
identify = newIdentify(peerInfo)
secureManagers = {SecioCodec: Secure(newSecio seckey)}.toTable
pubSub = if gossip: PubSub newPubSub(GossipSub, peerInfo, triggerSelf)
else: PubSub newPubSub(FloodSub, peerInfo, triggerSelf)
result = newSwitch(peerInfo,
transports,
identify,
muxers,
secureManagers = secureManagers,
pubSub = some(pubSub))

View File

@ -110,11 +110,9 @@ proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} =
let handlerFut = muxer.handle() let handlerFut = muxer.handle()
# add muxer handler cleanup proc # add muxer handler cleanup proc
handlerFut.addCallback( handlerFut.addCallback do (udata: pointer = nil):
proc(udata: pointer = nil) {.gcsafe.} =
trace "muxer handler completed for peer", trace "muxer handler completed for peer",
peer = conn.peerInfo.get().id peer = conn.peerInfo.get().id
)
# do identify first, so that we have a # do identify first, so that we have a
# PeerInfo in case we didn't before # PeerInfo in case we didn't before
@ -141,6 +139,11 @@ proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} =
await s.connections[id].close() await s.connections[id].close()
s.connections.del(id) s.connections.del(id)
proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.} =
let conn = s.connections.getOrDefault(peer.id)
if conn != nil:
await s.cleanupConn(conn)
proc getMuxedStream(s: Switch, peerInfo: PeerInfo): Future[Option[Connection]] {.async, gcsafe.} = proc getMuxedStream(s: Switch, peerInfo: PeerInfo): Future[Option[Connection]] {.async, gcsafe.} =
# if there is a muxer for the connection # if there is a muxer for the connection
# use it instead to create a muxed stream # use it instead to create a muxed stream
@ -194,34 +197,33 @@ proc dial*(s: Switch,
proto: string = ""): proto: string = ""):
Future[Connection] {.async.} = Future[Connection] {.async.} =
let id = peer.id let id = peer.id
trace "dialing peer", peer = id trace "Dialing peer", peer = id
result = s.connections.getOrDefault(id)
if result == nil or result.closed:
for t in s.transports: # for each transport for t in s.transports: # for each transport
for a in peer.addrs: # for each address for a in peer.addrs: # for each address
if t.handles(a): # check if it can dial it if t.handles(a): # check if it can dial it
if id notin s.connections: trace "Dialing address", address = $a
trace "dialing address", address = $a
result = await t.dial(a) result = await t.dial(a)
# make sure to assign the peer to the connection # make sure to assign the peer to the connection
result.peerInfo = some(peer) result.peerInfo = some peer
result = await s.upgradeOutgoing(result) result = await s.upgradeOutgoing(result)
result.closeEvent.wait().addCallback( result.closeEvent.wait().addCallback do (udata: pointer):
proc(udata: pointer) =
asyncCheck s.cleanupConn(result) asyncCheck s.cleanupConn(result)
) break
else:
trace "Reusing existing connection"
if proto.len > 0 and not result.closed: if proto.len > 0 and not result.closed:
let stream = await s.getMuxedStream(peer) let stream = await s.getMuxedStream(peer)
if stream.isSome: if stream.isSome:
trace "connection is muxed, return muxed stream" trace "Connection is muxed, return muxed stream"
result = stream.get() result = stream.get()
trace "attempting to select remote", proto = proto trace "Attempting to select remote", proto = proto
if not (await s.ms.select(result, proto)): if not await s.ms.select(result, proto):
error "unable to select protocol: ", proto = proto error "Unable to select sub-protocol", proto = proto
raise newException(CatchableError, raise newException(CatchableError, &"unable to select protocol: {proto}")
&"unable to select protocol: {proto}")
break # don't dial more than one addr on the same transport
proc mount*[T: LPProtocol](s: Switch, proto: T) {.gcsafe.} = proc mount*[T: LPProtocol](s: Switch, proto: T) {.gcsafe.} =
if isNil(proto.handler): if isNil(proto.handler):
@ -337,3 +339,4 @@ proc newSwitch*(peerInfo: PeerInfo,
if pubSub.isSome: if pubSub.isSome:
result.pubSub = pubSub result.pubSub = pubSub
result.mount(pubSub.get()) result.mount(pubSub.get())

View File

@ -72,7 +72,7 @@ suite "FloodSub":
var nodes: seq[Switch] = newSeq[Switch]() var nodes: seq[Switch] = newSeq[Switch]()
for i in 0..<10: for i in 0..<10:
nodes.add(createNode()) nodes.add(newStandardSwitch())
var awaitters: seq[Future[void]] var awaitters: seq[Future[void]]
for node in nodes: for node in nodes:
@ -104,7 +104,7 @@ suite "FloodSub":
var nodes: seq[Switch] = newSeq[Switch]() var nodes: seq[Switch] = newSeq[Switch]()
for i in 0..<10: for i in 0..<10:
nodes.add(createNode(none(PrivateKey), "/ip4/127.0.0.1/tcp/0", true)) nodes.add newStandardSwitch(triggerSelf = true)
var awaitters: seq[Future[void]] var awaitters: seq[Future[void]]
for node in nodes: for node in nodes:

View File

@ -63,7 +63,7 @@ suite "GossipSub":
var nodes: seq[Switch] = newSeq[Switch]() var nodes: seq[Switch] = newSeq[Switch]()
for i in 0..<2: for i in 0..<2:
nodes.add(createNode(gossip = true)) nodes.add newStandardSwitch(gossip = true)
var awaitters: seq[Future[void]] var awaitters: seq[Future[void]]
for node in nodes: for node in nodes:
@ -143,7 +143,7 @@ suite "GossipSub":
var nodes: seq[Switch] = newSeq[Switch]() var nodes: seq[Switch] = newSeq[Switch]()
for i in 0..<2: for i in 0..<2:
nodes.add(createNode(gossip = true)) nodes.add newStandardSwitch(gossip = true)
var awaitters: seq[Future[void]] var awaitters: seq[Future[void]]
for node in nodes: for node in nodes:
@ -384,7 +384,7 @@ suite "GossipSub":
var awaitters: seq[Future[void]] var awaitters: seq[Future[void]]
for i in 0..<10: for i in 0..<10:
nodes.add(createNode(none(PrivateKey), "/ip4/127.0.0.1/tcp/0", true, true)) nodes.add newStandardSwitch(triggerSelf = true, gossip = true)
awaitters.add((await nodes[i].start())) awaitters.add((await nodes[i].start()))
var seen: Table[string, int] var seen: Table[string, int]

View File

@ -1,57 +1,11 @@
import options, tables import options, tables
import chronos import chronos
import ../../libp2p/[switch, import ../../libp2p/standard_setup
peer, export standard_setup
connection,
multiaddress,
peerinfo,
muxers/muxer,
crypto/crypto,
muxers/mplex/mplex,
muxers/mplex/types,
protocols/identify,
transports/transport,
transports/tcptransport,
protocols/secure/secure,
protocols/secure/secio,
protocols/pubsub/pubsub,
protocols/pubsub/gossipsub,
protocols/pubsub/floodsub]
proc createMplex(conn: Connection): Muxer =
result = newMplex(conn)
proc createNode*(privKey: Option[PrivateKey] = none(PrivateKey),
address: string = "/ip4/127.0.0.1/tcp/0",
triggerSelf: bool = false,
gossip: bool = false): Switch =
var seckey = privKey
if privKey.isNone:
seckey = some(PrivateKey.random(RSA))
var peerInfo = PeerInfo.init(seckey.get(), @[Multiaddress.init(address)])
let mplexProvider = newMuxerProvider(createMplex, MplexCodec)
let transports = @[Transport(newTransport(TcpTransport))]
let muxers = [(MplexCodec, mplexProvider)].toTable()
let identify = newIdentify(peerInfo)
let secureManagers = [(SecioCodec, Secure(newSecio(seckey.get())))].toTable()
var pubSub: Option[PubSub]
if gossip:
pubSub = some(PubSub(newPubSub(GossipSub, peerInfo, triggerSelf)))
else:
pubSub = some(PubSub(newPubSub(FloodSub, peerInfo, triggerSelf)))
result = newSwitch(peerInfo,
transports,
identify,
muxers,
secureManagers = secureManagers,
pubSub = pubSub)
proc generateNodes*(num: Natural, gossip: bool = false): seq[Switch] = proc generateNodes*(num: Natural, gossip: bool = false): seq[Switch] =
for i in 0..<num: for i in 0..<num:
result.add(createNode(gossip = gossip)) result.add(newStandardSwitch(gossip = gossip))
proc subscribeNodes*(nodes: seq[Switch]) {.async.} = proc subscribeNodes*(nodes: seq[Switch]) {.async.} =
for dialer in nodes: for dialer in nodes: