feat(wallet)_: added and fixed tests for findBlocksCommand with limiter.

This commit is contained in:
Ivan Belyakov 2024-05-21 15:05:04 +02:00 committed by IvanBelyakoff
parent 78f05f60b2
commit 9fe87657d6
5 changed files with 292 additions and 195 deletions

View File

@ -238,7 +238,7 @@ func (c *ClientWithFallback) IsConnected() bool {
func (c *ClientWithFallback) makeCall(ctx context.Context, main func() ([]any, error), fallback func() ([]any, error)) ([]any, error) { func (c *ClientWithFallback) makeCall(ctx context.Context, main func() ([]any, error), fallback func() ([]any, error)) ([]any, error) {
if c.commonLimiter != nil { if c.commonLimiter != nil {
if limited, err := c.commonLimiter.IsLimitReached(c.tag); limited { if limited, err := c.commonLimiter.IsLimitReached(c.tag); limited {
return nil, fmt.Errorf("rate limit exceeded for %s: %s", c.tag, err) return nil, fmt.Errorf("tag=%s, %w", c.tag, err)
} }
} }

View File

@ -5,8 +5,9 @@ import (
"sync" "sync"
"time" "time"
"github.com/ethereum/go-ethereum/log"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/ethereum/go-ethereum/log"
) )
const ( const (
@ -27,25 +28,29 @@ type callerOnWait struct {
} }
type RequestsStorage interface { type RequestsStorage interface {
Get(tag string) (RequestData, error) Get(tag string) (*RequestData, error)
Set(data RequestData) error Set(data *RequestData) error
} }
// InMemRequestsStorage is an in-memory dummy implementation of RequestsStorage type InMemRequestsMapStorage struct {
type InMemRequestsStorage struct { data sync.Map
data RequestData
} }
func NewInMemRequestsStorage() *InMemRequestsStorage { func NewInMemRequestsMapStorage() *InMemRequestsMapStorage {
return &InMemRequestsStorage{} return &InMemRequestsMapStorage{}
} }
func (s *InMemRequestsStorage) Get(tag string) (RequestData, error) { func (s *InMemRequestsMapStorage) Get(tag string) (*RequestData, error) {
return s.data, nil data, ok := s.data.Load(tag)
if !ok {
return nil, nil
}
return data.(*RequestData), nil
} }
func (s *InMemRequestsStorage) Set(data RequestData) error { func (s *InMemRequestsMapStorage) Set(data *RequestData) error {
s.data = data s.data.Store(data.Tag, data)
return nil return nil
} }
@ -59,7 +64,7 @@ type RequestData struct {
type RequestLimiter interface { type RequestLimiter interface {
SetMaxRequests(tag string, maxRequests int, interval time.Duration) error SetMaxRequests(tag string, maxRequests int, interval time.Duration) error
GetMaxRequests(tag string) (RequestData, error) GetMaxRequests(tag string) (*RequestData, error)
IsLimitReached(tag string) (bool, error) IsLimitReached(tag string) (bool, error)
} }
@ -83,18 +88,17 @@ func (rl *RPCRequestLimiter) SetMaxRequests(tag string, maxRequests int, interva
return nil return nil
} }
func (rl *RPCRequestLimiter) GetMaxRequests(tag string) (RequestData, error) { func (rl *RPCRequestLimiter) GetMaxRequests(tag string) (*RequestData, error) {
data, err := rl.storage.Get(tag) data, err := rl.storage.Get(tag)
if err != nil { if err != nil {
log.Error("Failed to get request data from storage", "error", err, "tag", tag) return nil, err
return RequestData{}, err
} }
return data, nil return data, nil
} }
func (rl *RPCRequestLimiter) saveToStorage(tag string, maxRequests int, interval time.Duration, numReqs int, timestamp time.Time) error { func (rl *RPCRequestLimiter) saveToStorage(tag string, maxRequests int, interval time.Duration, numReqs int, timestamp time.Time) error {
data := RequestData{ data := &RequestData{
Tag: tag, Tag: tag,
CreatedAt: timestamp, CreatedAt: timestamp,
Period: interval, Period: interval,
@ -117,6 +121,10 @@ func (rl *RPCRequestLimiter) IsLimitReached(tag string) (bool, error) {
return false, err return false, err
} }
if data == nil {
return false, nil
}
// Check if a number of requests is over the limit within the interval // Check if a number of requests is over the limit within the interval
if time.Since(data.CreatedAt) < data.Period { if time.Since(data.CreatedAt) < data.Period {
if data.NumReqs >= data.MaxReqs { if data.NumReqs >= data.MaxReqs {

View File

@ -7,8 +7,8 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func setupTest() (*InMemRequestsStorage, RequestLimiter) { func setupTest() (*InMemRequestsMapStorage, RequestLimiter) {
storage := NewInMemRequestsStorage() storage := NewInMemRequestsMapStorage()
rl := NewRequestLimiter(storage) rl := NewRequestLimiter(storage)
return storage, rl return storage, rl
} }
@ -37,7 +37,7 @@ func TestSetMaxRequests(t *testing.T) {
func TestGetMaxRequests(t *testing.T) { func TestGetMaxRequests(t *testing.T) {
storage, rl := setupTest() storage, rl := setupTest()
data := RequestData{ data := &RequestData{
Tag: "testTag", Tag: "testTag",
Period: time.Second, Period: time.Second,
MaxReqs: 10, MaxReqs: 10,
@ -63,7 +63,7 @@ func TestIsLimitReachedWithinPeriod(t *testing.T) {
interval := time.Second interval := time.Second
// Set up the storage with test data // Set up the storage with test data
data := RequestData{ data := &RequestData{
Tag: tag, Tag: tag,
Period: interval, Period: interval,
CreatedAt: time.Now(), CreatedAt: time.Now(),
@ -95,7 +95,7 @@ func TestIsLimitReachedWhenPeriodPassed(t *testing.T) {
interval := time.Second interval := time.Second
// Set up the storage with test data // Set up the storage with test data
data := RequestData{ data := &RequestData{
Tag: tag, Tag: tag,
Period: interval, Period: interval,
CreatedAt: time.Now().Add(-interval), CreatedAt: time.Now().Add(-interval),

View File

@ -29,8 +29,8 @@ const (
transferHistoryTag = "transfer_history" transferHistoryTag = "transfer_history"
newTransferHistoryTag = "new_transfer_history" newTransferHistoryTag = "new_transfer_history"
transferHistoryMaxRequests = 100 transferHistoryMaxRequests = 10000
transferHistoryMaxRequestsPeriod = 10 * time.Second transferHistoryMaxRequestsPeriod = 24 * time.Hour
) )
type nonceInfo struct { type nonceInfo struct {
@ -1122,7 +1122,7 @@ func (c *loadBlocksAndTransfersCommand) fetchHistoryBlocksForAccount(group *asyn
log.Debug("range item", "r", rangeItem, "n", c.chainClient.NetworkID(), "a", account) log.Debug("range item", "r", rangeItem, "n", c.chainClient.NetworkID(), "a", account)
chainClient := chain.ClientWithTag(c.chainClient, transferHistoryTag) chainClient := chain.ClientWithTag(c.chainClient, transferHistoryTag)
limiter := chain.NewRequestLimiter(chain.NewInMemRequestsStorage()) limiter := chain.NewRequestLimiter(chain.NewInMemRequestsMapStorage())
limiter.SetMaxRequests(transferHistoryTag, transferHistoryMaxRequests, transferHistoryMaxRequestsPeriod) limiter.SetMaxRequests(transferHistoryTag, transferHistoryMaxRequests, transferHistoryMaxRequestsPeriod)
chainClient.SetLimiter(limiter) chainClient.SetLimiter(limiter)

View File

@ -64,6 +64,24 @@ type TestClient struct {
rw sync.RWMutex rw sync.RWMutex
callsCounter map[string]int callsCounter map[string]int
currentBlock uint64 currentBlock uint64
limiter chain.RequestLimiter
}
var countAndlog = func(tc *TestClient, method string, params ...interface{}) error {
tc.incCounter(method)
if tc.traceAPICalls {
if len(params) > 0 {
tc.t.Log(method, params)
} else {
tc.t.Log(method)
}
}
return nil
}
func (tc *TestClient) countAndlog(method string, params ...interface{}) error {
return countAndlog(tc, method, params...)
} }
func (tc *TestClient) incCounter(method string) { func (tc *TestClient) incCounter(method string) {
@ -109,42 +127,42 @@ func (tc *TestClient) BatchCallContext(ctx context.Context, b []rpc.BatchElem) e
} }
func (tc *TestClient) HeaderByHash(ctx context.Context, hash common.Hash) (*types.Header, error) { func (tc *TestClient) HeaderByHash(ctx context.Context, hash common.Hash) (*types.Header, error) {
tc.incCounter("HeaderByHash") err := tc.countAndlog("HeaderByHash")
if tc.traceAPICalls { if err != nil {
tc.t.Log("HeaderByHash") return nil, err
} }
return nil, nil return nil, nil
} }
func (tc *TestClient) BlockByHash(ctx context.Context, hash common.Hash) (*types.Block, error) { func (tc *TestClient) BlockByHash(ctx context.Context, hash common.Hash) (*types.Block, error) {
tc.incCounter("BlockByHash") err := tc.countAndlog("BlockByHash")
if tc.traceAPICalls { if err != nil {
tc.t.Log("BlockByHash") return nil, err
} }
return nil, nil return nil, nil
} }
func (tc *TestClient) BlockByNumber(ctx context.Context, number *big.Int) (*types.Block, error) { func (tc *TestClient) BlockByNumber(ctx context.Context, number *big.Int) (*types.Block, error) {
tc.incCounter("BlockByNumber") err := tc.countAndlog("BlockByNumber")
if tc.traceAPICalls { if err != nil {
tc.t.Log("BlockByNumber") return nil, err
} }
return nil, nil return nil, nil
} }
func (tc *TestClient) NonceAt(ctx context.Context, account common.Address, blockNumber *big.Int) (uint64, error) { func (tc *TestClient) NonceAt(ctx context.Context, account common.Address, blockNumber *big.Int) (uint64, error) {
tc.incCounter("NonceAt")
nonce := tc.nonceHistory[account][blockNumber.Uint64()] nonce := tc.nonceHistory[account][blockNumber.Uint64()]
if tc.traceAPICalls { err := tc.countAndlog("NonceAt", fmt.Sprintf("result: %d", nonce))
tc.t.Log("NonceAt", blockNumber, "result:", nonce) if err != nil {
return nonce, err
} }
return nonce, nil return nonce, nil
} }
func (tc *TestClient) FilterLogs(ctx context.Context, q ethereum.FilterQuery) ([]types.Log, error) { func (tc *TestClient) FilterLogs(ctx context.Context, q ethereum.FilterQuery) ([]types.Log, error) {
tc.incCounter("FilterLogs") err := tc.countAndlog("FilterLogs")
if tc.traceAPICalls { if err != nil {
tc.t.Log("FilterLogs") return nil, err
} }
// We do not verify addresses for now // We do not verify addresses for now
@ -242,12 +260,12 @@ func (tc *TestClient) getBalance(address common.Address, blockNumber *big.Int) *
} }
func (tc *TestClient) BalanceAt(ctx context.Context, account common.Address, blockNumber *big.Int) (*big.Int, error) { func (tc *TestClient) BalanceAt(ctx context.Context, account common.Address, blockNumber *big.Int) (*big.Int, error) {
tc.incCounter("BalanceAt")
balance := tc.getBalance(account, blockNumber) balance := tc.getBalance(account, blockNumber)
err := tc.countAndlog("BalanceAt", fmt.Sprintf("account: %s, result: %d", account, balance))
if tc.traceAPICalls { if err != nil {
tc.t.Log("BalanceAt", blockNumber, "account:", account, "result:", balance) return nil, err
} }
return balance, nil return balance, nil
} }
@ -264,14 +282,15 @@ func (tc *TestClient) tokenBalanceAt(account common.Address, token common.Addres
} }
func (tc *TestClient) HeaderByNumber(ctx context.Context, number *big.Int) (*types.Header, error) { func (tc *TestClient) HeaderByNumber(ctx context.Context, number *big.Int) (*types.Header, error) {
tc.incCounter("HeaderByNumber")
if number == nil { if number == nil {
number = big.NewInt(int64(tc.currentBlock)) number = big.NewInt(int64(tc.currentBlock))
} }
if tc.traceAPICalls { err := tc.countAndlog("HeaderByNumber", fmt.Sprintf("number: %d", number))
tc.t.Log("HeaderByNumber", number) if err != nil {
return nil, err
} }
header := &types.Header{ header := &types.Header{
Number: number, Number: number,
Time: 0, Time: 0,
@ -281,17 +300,17 @@ func (tc *TestClient) HeaderByNumber(ctx context.Context, number *big.Int) (*typ
} }
func (tc *TestClient) CallBlockHashByTransaction(ctx context.Context, blockNumber *big.Int, index uint) (common.Hash, error) { func (tc *TestClient) CallBlockHashByTransaction(ctx context.Context, blockNumber *big.Int, index uint) (common.Hash, error) {
tc.incCounter("FullTransactionByBlockNumberAndIndex") err := tc.countAndlog("CallBlockHashByTransaction")
if tc.traceAPICalls { if err != nil {
tc.t.Log("FullTransactionByBlockNumberAndIndex") return common.Hash{}, err
} }
return common.BigToHash(blockNumber), nil return common.BigToHash(blockNumber), nil
} }
func (tc *TestClient) GetBaseFeeFromBlock(ctx context.Context, blockNumber *big.Int) (string, error) { func (tc *TestClient) GetBaseFeeFromBlock(ctx context.Context, blockNumber *big.Int) (string, error) {
tc.incCounter("GetBaseFeeFromBlock") err := tc.countAndlog("GetBaseFeeFromBlock")
if tc.traceAPICalls { if err != nil {
tc.t.Log("GetBaseFeeFromBlock") return "", err
} }
return "", nil return "", nil
} }
@ -311,10 +330,7 @@ var ethscanAddress = common.HexToAddress("0x000000000000000000000000000000000077
var balanceCheckAddress = common.HexToAddress("0x0000000000000000000000000000000010777333") var balanceCheckAddress = common.HexToAddress("0x0000000000000000000000000000000010777333")
func (tc *TestClient) CodeAt(ctx context.Context, contract common.Address, blockNumber *big.Int) ([]byte, error) { func (tc *TestClient) CodeAt(ctx context.Context, contract common.Address, blockNumber *big.Int) ([]byte, error) {
tc.incCounter("CodeAt") tc.countAndlog("CodeAt", fmt.Sprintf("contract: %s, blockNumber: %d", contract, blockNumber))
if tc.traceAPICalls {
tc.t.Log("CodeAt", contract, blockNumber)
}
if ethscanAddress == contract || balanceCheckAddress == contract { if ethscanAddress == contract || balanceCheckAddress == contract {
return []byte{1}, nil return []byte{1}, nil
@ -324,9 +340,9 @@ func (tc *TestClient) CodeAt(ctx context.Context, contract common.Address, block
} }
func (tc *TestClient) CallContract(ctx context.Context, call ethereum.CallMsg, blockNumber *big.Int) ([]byte, error) { func (tc *TestClient) CallContract(ctx context.Context, call ethereum.CallMsg, blockNumber *big.Int) ([]byte, error) {
tc.incCounter("CallContract") err := tc.countAndlog("CallContract", fmt.Sprintf("call: %v, blockNumber: %d, to: %s", call, blockNumber, call.To))
if tc.traceAPICalls { if err != nil {
tc.t.Log("CallContract", call, blockNumber, call.To) return nil, err
} }
if *call.To == ethscanAddress { if *call.To == ethscanAddress {
@ -557,9 +573,9 @@ func (tc *TestClient) prepareTokenBalanceHistory(toBlock int) {
} }
func (tc *TestClient) CallContext(ctx context.Context, result interface{}, method string, args ...interface{}) error { func (tc *TestClient) CallContext(ctx context.Context, result interface{}, method string, args ...interface{}) error {
tc.incCounter("CallContext") err := tc.countAndlog("CallContext")
if tc.traceAPICalls { if err != nil {
tc.t.Log("CallContext") return err
} }
return nil return nil
} }
@ -578,91 +594,79 @@ func (tc *TestClient) SetWalletNotifier(notifier func(chainId uint64, message st
} }
func (tc *TestClient) EstimateGas(ctx context.Context, call ethereum.CallMsg) (gas uint64, err error) { func (tc *TestClient) EstimateGas(ctx context.Context, call ethereum.CallMsg) (gas uint64, err error) {
tc.incCounter("EstimateGas") err = tc.countAndlog("EstimateGas")
if tc.traceAPICalls { if err != nil {
tc.t.Log("EstimateGas") return 0, err
} }
return 0, nil return 0, nil
} }
func (tc *TestClient) PendingCodeAt(ctx context.Context, account common.Address) ([]byte, error) { func (tc *TestClient) PendingCodeAt(ctx context.Context, account common.Address) ([]byte, error) {
tc.incCounter("PendingCodeAt") err := tc.countAndlog("PendingCodeAt")
if tc.traceAPICalls { if err != nil {
tc.t.Log("PendingCodeAt") return nil, err
} }
return nil, nil return nil, nil
} }
func (tc *TestClient) PendingCallContract(ctx context.Context, call ethereum.CallMsg) ([]byte, error) { func (tc *TestClient) PendingCallContract(ctx context.Context, call ethereum.CallMsg) ([]byte, error) {
tc.incCounter("PendingCallContract") err := tc.countAndlog("PendingCallContract")
if tc.traceAPICalls { if err != nil {
tc.t.Log("PendingCallContract") return nil, err
} }
return nil, nil return nil, nil
} }
func (tc *TestClient) PendingNonceAt(ctx context.Context, account common.Address) (uint64, error) { func (tc *TestClient) PendingNonceAt(ctx context.Context, account common.Address) (uint64, error) {
tc.incCounter("PendingNonceAt") err := tc.countAndlog("PendingNonceAt")
if tc.traceAPICalls { if err != nil {
tc.t.Log("PendingNonceAt") return 0, err
} }
return 0, nil return 0, nil
} }
func (tc *TestClient) SuggestGasPrice(ctx context.Context) (*big.Int, error) { func (tc *TestClient) SuggestGasPrice(ctx context.Context) (*big.Int, error) {
tc.incCounter("SuggestGasPrice") err := tc.countAndlog("SuggestGasPrice")
if tc.traceAPICalls { if err != nil {
tc.t.Log("SuggestGasPrice") return nil, err
} }
return nil, nil return nil, nil
} }
func (tc *TestClient) SendTransaction(ctx context.Context, tx *types.Transaction) error { func (tc *TestClient) SendTransaction(ctx context.Context, tx *types.Transaction) error {
tc.incCounter("SendTransaction") err := tc.countAndlog("SendTransaction")
if tc.traceAPICalls { if err != nil {
tc.t.Log("SendTransaction") return err
} }
return nil return nil
} }
func (tc *TestClient) SuggestGasTipCap(ctx context.Context) (*big.Int, error) { func (tc *TestClient) SuggestGasTipCap(ctx context.Context) (*big.Int, error) {
tc.incCounter("SuggestGasTipCap") err := tc.countAndlog("SuggestGasTipCap")
if tc.traceAPICalls { if err != nil {
tc.t.Log("SuggestGasTipCap") return nil, err
} }
return nil, nil return nil, nil
} }
func (tc *TestClient) BatchCallContextIgnoringLocalHandlers(ctx context.Context, b []rpc.BatchElem) error { func (tc *TestClient) BatchCallContextIgnoringLocalHandlers(ctx context.Context, b []rpc.BatchElem) error {
tc.incCounter("BatchCallContextIgnoringLocalHandlers") err := tc.countAndlog("BatchCallContextIgnoringLocalHandlers")
if tc.traceAPICalls { if err != nil {
tc.t.Log("BatchCallContextIgnoringLocalHandlers") return err
} }
return nil return nil
} }
func (tc *TestClient) CallContextIgnoringLocalHandlers(ctx context.Context, result interface{}, method string, args ...interface{}) error { func (tc *TestClient) CallContextIgnoringLocalHandlers(ctx context.Context, result interface{}, method string, args ...interface{}) error {
tc.incCounter("CallContextIgnoringLocalHandlers") err := tc.countAndlog("CallContextIgnoringLocalHandlers")
if tc.traceAPICalls { if err != nil {
tc.t.Log("CallContextIgnoringLocalHandlers") return err
} }
return nil return nil
} }
func (tc *TestClient) CallRaw(data string) string { func (tc *TestClient) CallRaw(data string) string {
tc.incCounter("CallRaw") _ = tc.countAndlog("CallRaw")
if tc.traceAPICalls {
tc.t.Log("CallRaw")
}
return "" return ""
} }
@ -671,38 +675,34 @@ func (tc *TestClient) GetChainID() *big.Int {
} }
func (tc *TestClient) SubscribeFilterLogs(ctx context.Context, q ethereum.FilterQuery, ch chan<- types.Log) (ethereum.Subscription, error) { func (tc *TestClient) SubscribeFilterLogs(ctx context.Context, q ethereum.FilterQuery, ch chan<- types.Log) (ethereum.Subscription, error) {
tc.incCounter("SubscribeFilterLogs") err := tc.countAndlog("SubscribeFilterLogs")
if tc.traceAPICalls { if err != nil {
tc.t.Log("SubscribeFilterLogs") return nil, err
} }
return nil, nil return nil, nil
} }
func (tc *TestClient) TransactionReceipt(ctx context.Context, txHash common.Hash) (*types.Receipt, error) { func (tc *TestClient) TransactionReceipt(ctx context.Context, txHash common.Hash) (*types.Receipt, error) {
tc.incCounter("TransactionReceipt") err := tc.countAndlog("TransactionReceipt")
if tc.traceAPICalls { if err != nil {
tc.t.Log("TransactionReceipt") return nil, err
} }
return nil, nil return nil, nil
} }
func (tc *TestClient) TransactionByHash(ctx context.Context, txHash common.Hash) (*types.Transaction, bool, error) { func (tc *TestClient) TransactionByHash(ctx context.Context, txHash common.Hash) (*types.Transaction, bool, error) {
tc.incCounter("TransactionByHash") err := tc.countAndlog("TransactionByHash")
if tc.traceAPICalls { if err != nil {
tc.t.Log("TransactionByHash") return nil, false, err
} }
return nil, false, nil return nil, false, nil
} }
func (tc *TestClient) BlockNumber(ctx context.Context) (uint64, error) { func (tc *TestClient) BlockNumber(ctx context.Context) (uint64, error) {
tc.incCounter("BlockNumber") err := tc.countAndlog("BlockNumber")
if tc.traceAPICalls { if err != nil {
tc.t.Log("BlockNumber") return 0, err
} }
return 0, nil return 0, nil
} }
func (tc *TestClient) SetIsConnected(value bool) { func (tc *TestClient) SetIsConnected(value bool) {
@ -711,7 +711,7 @@ func (tc *TestClient) SetIsConnected(value bool) {
} }
} }
func (tc *TestClient) GetIsConnected() bool { func (tc *TestClient) IsConnected() bool {
if tc.traceAPICalls { if tc.traceAPICalls {
tc.t.Log("GetIsConnected") tc.t.Log("GetIsConnected")
} }
@ -719,6 +719,14 @@ func (tc *TestClient) GetIsConnected() bool {
return true return true
} }
func (tc *TestClient) GetLimiter() chain.RequestLimiter {
return tc.limiter
}
func (tc *TestClient) SetLimiter(limiter chain.RequestLimiter) {
tc.limiter = limiter
}
type testERC20Transfer struct { type testERC20Transfer struct {
block *big.Int block *big.Int
address common.Address address common.Address
@ -1010,82 +1018,113 @@ func getCases() []findBlockCase {
var tokenTXXAddress = common.HexToAddress("0x53211") var tokenTXXAddress = common.HexToAddress("0x53211")
var tokenTXYAddress = common.HexToAddress("0x73211") var tokenTXYAddress = common.HexToAddress("0x73211")
func setupFindBlocksCommand(t *testing.T, accountAddress common.Address, fromBlock, toBlock *big.Int, rangeSize int, balances map[common.Address][][]int, outgoingERC20Transfers, incomingERC20Transfers, outgoingERC1155SingleTransfers, incomingERC1155SingleTransfers map[common.Address][]testERC20Transfer) (*findBlocksCommand, *TestClient, chan []*DBHeader, *BlockRangeSequentialDAO) {
appdb, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{})
require.NoError(t, err)
db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{})
require.NoError(t, err)
mediaServer, err := server.NewMediaServer(appdb, nil, nil, db)
require.NoError(t, err)
wdb := NewDB(db)
tc := &TestClient{
t: t,
balances: balances,
outgoingERC20Transfers: outgoingERC20Transfers,
incomingERC20Transfers: incomingERC20Transfers,
outgoingERC1155SingleTransfers: outgoingERC1155SingleTransfers,
incomingERC1155SingleTransfers: incomingERC1155SingleTransfers,
callsCounter: map[string]int{},
}
// tc.traceAPICalls = true
// tc.printPreparedData = true
tc.prepareBalanceHistory(100)
tc.prepareTokenBalanceHistory(100)
blockChannel := make(chan []*DBHeader, 100)
// Reimplement the common function that is called from every method to check for the limit
countAndlog = func(tc *TestClient, method string, params ...interface{}) error {
if tc.GetLimiter() != nil {
if limited, _ := tc.GetLimiter().IsLimitReached(transferHistoryTag); limited {
t.Log("ERROR: requests over limit")
return chain.ErrRequestsOverLimit
}
}
tc.incCounter(method)
if tc.traceAPICalls {
if len(params) > 0 {
tc.t.Log(method, params)
} else {
tc.t.Log(method)
}
}
return nil
}
client, _ := statusRpc.NewClient(nil, 1, params.UpstreamRPCConfig{Enabled: false, URL: ""}, []params.Network{}, db)
client.SetClient(tc.NetworkID(), tc)
tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil, nil, nil)
tokenManager.SetTokens([]*token.Token{
{
Address: tokenTXXAddress,
Symbol: "TXX",
Decimals: 18,
ChainID: tc.NetworkID(),
Name: "Test Token 1",
Verified: true,
},
{
Address: tokenTXYAddress,
Symbol: "TXY",
Decimals: 18,
ChainID: tc.NetworkID(),
Name: "Test Token 2",
Verified: true,
},
})
accDB, err := accounts.NewDB(appdb)
require.NoError(t, err)
blockRangeDAO := &BlockRangeSequentialDAO{wdb.client}
fbc := &findBlocksCommand{
accounts: []common.Address{accountAddress},
db: wdb,
blockRangeDAO: blockRangeDAO,
accountsDB: accDB,
chainClient: tc,
balanceCacher: balance.NewCacherWithTTL(5 * time.Minute),
feed: &event.Feed{},
noLimit: false,
fromBlockNumber: fromBlock,
toBlockNumber: toBlock,
blocksLoadedCh: blockChannel,
defaultNodeBlockChunkSize: rangeSize,
tokenManager: tokenManager,
}
return fbc, tc, blockChannel, blockRangeDAO
}
func TestFindBlocksCommand(t *testing.T) { func TestFindBlocksCommand(t *testing.T) {
for idx, testCase := range getCases() { for idx, testCase := range getCases() {
t.Log("case #", idx+1) t.Log("case #", idx+1)
ctx := context.Background()
group := async.NewGroup(ctx)
appdb, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{})
require.NoError(t, err)
db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{})
require.NoError(t, err)
tm := &TransactionManager{db, nil, nil, nil, nil, nil, nil, nil, nil, nil}
mediaServer, err := server.NewMediaServer(appdb, nil, nil, db)
require.NoError(t, err)
accountAddress := common.HexToAddress("0x1234") accountAddress := common.HexToAddress("0x1234")
wdb := NewDB(db)
tc := &TestClient{
t: t,
balances: map[common.Address][][]int{accountAddress: testCase.balanceChanges},
outgoingERC20Transfers: map[common.Address][]testERC20Transfer{accountAddress: testCase.outgoingERC20Transfers},
incomingERC20Transfers: map[common.Address][]testERC20Transfer{accountAddress: testCase.incomingERC20Transfers},
outgoingERC1155SingleTransfers: map[common.Address][]testERC20Transfer{accountAddress: testCase.outgoingERC1155SingleTransfers},
incomingERC1155SingleTransfers: map[common.Address][]testERC20Transfer{accountAddress: testCase.incomingERC1155SingleTransfers},
callsCounter: map[string]int{},
}
// tc.traceAPICalls = true
// tc.printPreparedData = true
tc.prepareBalanceHistory(100)
tc.prepareTokenBalanceHistory(100)
blockChannel := make(chan []*DBHeader, 100)
rangeSize := 20 rangeSize := 20
if testCase.rangeSize != 0 { if testCase.rangeSize != 0 {
rangeSize = testCase.rangeSize rangeSize = testCase.rangeSize
} }
client, _ := statusRpc.NewClient(nil, 1, params.UpstreamRPCConfig{Enabled: false, URL: ""}, []params.Network{}, db)
client.SetClient(tc.NetworkID(), tc) balances := map[common.Address][][]int{accountAddress: testCase.balanceChanges}
tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil, nil, nil) outgoingERC20Transfers := map[common.Address][]testERC20Transfer{accountAddress: testCase.outgoingERC20Transfers}
tokenManager.SetTokens([]*token.Token{ incomingERC20Transfers := map[common.Address][]testERC20Transfer{accountAddress: testCase.incomingERC20Transfers}
{ outgoingERC1155SingleTransfers := map[common.Address][]testERC20Transfer{accountAddress: testCase.outgoingERC1155SingleTransfers}
Address: tokenTXXAddress, incomingERC1155SingleTransfers := map[common.Address][]testERC20Transfer{accountAddress: testCase.incomingERC1155SingleTransfers}
Symbol: "TXX",
Decimals: 18, fbc, tc, blockChannel, blockRangeDAO := setupFindBlocksCommand(t, accountAddress, big.NewInt(testCase.fromBlock), big.NewInt(testCase.toBlock), rangeSize, balances, outgoingERC20Transfers, incomingERC20Transfers, outgoingERC1155SingleTransfers, incomingERC1155SingleTransfers)
ChainID: tc.NetworkID(), ctx := context.Background()
Name: "Test Token 1", group := async.NewGroup(ctx)
Verified: true,
},
{
Address: tokenTXYAddress,
Symbol: "TXY",
Decimals: 18,
ChainID: tc.NetworkID(),
Name: "Test Token 2",
Verified: true,
},
})
accDB, err := accounts.NewDB(appdb)
require.NoError(t, err)
blockRangeDAO := &BlockRangeSequentialDAO{wdb.client}
fbc := &findBlocksCommand{
accounts: []common.Address{accountAddress},
db: wdb,
blockRangeDAO: blockRangeDAO,
accountsDB: accDB,
chainClient: tc,
balanceCacher: balance.NewCacherWithTTL(5 * time.Minute),
feed: &event.Feed{},
noLimit: false,
fromBlockNumber: big.NewInt(testCase.fromBlock),
toBlockNumber: big.NewInt(testCase.toBlock),
transactionManager: tm,
blocksLoadedCh: blockChannel,
defaultNodeBlockChunkSize: rangeSize,
tokenManager: tokenManager,
}
group.Add(fbc.Command()) group.Add(fbc.Command())
foundBlocks := []*DBHeader{} foundBlocks := []*DBHeader{}
@ -1129,6 +1168,62 @@ func TestFindBlocksCommand(t *testing.T) {
} }
} }
func TestFindBlocksCommandWithLimiter(t *testing.T) {
// Set up logging
// handler := log.StreamHandler(os.Stdout, log.TerminalFormat(true))
// log.Root().SetHandler(handler)
maxRequests := 1
rangeSize := 20
accountAddress := common.HexToAddress("0x1234")
balances := map[common.Address][][]int{accountAddress: {{5, 1, 0}, {20, 2, 0}, {45, 1, 1}, {46, 50, 0}, {75, 0, 1}}}
fbc, tc, blockChannel, _ := setupFindBlocksCommand(t, accountAddress, big.NewInt(0), big.NewInt(20), rangeSize, balances, nil, nil, nil, nil)
limiter := chain.NewRequestLimiter(chain.NewInMemRequestsMapStorage())
limiter.SetMaxRequests(transferHistoryTag, maxRequests, time.Hour)
tc.SetLimiter(limiter)
ctx := context.Background()
group := async.NewAtomicGroup(ctx)
group.Add(fbc.Command(1 * time.Millisecond))
select {
case <-ctx.Done():
t.Log("ERROR")
case <-group.WaitAsync():
close(blockChannel)
require.Error(t, chain.ErrRequestsOverLimit, group.Error())
require.Equal(t, maxRequests, tc.getCounter())
}
}
func TestFindBlocksCommandWithLimiterTagDifferentThanTransfers(t *testing.T) {
rangeSize := 20
maxRequests := 1
accountAddress := common.HexToAddress("0x1234")
balances := map[common.Address][][]int{accountAddress: {{5, 1, 0}, {20, 2, 0}, {45, 1, 1}, {46, 50, 0}, {75, 0, 1}}}
outgoingERC20Transfers := map[common.Address][]testERC20Transfer{accountAddress: {{big.NewInt(6), tokenTXXAddress, big.NewInt(1), walletcommon.Erc20TransferEventType}}}
incomingERC20Transfers := map[common.Address][]testERC20Transfer{accountAddress: {{big.NewInt(6), tokenTXXAddress, big.NewInt(1), walletcommon.Erc20TransferEventType}}}
fbc, tc, blockChannel, _ := setupFindBlocksCommand(t, accountAddress, big.NewInt(0), big.NewInt(20), rangeSize, balances, outgoingERC20Transfers, incomingERC20Transfers, nil, nil)
limiter := chain.NewRequestLimiter(chain.NewInMemRequestsMapStorage())
limiter.SetMaxRequests("some-other-tag-than-transfer-history", maxRequests, time.Hour)
tc.SetLimiter(limiter)
ctx := context.Background()
group := async.NewAtomicGroup(ctx)
group.Add(fbc.Command(1 * time.Millisecond))
select {
case <-ctx.Done():
t.Log("ERROR")
case <-group.WaitAsync():
close(blockChannel)
require.NoError(t, group.Error())
require.Greater(t, tc.getCounter(), maxRequests)
}
}
type MockETHClient struct { type MockETHClient struct {
mock.Mock mock.Mock
} }
@ -1228,7 +1323,7 @@ func TestFetchTransfersForLoadedBlocks(t *testing.T) {
tc.prepareBalanceHistory(int(tc.currentBlock)) tc.prepareBalanceHistory(int(tc.currentBlock))
tc.prepareTokenBalanceHistory(int(tc.currentBlock)) tc.prepareTokenBalanceHistory(int(tc.currentBlock))
tc.traceAPICalls = true // tc.traceAPICalls = true
ctx := context.Background() ctx := context.Background()
group := async.NewAtomicGroup(ctx) group := async.NewAtomicGroup(ctx)
@ -1278,7 +1373,6 @@ func TestFetchNewBlocksCommand_findBlocksWithEthTransfers(t *testing.T) {
db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{})
require.NoError(t, err) require.NoError(t, err)
tm := &TransactionManager{db, nil, nil, nil, nil, nil, nil, nil, nil, nil}
mediaServer, err := server.NewMediaServer(appdb, nil, nil, db) mediaServer, err := server.NewMediaServer(appdb, nil, nil, db)
require.NoError(t, err) require.NoError(t, err)
@ -1334,7 +1428,6 @@ func TestFetchNewBlocksCommand_findBlocksWithEthTransfers(t *testing.T) {
balanceCacher: balance.NewCacherWithTTL(5 * time.Minute), balanceCacher: balance.NewCacherWithTTL(5 * time.Minute),
feed: &event.Feed{}, feed: &event.Feed{},
noLimit: false, noLimit: false,
transactionManager: tm,
tokenManager: tokenManager, tokenManager: tokenManager,
blocksLoadedCh: blockChannel, blocksLoadedCh: blockChannel,
defaultNodeBlockChunkSize: DefaultNodeBlockChunkSize, defaultNodeBlockChunkSize: DefaultNodeBlockChunkSize,
@ -1378,7 +1471,6 @@ func TestFetchNewBlocksCommand_nonceDetection(t *testing.T) {
db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{})
require.NoError(t, err) require.NoError(t, err)
tm := &TransactionManager{db, nil, nil, nil, nil, nil, nil, nil, nil, nil}
mediaServer, err := server.NewMediaServer(appdb, nil, nil, db) mediaServer, err := server.NewMediaServer(appdb, nil, nil, db)
require.NoError(t, err) require.NoError(t, err)
@ -1405,7 +1497,6 @@ func TestFetchNewBlocksCommand_nonceDetection(t *testing.T) {
balanceCacher: balance.NewCacherWithTTL(5 * time.Minute), balanceCacher: balance.NewCacherWithTTL(5 * time.Minute),
feed: &event.Feed{}, feed: &event.Feed{},
noLimit: false, noLimit: false,
transactionManager: tm,
tokenManager: tokenManager, tokenManager: tokenManager,
blocksLoadedCh: blockChannel, blocksLoadedCh: blockChannel,
defaultNodeBlockChunkSize: scanRange, defaultNodeBlockChunkSize: scanRange,
@ -1464,7 +1555,6 @@ func TestFetchNewBlocksCommand(t *testing.T) {
db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{})
require.NoError(t, err) require.NoError(t, err)
tm := &TransactionManager{db, nil, nil, nil, nil, nil, nil, nil, nil, nil}
mediaServer, err := server.NewMediaServer(appdb, nil, nil, db) mediaServer, err := server.NewMediaServer(appdb, nil, nil, db)
require.NoError(t, err) require.NoError(t, err)
@ -1534,7 +1624,6 @@ func TestFetchNewBlocksCommand(t *testing.T) {
feed: &event.Feed{}, feed: &event.Feed{},
noLimit: false, noLimit: false,
fromBlockNumber: big.NewInt(int64(tc.currentBlock)), fromBlockNumber: big.NewInt(int64(tc.currentBlock)),
transactionManager: tm,
tokenManager: tokenManager, tokenManager: tokenManager,
blocksLoadedCh: blockChannel, blocksLoadedCh: blockChannel,
defaultNodeBlockChunkSize: DefaultNodeBlockChunkSize, defaultNodeBlockChunkSize: DefaultNodeBlockChunkSize,