From 4e51b5ba24ea05cb2778ac621871058dd82e9068 Mon Sep 17 00:00:00 2001 From: Ivan Belyakov Date: Thu, 23 May 2024 14:01:55 +0400 Subject: [PATCH] feat(wallet)_: more fixes for rpc_limiter_db and chain client, more tests --- rpc/chain/client.go | 2 +- rpc/chain/rpc_limiter.go | 9 ++----- rpc/chain/rpc_limiter_db.go | 16 +++++++----- rpc/chain/rpc_limiter_test.go | 26 +++++++++++++++++++ services/wallet/token/token.go | 6 ++--- .../wallet/transfer/commands_sequential.go | 1 + services/wallet/transfer/database.go | 2 +- services/wallet/transfer/downloader.go | 4 +-- 8 files changed, 46 insertions(+), 20 deletions(-) diff --git a/rpc/chain/client.go b/rpc/chain/client.go index 3250a4e99..8e10347ae 100644 --- a/rpc/chain/client.go +++ b/rpc/chain/client.go @@ -80,7 +80,7 @@ func ClientWithTag(chainClient ClientInterface, tag, groupTag string) ClientInte if tagIface, ok := chainClient.(Tagger); ok { tagIface = DeepCopyTagger(tagIface) tagIface.SetTag(tag) - tagIface.SetGroupTag(tag) + tagIface.SetGroupTag(groupTag) newClient = tagIface.(ClientInterface) } diff --git a/rpc/chain/rpc_limiter.go b/rpc/chain/rpc_limiter.go index e2ed41f0f..369e45a72 100644 --- a/rpc/chain/rpc_limiter.go +++ b/rpc/chain/rpc_limiter.go @@ -80,7 +80,7 @@ func (s *LimitsDBStorage) Set(data *LimitData) error { } limit, err := s.db.GetRPCLimit(data.Tag) - if err != nil { + if err != nil && err != sql.ErrNoRows { return err } @@ -166,13 +166,8 @@ func (rl *RPCRequestLimiter) Allow(tag string) (bool, error) { return true, nil } - // Check if period is forever - if data.Period.Milliseconds() == LimitInfinitely { - return false, nil - } - // Check if a number of requests is over the limit within the interval - if time.Since(data.CreatedAt) < data.Period { + if time.Since(data.CreatedAt) < data.Period || data.Period.Milliseconds() == LimitInfinitely { if data.NumReqs >= data.MaxReqs { return false, nil } diff --git a/rpc/chain/rpc_limiter_db.go b/rpc/chain/rpc_limiter_db.go index ff4e85f04..55bab9086 100644 --- a/rpc/chain/rpc_limiter_db.go +++ b/rpc/chain/rpc_limiter_db.go @@ -2,6 +2,7 @@ package chain import ( "database/sql" + "time" ) type RPCLimiterDB struct { @@ -16,7 +17,7 @@ func NewRPCLimiterDB(db *sql.DB) *RPCLimiterDB { func (r *RPCLimiterDB) CreateRPCLimit(limit LimitData) error { query := `INSERT INTO rpc_limits (tag, created_at, period, max_requests, counter) VALUES (?, ?, ?, ?, ?)` - _, err := r.db.Exec(query, limit.Tag, limit.CreatedAt, limit.Period, limit.MaxReqs, limit.NumReqs) + _, err := r.db.Exec(query, limit.Tag, limit.CreatedAt.Unix(), limit.Period, limit.MaxReqs, limit.NumReqs) if err != nil { return err } @@ -27,16 +28,19 @@ func (r *RPCLimiterDB) GetRPCLimit(tag string) (*LimitData, error) { query := `SELECT tag, created_at, period, max_requests, counter FROM rpc_limits WHERE tag = ?` row := r.db.QueryRow(query, tag) limit := &LimitData{} - err := row.Scan(limit.Tag, limit.CreatedAt, limit.Period, limit.MaxReqs, limit.NumReqs) - if err != nil || err == sql.ErrNoRows { + createdAtSecs := int64(0) + err := row.Scan(&limit.Tag, &createdAtSecs, &limit.Period, &limit.MaxReqs, &limit.NumReqs) + if err != nil { return nil, err } + + limit.CreatedAt = time.Unix(createdAtSecs, 0) return limit, nil } func (r *RPCLimiterDB) UpdateRPCLimit(limit LimitData) error { - query := `UPDATE rpc_limits SET created_at = ?, period = ?, limit = ?, counter = ? WHERE tag = ?` - _, err := r.db.Exec(query, limit.CreatedAt, limit.Period, limit.MaxReqs, limit.NumReqs, limit.Tag) + query := `UPDATE rpc_limits SET created_at = ?, period = ?, max_requests = ?, counter = ? WHERE tag = ?` + _, err := r.db.Exec(query, limit.CreatedAt.Unix(), limit.Period, limit.MaxReqs, limit.NumReqs, limit.Tag) if err != nil { return err } @@ -46,7 +50,7 @@ func (r *RPCLimiterDB) UpdateRPCLimit(limit LimitData) error { func (r *RPCLimiterDB) DeleteRPCLimit(tag string) error { query := `DELETE FROM rpc_limits WHERE tag = ?` _, err := r.db.Exec(query, tag) - if err != nil || err == sql.ErrNoRows { + if err != nil && err != sql.ErrNoRows { return err } return nil diff --git a/rpc/chain/rpc_limiter_test.go b/rpc/chain/rpc_limiter_test.go index 7d1f95f7d..142ca9a29 100644 --- a/rpc/chain/rpc_limiter_test.go +++ b/rpc/chain/rpc_limiter_test.go @@ -140,3 +140,29 @@ func TestAllowRestrictInfinitelyWhenLimitReached(t *testing.T) { // Verify the result require.False(t, allow) } + +func TestAllowWhenLimitNotReachedForInfinitePeriod(t *testing.T) { + storage, rl := setupTest() + + // Define test inputs + tag := "testTag" + maxRequests := 10 + + // Set up the storage with test data + data := &LimitData{ + Tag: tag, + Period: LimitInfinitely, + CreatedAt: time.Now(), + MaxReqs: maxRequests, + NumReqs: maxRequests - 1, + } + err := storage.Set(data) + require.NoError(t, err) + + // Call the Allow method + allow, err := rl.Allow(tag) + require.NoError(t, err) + + // Verify the result + require.True(t, allow) +} diff --git a/services/wallet/token/token.go b/services/wallet/token/token.go index 6d9fed924..896b2781a 100644 --- a/services/wallet/token/token.go +++ b/services/wallet/token/token.go @@ -769,7 +769,7 @@ func (tm *Manager) GetBalancesAtByChain(parent context.Context, clients map[uint BlockNumber: atBlock, }, accounts) if err != nil { - log.Error("can't fetch chain balance 5", err) + log.Error("can't fetch chain balance 5", "err", err) return nil } for idx, account := range accounts { @@ -804,7 +804,7 @@ func (tm *Manager) GetBalancesAtByChain(parent context.Context, clients map[uint } if len(res) != len(chunk) { - log.Error("can't fetch erc20 token balance 7", "account", account, "error response not complete") + log.Error("can't fetch erc20 token balance 7", "account", account, "error", "response not complete") return nil } @@ -823,7 +823,7 @@ func (tm *Manager) GetBalancesAtByChain(parent context.Context, clients map[uint balance, err := tm.GetTokenBalanceAt(ctx, client, account, token, atBlock) if err != nil { if err != bind.ErrNoCode { - log.Error("can't fetch erc20 token balance 8", "account", account, "token", token, "error on fetching token balance") + log.Error("can't fetch erc20 token balance 8", "account", account, "token", token, "error", "on fetching token balance") return nil } diff --git a/services/wallet/transfer/commands_sequential.go b/services/wallet/transfer/commands_sequential.go index ce9787003..abbae60a4 100644 --- a/services/wallet/transfer/commands_sequential.go +++ b/services/wallet/transfer/commands_sequential.go @@ -1130,6 +1130,7 @@ func (c *loadBlocksAndTransfersCommand) fetchHistoryBlocksForAccount(group *asyn // Check if limit is already reached, then skip the comamnd if allow, _ := limiter.Allow(accountTag); !allow { + log.Debug("fetchHistoryBlocksForAccount limit reached", "account", account, "chain", c.chainClient.NetworkID()) continue } diff --git a/services/wallet/transfer/database.go b/services/wallet/transfer/database.go index 7589fecd3..69b15b1aa 100644 --- a/services/wallet/transfer/database.go +++ b/services/wallet/transfer/database.go @@ -540,7 +540,7 @@ func removeGasOnlyEthTransfer(creator statementCreator, t transferDBFields) erro if err != nil { return err } - log.Debug("removeGasOnlyEthTransfer row deleted ", count) + log.Debug("removeGasOnlyEthTransfer rows deleted", "count", count) return nil } diff --git a/services/wallet/transfer/downloader.go b/services/wallet/transfer/downloader.go index 7aca30363..00e59e8c1 100644 --- a/services/wallet/transfer/downloader.go +++ b/services/wallet/transfer/downloader.go @@ -533,7 +533,7 @@ func (d *ERC20TransfersDownloader) blocksFromLogs(parent context.Context, logs [ // time to get logs for 100000 blocks = 1.144686979s. with 249 events in the result set. func (d *ERC20TransfersDownloader) GetHeadersInRange(parent context.Context, from, to *big.Int) ([]*DBHeader, error) { start := time.Now() - log.Debug("get erc20 transfers in range start", "chainID", d.client.NetworkID(), "from", from, "to", to) + log.Debug("get erc20 transfers in range start", "chainID", d.client.NetworkID(), "from", from, "to", to, "accounts", d.accounts) headers := []*DBHeader{} ctx := context.Background() var err error @@ -596,7 +596,7 @@ func (d *ERC20TransfersDownloader) GetHeadersInRange(parent context.Context, fro } log.Debug("get erc20 transfers in range end", "chainID", d.client.NetworkID(), - "from", from, "to", to, "headers", len(headers), "took", time.Since(start)) + "from", from, "to", to, "headers", len(headers), "accounts", d.accounts, "took", time.Since(start)) return headers, nil }