feat(wallet)_: added and fixed tests for findBlocksCommand with limiter.

This commit is contained in:
Ivan Belyakov 2024-05-21 15:05:04 +02:00 committed by IvanBelyakoff
parent 78f05f60b2
commit 9fe87657d6
5 changed files with 292 additions and 195 deletions

View File

@ -238,7 +238,7 @@ func (c *ClientWithFallback) IsConnected() bool {
func (c *ClientWithFallback) makeCall(ctx context.Context, main func() ([]any, error), fallback func() ([]any, error)) ([]any, error) {
if c.commonLimiter != nil {
if limited, err := c.commonLimiter.IsLimitReached(c.tag); limited {
return nil, fmt.Errorf("rate limit exceeded for %s: %s", c.tag, err)
return nil, fmt.Errorf("tag=%s, %w", c.tag, err)
}
}

View File

@ -5,8 +5,9 @@ import (
"sync"
"time"
"github.com/ethereum/go-ethereum/log"
"github.com/google/uuid"
"github.com/ethereum/go-ethereum/log"
)
const (
@ -27,25 +28,29 @@ type callerOnWait struct {
}
type RequestsStorage interface {
Get(tag string) (RequestData, error)
Set(data RequestData) error
Get(tag string) (*RequestData, error)
Set(data *RequestData) error
}
// InMemRequestsStorage is an in-memory dummy implementation of RequestsStorage
type InMemRequestsStorage struct {
data RequestData
type InMemRequestsMapStorage struct {
data sync.Map
}
func NewInMemRequestsStorage() *InMemRequestsStorage {
return &InMemRequestsStorage{}
func NewInMemRequestsMapStorage() *InMemRequestsMapStorage {
return &InMemRequestsMapStorage{}
}
func (s *InMemRequestsStorage) Get(tag string) (RequestData, error) {
return s.data, nil
func (s *InMemRequestsMapStorage) Get(tag string) (*RequestData, error) {
data, ok := s.data.Load(tag)
if !ok {
return nil, nil
}
return data.(*RequestData), nil
}
func (s *InMemRequestsStorage) Set(data RequestData) error {
s.data = data
func (s *InMemRequestsMapStorage) Set(data *RequestData) error {
s.data.Store(data.Tag, data)
return nil
}
@ -59,7 +64,7 @@ type RequestData struct {
type RequestLimiter interface {
SetMaxRequests(tag string, maxRequests int, interval time.Duration) error
GetMaxRequests(tag string) (RequestData, error)
GetMaxRequests(tag string) (*RequestData, error)
IsLimitReached(tag string) (bool, error)
}
@ -83,18 +88,17 @@ func (rl *RPCRequestLimiter) SetMaxRequests(tag string, maxRequests int, interva
return nil
}
func (rl *RPCRequestLimiter) GetMaxRequests(tag string) (RequestData, error) {
func (rl *RPCRequestLimiter) GetMaxRequests(tag string) (*RequestData, error) {
data, err := rl.storage.Get(tag)
if err != nil {
log.Error("Failed to get request data from storage", "error", err, "tag", tag)
return RequestData{}, err
return nil, err
}
return data, nil
}
func (rl *RPCRequestLimiter) saveToStorage(tag string, maxRequests int, interval time.Duration, numReqs int, timestamp time.Time) error {
data := RequestData{
data := &RequestData{
Tag: tag,
CreatedAt: timestamp,
Period: interval,
@ -117,6 +121,10 @@ func (rl *RPCRequestLimiter) IsLimitReached(tag string) (bool, error) {
return false, err
}
if data == nil {
return false, nil
}
// Check if a number of requests is over the limit within the interval
if time.Since(data.CreatedAt) < data.Period {
if data.NumReqs >= data.MaxReqs {

View File

@ -7,8 +7,8 @@ import (
"github.com/stretchr/testify/require"
)
func setupTest() (*InMemRequestsStorage, RequestLimiter) {
storage := NewInMemRequestsStorage()
func setupTest() (*InMemRequestsMapStorage, RequestLimiter) {
storage := NewInMemRequestsMapStorage()
rl := NewRequestLimiter(storage)
return storage, rl
}
@ -37,7 +37,7 @@ func TestSetMaxRequests(t *testing.T) {
func TestGetMaxRequests(t *testing.T) {
storage, rl := setupTest()
data := RequestData{
data := &RequestData{
Tag: "testTag",
Period: time.Second,
MaxReqs: 10,
@ -63,7 +63,7 @@ func TestIsLimitReachedWithinPeriod(t *testing.T) {
interval := time.Second
// Set up the storage with test data
data := RequestData{
data := &RequestData{
Tag: tag,
Period: interval,
CreatedAt: time.Now(),
@ -95,7 +95,7 @@ func TestIsLimitReachedWhenPeriodPassed(t *testing.T) {
interval := time.Second
// Set up the storage with test data
data := RequestData{
data := &RequestData{
Tag: tag,
Period: interval,
CreatedAt: time.Now().Add(-interval),

View File

@ -29,8 +29,8 @@ const (
transferHistoryTag = "transfer_history"
newTransferHistoryTag = "new_transfer_history"
transferHistoryMaxRequests = 100
transferHistoryMaxRequestsPeriod = 10 * time.Second
transferHistoryMaxRequests = 10000
transferHistoryMaxRequestsPeriod = 24 * time.Hour
)
type nonceInfo struct {
@ -1122,7 +1122,7 @@ func (c *loadBlocksAndTransfersCommand) fetchHistoryBlocksForAccount(group *asyn
log.Debug("range item", "r", rangeItem, "n", c.chainClient.NetworkID(), "a", account)
chainClient := chain.ClientWithTag(c.chainClient, transferHistoryTag)
limiter := chain.NewRequestLimiter(chain.NewInMemRequestsStorage())
limiter := chain.NewRequestLimiter(chain.NewInMemRequestsMapStorage())
limiter.SetMaxRequests(transferHistoryTag, transferHistoryMaxRequests, transferHistoryMaxRequestsPeriod)
chainClient.SetLimiter(limiter)

View File

@ -64,6 +64,24 @@ type TestClient struct {
rw sync.RWMutex
callsCounter map[string]int
currentBlock uint64
limiter chain.RequestLimiter
}
var countAndlog = func(tc *TestClient, method string, params ...interface{}) error {
tc.incCounter(method)
if tc.traceAPICalls {
if len(params) > 0 {
tc.t.Log(method, params)
} else {
tc.t.Log(method)
}
}
return nil
}
func (tc *TestClient) countAndlog(method string, params ...interface{}) error {
return countAndlog(tc, method, params...)
}
func (tc *TestClient) incCounter(method string) {
@ -109,42 +127,42 @@ func (tc *TestClient) BatchCallContext(ctx context.Context, b []rpc.BatchElem) e
}
func (tc *TestClient) HeaderByHash(ctx context.Context, hash common.Hash) (*types.Header, error) {
tc.incCounter("HeaderByHash")
if tc.traceAPICalls {
tc.t.Log("HeaderByHash")
err := tc.countAndlog("HeaderByHash")
if err != nil {
return nil, err
}
return nil, nil
}
func (tc *TestClient) BlockByHash(ctx context.Context, hash common.Hash) (*types.Block, error) {
tc.incCounter("BlockByHash")
if tc.traceAPICalls {
tc.t.Log("BlockByHash")
err := tc.countAndlog("BlockByHash")
if err != nil {
return nil, err
}
return nil, nil
}
func (tc *TestClient) BlockByNumber(ctx context.Context, number *big.Int) (*types.Block, error) {
tc.incCounter("BlockByNumber")
if tc.traceAPICalls {
tc.t.Log("BlockByNumber")
err := tc.countAndlog("BlockByNumber")
if err != nil {
return nil, err
}
return nil, nil
}
func (tc *TestClient) NonceAt(ctx context.Context, account common.Address, blockNumber *big.Int) (uint64, error) {
tc.incCounter("NonceAt")
nonce := tc.nonceHistory[account][blockNumber.Uint64()]
if tc.traceAPICalls {
tc.t.Log("NonceAt", blockNumber, "result:", nonce)
err := tc.countAndlog("NonceAt", fmt.Sprintf("result: %d", nonce))
if err != nil {
return nonce, err
}
return nonce, nil
}
func (tc *TestClient) FilterLogs(ctx context.Context, q ethereum.FilterQuery) ([]types.Log, error) {
tc.incCounter("FilterLogs")
if tc.traceAPICalls {
tc.t.Log("FilterLogs")
err := tc.countAndlog("FilterLogs")
if err != nil {
return nil, err
}
// We do not verify addresses for now
@ -242,12 +260,12 @@ func (tc *TestClient) getBalance(address common.Address, blockNumber *big.Int) *
}
func (tc *TestClient) BalanceAt(ctx context.Context, account common.Address, blockNumber *big.Int) (*big.Int, error) {
tc.incCounter("BalanceAt")
balance := tc.getBalance(account, blockNumber)
if tc.traceAPICalls {
tc.t.Log("BalanceAt", blockNumber, "account:", account, "result:", balance)
err := tc.countAndlog("BalanceAt", fmt.Sprintf("account: %s, result: %d", account, balance))
if err != nil {
return nil, err
}
return balance, nil
}
@ -264,14 +282,15 @@ func (tc *TestClient) tokenBalanceAt(account common.Address, token common.Addres
}
func (tc *TestClient) HeaderByNumber(ctx context.Context, number *big.Int) (*types.Header, error) {
tc.incCounter("HeaderByNumber")
if number == nil {
number = big.NewInt(int64(tc.currentBlock))
}
if tc.traceAPICalls {
tc.t.Log("HeaderByNumber", number)
err := tc.countAndlog("HeaderByNumber", fmt.Sprintf("number: %d", number))
if err != nil {
return nil, err
}
header := &types.Header{
Number: number,
Time: 0,
@ -281,17 +300,17 @@ func (tc *TestClient) HeaderByNumber(ctx context.Context, number *big.Int) (*typ
}
func (tc *TestClient) CallBlockHashByTransaction(ctx context.Context, blockNumber *big.Int, index uint) (common.Hash, error) {
tc.incCounter("FullTransactionByBlockNumberAndIndex")
if tc.traceAPICalls {
tc.t.Log("FullTransactionByBlockNumberAndIndex")
err := tc.countAndlog("CallBlockHashByTransaction")
if err != nil {
return common.Hash{}, err
}
return common.BigToHash(blockNumber), nil
}
func (tc *TestClient) GetBaseFeeFromBlock(ctx context.Context, blockNumber *big.Int) (string, error) {
tc.incCounter("GetBaseFeeFromBlock")
if tc.traceAPICalls {
tc.t.Log("GetBaseFeeFromBlock")
err := tc.countAndlog("GetBaseFeeFromBlock")
if err != nil {
return "", err
}
return "", nil
}
@ -311,10 +330,7 @@ var ethscanAddress = common.HexToAddress("0x000000000000000000000000000000000077
var balanceCheckAddress = common.HexToAddress("0x0000000000000000000000000000000010777333")
func (tc *TestClient) CodeAt(ctx context.Context, contract common.Address, blockNumber *big.Int) ([]byte, error) {
tc.incCounter("CodeAt")
if tc.traceAPICalls {
tc.t.Log("CodeAt", contract, blockNumber)
}
tc.countAndlog("CodeAt", fmt.Sprintf("contract: %s, blockNumber: %d", contract, blockNumber))
if ethscanAddress == contract || balanceCheckAddress == contract {
return []byte{1}, nil
@ -324,9 +340,9 @@ func (tc *TestClient) CodeAt(ctx context.Context, contract common.Address, block
}
func (tc *TestClient) CallContract(ctx context.Context, call ethereum.CallMsg, blockNumber *big.Int) ([]byte, error) {
tc.incCounter("CallContract")
if tc.traceAPICalls {
tc.t.Log("CallContract", call, blockNumber, call.To)
err := tc.countAndlog("CallContract", fmt.Sprintf("call: %v, blockNumber: %d, to: %s", call, blockNumber, call.To))
if err != nil {
return nil, err
}
if *call.To == ethscanAddress {
@ -557,9 +573,9 @@ func (tc *TestClient) prepareTokenBalanceHistory(toBlock int) {
}
func (tc *TestClient) CallContext(ctx context.Context, result interface{}, method string, args ...interface{}) error {
tc.incCounter("CallContext")
if tc.traceAPICalls {
tc.t.Log("CallContext")
err := tc.countAndlog("CallContext")
if err != nil {
return err
}
return nil
}
@ -578,91 +594,79 @@ func (tc *TestClient) SetWalletNotifier(notifier func(chainId uint64, message st
}
func (tc *TestClient) EstimateGas(ctx context.Context, call ethereum.CallMsg) (gas uint64, err error) {
tc.incCounter("EstimateGas")
if tc.traceAPICalls {
tc.t.Log("EstimateGas")
err = tc.countAndlog("EstimateGas")
if err != nil {
return 0, err
}
return 0, nil
}
func (tc *TestClient) PendingCodeAt(ctx context.Context, account common.Address) ([]byte, error) {
tc.incCounter("PendingCodeAt")
if tc.traceAPICalls {
tc.t.Log("PendingCodeAt")
err := tc.countAndlog("PendingCodeAt")
if err != nil {
return nil, err
}
return nil, nil
}
func (tc *TestClient) PendingCallContract(ctx context.Context, call ethereum.CallMsg) ([]byte, error) {
tc.incCounter("PendingCallContract")
if tc.traceAPICalls {
tc.t.Log("PendingCallContract")
err := tc.countAndlog("PendingCallContract")
if err != nil {
return nil, err
}
return nil, nil
}
func (tc *TestClient) PendingNonceAt(ctx context.Context, account common.Address) (uint64, error) {
tc.incCounter("PendingNonceAt")
if tc.traceAPICalls {
tc.t.Log("PendingNonceAt")
err := tc.countAndlog("PendingNonceAt")
if err != nil {
return 0, err
}
return 0, nil
}
func (tc *TestClient) SuggestGasPrice(ctx context.Context) (*big.Int, error) {
tc.incCounter("SuggestGasPrice")
if tc.traceAPICalls {
tc.t.Log("SuggestGasPrice")
err := tc.countAndlog("SuggestGasPrice")
if err != nil {
return nil, err
}
return nil, nil
}
func (tc *TestClient) SendTransaction(ctx context.Context, tx *types.Transaction) error {
tc.incCounter("SendTransaction")
if tc.traceAPICalls {
tc.t.Log("SendTransaction")
err := tc.countAndlog("SendTransaction")
if err != nil {
return err
}
return nil
}
func (tc *TestClient) SuggestGasTipCap(ctx context.Context) (*big.Int, error) {
tc.incCounter("SuggestGasTipCap")
if tc.traceAPICalls {
tc.t.Log("SuggestGasTipCap")
err := tc.countAndlog("SuggestGasTipCap")
if err != nil {
return nil, err
}
return nil, nil
}
func (tc *TestClient) BatchCallContextIgnoringLocalHandlers(ctx context.Context, b []rpc.BatchElem) error {
tc.incCounter("BatchCallContextIgnoringLocalHandlers")
if tc.traceAPICalls {
tc.t.Log("BatchCallContextIgnoringLocalHandlers")
err := tc.countAndlog("BatchCallContextIgnoringLocalHandlers")
if err != nil {
return err
}
return nil
}
func (tc *TestClient) CallContextIgnoringLocalHandlers(ctx context.Context, result interface{}, method string, args ...interface{}) error {
tc.incCounter("CallContextIgnoringLocalHandlers")
if tc.traceAPICalls {
tc.t.Log("CallContextIgnoringLocalHandlers")
err := tc.countAndlog("CallContextIgnoringLocalHandlers")
if err != nil {
return err
}
return nil
}
func (tc *TestClient) CallRaw(data string) string {
tc.incCounter("CallRaw")
if tc.traceAPICalls {
tc.t.Log("CallRaw")
}
_ = tc.countAndlog("CallRaw")
return ""
}
@ -671,38 +675,34 @@ func (tc *TestClient) GetChainID() *big.Int {
}
func (tc *TestClient) SubscribeFilterLogs(ctx context.Context, q ethereum.FilterQuery, ch chan<- types.Log) (ethereum.Subscription, error) {
tc.incCounter("SubscribeFilterLogs")
if tc.traceAPICalls {
tc.t.Log("SubscribeFilterLogs")
err := tc.countAndlog("SubscribeFilterLogs")
if err != nil {
return nil, err
}
return nil, nil
}
func (tc *TestClient) TransactionReceipt(ctx context.Context, txHash common.Hash) (*types.Receipt, error) {
tc.incCounter("TransactionReceipt")
if tc.traceAPICalls {
tc.t.Log("TransactionReceipt")
err := tc.countAndlog("TransactionReceipt")
if err != nil {
return nil, err
}
return nil, nil
}
func (tc *TestClient) TransactionByHash(ctx context.Context, txHash common.Hash) (*types.Transaction, bool, error) {
tc.incCounter("TransactionByHash")
if tc.traceAPICalls {
tc.t.Log("TransactionByHash")
err := tc.countAndlog("TransactionByHash")
if err != nil {
return nil, false, err
}
return nil, false, nil
}
func (tc *TestClient) BlockNumber(ctx context.Context) (uint64, error) {
tc.incCounter("BlockNumber")
if tc.traceAPICalls {
tc.t.Log("BlockNumber")
err := tc.countAndlog("BlockNumber")
if err != nil {
return 0, err
}
return 0, nil
}
func (tc *TestClient) SetIsConnected(value bool) {
@ -711,7 +711,7 @@ func (tc *TestClient) SetIsConnected(value bool) {
}
}
func (tc *TestClient) GetIsConnected() bool {
func (tc *TestClient) IsConnected() bool {
if tc.traceAPICalls {
tc.t.Log("GetIsConnected")
}
@ -719,6 +719,14 @@ func (tc *TestClient) GetIsConnected() bool {
return true
}
func (tc *TestClient) GetLimiter() chain.RequestLimiter {
return tc.limiter
}
func (tc *TestClient) SetLimiter(limiter chain.RequestLimiter) {
tc.limiter = limiter
}
type testERC20Transfer struct {
block *big.Int
address common.Address
@ -1010,82 +1018,113 @@ func getCases() []findBlockCase {
var tokenTXXAddress = common.HexToAddress("0x53211")
var tokenTXYAddress = common.HexToAddress("0x73211")
func setupFindBlocksCommand(t *testing.T, accountAddress common.Address, fromBlock, toBlock *big.Int, rangeSize int, balances map[common.Address][][]int, outgoingERC20Transfers, incomingERC20Transfers, outgoingERC1155SingleTransfers, incomingERC1155SingleTransfers map[common.Address][]testERC20Transfer) (*findBlocksCommand, *TestClient, chan []*DBHeader, *BlockRangeSequentialDAO) {
appdb, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{})
require.NoError(t, err)
db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{})
require.NoError(t, err)
mediaServer, err := server.NewMediaServer(appdb, nil, nil, db)
require.NoError(t, err)
wdb := NewDB(db)
tc := &TestClient{
t: t,
balances: balances,
outgoingERC20Transfers: outgoingERC20Transfers,
incomingERC20Transfers: incomingERC20Transfers,
outgoingERC1155SingleTransfers: outgoingERC1155SingleTransfers,
incomingERC1155SingleTransfers: incomingERC1155SingleTransfers,
callsCounter: map[string]int{},
}
// tc.traceAPICalls = true
// tc.printPreparedData = true
tc.prepareBalanceHistory(100)
tc.prepareTokenBalanceHistory(100)
blockChannel := make(chan []*DBHeader, 100)
// Reimplement the common function that is called from every method to check for the limit
countAndlog = func(tc *TestClient, method string, params ...interface{}) error {
if tc.GetLimiter() != nil {
if limited, _ := tc.GetLimiter().IsLimitReached(transferHistoryTag); limited {
t.Log("ERROR: requests over limit")
return chain.ErrRequestsOverLimit
}
}
tc.incCounter(method)
if tc.traceAPICalls {
if len(params) > 0 {
tc.t.Log(method, params)
} else {
tc.t.Log(method)
}
}
return nil
}
client, _ := statusRpc.NewClient(nil, 1, params.UpstreamRPCConfig{Enabled: false, URL: ""}, []params.Network{}, db)
client.SetClient(tc.NetworkID(), tc)
tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil, nil, nil)
tokenManager.SetTokens([]*token.Token{
{
Address: tokenTXXAddress,
Symbol: "TXX",
Decimals: 18,
ChainID: tc.NetworkID(),
Name: "Test Token 1",
Verified: true,
},
{
Address: tokenTXYAddress,
Symbol: "TXY",
Decimals: 18,
ChainID: tc.NetworkID(),
Name: "Test Token 2",
Verified: true,
},
})
accDB, err := accounts.NewDB(appdb)
require.NoError(t, err)
blockRangeDAO := &BlockRangeSequentialDAO{wdb.client}
fbc := &findBlocksCommand{
accounts: []common.Address{accountAddress},
db: wdb,
blockRangeDAO: blockRangeDAO,
accountsDB: accDB,
chainClient: tc,
balanceCacher: balance.NewCacherWithTTL(5 * time.Minute),
feed: &event.Feed{},
noLimit: false,
fromBlockNumber: fromBlock,
toBlockNumber: toBlock,
blocksLoadedCh: blockChannel,
defaultNodeBlockChunkSize: rangeSize,
tokenManager: tokenManager,
}
return fbc, tc, blockChannel, blockRangeDAO
}
func TestFindBlocksCommand(t *testing.T) {
for idx, testCase := range getCases() {
t.Log("case #", idx+1)
ctx := context.Background()
group := async.NewGroup(ctx)
appdb, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{})
require.NoError(t, err)
db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{})
require.NoError(t, err)
tm := &TransactionManager{db, nil, nil, nil, nil, nil, nil, nil, nil, nil}
mediaServer, err := server.NewMediaServer(appdb, nil, nil, db)
require.NoError(t, err)
accountAddress := common.HexToAddress("0x1234")
wdb := NewDB(db)
tc := &TestClient{
t: t,
balances: map[common.Address][][]int{accountAddress: testCase.balanceChanges},
outgoingERC20Transfers: map[common.Address][]testERC20Transfer{accountAddress: testCase.outgoingERC20Transfers},
incomingERC20Transfers: map[common.Address][]testERC20Transfer{accountAddress: testCase.incomingERC20Transfers},
outgoingERC1155SingleTransfers: map[common.Address][]testERC20Transfer{accountAddress: testCase.outgoingERC1155SingleTransfers},
incomingERC1155SingleTransfers: map[common.Address][]testERC20Transfer{accountAddress: testCase.incomingERC1155SingleTransfers},
callsCounter: map[string]int{},
}
// tc.traceAPICalls = true
// tc.printPreparedData = true
tc.prepareBalanceHistory(100)
tc.prepareTokenBalanceHistory(100)
blockChannel := make(chan []*DBHeader, 100)
rangeSize := 20
if testCase.rangeSize != 0 {
rangeSize = testCase.rangeSize
}
client, _ := statusRpc.NewClient(nil, 1, params.UpstreamRPCConfig{Enabled: false, URL: ""}, []params.Network{}, db)
client.SetClient(tc.NetworkID(), tc)
tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil, nil, nil)
tokenManager.SetTokens([]*token.Token{
{
Address: tokenTXXAddress,
Symbol: "TXX",
Decimals: 18,
ChainID: tc.NetworkID(),
Name: "Test Token 1",
Verified: true,
},
{
Address: tokenTXYAddress,
Symbol: "TXY",
Decimals: 18,
ChainID: tc.NetworkID(),
Name: "Test Token 2",
Verified: true,
},
})
accDB, err := accounts.NewDB(appdb)
require.NoError(t, err)
blockRangeDAO := &BlockRangeSequentialDAO{wdb.client}
fbc := &findBlocksCommand{
accounts: []common.Address{accountAddress},
db: wdb,
blockRangeDAO: blockRangeDAO,
accountsDB: accDB,
chainClient: tc,
balanceCacher: balance.NewCacherWithTTL(5 * time.Minute),
feed: &event.Feed{},
noLimit: false,
fromBlockNumber: big.NewInt(testCase.fromBlock),
toBlockNumber: big.NewInt(testCase.toBlock),
transactionManager: tm,
blocksLoadedCh: blockChannel,
defaultNodeBlockChunkSize: rangeSize,
tokenManager: tokenManager,
}
balances := map[common.Address][][]int{accountAddress: testCase.balanceChanges}
outgoingERC20Transfers := map[common.Address][]testERC20Transfer{accountAddress: testCase.outgoingERC20Transfers}
incomingERC20Transfers := map[common.Address][]testERC20Transfer{accountAddress: testCase.incomingERC20Transfers}
outgoingERC1155SingleTransfers := map[common.Address][]testERC20Transfer{accountAddress: testCase.outgoingERC1155SingleTransfers}
incomingERC1155SingleTransfers := map[common.Address][]testERC20Transfer{accountAddress: testCase.incomingERC1155SingleTransfers}
fbc, tc, blockChannel, blockRangeDAO := setupFindBlocksCommand(t, accountAddress, big.NewInt(testCase.fromBlock), big.NewInt(testCase.toBlock), rangeSize, balances, outgoingERC20Transfers, incomingERC20Transfers, outgoingERC1155SingleTransfers, incomingERC1155SingleTransfers)
ctx := context.Background()
group := async.NewGroup(ctx)
group.Add(fbc.Command())
foundBlocks := []*DBHeader{}
@ -1129,6 +1168,62 @@ func TestFindBlocksCommand(t *testing.T) {
}
}
func TestFindBlocksCommandWithLimiter(t *testing.T) {
// Set up logging
// handler := log.StreamHandler(os.Stdout, log.TerminalFormat(true))
// log.Root().SetHandler(handler)
maxRequests := 1
rangeSize := 20
accountAddress := common.HexToAddress("0x1234")
balances := map[common.Address][][]int{accountAddress: {{5, 1, 0}, {20, 2, 0}, {45, 1, 1}, {46, 50, 0}, {75, 0, 1}}}
fbc, tc, blockChannel, _ := setupFindBlocksCommand(t, accountAddress, big.NewInt(0), big.NewInt(20), rangeSize, balances, nil, nil, nil, nil)
limiter := chain.NewRequestLimiter(chain.NewInMemRequestsMapStorage())
limiter.SetMaxRequests(transferHistoryTag, maxRequests, time.Hour)
tc.SetLimiter(limiter)
ctx := context.Background()
group := async.NewAtomicGroup(ctx)
group.Add(fbc.Command(1 * time.Millisecond))
select {
case <-ctx.Done():
t.Log("ERROR")
case <-group.WaitAsync():
close(blockChannel)
require.Error(t, chain.ErrRequestsOverLimit, group.Error())
require.Equal(t, maxRequests, tc.getCounter())
}
}
func TestFindBlocksCommandWithLimiterTagDifferentThanTransfers(t *testing.T) {
rangeSize := 20
maxRequests := 1
accountAddress := common.HexToAddress("0x1234")
balances := map[common.Address][][]int{accountAddress: {{5, 1, 0}, {20, 2, 0}, {45, 1, 1}, {46, 50, 0}, {75, 0, 1}}}
outgoingERC20Transfers := map[common.Address][]testERC20Transfer{accountAddress: {{big.NewInt(6), tokenTXXAddress, big.NewInt(1), walletcommon.Erc20TransferEventType}}}
incomingERC20Transfers := map[common.Address][]testERC20Transfer{accountAddress: {{big.NewInt(6), tokenTXXAddress, big.NewInt(1), walletcommon.Erc20TransferEventType}}}
fbc, tc, blockChannel, _ := setupFindBlocksCommand(t, accountAddress, big.NewInt(0), big.NewInt(20), rangeSize, balances, outgoingERC20Transfers, incomingERC20Transfers, nil, nil)
limiter := chain.NewRequestLimiter(chain.NewInMemRequestsMapStorage())
limiter.SetMaxRequests("some-other-tag-than-transfer-history", maxRequests, time.Hour)
tc.SetLimiter(limiter)
ctx := context.Background()
group := async.NewAtomicGroup(ctx)
group.Add(fbc.Command(1 * time.Millisecond))
select {
case <-ctx.Done():
t.Log("ERROR")
case <-group.WaitAsync():
close(blockChannel)
require.NoError(t, group.Error())
require.Greater(t, tc.getCounter(), maxRequests)
}
}
type MockETHClient struct {
mock.Mock
}
@ -1228,7 +1323,7 @@ func TestFetchTransfersForLoadedBlocks(t *testing.T) {
tc.prepareBalanceHistory(int(tc.currentBlock))
tc.prepareTokenBalanceHistory(int(tc.currentBlock))
tc.traceAPICalls = true
// tc.traceAPICalls = true
ctx := context.Background()
group := async.NewAtomicGroup(ctx)
@ -1278,7 +1373,6 @@ func TestFetchNewBlocksCommand_findBlocksWithEthTransfers(t *testing.T) {
db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{})
require.NoError(t, err)
tm := &TransactionManager{db, nil, nil, nil, nil, nil, nil, nil, nil, nil}
mediaServer, err := server.NewMediaServer(appdb, nil, nil, db)
require.NoError(t, err)
@ -1334,7 +1428,6 @@ func TestFetchNewBlocksCommand_findBlocksWithEthTransfers(t *testing.T) {
balanceCacher: balance.NewCacherWithTTL(5 * time.Minute),
feed: &event.Feed{},
noLimit: false,
transactionManager: tm,
tokenManager: tokenManager,
blocksLoadedCh: blockChannel,
defaultNodeBlockChunkSize: DefaultNodeBlockChunkSize,
@ -1378,7 +1471,6 @@ func TestFetchNewBlocksCommand_nonceDetection(t *testing.T) {
db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{})
require.NoError(t, err)
tm := &TransactionManager{db, nil, nil, nil, nil, nil, nil, nil, nil, nil}
mediaServer, err := server.NewMediaServer(appdb, nil, nil, db)
require.NoError(t, err)
@ -1405,7 +1497,6 @@ func TestFetchNewBlocksCommand_nonceDetection(t *testing.T) {
balanceCacher: balance.NewCacherWithTTL(5 * time.Minute),
feed: &event.Feed{},
noLimit: false,
transactionManager: tm,
tokenManager: tokenManager,
blocksLoadedCh: blockChannel,
defaultNodeBlockChunkSize: scanRange,
@ -1464,7 +1555,6 @@ func TestFetchNewBlocksCommand(t *testing.T) {
db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{})
require.NoError(t, err)
tm := &TransactionManager{db, nil, nil, nil, nil, nil, nil, nil, nil, nil}
mediaServer, err := server.NewMediaServer(appdb, nil, nil, db)
require.NoError(t, err)
@ -1534,7 +1624,6 @@ func TestFetchNewBlocksCommand(t *testing.T) {
feed: &event.Feed{},
noLimit: false,
fromBlockNumber: big.NewInt(int64(tc.currentBlock)),
transactionManager: tm,
tokenManager: tokenManager,
blocksLoadedCh: blockChannel,
defaultNodeBlockChunkSize: DefaultNodeBlockChunkSize,