feat(wallet): keep multi-transactions relation after transaction is complete

Updates status-desktop #7663
This commit is contained in:
Stefan 2023-02-15 14:28:19 +02:00 committed by Stefan Dunca
parent 226fa7d696
commit 7c7c3a1f13
13 changed files with 254 additions and 187 deletions

View File

@ -114,7 +114,7 @@ func TestTransactionNotification(t *testing.T) {
Nonce: &nonce,
}
require.NoError(t, walletDb.ProcessBlocks(1777, header.Address, big.NewInt(1), lastBlock, []*transfer.DBHeader{header}))
require.NoError(t, walletDb.ProcessTranfers(1777, transfers, []*transfer.DBHeader{}))
require.NoError(t, walletDb.ProcessTransfers(1777, transfers, []*transfer.DBHeader{}))
feed.Send(walletevent.Event{
Type: transfer.EventRecentHistoryReady,

View File

@ -224,54 +224,54 @@ func (api *API) DeleteSavedAddress(ctx context.Context, address common.Address)
return err
}
func (api *API) GetPendingTransactions(ctx context.Context) ([]*PendingTransaction, error) {
func (api *API) GetPendingTransactions(ctx context.Context) ([]*transfer.PendingTransaction, error) {
log.Debug("call to get pending transactions")
rst, err := api.s.transactionManager.getAllPendings([]uint64{api.s.rpcClient.UpstreamChainID})
rst, err := api.s.transactionManager.GetAllPending([]uint64{api.s.rpcClient.UpstreamChainID})
log.Debug("result from database for pending transactions", "len", len(rst))
return rst, err
}
func (api *API) GetPendingTransactionsByChainIDs(ctx context.Context, chainIDs []uint64) ([]*PendingTransaction, error) {
func (api *API) GetPendingTransactionsByChainIDs(ctx context.Context, chainIDs []uint64) ([]*transfer.PendingTransaction, error) {
log.Debug("call to get pending transactions")
rst, err := api.s.transactionManager.getAllPendings(chainIDs)
rst, err := api.s.transactionManager.GetAllPending(chainIDs)
log.Debug("result from database for pending transactions", "len", len(rst))
return rst, err
}
func (api *API) GetPendingOutboundTransactionsByAddress(ctx context.Context, address common.Address) ([]*PendingTransaction, error) {
func (api *API) GetPendingOutboundTransactionsByAddress(ctx context.Context, address common.Address) ([]*transfer.PendingTransaction, error) {
log.Debug("call to get pending outbound transactions by address")
rst, err := api.s.transactionManager.getPendingByAddress([]uint64{api.s.rpcClient.UpstreamChainID}, address)
rst, err := api.s.transactionManager.GetPendingByAddress([]uint64{api.s.rpcClient.UpstreamChainID}, address)
log.Debug("result from database for pending transactions by address", "len", len(rst))
return rst, err
}
func (api *API) GetPendingOutboundTransactionsByAddressAndChainID(ctx context.Context, chainIDs []uint64, address common.Address) ([]*PendingTransaction, error) {
func (api *API) GetPendingOutboundTransactionsByAddressAndChainID(ctx context.Context, chainIDs []uint64, address common.Address) ([]*transfer.PendingTransaction, error) {
log.Debug("call to get pending outbound transactions by address")
rst, err := api.s.transactionManager.getPendingByAddress(chainIDs, address)
rst, err := api.s.transactionManager.GetPendingByAddress(chainIDs, address)
log.Debug("result from database for pending transactions by address", "len", len(rst))
return rst, err
}
func (api *API) StorePendingTransaction(ctx context.Context, trx PendingTransaction) error {
func (api *API) StorePendingTransaction(ctx context.Context, trx transfer.PendingTransaction) error {
log.Debug("call to create or edit pending transaction")
if trx.ChainID == 0 {
trx.ChainID = api.s.rpcClient.UpstreamChainID
}
err := api.s.transactionManager.addPending(trx)
err := api.s.transactionManager.AddPending(trx)
log.Debug("result from database for creating or editing a pending transaction", "err", err)
return err
}
func (api *API) DeletePendingTransaction(ctx context.Context, transactionHash common.Hash) error {
log.Debug("call to remove pending transaction")
err := api.s.transactionManager.deletePending(api.s.rpcClient.UpstreamChainID, transactionHash)
err := api.s.transactionManager.DeletePending(api.s.rpcClient.UpstreamChainID, transactionHash)
log.Debug("result from database for remove pending transaction", "err", err)
return err
}
func (api *API) DeletePendingTransactionByChainID(ctx context.Context, chainID uint64, transactionHash common.Hash) error {
log.Debug("call to remove pending transaction")
err := api.s.transactionManager.deletePending(chainID, transactionHash)
err := api.s.transactionManager.DeletePending(chainID, transactionHash)
log.Debug("result from database for remove pending transaction", "err", err)
return err
}
@ -281,7 +281,7 @@ func (api *API) WatchTransaction(ctx context.Context, transactionHash common.Has
if err != nil {
return err
}
return api.s.transactionManager.watch(ctx, transactionHash, chainClient)
return api.s.transactionManager.Watch(ctx, transactionHash, chainClient)
}
func (api *API) WatchTransactionByChainID(ctx context.Context, chainID uint64, transactionHash common.Hash) error {
@ -289,7 +289,7 @@ func (api *API) WatchTransactionByChainID(ctx context.Context, chainID uint64, t
if err != nil {
return err
}
return api.s.transactionManager.watch(ctx, transactionHash, chainClient)
return api.s.transactionManager.Watch(ctx, transactionHash, chainClient)
}
func (api *API) GetCryptoOnRamps(ctx context.Context) ([]CryptoOnRamp, error) {
@ -545,9 +545,9 @@ func (api *API) getDerivedAddress(id string, derivedPath string) (*DerivedAddres
return address, nil
}
func (api *API) CreateMultiTransaction(ctx context.Context, multiTransaction *MultiTransaction, data []*bridge.TransactionBridge, password string) (*MultiTransactionResult, error) {
func (api *API) CreateMultiTransaction(ctx context.Context, multiTransaction *transfer.MultiTransaction, data []*bridge.TransactionBridge, password string) (*transfer.MultiTransactionResult, error) {
log.Debug("[WalletAPI:: CreateMultiTransaction] create multi transaction")
return api.s.transactionManager.createMultiTransaction(ctx, multiTransaction, data, api.router.bridges, password)
return api.s.transactionManager.CreateMultiTransaction(ctx, multiTransaction, data, api.router.bridges, password)
}
func (api *API) GetCachedCurrencyFormats() (currency.FormatPerSymbol, error) {

View File

@ -61,8 +61,8 @@ func NewService(
}
tokenManager := token.NewTokenManager(db, rpcClient, rpcClient.NetworkManager)
savedAddressesManager := &SavedAddressesManager{db: db}
transactionManager := &TransactionManager{db: db, transactor: transactor, gethManager: gethManager, config: config, accountsDB: accountsDB}
transferController := transfer.NewTransferController(db, rpcClient, accountFeed, walletFeed)
transactionManager := transfer.NewTransactionManager(db, gethManager, transactor, config, accountsDB)
transferController := transfer.NewTransferController(db, rpcClient, accountFeed, walletFeed, transactionManager)
cryptoCompare := cryptocompare.NewClient()
coingecko := coingecko.NewClient()
marketManager := market.NewManager(cryptoCompare, coingecko)
@ -100,7 +100,7 @@ type Service struct {
rpcClient *rpc.Client
savedAddressesManager *SavedAddressesManager
tokenManager *token.Manager
transactionManager *TransactionManager
transactionManager *transfer.TransactionManager
cryptoOnRampManager *CryptoOnRampManager
transferController *transfer.Controller
feesManager *FeeManager

View File

@ -2,6 +2,7 @@ package transfer
import (
"context"
"database/sql"
"math/big"
"strings"
"time"
@ -168,10 +169,11 @@ type controlCommand struct {
feed *event.Feed
errorsCount int
nonArchivalRPCNode bool
transactionManager *TransactionManager
}
func (c *controlCommand) LoadTransfers(ctx context.Context, downloader *ETHDownloader, limit int) (map[common.Address][]Transfer, error) {
return loadTransfers(ctx, c.accounts, c.block, c.db, c.chainClient, limit, make(map[common.Address][]*big.Int))
return loadTransfers(ctx, c.accounts, c.block, c.db, c.chainClient, limit, make(map[common.Address][]*big.Int), c.transactionManager)
}
func (c *controlCommand) Run(parent context.Context) error {
@ -262,6 +264,7 @@ func (c *controlCommand) Run(parent context.Context) error {
signer: types.NewLondonSigner(c.chainClient.ToBigInt()),
db: c.db,
}
_, err = c.LoadTransfers(parent, downloader, 40)
if err != nil {
if c.NewError(err) {
@ -333,12 +336,13 @@ func (c *controlCommand) Command() async.Command {
}
type transfersCommand struct {
db *Database
eth *ETHDownloader
block *big.Int
address common.Address
chainClient *chain.ClientWithFallback
fetchedTransfers []Transfer
db *Database
eth *ETHDownloader
block *big.Int
address common.Address
chainClient *chain.ClientWithFallback
fetchedTransfers []Transfer
transactionManager *TransactionManager
}
func (c *transfersCommand) Command() async.Command {
@ -355,10 +359,36 @@ func (c *transfersCommand) Run(ctx context.Context) (err error) {
return err
}
err = c.db.SaveTranfers(c.chainClient.ChainID, c.address, allTransfers, []*big.Int{c.block})
if err != nil {
log.Error("SaveTranfers error", "error", err)
return err
// Update MultiTransactionID from pending entry
for index := range allTransfers {
transfer := &allTransfers[index]
if transfer.MultiTransactionID == NoMultiTransactionID {
entry, err := c.transactionManager.GetPendingEntry(c.chainClient.ChainID, transfer.ID)
if err != nil {
if err == sql.ErrNoRows {
log.Warn("Pending transaction not found for", c.chainClient.ChainID, transfer.ID)
} else {
return err
}
} else {
transfer.MultiTransactionID = entry.MultiTransactionID
if transfer.Receipt != nil && transfer.Receipt.Status == types.ReceiptStatusSuccessful {
// TODO: Nim logic was deleting pending previously, should we notify UI about it?
err := c.transactionManager.DeletePending(c.chainClient.ChainID, transfer.ID)
if err != nil {
return err
}
}
}
}
}
if len(allTransfers) > 0 {
err = c.db.SaveTransfers(c.chainClient.ChainID, c.address, allTransfers, []*big.Int{c.block})
if err != nil {
log.Error("SaveTransfers error", "error", err)
return err
}
}
c.fetchedTransfers = allTransfers
@ -373,6 +403,7 @@ type loadTransfersCommand struct {
chainClient *chain.ClientWithFallback
blocksByAddress map[common.Address][]*big.Int
foundTransfersByAddress map[common.Address][]Transfer
transactionManager *TransactionManager
}
func (c *loadTransfersCommand) Command() async.Command {
@ -382,8 +413,8 @@ func (c *loadTransfersCommand) Command() async.Command {
}.Run
}
func (c *loadTransfersCommand) LoadTransfers(ctx context.Context, downloader *ETHDownloader, limit int, blocksByAddress map[common.Address][]*big.Int) (map[common.Address][]Transfer, error) {
return loadTransfers(ctx, c.accounts, c.block, c.db, c.chainClient, limit, blocksByAddress)
func (c *loadTransfersCommand) LoadTransfers(ctx context.Context, downloader *ETHDownloader, limit int, blocksByAddress map[common.Address][]*big.Int, transactionManager *TransactionManager) (map[common.Address][]Transfer, error) {
return loadTransfers(ctx, c.accounts, c.block, c.db, c.chainClient, limit, blocksByAddress, c.transactionManager)
}
func (c *loadTransfersCommand) Run(parent context.Context) (err error) {
@ -393,7 +424,7 @@ func (c *loadTransfersCommand) Run(parent context.Context) (err error) {
signer: types.NewLondonSigner(c.chainClient.ToBigInt()),
db: c.db,
}
transfersByAddress, err := c.LoadTransfers(parent, downloader, 40, c.blocksByAddress)
transfersByAddress, err := c.LoadTransfers(parent, downloader, 40, c.blocksByAddress, c.transactionManager)
if err != nil {
return err
}
@ -570,7 +601,7 @@ func (c *findAndCheckBlockRangeCommand) fastIndexErc20(ctx context.Context, from
}
}
func loadTransfers(ctx context.Context, accounts []common.Address, block *Block, db *Database, chainClient *chain.ClientWithFallback, limit int, blocksByAddress map[common.Address][]*big.Int) (map[common.Address][]Transfer, error) {
func loadTransfers(ctx context.Context, accounts []common.Address, block *Block, db *Database, chainClient *chain.ClientWithFallback, limit int, blocksByAddress map[common.Address][]*big.Int, transactionManager *TransactionManager) (map[common.Address][]Transfer, error) {
start := time.Now()
group := async.NewGroup(ctx)
@ -592,7 +623,8 @@ func loadTransfers(ctx context.Context, accounts []common.Address, block *Block,
signer: types.NewLondonSigner(chainClient.ToBigInt()),
db: db,
},
block: block,
block: block,
transactionManager: transactionManager,
}
commands = append(commands, transfers)
group.Add(transfers.Command())

View File

@ -17,24 +17,26 @@ import (
)
type Controller struct {
db *Database
rpcClient *rpc.Client
block *Block
reactor *Reactor
accountFeed *event.Feed
TransferFeed *event.Feed
group *async.Group
balanceCache *balanceCache
db *Database
rpcClient *rpc.Client
block *Block
reactor *Reactor
accountFeed *event.Feed
TransferFeed *event.Feed
group *async.Group
balanceCache *balanceCache
transactionManager *TransactionManager
}
func NewTransferController(db *sql.DB, rpcClient *rpc.Client, accountFeed *event.Feed, transferFeed *event.Feed) *Controller {
func NewTransferController(db *sql.DB, rpcClient *rpc.Client, accountFeed *event.Feed, transferFeed *event.Feed, transactionManager *TransactionManager) *Controller {
block := &Block{db}
return &Controller{
db: NewDB(db),
block: block,
rpcClient: rpcClient,
accountFeed: accountFeed,
TransferFeed: transferFeed,
db: NewDB(db),
block: block,
rpcClient: rpcClient,
accountFeed: accountFeed,
TransferFeed: transferFeed,
transactionManager: transactionManager,
}
}
@ -99,9 +101,10 @@ func (c *Controller) CheckRecentHistory(chainIDs []uint64, accounts []common.Add
}
c.reactor = &Reactor{
db: c.db,
feed: c.TransferFeed,
block: c.block,
db: c.db,
feed: c.TransferFeed,
block: c.block,
transactionManager: c.transactionManager,
}
err = c.reactor.start(chainClients, accounts)
if err != nil {
@ -114,7 +117,7 @@ func (c *Controller) CheckRecentHistory(chainIDs []uint64, accounts []common.Add
return nil
}
// watchAccountsChanges subsribes to a feed and watches for changes in accounts list. If there are new or removed accounts
// watchAccountsChanges subscribes to a feed and watches for changes in accounts list. If there are new or removed accounts
// reactor will be restarted.
func watchAccountsChanges(ctx context.Context, accountFeed *event.Feed, reactor *Reactor, chainClients []*chain.ClientWithFallback, initial []common.Address) error {
accounts := make(chan []*accounts.Account, 1) // it may block if the rate of updates will be significantly higher
@ -184,7 +187,7 @@ func (c *Controller) LoadTransferByHash(ctx context.Context, rpcClient *rpc.Clie
}
blocks := []*big.Int{transfer.BlockNumber}
err = c.db.SaveTranfers(rpcClient.UpstreamChainID, address, transfers, blocks)
err = c.db.SaveTransfers(rpcClient.UpstreamChainID, address, transfers, blocks)
if err != nil {
return err
}
@ -260,10 +263,11 @@ func (c *Controller) GetTransfersByAddress(ctx context.Context, chainID uint64,
log.Info("checking blocks again", "blocks", len(blocks))
if len(blocks) > 0 {
txCommand := &loadTransfersCommand{
accounts: []common.Address{address},
db: c.db,
block: c.block,
chainClient: chainClient,
accounts: []common.Address{address},
db: c.db,
block: c.block,
chainClient: chainClient,
transactionManager: c.transactionManager,
}
err = txCommand.Command()(ctx)

View File

@ -137,8 +137,8 @@ func (db *Database) SaveBlocks(chainID uint64, account common.Address, headers [
return
}
// ProcessTranfers atomically adds/removes blocks and adds new tranfers.
func (db *Database) ProcessTranfers(chainID uint64, transfers []Transfer, removed []*DBHeader) (err error) {
// ProcessTransfers atomically adds/removes blocks and adds new transfers.
func (db *Database) ProcessTransfers(chainID uint64, transfers []Transfer, removed []*DBHeader) (err error) {
var (
tx *sql.Tx
)
@ -153,10 +153,12 @@ func (db *Database) ProcessTranfers(chainID uint64, transfers []Transfer, remove
}
_ = tx.Rollback()
}()
err = deleteHeaders(tx, removed)
if err != nil {
return
}
err = updateOrInsertTransfers(chainID, tx, transfers)
if err != nil {
return
@ -164,11 +166,9 @@ func (db *Database) ProcessTranfers(chainID uint64, transfers []Transfer, remove
return
}
// SaveTranfers
func (db *Database) SaveTranfers(chainID uint64, address common.Address, transfers []Transfer, blocks []*big.Int) (err error) {
var (
tx *sql.Tx
)
// SaveTransfers
func (db *Database) SaveTransfers(chainID uint64, address common.Address, transfers []Transfer, blocks []*big.Int) (err error) {
var tx *sql.Tx
tx, err = db.client.Begin()
if err != nil {
return err
@ -389,8 +389,8 @@ func insertBlocksWithTransactions(chainID uint64, creator statementCreator, acco
}
insertTx, err := creator.Prepare(`INSERT OR IGNORE
INTO transfers (network_id, address, sender, hash, blk_number, blk_hash, type, timestamp, log, loaded)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, 0)`)
INTO transfers (network_id, address, sender, hash, blk_number, blk_hash, type, timestamp, log, loaded, multi_transaction_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, 0, ?)`)
if err != nil {
return err
}
@ -414,7 +414,7 @@ func insertBlocksWithTransactions(chainID uint64, creator statementCreator, acco
continue
}
_, err = insertTx.Exec(chainID, account, account, transfer.ID, (*bigint.SQLBigInt)(header.Number), header.Hash, erc20Transfer, transfer.Timestamp, &JSONBlob{transfer.Log})
_, err = insertTx.Exec(chainID, account, account, transfer.ID, (*bigint.SQLBigInt)(header.Number), header.Hash, erc20Transfer, transfer.Timestamp, &JSONBlob{transfer.Log}, transfer.MultiTransactionID)
if err != nil {
log.Error("error saving erc20transfer", "err", err)
return err
@ -434,9 +434,9 @@ func updateOrInsertTransfers(chainID uint64, creator statementCreator, transfers
}
insert, err := creator.Prepare(`INSERT OR IGNORE INTO transfers
(network_id, hash, blk_hash, blk_number, timestamp, address, tx, sender, receipt, log, type, loaded, base_gas_fee)
(network_id, hash, blk_hash, blk_number, timestamp, address, tx, sender, receipt, log, type, loaded, base_gas_fee, multi_transaction_id)
VALUES
(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 1, ?)`)
(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 1, ?, ?)`)
if err != nil {
return err
}
@ -454,7 +454,7 @@ func updateOrInsertTransfers(chainID uint64, creator statementCreator, transfers
continue
}
_, err = insert.Exec(chainID, t.ID, t.BlockHash, (*bigint.SQLBigInt)(t.BlockNumber), t.Timestamp, t.Address, &JSONBlob{t.Transaction}, t.From, &JSONBlob{t.Receipt}, &JSONBlob{t.Log}, t.Type, t.BaseGasFees)
_, err = insert.Exec(chainID, t.ID, t.BlockHash, (*bigint.SQLBigInt)(t.BlockNumber), t.Timestamp, t.Address, &JSONBlob{t.Transaction}, t.From, &JSONBlob{t.Receipt}, &JSONBlob{t.Log}, t.Type, t.BaseGasFees, t.MultiTransactionID)
if err != nil {
log.Error("can't save transfer", "b-hash", t.BlockHash, "b-n", t.BlockNumber, "a", t.Address, "h", t.ID)
return err

View File

@ -83,7 +83,7 @@ func TestDBProcessBlocks(t *testing.T) {
From: common.Address{1},
},
}
require.NoError(t, db.SaveTranfers(777, address, transfers, []*big.Int{big.NewInt(1), big.NewInt(2)}))
require.NoError(t, db.SaveTransfers(777, address, transfers, []*big.Int{big.NewInt(1), big.NewInt(2)}))
}
func TestDBProcessTransfer(t *testing.T) {
@ -97,13 +97,14 @@ func TestDBProcessTransfer(t *testing.T) {
tx := types.NewTransaction(1, common.Address{1}, nil, 10, big.NewInt(10), nil)
transfers := []Transfer{
{
ID: common.Hash{1},
Type: ethTransfer,
BlockHash: header.Hash,
BlockNumber: header.Number,
Transaction: tx,
Receipt: types.NewReceipt(nil, false, 100),
Address: common.Address{1},
ID: common.Hash{1},
Type: ethTransfer,
BlockHash: header.Hash,
BlockNumber: header.Number,
Transaction: tx,
Receipt: types.NewReceipt(nil, false, 100),
Address: common.Address{1},
MultiTransactionID: 0,
},
}
nonce := int64(0)
@ -113,7 +114,7 @@ func TestDBProcessTransfer(t *testing.T) {
Nonce: &nonce,
}
require.NoError(t, db.ProcessBlocks(777, common.Address{1}, big.NewInt(1), lastBlock, []*DBHeader{header}))
require.NoError(t, db.ProcessTranfers(777, transfers, []*DBHeader{}))
require.NoError(t, db.ProcessTransfers(777, transfers, []*DBHeader{}))
}
func TestDBReorgTransfers(t *testing.T) {
@ -140,8 +141,8 @@ func TestDBReorgTransfers(t *testing.T) {
Nonce: &nonce,
}
require.NoError(t, db.ProcessBlocks(777, original.Address, original.Number, lastBlock, []*DBHeader{original}))
require.NoError(t, db.ProcessTranfers(777, []Transfer{
{ethTransfer, common.Hash{1}, *originalTX.To(), original.Number, original.Hash, 100, originalTX, true, 1777, common.Address{1}, rcpt, nil, "2100"},
require.NoError(t, db.ProcessTransfers(777, []Transfer{
{ethTransfer, common.Hash{1}, *originalTX.To(), original.Number, original.Hash, 100, originalTX, true, 1777, common.Address{1}, rcpt, nil, "2100", NoMultiTransactionID},
}, []*DBHeader{}))
nonce = int64(0)
lastBlock = &LastKnownBlock{
@ -150,8 +151,8 @@ func TestDBReorgTransfers(t *testing.T) {
Nonce: &nonce,
}
require.NoError(t, db.ProcessBlocks(777, replaced.Address, replaced.Number, lastBlock, []*DBHeader{replaced}))
require.NoError(t, db.ProcessTranfers(777, []Transfer{
{ethTransfer, common.Hash{2}, *replacedTX.To(), replaced.Number, replaced.Hash, 100, replacedTX, true, 1777, common.Address{1}, rcpt, nil, "2100"},
require.NoError(t, db.ProcessTransfers(777, []Transfer{
{ethTransfer, common.Hash{2}, *replacedTX.To(), replaced.Number, replaced.Hash, 100, replacedTX, true, 1777, common.Address{1}, rcpt, nil, "2100", NoMultiTransactionID},
}, []*DBHeader{original}))
all, err := db.GetTransfers(777, big.NewInt(0), nil)
@ -193,7 +194,7 @@ func TestDBGetTransfersFromBlock(t *testing.T) {
Nonce: &nonce,
}
require.NoError(t, db.ProcessBlocks(777, headers[0].Address, headers[0].Number, lastBlock, headers))
require.NoError(t, db.ProcessTranfers(777, transfers, []*DBHeader{}))
require.NoError(t, db.ProcessTransfers(777, transfers, []*DBHeader{}))
rst, err := db.GetTransfers(777, big.NewInt(7), nil)
require.NoError(t, err)
require.Len(t, rst, 3)

View File

@ -18,12 +18,15 @@ import (
// Type type of the asset that was transferred.
type Type string
type MultiTransactionIDType int64
const (
ethTransfer Type = "eth"
erc20Transfer Type = "erc20"
erc20TransferEventSignature = "Transfer(address,address,uint256)"
NoMultiTransactionID = MultiTransactionIDType(0)
)
var (
@ -49,6 +52,8 @@ type Transfer struct {
// Log that was used to generate erc20 transfer. Nil for eth transfer.
Log *types.Log `json:"log"`
BaseGasFees string
// Internal field that is used to track multi-transaction transfers.
MultiTransactionID MultiTransactionIDType `json:"multi_transaction_id"`
}
// ETHDownloader downloads regular eth transfers.

View File

@ -10,7 +10,7 @@ import (
"github.com/status-im/status-go/services/wallet/bigint"
)
const baseTransfersQuery = "SELECT hash, type, blk_hash, blk_number, timestamp, address, tx, sender, receipt, log, network_id, base_gas_fee FROM transfers"
const baseTransfersQuery = "SELECT hash, type, blk_hash, blk_number, timestamp, address, tx, sender, receipt, log, network_id, base_gas_fee, COALESCE(multi_transaction_id, 0) FROM transfers"
func newTransfersQuery() *transfersQuery {
buf := bytes.NewBuffer(nil)
@ -119,7 +119,7 @@ func (q *transfersQuery) Scan(rows *sql.Rows) (rst []Transfer, err error) {
err = rows.Scan(
&transfer.ID, &transfer.Type, &transfer.BlockHash,
(*bigint.SQLBigInt)(transfer.BlockNumber), &transfer.Timestamp, &transfer.Address,
&JSONBlob{transfer.Transaction}, &transfer.From, &JSONBlob{transfer.Receipt}, &JSONBlob{transfer.Log}, &transfer.NetworkID, &transfer.BaseGasFees)
&JSONBlob{transfer.Transaction}, &transfer.From, &JSONBlob{transfer.Receipt}, &JSONBlob{transfer.Log}, &transfer.NetworkID, &transfer.BaseGasFees, &transfer.MultiTransactionID)
if err != nil {
return nil, err
}

View File

@ -30,11 +30,12 @@ type BalanceReader interface {
// Reactor listens to new blocks and stores transfers into the database.
type Reactor struct {
db *Database
block *Block
feed *event.Feed
mu sync.Mutex
group *async.Group
db *Database
block *Block
feed *event.Feed
mu sync.Mutex
group *async.Group
transactionManager *TransactionManager
}
func (r *Reactor) newControlCommand(chainClient *chain.ClientWithFallback, accounts []common.Address) *controlCommand {
@ -50,9 +51,10 @@ func (r *Reactor) newControlCommand(chainClient *chain.ClientWithFallback, accou
signer: signer,
db: r.db,
},
erc20: NewERC20TransfersDownloader(chainClient, accounts, signer),
feed: r.feed,
errorsCount: 0,
erc20: NewERC20TransfersDownloader(chainClient, accounts, signer),
feed: r.feed,
errorsCount: 0,
transactionManager: r.transactionManager,
}
return ctl

View File

@ -1,4 +1,4 @@
package wallet
package transfer
import (
"context"
@ -31,6 +31,17 @@ type TransactionManager struct {
accountsDB *accounts.Database
}
func NewTransactionManager(db *sql.DB, gethManager *account.GethManager, transactor *transactions.Transactor,
config *params.NodeConfig, accountsDB *accounts.Database) *TransactionManager {
return &TransactionManager{
db: db,
gethManager: gethManager,
transactor: transactor,
config: config,
accountsDB: accountsDB,
}
}
type MultiTransactionType uint8
const (
@ -67,44 +78,28 @@ const (
)
type PendingTransaction struct {
Hash common.Hash `json:"hash"`
Timestamp uint64 `json:"timestamp"`
Value bigint.BigInt `json:"value"`
From common.Address `json:"from"`
To common.Address `json:"to"`
Data string `json:"data"`
Symbol string `json:"symbol"`
GasPrice bigint.BigInt `json:"gasPrice"`
GasLimit bigint.BigInt `json:"gasLimit"`
Type PendingTrxType `json:"type"`
AdditionalData string `json:"additionalData"`
ChainID uint64 `json:"network_id"`
MultiTransactionID int64 `json:"multi_transaction_id"`
Hash common.Hash `json:"hash"`
Timestamp uint64 `json:"timestamp"`
Value bigint.BigInt `json:"value"`
From common.Address `json:"from"`
To common.Address `json:"to"`
Data string `json:"data"`
Symbol string `json:"symbol"`
GasPrice bigint.BigInt `json:"gasPrice"`
GasLimit bigint.BigInt `json:"gasLimit"`
Type PendingTrxType `json:"type"`
AdditionalData string `json:"additionalData"`
ChainID uint64 `json:"network_id"`
MultiTransactionID MultiTransactionIDType `json:"multi_transaction_id"`
}
func (tm *TransactionManager) getAllPendings(chainIDs []uint64) ([]*PendingTransaction, error) {
if len(chainIDs) == 0 {
return nil, errors.New("at least 1 chainID is required")
}
const selectFromPending = `SELECT hash, timestamp, value, from_address, to_address, data,
symbol, gas_price, gas_limit, type, additional_data,
network_id, COALESCE(multi_transaction_id, 0)
FROM pending_transactions
`
inVector := strings.Repeat("?, ", len(chainIDs)-1) + "?"
var parameters []interface{}
for _, c := range chainIDs {
parameters = append(parameters, c)
}
rows, err := tm.db.Query(fmt.Sprintf(`SELECT hash, timestamp, value, from_address, to_address, data,
symbol, gas_price, gas_limit, type, additional_data,
network_id
FROM pending_transactions
WHERE network_id in (%s)`, inVector), parameters...)
if err != nil {
return nil, err
}
defer rows.Close()
var transactions []*PendingTransaction
func rowsToTransactions(rows *sql.Rows) (transactions []*PendingTransaction, err error) {
for rows.Next() {
transaction := &PendingTransaction{
Value: bigint.BigInt{Int: new(big.Int)},
@ -123,6 +118,7 @@ func (tm *TransactionManager) getAllPendings(chainIDs []uint64) ([]*PendingTrans
&transaction.Type,
&transaction.AdditionalData,
&transaction.ChainID,
&transaction.MultiTransactionID,
)
if err != nil {
return nil, err
@ -130,11 +126,30 @@ func (tm *TransactionManager) getAllPendings(chainIDs []uint64) ([]*PendingTrans
transactions = append(transactions, transaction)
}
return transactions, nil
}
func (tm *TransactionManager) getPendingByAddress(chainIDs []uint64, address common.Address) ([]*PendingTransaction, error) {
func (tm *TransactionManager) GetAllPending(chainIDs []uint64) ([]*PendingTransaction, error) {
if len(chainIDs) == 0 {
return nil, errors.New("at least 1 chainID is required")
}
inVector := strings.Repeat("?, ", len(chainIDs)-1) + "?"
var parameters []interface{}
for _, c := range chainIDs {
parameters = append(parameters, c)
}
rows, err := tm.db.Query(fmt.Sprintf(selectFromPending+"WHERE network_id in (%s)", inVector), parameters...)
if err != nil {
return nil, err
}
defer rows.Close()
return rowsToTransactions(rows)
}
func (tm *TransactionManager) GetPendingByAddress(chainIDs []uint64, address common.Address) ([]*PendingTransaction, error) {
if len(chainIDs) == 0 {
return nil, errors.New("at least 1 chainID is required")
}
@ -147,52 +162,56 @@ func (tm *TransactionManager) getPendingByAddress(chainIDs []uint64, address com
parameters = append(parameters, address)
rows, err := tm.db.Query(fmt.Sprintf(`SELECT hash, timestamp, value, from_address, to_address, data,
symbol, gas_price, gas_limit, type, additional_data,
network_id
FROM pending_transactions
WHERE network_id in (%s) AND from_address = ?`, inVector), parameters...)
rows, err := tm.db.Query(fmt.Sprintf(selectFromPending+"WHERE network_id in (%s) AND from_address = ?", inVector), parameters...)
if err != nil {
return nil, err
}
defer rows.Close()
var transactions []*PendingTransaction
for rows.Next() {
transaction := &PendingTransaction{
Value: bigint.BigInt{Int: new(big.Int)},
GasPrice: bigint.BigInt{Int: new(big.Int)},
GasLimit: bigint.BigInt{Int: new(big.Int)},
}
err := rows.Scan(&transaction.Hash,
&transaction.Timestamp,
(*bigint.SQLBigIntBytes)(transaction.Value.Int),
&transaction.From,
&transaction.To,
&transaction.Data,
&transaction.Symbol,
(*bigint.SQLBigIntBytes)(transaction.GasPrice.Int),
(*bigint.SQLBigIntBytes)(transaction.GasLimit.Int),
&transaction.Type,
&transaction.AdditionalData,
&transaction.ChainID,
)
if err != nil {
return nil, err
}
transactions = append(transactions, transaction)
}
return transactions, nil
return rowsToTransactions(rows)
}
func (tm *TransactionManager) addPending(transaction PendingTransaction) error {
// GetPendingEntry returns sql.ErrNoRows if no pending transaction is found for the given identity
func (tm *TransactionManager) GetPendingEntry(chainID uint64, hash common.Hash) (*PendingTransaction, error) {
row := tm.db.QueryRow(`SELECT timestamp, value, from_address, to_address, data,
symbol, gas_price, gas_limit, type, additional_data,
network_id, COALESCE(multi_transaction_id, 0)
FROM pending_transactions
WHERE network_id = ? AND hash = ?`, chainID, hash)
transaction := &PendingTransaction{
Hash: hash,
Value: bigint.BigInt{Int: new(big.Int)},
GasPrice: bigint.BigInt{Int: new(big.Int)},
GasLimit: bigint.BigInt{Int: new(big.Int)},
ChainID: chainID,
}
err := row.Scan(
&transaction.Timestamp,
(*bigint.SQLBigIntBytes)(transaction.Value.Int),
&transaction.From,
&transaction.To,
&transaction.Data,
&transaction.Symbol,
(*bigint.SQLBigIntBytes)(transaction.GasPrice.Int),
(*bigint.SQLBigIntBytes)(transaction.GasLimit.Int),
&transaction.Type,
&transaction.AdditionalData,
&transaction.ChainID,
&transaction.MultiTransactionID,
)
if err != nil {
return nil, err
}
return transaction, nil
}
func (tm *TransactionManager) AddPending(transaction PendingTransaction) error {
insert, err := tm.db.Prepare(`INSERT OR REPLACE INTO pending_transactions
(network_id, hash, timestamp, value, from_address, to_address,
data, symbol, gas_price, gas_limit, type, additional_data)
data, symbol, gas_price, gas_limit, type, additional_data, multi_transaction_id)
VALUES
(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
if err != nil {
return err
}
@ -209,16 +228,17 @@ func (tm *TransactionManager) addPending(transaction PendingTransaction) error {
(*bigint.SQLBigIntBytes)(transaction.GasLimit.Int),
transaction.Type,
transaction.AdditionalData,
transaction.MultiTransactionID,
)
return err
}
func (tm *TransactionManager) deletePending(chainID uint64, hash common.Hash) error {
func (tm *TransactionManager) DeletePending(chainID uint64, hash common.Hash) error {
_, err := tm.db.Exec(`DELETE FROM pending_transactions WHERE network_id = ? AND hash = ?`, chainID, hash)
return err
}
func (tm *TransactionManager) watch(ctx context.Context, transactionHash common.Hash, client *chain.ClientWithFallback) error {
func (tm *TransactionManager) Watch(ctx context.Context, transactionHash common.Hash, client *chain.ClientWithFallback) error {
watchTxCommand := &watchTransactionCommand{
hash: transactionHash,
client: client,
@ -230,7 +250,7 @@ func (tm *TransactionManager) watch(ctx context.Context, transactionHash common.
return watchTxCommand.Command()(commandContext)
}
func (tm *TransactionManager) createMultiTransaction(ctx context.Context, multiTransaction *MultiTransaction, data []*bridge.TransactionBridge, bridges map[string]bridge.Bridge, password string) (*MultiTransactionResult, error) {
func (tm *TransactionManager) CreateMultiTransaction(ctx context.Context, multiTransaction *MultiTransaction, data []*bridge.TransactionBridge, bridges map[string]bridge.Bridge, password string) (*MultiTransactionResult, error) {
selectedAccount, err := tm.getVerifiedWalletAccount(multiTransaction.FromAddress.Hex(), password)
if err != nil {
return nil, err
@ -267,7 +287,7 @@ func (tm *TransactionManager) createMultiTransaction(ctx context.Context, multiT
if err != nil {
return nil, err
}
err = tm.addPending(PendingTransaction{
pendingTransaction := PendingTransaction{
Hash: common.Hash(hash),
Timestamp: uint64(time.Now().Unix()),
Value: bigint.BigInt{Int: multiTransaction.FromAmount.ToInt()},
@ -276,9 +296,10 @@ func (tm *TransactionManager) createMultiTransaction(ctx context.Context, multiT
Data: tx.Data().String(),
Type: WalletTransfer,
ChainID: tx.ChainID,
MultiTransactionID: multiTransactionID,
MultiTransactionID: MultiTransactionIDType(multiTransactionID),
Symbol: multiTransaction.FromAsset,
})
}
err = tm.AddPending(pendingTransaction)
if err != nil {
return nil, err
}

View File

@ -1,4 +1,4 @@
package wallet
package transfer
import (
"io/ioutil"
@ -42,39 +42,39 @@ func TestPendingTransactions(t *testing.T) {
ChainID: 777,
}
rst, err := manager.getAllPendings([]uint64{777})
rst, err := manager.GetAllPending([]uint64{777})
require.NoError(t, err)
require.Nil(t, rst)
rst, err = manager.getPendingByAddress([]uint64{777}, trx.From)
rst, err = manager.GetPendingByAddress([]uint64{777}, trx.From)
require.NoError(t, err)
require.Nil(t, rst)
err = manager.addPending(trx)
err = manager.AddPending(trx)
require.NoError(t, err)
rst, err = manager.getPendingByAddress([]uint64{777}, trx.From)
rst, err = manager.GetPendingByAddress([]uint64{777}, trx.From)
require.NoError(t, err)
require.Equal(t, 1, len(rst))
require.Equal(t, trx, *rst[0])
rst, err = manager.getAllPendings([]uint64{777})
rst, err = manager.GetAllPending([]uint64{777})
require.NoError(t, err)
require.Equal(t, 1, len(rst))
require.Equal(t, trx, *rst[0])
rst, err = manager.getPendingByAddress([]uint64{777}, common.Address{2})
rst, err = manager.GetPendingByAddress([]uint64{777}, common.Address{2})
require.NoError(t, err)
require.Nil(t, rst)
err = manager.deletePending(777, trx.Hash)
err = manager.DeletePending(777, trx.Hash)
require.NoError(t, err)
rst, err = manager.getPendingByAddress([]uint64{777}, trx.From)
rst, err = manager.GetPendingByAddress([]uint64{777}, trx.From)
require.NoError(t, err)
require.Equal(t, 0, len(rst))
rst, err = manager.getAllPendings([]uint64{777})
rst, err = manager.GetAllPending([]uint64{777})
require.NoError(t, err)
require.Equal(t, 0, len(rst))
}

View File

@ -34,7 +34,7 @@ type View struct {
To common.Address `json:"to"`
Contract common.Address `json:"contract"`
NetworkID uint64 `json:"networkId"`
MultiTransactionID int64 `json:"multi_transaction_id"`
MultiTransactionID int64 `json:"multiTransactionID"`
BaseGasFees string `json:"base_gas_fee"`
}
@ -87,6 +87,8 @@ func CastToTransferView(t Transfer) View {
from, to, amount := parseLog(t.Log)
view.From, view.To, view.Value = from, to, (*hexutil.Big)(amount)
}
view.MultiTransactionID = int64(t.MultiTransactionID)
return view
}