status-go/services/wallet/iterative_test.go

119 lines
2.8 KiB
Go

package wallet
import (
"context"
"errors"
"math/big"
"testing"
"github.com/stretchr/testify/require"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
)
type transfersFixture []Transfer
func (f transfersFixture) GetTransfersInRange(ctx context.Context, from, to *big.Int) ([]Transfer, error) {
rst := []Transfer{}
for _, t := range f {
if t.BlockNumber.Cmp(from) >= 0 && t.BlockNumber.Cmp(to) <= 0 {
rst = append(rst, t)
}
}
return rst, nil
}
func TestIterFinished(t *testing.T) {
iterator := IterativeDownloader{
from: &DBHeader{Number: big.NewInt(10)},
to: &DBHeader{Number: big.NewInt(10)},
}
require.True(t, iterator.Finished())
}
func TestIterNotFinished(t *testing.T) {
iterator := IterativeDownloader{
from: &DBHeader{Number: big.NewInt(2)},
to: &DBHeader{Number: big.NewInt(5)},
}
require.False(t, iterator.Finished())
}
func TestIterRevert(t *testing.T) {
iterator := IterativeDownloader{
from: &DBHeader{Number: big.NewInt(12)},
to: &DBHeader{Number: big.NewInt(12)},
previous: &DBHeader{Number: big.NewInt(9)},
}
require.True(t, iterator.Finished())
iterator.Revert()
require.False(t, iterator.Finished())
}
func TestIterProgress(t *testing.T) {
var (
chain headers = genHeadersChain(10, 1)
transfers = make(transfersFixture, 10)
)
for i := range transfers {
transfers[i] = Transfer{
BlockNumber: chain[i].Number,
BlockHash: chain[i].Hash(),
}
}
iter := &IterativeDownloader{
client: chain,
downloader: transfers,
batchSize: big.NewInt(5),
from: &DBHeader{Number: big.NewInt(0)},
to: &DBHeader{Number: big.NewInt(9)},
}
batch, err := iter.Next(context.TODO())
require.NoError(t, err)
require.Len(t, batch, 6)
batch, err = iter.Next(context.TODO())
require.NoError(t, err)
require.Len(t, batch, 5)
require.True(t, iter.Finished())
}
type headers []*types.Header
func (h headers) HeaderByHash(ctx context.Context, hash common.Hash) (*types.Header, error) {
for _, item := range h {
if item.Hash() == hash {
return item, nil
}
}
return nil, errors.New("not found")
}
func (h headers) HeaderByNumber(ctx context.Context, number *big.Int) (*types.Header, error) {
for _, item := range h {
if item.Number.Cmp(number) == 0 {
return item, nil
}
}
return nil, errors.New("not found")
}
func (h headers) BalanceAt(ctx context.Context, account common.Address, blockNumber *big.Int) (*big.Int, error) {
return nil, errors.New("not implemented")
}
func genHeadersChain(size, difficulty int) []*types.Header {
rst := make([]*types.Header, size)
for i := 0; i < size; i++ {
rst[i] = &types.Header{
Number: big.NewInt(int64(i)),
Difficulty: big.NewInt(int64(difficulty)),
Time: 1,
}
if i != 0 {
rst[i].ParentHash = rst[i-1].Hash()
}
}
return rst
}