feat(wallet)_: more fixes for rpc_limiter_db and chain client, more

tests
This commit is contained in:
Ivan Belyakov 2024-05-23 14:01:55 +04:00 committed by IvanBelyakoff
parent cec11e9313
commit 4e51b5ba24
8 changed files with 46 additions and 20 deletions

View File

@ -80,7 +80,7 @@ func ClientWithTag(chainClient ClientInterface, tag, groupTag string) ClientInte
if tagIface, ok := chainClient.(Tagger); ok { if tagIface, ok := chainClient.(Tagger); ok {
tagIface = DeepCopyTagger(tagIface) tagIface = DeepCopyTagger(tagIface)
tagIface.SetTag(tag) tagIface.SetTag(tag)
tagIface.SetGroupTag(tag) tagIface.SetGroupTag(groupTag)
newClient = tagIface.(ClientInterface) newClient = tagIface.(ClientInterface)
} }

View File

@ -80,7 +80,7 @@ func (s *LimitsDBStorage) Set(data *LimitData) error {
} }
limit, err := s.db.GetRPCLimit(data.Tag) limit, err := s.db.GetRPCLimit(data.Tag)
if err != nil { if err != nil && err != sql.ErrNoRows {
return err return err
} }
@ -166,13 +166,8 @@ func (rl *RPCRequestLimiter) Allow(tag string) (bool, error) {
return true, nil return true, nil
} }
// Check if period is forever
if data.Period.Milliseconds() == LimitInfinitely {
return false, nil
}
// Check if a number of requests is over the limit within the interval // Check if a number of requests is over the limit within the interval
if time.Since(data.CreatedAt) < data.Period { if time.Since(data.CreatedAt) < data.Period || data.Period.Milliseconds() == LimitInfinitely {
if data.NumReqs >= data.MaxReqs { if data.NumReqs >= data.MaxReqs {
return false, nil return false, nil
} }

View File

@ -2,6 +2,7 @@ package chain
import ( import (
"database/sql" "database/sql"
"time"
) )
type RPCLimiterDB struct { type RPCLimiterDB struct {
@ -16,7 +17,7 @@ func NewRPCLimiterDB(db *sql.DB) *RPCLimiterDB {
func (r *RPCLimiterDB) CreateRPCLimit(limit LimitData) error { func (r *RPCLimiterDB) CreateRPCLimit(limit LimitData) error {
query := `INSERT INTO rpc_limits (tag, created_at, period, max_requests, counter) VALUES (?, ?, ?, ?, ?)` query := `INSERT INTO rpc_limits (tag, created_at, period, max_requests, counter) VALUES (?, ?, ?, ?, ?)`
_, err := r.db.Exec(query, limit.Tag, limit.CreatedAt, limit.Period, limit.MaxReqs, limit.NumReqs) _, err := r.db.Exec(query, limit.Tag, limit.CreatedAt.Unix(), limit.Period, limit.MaxReqs, limit.NumReqs)
if err != nil { if err != nil {
return err return err
} }
@ -27,16 +28,19 @@ func (r *RPCLimiterDB) GetRPCLimit(tag string) (*LimitData, error) {
query := `SELECT tag, created_at, period, max_requests, counter FROM rpc_limits WHERE tag = ?` query := `SELECT tag, created_at, period, max_requests, counter FROM rpc_limits WHERE tag = ?`
row := r.db.QueryRow(query, tag) row := r.db.QueryRow(query, tag)
limit := &LimitData{} limit := &LimitData{}
err := row.Scan(limit.Tag, limit.CreatedAt, limit.Period, limit.MaxReqs, limit.NumReqs) createdAtSecs := int64(0)
if err != nil || err == sql.ErrNoRows { err := row.Scan(&limit.Tag, &createdAtSecs, &limit.Period, &limit.MaxReqs, &limit.NumReqs)
if err != nil {
return nil, err return nil, err
} }
limit.CreatedAt = time.Unix(createdAtSecs, 0)
return limit, nil return limit, nil
} }
func (r *RPCLimiterDB) UpdateRPCLimit(limit LimitData) error { func (r *RPCLimiterDB) UpdateRPCLimit(limit LimitData) error {
query := `UPDATE rpc_limits SET created_at = ?, period = ?, limit = ?, counter = ? WHERE tag = ?` query := `UPDATE rpc_limits SET created_at = ?, period = ?, max_requests = ?, counter = ? WHERE tag = ?`
_, err := r.db.Exec(query, limit.CreatedAt, limit.Period, limit.MaxReqs, limit.NumReqs, limit.Tag) _, err := r.db.Exec(query, limit.CreatedAt.Unix(), limit.Period, limit.MaxReqs, limit.NumReqs, limit.Tag)
if err != nil { if err != nil {
return err return err
} }
@ -46,7 +50,7 @@ func (r *RPCLimiterDB) UpdateRPCLimit(limit LimitData) error {
func (r *RPCLimiterDB) DeleteRPCLimit(tag string) error { func (r *RPCLimiterDB) DeleteRPCLimit(tag string) error {
query := `DELETE FROM rpc_limits WHERE tag = ?` query := `DELETE FROM rpc_limits WHERE tag = ?`
_, err := r.db.Exec(query, tag) _, err := r.db.Exec(query, tag)
if err != nil || err == sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return err return err
} }
return nil return nil

View File

@ -140,3 +140,29 @@ func TestAllowRestrictInfinitelyWhenLimitReached(t *testing.T) {
// Verify the result // Verify the result
require.False(t, allow) require.False(t, allow)
} }
func TestAllowWhenLimitNotReachedForInfinitePeriod(t *testing.T) {
storage, rl := setupTest()
// Define test inputs
tag := "testTag"
maxRequests := 10
// Set up the storage with test data
data := &LimitData{
Tag: tag,
Period: LimitInfinitely,
CreatedAt: time.Now(),
MaxReqs: maxRequests,
NumReqs: maxRequests - 1,
}
err := storage.Set(data)
require.NoError(t, err)
// Call the Allow method
allow, err := rl.Allow(tag)
require.NoError(t, err)
// Verify the result
require.True(t, allow)
}

View File

@ -769,7 +769,7 @@ func (tm *Manager) GetBalancesAtByChain(parent context.Context, clients map[uint
BlockNumber: atBlock, BlockNumber: atBlock,
}, accounts) }, accounts)
if err != nil { if err != nil {
log.Error("can't fetch chain balance 5", err) log.Error("can't fetch chain balance 5", "err", err)
return nil return nil
} }
for idx, account := range accounts { for idx, account := range accounts {
@ -804,7 +804,7 @@ func (tm *Manager) GetBalancesAtByChain(parent context.Context, clients map[uint
} }
if len(res) != len(chunk) { if len(res) != len(chunk) {
log.Error("can't fetch erc20 token balance 7", "account", account, "error response not complete") log.Error("can't fetch erc20 token balance 7", "account", account, "error", "response not complete")
return nil return nil
} }
@ -823,7 +823,7 @@ func (tm *Manager) GetBalancesAtByChain(parent context.Context, clients map[uint
balance, err := tm.GetTokenBalanceAt(ctx, client, account, token, atBlock) balance, err := tm.GetTokenBalanceAt(ctx, client, account, token, atBlock)
if err != nil { if err != nil {
if err != bind.ErrNoCode { if err != bind.ErrNoCode {
log.Error("can't fetch erc20 token balance 8", "account", account, "token", token, "error on fetching token balance") log.Error("can't fetch erc20 token balance 8", "account", account, "token", token, "error", "on fetching token balance")
return nil return nil
} }

View File

@ -1130,6 +1130,7 @@ func (c *loadBlocksAndTransfersCommand) fetchHistoryBlocksForAccount(group *asyn
// Check if limit is already reached, then skip the comamnd // Check if limit is already reached, then skip the comamnd
if allow, _ := limiter.Allow(accountTag); !allow { if allow, _ := limiter.Allow(accountTag); !allow {
log.Debug("fetchHistoryBlocksForAccount limit reached", "account", account, "chain", c.chainClient.NetworkID())
continue continue
} }

View File

@ -540,7 +540,7 @@ func removeGasOnlyEthTransfer(creator statementCreator, t transferDBFields) erro
if err != nil { if err != nil {
return err return err
} }
log.Debug("removeGasOnlyEthTransfer row deleted ", count) log.Debug("removeGasOnlyEthTransfer rows deleted", "count", count)
return nil return nil
} }

View File

@ -533,7 +533,7 @@ func (d *ERC20TransfersDownloader) blocksFromLogs(parent context.Context, logs [
// time to get logs for 100000 blocks = 1.144686979s. with 249 events in the result set. // time to get logs for 100000 blocks = 1.144686979s. with 249 events in the result set.
func (d *ERC20TransfersDownloader) GetHeadersInRange(parent context.Context, from, to *big.Int) ([]*DBHeader, error) { func (d *ERC20TransfersDownloader) GetHeadersInRange(parent context.Context, from, to *big.Int) ([]*DBHeader, error) {
start := time.Now() start := time.Now()
log.Debug("get erc20 transfers in range start", "chainID", d.client.NetworkID(), "from", from, "to", to) log.Debug("get erc20 transfers in range start", "chainID", d.client.NetworkID(), "from", from, "to", to, "accounts", d.accounts)
headers := []*DBHeader{} headers := []*DBHeader{}
ctx := context.Background() ctx := context.Background()
var err error var err error
@ -596,7 +596,7 @@ func (d *ERC20TransfersDownloader) GetHeadersInRange(parent context.Context, fro
} }
log.Debug("get erc20 transfers in range end", "chainID", d.client.NetworkID(), log.Debug("get erc20 transfers in range end", "chainID", d.client.NetworkID(),
"from", from, "to", to, "headers", len(headers), "took", time.Since(start)) "from", from, "to", to, "headers", len(headers), "accounts", d.accounts, "took", time.Since(start))
return headers, nil return headers, nil
} }