mirror of https://github.com/status-im/go-waku.git
fix: do not start metadata protocol unless required (#920)
This commit is contained in:
parent
d7249fc123
commit
cf8c36f85d
|
@ -4,7 +4,6 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"github.com/ethereum/go-ethereum/p2p/enode"
|
||||
"github.com/libp2p/go-libp2p/core/host"
|
||||
|
@ -57,6 +56,7 @@ func (wakuM *WakuMetadata) SetHost(h host.Host) {
|
|||
func (wakuM *WakuMetadata) Start(ctx context.Context) error {
|
||||
if wakuM.clusterID == 0 {
|
||||
wakuM.log.Warn("no clusterID is specified. Protocol will not be initialized")
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
@ -135,16 +135,21 @@ func (wakuM *WakuMetadata) Request(ctx context.Context, peerID peer.ID) (*protoc
|
|||
stream.Close()
|
||||
|
||||
if response.ClusterId == nil {
|
||||
return nil, nil // Node is not using sharding
|
||||
return nil, errors.New("node did not provide a waku clusterid")
|
||||
}
|
||||
|
||||
result := &protocol.RelayShards{}
|
||||
result.ClusterID = uint16(*response.ClusterId)
|
||||
rClusterID := uint16(*response.ClusterId)
|
||||
var rShardIDs []uint16
|
||||
for _, i := range response.Shards {
|
||||
result.ShardIDs = append(result.ShardIDs, uint16(i))
|
||||
rShardIDs = append(rShardIDs, uint16(i))
|
||||
}
|
||||
|
||||
return result, nil
|
||||
rs, err := protocol.NewRelayShards(rClusterID, rShardIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &rs, nil
|
||||
}
|
||||
|
||||
func (wakuM *WakuMetadata) onRequest(ctx context.Context) func(network.Stream) {
|
||||
|
@ -209,6 +214,15 @@ func (wakuM *WakuMetadata) ListenClose(n network.Network, m multiaddr.Multiaddr)
|
|||
// Do nothing
|
||||
}
|
||||
|
||||
func (wakuM *WakuMetadata) disconnectPeer(peerID peer.ID, reason error) {
|
||||
logger := wakuM.log.With(logging.HostID("peerID", peerID))
|
||||
logger.Error("disconnecting from peer", zap.Error(reason))
|
||||
wakuM.h.Peerstore().RemovePeer(peerID)
|
||||
if err := wakuM.h.Network().ClosePeer(peerID); err != nil {
|
||||
logger.Error("could not disconnect from peer", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Connected is called when a connection is opened
|
||||
func (wakuM *WakuMetadata) Connected(n network.Network, cc network.Conn) {
|
||||
go func() {
|
||||
|
@ -219,30 +233,14 @@ func (wakuM *WakuMetadata) Connected(n network.Network, cc network.Conn) {
|
|||
|
||||
peerID := cc.RemotePeer()
|
||||
|
||||
logger := wakuM.log.With(logging.HostID("peerID", peerID))
|
||||
|
||||
shouldDisconnect := true
|
||||
shard, err := wakuM.Request(wakuM.ctx, peerID)
|
||||
if err == nil {
|
||||
if shard == nil {
|
||||
err = errors.New("no shard reported")
|
||||
} else if shard.ClusterID != wakuM.clusterID {
|
||||
err = errors.New("different clusterID reported")
|
||||
}
|
||||
} else {
|
||||
// Only disconnect from peers if they support the protocol
|
||||
// TODO: open a PR in go-libp2p to create a var with this error to not have to compare strings but use errors.Is instead
|
||||
if strings.Contains(err.Error(), "protocols not supported") {
|
||||
shouldDisconnect = false
|
||||
}
|
||||
if err != nil {
|
||||
wakuM.disconnectPeer(peerID, err)
|
||||
return
|
||||
}
|
||||
|
||||
if shouldDisconnect && err != nil {
|
||||
logger.Error("disconnecting from peer", zap.Error(err))
|
||||
wakuM.h.Peerstore().RemovePeer(peerID)
|
||||
if err := wakuM.h.Network().ClosePeer(peerID); err != nil {
|
||||
logger.Error("could not disconnect from peer", zap.Error(err))
|
||||
}
|
||||
if shard.ClusterID != wakuM.clusterID {
|
||||
wakuM.disconnectPeer(peerID, errors.New("different clusterID reported"))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
|
|
@ -3,12 +3,15 @@ package metadata
|
|||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
gcrypto "github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/libp2p/go-libp2p/core/peerstore"
|
||||
libp2pProtocol "github.com/libp2p/go-libp2p/core/protocol"
|
||||
"github.com/multiformats/go-multistream"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/waku-org/go-waku/tests"
|
||||
"github.com/waku-org/go-waku/waku/v2/protocol"
|
||||
|
@ -42,6 +45,11 @@ func createWakuMetadata(t *testing.T, rs *protocol.RelayShards) *WakuMetadata {
|
|||
return m1
|
||||
}
|
||||
|
||||
func isProtocolNotSupported(err error) bool {
|
||||
notSupportedErr := multistream.ErrNotSupported[libp2pProtocol.ID]{}
|
||||
return errors.Is(err, notSupportedErr)
|
||||
}
|
||||
|
||||
func TestWakuMetadataRequest(t *testing.T) {
|
||||
testShard16 := uint16(16)
|
||||
|
||||
|
@ -79,11 +87,9 @@ func TestWakuMetadataRequest(t *testing.T) {
|
|||
require.Equal(t, testShard16, result.ClusterID)
|
||||
require.ElementsMatch(t, rs16_2.ShardIDs, result.ShardIDs)
|
||||
|
||||
// Query a peer not subscribed to a shard
|
||||
result, err = m16_1.Request(context.Background(), m_noRS.h.ID())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint16(0), result.ClusterID)
|
||||
require.Len(t, result.ShardIDs, 0)
|
||||
// Query a peer not subscribed to any shard
|
||||
_, err = m16_1.Request(context.Background(), m_noRS.h.ID())
|
||||
require.True(t, isProtocolNotSupported(err))
|
||||
}
|
||||
|
||||
func TestNoNetwork(t *testing.T) {
|
||||
|
@ -93,7 +99,7 @@ func TestNoNetwork(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
m1 := createWakuMetadata(t, &rs1)
|
||||
|
||||
// host2 does not support metadata protocol
|
||||
// host2 does not support metadata protocol, so it should be dropped
|
||||
port, err := tests.FindFreePort(t, "", 5)
|
||||
require.NoError(t, err)
|
||||
host2, err := tests.MakeHost(context.Background(), port, rand.Reader)
|
||||
|
@ -106,12 +112,10 @@ func TestNoNetwork(t *testing.T) {
|
|||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Verifying peer connections
|
||||
require.Len(t, m1.h.Network().Peers(), 1)
|
||||
require.Len(t, host2.Network().Peers(), 1)
|
||||
require.Len(t, m1.h.Network().Peers(), 0)
|
||||
require.Len(t, host2.Network().Peers(), 0)
|
||||
}
|
||||
|
||||
// go test -timeout 300s -run TestDropConnectionOnDiffNetworks github.com/waku-org/go-waku/waku/v2/protocol/metadata -count 1 -v
|
||||
|
||||
func TestDropConnectionOnDiffNetworks(t *testing.T) {
|
||||
cluster1 := uint16(1)
|
||||
cluster2 := uint16(2)
|
||||
|
|
Loading…
Reference in New Issue