chore(communities)_: make member use wallet tokens during permission checking (#5268)

fixes #14913
This commit is contained in:
Godfrain Jacques 2024-06-11 14:00:04 -07:00 committed by GitHub
parent 9ffe842acc
commit 892fcffce4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 223 additions and 17 deletions

View File

@ -243,7 +243,8 @@ type managerOptions struct {
}
type TokenManager interface {
GetBalancesByChain(ctx context.Context, accounts, tokens []gethcommon.Address, chainIDs []uint64) (map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big, error)
GetBalancesByChain(ctx context.Context, accounts, tokens []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error)
GetCachedBalancesByChain(ctx context.Context, accounts, tokenAddresses []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error)
FindOrCreateTokenByAddress(ctx context.Context, chainID uint64, address gethcommon.Address) *token.Token
GetAllChainIDs() ([]uint64, error)
}
@ -301,6 +302,7 @@ func (m *DefaultTokenManager) GetAllChainIDs() ([]uint64, error) {
type CollectiblesManager interface {
FetchBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletcommon.ChainID, ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (thirdparty.TokenBalancesPerContractAddress, error)
FetchCachedBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletcommon.ChainID, ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (thirdparty.TokenBalancesPerContractAddress, error)
GetCollectibleOwnership(id thirdparty.CollectibleUniqueID) ([]thirdparty.AccountBalance, error)
FetchCollectibleOwnersByContractAddress(ctx context.Context, chainID walletcommon.ChainID, contractAddress gethcommon.Address) (*thirdparty.CollectibleContractOwnership, error)
}
@ -315,6 +317,15 @@ func (m *DefaultTokenManager) GetBalancesByChain(ctx context.Context, accounts,
return resp, err
}
func (m *DefaultTokenManager) GetCachedBalancesByChain(ctx context.Context, accounts, tokenAddresses []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error) {
resp, err := m.tokenManager.GetCachedBalancesByChain(accounts, tokenAddresses, chainIDs)
if err != nil {
return resp, err
}
return resp, nil
}
func (m *DefaultTokenManager) FindOrCreateTokenByAddress(ctx context.Context, chainID uint64, address gethcommon.Address) *token.Token {
return m.tokenManager.FindOrCreateTokenByAddress(ctx, chainID, address)
}

View File

@ -163,6 +163,10 @@ func (m *testCollectiblesManager) FetchCollectibleOwnersByContractAddress(ctx co
return ret, nil
}
func (m *testCollectiblesManager) FetchCachedBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletCommon.ChainID, ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (thirdparty.TokenBalancesPerContractAddress, error) {
return m.response[uint64(chainID)][ownerAddress], nil
}
type testTokenManager struct {
response map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big
}
@ -193,6 +197,10 @@ func (m *testTokenManager) GetBalancesByChain(ctx context.Context, accounts, tok
return m.response, nil
}
func (m *testTokenManager) GetCachedBalancesByChain(ctx context.Context, accounts, tokenAddresses []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error) {
return m.response, nil
}
func (m *testTokenManager) FindOrCreateTokenByAddress(ctx context.Context, chainID uint64, address gethcommon.Address) *token.Token {
return nil
}

View File

@ -176,7 +176,13 @@ func (p *DefaultPermissionChecker) CheckPermissionToJoin(community *Community, a
}
// If there are any admin or token master permissions, combine result.
preParsedPermissions := preParsedCommunityPermissionsData(adminOrTokenMasterPermissionsToJoin)
adminOrTokenPermissionsResponse, err := p.CheckPermissions(preParsedPermissions, accountsAndChainIDs, false)
var adminOrTokenPermissionsResponse *CheckPermissionsResponse
if community.IsControlNode() {
adminOrTokenPermissionsResponse, err = p.CheckPermissions(preParsedPermissions, accountsAndChainIDs, false)
} else {
adminOrTokenPermissionsResponse, err = p.CheckCachedPermissions(preParsedPermissions, accountsAndChainIDs, false)
}
if err != nil {
return nil, err
}
@ -209,12 +215,14 @@ func (p *DefaultPermissionChecker) checkPermissionsOrDefault(permissions []*Comm
}
preParsedPermissions := preParsedCommunityPermissionsData(permissions)
return p.CheckPermissions(preParsedPermissions, accountsAndChainIDs, false)
return p.CheckCachedPermissions(preParsedPermissions, accountsAndChainIDs, false)
}
type ownedERC721TokensGetter = func(walletAddresses []gethcommon.Address, tokenRequirements map[uint64]map[string]*protobuf.TokenCriteria, chainIDs []uint64) (CollectiblesByChain, error)
type balancesByChainGetter = func(ctx context.Context, accounts, tokens []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error)
func (p *DefaultPermissionChecker) checkPermissions(permissionsParsedData *PreParsedCommunityPermissionsData, accountsAndChainIDs []*AccountChainIDsCombination, shortcircuit bool, getOwnedERC721Tokens ownedERC721TokensGetter) (*CheckPermissionsResponse, error) {
func (p *DefaultPermissionChecker) checkPermissions(permissionsParsedData *PreParsedCommunityPermissionsData, accountsAndChainIDs []*AccountChainIDsCombination, shortcircuit bool,
getOwnedERC721Tokens ownedERC721TokensGetter, getBalancesByChain balancesByChainGetter) (*CheckPermissionsResponse, error) {
response := &CheckPermissionsResponse{
Satisfied: false,
@ -254,7 +262,7 @@ func (p *DefaultPermissionChecker) checkPermissions(permissionsParsedData *PrePa
ownedERC20TokenBalances := make(map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big, 0)
if len(chainIDsForERC20) > 0 {
// this only returns balances for the networks we're actually interested in
balances, err := p.tokenManager.GetBalancesByChain(context.Background(), accounts, erc20TokenAddresses, chainIDsForERC20)
balances, err := getBalancesByChain(context.Background(), accounts, erc20TokenAddresses, chainIDsForERC20)
if err != nil {
return nil, err
}
@ -448,15 +456,28 @@ func (p *DefaultPermissionChecker) checkPermissions(permissionsParsedData *PrePa
return response, nil
}
type balancesByOwnerAndContractAddressGetter = func(ctx context.Context, chainID walletcommon.ChainID, ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (map[gethcommon.Address][]thirdparty.TokenBalance, error)
func (p *DefaultPermissionChecker) handlePermissionsCheck(permissionsParsedData *PreParsedCommunityPermissionsData, accountsAndChainIDs []*AccountChainIDsCombination, shortcircuit bool,
getBalancesByOwnerAndContractAddress balancesByOwnerAndContractAddressGetter,
getBalancesByChain balancesByChainGetter) (*CheckPermissionsResponse, error) {
var getOwnedERC721Tokens ownedERC721TokensGetter = func(walletAddresses []gethcommon.Address, tokenRequirements map[uint64]map[string]*protobuf.TokenCriteria, chainIDs []uint64) (CollectiblesByChain, error) {
return p.getOwnedERC721Tokens(walletAddresses, tokenRequirements, chainIDs, getBalancesByOwnerAndContractAddress)
}
return p.checkPermissions(permissionsParsedData, accountsAndChainIDs, shortcircuit, getOwnedERC721Tokens, getBalancesByChain)
}
func (p *DefaultPermissionChecker) CheckCachedPermissions(permissionsParsedData *PreParsedCommunityPermissionsData, accountsAndChainIDs []*AccountChainIDsCombination, shortcircuit bool) (*CheckPermissionsResponse, error) {
return p.handlePermissionsCheck(permissionsParsedData, accountsAndChainIDs, shortcircuit, p.collectiblesManager.FetchCachedBalancesByOwnerAndContractAddress, p.tokenManager.GetCachedBalancesByChain)
}
// CheckPermissions will retrieve balances and check whether the user has
// permission to join the community, if shortcircuit is true, it will stop as soon
// as we know the answer
func (p *DefaultPermissionChecker) CheckPermissions(permissionsParsedData *PreParsedCommunityPermissionsData, accountsAndChainIDs []*AccountChainIDsCombination, shortcircuit bool) (*CheckPermissionsResponse, error) {
var getOwnedERC721Tokens ownedERC721TokensGetter = func(walletAddresses []gethcommon.Address, tokenRequirements map[uint64]map[string]*protobuf.TokenCriteria, chainIDs []uint64) (CollectiblesByChain, error) {
return p.getOwnedERC721Tokens(walletAddresses, tokenRequirements, chainIDs, p.collectiblesManager.FetchBalancesByOwnerAndContractAddress)
}
return p.checkPermissions(permissionsParsedData, accountsAndChainIDs, shortcircuit, getOwnedERC721Tokens)
return p.handlePermissionsCheck(permissionsParsedData, accountsAndChainIDs, shortcircuit, p.collectiblesManager.FetchBalancesByOwnerAndContractAddress, p.tokenManager.GetBalancesByChain)
}
type CollectiblesOwners = map[walletcommon.ChainID]map[gethcommon.Address]*thirdparty.CollectibleContractOwnership
@ -492,7 +513,7 @@ func (p *DefaultPermissionChecker) CheckPermissionsWithPreFetchedData(permission
return p.getOwnedERC721Tokens(walletAddresses, tokenRequirements, chainIDs, getCollectiblesBalances)
}
return p.checkPermissions(permissionsParsedData, accountsAndChainIDs, shortcircuit, getOwnedERC721Tokens)
return p.checkPermissions(permissionsParsedData, accountsAndChainIDs, shortcircuit, getOwnedERC721Tokens, p.tokenManager.GetBalancesByChain)
}
func preParsedPermissionsData(permissions []*CommunityTokenPermission) *PreParsedPermissionsData {

View File

@ -71,6 +71,11 @@ func (m *TokenManagerMock) GetBalancesByChain(ctx context.Context, accounts, tok
return *m.Balances, nil
}
func (m *TokenManagerMock) GetCachedBalancesByChain(ctx context.Context, accounts, tokenAddresses []gethcommon.Address, chainIDs []uint64) (map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big, error) {
time.Sleep(100 * time.Millisecond) // simulate response time
return *m.Balances, nil
}
func (m *TokenManagerMock) FindOrCreateTokenByAddress(ctx context.Context, chainID uint64, address gethcommon.Address) *walletToken.Token {
time.Sleep(100 * time.Millisecond) // simulate response time
return nil
@ -81,6 +86,11 @@ type CollectiblesManagerMock struct {
collectibleOwnershipResponse map[string][]thirdparty.AccountBalance
}
func (m *CollectiblesManagerMock) FetchCachedBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletCommon.ChainID,
ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (thirdparty.TokenBalancesPerContractAddress, error) {
return m.FetchBalancesByOwnerAndContractAddress(ctx, chainID, ownerAddress, contractAddresses)
}
func (m *CollectiblesManagerMock) FetchBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletCommon.ChainID,
ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (thirdparty.TokenBalancesPerContractAddress, error) {
ret := make(thirdparty.TokenBalancesPerContractAddress)

View File

@ -284,6 +284,12 @@ func (api *API) GetCryptoOnRamps(ctx context.Context) ([]onramp.CryptoOnRamp, er
Collectibles API Start
*/
func (api *API) FetchCachedBalancesByOwnerAndContractAddress(ctx context.Context, chainID wcommon.ChainID, ownerAddress common.Address, contractAddresses []common.Address) (thirdparty.TokenBalancesPerContractAddress, error) {
log.Debug("call to FetchCachedBalancesByOwnerAndContractAddress")
return api.s.collectiblesManager.FetchCachedBalancesByOwnerAndContractAddress(ctx, chainID, ownerAddress, contractAddresses)
}
func (api *API) FetchBalancesByOwnerAndContractAddress(ctx context.Context, chainID wcommon.ChainID, ownerAddress common.Address, contractAddresses []common.Address) (thirdparty.TokenBalancesPerContractAddress, error) {
log.Debug("call to FetchBalancesByOwnerAndContractAddress")

View File

@ -125,6 +125,42 @@ func (o *Manager) doContentTypeRequest(ctx context.Context, url string) (string,
return resp.Header.Get("Content-Type"), nil
}
func (o *Manager) getTokenBalancesByOwnerAddress(collectibles *thirdparty.CollectibleContractOwnership, ownerAddress common.Address) map[common.Address][]thirdparty.TokenBalance {
ret := make(map[common.Address][]thirdparty.TokenBalance)
for _, nftOwner := range collectibles.Owners {
if nftOwner.OwnerAddress == ownerAddress {
ret[collectibles.ContractAddress] = nftOwner.TokenBalances
break
}
}
return ret
}
func (o *Manager) FetchCachedBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletCommon.ChainID, ownerAddress common.Address, contractAddresses []common.Address) (thirdparty.TokenBalancesPerContractAddress, error) {
ret := make(map[common.Address][]thirdparty.TokenBalance)
for _, contractAddress := range contractAddresses {
ret[contractAddress] = make([]thirdparty.TokenBalance, 0)
}
for _, contractAddress := range contractAddresses {
ownership, err := o.ownershipDB.FetchCachedCollectibleOwnersByContractAddress(chainID, contractAddress)
if err != nil {
return nil, err
}
t := o.getTokenBalancesByOwnerAddress(ownership, ownerAddress)
for address, tokenBalances := range t {
ret[address] = append(ret[address], tokenBalances...)
}
}
return ret, nil
}
// Need to combine different providers to support all needed ChainIDs
func (o *Manager) FetchBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletCommon.ChainID, ownerAddress common.Address, contractAddresses []common.Address) (thirdparty.TokenBalancesPerContractAddress, error) {
ret := make(thirdparty.TokenBalancesPerContractAddress)
@ -142,12 +178,8 @@ func (o *Manager) FetchBalancesByOwnerAndContractAddress(ctx context.Context, ch
if err != nil {
return nil, err
}
for _, nftOwner := range ownership.Owners {
if nftOwner.OwnerAddress == ownerAddress {
ret[contractAddress] = nftOwner.TokenBalances
break
}
}
ret = o.getTokenBalancesByOwnerAddress(ownership, ownerAddress)
}
} else if err == nil {
// Account ownership providers succeeded

View File

@ -34,6 +34,8 @@ const unknownUpdateTimestamp = int64(math.MaxInt64)
const selectOwnershipColumns = "chain_id, contract_address, token_id"
const collectiblesOwnershipColumns = "token_id, owner_address, balance"
const ownershipTimestampColumns = "owner_address, chain_id, timestamp"
const selectOwnershipTimestampColumns = "timestamp"
@ -420,6 +422,61 @@ func (o *OwnershipDB) GetOwnedCollectibles(chainIDs []w_common.ChainID, ownerAdd
return thirdparty.RowsToCollectibles(rows)
}
func (o *OwnershipDB) FetchCachedCollectibleOwnersByContractAddress(chainID w_common.ChainID, contractAddress common.Address) (*thirdparty.CollectibleContractOwnership, error) {
query, args, err := sqlx.In(fmt.Sprintf(`SELECT %s
FROM collectibles_ownership_cache
WHERE chain_id = ? AND contract_address = ?`, collectiblesOwnershipColumns), chainID, contractAddress)
if err != nil {
return nil, err
}
var ret thirdparty.CollectibleContractOwnership
stmt, err := o.db.Prepare(query)
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err := stmt.Query(args...)
if err != nil {
return nil, err
}
defer rows.Close()
tokenID := &bigint.BigInt{Int: big.NewInt(0)}
var ownerAddress common.Address
balance := &bigint.BigInt{Int: big.NewInt(0)}
var tokenBalances []thirdparty.TokenBalance
for rows.Next() {
err = rows.Scan(
(*bigint.SQLBigIntBytes)(tokenID.Int),
&ownerAddress,
(*bigint.SQLBigIntBytes)(balance.Int),
)
if err != nil {
return nil, err
}
tokenBalance := thirdparty.TokenBalance{
TokenID: tokenID,
Balance: balance,
}
tokenBalances = append(tokenBalances, tokenBalance)
collectibleOwner := thirdparty.CollectibleOwner{
OwnerAddress: ownerAddress,
TokenBalances: tokenBalances,
}
ret.ContractAddress = contractAddress
ret.Owners = append(ret.Owners, collectibleOwner)
}
return &ret, nil
}
func (o *OwnershipDB) GetOwnedCollectible(chainID w_common.ChainID, ownerAddresses common.Address, contractAddress common.Address, tokenID *big.Int) (*thirdparty.CollectibleUniqueID, error) {
query := fmt.Sprintf(`SELECT %s
FROM collectibles_ownership_cache

View File

@ -5,6 +5,7 @@ import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"math/big"
"strconv"
"strings"
@ -964,3 +965,63 @@ func (tm *Manager) onAccountsChange(changedAddresses []common.Address, eventType
}
}
}
func (tm *Manager) GetCachedBalancesByChain(accounts, tokenAddresses []common.Address, chainIDs []uint64) (map[uint64]map[common.Address]map[common.Address]*hexutil.Big, error) {
accountStrings := make([]string, len(accounts))
for i, account := range accounts {
accountStrings[i] = fmt.Sprintf("'%s'", account.Hex())
}
tokenAddressStrings := make([]string, len(tokenAddresses))
for i, tokenAddress := range tokenAddresses {
tokenAddressStrings[i] = fmt.Sprintf("'%s'", tokenAddress.Hex())
}
chainIDStrings := make([]string, len(chainIDs))
for i, chainID := range chainIDs {
chainIDStrings[i] = fmt.Sprintf("%d", chainID)
}
query := `SELECT chain_id, user_address, token_address, balance
FROM token_balances
WHERE user_address IN (` + strings.Join(accountStrings, ",") + `)
AND token_address IN (` + strings.Join(tokenAddressStrings, ",") + `)
AND chain_id IN (` + strings.Join(chainIDStrings, ",") + `)`
rows, err := tm.db.Query(query)
if err != nil {
return nil, err
}
defer rows.Close()
ret := make(map[uint64]map[common.Address]map[common.Address]*hexutil.Big)
for rows.Next() {
var chainID uint64
var userAddressStr, tokenAddressStr string
var balanceStr string
err := rows.Scan(&chainID, &userAddressStr, &tokenAddressStr, &balanceStr)
if err != nil {
return nil, err
}
num := new(hexutil.Big)
_, ok := num.ToInt().SetString(balanceStr, 0)
if !ok {
return ret, nil
}
if ret[chainID] == nil {
ret[chainID] = make(map[common.Address]map[common.Address]*hexutil.Big)
}
if ret[chainID][common.HexToAddress(userAddressStr)] == nil {
ret[chainID][common.HexToAddress(userAddressStr)] = make(map[common.Address]*hexutil.Big)
}
ret[chainID][common.HexToAddress(userAddressStr)][common.HexToAddress(tokenAddressStr)] = num
}
return ret, nil
}