From 892fcffce479fa315136a7be15b6d566473649a0 Mon Sep 17 00:00:00 2001 From: Godfrain Jacques Date: Tue, 11 Jun 2024 14:00:04 -0700 Subject: [PATCH] chore(communities)_: make member use wallet tokens during permission checking (#5268) fixes #14913 --- protocol/communities/manager.go | 13 +++- protocol/communities/manager_test.go | 8 +++ protocol/communities/permission_checker.go | 41 ++++++++++--- .../communities_messenger_helpers_test.go | 10 +++ services/wallet/api.go | 6 ++ services/wallet/collectibles/manager.go | 44 +++++++++++-- services/wallet/collectibles/ownership_db.go | 57 +++++++++++++++++ services/wallet/token/token.go | 61 +++++++++++++++++++ 8 files changed, 223 insertions(+), 17 deletions(-) diff --git a/protocol/communities/manager.go b/protocol/communities/manager.go index faaf7e9fc..3cac3ab86 100644 --- a/protocol/communities/manager.go +++ b/protocol/communities/manager.go @@ -243,7 +243,8 @@ type managerOptions struct { } type TokenManager interface { - GetBalancesByChain(ctx context.Context, accounts, tokens []gethcommon.Address, chainIDs []uint64) (map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big, error) + GetBalancesByChain(ctx context.Context, accounts, tokens []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error) + GetCachedBalancesByChain(ctx context.Context, accounts, tokenAddresses []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error) FindOrCreateTokenByAddress(ctx context.Context, chainID uint64, address gethcommon.Address) *token.Token GetAllChainIDs() ([]uint64, error) } @@ -301,6 +302,7 @@ func (m *DefaultTokenManager) GetAllChainIDs() ([]uint64, error) { type CollectiblesManager interface { FetchBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletcommon.ChainID, ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (thirdparty.TokenBalancesPerContractAddress, error) + FetchCachedBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletcommon.ChainID, ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (thirdparty.TokenBalancesPerContractAddress, error) GetCollectibleOwnership(id thirdparty.CollectibleUniqueID) ([]thirdparty.AccountBalance, error) FetchCollectibleOwnersByContractAddress(ctx context.Context, chainID walletcommon.ChainID, contractAddress gethcommon.Address) (*thirdparty.CollectibleContractOwnership, error) } @@ -315,6 +317,15 @@ func (m *DefaultTokenManager) GetBalancesByChain(ctx context.Context, accounts, return resp, err } +func (m *DefaultTokenManager) GetCachedBalancesByChain(ctx context.Context, accounts, tokenAddresses []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error) { + resp, err := m.tokenManager.GetCachedBalancesByChain(accounts, tokenAddresses, chainIDs) + if err != nil { + return resp, err + } + + return resp, nil +} + func (m *DefaultTokenManager) FindOrCreateTokenByAddress(ctx context.Context, chainID uint64, address gethcommon.Address) *token.Token { return m.tokenManager.FindOrCreateTokenByAddress(ctx, chainID, address) } diff --git a/protocol/communities/manager_test.go b/protocol/communities/manager_test.go index 1bae36ea1..f3122e530 100644 --- a/protocol/communities/manager_test.go +++ b/protocol/communities/manager_test.go @@ -163,6 +163,10 @@ func (m *testCollectiblesManager) FetchCollectibleOwnersByContractAddress(ctx co return ret, nil } +func (m *testCollectiblesManager) FetchCachedBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletCommon.ChainID, ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (thirdparty.TokenBalancesPerContractAddress, error) { + return m.response[uint64(chainID)][ownerAddress], nil +} + type testTokenManager struct { response map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big } @@ -193,6 +197,10 @@ func (m *testTokenManager) GetBalancesByChain(ctx context.Context, accounts, tok return m.response, nil } +func (m *testTokenManager) GetCachedBalancesByChain(ctx context.Context, accounts, tokenAddresses []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error) { + return m.response, nil +} + func (m *testTokenManager) FindOrCreateTokenByAddress(ctx context.Context, chainID uint64, address gethcommon.Address) *token.Token { return nil } diff --git a/protocol/communities/permission_checker.go b/protocol/communities/permission_checker.go index 08980164c..8c5f63191 100644 --- a/protocol/communities/permission_checker.go +++ b/protocol/communities/permission_checker.go @@ -176,7 +176,13 @@ func (p *DefaultPermissionChecker) CheckPermissionToJoin(community *Community, a } // If there are any admin or token master permissions, combine result. preParsedPermissions := preParsedCommunityPermissionsData(adminOrTokenMasterPermissionsToJoin) - adminOrTokenPermissionsResponse, err := p.CheckPermissions(preParsedPermissions, accountsAndChainIDs, false) + var adminOrTokenPermissionsResponse *CheckPermissionsResponse + + if community.IsControlNode() { + adminOrTokenPermissionsResponse, err = p.CheckPermissions(preParsedPermissions, accountsAndChainIDs, false) + } else { + adminOrTokenPermissionsResponse, err = p.CheckCachedPermissions(preParsedPermissions, accountsAndChainIDs, false) + } if err != nil { return nil, err } @@ -209,12 +215,14 @@ func (p *DefaultPermissionChecker) checkPermissionsOrDefault(permissions []*Comm } preParsedPermissions := preParsedCommunityPermissionsData(permissions) - return p.CheckPermissions(preParsedPermissions, accountsAndChainIDs, false) + return p.CheckCachedPermissions(preParsedPermissions, accountsAndChainIDs, false) } type ownedERC721TokensGetter = func(walletAddresses []gethcommon.Address, tokenRequirements map[uint64]map[string]*protobuf.TokenCriteria, chainIDs []uint64) (CollectiblesByChain, error) +type balancesByChainGetter = func(ctx context.Context, accounts, tokens []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error) -func (p *DefaultPermissionChecker) checkPermissions(permissionsParsedData *PreParsedCommunityPermissionsData, accountsAndChainIDs []*AccountChainIDsCombination, shortcircuit bool, getOwnedERC721Tokens ownedERC721TokensGetter) (*CheckPermissionsResponse, error) { +func (p *DefaultPermissionChecker) checkPermissions(permissionsParsedData *PreParsedCommunityPermissionsData, accountsAndChainIDs []*AccountChainIDsCombination, shortcircuit bool, + getOwnedERC721Tokens ownedERC721TokensGetter, getBalancesByChain balancesByChainGetter) (*CheckPermissionsResponse, error) { response := &CheckPermissionsResponse{ Satisfied: false, @@ -254,7 +262,7 @@ func (p *DefaultPermissionChecker) checkPermissions(permissionsParsedData *PrePa ownedERC20TokenBalances := make(map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big, 0) if len(chainIDsForERC20) > 0 { // this only returns balances for the networks we're actually interested in - balances, err := p.tokenManager.GetBalancesByChain(context.Background(), accounts, erc20TokenAddresses, chainIDsForERC20) + balances, err := getBalancesByChain(context.Background(), accounts, erc20TokenAddresses, chainIDsForERC20) if err != nil { return nil, err } @@ -448,15 +456,28 @@ func (p *DefaultPermissionChecker) checkPermissions(permissionsParsedData *PrePa return response, nil } +type balancesByOwnerAndContractAddressGetter = func(ctx context.Context, chainID walletcommon.ChainID, ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (map[gethcommon.Address][]thirdparty.TokenBalance, error) + +func (p *DefaultPermissionChecker) handlePermissionsCheck(permissionsParsedData *PreParsedCommunityPermissionsData, accountsAndChainIDs []*AccountChainIDsCombination, shortcircuit bool, + getBalancesByOwnerAndContractAddress balancesByOwnerAndContractAddressGetter, + getBalancesByChain balancesByChainGetter) (*CheckPermissionsResponse, error) { + + var getOwnedERC721Tokens ownedERC721TokensGetter = func(walletAddresses []gethcommon.Address, tokenRequirements map[uint64]map[string]*protobuf.TokenCriteria, chainIDs []uint64) (CollectiblesByChain, error) { + return p.getOwnedERC721Tokens(walletAddresses, tokenRequirements, chainIDs, getBalancesByOwnerAndContractAddress) + } + + return p.checkPermissions(permissionsParsedData, accountsAndChainIDs, shortcircuit, getOwnedERC721Tokens, getBalancesByChain) +} + +func (p *DefaultPermissionChecker) CheckCachedPermissions(permissionsParsedData *PreParsedCommunityPermissionsData, accountsAndChainIDs []*AccountChainIDsCombination, shortcircuit bool) (*CheckPermissionsResponse, error) { + return p.handlePermissionsCheck(permissionsParsedData, accountsAndChainIDs, shortcircuit, p.collectiblesManager.FetchCachedBalancesByOwnerAndContractAddress, p.tokenManager.GetCachedBalancesByChain) +} + // CheckPermissions will retrieve balances and check whether the user has // permission to join the community, if shortcircuit is true, it will stop as soon // as we know the answer func (p *DefaultPermissionChecker) CheckPermissions(permissionsParsedData *PreParsedCommunityPermissionsData, accountsAndChainIDs []*AccountChainIDsCombination, shortcircuit bool) (*CheckPermissionsResponse, error) { - var getOwnedERC721Tokens ownedERC721TokensGetter = func(walletAddresses []gethcommon.Address, tokenRequirements map[uint64]map[string]*protobuf.TokenCriteria, chainIDs []uint64) (CollectiblesByChain, error) { - return p.getOwnedERC721Tokens(walletAddresses, tokenRequirements, chainIDs, p.collectiblesManager.FetchBalancesByOwnerAndContractAddress) - } - - return p.checkPermissions(permissionsParsedData, accountsAndChainIDs, shortcircuit, getOwnedERC721Tokens) + return p.handlePermissionsCheck(permissionsParsedData, accountsAndChainIDs, shortcircuit, p.collectiblesManager.FetchBalancesByOwnerAndContractAddress, p.tokenManager.GetBalancesByChain) } type CollectiblesOwners = map[walletcommon.ChainID]map[gethcommon.Address]*thirdparty.CollectibleContractOwnership @@ -492,7 +513,7 @@ func (p *DefaultPermissionChecker) CheckPermissionsWithPreFetchedData(permission return p.getOwnedERC721Tokens(walletAddresses, tokenRequirements, chainIDs, getCollectiblesBalances) } - return p.checkPermissions(permissionsParsedData, accountsAndChainIDs, shortcircuit, getOwnedERC721Tokens) + return p.checkPermissions(permissionsParsedData, accountsAndChainIDs, shortcircuit, getOwnedERC721Tokens, p.tokenManager.GetBalancesByChain) } func preParsedPermissionsData(permissions []*CommunityTokenPermission) *PreParsedPermissionsData { diff --git a/protocol/communities_messenger_helpers_test.go b/protocol/communities_messenger_helpers_test.go index 14975231b..be4437182 100644 --- a/protocol/communities_messenger_helpers_test.go +++ b/protocol/communities_messenger_helpers_test.go @@ -71,6 +71,11 @@ func (m *TokenManagerMock) GetBalancesByChain(ctx context.Context, accounts, tok return *m.Balances, nil } +func (m *TokenManagerMock) GetCachedBalancesByChain(ctx context.Context, accounts, tokenAddresses []gethcommon.Address, chainIDs []uint64) (map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big, error) { + time.Sleep(100 * time.Millisecond) // simulate response time + return *m.Balances, nil +} + func (m *TokenManagerMock) FindOrCreateTokenByAddress(ctx context.Context, chainID uint64, address gethcommon.Address) *walletToken.Token { time.Sleep(100 * time.Millisecond) // simulate response time return nil @@ -81,6 +86,11 @@ type CollectiblesManagerMock struct { collectibleOwnershipResponse map[string][]thirdparty.AccountBalance } +func (m *CollectiblesManagerMock) FetchCachedBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletCommon.ChainID, + ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (thirdparty.TokenBalancesPerContractAddress, error) { + return m.FetchBalancesByOwnerAndContractAddress(ctx, chainID, ownerAddress, contractAddresses) +} + func (m *CollectiblesManagerMock) FetchBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletCommon.ChainID, ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (thirdparty.TokenBalancesPerContractAddress, error) { ret := make(thirdparty.TokenBalancesPerContractAddress) diff --git a/services/wallet/api.go b/services/wallet/api.go index f89d12835..7a9527334 100644 --- a/services/wallet/api.go +++ b/services/wallet/api.go @@ -284,6 +284,12 @@ func (api *API) GetCryptoOnRamps(ctx context.Context) ([]onramp.CryptoOnRamp, er Collectibles API Start */ +func (api *API) FetchCachedBalancesByOwnerAndContractAddress(ctx context.Context, chainID wcommon.ChainID, ownerAddress common.Address, contractAddresses []common.Address) (thirdparty.TokenBalancesPerContractAddress, error) { + log.Debug("call to FetchCachedBalancesByOwnerAndContractAddress") + + return api.s.collectiblesManager.FetchCachedBalancesByOwnerAndContractAddress(ctx, chainID, ownerAddress, contractAddresses) +} + func (api *API) FetchBalancesByOwnerAndContractAddress(ctx context.Context, chainID wcommon.ChainID, ownerAddress common.Address, contractAddresses []common.Address) (thirdparty.TokenBalancesPerContractAddress, error) { log.Debug("call to FetchBalancesByOwnerAndContractAddress") diff --git a/services/wallet/collectibles/manager.go b/services/wallet/collectibles/manager.go index d3811543d..457300eb4 100644 --- a/services/wallet/collectibles/manager.go +++ b/services/wallet/collectibles/manager.go @@ -125,6 +125,42 @@ func (o *Manager) doContentTypeRequest(ctx context.Context, url string) (string, return resp.Header.Get("Content-Type"), nil } +func (o *Manager) getTokenBalancesByOwnerAddress(collectibles *thirdparty.CollectibleContractOwnership, ownerAddress common.Address) map[common.Address][]thirdparty.TokenBalance { + ret := make(map[common.Address][]thirdparty.TokenBalance) + + for _, nftOwner := range collectibles.Owners { + if nftOwner.OwnerAddress == ownerAddress { + ret[collectibles.ContractAddress] = nftOwner.TokenBalances + break + } + } + + return ret +} + +func (o *Manager) FetchCachedBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletCommon.ChainID, ownerAddress common.Address, contractAddresses []common.Address) (thirdparty.TokenBalancesPerContractAddress, error) { + ret := make(map[common.Address][]thirdparty.TokenBalance) + + for _, contractAddress := range contractAddresses { + ret[contractAddress] = make([]thirdparty.TokenBalance, 0) + } + + for _, contractAddress := range contractAddresses { + ownership, err := o.ownershipDB.FetchCachedCollectibleOwnersByContractAddress(chainID, contractAddress) + if err != nil { + return nil, err + } + + t := o.getTokenBalancesByOwnerAddress(ownership, ownerAddress) + + for address, tokenBalances := range t { + ret[address] = append(ret[address], tokenBalances...) + } + } + + return ret, nil +} + // Need to combine different providers to support all needed ChainIDs func (o *Manager) FetchBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletCommon.ChainID, ownerAddress common.Address, contractAddresses []common.Address) (thirdparty.TokenBalancesPerContractAddress, error) { ret := make(thirdparty.TokenBalancesPerContractAddress) @@ -142,12 +178,8 @@ func (o *Manager) FetchBalancesByOwnerAndContractAddress(ctx context.Context, ch if err != nil { return nil, err } - for _, nftOwner := range ownership.Owners { - if nftOwner.OwnerAddress == ownerAddress { - ret[contractAddress] = nftOwner.TokenBalances - break - } - } + + ret = o.getTokenBalancesByOwnerAddress(ownership, ownerAddress) } } else if err == nil { // Account ownership providers succeeded diff --git a/services/wallet/collectibles/ownership_db.go b/services/wallet/collectibles/ownership_db.go index 50e5b65e8..94253b751 100644 --- a/services/wallet/collectibles/ownership_db.go +++ b/services/wallet/collectibles/ownership_db.go @@ -34,6 +34,8 @@ const unknownUpdateTimestamp = int64(math.MaxInt64) const selectOwnershipColumns = "chain_id, contract_address, token_id" +const collectiblesOwnershipColumns = "token_id, owner_address, balance" + const ownershipTimestampColumns = "owner_address, chain_id, timestamp" const selectOwnershipTimestampColumns = "timestamp" @@ -420,6 +422,61 @@ func (o *OwnershipDB) GetOwnedCollectibles(chainIDs []w_common.ChainID, ownerAdd return thirdparty.RowsToCollectibles(rows) } +func (o *OwnershipDB) FetchCachedCollectibleOwnersByContractAddress(chainID w_common.ChainID, contractAddress common.Address) (*thirdparty.CollectibleContractOwnership, error) { + query, args, err := sqlx.In(fmt.Sprintf(`SELECT %s + FROM collectibles_ownership_cache + WHERE chain_id = ? AND contract_address = ?`, collectiblesOwnershipColumns), chainID, contractAddress) + if err != nil { + return nil, err + } + + var ret thirdparty.CollectibleContractOwnership + + stmt, err := o.db.Prepare(query) + if err != nil { + return nil, err + } + defer stmt.Close() + + rows, err := stmt.Query(args...) + if err != nil { + return nil, err + } + defer rows.Close() + + tokenID := &bigint.BigInt{Int: big.NewInt(0)} + var ownerAddress common.Address + balance := &bigint.BigInt{Int: big.NewInt(0)} + var tokenBalances []thirdparty.TokenBalance + + for rows.Next() { + err = rows.Scan( + (*bigint.SQLBigIntBytes)(tokenID.Int), + &ownerAddress, + (*bigint.SQLBigIntBytes)(balance.Int), + ) + if err != nil { + return nil, err + } + + tokenBalance := thirdparty.TokenBalance{ + TokenID: tokenID, + Balance: balance, + } + tokenBalances = append(tokenBalances, tokenBalance) + + collectibleOwner := thirdparty.CollectibleOwner{ + OwnerAddress: ownerAddress, + TokenBalances: tokenBalances, + } + + ret.ContractAddress = contractAddress + ret.Owners = append(ret.Owners, collectibleOwner) + } + + return &ret, nil +} + func (o *OwnershipDB) GetOwnedCollectible(chainID w_common.ChainID, ownerAddresses common.Address, contractAddress common.Address, tokenID *big.Int) (*thirdparty.CollectibleUniqueID, error) { query := fmt.Sprintf(`SELECT %s FROM collectibles_ownership_cache diff --git a/services/wallet/token/token.go b/services/wallet/token/token.go index 896b2781a..efc92274c 100644 --- a/services/wallet/token/token.go +++ b/services/wallet/token/token.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/json" "errors" + "fmt" "math/big" "strconv" "strings" @@ -964,3 +965,63 @@ func (tm *Manager) onAccountsChange(changedAddresses []common.Address, eventType } } } + +func (tm *Manager) GetCachedBalancesByChain(accounts, tokenAddresses []common.Address, chainIDs []uint64) (map[uint64]map[common.Address]map[common.Address]*hexutil.Big, error) { + accountStrings := make([]string, len(accounts)) + for i, account := range accounts { + accountStrings[i] = fmt.Sprintf("'%s'", account.Hex()) + } + + tokenAddressStrings := make([]string, len(tokenAddresses)) + for i, tokenAddress := range tokenAddresses { + tokenAddressStrings[i] = fmt.Sprintf("'%s'", tokenAddress.Hex()) + } + + chainIDStrings := make([]string, len(chainIDs)) + for i, chainID := range chainIDs { + chainIDStrings[i] = fmt.Sprintf("%d", chainID) + } + + query := `SELECT chain_id, user_address, token_address, balance + FROM token_balances + WHERE user_address IN (` + strings.Join(accountStrings, ",") + `) + AND token_address IN (` + strings.Join(tokenAddressStrings, ",") + `) + AND chain_id IN (` + strings.Join(chainIDStrings, ",") + `)` + + rows, err := tm.db.Query(query) + if err != nil { + return nil, err + } + defer rows.Close() + + ret := make(map[uint64]map[common.Address]map[common.Address]*hexutil.Big) + + for rows.Next() { + var chainID uint64 + var userAddressStr, tokenAddressStr string + var balanceStr string + + err := rows.Scan(&chainID, &userAddressStr, &tokenAddressStr, &balanceStr) + if err != nil { + return nil, err + } + + num := new(hexutil.Big) + _, ok := num.ToInt().SetString(balanceStr, 0) + if !ok { + return ret, nil + } + + if ret[chainID] == nil { + ret[chainID] = make(map[common.Address]map[common.Address]*hexutil.Big) + } + + if ret[chainID][common.HexToAddress(userAddressStr)] == nil { + ret[chainID][common.HexToAddress(userAddressStr)] = make(map[common.Address]*hexutil.Big) + } + + ret[chainID][common.HexToAddress(userAddressStr)][common.HexToAddress(tokenAddressStr)] = num + } + + return ret, nil +}