fix(communities)_: correct >1 NFT token requirement evaluation

Fixed logic to respect specified NFT quantities. Previously, holding one
NFT sufficed, regardless of the required count.

fixes: status-im/status-desktop#15122
This commit is contained in:
Patryk Osmaczko 2024-06-13 15:26:52 +02:00 committed by osmaczko
parent d351acbba5
commit 88c671fcf0
5 changed files with 392 additions and 176 deletions

View File

@ -296,15 +296,16 @@ func (s *ManagerSuite) TestRetrieveCollectibles() {
var tokenBalances []thirdparty.TokenBalance var tokenBalances []thirdparty.TokenBalance
var tokenCriteria = []*protobuf.TokenCriteria{ var tokenCriteria = []*protobuf.TokenCriteria{
&protobuf.TokenCriteria{ {
ContractAddresses: contractAddresses, ContractAddresses: contractAddresses,
TokenIds: []uint64{tokenID}, TokenIds: []uint64{tokenID},
Type: protobuf.CommunityTokenType_ERC721, Type: protobuf.CommunityTokenType_ERC721,
AmountInWei: "1",
}, },
} }
var permissions = []*CommunityTokenPermission{ var permissions = []*CommunityTokenPermission{
&CommunityTokenPermission{ {
CommunityTokenPermission: &protobuf.CommunityTokenPermission{ CommunityTokenPermission: &protobuf.CommunityTokenPermission{
Id: "some-id", Id: "some-id",
Type: protobuf.CommunityTokenPermission_BECOME_MEMBER, Type: protobuf.CommunityTokenPermission_BECOME_MEMBER,
@ -316,7 +317,7 @@ func (s *ManagerSuite) TestRetrieveCollectibles() {
preParsedPermissions := preParsedCommunityPermissionsData(permissions) preParsedPermissions := preParsedCommunityPermissionsData(permissions)
accountChainIDsCombination := []*AccountChainIDsCombination{ accountChainIDsCombination := []*AccountChainIDsCombination{
&AccountChainIDsCombination{ {
Address: gethcommon.HexToAddress("0xD6b912e09E797D291E8D0eA3D3D17F8000e01c32"), Address: gethcommon.HexToAddress("0xD6b912e09E797D291E8D0eA3D3D17F8000e01c32"),
ChainIDs: []uint64{chainID}, ChainIDs: []uint64{chainID},
}, },

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"math/big" "math/big"
"strconv"
"strings" "strings"
"go.uber.org/zap" "go.uber.org/zap"
@ -221,6 +222,158 @@ func (p *DefaultPermissionChecker) checkPermissionsOrDefault(permissions []*Comm
type ownedERC721TokensGetter = func(walletAddresses []gethcommon.Address, tokenRequirements map[uint64]map[string]*protobuf.TokenCriteria, chainIDs []uint64) (CollectiblesByChain, error) type ownedERC721TokensGetter = func(walletAddresses []gethcommon.Address, tokenRequirements map[uint64]map[string]*protobuf.TokenCriteria, chainIDs []uint64) (CollectiblesByChain, error)
type balancesByChainGetter = func(ctx context.Context, accounts, tokens []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error) type balancesByChainGetter = func(ctx context.Context, accounts, tokens []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error)
func (p *DefaultPermissionChecker) checkTokenRequirement(
tokenRequirement *protobuf.TokenCriteria,
accounts []gethcommon.Address, ownedERC20TokenBalances BalancesByChain, ownedERC721Tokens CollectiblesByChain,
accountsChainIDsCombinations map[gethcommon.Address]map[uint64]bool,
) (TokenRequirementResponse, error) {
tokenRequirementResponse := TokenRequirementResponse{TokenCriteria: tokenRequirement}
switch tokenRequirement.Type {
case protobuf.CommunityTokenType_ERC721:
if len(ownedERC721Tokens) == 0 {
return tokenRequirementResponse, nil
}
// Limit NFTs count to uint32
requiredCount, err := strconv.ParseUint(tokenRequirement.AmountInWei, 10, 32)
if err != nil {
return tokenRequirementResponse, fmt.Errorf("invalid ERC721 amount: %s", tokenRequirement.AmountInWei)
}
accumulatedCount := uint64(0)
for chainID, addressStr := range tokenRequirement.ContractAddresses {
contractAddress := gethcommon.HexToAddress(addressStr)
if _, exists := ownedERC721Tokens[chainID]; !exists || len(ownedERC721Tokens[chainID]) == 0 {
continue
}
for account := range ownedERC721Tokens[chainID] {
if _, exists := ownedERC721Tokens[chainID][account]; !exists {
continue
}
tokenBalances := ownedERC721Tokens[chainID][account][contractAddress]
accumulatedCount += uint64(len(tokenBalances))
if len(tokenBalances) > 0 {
// 'account' owns some TokenID owned from contract 'address'
if _, exists := accountsChainIDsCombinations[account]; !exists {
accountsChainIDsCombinations[account] = make(map[uint64]bool)
}
// account has balance > 0 on this chain for this token, so let's add it the chain IDs
accountsChainIDsCombinations[account][chainID] = true
if len(tokenRequirement.TokenIds) == 0 {
// no specific tokenId of this collection is needed
if accumulatedCount >= requiredCount {
tokenRequirementResponse.Satisfied = true
return tokenRequirementResponse, nil
}
}
for _, tokenID := range tokenRequirement.TokenIds {
tokenIDBigInt := new(big.Int).SetUint64(tokenID)
for _, asset := range tokenBalances {
if asset.TokenID.Cmp(tokenIDBigInt) == 0 && asset.Balance.Sign() > 0 {
tokenRequirementResponse.Satisfied = true
return tokenRequirementResponse, nil
}
}
}
}
}
}
case protobuf.CommunityTokenType_ERC20:
if len(ownedERC20TokenBalances) == 0 {
return tokenRequirementResponse, nil
}
accumulatedBalance := new(big.Int)
chainIDLoopERC20:
for chainID, address := range tokenRequirement.ContractAddresses {
if _, exists := ownedERC20TokenBalances[chainID]; !exists || len(ownedERC20TokenBalances[chainID]) == 0 {
continue chainIDLoopERC20
}
contractAddress := gethcommon.HexToAddress(address)
for account := range ownedERC20TokenBalances[chainID] {
if _, exists := ownedERC20TokenBalances[chainID][account][contractAddress]; !exists {
continue
}
value := ownedERC20TokenBalances[chainID][account][contractAddress]
if _, exists := accountsChainIDsCombinations[account]; !exists {
accountsChainIDsCombinations[account] = make(map[uint64]bool)
}
if value.ToInt().Cmp(big.NewInt(0)) > 0 {
// account has balance > 0 on this chain for this token, so let's add it the chain IDs
accountsChainIDsCombinations[account][chainID] = true
}
// check if adding current chain account balance to accumulated balance
// satisfies required amount
prevBalance := accumulatedBalance
accumulatedBalance.Add(prevBalance, value.ToInt())
requiredAmount, success := new(big.Int).SetString(tokenRequirement.AmountInWei, 10)
if !success {
return tokenRequirementResponse, fmt.Errorf("amountInWeis value is incorrect: %s", tokenRequirement.AmountInWei)
}
if accumulatedBalance.Cmp(requiredAmount) != -1 {
tokenRequirementResponse.Satisfied = true
return tokenRequirementResponse, nil
}
}
}
case protobuf.CommunityTokenType_ENS:
for _, account := range accounts {
ownedENSNames, err := p.getOwnedENS([]gethcommon.Address{account})
if err != nil {
return tokenRequirementResponse, err
}
if _, exists := accountsChainIDsCombinations[account]; !exists {
accountsChainIDsCombinations[account] = make(map[uint64]bool)
}
if !strings.HasPrefix(tokenRequirement.EnsPattern, "*.") {
for _, ownedENS := range ownedENSNames {
if ownedENS == tokenRequirement.EnsPattern {
accountsChainIDsCombinations[account][walletcommon.EthereumMainnet] = true
tokenRequirementResponse.Satisfied = true
return tokenRequirementResponse, nil
}
}
} else {
parentName := tokenRequirement.EnsPattern[2:]
for _, ownedENS := range ownedENSNames {
if strings.HasSuffix(ownedENS, parentName) {
accountsChainIDsCombinations[account][walletcommon.EthereumMainnet] = true
tokenRequirementResponse.Satisfied = true
return tokenRequirementResponse, nil
}
}
}
}
}
return tokenRequirementResponse, nil
}
func (p *DefaultPermissionChecker) checkPermissions(permissionsParsedData *PreParsedCommunityPermissionsData, accountsAndChainIDs []*AccountChainIDsCombination, shortcircuit bool, func (p *DefaultPermissionChecker) checkPermissions(permissionsParsedData *PreParsedCommunityPermissionsData, accountsAndChainIDs []*AccountChainIDsCombination, shortcircuit bool,
getOwnedERC721Tokens ownedERC721TokensGetter, getBalancesByChain balancesByChainGetter) (*CheckPermissionsResponse, error) { getOwnedERC721Tokens ownedERC721TokensGetter, getBalancesByChain balancesByChainGetter) (*CheckPermissionsResponse, error) {
@ -281,7 +434,6 @@ func (p *DefaultPermissionChecker) checkPermissions(permissionsParsedData *PrePa
accountsChainIDsCombinations := make(map[gethcommon.Address]map[uint64]bool) accountsChainIDsCombinations := make(map[gethcommon.Address]map[uint64]bool)
for _, tokenPermission := range permissionsParsedData.Permissions { for _, tokenPermission := range permissionsParsedData.Permissions {
permissionRequirementsMet := true permissionRequirementsMet := true
response.Permissions[tokenPermission.Id] = &PermissionTokenCriteriaResult{Role: tokenPermission.Type} response.Permissions[tokenPermission.Id] = &PermissionTokenCriteriaResult{Role: tokenPermission.Type}
@ -289,146 +441,17 @@ func (p *DefaultPermissionChecker) checkPermissions(permissionsParsedData *PrePa
// If only one is not met, the entire permission is marked // If only one is not met, the entire permission is marked
// as not fulfilled // as not fulfilled
for _, tokenRequirement := range tokenPermission.TokenCriteria { for _, tokenRequirement := range tokenPermission.TokenCriteria {
tokenRequirementResponse, err := p.checkTokenRequirement(tokenRequirement, accounts, ownedERC20TokenBalances, ownedERC721Tokens, accountsChainIDsCombinations)
tokenRequirementMet := false if err != nil {
tokenRequirementResponse := TokenRequirementResponse{TokenCriteria: tokenRequirement} p.logger.Error("failed to check token requirement", zap.Error(err))
if tokenRequirement.Type == protobuf.CommunityTokenType_ERC721 {
if len(ownedERC721Tokens) == 0 {
response.Permissions[tokenPermission.Id].TokenRequirements = append(response.Permissions[tokenPermission.Id].TokenRequirements, tokenRequirementResponse)
response.Permissions[tokenPermission.Id].Criteria = append(response.Permissions[tokenPermission.Id].Criteria, false)
continue
}
chainIDLoopERC721:
for chainID, addressStr := range tokenRequirement.ContractAddresses {
contractAddress := gethcommon.HexToAddress(addressStr)
if _, exists := ownedERC721Tokens[chainID]; !exists || len(ownedERC721Tokens[chainID]) == 0 {
continue chainIDLoopERC721
}
for account := range ownedERC721Tokens[chainID] {
if _, exists := ownedERC721Tokens[chainID][account]; !exists {
continue
}
tokenBalances := ownedERC721Tokens[chainID][account][contractAddress]
if len(tokenBalances) > 0 {
// 'account' owns some TokenID owned from contract 'address'
if _, exists := accountsChainIDsCombinations[account]; !exists {
accountsChainIDsCombinations[account] = make(map[uint64]bool)
}
if len(tokenRequirement.TokenIds) == 0 {
// no specific tokenId of this collection is needed
tokenRequirementMet = true
accountsChainIDsCombinations[account][chainID] = true
break chainIDLoopERC721
}
tokenIDsLoop:
for _, tokenID := range tokenRequirement.TokenIds {
tokenIDBigInt := new(big.Int).SetUint64(tokenID)
for _, asset := range tokenBalances {
if asset.TokenID.Cmp(tokenIDBigInt) == 0 && asset.Balance.Sign() > 0 {
tokenRequirementMet = true
accountsChainIDsCombinations[account][chainID] = true
break tokenIDsLoop
}
}
}
}
}
}
} else if tokenRequirement.Type == protobuf.CommunityTokenType_ERC20 {
if len(ownedERC20TokenBalances) == 0 {
response.Permissions[tokenPermission.Id].TokenRequirements = append(response.Permissions[tokenPermission.Id].TokenRequirements, tokenRequirementResponse)
response.Permissions[tokenPermission.Id].Criteria = append(response.Permissions[tokenPermission.Id].Criteria, false)
continue
}
accumulatedBalance := new(big.Int)
chainIDLoopERC20:
for chainID, address := range tokenRequirement.ContractAddresses {
if _, exists := ownedERC20TokenBalances[chainID]; !exists || len(ownedERC20TokenBalances[chainID]) == 0 {
continue chainIDLoopERC20
}
contractAddress := gethcommon.HexToAddress(address)
for account := range ownedERC20TokenBalances[chainID] {
if _, exists := ownedERC20TokenBalances[chainID][account][contractAddress]; !exists {
continue
}
value := ownedERC20TokenBalances[chainID][account][contractAddress]
if _, exists := accountsChainIDsCombinations[account]; !exists {
accountsChainIDsCombinations[account] = make(map[uint64]bool)
}
if value.ToInt().Cmp(big.NewInt(0)) > 0 {
// account has balance > 0 on this chain for this token, so let's add it the chain IDs
accountsChainIDsCombinations[account][chainID] = true
}
// check if adding current chain account balance to accumulated balance
// satisfies required amount
prevBalance := accumulatedBalance
accumulatedBalance.Add(prevBalance, value.ToInt())
requiredAmount, success := new(big.Int).SetString(tokenRequirement.AmountInWei, 10)
if !success {
return nil, fmt.Errorf("amountInWeis value is incorrect: %s", tokenRequirement.AmountInWei)
}
if accumulatedBalance.Cmp(requiredAmount) != -1 {
tokenRequirementMet = true
if shortcircuit {
break chainIDLoopERC20
}
}
}
}
} else if tokenRequirement.Type == protobuf.CommunityTokenType_ENS {
for _, account := range accounts {
ownedENSNames, err := p.getOwnedENS([]gethcommon.Address{account})
if err != nil {
return nil, err
}
if _, exists := accountsChainIDsCombinations[account]; !exists {
accountsChainIDsCombinations[account] = make(map[uint64]bool)
}
if !strings.HasPrefix(tokenRequirement.EnsPattern, "*.") {
for _, ownedENS := range ownedENSNames {
if ownedENS == tokenRequirement.EnsPattern {
tokenRequirementMet = true
accountsChainIDsCombinations[account][walletcommon.EthereumMainnet] = true
}
}
} else {
parentName := tokenRequirement.EnsPattern[2:]
for _, ownedENS := range ownedENSNames {
if strings.HasSuffix(ownedENS, parentName) {
tokenRequirementMet = true
accountsChainIDsCombinations[account][walletcommon.EthereumMainnet] = true
}
}
}
}
} }
if !tokenRequirementMet {
if !tokenRequirementResponse.Satisfied {
permissionRequirementsMet = false permissionRequirementsMet = false
} }
tokenRequirementResponse.Satisfied = tokenRequirementMet
response.Permissions[tokenPermission.Id].TokenRequirements = append(response.Permissions[tokenPermission.Id].TokenRequirements, tokenRequirementResponse) response.Permissions[tokenPermission.Id].TokenRequirements = append(response.Permissions[tokenPermission.Id].TokenRequirements, tokenRequirementResponse)
response.Permissions[tokenPermission.Id].Criteria = append(response.Permissions[tokenPermission.Id].Criteria, tokenRequirementMet) response.Permissions[tokenPermission.Id].Criteria = append(response.Permissions[tokenPermission.Id].Criteria, tokenRequirementResponse.Satisfied)
} }
response.Permissions[tokenPermission.Id].ID = tokenPermission.Id response.Permissions[tokenPermission.Id].ID = tokenPermission.Id

View File

@ -1,11 +1,21 @@
package communities package communities
import ( import (
"context"
"errors"
"fmt"
"math/big"
"strconv"
"testing" "testing"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/status-im/status-go/protocol/protobuf"
"github.com/status-im/status-go/services/wallet/bigint"
"github.com/status-im/status-go/services/wallet/thirdparty"
gethcommon "github.com/ethereum/go-ethereum/common" gethcommon "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
) )
func TestPermissionCheckerSuite(t *testing.T) { func TestPermissionCheckerSuite(t *testing.T) {
@ -56,3 +66,151 @@ func (s *PermissionCheckerSuite) TestMergeValidCombinations() {
} }
} }
func (s *PermissionCheckerSuite) TestCheckPermissions() {
testCases := []struct {
name string
amountInWei func(t protobuf.CommunityTokenType) string
requiredAmountInWei func(t protobuf.CommunityTokenType) string
shouldSatisfy bool
}{
{
name: "account does not meet criteria",
amountInWei: func(t protobuf.CommunityTokenType) string {
if t == protobuf.CommunityTokenType_ERC721 {
return "1"
}
return "1000000000000000000"
},
requiredAmountInWei: func(t protobuf.CommunityTokenType) string {
if t == protobuf.CommunityTokenType_ERC721 {
return "2"
}
return "2000000000000000000"
},
shouldSatisfy: false,
},
{
name: "account does exactly meet criteria",
amountInWei: func(t protobuf.CommunityTokenType) string {
if t == protobuf.CommunityTokenType_ERC721 {
return "2"
}
return "2000000000000000000"
},
requiredAmountInWei: func(t protobuf.CommunityTokenType) string {
if t == protobuf.CommunityTokenType_ERC721 {
return "2"
}
return "2000000000000000000"
},
shouldSatisfy: true,
},
{
name: "account does meet criteria",
amountInWei: func(t protobuf.CommunityTokenType) string {
if t == protobuf.CommunityTokenType_ERC721 {
return "3"
}
return "3000000000000000000"
},
requiredAmountInWei: func(t protobuf.CommunityTokenType) string {
if t == protobuf.CommunityTokenType_ERC721 {
return "2"
}
return "2000000000000000000"
},
shouldSatisfy: true,
},
}
permissionChecker := DefaultPermissionChecker{}
chainID := uint64(1)
contractAddress := gethcommon.HexToAddress("0x3d6afaa395c31fcd391fe3d562e75fe9e8ec7e6a")
walletAddress := gethcommon.HexToAddress("0xD6b912e09E797D291E8D0eA3D3D17F8000e01c32")
for _, tc := range testCases {
for _, tokenType := range [](protobuf.CommunityTokenType){protobuf.CommunityTokenType_ERC20, protobuf.CommunityTokenType_ERC721} {
s.Run(fmt.Sprintf("%s_%s", tc.name, tokenType.String()), func() {
decimals := uint64(0)
if tokenType == protobuf.CommunityTokenType_ERC20 {
decimals = 18
}
permissions := map[string]*CommunityTokenPermission{
"p1": {
CommunityTokenPermission: &protobuf.CommunityTokenPermission{
Id: "p1",
Type: protobuf.CommunityTokenPermission_BECOME_MEMBER,
TokenCriteria: []*protobuf.TokenCriteria{
{
ContractAddresses: map[uint64]string{
chainID: contractAddress.String(),
},
Type: tokenType,
Symbol: "STT",
TokenIds: []uint64{},
Decimals: decimals,
AmountInWei: tc.requiredAmountInWei(tokenType),
},
},
},
},
}
permissionsData, _ := PreParsePermissionsData(permissions)
accountsAndChainIDs := []*AccountChainIDsCombination{
{
Address: walletAddress,
ChainIDs: []uint64{chainID},
},
}
var getOwnedERC721Tokens ownedERC721TokensGetter = func(walletAddresses []gethcommon.Address, tokenRequirements map[uint64]map[string]*protobuf.TokenCriteria, chainIDs []uint64) (CollectiblesByChain, error) {
amount, err := strconv.ParseUint(tc.amountInWei(protobuf.CommunityTokenType_ERC721), 10, 64)
if err != nil {
return nil, err
}
balances := []thirdparty.TokenBalance{}
for i := uint64(0); i < amount; i++ {
balances = append(balances, thirdparty.TokenBalance{
TokenID: &bigint.BigInt{
Int: new(big.Int).SetUint64(i + 1),
},
Balance: &bigint.BigInt{
Int: new(big.Int).SetUint64(1),
},
})
}
return CollectiblesByChain{
chainID: {
walletAddress: {
contractAddress: balances,
},
},
}, nil
}
var getBalancesByChain balancesByChainGetter = func(ctx context.Context, accounts, tokens []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error) {
balance, ok := new(big.Int).SetString(tc.amountInWei(protobuf.CommunityTokenType_ERC20), 10)
if !ok {
return nil, errors.New("invalid conversion")
}
return BalancesByChain{
chainID: {
walletAddress: {
contractAddress: (*hexutil.Big)(balance),
},
},
}, nil
}
response, err := permissionChecker.checkPermissions(permissionsData[protobuf.CommunityTokenPermission_BECOME_MEMBER], accountsAndChainIDs, true, getOwnedERC721Tokens, getBalancesByChain)
s.Require().NoError(err)
s.Require().Equal(tc.shouldSatisfy, response.Satisfied)
})
}
}
}

View File

@ -5,7 +5,6 @@ import (
"crypto/ecdsa" "crypto/ecdsa"
"encoding/json" "encoding/json"
"errors" "errors"
"math/big"
"sync" "sync"
"time" "time"
@ -25,7 +24,6 @@ import (
"github.com/status-im/status-go/protocol/communities/token" "github.com/status-im/status-go/protocol/communities/token"
"github.com/status-im/status-go/protocol/protobuf" "github.com/status-im/status-go/protocol/protobuf"
"github.com/status-im/status-go/protocol/requests" "github.com/status-im/status-go/protocol/requests"
"github.com/status-im/status-go/services/wallet/bigint"
walletCommon "github.com/status-im/status-go/services/wallet/common" walletCommon "github.com/status-im/status-go/services/wallet/common"
"github.com/status-im/status-go/services/wallet/thirdparty" "github.com/status-im/status-go/services/wallet/thirdparty"
walletToken "github.com/status-im/status-go/services/wallet/token" walletToken "github.com/status-im/status-go/services/wallet/token"
@ -55,7 +53,7 @@ func (m *AccountManagerMock) DeleteAccount(address types.Address) error {
} }
type TokenManagerMock struct { type TokenManagerMock struct {
Balances *map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big Balances *communities.BalancesByChain
} }
func (m *TokenManagerMock) GetAllChainIDs() ([]uint64, error) { func (m *TokenManagerMock) GetAllChainIDs() ([]uint64, error) {
@ -82,7 +80,7 @@ func (m *TokenManagerMock) FindOrCreateTokenByAddress(ctx context.Context, chain
} }
type CollectiblesManagerMock struct { type CollectiblesManagerMock struct {
Balances *map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big Collectibles *communities.CollectiblesByChain
collectibleOwnershipResponse map[string][]thirdparty.AccountBalance collectibleOwnershipResponse map[string][]thirdparty.AccountBalance
} }
@ -94,7 +92,7 @@ func (m *CollectiblesManagerMock) FetchCachedBalancesByOwnerAndContractAddress(c
func (m *CollectiblesManagerMock) FetchBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletCommon.ChainID, func (m *CollectiblesManagerMock) FetchBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletCommon.ChainID,
ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (thirdparty.TokenBalancesPerContractAddress, error) { ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (thirdparty.TokenBalancesPerContractAddress, error) {
ret := make(thirdparty.TokenBalancesPerContractAddress) ret := make(thirdparty.TokenBalancesPerContractAddress)
accountsBalances, ok := (*m.Balances)[uint64(chainID)] accountsBalances, ok := (*m.Collectibles)[uint64(chainID)]
if !ok { if !ok {
return ret, nil return ret, nil
} }
@ -107,14 +105,7 @@ func (m *CollectiblesManagerMock) FetchBalancesByOwnerAndContractAddress(ctx con
for _, contractAddress := range contractAddresses { for _, contractAddress := range contractAddresses {
balance, ok := balances[contractAddress] balance, ok := balances[contractAddress]
if ok { if ok {
ret[contractAddress] = []thirdparty.TokenBalance{ ret[contractAddress] = balance
{
TokenID: &bigint.BigInt{},
Balance: &bigint.BigInt{
Int: (*big.Int)(balance),
},
},
}
} }
} }
@ -135,24 +126,17 @@ func (m *CollectiblesManagerMock) FetchCollectibleOwnersByContractAddress(ctx co
ContractAddress: contractAddress, ContractAddress: contractAddress,
Owners: []thirdparty.CollectibleOwner{}, Owners: []thirdparty.CollectibleOwner{},
} }
accountsBalances, ok := (*m.Balances)[uint64(chainID)] accountsBalances, ok := (*m.Collectibles)[uint64(chainID)]
if !ok { if !ok {
return ret, nil return ret, nil
} }
for wallet, collectiblesBalance := range accountsBalances { for wallet, balances := range accountsBalances {
balance, ok := collectiblesBalance[contractAddress] balance, ok := balances[contractAddress]
if ok { if ok {
ret.Owners = append(ret.Owners, thirdparty.CollectibleOwner{ ret.Owners = append(ret.Owners, thirdparty.CollectibleOwner{
OwnerAddress: wallet, OwnerAddress: wallet,
TokenBalances: []thirdparty.TokenBalance{ TokenBalances: balance,
{
TokenID: &bigint.BigInt{},
Balance: &bigint.BigInt{
Int: (*big.Int)(balance),
},
},
},
}) })
} }
} }
@ -256,7 +240,8 @@ type testCommunitiesMessengerConfig struct {
password string password string
walletAddresses []string walletAddresses []string
mockedBalances *map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big mockedBalances *communities.BalancesByChain
mockedCollectibles *communities.CollectiblesByChain
collectiblesService communities.CommunityTokensServiceInterface collectiblesService communities.CommunityTokensServiceInterface
} }
@ -323,7 +308,7 @@ func newTestCommunitiesMessenger(s *suite.Suite, waku types.Waku, config testCom
} }
collectiblesManagerMock := &CollectiblesManagerMock{ collectiblesManagerMock := &CollectiblesManagerMock{
Balances: config.mockedBalances, Collectibles: config.mockedCollectibles,
} }
options := []Option{ options := []Option{

View File

@ -32,6 +32,8 @@ import (
"github.com/status-im/status-go/protocol/requests" "github.com/status-im/status-go/protocol/requests"
"github.com/status-im/status-go/protocol/transport" "github.com/status-im/status-go/protocol/transport"
"github.com/status-im/status-go/protocol/tt" "github.com/status-im/status-go/protocol/tt"
"github.com/status-im/status-go/services/wallet/bigint"
"github.com/status-im/status-go/services/wallet/thirdparty"
) )
const testChainID1 = 1 const testChainID1 = 1
@ -145,7 +147,8 @@ type MessengerCommunitiesTokenPermissionsSuite struct {
logger *zap.Logger logger *zap.Logger
mockedBalances map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big // chainID, account, token, balance mockedBalances communities.BalancesByChain
mockedCollectibles communities.CollectiblesByChain
collectiblesServiceMock *CollectiblesServiceMock collectiblesServiceMock *CollectiblesServiceMock
} }
@ -212,6 +215,7 @@ func (s *MessengerCommunitiesTokenPermissionsSuite) newMessenger(password string
password: password, password: password,
walletAddresses: walletAddresses, walletAddresses: walletAddresses,
mockedBalances: &s.mockedBalances, mockedBalances: &s.mockedBalances,
mockedCollectibles: &s.mockedCollectibles,
collectiblesService: s.collectiblesServiceMock, collectiblesService: s.collectiblesServiceMock,
}) })
} }
@ -245,10 +249,35 @@ func (s *MessengerCommunitiesTokenPermissionsSuite) sendChatMessage(sender *Mess
func (s *MessengerCommunitiesTokenPermissionsSuite) makeAddressSatisfyTheCriteria(chainID uint64, address string, criteria *protobuf.TokenCriteria) { func (s *MessengerCommunitiesTokenPermissionsSuite) makeAddressSatisfyTheCriteria(chainID uint64, address string, criteria *protobuf.TokenCriteria) {
walletAddress := gethcommon.HexToAddress(address) walletAddress := gethcommon.HexToAddress(address)
contractAddress := gethcommon.HexToAddress(criteria.ContractAddresses[chainID]) contractAddress := gethcommon.HexToAddress(criteria.ContractAddresses[chainID])
balance, ok := new(big.Int).SetString(criteria.AmountInWei, 10)
s.Require().True(ok)
s.mockedBalances[chainID][walletAddress][contractAddress] = (*hexutil.Big)(balance) switch criteria.Type {
case protobuf.CommunityTokenType_ERC20:
balance, ok := new(big.Int).SetString(criteria.AmountInWei, 10)
s.Require().True(ok)
s.mockedBalances[chainID][walletAddress][contractAddress] = (*hexutil.Big)(balance)
case protobuf.CommunityTokenType_ERC721:
amount, err := strconv.ParseUint(criteria.AmountInWei, 10, 32)
s.Require().NoError(err)
balances := []thirdparty.TokenBalance{}
for i := uint64(0); i < amount; i++ {
balances = append(balances, thirdparty.TokenBalance{
TokenID: &bigint.BigInt{
Int: new(big.Int).SetUint64(i + 1),
},
Balance: &bigint.BigInt{
Int: new(big.Int).SetUint64(1),
},
})
}
s.mockedCollectibles[chainID][walletAddress][contractAddress] = balances
case protobuf.CommunityTokenType_ENS:
// not implemented
}
} }
func (s *MessengerCommunitiesTokenPermissionsSuite) resetMockedBalances() { func (s *MessengerCommunitiesTokenPermissionsSuite) resetMockedBalances() {
@ -257,6 +286,12 @@ func (s *MessengerCommunitiesTokenPermissionsSuite) resetMockedBalances() {
s.mockedBalances[testChainID1][gethcommon.HexToAddress(aliceAddress1)] = make(map[gethcommon.Address]*hexutil.Big) s.mockedBalances[testChainID1][gethcommon.HexToAddress(aliceAddress1)] = make(map[gethcommon.Address]*hexutil.Big)
s.mockedBalances[testChainID1][gethcommon.HexToAddress(aliceAddress2)] = make(map[gethcommon.Address]*hexutil.Big) s.mockedBalances[testChainID1][gethcommon.HexToAddress(aliceAddress2)] = make(map[gethcommon.Address]*hexutil.Big)
s.mockedBalances[testChainID1][gethcommon.HexToAddress(bobAddress)] = make(map[gethcommon.Address]*hexutil.Big) s.mockedBalances[testChainID1][gethcommon.HexToAddress(bobAddress)] = make(map[gethcommon.Address]*hexutil.Big)
s.mockedCollectibles = make(communities.CollectiblesByChain)
s.mockedCollectibles[testChainID1] = make(map[gethcommon.Address]thirdparty.TokenBalancesPerContractAddress)
s.mockedCollectibles[testChainID1][gethcommon.HexToAddress(aliceAddress1)] = make(thirdparty.TokenBalancesPerContractAddress)
s.mockedCollectibles[testChainID1][gethcommon.HexToAddress(aliceAddress2)] = make(thirdparty.TokenBalancesPerContractAddress)
s.mockedCollectibles[testChainID1][gethcommon.HexToAddress(bobAddress)] = make(thirdparty.TokenBalancesPerContractAddress)
} }
func (s *MessengerCommunitiesTokenPermissionsSuite) waitOnKeyDistribution(condition func(*CommunityAndKeyActions) bool) <-chan error { func (s *MessengerCommunitiesTokenPermissionsSuite) waitOnKeyDistribution(condition func(*CommunityAndKeyActions) bool) <-chan error {
@ -1604,6 +1639,13 @@ func (s *MessengerCommunitiesTokenPermissionsSuite) TestMemberRoleGetUpdatedWhen
func (s *MessengerCommunitiesTokenPermissionsSuite) testReevaluateMemberPrivilegedRoleInOpenCommunity(permissionType protobuf.CommunityTokenPermission_Type, tokenType protobuf.CommunityTokenType) { func (s *MessengerCommunitiesTokenPermissionsSuite) testReevaluateMemberPrivilegedRoleInOpenCommunity(permissionType protobuf.CommunityTokenPermission_Type, tokenType protobuf.CommunityTokenType) {
community, _ := s.createCommunity() community, _ := s.createCommunity()
amountInWei := "100000000000000000000"
decimals := uint64(18)
if tokenType == protobuf.CommunityTokenType_ERC721 {
amountInWei = "1"
decimals = 0
}
createTokenPermission := &requests.CreateCommunityTokenPermission{ createTokenPermission := &requests.CreateCommunityTokenPermission{
CommunityID: community.ID(), CommunityID: community.ID(),
Type: permissionType, Type: permissionType,
@ -1612,8 +1654,8 @@ func (s *MessengerCommunitiesTokenPermissionsSuite) testReevaluateMemberPrivileg
Type: tokenType, Type: tokenType,
ContractAddresses: map[uint64]string{testChainID1: "0x123"}, ContractAddresses: map[uint64]string{testChainID1: "0x123"},
Symbol: "TEST", Symbol: "TEST",
AmountInWei: "100000000000000000000", AmountInWei: amountInWei,
Decimals: uint64(18), Decimals: decimals,
}, },
}, },
} }
@ -1723,6 +1765,13 @@ func (s *MessengerCommunitiesTokenPermissionsSuite) TestReevaluateMemberTokenMas
func (s *MessengerCommunitiesTokenPermissionsSuite) testReevaluateMemberPrivilegedRoleInClosedCommunity(permissionType protobuf.CommunityTokenPermission_Type, tokenType protobuf.CommunityTokenType) { func (s *MessengerCommunitiesTokenPermissionsSuite) testReevaluateMemberPrivilegedRoleInClosedCommunity(permissionType protobuf.CommunityTokenPermission_Type, tokenType protobuf.CommunityTokenType) {
community, _ := s.createCommunity() community, _ := s.createCommunity()
amountInWei := "100000000000000000000"
decimals := uint64(18)
if tokenType == protobuf.CommunityTokenType_ERC721 {
amountInWei = "1"
decimals = 0
}
createTokenPermission := &requests.CreateCommunityTokenPermission{ createTokenPermission := &requests.CreateCommunityTokenPermission{
CommunityID: community.ID(), CommunityID: community.ID(),
Type: permissionType, Type: permissionType,
@ -1731,8 +1780,8 @@ func (s *MessengerCommunitiesTokenPermissionsSuite) testReevaluateMemberPrivileg
Type: tokenType, Type: tokenType,
ContractAddresses: map[uint64]string{testChainID1: "0x123"}, ContractAddresses: map[uint64]string{testChainID1: "0x123"},
Symbol: "TEST", Symbol: "TEST",
AmountInWei: "100000000000000000000", AmountInWei: amountInWei,
Decimals: uint64(18), Decimals: decimals,
}, },
}, },
} }
@ -1751,8 +1800,8 @@ func (s *MessengerCommunitiesTokenPermissionsSuite) testReevaluateMemberPrivileg
Type: tokenType, Type: tokenType,
ContractAddresses: map[uint64]string{testChainID1: "0x124"}, ContractAddresses: map[uint64]string{testChainID1: "0x124"},
Symbol: "TEST2", Symbol: "TEST2",
AmountInWei: "100000000000000000000", AmountInWei: amountInWei,
Decimals: uint64(18), Decimals: decimals,
}, },
}, },
} }