From 2e7a82e130b587eb086fa78628f6733001432f4d Mon Sep 17 00:00:00 2001 From: Prem Chaitanya Prathi Date: Fri, 2 Feb 2024 18:47:09 +0530 Subject: [PATCH] feat: peer exchange filter by shard (#1026) --- waku/v2/protocol/peer_exchange/client.go | 21 +++- .../waku_peer_exchange_option.go | 11 +++ .../peer_exchange/waku_peer_exchange_test.go | 97 +++++++++++++++++++ waku/v2/protocol/shard.go | 8 ++ 4 files changed, 135 insertions(+), 2 deletions(-) diff --git a/waku/v2/protocol/peer_exchange/client.go b/waku/v2/protocol/peer_exchange/client.go index 0f9037b9..6f629b06 100644 --- a/waku/v2/protocol/peer_exchange/client.go +++ b/waku/v2/protocol/peer_exchange/client.go @@ -12,6 +12,7 @@ import ( "github.com/libp2p/go-msgio/pbio" "github.com/waku-org/go-waku/waku/v2/peermanager" "github.com/waku-org/go-waku/waku/v2/peerstore" + "github.com/waku-org/go-waku/waku/v2/protocol" wenr "github.com/waku-org/go-waku/waku/v2/protocol/enr" "github.com/waku-org/go-waku/waku/v2/protocol/peer_exchange/pb" "github.com/waku-org/go-waku/waku/v2/service" @@ -43,10 +44,16 @@ func (wakuPX *WakuPeerExchange) Request(ctx context.Context, numPeers int, opts } if params.pm != nil && params.selectedPeer == "" { + pubsubTopics := []string{} + if params.clusterID != 0 { + pubsubTopics = append(pubsubTopics, + protocol.NewStaticShardingPubsubTopic(uint16(params.clusterID), uint16(params.shard)).String()) + } selectedPeers, err := wakuPX.pm.SelectPeers( peermanager.PeerSelectionCriteria{ SelectionType: params.peerSelectionType, Proto: PeerExchangeID_v20alpha1, + PubsubTopics: pubsubTopics, SpecificPeers: params.preferredPeers, Ctx: ctx, }, @@ -93,10 +100,10 @@ func (wakuPX *WakuPeerExchange) Request(ctx context.Context, numPeers int, opts stream.Close() - return wakuPX.handleResponse(ctx, responseRPC.Response) + return wakuPX.handleResponse(ctx, responseRPC.Response, params) } -func (wakuPX *WakuPeerExchange) handleResponse(ctx context.Context, response *pb.PeerExchangeResponse) error { +func (wakuPX *WakuPeerExchange) handleResponse(ctx context.Context, response *pb.PeerExchangeResponse, params *PeerExchangeParameters) error { var discoveredPeers []struct { addrInfo peer.AddrInfo enr *enode.Node @@ -112,6 +119,16 @@ func (wakuPX *WakuPeerExchange) handleResponse(ctx context.Context, response *pb return err } + if params.clusterID != 0 { + wakuPX.log.Debug("clusterID is non zero, filtering by shard") + rs, err := wenr.RelaySharding(enrRecord) + if err != nil || rs == nil || !rs.Contains(uint16(params.clusterID), uint16(params.shard)) { + wakuPX.log.Debug("peer doesn't matches filter", zap.Int("shard", params.shard)) + continue + } + wakuPX.log.Debug("peer matches filter", zap.Int("shard", params.shard)) + } + enodeRecord, err := enode.New(enode.ValidSchemes, enrRecord) if err != nil { wakuPX.log.Error("creating enode record", zap.Error(err)) diff --git a/waku/v2/protocol/peer_exchange/waku_peer_exchange_option.go b/waku/v2/protocol/peer_exchange/waku_peer_exchange_option.go index 55702ad5..220bd4f9 100644 --- a/waku/v2/protocol/peer_exchange/waku_peer_exchange_option.go +++ b/waku/v2/protocol/peer_exchange/waku_peer_exchange_option.go @@ -18,6 +18,8 @@ type PeerExchangeParameters struct { preferredPeers peer.IDSlice pm *peermanager.PeerManager log *zap.Logger + shard int + clusterID int } type PeerExchangeOption func(*PeerExchangeParameters) error @@ -77,3 +79,12 @@ func DefaultOptions(host host.Host) []PeerExchangeOption { WithAutomaticPeerSelection(), } } + +// Use this if you want to filter peers by specific shards +func FilterByShard(clusterID int, shard int) PeerExchangeOption { + return func(params *PeerExchangeParameters) error { + params.shard = shard + params.clusterID = clusterID + return nil + } +} diff --git a/waku/v2/protocol/peer_exchange/waku_peer_exchange_test.go b/waku/v2/protocol/peer_exchange/waku_peer_exchange_test.go index 5d2fbaa1..4aa0f13a 100644 --- a/waku/v2/protocol/peer_exchange/waku_peer_exchange_test.go +++ b/waku/v2/protocol/peer_exchange/waku_peer_exchange_test.go @@ -6,11 +6,13 @@ import ( "time" "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" "github.com/libp2p/go-libp2p/core/peerstore" "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" "github.com/waku-org/go-waku/tests" "github.com/waku-org/go-waku/waku/v2/discv5" + "github.com/waku-org/go-waku/waku/v2/protocol" wenr "github.com/waku-org/go-waku/waku/v2/protocol/enr" "github.com/waku-org/go-waku/waku/v2/utils" ) @@ -88,3 +90,98 @@ func TestRetrieveProvidePeerExchangePeers(t *testing.T) { px1.Stop() px3.Stop() } + +func TestRetrieveFilteredPeerExchangePeers(t *testing.T) { + // H1 + host1, _, prvKey1 := tests.CreateHost(t) + udpPort1, err := tests.FindFreePort(t, "127.0.0.1", 3) + require.NoError(t, err) + ip1, _ := tests.ExtractIP(host1.Addrs()[0]) + l1, err := tests.NewLocalnode(prvKey1, ip1, udpPort1, wenr.NewWakuEnrBitfield(false, false, false, true), nil, utils.Logger()) + require.NoError(t, err) + + discv5PeerConn1 := discv5.NewTestPeerDiscoverer() + d1, err := discv5.NewDiscoveryV5(prvKey1, l1, discv5PeerConn1, prometheus.DefaultRegisterer, utils.Logger(), discv5.WithUDPPort(uint(udpPort1))) + require.NoError(t, err) + d1.SetHost(host1) + + // H2 + host2, _, prvKey2 := tests.CreateHost(t) + ip2, _ := tests.ExtractIP(host2.Addrs()[0]) + udpPort2, err := tests.FindFreePort(t, "127.0.0.1", 3) + require.NoError(t, err) + l2, err := tests.NewLocalnode(prvKey2, ip2, udpPort2, wenr.NewWakuEnrBitfield(false, false, false, true), nil, utils.Logger()) + require.NoError(t, err) + rs, err := protocol.NewRelayShards(1, 2) + require.NoError(t, err) + l2.Set(enr.WithEntry(wenr.ShardingBitVectorEnrField, rs.BitVector())) + discv5PeerConn2 := discv5.NewTestPeerDiscoverer() + d2, err := discv5.NewDiscoveryV5(prvKey2, l2, discv5PeerConn2, prometheus.DefaultRegisterer, utils.Logger(), discv5.WithUDPPort(uint(udpPort2)), discv5.WithBootnodes([]*enode.Node{d1.Node()})) + require.NoError(t, err) + d2.SetHost(host2) + + // H3 + host3, _, _ := tests.CreateHost(t) + + defer d1.Stop() + defer d2.Stop() + defer host1.Close() + defer host2.Close() + defer host3.Close() + + err = d1.Start(context.Background()) + require.NoError(t, err) + + err = d2.Start(context.Background()) + require.NoError(t, err) + + time.Sleep(3 * time.Second) // Wait some time for peers to be discovered + + // mount peer exchange + pxPeerConn1 := discv5.NewTestPeerDiscoverer() + px1, err := NewWakuPeerExchange(d1, pxPeerConn1, nil, prometheus.DefaultRegisterer, utils.Logger()) + require.NoError(t, err) + px1.SetHost(host1) + + pxPeerConn3 := discv5.NewTestPeerDiscoverer() + px3, err := NewWakuPeerExchange(nil, pxPeerConn3, nil, prometheus.DefaultRegisterer, utils.Logger()) + require.NoError(t, err) + px3.SetHost(host3) + + err = px1.Start(context.Background()) + require.NoError(t, err) + + err = px3.Start(context.Background()) + require.NoError(t, err) + + host3.Peerstore().AddAddrs(host1.ID(), host1.Addrs(), peerstore.PermanentAddrTTL) + err = host3.Peerstore().AddProtocols(host1.ID(), PeerExchangeID_v20alpha1) + require.NoError(t, err) + + //Try with shard that is not registered. + err = px3.Request(context.Background(), 1, WithPeer(host1.ID()), FilterByShard(1, 3)) + require.NoError(t, err) + + time.Sleep(3 * time.Second) // Give the algorithm some time to work its magic + + require.False(t, pxPeerConn3.HasPeer(host2.ID())) + + //Try without shard filtering + + err = px3.Request(context.Background(), 1, WithPeer(host1.ID())) + require.NoError(t, err) + + time.Sleep(3 * time.Second) // Give the algorithm some time to work its magic + + require.True(t, pxPeerConn3.HasPeer(host2.ID())) + + err = px3.Request(context.Background(), 1, WithPeer(host1.ID()), FilterByShard(1, 2)) + require.NoError(t, err) + + time.Sleep(3 * time.Second) // Give the algorithm some time to work its magic + + require.True(t, pxPeerConn3.HasPeer(host2.ID())) + + px1.Stop() + px3.Stop() +} diff --git a/waku/v2/protocol/shard.go b/waku/v2/protocol/shard.go index 66ec5fdc..7438ac49 100644 --- a/waku/v2/protocol/shard.go +++ b/waku/v2/protocol/shard.go @@ -268,3 +268,11 @@ func GeneratePubsubToContentTopicMap(pubsubTopic string, contentTopics []string) } return pubSubTopicMap, nil } + +func ShardsToTopics(clusterId int, shards []int) []string { + pubsubTopics := make([]string, len(shards)) + for i, shard := range shards { + pubsubTopics[i] = NewStaticShardingPubsubTopic(uint16(clusterId), uint16(shard)).String() + } + return pubsubTopics +}