feat(wallet)_: more fixes for rpc_limiter_db and chain client, more
tests
This commit is contained in:
parent
cec11e9313
commit
4e51b5ba24
|
@ -80,7 +80,7 @@ func ClientWithTag(chainClient ClientInterface, tag, groupTag string) ClientInte
|
|||
if tagIface, ok := chainClient.(Tagger); ok {
|
||||
tagIface = DeepCopyTagger(tagIface)
|
||||
tagIface.SetTag(tag)
|
||||
tagIface.SetGroupTag(tag)
|
||||
tagIface.SetGroupTag(groupTag)
|
||||
newClient = tagIface.(ClientInterface)
|
||||
}
|
||||
|
||||
|
|
|
@ -80,7 +80,7 @@ func (s *LimitsDBStorage) Set(data *LimitData) error {
|
|||
}
|
||||
|
||||
limit, err := s.db.GetRPCLimit(data.Tag)
|
||||
if err != nil {
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -166,13 +166,8 @@ func (rl *RPCRequestLimiter) Allow(tag string) (bool, error) {
|
|||
return true, nil
|
||||
}
|
||||
|
||||
// Check if period is forever
|
||||
if data.Period.Milliseconds() == LimitInfinitely {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Check if a number of requests is over the limit within the interval
|
||||
if time.Since(data.CreatedAt) < data.Period {
|
||||
if time.Since(data.CreatedAt) < data.Period || data.Period.Milliseconds() == LimitInfinitely {
|
||||
if data.NumReqs >= data.MaxReqs {
|
||||
return false, nil
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package chain
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RPCLimiterDB struct {
|
||||
|
@ -16,7 +17,7 @@ func NewRPCLimiterDB(db *sql.DB) *RPCLimiterDB {
|
|||
|
||||
func (r *RPCLimiterDB) CreateRPCLimit(limit LimitData) error {
|
||||
query := `INSERT INTO rpc_limits (tag, created_at, period, max_requests, counter) VALUES (?, ?, ?, ?, ?)`
|
||||
_, err := r.db.Exec(query, limit.Tag, limit.CreatedAt, limit.Period, limit.MaxReqs, limit.NumReqs)
|
||||
_, err := r.db.Exec(query, limit.Tag, limit.CreatedAt.Unix(), limit.Period, limit.MaxReqs, limit.NumReqs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -27,16 +28,19 @@ func (r *RPCLimiterDB) GetRPCLimit(tag string) (*LimitData, error) {
|
|||
query := `SELECT tag, created_at, period, max_requests, counter FROM rpc_limits WHERE tag = ?`
|
||||
row := r.db.QueryRow(query, tag)
|
||||
limit := &LimitData{}
|
||||
err := row.Scan(limit.Tag, limit.CreatedAt, limit.Period, limit.MaxReqs, limit.NumReqs)
|
||||
if err != nil || err == sql.ErrNoRows {
|
||||
createdAtSecs := int64(0)
|
||||
err := row.Scan(&limit.Tag, &createdAtSecs, &limit.Period, &limit.MaxReqs, &limit.NumReqs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
limit.CreatedAt = time.Unix(createdAtSecs, 0)
|
||||
return limit, nil
|
||||
}
|
||||
|
||||
func (r *RPCLimiterDB) UpdateRPCLimit(limit LimitData) error {
|
||||
query := `UPDATE rpc_limits SET created_at = ?, period = ?, limit = ?, counter = ? WHERE tag = ?`
|
||||
_, err := r.db.Exec(query, limit.CreatedAt, limit.Period, limit.MaxReqs, limit.NumReqs, limit.Tag)
|
||||
query := `UPDATE rpc_limits SET created_at = ?, period = ?, max_requests = ?, counter = ? WHERE tag = ?`
|
||||
_, err := r.db.Exec(query, limit.CreatedAt.Unix(), limit.Period, limit.MaxReqs, limit.NumReqs, limit.Tag)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -46,7 +50,7 @@ func (r *RPCLimiterDB) UpdateRPCLimit(limit LimitData) error {
|
|||
func (r *RPCLimiterDB) DeleteRPCLimit(tag string) error {
|
||||
query := `DELETE FROM rpc_limits WHERE tag = ?`
|
||||
_, err := r.db.Exec(query, tag)
|
||||
if err != nil || err == sql.ErrNoRows {
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -140,3 +140,29 @@ func TestAllowRestrictInfinitelyWhenLimitReached(t *testing.T) {
|
|||
// Verify the result
|
||||
require.False(t, allow)
|
||||
}
|
||||
|
||||
func TestAllowWhenLimitNotReachedForInfinitePeriod(t *testing.T) {
|
||||
storage, rl := setupTest()
|
||||
|
||||
// Define test inputs
|
||||
tag := "testTag"
|
||||
maxRequests := 10
|
||||
|
||||
// Set up the storage with test data
|
||||
data := &LimitData{
|
||||
Tag: tag,
|
||||
Period: LimitInfinitely,
|
||||
CreatedAt: time.Now(),
|
||||
MaxReqs: maxRequests,
|
||||
NumReqs: maxRequests - 1,
|
||||
}
|
||||
err := storage.Set(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Call the Allow method
|
||||
allow, err := rl.Allow(tag)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the result
|
||||
require.True(t, allow)
|
||||
}
|
||||
|
|
|
@ -769,7 +769,7 @@ func (tm *Manager) GetBalancesAtByChain(parent context.Context, clients map[uint
|
|||
BlockNumber: atBlock,
|
||||
}, accounts)
|
||||
if err != nil {
|
||||
log.Error("can't fetch chain balance 5", err)
|
||||
log.Error("can't fetch chain balance 5", "err", err)
|
||||
return nil
|
||||
}
|
||||
for idx, account := range accounts {
|
||||
|
@ -804,7 +804,7 @@ func (tm *Manager) GetBalancesAtByChain(parent context.Context, clients map[uint
|
|||
}
|
||||
|
||||
if len(res) != len(chunk) {
|
||||
log.Error("can't fetch erc20 token balance 7", "account", account, "error response not complete")
|
||||
log.Error("can't fetch erc20 token balance 7", "account", account, "error", "response not complete")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -823,7 +823,7 @@ func (tm *Manager) GetBalancesAtByChain(parent context.Context, clients map[uint
|
|||
balance, err := tm.GetTokenBalanceAt(ctx, client, account, token, atBlock)
|
||||
if err != nil {
|
||||
if err != bind.ErrNoCode {
|
||||
log.Error("can't fetch erc20 token balance 8", "account", account, "token", token, "error on fetching token balance")
|
||||
log.Error("can't fetch erc20 token balance 8", "account", account, "token", token, "error", "on fetching token balance")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1130,6 +1130,7 @@ func (c *loadBlocksAndTransfersCommand) fetchHistoryBlocksForAccount(group *asyn
|
|||
|
||||
// Check if limit is already reached, then skip the comamnd
|
||||
if allow, _ := limiter.Allow(accountTag); !allow {
|
||||
log.Debug("fetchHistoryBlocksForAccount limit reached", "account", account, "chain", c.chainClient.NetworkID())
|
||||
continue
|
||||
}
|
||||
|
||||
|
|
|
@ -540,7 +540,7 @@ func removeGasOnlyEthTransfer(creator statementCreator, t transferDBFields) erro
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debug("removeGasOnlyEthTransfer row deleted ", count)
|
||||
log.Debug("removeGasOnlyEthTransfer rows deleted", "count", count)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -533,7 +533,7 @@ func (d *ERC20TransfersDownloader) blocksFromLogs(parent context.Context, logs [
|
|||
// time to get logs for 100000 blocks = 1.144686979s. with 249 events in the result set.
|
||||
func (d *ERC20TransfersDownloader) GetHeadersInRange(parent context.Context, from, to *big.Int) ([]*DBHeader, error) {
|
||||
start := time.Now()
|
||||
log.Debug("get erc20 transfers in range start", "chainID", d.client.NetworkID(), "from", from, "to", to)
|
||||
log.Debug("get erc20 transfers in range start", "chainID", d.client.NetworkID(), "from", from, "to", to, "accounts", d.accounts)
|
||||
headers := []*DBHeader{}
|
||||
ctx := context.Background()
|
||||
var err error
|
||||
|
@ -596,7 +596,7 @@ func (d *ERC20TransfersDownloader) GetHeadersInRange(parent context.Context, fro
|
|||
}
|
||||
|
||||
log.Debug("get erc20 transfers in range end", "chainID", d.client.NetworkID(),
|
||||
"from", from, "to", to, "headers", len(headers), "took", time.Since(start))
|
||||
"from", from, "to", to, "headers", len(headers), "accounts", d.accounts, "took", time.Since(start))
|
||||
return headers, nil
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue