diff --git a/waku/v2/node/wakunode2.go b/waku/v2/node/wakunode2.go index 5032c823..4eefa7e2 100644 --- a/waku/v2/node/wakunode2.go +++ b/waku/v2/node/wakunode2.go @@ -272,7 +272,7 @@ func New(opts ...WakuNodeOption) (*WakuNode, error) { } } - w.peerExchange, err = peer_exchange.NewWakuPeerExchange(w.DiscV5(), w.opts.clusterID, w.peerConnector, w.peermanager, w.opts.prometheusReg, w.log) + w.peerExchange, err = peer_exchange.NewWakuPeerExchange(w.DiscV5(), w.opts.clusterID, w.peerConnector, w.peermanager, w.opts.prometheusReg, w.log, w.opts.peerExchangeOptions...) if err != nil { return nil, err } diff --git a/waku/v2/node/wakunode2_test.go b/waku/v2/node/wakunode2_test.go index d4ca453c..144ce681 100644 --- a/waku/v2/node/wakunode2_test.go +++ b/waku/v2/node/wakunode2_test.go @@ -13,6 +13,7 @@ import ( "time" wenr "github.com/waku-org/go-waku/waku/v2/protocol/enr" + "github.com/waku-org/go-waku/waku/v2/protocol/peer_exchange" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/p2p/enode" @@ -540,3 +541,62 @@ func TestStaticShardingLimits(t *testing.T) { tests.WaitForMsg(t, 2*time.Second, &wg, s2.Ch) } + +func TestPeerExchangeRatelimit(t *testing.T) { + log := utils.Logger() + + if os.Getenv("RUN_FLAKY_TESTS") != "true" { + + log.Info("Skipping", zap.String("test", t.Name()), + zap.String("reason", "RUN_FLAKY_TESTS environment variable is not set to true")) + t.SkipNow() + } + + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second) + defer cancel() + + testClusterID := uint16(21) + + // Node1 with Relay + hostAddr1, err := net.ResolveTCPAddr("tcp", "0.0.0.0:0") + require.NoError(t, err) + wakuNode1, err := New( + WithHostAddress(hostAddr1), + WithWakuRelay(), + WithClusterID(testClusterID), + WithPeerExchange(peer_exchange.WithRateLimiter(1, 1)), + ) + require.NoError(t, err) + err = wakuNode1.Start(ctx) + require.NoError(t, err) + defer wakuNode1.Stop() + + // Node2 with Relay + hostAddr2, err := net.ResolveTCPAddr("tcp", "0.0.0.0:0") + require.NoError(t, err) + wakuNode2, err := New( + WithHostAddress(hostAddr2), + WithWakuRelay(), + WithClusterID(testClusterID), + WithPeerExchange(peer_exchange.WithRateLimiter(1, 1)), + ) + require.NoError(t, err) + err = wakuNode2.Start(ctx) + require.NoError(t, err) + defer wakuNode2.Stop() + + err = wakuNode2.DialPeer(ctx, wakuNode1.ListenAddresses()[0].String()) + require.NoError(t, err) + + //time.Sleep(1 * time.Second) + + err = wakuNode1.PeerExchange().Request(ctx, 1) + require.NoError(t, err) + + err = wakuNode1.PeerExchange().Request(ctx, 1) + require.Error(t, err) + + time.Sleep(1 * time.Second) + err = wakuNode1.PeerExchange().Request(ctx, 1) + require.NoError(t, err) +} diff --git a/waku/v2/node/wakuoptions.go b/waku/v2/node/wakuoptions.go index 82d96461..2e34ace7 100644 --- a/waku/v2/node/wakuoptions.go +++ b/waku/v2/node/wakuoptions.go @@ -31,6 +31,7 @@ import ( "github.com/waku-org/go-waku/waku/v2/protocol/legacy_store" "github.com/waku-org/go-waku/waku/v2/protocol/lightpush" "github.com/waku-org/go-waku/waku/v2/protocol/pb" + "github.com/waku-org/go-waku/waku/v2/protocol/peer_exchange" "github.com/waku-org/go-waku/waku/v2/rendezvous" "github.com/waku-org/go-waku/waku/v2/timesource" "github.com/waku-org/go-waku/waku/v2/utils" @@ -102,7 +103,8 @@ type WakuNodeParameters struct { discV5bootnodes []*enode.Node discV5autoUpdate bool - enablePeerExchange bool + enablePeerExchange bool + peerExchangeOptions []peer_exchange.Option enableRLN bool rlnRelayMemIndex *uint @@ -411,9 +413,10 @@ func WithDiscoveryV5(udpPort uint, bootnodes []*enode.Node, autoUpdate bool) Wak } // WithPeerExchange is a WakuOption used to enable Peer Exchange -func WithPeerExchange() WakuNodeOption { +func WithPeerExchange(options ...peer_exchange.Option) WakuNodeOption { return func(params *WakuNodeParameters) error { params.enablePeerExchange = true + params.peerExchangeOptions = options return nil } } diff --git a/waku/v2/protocol/peer_exchange/protocol.go b/waku/v2/protocol/peer_exchange/protocol.go index c02cdca6..5f103e12 100644 --- a/waku/v2/protocol/peer_exchange/protocol.go +++ b/waku/v2/protocol/peer_exchange/protocol.go @@ -100,6 +100,9 @@ func (wakuPX *WakuPeerExchange) onRequest() func(network.Stream) { wakuPX.metrics.RecordError(rateLimitFailure) wakuPX.log.Error("exceeds the rate limit") // TODO: peer exchange protocol should contain an err field + if err := stream.Reset(); err != nil { + wakuPX.log.Error("resetting connection", zap.Error(err)) + } return }