chore(communities)_: make member use wallet tokens during permission checking (#5268)
fixes #14913
This commit is contained in:
parent
9ffe842acc
commit
892fcffce4
|
@ -243,7 +243,8 @@ type managerOptions struct {
|
|||
}
|
||||
|
||||
type TokenManager interface {
|
||||
GetBalancesByChain(ctx context.Context, accounts, tokens []gethcommon.Address, chainIDs []uint64) (map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big, error)
|
||||
GetBalancesByChain(ctx context.Context, accounts, tokens []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error)
|
||||
GetCachedBalancesByChain(ctx context.Context, accounts, tokenAddresses []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error)
|
||||
FindOrCreateTokenByAddress(ctx context.Context, chainID uint64, address gethcommon.Address) *token.Token
|
||||
GetAllChainIDs() ([]uint64, error)
|
||||
}
|
||||
|
@ -301,6 +302,7 @@ func (m *DefaultTokenManager) GetAllChainIDs() ([]uint64, error) {
|
|||
|
||||
type CollectiblesManager interface {
|
||||
FetchBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletcommon.ChainID, ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (thirdparty.TokenBalancesPerContractAddress, error)
|
||||
FetchCachedBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletcommon.ChainID, ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (thirdparty.TokenBalancesPerContractAddress, error)
|
||||
GetCollectibleOwnership(id thirdparty.CollectibleUniqueID) ([]thirdparty.AccountBalance, error)
|
||||
FetchCollectibleOwnersByContractAddress(ctx context.Context, chainID walletcommon.ChainID, contractAddress gethcommon.Address) (*thirdparty.CollectibleContractOwnership, error)
|
||||
}
|
||||
|
@ -315,6 +317,15 @@ func (m *DefaultTokenManager) GetBalancesByChain(ctx context.Context, accounts,
|
|||
return resp, err
|
||||
}
|
||||
|
||||
func (m *DefaultTokenManager) GetCachedBalancesByChain(ctx context.Context, accounts, tokenAddresses []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error) {
|
||||
resp, err := m.tokenManager.GetCachedBalancesByChain(accounts, tokenAddresses, chainIDs)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (m *DefaultTokenManager) FindOrCreateTokenByAddress(ctx context.Context, chainID uint64, address gethcommon.Address) *token.Token {
|
||||
return m.tokenManager.FindOrCreateTokenByAddress(ctx, chainID, address)
|
||||
}
|
||||
|
|
|
@ -163,6 +163,10 @@ func (m *testCollectiblesManager) FetchCollectibleOwnersByContractAddress(ctx co
|
|||
return ret, nil
|
||||
}
|
||||
|
||||
func (m *testCollectiblesManager) FetchCachedBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletCommon.ChainID, ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (thirdparty.TokenBalancesPerContractAddress, error) {
|
||||
return m.response[uint64(chainID)][ownerAddress], nil
|
||||
}
|
||||
|
||||
type testTokenManager struct {
|
||||
response map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big
|
||||
}
|
||||
|
@ -193,6 +197,10 @@ func (m *testTokenManager) GetBalancesByChain(ctx context.Context, accounts, tok
|
|||
return m.response, nil
|
||||
}
|
||||
|
||||
func (m *testTokenManager) GetCachedBalancesByChain(ctx context.Context, accounts, tokenAddresses []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error) {
|
||||
return m.response, nil
|
||||
}
|
||||
|
||||
func (m *testTokenManager) FindOrCreateTokenByAddress(ctx context.Context, chainID uint64, address gethcommon.Address) *token.Token {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -176,7 +176,13 @@ func (p *DefaultPermissionChecker) CheckPermissionToJoin(community *Community, a
|
|||
}
|
||||
// If there are any admin or token master permissions, combine result.
|
||||
preParsedPermissions := preParsedCommunityPermissionsData(adminOrTokenMasterPermissionsToJoin)
|
||||
adminOrTokenPermissionsResponse, err := p.CheckPermissions(preParsedPermissions, accountsAndChainIDs, false)
|
||||
var adminOrTokenPermissionsResponse *CheckPermissionsResponse
|
||||
|
||||
if community.IsControlNode() {
|
||||
adminOrTokenPermissionsResponse, err = p.CheckPermissions(preParsedPermissions, accountsAndChainIDs, false)
|
||||
} else {
|
||||
adminOrTokenPermissionsResponse, err = p.CheckCachedPermissions(preParsedPermissions, accountsAndChainIDs, false)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -209,12 +215,14 @@ func (p *DefaultPermissionChecker) checkPermissionsOrDefault(permissions []*Comm
|
|||
}
|
||||
|
||||
preParsedPermissions := preParsedCommunityPermissionsData(permissions)
|
||||
return p.CheckPermissions(preParsedPermissions, accountsAndChainIDs, false)
|
||||
return p.CheckCachedPermissions(preParsedPermissions, accountsAndChainIDs, false)
|
||||
}
|
||||
|
||||
type ownedERC721TokensGetter = func(walletAddresses []gethcommon.Address, tokenRequirements map[uint64]map[string]*protobuf.TokenCriteria, chainIDs []uint64) (CollectiblesByChain, error)
|
||||
type balancesByChainGetter = func(ctx context.Context, accounts, tokens []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error)
|
||||
|
||||
func (p *DefaultPermissionChecker) checkPermissions(permissionsParsedData *PreParsedCommunityPermissionsData, accountsAndChainIDs []*AccountChainIDsCombination, shortcircuit bool, getOwnedERC721Tokens ownedERC721TokensGetter) (*CheckPermissionsResponse, error) {
|
||||
func (p *DefaultPermissionChecker) checkPermissions(permissionsParsedData *PreParsedCommunityPermissionsData, accountsAndChainIDs []*AccountChainIDsCombination, shortcircuit bool,
|
||||
getOwnedERC721Tokens ownedERC721TokensGetter, getBalancesByChain balancesByChainGetter) (*CheckPermissionsResponse, error) {
|
||||
|
||||
response := &CheckPermissionsResponse{
|
||||
Satisfied: false,
|
||||
|
@ -254,7 +262,7 @@ func (p *DefaultPermissionChecker) checkPermissions(permissionsParsedData *PrePa
|
|||
ownedERC20TokenBalances := make(map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big, 0)
|
||||
if len(chainIDsForERC20) > 0 {
|
||||
// this only returns balances for the networks we're actually interested in
|
||||
balances, err := p.tokenManager.GetBalancesByChain(context.Background(), accounts, erc20TokenAddresses, chainIDsForERC20)
|
||||
balances, err := getBalancesByChain(context.Background(), accounts, erc20TokenAddresses, chainIDsForERC20)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -448,15 +456,28 @@ func (p *DefaultPermissionChecker) checkPermissions(permissionsParsedData *PrePa
|
|||
return response, nil
|
||||
}
|
||||
|
||||
type balancesByOwnerAndContractAddressGetter = func(ctx context.Context, chainID walletcommon.ChainID, ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (map[gethcommon.Address][]thirdparty.TokenBalance, error)
|
||||
|
||||
func (p *DefaultPermissionChecker) handlePermissionsCheck(permissionsParsedData *PreParsedCommunityPermissionsData, accountsAndChainIDs []*AccountChainIDsCombination, shortcircuit bool,
|
||||
getBalancesByOwnerAndContractAddress balancesByOwnerAndContractAddressGetter,
|
||||
getBalancesByChain balancesByChainGetter) (*CheckPermissionsResponse, error) {
|
||||
|
||||
var getOwnedERC721Tokens ownedERC721TokensGetter = func(walletAddresses []gethcommon.Address, tokenRequirements map[uint64]map[string]*protobuf.TokenCriteria, chainIDs []uint64) (CollectiblesByChain, error) {
|
||||
return p.getOwnedERC721Tokens(walletAddresses, tokenRequirements, chainIDs, getBalancesByOwnerAndContractAddress)
|
||||
}
|
||||
|
||||
return p.checkPermissions(permissionsParsedData, accountsAndChainIDs, shortcircuit, getOwnedERC721Tokens, getBalancesByChain)
|
||||
}
|
||||
|
||||
func (p *DefaultPermissionChecker) CheckCachedPermissions(permissionsParsedData *PreParsedCommunityPermissionsData, accountsAndChainIDs []*AccountChainIDsCombination, shortcircuit bool) (*CheckPermissionsResponse, error) {
|
||||
return p.handlePermissionsCheck(permissionsParsedData, accountsAndChainIDs, shortcircuit, p.collectiblesManager.FetchCachedBalancesByOwnerAndContractAddress, p.tokenManager.GetCachedBalancesByChain)
|
||||
}
|
||||
|
||||
// CheckPermissions will retrieve balances and check whether the user has
|
||||
// permission to join the community, if shortcircuit is true, it will stop as soon
|
||||
// as we know the answer
|
||||
func (p *DefaultPermissionChecker) CheckPermissions(permissionsParsedData *PreParsedCommunityPermissionsData, accountsAndChainIDs []*AccountChainIDsCombination, shortcircuit bool) (*CheckPermissionsResponse, error) {
|
||||
var getOwnedERC721Tokens ownedERC721TokensGetter = func(walletAddresses []gethcommon.Address, tokenRequirements map[uint64]map[string]*protobuf.TokenCriteria, chainIDs []uint64) (CollectiblesByChain, error) {
|
||||
return p.getOwnedERC721Tokens(walletAddresses, tokenRequirements, chainIDs, p.collectiblesManager.FetchBalancesByOwnerAndContractAddress)
|
||||
}
|
||||
|
||||
return p.checkPermissions(permissionsParsedData, accountsAndChainIDs, shortcircuit, getOwnedERC721Tokens)
|
||||
return p.handlePermissionsCheck(permissionsParsedData, accountsAndChainIDs, shortcircuit, p.collectiblesManager.FetchBalancesByOwnerAndContractAddress, p.tokenManager.GetBalancesByChain)
|
||||
}
|
||||
|
||||
type CollectiblesOwners = map[walletcommon.ChainID]map[gethcommon.Address]*thirdparty.CollectibleContractOwnership
|
||||
|
@ -492,7 +513,7 @@ func (p *DefaultPermissionChecker) CheckPermissionsWithPreFetchedData(permission
|
|||
return p.getOwnedERC721Tokens(walletAddresses, tokenRequirements, chainIDs, getCollectiblesBalances)
|
||||
}
|
||||
|
||||
return p.checkPermissions(permissionsParsedData, accountsAndChainIDs, shortcircuit, getOwnedERC721Tokens)
|
||||
return p.checkPermissions(permissionsParsedData, accountsAndChainIDs, shortcircuit, getOwnedERC721Tokens, p.tokenManager.GetBalancesByChain)
|
||||
}
|
||||
|
||||
func preParsedPermissionsData(permissions []*CommunityTokenPermission) *PreParsedPermissionsData {
|
||||
|
|
|
@ -71,6 +71,11 @@ func (m *TokenManagerMock) GetBalancesByChain(ctx context.Context, accounts, tok
|
|||
return *m.Balances, nil
|
||||
}
|
||||
|
||||
func (m *TokenManagerMock) GetCachedBalancesByChain(ctx context.Context, accounts, tokenAddresses []gethcommon.Address, chainIDs []uint64) (map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big, error) {
|
||||
time.Sleep(100 * time.Millisecond) // simulate response time
|
||||
return *m.Balances, nil
|
||||
}
|
||||
|
||||
func (m *TokenManagerMock) FindOrCreateTokenByAddress(ctx context.Context, chainID uint64, address gethcommon.Address) *walletToken.Token {
|
||||
time.Sleep(100 * time.Millisecond) // simulate response time
|
||||
return nil
|
||||
|
@ -81,6 +86,11 @@ type CollectiblesManagerMock struct {
|
|||
collectibleOwnershipResponse map[string][]thirdparty.AccountBalance
|
||||
}
|
||||
|
||||
func (m *CollectiblesManagerMock) FetchCachedBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletCommon.ChainID,
|
||||
ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (thirdparty.TokenBalancesPerContractAddress, error) {
|
||||
return m.FetchBalancesByOwnerAndContractAddress(ctx, chainID, ownerAddress, contractAddresses)
|
||||
}
|
||||
|
||||
func (m *CollectiblesManagerMock) FetchBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletCommon.ChainID,
|
||||
ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (thirdparty.TokenBalancesPerContractAddress, error) {
|
||||
ret := make(thirdparty.TokenBalancesPerContractAddress)
|
||||
|
|
|
@ -284,6 +284,12 @@ func (api *API) GetCryptoOnRamps(ctx context.Context) ([]onramp.CryptoOnRamp, er
|
|||
Collectibles API Start
|
||||
*/
|
||||
|
||||
func (api *API) FetchCachedBalancesByOwnerAndContractAddress(ctx context.Context, chainID wcommon.ChainID, ownerAddress common.Address, contractAddresses []common.Address) (thirdparty.TokenBalancesPerContractAddress, error) {
|
||||
log.Debug("call to FetchCachedBalancesByOwnerAndContractAddress")
|
||||
|
||||
return api.s.collectiblesManager.FetchCachedBalancesByOwnerAndContractAddress(ctx, chainID, ownerAddress, contractAddresses)
|
||||
}
|
||||
|
||||
func (api *API) FetchBalancesByOwnerAndContractAddress(ctx context.Context, chainID wcommon.ChainID, ownerAddress common.Address, contractAddresses []common.Address) (thirdparty.TokenBalancesPerContractAddress, error) {
|
||||
log.Debug("call to FetchBalancesByOwnerAndContractAddress")
|
||||
|
||||
|
|
|
@ -125,6 +125,42 @@ func (o *Manager) doContentTypeRequest(ctx context.Context, url string) (string,
|
|||
return resp.Header.Get("Content-Type"), nil
|
||||
}
|
||||
|
||||
func (o *Manager) getTokenBalancesByOwnerAddress(collectibles *thirdparty.CollectibleContractOwnership, ownerAddress common.Address) map[common.Address][]thirdparty.TokenBalance {
|
||||
ret := make(map[common.Address][]thirdparty.TokenBalance)
|
||||
|
||||
for _, nftOwner := range collectibles.Owners {
|
||||
if nftOwner.OwnerAddress == ownerAddress {
|
||||
ret[collectibles.ContractAddress] = nftOwner.TokenBalances
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func (o *Manager) FetchCachedBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletCommon.ChainID, ownerAddress common.Address, contractAddresses []common.Address) (thirdparty.TokenBalancesPerContractAddress, error) {
|
||||
ret := make(map[common.Address][]thirdparty.TokenBalance)
|
||||
|
||||
for _, contractAddress := range contractAddresses {
|
||||
ret[contractAddress] = make([]thirdparty.TokenBalance, 0)
|
||||
}
|
||||
|
||||
for _, contractAddress := range contractAddresses {
|
||||
ownership, err := o.ownershipDB.FetchCachedCollectibleOwnersByContractAddress(chainID, contractAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := o.getTokenBalancesByOwnerAddress(ownership, ownerAddress)
|
||||
|
||||
for address, tokenBalances := range t {
|
||||
ret[address] = append(ret[address], tokenBalances...)
|
||||
}
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// Need to combine different providers to support all needed ChainIDs
|
||||
func (o *Manager) FetchBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletCommon.ChainID, ownerAddress common.Address, contractAddresses []common.Address) (thirdparty.TokenBalancesPerContractAddress, error) {
|
||||
ret := make(thirdparty.TokenBalancesPerContractAddress)
|
||||
|
@ -142,12 +178,8 @@ func (o *Manager) FetchBalancesByOwnerAndContractAddress(ctx context.Context, ch
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, nftOwner := range ownership.Owners {
|
||||
if nftOwner.OwnerAddress == ownerAddress {
|
||||
ret[contractAddress] = nftOwner.TokenBalances
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
ret = o.getTokenBalancesByOwnerAddress(ownership, ownerAddress)
|
||||
}
|
||||
} else if err == nil {
|
||||
// Account ownership providers succeeded
|
||||
|
|
|
@ -34,6 +34,8 @@ const unknownUpdateTimestamp = int64(math.MaxInt64)
|
|||
|
||||
const selectOwnershipColumns = "chain_id, contract_address, token_id"
|
||||
|
||||
const collectiblesOwnershipColumns = "token_id, owner_address, balance"
|
||||
|
||||
const ownershipTimestampColumns = "owner_address, chain_id, timestamp"
|
||||
const selectOwnershipTimestampColumns = "timestamp"
|
||||
|
||||
|
@ -420,6 +422,61 @@ func (o *OwnershipDB) GetOwnedCollectibles(chainIDs []w_common.ChainID, ownerAdd
|
|||
return thirdparty.RowsToCollectibles(rows)
|
||||
}
|
||||
|
||||
func (o *OwnershipDB) FetchCachedCollectibleOwnersByContractAddress(chainID w_common.ChainID, contractAddress common.Address) (*thirdparty.CollectibleContractOwnership, error) {
|
||||
query, args, err := sqlx.In(fmt.Sprintf(`SELECT %s
|
||||
FROM collectibles_ownership_cache
|
||||
WHERE chain_id = ? AND contract_address = ?`, collectiblesOwnershipColumns), chainID, contractAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ret thirdparty.CollectibleContractOwnership
|
||||
|
||||
stmt, err := o.db.Prepare(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
rows, err := stmt.Query(args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
tokenID := &bigint.BigInt{Int: big.NewInt(0)}
|
||||
var ownerAddress common.Address
|
||||
balance := &bigint.BigInt{Int: big.NewInt(0)}
|
||||
var tokenBalances []thirdparty.TokenBalance
|
||||
|
||||
for rows.Next() {
|
||||
err = rows.Scan(
|
||||
(*bigint.SQLBigIntBytes)(tokenID.Int),
|
||||
&ownerAddress,
|
||||
(*bigint.SQLBigIntBytes)(balance.Int),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tokenBalance := thirdparty.TokenBalance{
|
||||
TokenID: tokenID,
|
||||
Balance: balance,
|
||||
}
|
||||
tokenBalances = append(tokenBalances, tokenBalance)
|
||||
|
||||
collectibleOwner := thirdparty.CollectibleOwner{
|
||||
OwnerAddress: ownerAddress,
|
||||
TokenBalances: tokenBalances,
|
||||
}
|
||||
|
||||
ret.ContractAddress = contractAddress
|
||||
ret.Owners = append(ret.Owners, collectibleOwner)
|
||||
}
|
||||
|
||||
return &ret, nil
|
||||
}
|
||||
|
||||
func (o *OwnershipDB) GetOwnedCollectible(chainID w_common.ChainID, ownerAddresses common.Address, contractAddress common.Address, tokenID *big.Int) (*thirdparty.CollectibleUniqueID, error) {
|
||||
query := fmt.Sprintf(`SELECT %s
|
||||
FROM collectibles_ownership_cache
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -964,3 +965,63 @@ func (tm *Manager) onAccountsChange(changedAddresses []common.Address, eventType
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *Manager) GetCachedBalancesByChain(accounts, tokenAddresses []common.Address, chainIDs []uint64) (map[uint64]map[common.Address]map[common.Address]*hexutil.Big, error) {
|
||||
accountStrings := make([]string, len(accounts))
|
||||
for i, account := range accounts {
|
||||
accountStrings[i] = fmt.Sprintf("'%s'", account.Hex())
|
||||
}
|
||||
|
||||
tokenAddressStrings := make([]string, len(tokenAddresses))
|
||||
for i, tokenAddress := range tokenAddresses {
|
||||
tokenAddressStrings[i] = fmt.Sprintf("'%s'", tokenAddress.Hex())
|
||||
}
|
||||
|
||||
chainIDStrings := make([]string, len(chainIDs))
|
||||
for i, chainID := range chainIDs {
|
||||
chainIDStrings[i] = fmt.Sprintf("%d", chainID)
|
||||
}
|
||||
|
||||
query := `SELECT chain_id, user_address, token_address, balance
|
||||
FROM token_balances
|
||||
WHERE user_address IN (` + strings.Join(accountStrings, ",") + `)
|
||||
AND token_address IN (` + strings.Join(tokenAddressStrings, ",") + `)
|
||||
AND chain_id IN (` + strings.Join(chainIDStrings, ",") + `)`
|
||||
|
||||
rows, err := tm.db.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
ret := make(map[uint64]map[common.Address]map[common.Address]*hexutil.Big)
|
||||
|
||||
for rows.Next() {
|
||||
var chainID uint64
|
||||
var userAddressStr, tokenAddressStr string
|
||||
var balanceStr string
|
||||
|
||||
err := rows.Scan(&chainID, &userAddressStr, &tokenAddressStr, &balanceStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
num := new(hexutil.Big)
|
||||
_, ok := num.ToInt().SetString(balanceStr, 0)
|
||||
if !ok {
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
if ret[chainID] == nil {
|
||||
ret[chainID] = make(map[common.Address]map[common.Address]*hexutil.Big)
|
||||
}
|
||||
|
||||
if ret[chainID][common.HexToAddress(userAddressStr)] == nil {
|
||||
ret[chainID][common.HexToAddress(userAddressStr)] = make(map[common.Address]*hexutil.Big)
|
||||
}
|
||||
|
||||
ret[chainID][common.HexToAddress(userAddressStr)][common.HexToAddress(tokenAddressStr)] = num
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue