diff --git a/libp2p/connmanager.nim b/libp2p/connmanager.nim index 4415128..a8e293e 100644 --- a/libp2p/connmanager.nim +++ b/libp2p/connmanager.nim @@ -48,11 +48,12 @@ type discard ConnEventHandler* = - proc(peerId: PeerID, event: ConnEvent): Future[void] + proc(peerInfo: PeerInfo, event: ConnEvent): Future[void] {.gcsafe, raises: [Defect].} PeerEventKind* {.pure.} = enum Left, + Identified, Joined PeerEvent* = object @@ -63,7 +64,7 @@ type discard PeerEventHandler* = - proc(peerId: PeerID, event: PeerEvent): Future[void] {.gcsafe.} + proc(peerInfo: PeerInfo, event: PeerEvent): Future[void] {.gcsafe.} MuxerHolder = object muxer: Muxer @@ -132,22 +133,22 @@ proc removeConnEventHandler*(c: ConnManager, raiseAssert exc.msg proc triggerConnEvent*(c: ConnManager, - peerId: PeerID, + peerInfo: PeerInfo, event: ConnEvent) {.async, gcsafe.} = try: - trace "About to trigger connection events", peer = peerId + trace "About to trigger connection events", peer = peerInfo.peerId if c.connEvents[event.kind].len() > 0: - trace "triggering connection events", peer = peerId, event = $event.kind + trace "triggering connection events", peer = peerInfo.peerId, event = $event.kind var connEvents: seq[Future[void]] for h in c.connEvents[event.kind]: - connEvents.add(h(peerId, event)) + connEvents.add(h(peerInfo, event)) checkFutures(await allFinished(connEvents)) except CancelledError as exc: raise exc except CatchableError as exc: warn "Exception in triggerConnEvents", - msg = exc.msg, peerId, event = $event + msg = exc.msg, peer = peerInfo.peerId, event = $event proc addPeerEventHandler*(c: ConnManager, handler: PeerEventHandler, @@ -178,33 +179,33 @@ proc removePeerEventHandler*(c: ConnManager, raiseAssert exc.msg proc triggerPeerEvents*(c: ConnManager, - peerId: PeerID, + peerInfo: PeerInfo, event: PeerEvent) {.async, gcsafe.} = - trace "About to trigger peer events", peer = peerId + trace "About to trigger peer events", peer = peerInfo.peerId if c.peerEvents[event.kind].len == 0: return try: - let count = c.connCount(peerId) + let count = c.connCount(peerInfo.peerId) if event.kind == PeerEventKind.Joined and count != 1: - trace "peer already joined", peerId, event = $event + trace "peer already joined", peer = peerInfo.peerId, event = $event return elif event.kind == PeerEventKind.Left and count != 0: - trace "peer still connected or already left", peerId, event = $event + trace "peer still connected or already left", peer = peerInfo.peerId, event = $event return - trace "triggering peer events", peerId, event = $event + trace "triggering peer events", peer = peerInfo.peerId, event = $event var peerEvents: seq[Future[void]] for h in c.peerEvents[event.kind]: - peerEvents.add(h(peerId, event)) + peerEvents.add(h(peerInfo, event)) checkFutures(await allFinished(peerEvents)) except CancelledError as exc: raise exc except CatchableError as exc: # handlers should not raise! - warn "Exception in triggerPeerEvents", exc = exc.msg, peerId + warn "Exception in triggerPeerEvents", exc = exc.msg, peer = peerInfo.peerId proc contains*(c: ConnManager, conn: Connection): bool = ## checks if a connection is being tracked by the @@ -292,12 +293,12 @@ proc onConnUpgraded(c: ConnManager, conn: Connection) {.async.} = trace "Triggering connect events", conn conn.upgrade() - let peerId = conn.peerInfo.peerId + let peerInfo = conn.peerInfo await c.triggerPeerEvents( - peerId, PeerEvent(kind: PeerEventKind.Joined, initiator: conn.dir == Direction.Out)) + peerInfo, PeerEvent(kind: PeerEventKind.Joined, initiator: conn.dir == Direction.Out)) await c.triggerConnEvent( - peerId, ConnEvent(kind: ConnEventKind.Connected, incoming: conn.dir == Direction.In)) + peerInfo, ConnEvent(kind: ConnEventKind.Connected, incoming: conn.dir == Direction.In)) except CatchableError as exc: # This is top-level procedure which will work as separate task, so it # do not need to propagate CancelledError and should handle other errors @@ -307,10 +308,10 @@ proc onConnUpgraded(c: ConnManager, conn: Connection) {.async.} = proc peerCleanup(c: ConnManager, conn: Connection) {.async.} = try: trace "Triggering disconnect events", conn - let peerId = conn.peerInfo.peerId + let peerInfo = conn.peerInfo await c.triggerConnEvent( - peerId, ConnEvent(kind: ConnEventKind.Disconnected)) - await c.triggerPeerEvents(peerId, PeerEvent(kind: PeerEventKind.Left)) + peerInfo, ConnEvent(kind: ConnEventKind.Disconnected)) + await c.triggerPeerEvents(peerInfo, PeerEvent(kind: PeerEventKind.Left)) except CatchableError as exc: # This is top-level procedure which will work as separate task, so it # do not need to propagate CancelledError and should handle other errors diff --git a/libp2p/peerstore.nim b/libp2p/peerstore.nim index 85787e8..c4fc34b 100644 --- a/libp2p/peerstore.nim +++ b/libp2p/peerstore.nim @@ -10,9 +10,9 @@ {.push raises: [Defect].} import - std/[tables, sets, sequtils], + std/[tables, sets, sequtils, options], ./crypto/crypto, - ./peerid, + ./peerid, ./peerinfo, ./multiaddress type @@ -34,23 +34,25 @@ type PeerBook*[T] = object of RootObj book*: Table[PeerID, T] changeHandlers: seq[PeerBookChangeHandler[T]] + + SetPeerBook*[T] = object of PeerBook[HashSet[T]] - AddressBook* = object of PeerBook[HashSet[MultiAddress]] - ProtoBook* = object of PeerBook[HashSet[string]] + AddressBook* = object of SetPeerBook[MultiAddress] + ProtoBook* = object of SetPeerBook[string] KeyBook* = object of PeerBook[PublicKey] #################### # Peer store types # #################### - PeerStore* = ref object of RootObj + PeerStore* = ref object addressBook*: AddressBook protoBook*: ProtoBook keyBook*: KeyBook StoredInfo* = object # Collates stored info about a peer - peerId*: PeerID + peerId*: PeerID addrs*: HashSet[MultiAddress] protos*: HashSet[string] publicKey*: PublicKey @@ -93,39 +95,23 @@ proc delete*[T](peerBook: var PeerBook[T], peerBook.book.del(peerId) return true -#################### -# Address Book API # -#################### +################ +# Set Book API # +################ -proc add*(addressBook: var AddressBook, - peerId: PeerID, - multiaddr: MultiAddress) = - ## Add known multiaddr of a given peer. If the peer is not known, - ## it will be set with the provided multiaddr. +proc add*[T]( + peerBook: var SetPeerBook[T], + peerId: PeerID, + entry: T) = + ## Add entry to a given peer. If the peer is not known, + ## it will be set with the provided entry. - addressBook.book.mgetOrPut(peerId, - initHashSet[MultiAddress]()).incl(multiaddr) + peerBook.book.mgetOrPut(peerId, + initHashSet[T]()).incl(entry) # Notify clients - for handler in addressBook.changeHandlers: - handler(peerId, addressBook.get(peerId)) - -##################### -# Protocol Book API # -##################### - -proc add*(protoBook: var ProtoBook, - peerId: PeerID, - protocol: string) = - ## Adds known protocol codec for a given peer. If the peer is not known, - ## it will be set with the provided protocol. - - protoBook.book.mgetOrPut(peerId, - initHashSet[string]()).incl(protocol) - - # Notify clients - for handler in protoBook.changeHandlers: - handler(peerId, protoBook.get(peerId)) + for handler in peerBook.changeHandlers: + handler(peerId, peerBook.get(peerId)) ################## # Peer Store API # @@ -160,6 +146,21 @@ proc get*(peerStore: PeerStore, publicKey: peerStore.keyBook.get(peerId) ) +proc update*(peerStore: PeerStore, peerInfo: PeerInfo) = + for address in peerInfo.addrs: + peerStore.addressBook.add(peerInfo.peerId, address) + for proto in peerInfo.protocols: + peerStore.protoBook.add(peerInfo.peerId, proto) + let pKey = peerInfo.publicKey() + if pKey.isSome: + peerStore.keyBook.set(peerInfo.peerId, pKey.get()) + +proc replace*(peerStore: PeerStore, peerInfo: PeerInfo) = + discard peerStore.addressBook.delete(peerInfo.peerId) + discard peerStore.protoBook.delete(peerInfo.peerId) + discard peerStore.keyBook.delete(peerInfo.peerId) + peerStore.update(peerInfo) + proc peers*(peerStore: PeerStore): seq[StoredInfo] = ## Get all the stored information of every peer. diff --git a/libp2p/protocols/pubsub/pubsub.nim b/libp2p/protocols/pubsub/pubsub.nim index f6c6cc5..d777182 100644 --- a/libp2p/protocols/pubsub/pubsub.nim +++ b/libp2p/protocols/pubsub/pubsub.nim @@ -568,11 +568,11 @@ proc init*[PubParams: object | bool]( parameters: parameters, topicsHigh: int.high) - proc peerEventHandler(peerId: PeerID, event: PeerEvent) {.async.} = + proc peerEventHandler(peerInfo: PeerInfo, event: PeerEvent) {.async.} = if event.kind == PeerEventKind.Joined: - pubsub.subscribePeer(peerId) + pubsub.subscribePeer(peerInfo.peerId) else: - pubsub.unsubscribePeer(peerId) + pubsub.unsubscribePeer(peerInfo.peerId) switch.addPeerEventHandler(peerEventHandler, PeerEventKind.Joined) switch.addPeerEventHandler(peerEventHandler, PeerEventKind.Left) diff --git a/libp2p/switch.nim b/libp2p/switch.nim index 759543f..66bea4e 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -33,10 +33,11 @@ import stream/connection, utils/semaphore, connmanager, peerid, + peerstore, errors, dialer -export connmanager, upgrade, dialer +export connmanager, upgrade, dialer, peerstore logScope: topics = "libp2p switch" @@ -60,6 +61,7 @@ type ms*: MultistreamSelect acceptFuts: seq[Future[void]] dialer*: Dial + peerStore*: PeerStore proc addConnEventHandler*(s: Switch, handler: ConnEventHandler, @@ -212,6 +214,11 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} = s.acceptFuts.add(s.accept(t)) startFuts.add(server) + proc peerIdentifiedHandler(peerInfo: PeerInfo, event: PeerEvent) {.async.} = + s.peerStore.replace(peerInfo) + + s.connManager.addPeerEventHandler(peerIdentifiedHandler, PeerEventKind.Identified) + debug "Started libp2p node", peer = s.peerInfo return startFuts # listen for incoming connections @@ -259,6 +266,7 @@ proc newSwitch*(peerInfo: PeerInfo, ms: ms, transports: transports, connManager: connManager, + peerStore: PeerStore.new(), dialer: Dialer.new(peerInfo, connManager, transports, ms)) switch.mount(identity) diff --git a/libp2p/upgrademngrs/upgrade.nim b/libp2p/upgrademngrs/upgrade.nim index 68da6d0..abdfef2 100644 --- a/libp2p/upgrademngrs/upgrade.nim +++ b/libp2p/upgrademngrs/upgrade.nim @@ -91,4 +91,5 @@ proc identify*( if info.protos.len > 0: conn.peerInfo.protocols = info.protos + await self.connManager.triggerPeerEvents(conn.peerInfo, PeerEvent(kind: PeerEventKind.Identified)) trace "identified remote peer", conn, peerInfo = shortLog(conn.peerInfo) diff --git a/tests/testpeerstore.nim b/tests/testpeerstore.nim index 4493987..db75f89 100644 --- a/tests/testpeerstore.nim +++ b/tests/testpeerstore.nim @@ -1,5 +1,6 @@ import - std/[unittest2, tables, sequtils, sets], + unittest2, + std/[tables, sequtils, sets], ../libp2p/crypto/crypto, ../libp2p/multiaddress, ../libp2p/peerid, diff --git a/tests/testswitch.nim b/tests/testswitch.nim index 282defb..b912a4c 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -1,6 +1,6 @@ {.used.} -import options, sequtils +import options, sequtils, sets import chronos import stew/byteutils import nimcrypto/sysrand @@ -246,18 +246,18 @@ suite "Switch": var step = 0 var kinds: set[ConnEventKind] - proc hook(peerId: PeerID, event: ConnEvent) {.async, gcsafe.} = + proc hook(peerInfo: PeerInfo, event: ConnEvent) {.async, gcsafe.} = kinds = kinds + {event.kind} case step: of 0: check: event.kind == ConnEventKind.Connected - peerId == switch1.peerInfo.peerId + peerInfo.peerId == switch1.peerInfo.peerId of 1: check: event.kind == ConnEventKind.Disconnected - check peerId == switch1.peerInfo.peerId + check peerInfo.peerId == switch1.peerInfo.peerId else: check false @@ -301,18 +301,18 @@ suite "Switch": var step = 0 var kinds: set[ConnEventKind] - proc hook(peerId: PeerID, event: ConnEvent) {.async, gcsafe.} = + proc hook(peerInfo: PeerInfo, event: ConnEvent) {.async, gcsafe.} = kinds = kinds + {event.kind} case step: of 0: check: event.kind == ConnEventKind.Connected - peerId == switch2.peerInfo.peerId + peerInfo.peerId == switch2.peerInfo.peerId of 1: check: event.kind == ConnEventKind.Disconnected - check peerId == switch2.peerInfo.peerId + check peerInfo.peerId == switch2.peerInfo.peerId else: check false @@ -356,17 +356,17 @@ suite "Switch": var step = 0 var kinds: set[PeerEventKind] - proc handler(peerId: PeerID, event: PeerEvent) {.async, gcsafe.} = + proc handler(peerInfo: PeerInfo, event: PeerEvent) {.async, gcsafe.} = kinds = kinds + {event.kind} case step: of 0: check: event.kind == PeerEventKind.Joined - peerId == switch2.peerInfo.peerId + peerInfo.peerId == switch2.peerInfo.peerId of 1: check: event.kind == PeerEventKind.Left - peerId == switch2.peerInfo.peerId + peerInfo.peerId == switch2.peerInfo.peerId else: check false @@ -410,17 +410,17 @@ suite "Switch": var step = 0 var kinds: set[PeerEventKind] - proc handler(peerId: PeerID, event: PeerEvent) {.async, gcsafe.} = + proc handler(peerInfo: PeerInfo, event: PeerEvent) {.async, gcsafe.} = kinds = kinds + {event.kind} case step: of 0: check: event.kind == PeerEventKind.Joined - peerId == switch1.peerInfo.peerId + peerInfo.peerId == switch1.peerInfo.peerId of 1: check: event.kind == PeerEventKind.Left - peerId == switch1.peerInfo.peerId + peerInfo.peerId == switch1.peerInfo.peerId else: check false @@ -474,7 +474,7 @@ suite "Switch": var step = 0 var kinds: set[PeerEventKind] - proc handler(peerId: PeerID, event: PeerEvent) {.async, gcsafe.} = + proc handler(peerInfo: PeerInfo, event: PeerEvent) {.async, gcsafe.} = kinds = kinds + {event.kind} case step: of 0: @@ -535,7 +535,7 @@ suite "Switch": var switches: seq[Switch] var done = newFuture[void]() var onConnect: Future[void] - proc hook(peerId: PeerID, event: ConnEvent) {.async, gcsafe.} = + proc hook(peerInfo: PeerInfo, event: ConnEvent) {.async, gcsafe.} = case event.kind: of ConnEventKind.Connected: await onConnect @@ -577,7 +577,7 @@ suite "Switch": var switches: seq[Switch] var done = newFuture[void]() var onConnect: Future[void] - proc hook(peerId: PeerID, event: ConnEvent) {.async, gcsafe.} = + proc hook(peerInfo2: PeerInfo, event: ConnEvent) {.async, gcsafe.} = case event.kind: of ConnEventKind.Connected: if conns == 5: @@ -844,3 +844,60 @@ suite "Switch": await allFuturesThrowing( allFutures(switches.mapIt( it.stop() ))) await allFuturesThrowing(awaiters) + + asyncTest "e2e peer store": + let done = newFuture[void]() + proc handle(conn: Connection, proto: string) {.async, gcsafe.} = + try: + let msg = string.fromBytes(await conn.readLp(1024)) + check "Hello!" == msg + await conn.writeLp("Hello!") + finally: + await conn.close() + done.complete() + + let testProto = new TestProto + testProto.codec = TestCodec + testProto.handler = handle + + let switch1 = newStandardSwitch() + switch1.mount(testProto) + + let switch2 = newStandardSwitch() + var awaiters: seq[Future[void]] + awaiters.add(await switch1.start()) + awaiters.add(await switch2.start()) + + let conn = await switch2.dial(switch1.peerInfo.peerId, switch1.peerInfo.addrs, TestCodec) + + check switch1.isConnected(switch2.peerInfo.peerId) + check switch2.isConnected(switch1.peerInfo.peerId) + + await conn.writeLp("Hello!") + let msg = string.fromBytes(await conn.readLp(1024)) + check "Hello!" == msg + await conn.close() + + await allFuturesThrowing( + done.wait(5.seconds), + switch1.stop(), + switch2.stop()) + + # this needs to go at end + await allFuturesThrowing(awaiters) + + check not switch1.isConnected(switch2.peerInfo.peerId) + check not switch2.isConnected(switch1.peerInfo.peerId) + + let storedInfo1 = switch1.peerStore.get(switch2.peerInfo.peerId) + let storedInfo2 = switch2.peerStore.get(switch1.peerInfo.peerId) + + check: + storedInfo1.peerId == switch2.peerInfo.peerId + storedInfo2.peerId == switch1.peerInfo.peerId + + storedInfo1.addrs.toSeq() == switch2.peerInfo.addrs + storedInfo2.addrs.toSeq() == switch1.peerInfo.addrs + + storedInfo1.protos.toSeq() == switch2.peerInfo.protocols + storedInfo2.protos.toSeq() == switch1.peerInfo.protocols