diff --git a/rpc/chain/client.go b/rpc/chain/client.go index 6ad866015..c0d695057 100644 --- a/rpc/chain/client.go +++ b/rpc/chain/client.go @@ -238,7 +238,7 @@ func (c *ClientWithFallback) IsConnected() bool { func (c *ClientWithFallback) makeCall(ctx context.Context, main func() ([]any, error), fallback func() ([]any, error)) ([]any, error) { if c.commonLimiter != nil { if limited, err := c.commonLimiter.IsLimitReached(c.tag); limited { - return nil, fmt.Errorf("rate limit exceeded for %s: %s", c.tag, err) + return nil, fmt.Errorf("tag=%s, %w", c.tag, err) } } diff --git a/rpc/chain/rpc_limiter.go b/rpc/chain/rpc_limiter.go index 38eecda77..b8aee8c1a 100644 --- a/rpc/chain/rpc_limiter.go +++ b/rpc/chain/rpc_limiter.go @@ -5,8 +5,9 @@ import ( "sync" "time" - "github.com/ethereum/go-ethereum/log" "github.com/google/uuid" + + "github.com/ethereum/go-ethereum/log" ) const ( @@ -27,25 +28,29 @@ type callerOnWait struct { } type RequestsStorage interface { - Get(tag string) (RequestData, error) - Set(data RequestData) error + Get(tag string) (*RequestData, error) + Set(data *RequestData) error } -// InMemRequestsStorage is an in-memory dummy implementation of RequestsStorage -type InMemRequestsStorage struct { - data RequestData +type InMemRequestsMapStorage struct { + data sync.Map } -func NewInMemRequestsStorage() *InMemRequestsStorage { - return &InMemRequestsStorage{} +func NewInMemRequestsMapStorage() *InMemRequestsMapStorage { + return &InMemRequestsMapStorage{} } -func (s *InMemRequestsStorage) Get(tag string) (RequestData, error) { - return s.data, nil +func (s *InMemRequestsMapStorage) Get(tag string) (*RequestData, error) { + data, ok := s.data.Load(tag) + if !ok { + return nil, nil + } + + return data.(*RequestData), nil } -func (s *InMemRequestsStorage) Set(data RequestData) error { - s.data = data +func (s *InMemRequestsMapStorage) Set(data *RequestData) error { + s.data.Store(data.Tag, data) return nil } @@ -59,7 +64,7 @@ type RequestData struct { type RequestLimiter interface { SetMaxRequests(tag string, maxRequests int, interval time.Duration) error - GetMaxRequests(tag string) (RequestData, error) + GetMaxRequests(tag string) (*RequestData, error) IsLimitReached(tag string) (bool, error) } @@ -83,18 +88,17 @@ func (rl *RPCRequestLimiter) SetMaxRequests(tag string, maxRequests int, interva return nil } -func (rl *RPCRequestLimiter) GetMaxRequests(tag string) (RequestData, error) { +func (rl *RPCRequestLimiter) GetMaxRequests(tag string) (*RequestData, error) { data, err := rl.storage.Get(tag) if err != nil { - log.Error("Failed to get request data from storage", "error", err, "tag", tag) - return RequestData{}, err + return nil, err } return data, nil } func (rl *RPCRequestLimiter) saveToStorage(tag string, maxRequests int, interval time.Duration, numReqs int, timestamp time.Time) error { - data := RequestData{ + data := &RequestData{ Tag: tag, CreatedAt: timestamp, Period: interval, @@ -117,6 +121,10 @@ func (rl *RPCRequestLimiter) IsLimitReached(tag string) (bool, error) { return false, err } + if data == nil { + return false, nil + } + // Check if a number of requests is over the limit within the interval if time.Since(data.CreatedAt) < data.Period { if data.NumReqs >= data.MaxReqs { diff --git a/rpc/chain/rpc_limiter_test.go b/rpc/chain/rpc_limiter_test.go index 02b75a2af..e922f6ca9 100644 --- a/rpc/chain/rpc_limiter_test.go +++ b/rpc/chain/rpc_limiter_test.go @@ -7,8 +7,8 @@ import ( "github.com/stretchr/testify/require" ) -func setupTest() (*InMemRequestsStorage, RequestLimiter) { - storage := NewInMemRequestsStorage() +func setupTest() (*InMemRequestsMapStorage, RequestLimiter) { + storage := NewInMemRequestsMapStorage() rl := NewRequestLimiter(storage) return storage, rl } @@ -37,7 +37,7 @@ func TestSetMaxRequests(t *testing.T) { func TestGetMaxRequests(t *testing.T) { storage, rl := setupTest() - data := RequestData{ + data := &RequestData{ Tag: "testTag", Period: time.Second, MaxReqs: 10, @@ -63,7 +63,7 @@ func TestIsLimitReachedWithinPeriod(t *testing.T) { interval := time.Second // Set up the storage with test data - data := RequestData{ + data := &RequestData{ Tag: tag, Period: interval, CreatedAt: time.Now(), @@ -95,7 +95,7 @@ func TestIsLimitReachedWhenPeriodPassed(t *testing.T) { interval := time.Second // Set up the storage with test data - data := RequestData{ + data := &RequestData{ Tag: tag, Period: interval, CreatedAt: time.Now().Add(-interval), diff --git a/services/wallet/transfer/commands_sequential.go b/services/wallet/transfer/commands_sequential.go index 156134544..583c9adde 100644 --- a/services/wallet/transfer/commands_sequential.go +++ b/services/wallet/transfer/commands_sequential.go @@ -29,8 +29,8 @@ const ( transferHistoryTag = "transfer_history" newTransferHistoryTag = "new_transfer_history" - transferHistoryMaxRequests = 100 - transferHistoryMaxRequestsPeriod = 10 * time.Second + transferHistoryMaxRequests = 10000 + transferHistoryMaxRequestsPeriod = 24 * time.Hour ) type nonceInfo struct { @@ -1122,7 +1122,7 @@ func (c *loadBlocksAndTransfersCommand) fetchHistoryBlocksForAccount(group *asyn log.Debug("range item", "r", rangeItem, "n", c.chainClient.NetworkID(), "a", account) chainClient := chain.ClientWithTag(c.chainClient, transferHistoryTag) - limiter := chain.NewRequestLimiter(chain.NewInMemRequestsStorage()) + limiter := chain.NewRequestLimiter(chain.NewInMemRequestsMapStorage()) limiter.SetMaxRequests(transferHistoryTag, transferHistoryMaxRequests, transferHistoryMaxRequestsPeriod) chainClient.SetLimiter(limiter) diff --git a/services/wallet/transfer/commands_sequential_test.go b/services/wallet/transfer/commands_sequential_test.go index fe530a182..7f93360d7 100644 --- a/services/wallet/transfer/commands_sequential_test.go +++ b/services/wallet/transfer/commands_sequential_test.go @@ -64,6 +64,24 @@ type TestClient struct { rw sync.RWMutex callsCounter map[string]int currentBlock uint64 + limiter chain.RequestLimiter +} + +var countAndlog = func(tc *TestClient, method string, params ...interface{}) error { + tc.incCounter(method) + if tc.traceAPICalls { + if len(params) > 0 { + tc.t.Log(method, params) + } else { + tc.t.Log(method) + } + } + + return nil +} + +func (tc *TestClient) countAndlog(method string, params ...interface{}) error { + return countAndlog(tc, method, params...) } func (tc *TestClient) incCounter(method string) { @@ -109,42 +127,42 @@ func (tc *TestClient) BatchCallContext(ctx context.Context, b []rpc.BatchElem) e } func (tc *TestClient) HeaderByHash(ctx context.Context, hash common.Hash) (*types.Header, error) { - tc.incCounter("HeaderByHash") - if tc.traceAPICalls { - tc.t.Log("HeaderByHash") + err := tc.countAndlog("HeaderByHash") + if err != nil { + return nil, err } return nil, nil } func (tc *TestClient) BlockByHash(ctx context.Context, hash common.Hash) (*types.Block, error) { - tc.incCounter("BlockByHash") - if tc.traceAPICalls { - tc.t.Log("BlockByHash") + err := tc.countAndlog("BlockByHash") + if err != nil { + return nil, err } return nil, nil } func (tc *TestClient) BlockByNumber(ctx context.Context, number *big.Int) (*types.Block, error) { - tc.incCounter("BlockByNumber") - if tc.traceAPICalls { - tc.t.Log("BlockByNumber") + err := tc.countAndlog("BlockByNumber") + if err != nil { + return nil, err } return nil, nil } func (tc *TestClient) NonceAt(ctx context.Context, account common.Address, blockNumber *big.Int) (uint64, error) { - tc.incCounter("NonceAt") nonce := tc.nonceHistory[account][blockNumber.Uint64()] - if tc.traceAPICalls { - tc.t.Log("NonceAt", blockNumber, "result:", nonce) + err := tc.countAndlog("NonceAt", fmt.Sprintf("result: %d", nonce)) + if err != nil { + return nonce, err } return nonce, nil } func (tc *TestClient) FilterLogs(ctx context.Context, q ethereum.FilterQuery) ([]types.Log, error) { - tc.incCounter("FilterLogs") - if tc.traceAPICalls { - tc.t.Log("FilterLogs") + err := tc.countAndlog("FilterLogs") + if err != nil { + return nil, err } // We do not verify addresses for now @@ -242,12 +260,12 @@ func (tc *TestClient) getBalance(address common.Address, blockNumber *big.Int) * } func (tc *TestClient) BalanceAt(ctx context.Context, account common.Address, blockNumber *big.Int) (*big.Int, error) { - tc.incCounter("BalanceAt") balance := tc.getBalance(account, blockNumber) - - if tc.traceAPICalls { - tc.t.Log("BalanceAt", blockNumber, "account:", account, "result:", balance) + err := tc.countAndlog("BalanceAt", fmt.Sprintf("account: %s, result: %d", account, balance)) + if err != nil { + return nil, err } + return balance, nil } @@ -264,14 +282,15 @@ func (tc *TestClient) tokenBalanceAt(account common.Address, token common.Addres } func (tc *TestClient) HeaderByNumber(ctx context.Context, number *big.Int) (*types.Header, error) { - tc.incCounter("HeaderByNumber") if number == nil { number = big.NewInt(int64(tc.currentBlock)) } - if tc.traceAPICalls { - tc.t.Log("HeaderByNumber", number) + err := tc.countAndlog("HeaderByNumber", fmt.Sprintf("number: %d", number)) + if err != nil { + return nil, err } + header := &types.Header{ Number: number, Time: 0, @@ -281,17 +300,17 @@ func (tc *TestClient) HeaderByNumber(ctx context.Context, number *big.Int) (*typ } func (tc *TestClient) CallBlockHashByTransaction(ctx context.Context, blockNumber *big.Int, index uint) (common.Hash, error) { - tc.incCounter("FullTransactionByBlockNumberAndIndex") - if tc.traceAPICalls { - tc.t.Log("FullTransactionByBlockNumberAndIndex") + err := tc.countAndlog("CallBlockHashByTransaction") + if err != nil { + return common.Hash{}, err } return common.BigToHash(blockNumber), nil } func (tc *TestClient) GetBaseFeeFromBlock(ctx context.Context, blockNumber *big.Int) (string, error) { - tc.incCounter("GetBaseFeeFromBlock") - if tc.traceAPICalls { - tc.t.Log("GetBaseFeeFromBlock") + err := tc.countAndlog("GetBaseFeeFromBlock") + if err != nil { + return "", err } return "", nil } @@ -311,10 +330,7 @@ var ethscanAddress = common.HexToAddress("0x000000000000000000000000000000000077 var balanceCheckAddress = common.HexToAddress("0x0000000000000000000000000000000010777333") func (tc *TestClient) CodeAt(ctx context.Context, contract common.Address, blockNumber *big.Int) ([]byte, error) { - tc.incCounter("CodeAt") - if tc.traceAPICalls { - tc.t.Log("CodeAt", contract, blockNumber) - } + tc.countAndlog("CodeAt", fmt.Sprintf("contract: %s, blockNumber: %d", contract, blockNumber)) if ethscanAddress == contract || balanceCheckAddress == contract { return []byte{1}, nil @@ -324,9 +340,9 @@ func (tc *TestClient) CodeAt(ctx context.Context, contract common.Address, block } func (tc *TestClient) CallContract(ctx context.Context, call ethereum.CallMsg, blockNumber *big.Int) ([]byte, error) { - tc.incCounter("CallContract") - if tc.traceAPICalls { - tc.t.Log("CallContract", call, blockNumber, call.To) + err := tc.countAndlog("CallContract", fmt.Sprintf("call: %v, blockNumber: %d, to: %s", call, blockNumber, call.To)) + if err != nil { + return nil, err } if *call.To == ethscanAddress { @@ -557,9 +573,9 @@ func (tc *TestClient) prepareTokenBalanceHistory(toBlock int) { } func (tc *TestClient) CallContext(ctx context.Context, result interface{}, method string, args ...interface{}) error { - tc.incCounter("CallContext") - if tc.traceAPICalls { - tc.t.Log("CallContext") + err := tc.countAndlog("CallContext") + if err != nil { + return err } return nil } @@ -578,91 +594,79 @@ func (tc *TestClient) SetWalletNotifier(notifier func(chainId uint64, message st } func (tc *TestClient) EstimateGas(ctx context.Context, call ethereum.CallMsg) (gas uint64, err error) { - tc.incCounter("EstimateGas") - if tc.traceAPICalls { - tc.t.Log("EstimateGas") + err = tc.countAndlog("EstimateGas") + if err != nil { + return 0, err } return 0, nil } func (tc *TestClient) PendingCodeAt(ctx context.Context, account common.Address) ([]byte, error) { - tc.incCounter("PendingCodeAt") - if tc.traceAPICalls { - tc.t.Log("PendingCodeAt") + err := tc.countAndlog("PendingCodeAt") + if err != nil { + return nil, err } - return nil, nil } func (tc *TestClient) PendingCallContract(ctx context.Context, call ethereum.CallMsg) ([]byte, error) { - tc.incCounter("PendingCallContract") - if tc.traceAPICalls { - tc.t.Log("PendingCallContract") + err := tc.countAndlog("PendingCallContract") + if err != nil { + return nil, err } - return nil, nil } func (tc *TestClient) PendingNonceAt(ctx context.Context, account common.Address) (uint64, error) { - tc.incCounter("PendingNonceAt") - if tc.traceAPICalls { - tc.t.Log("PendingNonceAt") + err := tc.countAndlog("PendingNonceAt") + if err != nil { + return 0, err } - return 0, nil } func (tc *TestClient) SuggestGasPrice(ctx context.Context) (*big.Int, error) { - tc.incCounter("SuggestGasPrice") - if tc.traceAPICalls { - tc.t.Log("SuggestGasPrice") + err := tc.countAndlog("SuggestGasPrice") + if err != nil { + return nil, err } - return nil, nil } func (tc *TestClient) SendTransaction(ctx context.Context, tx *types.Transaction) error { - tc.incCounter("SendTransaction") - if tc.traceAPICalls { - tc.t.Log("SendTransaction") + err := tc.countAndlog("SendTransaction") + if err != nil { + return err } - return nil } func (tc *TestClient) SuggestGasTipCap(ctx context.Context) (*big.Int, error) { - tc.incCounter("SuggestGasTipCap") - if tc.traceAPICalls { - tc.t.Log("SuggestGasTipCap") + err := tc.countAndlog("SuggestGasTipCap") + if err != nil { + return nil, err } - return nil, nil } func (tc *TestClient) BatchCallContextIgnoringLocalHandlers(ctx context.Context, b []rpc.BatchElem) error { - tc.incCounter("BatchCallContextIgnoringLocalHandlers") - if tc.traceAPICalls { - tc.t.Log("BatchCallContextIgnoringLocalHandlers") + err := tc.countAndlog("BatchCallContextIgnoringLocalHandlers") + if err != nil { + return err } - return nil } func (tc *TestClient) CallContextIgnoringLocalHandlers(ctx context.Context, result interface{}, method string, args ...interface{}) error { - tc.incCounter("CallContextIgnoringLocalHandlers") - if tc.traceAPICalls { - tc.t.Log("CallContextIgnoringLocalHandlers") + err := tc.countAndlog("CallContextIgnoringLocalHandlers") + if err != nil { + return err } - return nil } func (tc *TestClient) CallRaw(data string) string { - tc.incCounter("CallRaw") - if tc.traceAPICalls { - tc.t.Log("CallRaw") - } - + _ = tc.countAndlog("CallRaw") return "" } @@ -671,38 +675,34 @@ func (tc *TestClient) GetChainID() *big.Int { } func (tc *TestClient) SubscribeFilterLogs(ctx context.Context, q ethereum.FilterQuery, ch chan<- types.Log) (ethereum.Subscription, error) { - tc.incCounter("SubscribeFilterLogs") - if tc.traceAPICalls { - tc.t.Log("SubscribeFilterLogs") + err := tc.countAndlog("SubscribeFilterLogs") + if err != nil { + return nil, err } - return nil, nil } func (tc *TestClient) TransactionReceipt(ctx context.Context, txHash common.Hash) (*types.Receipt, error) { - tc.incCounter("TransactionReceipt") - if tc.traceAPICalls { - tc.t.Log("TransactionReceipt") + err := tc.countAndlog("TransactionReceipt") + if err != nil { + return nil, err } - return nil, nil } func (tc *TestClient) TransactionByHash(ctx context.Context, txHash common.Hash) (*types.Transaction, bool, error) { - tc.incCounter("TransactionByHash") - if tc.traceAPICalls { - tc.t.Log("TransactionByHash") + err := tc.countAndlog("TransactionByHash") + if err != nil { + return nil, false, err } - return nil, false, nil } func (tc *TestClient) BlockNumber(ctx context.Context) (uint64, error) { - tc.incCounter("BlockNumber") - if tc.traceAPICalls { - tc.t.Log("BlockNumber") + err := tc.countAndlog("BlockNumber") + if err != nil { + return 0, err } - return 0, nil } func (tc *TestClient) SetIsConnected(value bool) { @@ -711,7 +711,7 @@ func (tc *TestClient) SetIsConnected(value bool) { } } -func (tc *TestClient) GetIsConnected() bool { +func (tc *TestClient) IsConnected() bool { if tc.traceAPICalls { tc.t.Log("GetIsConnected") } @@ -719,6 +719,14 @@ func (tc *TestClient) GetIsConnected() bool { return true } +func (tc *TestClient) GetLimiter() chain.RequestLimiter { + return tc.limiter +} + +func (tc *TestClient) SetLimiter(limiter chain.RequestLimiter) { + tc.limiter = limiter +} + type testERC20Transfer struct { block *big.Int address common.Address @@ -1010,82 +1018,113 @@ func getCases() []findBlockCase { var tokenTXXAddress = common.HexToAddress("0x53211") var tokenTXYAddress = common.HexToAddress("0x73211") +func setupFindBlocksCommand(t *testing.T, accountAddress common.Address, fromBlock, toBlock *big.Int, rangeSize int, balances map[common.Address][][]int, outgoingERC20Transfers, incomingERC20Transfers, outgoingERC1155SingleTransfers, incomingERC1155SingleTransfers map[common.Address][]testERC20Transfer) (*findBlocksCommand, *TestClient, chan []*DBHeader, *BlockRangeSequentialDAO) { + appdb, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{}) + require.NoError(t, err) + + db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) + require.NoError(t, err) + + mediaServer, err := server.NewMediaServer(appdb, nil, nil, db) + require.NoError(t, err) + + wdb := NewDB(db) + tc := &TestClient{ + t: t, + balances: balances, + outgoingERC20Transfers: outgoingERC20Transfers, + incomingERC20Transfers: incomingERC20Transfers, + outgoingERC1155SingleTransfers: outgoingERC1155SingleTransfers, + incomingERC1155SingleTransfers: incomingERC1155SingleTransfers, + callsCounter: map[string]int{}, + } + // tc.traceAPICalls = true + // tc.printPreparedData = true + tc.prepareBalanceHistory(100) + tc.prepareTokenBalanceHistory(100) + blockChannel := make(chan []*DBHeader, 100) + + // Reimplement the common function that is called from every method to check for the limit + countAndlog = func(tc *TestClient, method string, params ...interface{}) error { + if tc.GetLimiter() != nil { + if limited, _ := tc.GetLimiter().IsLimitReached(transferHistoryTag); limited { + t.Log("ERROR: requests over limit") + return chain.ErrRequestsOverLimit + } + } + + tc.incCounter(method) + if tc.traceAPICalls { + if len(params) > 0 { + tc.t.Log(method, params) + } else { + tc.t.Log(method) + } + } + + return nil + } + client, _ := statusRpc.NewClient(nil, 1, params.UpstreamRPCConfig{Enabled: false, URL: ""}, []params.Network{}, db) + client.SetClient(tc.NetworkID(), tc) + tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil, nil, nil) + tokenManager.SetTokens([]*token.Token{ + { + Address: tokenTXXAddress, + Symbol: "TXX", + Decimals: 18, + ChainID: tc.NetworkID(), + Name: "Test Token 1", + Verified: true, + }, + { + Address: tokenTXYAddress, + Symbol: "TXY", + Decimals: 18, + ChainID: tc.NetworkID(), + Name: "Test Token 2", + Verified: true, + }, + }) + accDB, err := accounts.NewDB(appdb) + require.NoError(t, err) + blockRangeDAO := &BlockRangeSequentialDAO{wdb.client} + fbc := &findBlocksCommand{ + accounts: []common.Address{accountAddress}, + db: wdb, + blockRangeDAO: blockRangeDAO, + accountsDB: accDB, + chainClient: tc, + balanceCacher: balance.NewCacherWithTTL(5 * time.Minute), + feed: &event.Feed{}, + noLimit: false, + fromBlockNumber: fromBlock, + toBlockNumber: toBlock, + blocksLoadedCh: blockChannel, + defaultNodeBlockChunkSize: rangeSize, + tokenManager: tokenManager, + } + return fbc, tc, blockChannel, blockRangeDAO +} + func TestFindBlocksCommand(t *testing.T) { for idx, testCase := range getCases() { t.Log("case #", idx+1) - ctx := context.Background() - group := async.NewGroup(ctx) - - appdb, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{}) - require.NoError(t, err) - - db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) - require.NoError(t, err) - tm := &TransactionManager{db, nil, nil, nil, nil, nil, nil, nil, nil, nil} - - mediaServer, err := server.NewMediaServer(appdb, nil, nil, db) - require.NoError(t, err) accountAddress := common.HexToAddress("0x1234") - wdb := NewDB(db) - tc := &TestClient{ - t: t, - balances: map[common.Address][][]int{accountAddress: testCase.balanceChanges}, - outgoingERC20Transfers: map[common.Address][]testERC20Transfer{accountAddress: testCase.outgoingERC20Transfers}, - incomingERC20Transfers: map[common.Address][]testERC20Transfer{accountAddress: testCase.incomingERC20Transfers}, - outgoingERC1155SingleTransfers: map[common.Address][]testERC20Transfer{accountAddress: testCase.outgoingERC1155SingleTransfers}, - incomingERC1155SingleTransfers: map[common.Address][]testERC20Transfer{accountAddress: testCase.incomingERC1155SingleTransfers}, - callsCounter: map[string]int{}, - } - // tc.traceAPICalls = true - // tc.printPreparedData = true - tc.prepareBalanceHistory(100) - tc.prepareTokenBalanceHistory(100) - blockChannel := make(chan []*DBHeader, 100) rangeSize := 20 if testCase.rangeSize != 0 { rangeSize = testCase.rangeSize } - client, _ := statusRpc.NewClient(nil, 1, params.UpstreamRPCConfig{Enabled: false, URL: ""}, []params.Network{}, db) - client.SetClient(tc.NetworkID(), tc) - tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil, nil, nil) - tokenManager.SetTokens([]*token.Token{ - { - Address: tokenTXXAddress, - Symbol: "TXX", - Decimals: 18, - ChainID: tc.NetworkID(), - Name: "Test Token 1", - Verified: true, - }, - { - Address: tokenTXYAddress, - Symbol: "TXY", - Decimals: 18, - ChainID: tc.NetworkID(), - Name: "Test Token 2", - Verified: true, - }, - }) - accDB, err := accounts.NewDB(appdb) - require.NoError(t, err) - blockRangeDAO := &BlockRangeSequentialDAO{wdb.client} - fbc := &findBlocksCommand{ - accounts: []common.Address{accountAddress}, - db: wdb, - blockRangeDAO: blockRangeDAO, - accountsDB: accDB, - chainClient: tc, - balanceCacher: balance.NewCacherWithTTL(5 * time.Minute), - feed: &event.Feed{}, - noLimit: false, - fromBlockNumber: big.NewInt(testCase.fromBlock), - toBlockNumber: big.NewInt(testCase.toBlock), - transactionManager: tm, - blocksLoadedCh: blockChannel, - defaultNodeBlockChunkSize: rangeSize, - tokenManager: tokenManager, - } + + balances := map[common.Address][][]int{accountAddress: testCase.balanceChanges} + outgoingERC20Transfers := map[common.Address][]testERC20Transfer{accountAddress: testCase.outgoingERC20Transfers} + incomingERC20Transfers := map[common.Address][]testERC20Transfer{accountAddress: testCase.incomingERC20Transfers} + outgoingERC1155SingleTransfers := map[common.Address][]testERC20Transfer{accountAddress: testCase.outgoingERC1155SingleTransfers} + incomingERC1155SingleTransfers := map[common.Address][]testERC20Transfer{accountAddress: testCase.incomingERC1155SingleTransfers} + + fbc, tc, blockChannel, blockRangeDAO := setupFindBlocksCommand(t, accountAddress, big.NewInt(testCase.fromBlock), big.NewInt(testCase.toBlock), rangeSize, balances, outgoingERC20Transfers, incomingERC20Transfers, outgoingERC1155SingleTransfers, incomingERC1155SingleTransfers) + ctx := context.Background() + group := async.NewGroup(ctx) group.Add(fbc.Command()) foundBlocks := []*DBHeader{} @@ -1129,6 +1168,62 @@ func TestFindBlocksCommand(t *testing.T) { } } +func TestFindBlocksCommandWithLimiter(t *testing.T) { + // Set up logging + // handler := log.StreamHandler(os.Stdout, log.TerminalFormat(true)) + // log.Root().SetHandler(handler) + + maxRequests := 1 + rangeSize := 20 + accountAddress := common.HexToAddress("0x1234") + balances := map[common.Address][][]int{accountAddress: {{5, 1, 0}, {20, 2, 0}, {45, 1, 1}, {46, 50, 0}, {75, 0, 1}}} + fbc, tc, blockChannel, _ := setupFindBlocksCommand(t, accountAddress, big.NewInt(0), big.NewInt(20), rangeSize, balances, nil, nil, nil, nil) + + limiter := chain.NewRequestLimiter(chain.NewInMemRequestsMapStorage()) + limiter.SetMaxRequests(transferHistoryTag, maxRequests, time.Hour) + tc.SetLimiter(limiter) + + ctx := context.Background() + group := async.NewAtomicGroup(ctx) + group.Add(fbc.Command(1 * time.Millisecond)) + + select { + case <-ctx.Done(): + t.Log("ERROR") + case <-group.WaitAsync(): + close(blockChannel) + require.Error(t, chain.ErrRequestsOverLimit, group.Error()) + require.Equal(t, maxRequests, tc.getCounter()) + } +} + +func TestFindBlocksCommandWithLimiterTagDifferentThanTransfers(t *testing.T) { + rangeSize := 20 + maxRequests := 1 + accountAddress := common.HexToAddress("0x1234") + balances := map[common.Address][][]int{accountAddress: {{5, 1, 0}, {20, 2, 0}, {45, 1, 1}, {46, 50, 0}, {75, 0, 1}}} + outgoingERC20Transfers := map[common.Address][]testERC20Transfer{accountAddress: {{big.NewInt(6), tokenTXXAddress, big.NewInt(1), walletcommon.Erc20TransferEventType}}} + incomingERC20Transfers := map[common.Address][]testERC20Transfer{accountAddress: {{big.NewInt(6), tokenTXXAddress, big.NewInt(1), walletcommon.Erc20TransferEventType}}} + + fbc, tc, blockChannel, _ := setupFindBlocksCommand(t, accountAddress, big.NewInt(0), big.NewInt(20), rangeSize, balances, outgoingERC20Transfers, incomingERC20Transfers, nil, nil) + limiter := chain.NewRequestLimiter(chain.NewInMemRequestsMapStorage()) + limiter.SetMaxRequests("some-other-tag-than-transfer-history", maxRequests, time.Hour) + tc.SetLimiter(limiter) + + ctx := context.Background() + group := async.NewAtomicGroup(ctx) + group.Add(fbc.Command(1 * time.Millisecond)) + + select { + case <-ctx.Done(): + t.Log("ERROR") + case <-group.WaitAsync(): + close(blockChannel) + require.NoError(t, group.Error()) + require.Greater(t, tc.getCounter(), maxRequests) + } +} + type MockETHClient struct { mock.Mock } @@ -1228,7 +1323,7 @@ func TestFetchTransfersForLoadedBlocks(t *testing.T) { tc.prepareBalanceHistory(int(tc.currentBlock)) tc.prepareTokenBalanceHistory(int(tc.currentBlock)) - tc.traceAPICalls = true + // tc.traceAPICalls = true ctx := context.Background() group := async.NewAtomicGroup(ctx) @@ -1278,7 +1373,6 @@ func TestFetchNewBlocksCommand_findBlocksWithEthTransfers(t *testing.T) { db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) require.NoError(t, err) - tm := &TransactionManager{db, nil, nil, nil, nil, nil, nil, nil, nil, nil} mediaServer, err := server.NewMediaServer(appdb, nil, nil, db) require.NoError(t, err) @@ -1334,7 +1428,6 @@ func TestFetchNewBlocksCommand_findBlocksWithEthTransfers(t *testing.T) { balanceCacher: balance.NewCacherWithTTL(5 * time.Minute), feed: &event.Feed{}, noLimit: false, - transactionManager: tm, tokenManager: tokenManager, blocksLoadedCh: blockChannel, defaultNodeBlockChunkSize: DefaultNodeBlockChunkSize, @@ -1378,7 +1471,6 @@ func TestFetchNewBlocksCommand_nonceDetection(t *testing.T) { db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) require.NoError(t, err) - tm := &TransactionManager{db, nil, nil, nil, nil, nil, nil, nil, nil, nil} mediaServer, err := server.NewMediaServer(appdb, nil, nil, db) require.NoError(t, err) @@ -1405,7 +1497,6 @@ func TestFetchNewBlocksCommand_nonceDetection(t *testing.T) { balanceCacher: balance.NewCacherWithTTL(5 * time.Minute), feed: &event.Feed{}, noLimit: false, - transactionManager: tm, tokenManager: tokenManager, blocksLoadedCh: blockChannel, defaultNodeBlockChunkSize: scanRange, @@ -1464,7 +1555,6 @@ func TestFetchNewBlocksCommand(t *testing.T) { db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) require.NoError(t, err) - tm := &TransactionManager{db, nil, nil, nil, nil, nil, nil, nil, nil, nil} mediaServer, err := server.NewMediaServer(appdb, nil, nil, db) require.NoError(t, err) @@ -1534,7 +1624,6 @@ func TestFetchNewBlocksCommand(t *testing.T) { feed: &event.Feed{}, noLimit: false, fromBlockNumber: big.NewInt(int64(tc.currentBlock)), - transactionManager: tm, tokenManager: tokenManager, blocksLoadedCh: blockChannel, defaultNodeBlockChunkSize: DefaultNodeBlockChunkSize,