diff --git a/waku/v2/node/connectedness_test.go b/waku/v2/node/connectedness_test.go index 2f061d84..49c7eb32 100644 --- a/waku/v2/node/connectedness_test.go +++ b/waku/v2/node/connectedness_test.go @@ -49,6 +49,7 @@ func TestConnectionStatusChanges(t *testing.T) { node1, err := New( WithHostAddress(hostAddr1), WithWakuRelay(), + WithClusterID(16), WithTopicHealthStatusChannel(topicHealthStatusChan), ) require.NoError(t, err) @@ -118,6 +119,7 @@ func startNodeAndSubscribe(t *testing.T, ctx context.Context) *WakuNode { node, err := New( WithHostAddress(hostAddr), WithWakuRelay(), + WithClusterID(16), ) require.NoError(t, err) err = node.Start(ctx) diff --git a/waku/v2/node/wakunode2.go b/waku/v2/node/wakunode2.go index dda7c092..7f2a7cbe 100644 --- a/waku/v2/node/wakunode2.go +++ b/waku/v2/node/wakunode2.go @@ -250,10 +250,11 @@ func New(opts ...WakuNodeOption) (*WakuNode, error) { w.log.Error("creating localnode", zap.Error(err)) } - w.metadata = metadata.NewWakuMetadata(w.opts.clusterID, w.localNode, w.log) + metadata := metadata.NewWakuMetadata(w.opts.clusterID, w.localNode, w.log) + w.metadata = metadata //Initialize peer manager. - w.peermanager = peermanager.NewPeerManager(w.opts.maxPeerConnections, w.opts.peerStoreCapacity, w.log) + w.peermanager = peermanager.NewPeerManager(w.opts.maxPeerConnections, w.opts.peerStoreCapacity, metadata, w.log) w.peerConnector, err = peermanager.NewPeerConnectionStrategy(w.peermanager, discoveryConnectTimeout, w.log) if err != nil { diff --git a/waku/v2/peermanager/peer_manager.go b/waku/v2/peermanager/peer_manager.go index 54cba27b..e2c79ddf 100644 --- a/waku/v2/peermanager/peer_manager.go +++ b/waku/v2/peermanager/peer_manager.go @@ -20,6 +20,7 @@ import ( wps "github.com/waku-org/go-waku/waku/v2/peerstore" waku_proto "github.com/waku-org/go-waku/waku/v2/protocol" wenr "github.com/waku-org/go-waku/waku/v2/protocol/enr" + "github.com/waku-org/go-waku/waku/v2/protocol/metadata" "github.com/waku-org/go-waku/waku/v2/protocol/relay" "github.com/waku-org/go-waku/waku/v2/service" @@ -68,6 +69,7 @@ type WakuProtoInfo struct { // PeerManager applies various controls and manage connections towards peers. type PeerManager struct { peerConnector *PeerConnectionStrategy + metadata *metadata.WakuMetadata maxPeers int maxRelayPeers int logger *zap.Logger @@ -167,7 +169,7 @@ func (pm *PeerManager) TopicHealth(pubsubTopic string) (TopicHealth, error) { } // NewPeerManager creates a new peerManager instance. -func NewPeerManager(maxConnections int, maxPeers int, logger *zap.Logger) *PeerManager { +func NewPeerManager(maxConnections int, maxPeers int, metadata *metadata.WakuMetadata, logger *zap.Logger) *PeerManager { maxRelayPeers, _ := relayAndServicePeers(maxConnections) inRelayPeersTarget, outRelayPeersTarget := inAndOutRelayPeers(maxRelayPeers) @@ -178,6 +180,7 @@ func NewPeerManager(maxConnections int, maxPeers int, logger *zap.Logger) *PeerM pm := &PeerManager{ logger: logger.Named("peer-manager"), + metadata: metadata, maxRelayPeers: maxRelayPeers, InRelayPeersTarget: inRelayPeersTarget, OutRelayPeersTarget: outRelayPeersTarget, diff --git a/waku/v2/peermanager/peer_manager_test.go b/waku/v2/peermanager/peer_manager_test.go index f85a80a3..41dd2769 100644 --- a/waku/v2/peermanager/peer_manager_test.go +++ b/waku/v2/peermanager/peer_manager_test.go @@ -46,7 +46,7 @@ func initTest(t *testing.T) (context.Context, *PeerManager, func()) { require.NoError(t, err) // host 1 is used by peer manager - pm := NewPeerManager(10, 20, utils.Logger()) + pm := NewPeerManager(10, 20, nil, utils.Logger()) pm.SetHost(h1) return ctx, pm, func() { @@ -269,7 +269,7 @@ func createHostWithDiscv5AndPM(t *testing.T, hostName string, topic string, enrF err = wenr.Update(localNode, wenr.WithWakuRelaySharding(rs[0])) require.NoError(t, err) - pm := NewPeerManager(10, 20, logger) + pm := NewPeerManager(10, 20, nil, logger) pm.SetHost(host) peerconn, err := NewPeerConnectionStrategy(pm, 30*time.Second, logger) require.NoError(t, err) diff --git a/waku/v2/peermanager/topic_event_handler.go b/waku/v2/peermanager/topic_event_handler.go index 1a39fee2..fff6de03 100644 --- a/waku/v2/peermanager/topic_event_handler.go +++ b/waku/v2/peermanager/topic_event_handler.go @@ -2,6 +2,7 @@ package peermanager import ( "context" + "time" pubsub "github.com/libp2p/go-libp2p-pubsub" "github.com/libp2p/go-libp2p/core/event" @@ -120,11 +121,29 @@ func (pm *PeerManager) handlerPeerTopicEvent(peerEvt relay.EvtPeerTopic) { wps := pm.host.Peerstore().(*wps.WakuPeerstoreImpl) peerID := peerEvt.PeerID if peerEvt.State == relay.PEER_JOINED { - err := wps.AddPubSubTopic(peerID, peerEvt.PubsubTopic) + rs, err := pm.metadata.RelayShard() + if err != nil { + pm.logger.Error("could not obtain the cluster and shards of wakunode", zap.Error(err)) + return + } else if rs == nil { + pm.logger.Info("not using sharding") + return + } + + if pm.metadata != nil && rs.ClusterID != 0 { + ctx, cancel := context.WithTimeout(pm.ctx, 7*time.Second) + defer cancel() + if err := pm.metadata.DisconnectPeerOnShardMismatch(ctx, peerEvt.PeerID); err != nil { + return + } + } + + err = wps.AddPubSubTopic(peerID, peerEvt.PubsubTopic) if err != nil { pm.logger.Error("failed to add pubSubTopic for peer", logging.HostID("peerID", peerID), zap.String("topic", peerEvt.PubsubTopic), zap.Error(err)) } + pm.topicMutex.RLock() defer pm.topicMutex.RUnlock() pm.checkAndUpdateTopicHealth(pm.subRelayTopics[peerEvt.PubsubTopic]) diff --git a/waku/v2/protocol/lightpush/waku_lightpush_test.go b/waku/v2/protocol/lightpush/waku_lightpush_test.go index 5775bbd2..08c134ee 100644 --- a/waku/v2/protocol/lightpush/waku_lightpush_test.go +++ b/waku/v2/protocol/lightpush/waku_lightpush_test.go @@ -254,7 +254,7 @@ func TestWakuLightPushCornerCases(t *testing.T) { testContentTopic := "/test/10/my-lp-app/proto" // Prepare peer manager instance to include in test - pm := peermanager.NewPeerManager(10, 10, utils.Logger()) + pm := peermanager.NewPeerManager(10, 10, nil, utils.Logger()) node1, sub1, host1 := makeWakuRelay(t, testTopic) defer node1.Stop() diff --git a/waku/v2/protocol/metadata/waku_metadata.go b/waku/v2/protocol/metadata/waku_metadata.go index 74aeb7e8..228f4487 100644 --- a/waku/v2/protocol/metadata/waku_metadata.go +++ b/waku/v2/protocol/metadata/waku_metadata.go @@ -4,6 +4,8 @@ import ( "context" "errors" "math" + "sync" + "time" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/libp2p/go-libp2p/core/host" @@ -32,6 +34,9 @@ type WakuMetadata struct { clusterID uint16 localnode *enode.LocalNode + peerShardsMutex sync.RWMutex + peerShards map[peer.ID][]uint16 + log *zap.Logger } @@ -63,16 +68,22 @@ func (wakuM *WakuMetadata) Start(ctx context.Context) error { wakuM.ctx = ctx wakuM.cancel = cancel + wakuM.peerShards = make(map[peer.ID][]uint16) + + wakuM.h.SetStreamHandlerMatch(MetadataID_v1, protocol.PrefixTextMatch(string(MetadataID_v1)), wakuM.onRequest(ctx)) wakuM.h.Network().Notify(wakuM) - wakuM.h.SetStreamHandlerMatch(MetadataID_v1, protocol.PrefixTextMatch(string(MetadataID_v1)), wakuM.onRequest(ctx)) wakuM.log.Info("metadata protocol started") return nil } -func (wakuM *WakuMetadata) getClusterAndShards() (*uint32, []uint32, error) { - shard, err := enr.RelaySharding(wakuM.localnode.Node().Record()) +func (wakuM *WakuMetadata) RelayShard() (*protocol.RelayShards, error) { + return enr.RelaySharding(wakuM.localnode.Node().Record()) +} + +func (wakuM *WakuMetadata) ClusterAndShards() (*uint32, []uint32, error) { + shard, err := wakuM.RelayShard() if err != nil { return nil, nil, err } @@ -98,7 +109,7 @@ func (wakuM *WakuMetadata) Request(ctx context.Context, peerID peer.ID) (*protoc return nil, err } - clusterID, shards, err := wakuM.getClusterAndShards() + clusterID, shards, err := wakuM.ClusterAndShards() if err != nil { if err := stream.Reset(); err != nil { wakuM.log.Error("resetting connection", zap.Error(err)) @@ -180,7 +191,7 @@ func (wakuM *WakuMetadata) onRequest(ctx context.Context) func(network.Stream) { response := new(pb.WakuMetadataResponse) - clusterID, shards, err := wakuM.getClusterAndShards() + clusterID, shards, err := wakuM.ClusterAndShards() if err != nil { logger.Error("obtaining shard info", zap.Error(err)) } else { @@ -252,11 +263,79 @@ func (wakuM *WakuMetadata) Connected(n network.Network, cc network.Conn) { if shard.ClusterID != wakuM.clusterID { wakuM.disconnectPeer(peerID, errors.New("different clusterID reported")) + return } + + // Store shards so they're used to verify if a relay peer supports the same shards we do + wakuM.peerShardsMutex.Lock() + defer wakuM.peerShardsMutex.Unlock() + wakuM.peerShards[peerID] = shard.ShardIDs }() } // Disconnected is called when a connection closed func (wakuM *WakuMetadata) Disconnected(n network.Network, cc network.Conn) { - // Do nothing + // We no longer need the shard info for that peer + wakuM.peerShardsMutex.Lock() + defer wakuM.peerShardsMutex.Unlock() + delete(wakuM.peerShards, cc.RemotePeer()) +} + +func (wakuM *WakuMetadata) GetPeerShards(ctx context.Context, peerID peer.ID) ([]uint16, error) { + // Already connected and we got the shard info, return immediatly + wakuM.peerShardsMutex.RLock() + shards, ok := wakuM.peerShards[peerID] + wakuM.peerShardsMutex.RUnlock() + if ok { + return shards, nil + } + + // Shard info pending. Let's wait + t := time.NewTicker(200 * time.Millisecond) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-t.C: + wakuM.peerShardsMutex.RLock() + shards, ok := wakuM.peerShards[peerID] + wakuM.peerShardsMutex.RUnlock() + if ok { + return shards, nil + } + } + } +} + +func (wakuM *WakuMetadata) disconnect(peerID peer.ID) { + wakuM.h.Peerstore().RemovePeer(peerID) + err := wakuM.h.Network().ClosePeer(peerID) + if err != nil { + wakuM.log.Error("disconnecting peer", logging.HostID("peerID", peerID), zap.Error(err)) + } +} + +func (wakuM *WakuMetadata) DisconnectPeerOnShardMismatch(ctx context.Context, peerID peer.ID) error { + peerShards, err := wakuM.GetPeerShards(ctx, peerID) + if err != nil { + wakuM.log.Error("could not obtain peer shards", zap.Error(err), logging.HostID("peerID", peerID)) + wakuM.disconnect(peerID) + return err + } + + rs, err := wakuM.RelayShard() + if err != nil { + wakuM.log.Error("could not obtain shards", zap.Error(err)) + wakuM.disconnect(peerID) + return err + } + + if !rs.ContainsAnyShard(rs.ClusterID, peerShards) { + wakuM.log.Info("shard mismatch", logging.HostID("peerID", peerID), zap.Uint16("clusterID", rs.ClusterID), zap.Uint16s("ourShardIDs", rs.ShardIDs), zap.Uint16s("theirShardIDs", peerShards)) + wakuM.disconnect(peerID) + return errors.New("shard mismatch") + } + + return nil } diff --git a/waku/v2/protocol/metadata/waku_metadata_test.go b/waku/v2/protocol/metadata/waku_metadata_test.go index 1ccb90a3..b384d6d5 100644 --- a/waku/v2/protocol/metadata/waku_metadata_test.go +++ b/waku/v2/protocol/metadata/waku_metadata_test.go @@ -62,11 +62,6 @@ func TestWakuMetadataRequest(t *testing.T) { m16_2 := createWakuMetadata(t, &rs16_2) m_noRS := createWakuMetadata(t, nil) - // Removing notifee to test metadata protocol functionality without having the peers being disconnected by the notify process - m16_1.h.Network().StopNotify(m16_1) - m16_2.h.Network().StopNotify(m16_2) - m_noRS.h.Network().StopNotify(m_noRS) - m16_1.h.Peerstore().AddAddrs(m16_2.h.ID(), m16_2.h.Network().ListenAddresses(), peerstore.PermanentAddrTTL) m16_1.h.Peerstore().AddAddrs(m_noRS.h.ID(), m_noRS.h.Network().ListenAddresses(), peerstore.PermanentAddrTTL) diff --git a/waku/v2/protocol/peer_exchange/waku_peer_exchange_test.go b/waku/v2/protocol/peer_exchange/waku_peer_exchange_test.go index 972ff8a6..37139c36 100644 --- a/waku/v2/protocol/peer_exchange/waku_peer_exchange_test.go +++ b/waku/v2/protocol/peer_exchange/waku_peer_exchange_test.go @@ -3,6 +3,9 @@ package peer_exchange import ( "context" "fmt" + "testing" + "time" + "github.com/libp2p/go-libp2p" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoremem" @@ -11,8 +14,6 @@ import ( wps "github.com/waku-org/go-waku/waku/v2/peerstore" "go.uber.org/zap" "golang.org/x/exp/slices" - "testing" - "time" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" @@ -290,7 +291,7 @@ func TestRetrieveProvidePeerExchangeWithPMAndPeerAddr(t *testing.T) { require.NoError(t, err) // Prepare peer manager for host3 - pm3 := peermanager.NewPeerManager(10, 20, log) + pm3 := peermanager.NewPeerManager(10, 20, nil, log) pm3.SetHost(host3) pxPeerConn3, err := peermanager.NewPeerConnectionStrategy(pm3, 30*time.Second, utils.Logger()) require.NoError(t, err) @@ -365,7 +366,7 @@ func TestRetrieveProvidePeerExchangeWithPMOnly(t *testing.T) { require.NoError(t, err) // Prepare peer manager for host3 - pm3 := peermanager.NewPeerManager(10, 20, log) + pm3 := peermanager.NewPeerManager(10, 20, nil, log) pm3.SetHost(host3) pxPeerConn3, err := peermanager.NewPeerConnectionStrategy(pm3, 30*time.Second, utils.Logger()) require.NoError(t, err) diff --git a/waku/v2/protocol/shard.go b/waku/v2/protocol/shard.go index 7438ac49..8b101484 100644 --- a/waku/v2/protocol/shard.go +++ b/waku/v2/protocol/shard.go @@ -64,21 +64,27 @@ func (rs RelayShards) Topics() []WakuPubSubTopic { return result } -func (rs RelayShards) Contains(cluster uint16, index uint16) bool { +func (rs RelayShards) ContainsAnyShard(cluster uint16, indexes []uint16) bool { if rs.ClusterID != cluster { return false } found := false - for _, idx := range rs.ShardIDs { - if idx == index { - found = true + for _, rsIdx := range rs.ShardIDs { + for _, idx := range indexes { + if rsIdx == idx { + return true + } } } return found } +func (rs RelayShards) Contains(cluster uint16, index uint16) bool { + return rs.ContainsAnyShard(cluster, []uint16{index}) +} + func (rs RelayShards) ContainsShardPubsubTopic(topic WakuPubSubTopic) bool { if shardedTopic, err := ToShardPubsubTopic(topic); err != nil { return false diff --git a/waku/v2/protocol/store/waku_store_client_test.go b/waku/v2/protocol/store/waku_store_client_test.go index e92056b5..fc44bde9 100644 --- a/waku/v2/protocol/store/waku_store_client_test.go +++ b/waku/v2/protocol/store/waku_store_client_test.go @@ -3,6 +3,8 @@ package store import ( "context" "crypto/rand" + "testing" + "github.com/libp2p/go-libp2p/core/peerstore" "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" @@ -13,7 +15,6 @@ import ( "github.com/waku-org/go-waku/waku/v2/timesource" "github.com/waku-org/go-waku/waku/v2/utils" "google.golang.org/protobuf/proto" - "testing" ) func TestQueryOptions(t *testing.T) { @@ -35,7 +36,7 @@ func TestQueryOptions(t *testing.T) { require.NoError(t, err) // Let peer manager reside at host - pm := peermanager.NewPeerManager(5, 5, utils.Logger()) + pm := peermanager.NewPeerManager(5, 5, nil, utils.Logger()) pm.SetHost(host) // Add host2 to peerstore