From 88c671fcf07f0c2f77b0b1675768114c57e3ef62 Mon Sep 17 00:00:00 2001 From: Patryk Osmaczko Date: Thu, 13 Jun 2024 15:26:52 +0200 Subject: [PATCH] fix(communities)_: correct >1 NFT token requirement evaluation Fixed logic to respect specified NFT quantities. Previously, holding one NFT sufficed, regardless of the required count. fixes: status-im/status-desktop#15122 --- protocol/communities/manager_test.go | 7 +- protocol/communities/permission_checker.go | 295 ++++++++++-------- .../communities/permission_checker_test.go | 158 ++++++++++ .../communities_messenger_helpers_test.go | 39 +-- ...nities_messenger_token_permissions_test.go | 69 +++- 5 files changed, 392 insertions(+), 176 deletions(-) diff --git a/protocol/communities/manager_test.go b/protocol/communities/manager_test.go index f3122e530..f6e7bd02e 100644 --- a/protocol/communities/manager_test.go +++ b/protocol/communities/manager_test.go @@ -296,15 +296,16 @@ func (s *ManagerSuite) TestRetrieveCollectibles() { var tokenBalances []thirdparty.TokenBalance var tokenCriteria = []*protobuf.TokenCriteria{ - &protobuf.TokenCriteria{ + { ContractAddresses: contractAddresses, TokenIds: []uint64{tokenID}, Type: protobuf.CommunityTokenType_ERC721, + AmountInWei: "1", }, } var permissions = []*CommunityTokenPermission{ - &CommunityTokenPermission{ + { CommunityTokenPermission: &protobuf.CommunityTokenPermission{ Id: "some-id", Type: protobuf.CommunityTokenPermission_BECOME_MEMBER, @@ -316,7 +317,7 @@ func (s *ManagerSuite) TestRetrieveCollectibles() { preParsedPermissions := preParsedCommunityPermissionsData(permissions) accountChainIDsCombination := []*AccountChainIDsCombination{ - &AccountChainIDsCombination{ + { Address: gethcommon.HexToAddress("0xD6b912e09E797D291E8D0eA3D3D17F8000e01c32"), ChainIDs: []uint64{chainID}, }, diff --git a/protocol/communities/permission_checker.go b/protocol/communities/permission_checker.go index 8c5f63191..a5a955734 100644 --- a/protocol/communities/permission_checker.go +++ b/protocol/communities/permission_checker.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "math/big" + "strconv" "strings" "go.uber.org/zap" @@ -221,6 +222,158 @@ func (p *DefaultPermissionChecker) checkPermissionsOrDefault(permissions []*Comm type ownedERC721TokensGetter = func(walletAddresses []gethcommon.Address, tokenRequirements map[uint64]map[string]*protobuf.TokenCriteria, chainIDs []uint64) (CollectiblesByChain, error) type balancesByChainGetter = func(ctx context.Context, accounts, tokens []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error) +func (p *DefaultPermissionChecker) checkTokenRequirement( + tokenRequirement *protobuf.TokenCriteria, + accounts []gethcommon.Address, ownedERC20TokenBalances BalancesByChain, ownedERC721Tokens CollectiblesByChain, + accountsChainIDsCombinations map[gethcommon.Address]map[uint64]bool, +) (TokenRequirementResponse, error) { + tokenRequirementResponse := TokenRequirementResponse{TokenCriteria: tokenRequirement} + + switch tokenRequirement.Type { + + case protobuf.CommunityTokenType_ERC721: + + if len(ownedERC721Tokens) == 0 { + return tokenRequirementResponse, nil + } + + // Limit NFTs count to uint32 + requiredCount, err := strconv.ParseUint(tokenRequirement.AmountInWei, 10, 32) + if err != nil { + return tokenRequirementResponse, fmt.Errorf("invalid ERC721 amount: %s", tokenRequirement.AmountInWei) + } + accumulatedCount := uint64(0) + + for chainID, addressStr := range tokenRequirement.ContractAddresses { + contractAddress := gethcommon.HexToAddress(addressStr) + if _, exists := ownedERC721Tokens[chainID]; !exists || len(ownedERC721Tokens[chainID]) == 0 { + continue + } + + for account := range ownedERC721Tokens[chainID] { + if _, exists := ownedERC721Tokens[chainID][account]; !exists { + continue + } + + tokenBalances := ownedERC721Tokens[chainID][account][contractAddress] + accumulatedCount += uint64(len(tokenBalances)) + + if len(tokenBalances) > 0 { + // 'account' owns some TokenID owned from contract 'address' + if _, exists := accountsChainIDsCombinations[account]; !exists { + accountsChainIDsCombinations[account] = make(map[uint64]bool) + } + + // account has balance > 0 on this chain for this token, so let's add it the chain IDs + accountsChainIDsCombinations[account][chainID] = true + + if len(tokenRequirement.TokenIds) == 0 { + // no specific tokenId of this collection is needed + + if accumulatedCount >= requiredCount { + tokenRequirementResponse.Satisfied = true + return tokenRequirementResponse, nil + } + } + + for _, tokenID := range tokenRequirement.TokenIds { + tokenIDBigInt := new(big.Int).SetUint64(tokenID) + + for _, asset := range tokenBalances { + if asset.TokenID.Cmp(tokenIDBigInt) == 0 && asset.Balance.Sign() > 0 { + tokenRequirementResponse.Satisfied = true + return tokenRequirementResponse, nil + } + } + } + } + } + } + + case protobuf.CommunityTokenType_ERC20: + + if len(ownedERC20TokenBalances) == 0 { + return tokenRequirementResponse, nil + } + + accumulatedBalance := new(big.Int) + + chainIDLoopERC20: + for chainID, address := range tokenRequirement.ContractAddresses { + if _, exists := ownedERC20TokenBalances[chainID]; !exists || len(ownedERC20TokenBalances[chainID]) == 0 { + continue chainIDLoopERC20 + } + contractAddress := gethcommon.HexToAddress(address) + for account := range ownedERC20TokenBalances[chainID] { + if _, exists := ownedERC20TokenBalances[chainID][account][contractAddress]; !exists { + continue + } + + value := ownedERC20TokenBalances[chainID][account][contractAddress] + + if _, exists := accountsChainIDsCombinations[account]; !exists { + accountsChainIDsCombinations[account] = make(map[uint64]bool) + } + + if value.ToInt().Cmp(big.NewInt(0)) > 0 { + // account has balance > 0 on this chain for this token, so let's add it the chain IDs + accountsChainIDsCombinations[account][chainID] = true + } + + // check if adding current chain account balance to accumulated balance + // satisfies required amount + prevBalance := accumulatedBalance + accumulatedBalance.Add(prevBalance, value.ToInt()) + + requiredAmount, success := new(big.Int).SetString(tokenRequirement.AmountInWei, 10) + if !success { + return tokenRequirementResponse, fmt.Errorf("amountInWeis value is incorrect: %s", tokenRequirement.AmountInWei) + } + + if accumulatedBalance.Cmp(requiredAmount) != -1 { + tokenRequirementResponse.Satisfied = true + return tokenRequirementResponse, nil + } + } + } + + case protobuf.CommunityTokenType_ENS: + + for _, account := range accounts { + ownedENSNames, err := p.getOwnedENS([]gethcommon.Address{account}) + if err != nil { + return tokenRequirementResponse, err + } + + if _, exists := accountsChainIDsCombinations[account]; !exists { + accountsChainIDsCombinations[account] = make(map[uint64]bool) + } + + if !strings.HasPrefix(tokenRequirement.EnsPattern, "*.") { + for _, ownedENS := range ownedENSNames { + if ownedENS == tokenRequirement.EnsPattern { + accountsChainIDsCombinations[account][walletcommon.EthereumMainnet] = true + tokenRequirementResponse.Satisfied = true + return tokenRequirementResponse, nil + } + } + } else { + parentName := tokenRequirement.EnsPattern[2:] + for _, ownedENS := range ownedENSNames { + if strings.HasSuffix(ownedENS, parentName) { + accountsChainIDsCombinations[account][walletcommon.EthereumMainnet] = true + tokenRequirementResponse.Satisfied = true + return tokenRequirementResponse, nil + } + } + } + } + + } + + return tokenRequirementResponse, nil +} + func (p *DefaultPermissionChecker) checkPermissions(permissionsParsedData *PreParsedCommunityPermissionsData, accountsAndChainIDs []*AccountChainIDsCombination, shortcircuit bool, getOwnedERC721Tokens ownedERC721TokensGetter, getBalancesByChain balancesByChainGetter) (*CheckPermissionsResponse, error) { @@ -281,7 +434,6 @@ func (p *DefaultPermissionChecker) checkPermissions(permissionsParsedData *PrePa accountsChainIDsCombinations := make(map[gethcommon.Address]map[uint64]bool) for _, tokenPermission := range permissionsParsedData.Permissions { - permissionRequirementsMet := true response.Permissions[tokenPermission.Id] = &PermissionTokenCriteriaResult{Role: tokenPermission.Type} @@ -289,146 +441,17 @@ func (p *DefaultPermissionChecker) checkPermissions(permissionsParsedData *PrePa // If only one is not met, the entire permission is marked // as not fulfilled for _, tokenRequirement := range tokenPermission.TokenCriteria { - - tokenRequirementMet := false - tokenRequirementResponse := TokenRequirementResponse{TokenCriteria: tokenRequirement} - - if tokenRequirement.Type == protobuf.CommunityTokenType_ERC721 { - if len(ownedERC721Tokens) == 0 { - - response.Permissions[tokenPermission.Id].TokenRequirements = append(response.Permissions[tokenPermission.Id].TokenRequirements, tokenRequirementResponse) - response.Permissions[tokenPermission.Id].Criteria = append(response.Permissions[tokenPermission.Id].Criteria, false) - continue - } - - chainIDLoopERC721: - for chainID, addressStr := range tokenRequirement.ContractAddresses { - contractAddress := gethcommon.HexToAddress(addressStr) - if _, exists := ownedERC721Tokens[chainID]; !exists || len(ownedERC721Tokens[chainID]) == 0 { - continue chainIDLoopERC721 - } - - for account := range ownedERC721Tokens[chainID] { - if _, exists := ownedERC721Tokens[chainID][account]; !exists { - continue - } - - tokenBalances := ownedERC721Tokens[chainID][account][contractAddress] - if len(tokenBalances) > 0 { - // 'account' owns some TokenID owned from contract 'address' - if _, exists := accountsChainIDsCombinations[account]; !exists { - accountsChainIDsCombinations[account] = make(map[uint64]bool) - } - - if len(tokenRequirement.TokenIds) == 0 { - // no specific tokenId of this collection is needed - tokenRequirementMet = true - accountsChainIDsCombinations[account][chainID] = true - break chainIDLoopERC721 - } - - tokenIDsLoop: - for _, tokenID := range tokenRequirement.TokenIds { - tokenIDBigInt := new(big.Int).SetUint64(tokenID) - - for _, asset := range tokenBalances { - if asset.TokenID.Cmp(tokenIDBigInt) == 0 && asset.Balance.Sign() > 0 { - tokenRequirementMet = true - accountsChainIDsCombinations[account][chainID] = true - break tokenIDsLoop - } - } - } - } - } - } - } else if tokenRequirement.Type == protobuf.CommunityTokenType_ERC20 { - if len(ownedERC20TokenBalances) == 0 { - response.Permissions[tokenPermission.Id].TokenRequirements = append(response.Permissions[tokenPermission.Id].TokenRequirements, tokenRequirementResponse) - response.Permissions[tokenPermission.Id].Criteria = append(response.Permissions[tokenPermission.Id].Criteria, false) - continue - } - - accumulatedBalance := new(big.Int) - - chainIDLoopERC20: - for chainID, address := range tokenRequirement.ContractAddresses { - if _, exists := ownedERC20TokenBalances[chainID]; !exists || len(ownedERC20TokenBalances[chainID]) == 0 { - continue chainIDLoopERC20 - } - contractAddress := gethcommon.HexToAddress(address) - for account := range ownedERC20TokenBalances[chainID] { - if _, exists := ownedERC20TokenBalances[chainID][account][contractAddress]; !exists { - continue - } - - value := ownedERC20TokenBalances[chainID][account][contractAddress] - - if _, exists := accountsChainIDsCombinations[account]; !exists { - accountsChainIDsCombinations[account] = make(map[uint64]bool) - } - - if value.ToInt().Cmp(big.NewInt(0)) > 0 { - // account has balance > 0 on this chain for this token, so let's add it the chain IDs - accountsChainIDsCombinations[account][chainID] = true - } - - // check if adding current chain account balance to accumulated balance - // satisfies required amount - prevBalance := accumulatedBalance - accumulatedBalance.Add(prevBalance, value.ToInt()) - - requiredAmount, success := new(big.Int).SetString(tokenRequirement.AmountInWei, 10) - if !success { - return nil, fmt.Errorf("amountInWeis value is incorrect: %s", tokenRequirement.AmountInWei) - } - - if accumulatedBalance.Cmp(requiredAmount) != -1 { - tokenRequirementMet = true - if shortcircuit { - break chainIDLoopERC20 - } - } - } - } - - } else if tokenRequirement.Type == protobuf.CommunityTokenType_ENS { - - for _, account := range accounts { - ownedENSNames, err := p.getOwnedENS([]gethcommon.Address{account}) - if err != nil { - return nil, err - } - - if _, exists := accountsChainIDsCombinations[account]; !exists { - accountsChainIDsCombinations[account] = make(map[uint64]bool) - } - - if !strings.HasPrefix(tokenRequirement.EnsPattern, "*.") { - for _, ownedENS := range ownedENSNames { - if ownedENS == tokenRequirement.EnsPattern { - tokenRequirementMet = true - accountsChainIDsCombinations[account][walletcommon.EthereumMainnet] = true - } - } - } else { - parentName := tokenRequirement.EnsPattern[2:] - for _, ownedENS := range ownedENSNames { - if strings.HasSuffix(ownedENS, parentName) { - tokenRequirementMet = true - accountsChainIDsCombinations[account][walletcommon.EthereumMainnet] = true - } - } - } - } + tokenRequirementResponse, err := p.checkTokenRequirement(tokenRequirement, accounts, ownedERC20TokenBalances, ownedERC721Tokens, accountsChainIDsCombinations) + if err != nil { + p.logger.Error("failed to check token requirement", zap.Error(err)) } - if !tokenRequirementMet { + + if !tokenRequirementResponse.Satisfied { permissionRequirementsMet = false } - tokenRequirementResponse.Satisfied = tokenRequirementMet response.Permissions[tokenPermission.Id].TokenRequirements = append(response.Permissions[tokenPermission.Id].TokenRequirements, tokenRequirementResponse) - response.Permissions[tokenPermission.Id].Criteria = append(response.Permissions[tokenPermission.Id].Criteria, tokenRequirementMet) + response.Permissions[tokenPermission.Id].Criteria = append(response.Permissions[tokenPermission.Id].Criteria, tokenRequirementResponse.Satisfied) } response.Permissions[tokenPermission.Id].ID = tokenPermission.Id diff --git a/protocol/communities/permission_checker_test.go b/protocol/communities/permission_checker_test.go index c3b785ae3..c3a93ccf2 100644 --- a/protocol/communities/permission_checker_test.go +++ b/protocol/communities/permission_checker_test.go @@ -1,11 +1,21 @@ package communities import ( + "context" + "errors" + "fmt" + "math/big" + "strconv" "testing" "github.com/stretchr/testify/suite" + "github.com/status-im/status-go/protocol/protobuf" + "github.com/status-im/status-go/services/wallet/bigint" + "github.com/status-im/status-go/services/wallet/thirdparty" + gethcommon "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" ) func TestPermissionCheckerSuite(t *testing.T) { @@ -56,3 +66,151 @@ func (s *PermissionCheckerSuite) TestMergeValidCombinations() { } } + +func (s *PermissionCheckerSuite) TestCheckPermissions() { + testCases := []struct { + name string + amountInWei func(t protobuf.CommunityTokenType) string + requiredAmountInWei func(t protobuf.CommunityTokenType) string + shouldSatisfy bool + }{ + { + name: "account does not meet criteria", + amountInWei: func(t protobuf.CommunityTokenType) string { + if t == protobuf.CommunityTokenType_ERC721 { + return "1" + } + return "1000000000000000000" + }, + requiredAmountInWei: func(t protobuf.CommunityTokenType) string { + if t == protobuf.CommunityTokenType_ERC721 { + return "2" + } + return "2000000000000000000" + }, + shouldSatisfy: false, + }, + { + name: "account does exactly meet criteria", + amountInWei: func(t protobuf.CommunityTokenType) string { + if t == protobuf.CommunityTokenType_ERC721 { + return "2" + } + return "2000000000000000000" + }, + requiredAmountInWei: func(t protobuf.CommunityTokenType) string { + if t == protobuf.CommunityTokenType_ERC721 { + return "2" + } + return "2000000000000000000" + }, + shouldSatisfy: true, + }, + { + name: "account does meet criteria", + amountInWei: func(t protobuf.CommunityTokenType) string { + if t == protobuf.CommunityTokenType_ERC721 { + return "3" + } + return "3000000000000000000" + }, + requiredAmountInWei: func(t protobuf.CommunityTokenType) string { + if t == protobuf.CommunityTokenType_ERC721 { + return "2" + } + return "2000000000000000000" + }, + shouldSatisfy: true, + }, + } + + permissionChecker := DefaultPermissionChecker{} + chainID := uint64(1) + contractAddress := gethcommon.HexToAddress("0x3d6afaa395c31fcd391fe3d562e75fe9e8ec7e6a") + walletAddress := gethcommon.HexToAddress("0xD6b912e09E797D291E8D0eA3D3D17F8000e01c32") + + for _, tc := range testCases { + for _, tokenType := range [](protobuf.CommunityTokenType){protobuf.CommunityTokenType_ERC20, protobuf.CommunityTokenType_ERC721} { + s.Run(fmt.Sprintf("%s_%s", tc.name, tokenType.String()), func() { + decimals := uint64(0) + if tokenType == protobuf.CommunityTokenType_ERC20 { + decimals = 18 + } + permissions := map[string]*CommunityTokenPermission{ + "p1": { + CommunityTokenPermission: &protobuf.CommunityTokenPermission{ + Id: "p1", + Type: protobuf.CommunityTokenPermission_BECOME_MEMBER, + TokenCriteria: []*protobuf.TokenCriteria{ + { + ContractAddresses: map[uint64]string{ + chainID: contractAddress.String(), + }, + Type: tokenType, + Symbol: "STT", + TokenIds: []uint64{}, + Decimals: decimals, + AmountInWei: tc.requiredAmountInWei(tokenType), + }, + }, + }, + }, + } + + permissionsData, _ := PreParsePermissionsData(permissions) + accountsAndChainIDs := []*AccountChainIDsCombination{ + { + Address: walletAddress, + ChainIDs: []uint64{chainID}, + }, + } + + var getOwnedERC721Tokens ownedERC721TokensGetter = func(walletAddresses []gethcommon.Address, tokenRequirements map[uint64]map[string]*protobuf.TokenCriteria, chainIDs []uint64) (CollectiblesByChain, error) { + amount, err := strconv.ParseUint(tc.amountInWei(protobuf.CommunityTokenType_ERC721), 10, 64) + if err != nil { + return nil, err + } + + balances := []thirdparty.TokenBalance{} + for i := uint64(0); i < amount; i++ { + balances = append(balances, thirdparty.TokenBalance{ + TokenID: &bigint.BigInt{ + Int: new(big.Int).SetUint64(i + 1), + }, + Balance: &bigint.BigInt{ + Int: new(big.Int).SetUint64(1), + }, + }) + } + + return CollectiblesByChain{ + chainID: { + walletAddress: { + contractAddress: balances, + }, + }, + }, nil + } + + var getBalancesByChain balancesByChainGetter = func(ctx context.Context, accounts, tokens []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error) { + balance, ok := new(big.Int).SetString(tc.amountInWei(protobuf.CommunityTokenType_ERC20), 10) + if !ok { + return nil, errors.New("invalid conversion") + } + + return BalancesByChain{ + chainID: { + walletAddress: { + contractAddress: (*hexutil.Big)(balance), + }, + }, + }, nil + } + + response, err := permissionChecker.checkPermissions(permissionsData[protobuf.CommunityTokenPermission_BECOME_MEMBER], accountsAndChainIDs, true, getOwnedERC721Tokens, getBalancesByChain) + s.Require().NoError(err) + s.Require().Equal(tc.shouldSatisfy, response.Satisfied) + }) + } + } +} diff --git a/protocol/communities_messenger_helpers_test.go b/protocol/communities_messenger_helpers_test.go index be4437182..0af773d9f 100644 --- a/protocol/communities_messenger_helpers_test.go +++ b/protocol/communities_messenger_helpers_test.go @@ -5,7 +5,6 @@ import ( "crypto/ecdsa" "encoding/json" "errors" - "math/big" "sync" "time" @@ -25,7 +24,6 @@ import ( "github.com/status-im/status-go/protocol/communities/token" "github.com/status-im/status-go/protocol/protobuf" "github.com/status-im/status-go/protocol/requests" - "github.com/status-im/status-go/services/wallet/bigint" walletCommon "github.com/status-im/status-go/services/wallet/common" "github.com/status-im/status-go/services/wallet/thirdparty" walletToken "github.com/status-im/status-go/services/wallet/token" @@ -55,7 +53,7 @@ func (m *AccountManagerMock) DeleteAccount(address types.Address) error { } type TokenManagerMock struct { - Balances *map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big + Balances *communities.BalancesByChain } func (m *TokenManagerMock) GetAllChainIDs() ([]uint64, error) { @@ -82,7 +80,7 @@ func (m *TokenManagerMock) FindOrCreateTokenByAddress(ctx context.Context, chain } type CollectiblesManagerMock struct { - Balances *map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big + Collectibles *communities.CollectiblesByChain collectibleOwnershipResponse map[string][]thirdparty.AccountBalance } @@ -94,7 +92,7 @@ func (m *CollectiblesManagerMock) FetchCachedBalancesByOwnerAndContractAddress(c func (m *CollectiblesManagerMock) FetchBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletCommon.ChainID, ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (thirdparty.TokenBalancesPerContractAddress, error) { ret := make(thirdparty.TokenBalancesPerContractAddress) - accountsBalances, ok := (*m.Balances)[uint64(chainID)] + accountsBalances, ok := (*m.Collectibles)[uint64(chainID)] if !ok { return ret, nil } @@ -107,14 +105,7 @@ func (m *CollectiblesManagerMock) FetchBalancesByOwnerAndContractAddress(ctx con for _, contractAddress := range contractAddresses { balance, ok := balances[contractAddress] if ok { - ret[contractAddress] = []thirdparty.TokenBalance{ - { - TokenID: &bigint.BigInt{}, - Balance: &bigint.BigInt{ - Int: (*big.Int)(balance), - }, - }, - } + ret[contractAddress] = balance } } @@ -135,24 +126,17 @@ func (m *CollectiblesManagerMock) FetchCollectibleOwnersByContractAddress(ctx co ContractAddress: contractAddress, Owners: []thirdparty.CollectibleOwner{}, } - accountsBalances, ok := (*m.Balances)[uint64(chainID)] + accountsBalances, ok := (*m.Collectibles)[uint64(chainID)] if !ok { return ret, nil } - for wallet, collectiblesBalance := range accountsBalances { - balance, ok := collectiblesBalance[contractAddress] + for wallet, balances := range accountsBalances { + balance, ok := balances[contractAddress] if ok { ret.Owners = append(ret.Owners, thirdparty.CollectibleOwner{ - OwnerAddress: wallet, - TokenBalances: []thirdparty.TokenBalance{ - { - TokenID: &bigint.BigInt{}, - Balance: &bigint.BigInt{ - Int: (*big.Int)(balance), - }, - }, - }, + OwnerAddress: wallet, + TokenBalances: balance, }) } } @@ -256,7 +240,8 @@ type testCommunitiesMessengerConfig struct { password string walletAddresses []string - mockedBalances *map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big + mockedBalances *communities.BalancesByChain + mockedCollectibles *communities.CollectiblesByChain collectiblesService communities.CommunityTokensServiceInterface } @@ -323,7 +308,7 @@ func newTestCommunitiesMessenger(s *suite.Suite, waku types.Waku, config testCom } collectiblesManagerMock := &CollectiblesManagerMock{ - Balances: config.mockedBalances, + Collectibles: config.mockedCollectibles, } options := []Option{ diff --git a/protocol/communities_messenger_token_permissions_test.go b/protocol/communities_messenger_token_permissions_test.go index 7dbf3593f..563cc0871 100644 --- a/protocol/communities_messenger_token_permissions_test.go +++ b/protocol/communities_messenger_token_permissions_test.go @@ -32,6 +32,8 @@ import ( "github.com/status-im/status-go/protocol/requests" "github.com/status-im/status-go/protocol/transport" "github.com/status-im/status-go/protocol/tt" + "github.com/status-im/status-go/services/wallet/bigint" + "github.com/status-im/status-go/services/wallet/thirdparty" ) const testChainID1 = 1 @@ -145,7 +147,8 @@ type MessengerCommunitiesTokenPermissionsSuite struct { logger *zap.Logger - mockedBalances map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big // chainID, account, token, balance + mockedBalances communities.BalancesByChain + mockedCollectibles communities.CollectiblesByChain collectiblesServiceMock *CollectiblesServiceMock } @@ -212,6 +215,7 @@ func (s *MessengerCommunitiesTokenPermissionsSuite) newMessenger(password string password: password, walletAddresses: walletAddresses, mockedBalances: &s.mockedBalances, + mockedCollectibles: &s.mockedCollectibles, collectiblesService: s.collectiblesServiceMock, }) } @@ -245,10 +249,35 @@ func (s *MessengerCommunitiesTokenPermissionsSuite) sendChatMessage(sender *Mess func (s *MessengerCommunitiesTokenPermissionsSuite) makeAddressSatisfyTheCriteria(chainID uint64, address string, criteria *protobuf.TokenCriteria) { walletAddress := gethcommon.HexToAddress(address) contractAddress := gethcommon.HexToAddress(criteria.ContractAddresses[chainID]) - balance, ok := new(big.Int).SetString(criteria.AmountInWei, 10) - s.Require().True(ok) - s.mockedBalances[chainID][walletAddress][contractAddress] = (*hexutil.Big)(balance) + switch criteria.Type { + case protobuf.CommunityTokenType_ERC20: + balance, ok := new(big.Int).SetString(criteria.AmountInWei, 10) + s.Require().True(ok) + + s.mockedBalances[chainID][walletAddress][contractAddress] = (*hexutil.Big)(balance) + + case protobuf.CommunityTokenType_ERC721: + amount, err := strconv.ParseUint(criteria.AmountInWei, 10, 32) + s.Require().NoError(err) + + balances := []thirdparty.TokenBalance{} + for i := uint64(0); i < amount; i++ { + balances = append(balances, thirdparty.TokenBalance{ + TokenID: &bigint.BigInt{ + Int: new(big.Int).SetUint64(i + 1), + }, + Balance: &bigint.BigInt{ + Int: new(big.Int).SetUint64(1), + }, + }) + } + + s.mockedCollectibles[chainID][walletAddress][contractAddress] = balances + + case protobuf.CommunityTokenType_ENS: + // not implemented + } } func (s *MessengerCommunitiesTokenPermissionsSuite) resetMockedBalances() { @@ -257,6 +286,12 @@ func (s *MessengerCommunitiesTokenPermissionsSuite) resetMockedBalances() { s.mockedBalances[testChainID1][gethcommon.HexToAddress(aliceAddress1)] = make(map[gethcommon.Address]*hexutil.Big) s.mockedBalances[testChainID1][gethcommon.HexToAddress(aliceAddress2)] = make(map[gethcommon.Address]*hexutil.Big) s.mockedBalances[testChainID1][gethcommon.HexToAddress(bobAddress)] = make(map[gethcommon.Address]*hexutil.Big) + + s.mockedCollectibles = make(communities.CollectiblesByChain) + s.mockedCollectibles[testChainID1] = make(map[gethcommon.Address]thirdparty.TokenBalancesPerContractAddress) + s.mockedCollectibles[testChainID1][gethcommon.HexToAddress(aliceAddress1)] = make(thirdparty.TokenBalancesPerContractAddress) + s.mockedCollectibles[testChainID1][gethcommon.HexToAddress(aliceAddress2)] = make(thirdparty.TokenBalancesPerContractAddress) + s.mockedCollectibles[testChainID1][gethcommon.HexToAddress(bobAddress)] = make(thirdparty.TokenBalancesPerContractAddress) } func (s *MessengerCommunitiesTokenPermissionsSuite) waitOnKeyDistribution(condition func(*CommunityAndKeyActions) bool) <-chan error { @@ -1604,6 +1639,13 @@ func (s *MessengerCommunitiesTokenPermissionsSuite) TestMemberRoleGetUpdatedWhen func (s *MessengerCommunitiesTokenPermissionsSuite) testReevaluateMemberPrivilegedRoleInOpenCommunity(permissionType protobuf.CommunityTokenPermission_Type, tokenType protobuf.CommunityTokenType) { community, _ := s.createCommunity() + amountInWei := "100000000000000000000" + decimals := uint64(18) + if tokenType == protobuf.CommunityTokenType_ERC721 { + amountInWei = "1" + decimals = 0 + } + createTokenPermission := &requests.CreateCommunityTokenPermission{ CommunityID: community.ID(), Type: permissionType, @@ -1612,8 +1654,8 @@ func (s *MessengerCommunitiesTokenPermissionsSuite) testReevaluateMemberPrivileg Type: tokenType, ContractAddresses: map[uint64]string{testChainID1: "0x123"}, Symbol: "TEST", - AmountInWei: "100000000000000000000", - Decimals: uint64(18), + AmountInWei: amountInWei, + Decimals: decimals, }, }, } @@ -1723,6 +1765,13 @@ func (s *MessengerCommunitiesTokenPermissionsSuite) TestReevaluateMemberTokenMas func (s *MessengerCommunitiesTokenPermissionsSuite) testReevaluateMemberPrivilegedRoleInClosedCommunity(permissionType protobuf.CommunityTokenPermission_Type, tokenType protobuf.CommunityTokenType) { community, _ := s.createCommunity() + amountInWei := "100000000000000000000" + decimals := uint64(18) + if tokenType == protobuf.CommunityTokenType_ERC721 { + amountInWei = "1" + decimals = 0 + } + createTokenPermission := &requests.CreateCommunityTokenPermission{ CommunityID: community.ID(), Type: permissionType, @@ -1731,8 +1780,8 @@ func (s *MessengerCommunitiesTokenPermissionsSuite) testReevaluateMemberPrivileg Type: tokenType, ContractAddresses: map[uint64]string{testChainID1: "0x123"}, Symbol: "TEST", - AmountInWei: "100000000000000000000", - Decimals: uint64(18), + AmountInWei: amountInWei, + Decimals: decimals, }, }, } @@ -1751,8 +1800,8 @@ func (s *MessengerCommunitiesTokenPermissionsSuite) testReevaluateMemberPrivileg Type: tokenType, ContractAddresses: map[uint64]string{testChainID1: "0x124"}, Symbol: "TEST2", - AmountInWei: "100000000000000000000", - Decimals: uint64(18), + AmountInWei: amountInWei, + Decimals: decimals, }, }, }