fix: for light node do not check for matching shards but only clusterID (#1154)

This commit is contained in:
Prem Chaitanya Prathi 2024-07-09 18:50:44 +05:30 committed by GitHub
parent 7c13021a32
commit 221cbf6599
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 70 additions and 33 deletions

View File

@ -18,6 +18,7 @@ import (
"github.com/waku-org/go-waku/waku/v2/protocol"
"github.com/waku-org/go-waku/waku/v2/protocol/enr"
"github.com/waku-org/go-waku/waku/v2/protocol/metadata/pb"
"github.com/waku-org/go-waku/waku/v2/protocol/relay"
"go.uber.org/zap"
)
@ -83,6 +84,7 @@ func (wakuM *WakuMetadata) RelayShard() (*protocol.RelayShards, error) {
}
func (wakuM *WakuMetadata) ClusterAndShards() (*uint32, []uint32, error) {
shard, err := wakuM.RelayShard()
if err != nil {
return nil, nil, err
@ -100,7 +102,7 @@ func (wakuM *WakuMetadata) ClusterAndShards() (*uint32, []uint32, error) {
return &u32ClusterID, shards, nil
}
func (wakuM *WakuMetadata) Request(ctx context.Context, peerID peer.ID) (*protocol.RelayShards, error) {
func (wakuM *WakuMetadata) Request(ctx context.Context, peerID peer.ID) (*pb.WakuMetadataResponse, error) {
logger := wakuM.log.With(logging.HostID("peer", peerID))
stream, err := wakuM.h.NewStream(ctx, peerID, MetadataID_v1)
@ -149,31 +151,7 @@ func (wakuM *WakuMetadata) Request(ctx context.Context, peerID peer.ID) (*protoc
stream.Close()
logger.Debug("received metadata response")
if response.ClusterId == nil {
return nil, errors.New("node did not provide a waku clusterid")
}
rClusterID := uint16(*response.ClusterId)
var rShardIDs []uint16
if len(response.Shards) != 0 {
for _, i := range response.Shards {
rShardIDs = append(rShardIDs, uint16(i))
}
} else {
// TODO: remove with nwaku 0.28 deployment
for _, i := range response.ShardsDeprecated { // nolint: staticcheck
rShardIDs = append(rShardIDs, uint16(i))
}
}
logger.Debug("getting remote cluster and shards")
rs, err := protocol.NewRelayShards(rClusterID, rShardIDs...)
if err != nil {
return nil, err
}
return &rs, nil
return response, nil
}
func (wakuM *WakuMetadata) onRequest(ctx context.Context) func(network.Stream) {
@ -259,14 +237,49 @@ func (wakuM *WakuMetadata) Connected(n network.Network, cc network.Conn) {
}
peerID := cc.RemotePeer()
shard, err := wakuM.Request(wakuM.ctx, peerID)
response, err := wakuM.Request(wakuM.ctx, peerID)
if err != nil {
wakuM.disconnectPeer(peerID, err)
return
}
if response.ClusterId == nil {
wakuM.disconnectPeer(peerID, errors.New("node did not provide a waku clusterid"))
return
}
rClusterID := uint16(*response.ClusterId)
var rs protocol.RelayShards
if _, err = wakuM.h.Peerstore().SupportsProtocols(peerID, relay.WakuRelayID_v200); err == nil {
wakuM.log.Debug("light peer only checking clusterID")
if rClusterID != wakuM.clusterID {
wakuM.disconnectPeer(peerID, errors.New("different clusterID reported"))
}
return
}
wakuM.log.Debug("relay peer checking cluster and shards")
var rShardIDs []uint16
if len(response.Shards) != 0 {
for _, i := range response.Shards {
rShardIDs = append(rShardIDs, uint16(i))
}
} else {
// TODO: remove with nwaku 0.28 deployment
for _, i := range response.ShardsDeprecated { // nolint: staticcheck
rShardIDs = append(rShardIDs, uint16(i))
}
}
wakuM.log.Debug("getting remote cluster and shards")
//if peer supports relay, then check for both clusterID and shards.
rs, err = protocol.NewRelayShards(rClusterID, rShardIDs...)
if err != nil {
wakuM.disconnectPeer(peerID, err)
return
}
if shard.ClusterID != wakuM.clusterID {
if rs.ClusterID != wakuM.clusterID {
wakuM.disconnectPeer(peerID, errors.New("different clusterID reported"))
return
}
@ -274,7 +287,7 @@ func (wakuM *WakuMetadata) Connected(n network.Network, cc network.Conn) {
// Store shards so they're used to verify if a relay peer supports the same shards we do
wakuM.peerShardsMutex.Lock()
defer wakuM.peerShardsMutex.Unlock()
wakuM.peerShards[peerID] = shard.ShardIDs
wakuM.peerShards[peerID] = rs.ShardIDs
}()
}

View File

@ -17,6 +17,7 @@ import (
"github.com/waku-org/go-waku/tests"
"github.com/waku-org/go-waku/waku/v2/protocol"
"github.com/waku-org/go-waku/waku/v2/protocol/enr"
"github.com/waku-org/go-waku/waku/v2/protocol/relay"
"github.com/waku-org/go-waku/waku/v2/utils"
)
@ -68,13 +69,28 @@ func TestWakuMetadataRequest(t *testing.T) {
m_noRS := createWakuMetadata(t, nil)
m16_1.h.Peerstore().AddAddrs(m16_2.h.ID(), m16_2.h.Network().ListenAddresses(), peerstore.PermanentAddrTTL)
err = m16_1.h.Peerstore().AddProtocols(m16_2.h.ID(), relay.WakuRelayID_v200)
require.NoError(t, err)
err = m16_2.h.Peerstore().AddProtocols(m16_1.h.ID(), relay.WakuRelayID_v200)
require.NoError(t, err)
m16_1.h.Peerstore().AddAddrs(m_noRS.h.ID(), m_noRS.h.Network().ListenAddresses(), peerstore.PermanentAddrTTL)
// Query a peer that is subscribed to a shard
result, err := m16_1.Request(context.Background(), m16_2.h.ID())
require.NoError(t, err)
require.Equal(t, testShard16, result.ClusterID)
require.Equal(t, rs16_2.ShardIDs, result.ShardIDs)
var rShardIDs []uint16
if len(result.Shards) != 0 {
for _, i := range result.Shards {
rShardIDs = append(rShardIDs, uint16(i))
}
}
rs, err := protocol.NewRelayShards(uint16(*result.ClusterId), rShardIDs...)
require.NoError(t, err)
require.Equal(t, testShard16, rs.ClusterID)
require.Equal(t, rs16_2.ShardIDs, rs.ShardIDs)
// Updating the peer shards
rs16_2.ShardIDs = append(rs16_2.ShardIDs, 3, 4)
@ -84,8 +100,16 @@ func TestWakuMetadataRequest(t *testing.T) {
// Query same peer, after that peer subscribes to more shards
result, err = m16_1.Request(context.Background(), m16_2.h.ID())
require.NoError(t, err)
require.Equal(t, testShard16, result.ClusterID)
require.ElementsMatch(t, rs16_2.ShardIDs, result.ShardIDs)
rShardIDs = make([]uint16, 0)
if len(result.Shards) != 0 {
for _, i := range result.Shards {
rShardIDs = append(rShardIDs, uint16(i))
}
}
rs, err = protocol.NewRelayShards(uint16(*result.ClusterId), rShardIDs...)
require.NoError(t, err)
require.Equal(t, testShard16, rs.ClusterID)
require.ElementsMatch(t, rs16_2.ShardIDs, rs.ShardIDs)
// Query a peer not subscribed to any shard
_, err = m16_1.Request(context.Background(), m_noRS.h.ID())