diff --git a/rpc/chain/client.go b/rpc/chain/client.go index a75f04ac4..6ad866015 100644 --- a/rpc/chain/client.go +++ b/rpc/chain/client.go @@ -55,6 +55,8 @@ type ClientInterface interface { bind.ContractCaller bind.ContractTransactor bind.ContractFilterer + GetLimiter() RequestLimiter + SetLimiter(RequestLimiter) } type Tagger interface { @@ -85,6 +87,7 @@ type ClientWithFallback struct { fallback *ethclient.Client mainLimiter *RPCRpsLimiter fallbackLimiter *RPCRpsLimiter + commonLimiter RequestLimiter mainRPC *rpc.Client fallbackRPC *rpc.Client @@ -233,6 +236,12 @@ func (c *ClientWithFallback) IsConnected() bool { } func (c *ClientWithFallback) makeCall(ctx context.Context, main func() ([]any, error), fallback func() ([]any, error)) ([]any, error) { + if c.commonLimiter != nil { + if limited, err := c.commonLimiter.IsLimitReached(c.tag); limited { + return nil, fmt.Errorf("rate limit exceeded for %s: %s", c.tag, err) + } + } + resultChan := make(chan CommandResult, 1) c.LastCheckedAt = time.Now().Unix() errChan := hystrix.Go(c.circuitBreakerCmdName, func() error { @@ -1005,3 +1014,11 @@ func (c *ClientWithFallback) DeepCopyTag() Tagger { copy := *c return © } + +func (c *ClientWithFallback) GetLimiter() RequestLimiter { + return c.commonLimiter +} + +func (c *ClientWithFallback) SetLimiter(limiter RequestLimiter) { + c.commonLimiter = limiter +} diff --git a/rpc/chain/rpc_limiter.go b/rpc/chain/rpc_limiter.go index be81921df..38eecda77 100644 --- a/rpc/chain/rpc_limiter.go +++ b/rpc/chain/rpc_limiter.go @@ -31,15 +31,36 @@ type RequestsStorage interface { Set(data RequestData) error } +// InMemRequestsStorage is an in-memory dummy implementation of RequestsStorage +type InMemRequestsStorage struct { + data RequestData +} + +func NewInMemRequestsStorage() *InMemRequestsStorage { + return &InMemRequestsStorage{} +} + +func (s *InMemRequestsStorage) Get(tag string) (RequestData, error) { + return s.data, nil +} + +func (s *InMemRequestsStorage) Set(data RequestData) error { + s.data = data + return nil +} + type RequestData struct { Tag string CreatedAt time.Time Period time.Duration + MaxReqs int + NumReqs int } type RequestLimiter interface { - SetMaxRequests(tag string, maxRequests int, interval time.Duration) - IsLimitReached(tag string) bool + SetMaxRequests(tag string, maxRequests int, interval time.Duration) error + GetMaxRequests(tag string) (RequestData, error) + IsLimitReached(tag string) (bool, error) } type RPCRequestLimiter struct { @@ -52,39 +73,71 @@ func NewRequestLimiter(storage RequestsStorage) *RPCRequestLimiter { } } -func (rl *RPCRequestLimiter) SetMaxRequests(tag string, maxRequests int, interval time.Duration) { - err := rl.saveToStorage(tag, maxRequests, interval) +func (rl *RPCRequestLimiter) SetMaxRequests(tag string, maxRequests int, interval time.Duration) error { + err := rl.saveToStorage(tag, maxRequests, interval, 0, time.Now()) if err != nil { log.Error("Failed to save request data to storage", "error", err) - return - } - - // Set max requests logic here -} - -func (rl *RPCRequestLimiter) saveToStorage(tag string, maxRequests int, interval time.Duration) error { - data := RequestData{ - Tag: tag, - CreatedAt: time.Now(), - Period: interval, - } - - err := rl.storage.Set(data) - if err != nil { return err } return nil } -func (rl *RPCRequestLimiter) IsLimitReached(tag string) bool { +func (rl *RPCRequestLimiter) GetMaxRequests(tag string) (RequestData, error) { data, err := rl.storage.Get(tag) if err != nil { log.Error("Failed to get request data from storage", "error", err, "tag", tag) - return false + return RequestData{}, err } - return time.Since(data.CreatedAt) >= data.Period + return data, nil +} + +func (rl *RPCRequestLimiter) saveToStorage(tag string, maxRequests int, interval time.Duration, numReqs int, timestamp time.Time) error { + data := RequestData{ + Tag: tag, + CreatedAt: timestamp, + Period: interval, + MaxReqs: maxRequests, + NumReqs: numReqs, + } + + err := rl.storage.Set(data) + if err != nil { + log.Error("Failed to save request data to storage", "error", err) + return err + } + + return nil +} + +func (rl *RPCRequestLimiter) IsLimitReached(tag string) (bool, error) { + data, err := rl.storage.Get(tag) + if err != nil { + return false, err + } + + // Check if a number of requests is over the limit within the interval + if time.Since(data.CreatedAt) < data.Period { + if data.NumReqs >= data.MaxReqs { + return true, nil + } + + err := rl.saveToStorage(tag, data.MaxReqs, data.Period, data.NumReqs+1, data.CreatedAt) + if err != nil { + return false, err + } + + return false, nil + } + + // Reset the number of requests if the interval has passed + err = rl.saveToStorage(tag, data.MaxReqs, data.Period, 0, time.Now()) + if err != nil { + return false, err + } + + return false, nil } type RPCRpsLimiter struct { diff --git a/rpc/chain/rpc_limiter_test.go b/rpc/chain/rpc_limiter_test.go new file mode 100644 index 000000000..02b75a2af --- /dev/null +++ b/rpc/chain/rpc_limiter_test.go @@ -0,0 +1,113 @@ +package chain + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func setupTest() (*InMemRequestsStorage, RequestLimiter) { + storage := NewInMemRequestsStorage() + rl := NewRequestLimiter(storage) + return storage, rl +} + +func TestSetMaxRequests(t *testing.T) { + storage, rl := setupTest() + + // Define test inputs + tag := "testTag" + maxRequests := 10 + interval := time.Second + + // Call the SetMaxRequests method + err := rl.SetMaxRequests(tag, maxRequests, interval) + require.NoError(t, err) + + // Verify that the data was saved to storage correctly + data, err := storage.Get(tag) + require.NoError(t, err) + require.Equal(t, tag, data.Tag) + require.Equal(t, interval, data.Period) + require.Equal(t, maxRequests, data.MaxReqs) + require.Equal(t, 0, data.NumReqs) +} + +func TestGetMaxRequests(t *testing.T) { + storage, rl := setupTest() + + data := RequestData{ + Tag: "testTag", + Period: time.Second, + MaxReqs: 10, + NumReqs: 1, + } + // Define test inputs + storage.Set(data) + + // Call the GetMaxRequests method + ret, err := rl.GetMaxRequests(data.Tag) + require.NoError(t, err) + + // Verify the returned data + require.Equal(t, data, ret) +} + +func TestIsLimitReachedWithinPeriod(t *testing.T) { + storage, rl := setupTest() + + // Define test inputs + tag := "testTag" + maxRequests := 10 + interval := time.Second + + // Set up the storage with test data + data := RequestData{ + Tag: tag, + Period: interval, + CreatedAt: time.Now(), + MaxReqs: maxRequests, + } + storage.Set(data) + + // Call the IsLimitReached method + for i := 0; i < maxRequests; i++ { + limitReached, err := rl.IsLimitReached(tag) + require.NoError(t, err) + + // Verify the result + require.False(t, limitReached) + } + + // Call the IsLimitReached method again + limitReached, err := rl.IsLimitReached(tag) + require.NoError(t, err) + require.True(t, limitReached) +} + +func TestIsLimitReachedWhenPeriodPassed(t *testing.T) { + storage, rl := setupTest() + + // Define test inputs + tag := "testTag" + maxRequests := 10 + interval := time.Second + + // Set up the storage with test data + data := RequestData{ + Tag: tag, + Period: interval, + CreatedAt: time.Now().Add(-interval), + MaxReqs: maxRequests, + NumReqs: maxRequests, + } + storage.Set(data) + + // Call the IsLimitReached method + limitReached, err := rl.IsLimitReached(tag) + require.NoError(t, err) + + // Verify the result + require.False(t, limitReached) +} diff --git a/services/wallet/transfer/commands_sequential.go b/services/wallet/transfer/commands_sequential.go index 64556ac25..156134544 100644 --- a/services/wallet/transfer/commands_sequential.go +++ b/services/wallet/transfer/commands_sequential.go @@ -25,8 +25,13 @@ import ( var findBlocksRetryInterval = 5 * time.Second -const transferHistoryTag = "transfer_history" -const newTransferHistoryTag = "new_transfer_history" +const ( + transferHistoryTag = "transfer_history" + newTransferHistoryTag = "new_transfer_history" + + transferHistoryMaxRequests = 100 + transferHistoryMaxRequestsPeriod = 10 * time.Second +) type nonceInfo struct { nonce *int64 @@ -1117,6 +1122,10 @@ func (c *loadBlocksAndTransfersCommand) fetchHistoryBlocksForAccount(group *asyn log.Debug("range item", "r", rangeItem, "n", c.chainClient.NetworkID(), "a", account) chainClient := chain.ClientWithTag(c.chainClient, transferHistoryTag) + limiter := chain.NewRequestLimiter(chain.NewInMemRequestsStorage()) + limiter.SetMaxRequests(transferHistoryTag, transferHistoryMaxRequests, transferHistoryMaxRequestsPeriod) + chainClient.SetLimiter(limiter) + fbc := &findBlocksCommand{ accounts: []common.Address{account}, db: c.db,