From 25d8c52dd5d46873a04022b9d732ca3f42bb9d95 Mon Sep 17 00:00:00 2001 From: Dario Gabriel Lipicar Date: Wed, 1 Nov 2023 16:09:10 -0300 Subject: [PATCH] feat: implement generalized collectibles filter --- services/wallet/api.go | 5 +- .../collectibles/collectible_data_db_test.go | 11 +- services/wallet/collectibles/filter.go | 98 ++++++++++ services/wallet/collectibles/filter.sql | 49 +++++ services/wallet/collectibles/filter_test.go | 184 ++++++++++++++++++ services/wallet/collectibles/service.go | 11 +- 6 files changed, 348 insertions(+), 10 deletions(-) create mode 100644 services/wallet/collectibles/filter.go create mode 100644 services/wallet/collectibles/filter.sql create mode 100644 services/wallet/collectibles/filter_test.go diff --git a/services/wallet/api.go b/services/wallet/api.go index 885f1ff8b..356013bc8 100644 --- a/services/wallet/api.go +++ b/services/wallet/api.go @@ -19,6 +19,7 @@ import ( "github.com/status-im/status-go/rpc/network" "github.com/status-im/status-go/services/wallet/activity" "github.com/status-im/status-go/services/wallet/bridge" + "github.com/status-im/status-go/services/wallet/collectibles" wcommon "github.com/status-im/status-go/services/wallet/common" "github.com/status-im/status-go/services/wallet/currency" "github.com/status-im/status-go/services/wallet/history" @@ -347,10 +348,10 @@ func (api *API) RefetchOwnedCollectibles() error { return nil } -func (api *API) FilterOwnedCollectiblesAsync(requestID int32, chainIDs []wcommon.ChainID, addresses []common.Address, offset int, limit int) error { +func (api *API) FilterOwnedCollectiblesAsync(requestID int32, chainIDs []wcommon.ChainID, addresses []common.Address, filter collectibles.Filter, offset int, limit int) error { log.Debug("wallet.api.FilterOwnedCollectiblesAsync", "chainIDs.count", len(chainIDs), "addr.count", len(addresses), "offset", offset, "limit", limit) - api.s.collectibles.FilterOwnedCollectiblesAsync(requestID, chainIDs, addresses, offset, limit) + api.s.collectibles.FilterOwnedCollectiblesAsync(requestID, chainIDs, addresses, filter, offset, limit) return nil } diff --git a/services/wallet/collectibles/collectible_data_db_test.go b/services/wallet/collectibles/collectible_data_db_test.go index 15bad9ec1..dd3c24fa4 100644 --- a/services/wallet/collectibles/collectible_data_db_test.go +++ b/services/wallet/collectibles/collectible_data_db_test.go @@ -28,11 +28,11 @@ func setupCollectibleDataDBTest(t *testing.T) (*CollectibleDataDB, func()) { func generateTestCollectiblesData(count int) (result []thirdparty.CollectibleData) { result = make([]thirdparty.CollectibleData, 0, count) for i := 0; i < count; i++ { - bigI := big.NewInt(int64(count)) + bigI := big.NewInt(int64(i)) newCollectible := thirdparty.CollectibleData{ ID: thirdparty.CollectibleUniqueID{ ContractID: thirdparty.ContractID{ - ChainID: w_common.ChainID(i), + ChainID: w_common.ChainID(i % 4), Address: common.BigToAddress(bigI), }, TokenID: &bigint.BigInt{Int: bigI}, @@ -66,7 +66,10 @@ func generateTestCollectiblesData(count int) (result []thirdparty.CollectibleDat }, BackgroundColor: fmt.Sprintf("backgroundcolor-%d", i), TokenURI: fmt.Sprintf("tokenuri-%d", i), - CommunityID: fmt.Sprintf("communityid-%d", i), + CommunityID: fmt.Sprintf("communityid-%d", i%5), + } + if i%5 == 0 { + newCollectible.CommunityID = "" } result = append(result, newCollectible) } @@ -77,7 +80,7 @@ func generateTestCommunityData(count int) []thirdparty.CollectibleCommunityInfo result := make([]thirdparty.CollectibleCommunityInfo, 0, count) for i := 0; i < count; i++ { newCommunityInfo := thirdparty.CollectibleCommunityInfo{ - PrivilegesLevel: token.PrivilegesLevel(i), + PrivilegesLevel: token.PrivilegesLevel(i) % (token.CommunityLevel + 1), } result = append(result, newCommunityInfo) } diff --git a/services/wallet/collectibles/filter.go b/services/wallet/collectibles/filter.go new file mode 100644 index 000000000..83badee7f --- /dev/null +++ b/services/wallet/collectibles/filter.go @@ -0,0 +1,98 @@ +package collectibles + +import ( + "context" + "database/sql" + "errors" + + // used for embedding the sql query in the binary + _ "embed" + + "github.com/ethereum/go-ethereum/common" + + "github.com/jmoiron/sqlx" + + "github.com/status-im/status-go/protocol/communities/token" + wcommon "github.com/status-im/status-go/services/wallet/common" + "github.com/status-im/status-go/services/wallet/thirdparty" +) + +func allCommunityIDsFilter() []string { + return []string{} +} + +func allCommunityPrivilegesLevelsFilter() []token.PrivilegesLevel { + return []token.PrivilegesLevel{} +} + +func allFilter() Filter { + return Filter{ + CommunityIDs: allCommunityIDsFilter(), + CommunityPrivilegesLevels: allCommunityPrivilegesLevelsFilter(), + FilterCommunity: All, + } +} + +type FilterCommunityType int + +const ( + All FilterCommunityType = iota + OnlyNonCommunity + OnlyCommunity +) + +type Filter struct { + CommunityIDs []string `json:"community_ids"` + CommunityPrivilegesLevels []token.PrivilegesLevel `json:"community_privileges_levels"` + + FilterCommunity FilterCommunityType `json:"filter_community"` +} + +//go:embed filter.sql +var queryString string + +func filterOwnedCollectibles(ctx context.Context, db *sql.DB, chainIDs []wcommon.ChainID, addresses []common.Address, filter Filter, offset int, limit int) ([]thirdparty.CollectibleUniqueID, error) { + if len(addresses) == 0 { + return nil, errors.New("no addresses provided") + } + if len(chainIDs) == 0 { + return nil, errors.New("no chainIDs provided") + } + + filterCommunityTypeAll := filter.FilterCommunity == All + filterCommunityTypeOnlyNonCommunity := filter.FilterCommunity == OnlyNonCommunity + filterCommunityTypeOnlyCommunity := filter.FilterCommunity == OnlyCommunity + communityIDFilterDisabled := len(filter.CommunityIDs) == 0 + if communityIDFilterDisabled { + // IN clause doesn't work with empty array, so we need to provide a dummy value + filter.CommunityIDs = []string{""} + } + communityPrivilegesLevelDisabled := len(filter.CommunityPrivilegesLevels) == 0 + if communityPrivilegesLevelDisabled { + // IN clause doesn't work with empty array, so we need to provide a dummy value + filter.CommunityPrivilegesLevels = []token.PrivilegesLevel{token.PrivilegesLevel(0)} + } + + query, args, err := sqlx.In(queryString, + filterCommunityTypeAll, filterCommunityTypeOnlyNonCommunity, filterCommunityTypeOnlyCommunity, + communityIDFilterDisabled, communityPrivilegesLevelDisabled, + chainIDs, addresses, filter.CommunityIDs, filter.CommunityPrivilegesLevels, + limit, offset) + if err != nil { + return nil, err + } + + stmt, err := db.Prepare(query) + if err != nil { + return nil, err + } + defer stmt.Close() + + rows, err := stmt.Query(args...) + if err != nil { + return nil, err + } + defer rows.Close() + + return thirdparty.RowsToCollectibles(rows) +} diff --git a/services/wallet/collectibles/filter.sql b/services/wallet/collectibles/filter.sql new file mode 100644 index 000000000..0cafa141b --- /dev/null +++ b/services/wallet/collectibles/filter.sql @@ -0,0 +1,49 @@ +WITH filter_conditions AS ( + SELECT + ? AS filterCommunityTypeAll, + ? AS filterCommunityTypeOnlyNonCommunity, + ? AS filterCommunityTypeOnlyCommunity, + ? AS communityIDFilterDisabled, + ? AS communityPrivilegesLevelDisabled +) +SELECT + ownership.chain_id, + ownership.contract_address, + ownership.token_id +FROM + collectibles_ownership_cache ownership + LEFT JOIN collectible_data_cache data + ON ( + ownership.chain_id = data.chain_id + AND ownership.contract_address = data.contract_address + AND ownership.token_id = data.token_id + ) + CROSS JOIN filter_conditions +WHERE + ownership.chain_id IN (?) + AND ownership.owner_address IN (?) + AND ( + filterCommunityTypeAll + OR ( + filterCommunityTypeOnlyNonCommunity + AND data.community_id = "" + ) + OR ( + filterCommunityTypeOnlyCommunity + AND data.community_id <> "" + ) + ) + AND ( + communityIDFilterDisabled + OR ( + data.community_id IN (?) + ) + ) + AND ( + communityPrivilegesLevelDisabled + OR ( + data.community_privileges_level IN (?) + ) + ) +LIMIT + ? OFFSET ? \ No newline at end of file diff --git a/services/wallet/collectibles/filter_test.go b/services/wallet/collectibles/filter_test.go new file mode 100644 index 000000000..3e5464b5e --- /dev/null +++ b/services/wallet/collectibles/filter_test.go @@ -0,0 +1,184 @@ +package collectibles + +import ( + "context" + "database/sql" + "testing" + + "github.com/ethereum/go-ethereum/common" + + "github.com/status-im/status-go/protocol/communities/token" + w_common "github.com/status-im/status-go/services/wallet/common" + "github.com/status-im/status-go/services/wallet/thirdparty" + "github.com/status-im/status-go/t/helpers" + "github.com/status-im/status-go/walletdatabase" + + "github.com/stretchr/testify/require" +) + +func setupTestFilterDB(t *testing.T) (db *sql.DB, close func()) { + db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) + require.NoError(t, err) + + return db, func() { + require.NoError(t, db.Close()) + } +} + +func TestFilterOwnedCollectibles(t *testing.T) { + db, close := setupTestFilterDB(t) + defer close() + + oDB := NewOwnershipDB(db) + cDB := NewCollectibleDataDB(db) + + const nData = 50 + data := generateTestCollectiblesData(nData) + communityData := generateTestCommunityData(nData) + + ownerAddresses := []common.Address{ + common.HexToAddress("0x1234"), + common.HexToAddress("0x5678"), + common.HexToAddress("0xABCD"), + } + randomAddress := common.HexToAddress("0xFFFF") + + dataPerID := make(map[string]thirdparty.CollectibleData) + communityDataPerID := make(map[string]thirdparty.CollectibleCommunityInfo) + idsPerChainIDAndOwner := make(map[w_common.ChainID]map[common.Address][]thirdparty.CollectibleUniqueID) + + var err error + + for i := 0; i < nData; i++ { + dataPerID[data[i].ID.HashKey()] = data[i] + communityDataPerID[data[i].ID.HashKey()] = communityData[i] + + chainID := data[i].ID.ContractID.ChainID + ownerAddress := ownerAddresses[i%len(ownerAddresses)] + + if _, ok := idsPerChainIDAndOwner[chainID]; !ok { + idsPerChainIDAndOwner[chainID] = make(map[common.Address][]thirdparty.CollectibleUniqueID) + } + if _, ok := idsPerChainIDAndOwner[chainID][ownerAddress]; !ok { + idsPerChainIDAndOwner[chainID][ownerAddress] = make([]thirdparty.CollectibleUniqueID, 0, len(data)) + } + + idsPerChainIDAndOwner[chainID][ownerAddress] = append(idsPerChainIDAndOwner[chainID][ownerAddress], data[i].ID) + } + + timestamp := int64(1234567890) + + for chainID, idsPerOwner := range idsPerChainIDAndOwner { + for ownerAddress, ids := range idsPerOwner { + err = oDB.Update(chainID, ownerAddress, ids, timestamp) + require.NoError(t, err) + } + } + + err = cDB.SetData(data) + require.NoError(t, err) + for i := 0; i < nData; i++ { + err = cDB.SetCommunityInfo(data[i].ID, communityData[i]) + require.NoError(t, err) + } + + var filter Filter + var filterIDs []thirdparty.CollectibleUniqueID + var expectedIDs []thirdparty.CollectibleUniqueID + var tmpIDs []thirdparty.CollectibleUniqueID + + ctx := context.Background() + + filterChains := []w_common.ChainID{w_common.ChainID(1), w_common.ChainID(2)} + filterAddresses := []common.Address{ownerAddresses[0], ownerAddresses[1], ownerAddresses[2], randomAddress} + + // Test common case + filter = allFilter() + + tmpIDs, err = oDB.GetOwnedCollectibles(filterChains, filterAddresses, 0, nData) + require.NoError(t, err) + + expectedIDs = tmpIDs + + filterIDs, err = filterOwnedCollectibles(ctx, db, filterChains, filterAddresses, filter, 0, nData) + require.NoError(t, err) + require.Equal(t, expectedIDs, filterIDs) + + // Test only non-community + filter = allFilter() + filter.FilterCommunity = OnlyNonCommunity + + tmpIDs, err = oDB.GetOwnedCollectibles(filterChains, filterAddresses, 0, nData) + require.NoError(t, err) + + expectedIDs = nil + for _, id := range tmpIDs { + if dataPerID[id.HashKey()].CommunityID == "" { + expectedIDs = append(expectedIDs, id) + } + } + + filterIDs, err = filterOwnedCollectibles(ctx, db, filterChains, filterAddresses, filter, 0, nData) + require.NoError(t, err) + require.Equal(t, expectedIDs, filterIDs) + + // Test only community + filter = allFilter() + filter.FilterCommunity = OnlyCommunity + + tmpIDs, err = oDB.GetOwnedCollectibles(filterChains, filterAddresses, 0, nData) + require.NoError(t, err) + + expectedIDs = nil + for _, id := range tmpIDs { + if dataPerID[id.HashKey()].CommunityID != "" { + expectedIDs = append(expectedIDs, id) + } + } + + filterIDs, err = filterOwnedCollectibles(ctx, db, filterChains, filterAddresses, filter, 0, nData) + require.NoError(t, err) + require.Equal(t, expectedIDs, filterIDs) + + // Test specific community + communityIDa := data[0].CommunityID + communityIDb := data[1].CommunityID + communityIDs := []string{communityIDa, communityIDb} + + filter = allFilter() + filter.CommunityIDs = communityIDs + + tmpIDs, err = oDB.GetOwnedCollectibles(filterChains, filterAddresses, 0, nData) + require.NoError(t, err) + + expectedIDs = nil + for _, id := range tmpIDs { + if dataPerID[id.HashKey()].CommunityID == communityIDa || dataPerID[id.HashKey()].CommunityID == communityIDb { + expectedIDs = append(expectedIDs, id) + } + } + + filterIDs, err = filterOwnedCollectibles(ctx, db, filterChains, filterAddresses, filter, 0, nData) + require.NoError(t, err) + require.Equal(t, expectedIDs, filterIDs) + + // Test specific privileges level + privilegeLevel := token.PrivilegesLevel(2) + + filter = allFilter() + filter.CommunityPrivilegesLevels = []token.PrivilegesLevel{privilegeLevel} + + tmpIDs, err = oDB.GetOwnedCollectibles(filterChains, filterAddresses, 0, nData) + require.NoError(t, err) + + expectedIDs = nil + for _, id := range tmpIDs { + if communityDataPerID[id.HashKey()].PrivilegesLevel == privilegeLevel { + expectedIDs = append(expectedIDs, id) + } + } + + filterIDs, err = filterOwnedCollectibles(ctx, db, filterChains, filterAddresses, filter, 0, nData) + require.NoError(t, err) + require.Equal(t, expectedIDs, filterIDs) +} diff --git a/services/wallet/collectibles/service.go b/services/wallet/collectibles/service.go index 7f4a853b5..be247e007 100644 --- a/services/wallet/collectibles/service.go +++ b/services/wallet/collectibles/service.go @@ -47,6 +47,7 @@ var ( type Service struct { manager *Manager controller *Controller + db *sql.DB ownershipDB *OwnershipDB communityDB *community.DataDB walletFeed *event.Feed @@ -64,6 +65,7 @@ func NewService( return &Service{ manager: manager, controller: NewController(db, walletFeed, accountsDB, accountsFeed, settingsFeed, networkManager, manager), + db: db, ownershipDB: NewOwnershipDB(db), communityDB: community.NewDataDB(db), walletFeed: walletFeed, @@ -111,9 +113,9 @@ type filterOwnedCollectiblesTaskReturnType struct { // FilterOwnedCollectiblesResponse allows only one filter task to run at a time // and it cancels the current one if a new one is started // All calls will trigger an EventOwnedCollectiblesFilteringDone event with the result of the filtering -func (s *Service) FilterOwnedCollectiblesAsync(requestID int32, chainIDs []walletCommon.ChainID, addresses []common.Address, offset int, limit int) { +func (s *Service) FilterOwnedCollectiblesAsync(requestID int32, chainIDs []walletCommon.ChainID, addresses []common.Address, filter Filter, offset int, limit int) { s.scheduler.Enqueue(requestID, filterOwnedCollectiblesTask, func(ctx context.Context) (interface{}, error) { - collectibles, hasMore, err := s.GetOwnedCollectibles(chainIDs, addresses, offset, limit) + collectibles, hasMore, err := s.FilterOwnedCollectibles(chainIDs, addresses, filter, offset, limit) if err != nil { return nil, err } @@ -212,9 +214,10 @@ func (s *Service) sendResponseEvent(requestID *int32, eventType walletevent.Even s.walletFeed.Send(event) } -func (s *Service) GetOwnedCollectibles(chainIDs []walletCommon.ChainID, owners []common.Address, offset int, limit int) ([]thirdparty.CollectibleUniqueID, bool, error) { +func (s *Service) FilterOwnedCollectibles(chainIDs []walletCommon.ChainID, owners []common.Address, filter Filter, offset int, limit int) ([]thirdparty.CollectibleUniqueID, bool, error) { + ctx := context.Background() // Request one more than limit, to check if DB has more available - ids, err := s.ownershipDB.GetOwnedCollectibles(chainIDs, owners, offset, limit+1) + ids, err := filterOwnedCollectibles(ctx, s.db, chainIDs, owners, filter, offset, limit+1) if err != nil { return nil, false, err }