feat: implement generalized collectibles filter

This commit is contained in:
Dario Gabriel Lipicar 2023-11-01 16:09:10 -03:00 committed by dlipicar
parent c17829bf8d
commit 25d8c52dd5
6 changed files with 348 additions and 10 deletions

View File

@ -19,6 +19,7 @@ import (
"github.com/status-im/status-go/rpc/network"
"github.com/status-im/status-go/services/wallet/activity"
"github.com/status-im/status-go/services/wallet/bridge"
"github.com/status-im/status-go/services/wallet/collectibles"
wcommon "github.com/status-im/status-go/services/wallet/common"
"github.com/status-im/status-go/services/wallet/currency"
"github.com/status-im/status-go/services/wallet/history"
@ -347,10 +348,10 @@ func (api *API) RefetchOwnedCollectibles() error {
return nil
}
func (api *API) FilterOwnedCollectiblesAsync(requestID int32, chainIDs []wcommon.ChainID, addresses []common.Address, offset int, limit int) error {
func (api *API) FilterOwnedCollectiblesAsync(requestID int32, chainIDs []wcommon.ChainID, addresses []common.Address, filter collectibles.Filter, offset int, limit int) error {
log.Debug("wallet.api.FilterOwnedCollectiblesAsync", "chainIDs.count", len(chainIDs), "addr.count", len(addresses), "offset", offset, "limit", limit)
api.s.collectibles.FilterOwnedCollectiblesAsync(requestID, chainIDs, addresses, offset, limit)
api.s.collectibles.FilterOwnedCollectiblesAsync(requestID, chainIDs, addresses, filter, offset, limit)
return nil
}

View File

@ -28,11 +28,11 @@ func setupCollectibleDataDBTest(t *testing.T) (*CollectibleDataDB, func()) {
func generateTestCollectiblesData(count int) (result []thirdparty.CollectibleData) {
result = make([]thirdparty.CollectibleData, 0, count)
for i := 0; i < count; i++ {
bigI := big.NewInt(int64(count))
bigI := big.NewInt(int64(i))
newCollectible := thirdparty.CollectibleData{
ID: thirdparty.CollectibleUniqueID{
ContractID: thirdparty.ContractID{
ChainID: w_common.ChainID(i),
ChainID: w_common.ChainID(i % 4),
Address: common.BigToAddress(bigI),
},
TokenID: &bigint.BigInt{Int: bigI},
@ -66,7 +66,10 @@ func generateTestCollectiblesData(count int) (result []thirdparty.CollectibleDat
},
BackgroundColor: fmt.Sprintf("backgroundcolor-%d", i),
TokenURI: fmt.Sprintf("tokenuri-%d", i),
CommunityID: fmt.Sprintf("communityid-%d", i),
CommunityID: fmt.Sprintf("communityid-%d", i%5),
}
if i%5 == 0 {
newCollectible.CommunityID = ""
}
result = append(result, newCollectible)
}
@ -77,7 +80,7 @@ func generateTestCommunityData(count int) []thirdparty.CollectibleCommunityInfo
result := make([]thirdparty.CollectibleCommunityInfo, 0, count)
for i := 0; i < count; i++ {
newCommunityInfo := thirdparty.CollectibleCommunityInfo{
PrivilegesLevel: token.PrivilegesLevel(i),
PrivilegesLevel: token.PrivilegesLevel(i) % (token.CommunityLevel + 1),
}
result = append(result, newCommunityInfo)
}

View File

@ -0,0 +1,98 @@
package collectibles
import (
"context"
"database/sql"
"errors"
// used for embedding the sql query in the binary
_ "embed"
"github.com/ethereum/go-ethereum/common"
"github.com/jmoiron/sqlx"
"github.com/status-im/status-go/protocol/communities/token"
wcommon "github.com/status-im/status-go/services/wallet/common"
"github.com/status-im/status-go/services/wallet/thirdparty"
)
func allCommunityIDsFilter() []string {
return []string{}
}
func allCommunityPrivilegesLevelsFilter() []token.PrivilegesLevel {
return []token.PrivilegesLevel{}
}
func allFilter() Filter {
return Filter{
CommunityIDs: allCommunityIDsFilter(),
CommunityPrivilegesLevels: allCommunityPrivilegesLevelsFilter(),
FilterCommunity: All,
}
}
type FilterCommunityType int
const (
All FilterCommunityType = iota
OnlyNonCommunity
OnlyCommunity
)
type Filter struct {
CommunityIDs []string `json:"community_ids"`
CommunityPrivilegesLevels []token.PrivilegesLevel `json:"community_privileges_levels"`
FilterCommunity FilterCommunityType `json:"filter_community"`
}
//go:embed filter.sql
var queryString string
func filterOwnedCollectibles(ctx context.Context, db *sql.DB, chainIDs []wcommon.ChainID, addresses []common.Address, filter Filter, offset int, limit int) ([]thirdparty.CollectibleUniqueID, error) {
if len(addresses) == 0 {
return nil, errors.New("no addresses provided")
}
if len(chainIDs) == 0 {
return nil, errors.New("no chainIDs provided")
}
filterCommunityTypeAll := filter.FilterCommunity == All
filterCommunityTypeOnlyNonCommunity := filter.FilterCommunity == OnlyNonCommunity
filterCommunityTypeOnlyCommunity := filter.FilterCommunity == OnlyCommunity
communityIDFilterDisabled := len(filter.CommunityIDs) == 0
if communityIDFilterDisabled {
// IN clause doesn't work with empty array, so we need to provide a dummy value
filter.CommunityIDs = []string{""}
}
communityPrivilegesLevelDisabled := len(filter.CommunityPrivilegesLevels) == 0
if communityPrivilegesLevelDisabled {
// IN clause doesn't work with empty array, so we need to provide a dummy value
filter.CommunityPrivilegesLevels = []token.PrivilegesLevel{token.PrivilegesLevel(0)}
}
query, args, err := sqlx.In(queryString,
filterCommunityTypeAll, filterCommunityTypeOnlyNonCommunity, filterCommunityTypeOnlyCommunity,
communityIDFilterDisabled, communityPrivilegesLevelDisabled,
chainIDs, addresses, filter.CommunityIDs, filter.CommunityPrivilegesLevels,
limit, offset)
if err != nil {
return nil, err
}
stmt, err := db.Prepare(query)
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err := stmt.Query(args...)
if err != nil {
return nil, err
}
defer rows.Close()
return thirdparty.RowsToCollectibles(rows)
}

View File

@ -0,0 +1,49 @@
WITH filter_conditions AS (
SELECT
? AS filterCommunityTypeAll,
? AS filterCommunityTypeOnlyNonCommunity,
? AS filterCommunityTypeOnlyCommunity,
? AS communityIDFilterDisabled,
? AS communityPrivilegesLevelDisabled
)
SELECT
ownership.chain_id,
ownership.contract_address,
ownership.token_id
FROM
collectibles_ownership_cache ownership
LEFT JOIN collectible_data_cache data
ON (
ownership.chain_id = data.chain_id
AND ownership.contract_address = data.contract_address
AND ownership.token_id = data.token_id
)
CROSS JOIN filter_conditions
WHERE
ownership.chain_id IN (?)
AND ownership.owner_address IN (?)
AND (
filterCommunityTypeAll
OR (
filterCommunityTypeOnlyNonCommunity
AND data.community_id = ""
)
OR (
filterCommunityTypeOnlyCommunity
AND data.community_id <> ""
)
)
AND (
communityIDFilterDisabled
OR (
data.community_id IN (?)
)
)
AND (
communityPrivilegesLevelDisabled
OR (
data.community_privileges_level IN (?)
)
)
LIMIT
? OFFSET ?

View File

@ -0,0 +1,184 @@
package collectibles
import (
"context"
"database/sql"
"testing"
"github.com/ethereum/go-ethereum/common"
"github.com/status-im/status-go/protocol/communities/token"
w_common "github.com/status-im/status-go/services/wallet/common"
"github.com/status-im/status-go/services/wallet/thirdparty"
"github.com/status-im/status-go/t/helpers"
"github.com/status-im/status-go/walletdatabase"
"github.com/stretchr/testify/require"
)
func setupTestFilterDB(t *testing.T) (db *sql.DB, close func()) {
db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{})
require.NoError(t, err)
return db, func() {
require.NoError(t, db.Close())
}
}
func TestFilterOwnedCollectibles(t *testing.T) {
db, close := setupTestFilterDB(t)
defer close()
oDB := NewOwnershipDB(db)
cDB := NewCollectibleDataDB(db)
const nData = 50
data := generateTestCollectiblesData(nData)
communityData := generateTestCommunityData(nData)
ownerAddresses := []common.Address{
common.HexToAddress("0x1234"),
common.HexToAddress("0x5678"),
common.HexToAddress("0xABCD"),
}
randomAddress := common.HexToAddress("0xFFFF")
dataPerID := make(map[string]thirdparty.CollectibleData)
communityDataPerID := make(map[string]thirdparty.CollectibleCommunityInfo)
idsPerChainIDAndOwner := make(map[w_common.ChainID]map[common.Address][]thirdparty.CollectibleUniqueID)
var err error
for i := 0; i < nData; i++ {
dataPerID[data[i].ID.HashKey()] = data[i]
communityDataPerID[data[i].ID.HashKey()] = communityData[i]
chainID := data[i].ID.ContractID.ChainID
ownerAddress := ownerAddresses[i%len(ownerAddresses)]
if _, ok := idsPerChainIDAndOwner[chainID]; !ok {
idsPerChainIDAndOwner[chainID] = make(map[common.Address][]thirdparty.CollectibleUniqueID)
}
if _, ok := idsPerChainIDAndOwner[chainID][ownerAddress]; !ok {
idsPerChainIDAndOwner[chainID][ownerAddress] = make([]thirdparty.CollectibleUniqueID, 0, len(data))
}
idsPerChainIDAndOwner[chainID][ownerAddress] = append(idsPerChainIDAndOwner[chainID][ownerAddress], data[i].ID)
}
timestamp := int64(1234567890)
for chainID, idsPerOwner := range idsPerChainIDAndOwner {
for ownerAddress, ids := range idsPerOwner {
err = oDB.Update(chainID, ownerAddress, ids, timestamp)
require.NoError(t, err)
}
}
err = cDB.SetData(data)
require.NoError(t, err)
for i := 0; i < nData; i++ {
err = cDB.SetCommunityInfo(data[i].ID, communityData[i])
require.NoError(t, err)
}
var filter Filter
var filterIDs []thirdparty.CollectibleUniqueID
var expectedIDs []thirdparty.CollectibleUniqueID
var tmpIDs []thirdparty.CollectibleUniqueID
ctx := context.Background()
filterChains := []w_common.ChainID{w_common.ChainID(1), w_common.ChainID(2)}
filterAddresses := []common.Address{ownerAddresses[0], ownerAddresses[1], ownerAddresses[2], randomAddress}
// Test common case
filter = allFilter()
tmpIDs, err = oDB.GetOwnedCollectibles(filterChains, filterAddresses, 0, nData)
require.NoError(t, err)
expectedIDs = tmpIDs
filterIDs, err = filterOwnedCollectibles(ctx, db, filterChains, filterAddresses, filter, 0, nData)
require.NoError(t, err)
require.Equal(t, expectedIDs, filterIDs)
// Test only non-community
filter = allFilter()
filter.FilterCommunity = OnlyNonCommunity
tmpIDs, err = oDB.GetOwnedCollectibles(filterChains, filterAddresses, 0, nData)
require.NoError(t, err)
expectedIDs = nil
for _, id := range tmpIDs {
if dataPerID[id.HashKey()].CommunityID == "" {
expectedIDs = append(expectedIDs, id)
}
}
filterIDs, err = filterOwnedCollectibles(ctx, db, filterChains, filterAddresses, filter, 0, nData)
require.NoError(t, err)
require.Equal(t, expectedIDs, filterIDs)
// Test only community
filter = allFilter()
filter.FilterCommunity = OnlyCommunity
tmpIDs, err = oDB.GetOwnedCollectibles(filterChains, filterAddresses, 0, nData)
require.NoError(t, err)
expectedIDs = nil
for _, id := range tmpIDs {
if dataPerID[id.HashKey()].CommunityID != "" {
expectedIDs = append(expectedIDs, id)
}
}
filterIDs, err = filterOwnedCollectibles(ctx, db, filterChains, filterAddresses, filter, 0, nData)
require.NoError(t, err)
require.Equal(t, expectedIDs, filterIDs)
// Test specific community
communityIDa := data[0].CommunityID
communityIDb := data[1].CommunityID
communityIDs := []string{communityIDa, communityIDb}
filter = allFilter()
filter.CommunityIDs = communityIDs
tmpIDs, err = oDB.GetOwnedCollectibles(filterChains, filterAddresses, 0, nData)
require.NoError(t, err)
expectedIDs = nil
for _, id := range tmpIDs {
if dataPerID[id.HashKey()].CommunityID == communityIDa || dataPerID[id.HashKey()].CommunityID == communityIDb {
expectedIDs = append(expectedIDs, id)
}
}
filterIDs, err = filterOwnedCollectibles(ctx, db, filterChains, filterAddresses, filter, 0, nData)
require.NoError(t, err)
require.Equal(t, expectedIDs, filterIDs)
// Test specific privileges level
privilegeLevel := token.PrivilegesLevel(2)
filter = allFilter()
filter.CommunityPrivilegesLevels = []token.PrivilegesLevel{privilegeLevel}
tmpIDs, err = oDB.GetOwnedCollectibles(filterChains, filterAddresses, 0, nData)
require.NoError(t, err)
expectedIDs = nil
for _, id := range tmpIDs {
if communityDataPerID[id.HashKey()].PrivilegesLevel == privilegeLevel {
expectedIDs = append(expectedIDs, id)
}
}
filterIDs, err = filterOwnedCollectibles(ctx, db, filterChains, filterAddresses, filter, 0, nData)
require.NoError(t, err)
require.Equal(t, expectedIDs, filterIDs)
}

View File

@ -47,6 +47,7 @@ var (
type Service struct {
manager *Manager
controller *Controller
db *sql.DB
ownershipDB *OwnershipDB
communityDB *community.DataDB
walletFeed *event.Feed
@ -64,6 +65,7 @@ func NewService(
return &Service{
manager: manager,
controller: NewController(db, walletFeed, accountsDB, accountsFeed, settingsFeed, networkManager, manager),
db: db,
ownershipDB: NewOwnershipDB(db),
communityDB: community.NewDataDB(db),
walletFeed: walletFeed,
@ -111,9 +113,9 @@ type filterOwnedCollectiblesTaskReturnType struct {
// FilterOwnedCollectiblesResponse allows only one filter task to run at a time
// and it cancels the current one if a new one is started
// All calls will trigger an EventOwnedCollectiblesFilteringDone event with the result of the filtering
func (s *Service) FilterOwnedCollectiblesAsync(requestID int32, chainIDs []walletCommon.ChainID, addresses []common.Address, offset int, limit int) {
func (s *Service) FilterOwnedCollectiblesAsync(requestID int32, chainIDs []walletCommon.ChainID, addresses []common.Address, filter Filter, offset int, limit int) {
s.scheduler.Enqueue(requestID, filterOwnedCollectiblesTask, func(ctx context.Context) (interface{}, error) {
collectibles, hasMore, err := s.GetOwnedCollectibles(chainIDs, addresses, offset, limit)
collectibles, hasMore, err := s.FilterOwnedCollectibles(chainIDs, addresses, filter, offset, limit)
if err != nil {
return nil, err
}
@ -212,9 +214,10 @@ func (s *Service) sendResponseEvent(requestID *int32, eventType walletevent.Even
s.walletFeed.Send(event)
}
func (s *Service) GetOwnedCollectibles(chainIDs []walletCommon.ChainID, owners []common.Address, offset int, limit int) ([]thirdparty.CollectibleUniqueID, bool, error) {
func (s *Service) FilterOwnedCollectibles(chainIDs []walletCommon.ChainID, owners []common.Address, filter Filter, offset int, limit int) ([]thirdparty.CollectibleUniqueID, bool, error) {
ctx := context.Background()
// Request one more than limit, to check if DB has more available
ids, err := s.ownershipDB.GetOwnedCollectibles(chainIDs, owners, offset, limit+1)
ids, err := filterOwnedCollectibles(ctx, s.db, chainIDs, owners, filter, offset, limit+1)
if err != nil {
return nil, false, err
}