From e17d4606b145787803d2d6f9604c6f045942d937 Mon Sep 17 00:00:00 2001 From: Dario Gabriel Lipicar Date: Tue, 14 Nov 2023 14:16:39 -0300 Subject: [PATCH] fix: implement cancellable collectibles requests --- protocol/communities/manager.go | 6 +- protocol/communities/manager_test.go | 2 +- protocol/communities/permission_checker.go | 4 +- services/wallet/activity/service.go | 38 ++++++---- services/wallet/activity/service_test.go | 3 +- services/wallet/api.go | 14 ++-- services/wallet/collectibles/commands.go | 14 +--- services/wallet/collectibles/controller.go | 4 +- services/wallet/collectibles/manager.go | 75 +++++++++---------- services/wallet/collectibles/service.go | 4 +- services/wallet/common/utils.go | 13 ++++ services/wallet/thirdparty/alchemy/client.go | 49 ++++++------ .../wallet/thirdparty/collectible_types.go | 11 +-- .../wallet/thirdparty/opensea/client_v2.go | 43 ++++++----- .../wallet/thirdparty/opensea/http_client.go | 5 +- 15 files changed, 157 insertions(+), 128 deletions(-) create mode 100644 services/wallet/common/utils.go diff --git a/protocol/communities/manager.go b/protocol/communities/manager.go index 2514349bc..79c7c38d3 100644 --- a/protocol/communities/manager.go +++ b/protocol/communities/manager.go @@ -156,7 +156,7 @@ func (m *DefaultTokenManager) GetAllChainIDs() ([]uint64, error) { } type CollectiblesManager interface { - FetchBalancesByOwnerAndContractAddress(chainID walletcommon.ChainID, ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (thirdparty.TokenBalancesPerContractAddress, error) + FetchBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletcommon.ChainID, ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (thirdparty.TokenBalancesPerContractAddress, error) } func (m *DefaultTokenManager) GetBalancesByChain(ctx context.Context, accounts, tokenAddresses []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error) { @@ -2639,6 +2639,8 @@ func (m *Manager) GetOwnedERC721Tokens(walletAddresses []gethcommon.Address, tok return nil, errors.New("no collectibles manager") } + ctx := context.Background() + ownedERC721Tokens := make(CollectiblesByChain) for chainID, erc721Tokens := range tokenRequirements { @@ -2664,7 +2666,7 @@ func (m *Manager) GetOwnedERC721Tokens(walletAddresses []gethcommon.Address, tok } for _, owner := range walletAddresses { - balances, err := m.collectiblesManager.FetchBalancesByOwnerAndContractAddress(walletcommon.ChainID(chainID), owner, contractAddresses) + balances, err := m.collectiblesManager.FetchBalancesByOwnerAndContractAddress(ctx, walletcommon.ChainID(chainID), owner, contractAddresses) if err != nil { m.logger.Info("couldn't fetch owner assets", zap.Error(err)) return nil, err diff --git a/protocol/communities/manager_test.go b/protocol/communities/manager_test.go index a1af7efd3..0c59879b7 100644 --- a/protocol/communities/manager_test.go +++ b/protocol/communities/manager_test.go @@ -110,7 +110,7 @@ func (m *testCollectiblesManager) setResponse(chainID uint64, walletAddress geth m.response[chainID][walletAddress][contractAddress] = balances } -func (m *testCollectiblesManager) FetchBalancesByOwnerAndContractAddress(chainID walletCommon.ChainID, ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (thirdparty.TokenBalancesPerContractAddress, error) { +func (m *testCollectiblesManager) FetchBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletCommon.ChainID, ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (thirdparty.TokenBalancesPerContractAddress, error) { return m.response[uint64(chainID)][ownerAddress], nil } diff --git a/protocol/communities/permission_checker.go b/protocol/communities/permission_checker.go index 1bd2b6ad9..cd7013d0d 100644 --- a/protocol/communities/permission_checker.go +++ b/protocol/communities/permission_checker.go @@ -53,6 +53,8 @@ func (p *DefaultPermissionChecker) GetOwnedERC721Tokens(walletAddresses []gethco return nil, errors.New("no collectibles manager") } + ctx := context.Background() + ownedERC721Tokens := make(CollectiblesByChain) for chainID, erc721Tokens := range tokenRequirements { @@ -78,7 +80,7 @@ func (p *DefaultPermissionChecker) GetOwnedERC721Tokens(walletAddresses []gethco } for _, owner := range walletAddresses { - balances, err := p.collectiblesManager.FetchBalancesByOwnerAndContractAddress(walletcommon.ChainID(chainID), owner, contractAddresses) + balances, err := p.collectiblesManager.FetchBalancesByOwnerAndContractAddress(ctx, walletcommon.ChainID(chainID), owner, contractAddresses) if err != nil { p.logger.Info("couldn't fetch owner assets", zap.Error(err)) return nil, err diff --git a/services/wallet/activity/service.go b/services/wallet/activity/service.go index cd265ee3c..9dc4f8e0d 100644 --- a/services/wallet/activity/service.go +++ b/services/wallet/activity/service.go @@ -111,13 +111,25 @@ func (s *Service) FilterActivityAsync(requestID int32, addresses []common.Addres s.sendResponseEvent(&requestID, EventActivityFilteringDone, res, err) - // Report details post-response to ensure updates have a match - if res.Activities != nil { - go s.lazyLoadDetails(requestID, res.Activities) - } + s.getActivityDetailsAsync(requestID, res.Activities) }) } +func (s *Service) getActivityDetailsAsync(requestID int32, entries []Entry) { + if len(entries) == 0 { + return + } + + ctx := context.Background() + + go func() { + activityData, err := s.getActivityDetails(ctx, entries) + if len(activityData) != 0 { + s.sendResponseEvent(&requestID, EventActivityFilteringUpdate, activityData, err) + } + }() +} + type CollectibleHeader struct { ID thirdparty.CollectibleUniqueID `json:"id"` Name string `json:"name"` @@ -141,7 +153,7 @@ func (s *Service) GetActivityCollectiblesAsync(requestID int32, chainIDs []w_com return nil, err } - data, err := s.collectibles.FetchAssetsByCollectibleUniqueID(collectibles) + data, err := s.collectibles.FetchAssetsByCollectibleUniqueID(ctx, collectibles) if err != nil { return nil, err } @@ -184,8 +196,8 @@ func (s *Service) GetTxDetails(ctx context.Context, id string) (*EntryDetails, e return getTxDetails(ctx, s.db, id) } -// lazyLoadDetails check if any of the entries have details that are not loaded then fetch and emit result -func (s *Service) lazyLoadDetails(requestID int32, entries []Entry) { +// getActivityDetails check if any of the entries have details that are not loaded then fetch and emit result +func (s *Service) getActivityDetails(ctx context.Context, entries []Entry) ([]*EntryData, error) { res := make([]*EntryData, 0) var err error ids := make([]thirdparty.CollectibleUniqueID, 0) @@ -205,15 +217,15 @@ func (s *Service) lazyLoadDetails(requestID int32, entries []Entry) { } if len(ids) == 0 { - return + return nil, nil } - log.Debug("wallet.activity.Service lazyLoadDetails", "requestID", requestID, "entries.len", len(entries), "ids.len", len(ids)) + log.Debug("wallet.activity.Service lazyLoadDetails", "entries.len", len(entries), "ids.len", len(ids)) - colData, err := s.collectibles.FetchAssetsByCollectibleUniqueID(ids) + colData, err := s.collectibles.FetchAssetsByCollectibleUniqueID(ctx, ids) if err != nil { log.Error("Error fetching collectible details", "error", err) - return + return nil, err } for _, col := range colData { @@ -236,9 +248,7 @@ func (s *Service) lazyLoadDetails(requestID int32, entries []Entry) { res = append(res, data) } - if len(res) > 0 { - s.sendResponseEvent(&requestID, EventActivityFilteringUpdate, res, err) - } + return res, nil } type GetRecipientsResponse struct { diff --git a/services/wallet/activity/service_test.go b/services/wallet/activity/service_test.go index 1dde48c8a..fa3a559e3 100644 --- a/services/wallet/activity/service_test.go +++ b/services/wallet/activity/service_test.go @@ -1,6 +1,7 @@ package activity import ( + "context" "database/sql" "encoding/json" "math/big" @@ -28,7 +29,7 @@ type mockCollectiblesManager struct { mock.Mock } -func (m *mockCollectiblesManager) FetchAssetsByCollectibleUniqueID(uniqueIDs []thirdparty.CollectibleUniqueID) ([]thirdparty.FullCollectibleData, error) { +func (m *mockCollectiblesManager) FetchAssetsByCollectibleUniqueID(ctx context.Context, uniqueIDs []thirdparty.CollectibleUniqueID) ([]thirdparty.FullCollectibleData, error) { args := m.Called(uniqueIDs) res := args.Get(0) if res == nil { diff --git a/services/wallet/api.go b/services/wallet/api.go index ed69ed3a7..716111f4e 100644 --- a/services/wallet/api.go +++ b/services/wallet/api.go @@ -307,10 +307,10 @@ func (api *API) GetCryptoOnRamps(ctx context.Context) ([]CryptoOnRamp, error) { Collectibles API Start */ -func (api *API) FetchBalancesByOwnerAndContractAddress(chainID wcommon.ChainID, ownerAddress common.Address, contractAddresses []common.Address) (thirdparty.TokenBalancesPerContractAddress, error) { +func (api *API) FetchBalancesByOwnerAndContractAddress(ctx context.Context, chainID wcommon.ChainID, ownerAddress common.Address, contractAddresses []common.Address) (thirdparty.TokenBalancesPerContractAddress, error) { log.Debug("call to FetchBalancesByOwnerAndContractAddress") - return api.s.collectiblesManager.FetchBalancesByOwnerAndContractAddress(chainID, ownerAddress, contractAddresses) + return api.s.collectiblesManager.FetchBalancesByOwnerAndContractAddress(ctx, chainID, ownerAddress, contractAddresses) } func (api *API) RefetchOwnedCollectibles() error { @@ -337,24 +337,24 @@ func (api *API) GetCollectiblesDetailsAsync(requestID int32, uniqueIDs []thirdpa // @deprecated func (api *API) GetCollectiblesByOwnerWithCursor(ctx context.Context, chainID wcommon.ChainID, owner common.Address, cursor string, limit int) (*thirdparty.FullCollectibleDataContainer, error) { log.Debug("call to GetCollectiblesByOwnerWithCursor") - return api.s.collectiblesManager.FetchAllAssetsByOwner(chainID, owner, cursor, limit, thirdparty.FetchFromAnyProvider) + return api.s.collectiblesManager.FetchAllAssetsByOwner(ctx, chainID, owner, cursor, limit, thirdparty.FetchFromAnyProvider) } // @deprecated func (api *API) GetCollectiblesByOwnerAndContractAddressWithCursor(ctx context.Context, chainID wcommon.ChainID, owner common.Address, contractAddresses []common.Address, cursor string, limit int) (*thirdparty.FullCollectibleDataContainer, error) { log.Debug("call to GetCollectiblesByOwnerAndContractAddressWithCursor") - return api.s.collectiblesManager.FetchAllAssetsByOwnerAndContractAddress(chainID, owner, contractAddresses, cursor, limit, thirdparty.FetchFromAnyProvider) + return api.s.collectiblesManager.FetchAllAssetsByOwnerAndContractAddress(ctx, chainID, owner, contractAddresses, cursor, limit, thirdparty.FetchFromAnyProvider) } // @deprecated func (api *API) GetCollectiblesByUniqueID(ctx context.Context, uniqueIDs []thirdparty.CollectibleUniqueID) ([]thirdparty.FullCollectibleData, error) { log.Debug("call to GetCollectiblesByUniqueID") - return api.s.collectiblesManager.FetchAssetsByCollectibleUniqueID(uniqueIDs) + return api.s.collectiblesManager.FetchAssetsByCollectibleUniqueID(ctx, uniqueIDs) } -func (api *API) GetCollectibleOwnersByContractAddress(chainID wcommon.ChainID, contractAddress common.Address) (*thirdparty.CollectibleContractOwnership, error) { +func (api *API) GetCollectibleOwnersByContractAddress(ctx context.Context, chainID wcommon.ChainID, contractAddress common.Address) (*thirdparty.CollectibleContractOwnership, error) { log.Debug("call to GetCollectibleOwnersByContractAddress") - return api.s.collectiblesManager.FetchCollectibleOwnersByContractAddress(chainID, contractAddress) + return api.s.collectiblesManager.FetchCollectibleOwnersByContractAddress(ctx, chainID, contractAddress) } /* diff --git a/services/wallet/collectibles/commands.go b/services/wallet/collectibles/commands.go index 4829dea80..24b0ceb5f 100644 --- a/services/wallet/collectibles/commands.go +++ b/services/wallet/collectibles/commands.go @@ -198,7 +198,7 @@ func (c *loadOwnedCollectiblesCommand) Run(parent context.Context) (err error) { initialFetch := lastFetchTimestamp == InvalidTimestamp // Fetch collectibles in chunks for { - if shouldCancel(parent) { + if walletCommon.ShouldCancel(parent) { c.err = errors.New("context cancelled") break } @@ -206,7 +206,7 @@ func (c *loadOwnedCollectiblesCommand) Run(parent context.Context) (err error) { pageStart := time.Now() log.Debug("start loadOwnedCollectiblesCommand", "chain", c.chainID, "account", c.account, "page", pageNr) - partialOwnership, err := c.manager.FetchCollectibleOwnershipByOwner(c.chainID, c.account, cursor, fetchLimit, providerID) + partialOwnership, err := c.manager.FetchCollectibleOwnershipByOwner(parent, c.chainID, c.account, cursor, fetchLimit, providerID) if err != nil { log.Error("failed loadOwnedCollectiblesCommand", "chain", c.chainID, "account", c.account, "page", pageNr, "error", err) @@ -263,13 +263,3 @@ func (c *loadOwnedCollectiblesCommand) Run(parent context.Context) (err error) { log.Debug("end loadOwnedCollectiblesCommand", "chain", c.chainID, "account", c.account, "in", time.Since(start)) return nil } - -// shouldCancel returns true if the context has been cancelled and task should be aborted -func shouldCancel(ctx context.Context) bool { - select { - case <-ctx.Done(): - return true - default: - } - return false -} diff --git a/services/wallet/collectibles/controller.go b/services/wallet/collectibles/controller.go index 0b0288307..e80212ebb 100644 --- a/services/wallet/collectibles/controller.go +++ b/services/wallet/collectibles/controller.go @@ -400,7 +400,9 @@ func (c *Controller) stopSettingsWatcher() { } func (c *Controller) notifyCommunityCollectiblesReceived(ownedCollectibles OwnedCollectibles) { - collectiblesData, err := c.manager.FetchAssetsByCollectibleUniqueID(ownedCollectibles.ids) + ctx := context.Background() + + collectiblesData, err := c.manager.FetchAssetsByCollectibleUniqueID(ctx, ownedCollectibles.ids) if err != nil { log.Error("Error fetching collectibles data", "error", err) return diff --git a/services/wallet/collectibles/manager.go b/services/wallet/collectibles/manager.go index a60a1986a..e6d3a722d 100644 --- a/services/wallet/collectibles/manager.go +++ b/services/wallet/collectibles/manager.go @@ -44,7 +44,7 @@ var ( ) type ManagerInterface interface { - FetchAssetsByCollectibleUniqueID(uniqueIDs []thirdparty.CollectibleUniqueID) ([]thirdparty.FullCollectibleData, error) + FetchAssetsByCollectibleUniqueID(ctx context.Context, uniqueIDs []thirdparty.CollectibleUniqueID) ([]thirdparty.FullCollectibleData, error) } type Manager struct { @@ -177,8 +177,8 @@ func makeContractOwnershipCall(main func() (any, error), fallback func() (any, e } } -func (o *Manager) doContentTypeRequest(url string) (string, error) { - req, err := http.NewRequest(http.MethodHead, url, nil) +func (o *Manager) doContentTypeRequest(ctx context.Context, url string) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) if err != nil { return "", err } @@ -202,7 +202,7 @@ func (o *Manager) SetCommunityInfoProvider(communityInfoProvider thirdparty.Coll } // Need to combine different providers to support all needed ChainIDs -func (o *Manager) FetchBalancesByOwnerAndContractAddress(chainID walletCommon.ChainID, ownerAddress common.Address, contractAddresses []common.Address) (thirdparty.TokenBalancesPerContractAddress, error) { +func (o *Manager) FetchBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletCommon.ChainID, ownerAddress common.Address, contractAddresses []common.Address) (thirdparty.TokenBalancesPerContractAddress, error) { ret := make(thirdparty.TokenBalancesPerContractAddress) for _, contractAddress := range contractAddresses { @@ -210,11 +210,11 @@ func (o *Manager) FetchBalancesByOwnerAndContractAddress(chainID walletCommon.Ch } // Try with account ownership providers first - assetsContainer, err := o.FetchAllAssetsByOwnerAndContractAddress(chainID, ownerAddress, contractAddresses, thirdparty.FetchFromStartCursor, thirdparty.FetchNoLimit, thirdparty.FetchFromAnyProvider) + assetsContainer, err := o.FetchAllAssetsByOwnerAndContractAddress(ctx, chainID, ownerAddress, contractAddresses, thirdparty.FetchFromStartCursor, thirdparty.FetchNoLimit, thirdparty.FetchFromAnyProvider) if err == ErrNoProvidersAvailableForChainID { // Use contract ownership providers for _, contractAddress := range contractAddresses { - ownership, err := o.FetchCollectibleOwnersByContractAddress(chainID, contractAddress) + ownership, err := o.FetchCollectibleOwnersByContractAddress(ctx, chainID, contractAddress) if err != nil { return nil, err } @@ -243,7 +243,7 @@ func (o *Manager) FetchBalancesByOwnerAndContractAddress(chainID walletCommon.Ch return ret, nil } -func (o *Manager) FetchAllAssetsByOwnerAndContractAddress(chainID walletCommon.ChainID, owner common.Address, contractAddresses []common.Address, cursor string, limit int, providerID string) (*thirdparty.FullCollectibleDataContainer, error) { +func (o *Manager) FetchAllAssetsByOwnerAndContractAddress(ctx context.Context, chainID walletCommon.ChainID, owner common.Address, contractAddresses []common.Address, cursor string, limit int, providerID string) (*thirdparty.FullCollectibleDataContainer, error) { defer o.checkConnectionStatus(chainID) anyProviderAvailable := false @@ -256,13 +256,13 @@ func (o *Manager) FetchAllAssetsByOwnerAndContractAddress(chainID walletCommon.C continue } - assetContainer, err := provider.FetchAllAssetsByOwnerAndContractAddress(chainID, owner, contractAddresses, cursor, limit) + assetContainer, err := provider.FetchAllAssetsByOwnerAndContractAddress(ctx, chainID, owner, contractAddresses, cursor, limit) if err != nil { log.Error("FetchAllAssetsByOwnerAndContractAddress failed for", "provider", provider.ID(), "chainID", chainID, "err", err) continue } - err = o.processFullCollectibleData(assetContainer.Items) + err = o.processFullCollectibleData(ctx, assetContainer.Items) if err != nil { return nil, err } @@ -276,7 +276,7 @@ func (o *Manager) FetchAllAssetsByOwnerAndContractAddress(chainID walletCommon.C return nil, ErrNoProvidersAvailableForChainID } -func (o *Manager) FetchAllAssetsByOwner(chainID walletCommon.ChainID, owner common.Address, cursor string, limit int, providerID string) (*thirdparty.FullCollectibleDataContainer, error) { +func (o *Manager) FetchAllAssetsByOwner(ctx context.Context, chainID walletCommon.ChainID, owner common.Address, cursor string, limit int, providerID string) (*thirdparty.FullCollectibleDataContainer, error) { defer o.checkConnectionStatus(chainID) anyProviderAvailable := false @@ -289,13 +289,13 @@ func (o *Manager) FetchAllAssetsByOwner(chainID walletCommon.ChainID, owner comm continue } - assetContainer, err := provider.FetchAllAssetsByOwner(chainID, owner, cursor, limit) + assetContainer, err := provider.FetchAllAssetsByOwner(ctx, chainID, owner, cursor, limit) if err != nil { log.Error("FetchAllAssetsByOwner failed for", "provider", provider.ID(), "chainID", chainID, "err", err) continue } - err = o.processFullCollectibleData(assetContainer.Items) + err = o.processFullCollectibleData(ctx, assetContainer.Items) if err != nil { return nil, err } @@ -309,10 +309,10 @@ func (o *Manager) FetchAllAssetsByOwner(chainID walletCommon.ChainID, owner comm return nil, ErrNoProvidersAvailableForChainID } -func (o *Manager) FetchCollectibleOwnershipByOwner(chainID walletCommon.ChainID, owner common.Address, cursor string, limit int, providerID string) (*thirdparty.CollectibleOwnershipContainer, error) { +func (o *Manager) FetchCollectibleOwnershipByOwner(ctx context.Context, chainID walletCommon.ChainID, owner common.Address, cursor string, limit int, providerID string) (*thirdparty.CollectibleOwnershipContainer, error) { // We don't yet have an API that will return only Ownership data // Use the full Ownership + Metadata endpoint and use the data we need - assetContainer, err := o.FetchAllAssetsByOwner(chainID, owner, cursor, limit, providerID) + assetContainer, err := o.FetchAllAssetsByOwner(ctx, chainID, owner, cursor, limit, providerID) if err != nil { return nil, err } @@ -321,7 +321,7 @@ func (o *Manager) FetchCollectibleOwnershipByOwner(chainID walletCommon.ChainID, return &ret, nil } -func (o *Manager) FetchAssetsByCollectibleUniqueID(uniqueIDs []thirdparty.CollectibleUniqueID) ([]thirdparty.FullCollectibleData, error) { +func (o *Manager) FetchAssetsByCollectibleUniqueID(ctx context.Context, uniqueIDs []thirdparty.CollectibleUniqueID) ([]thirdparty.FullCollectibleData, error) { missingIDs, err := o.collectiblesDataDB.GetIDsNotInDB(uniqueIDs) if err != nil { return nil, err @@ -337,13 +337,13 @@ func (o *Manager) FetchAssetsByCollectibleUniqueID(uniqueIDs []thirdparty.Collec continue } - fetchedAssets, err := provider.FetchAssetsByCollectibleUniqueID(idsToFetch) + fetchedAssets, err := provider.FetchAssetsByCollectibleUniqueID(ctx, idsToFetch) if err != nil { log.Error("FetchAssetsByCollectibleUniqueID failed for", "provider", provider.ID(), "chainID", chainID, "err", err) continue } - err = o.processFullCollectibleData(fetchedAssets) + err = o.processFullCollectibleData(ctx, fetchedAssets) if err != nil { return nil, err } @@ -355,7 +355,7 @@ func (o *Manager) FetchAssetsByCollectibleUniqueID(uniqueIDs []thirdparty.Collec return o.getCacheFullCollectibleData(uniqueIDs) } -func (o *Manager) FetchCollectionsDataByContractID(ids []thirdparty.ContractID) ([]thirdparty.CollectionData, error) { +func (o *Manager) FetchCollectionsDataByContractID(ctx context.Context, ids []thirdparty.ContractID) ([]thirdparty.CollectionData, error) { missingIDs, err := o.collectionsDataDB.GetIDsNotInDB(ids) if err != nil { return nil, err @@ -371,13 +371,13 @@ func (o *Manager) FetchCollectionsDataByContractID(ids []thirdparty.ContractID) continue } - fetchedCollections, err := provider.FetchCollectionsDataByContractID(idsToFetch) + fetchedCollections, err := provider.FetchCollectionsDataByContractID(ctx, idsToFetch) if err != nil { log.Error("FetchCollectionsDataByContractID failed for", "provider", provider.ID(), "chainID", chainID, "err", err) continue } - err = o.processCollectionData(fetchedCollections) + err = o.processCollectionData(ctx, fetchedCollections) if err != nil { return nil, err } @@ -413,12 +413,12 @@ func (o *Manager) getContractOwnershipProviders(chainID walletCommon.ChainID) (m return } -func getCollectibleOwnersByContractAddressFunc(chainID walletCommon.ChainID, contractAddress common.Address, provider thirdparty.CollectibleContractOwnershipProvider) func() (any, error) { +func getCollectibleOwnersByContractAddressFunc(ctx context.Context, chainID walletCommon.ChainID, contractAddress common.Address, provider thirdparty.CollectibleContractOwnershipProvider) func() (any, error) { if provider == nil { return nil } return func() (any, error) { - res, err := provider.FetchCollectibleOwnersByContractAddress(chainID, contractAddress) + res, err := provider.FetchCollectibleOwnersByContractAddress(ctx, chainID, contractAddress) if err != nil { log.Error("FetchCollectibleOwnersByContractAddress failed for", "provider", provider.ID(), "chainID", chainID, "err", err) } @@ -426,7 +426,7 @@ func getCollectibleOwnersByContractAddressFunc(chainID walletCommon.ChainID, con } } -func (o *Manager) FetchCollectibleOwnersByContractAddress(chainID walletCommon.ChainID, contractAddress common.Address) (*thirdparty.CollectibleContractOwnership, error) { +func (o *Manager) FetchCollectibleOwnersByContractAddress(ctx context.Context, chainID walletCommon.ChainID, contractAddress common.Address) (*thirdparty.CollectibleContractOwnership, error) { defer o.checkConnectionStatus(chainID) mainProvider, fallbackProvider := o.getContractOwnershipProviders(chainID) @@ -434,8 +434,8 @@ func (o *Manager) FetchCollectibleOwnersByContractAddress(chainID walletCommon.C return nil, ErrNoProvidersAvailableForChainID } - mainFn := getCollectibleOwnersByContractAddressFunc(chainID, contractAddress, mainProvider) - fallbackFn := getCollectibleOwnersByContractAddressFunc(chainID, contractAddress, fallbackProvider) + mainFn := getCollectibleOwnersByContractAddressFunc(ctx, chainID, contractAddress, mainProvider) + fallbackFn := getCollectibleOwnersByContractAddressFunc(ctx, chainID, contractAddress, fallbackProvider) owners, err := makeContractOwnershipCall(mainFn, fallbackFn) if err != nil { @@ -451,7 +451,7 @@ func isMetadataEmpty(asset thirdparty.CollectibleData) bool { asset.ImageURL == "" } -func (o *Manager) fetchTokenURI(id thirdparty.CollectibleUniqueID) (string, error) { +func (o *Manager) fetchTokenURI(ctx context.Context, id thirdparty.CollectibleUniqueID) (string, error) { if id.TokenID == nil { return "", errors.New("empty token ID") } @@ -465,11 +465,8 @@ func (o *Manager) fetchTokenURI(id thirdparty.CollectibleUniqueID) (string, erro return "", err } - timeoutContext, timeoutCancel := context.WithTimeout(context.Background(), requestTimeout) - defer timeoutCancel() - tokenURI, err := caller.TokenURI(&bind.CallOpts{ - Context: timeoutContext, + Context: ctx, }, id.TokenID.Int) if err != nil { @@ -485,7 +482,7 @@ func (o *Manager) fetchTokenURI(id thirdparty.CollectibleUniqueID) (string, erro return tokenURI, err } -func (o *Manager) processFullCollectibleData(assets []thirdparty.FullCollectibleData) error { +func (o *Manager) processFullCollectibleData(ctx context.Context, assets []thirdparty.FullCollectibleData) error { fullyFetchedAssets := make(map[string]*thirdparty.FullCollectibleData) communityCollectibles := make(map[string][]*thirdparty.FullCollectibleData) @@ -499,7 +496,7 @@ func (o *Manager) processFullCollectibleData(assets []thirdparty.FullCollectible for _, asset := range fullyFetchedAssets { // Only check community ownership if metadata is empty if isMetadataEmpty(asset.CollectibleData) { - err := o.fillTokenURI(asset) + err := o.fillTokenURI(ctx, asset) if err != nil { log.Error("fillTokenURI failed", "err", err) delete(fullyFetchedAssets, asset.CollectibleData.ID.HashKey()) @@ -536,7 +533,7 @@ func (o *Manager) processFullCollectibleData(assets []thirdparty.FullCollectible } for _, asset := range fullyFetchedAssets { - err := o.fillAnimationMediatype(asset) + err := o.fillAnimationMediatype(ctx, asset) if err != nil { log.Error("fillAnimationMediatype failed", "err", err) delete(fullyFetchedAssets, asset.CollectibleData.ID.HashKey()) @@ -581,7 +578,7 @@ func (o *Manager) processFullCollectibleData(assets []thirdparty.FullCollectible if len(missingCollectionIDs) > 0 { // Calling this ensures collection data is fetched and cached (if not already available) - _, err := o.FetchCollectionsDataByContractID(missingCollectionIDs) + _, err := o.FetchCollectionsDataByContractID(ctx, missingCollectionIDs) if err != nil { return err } @@ -590,13 +587,13 @@ func (o *Manager) processFullCollectibleData(assets []thirdparty.FullCollectible return nil } -func (o *Manager) fillTokenURI(asset *thirdparty.FullCollectibleData) error { +func (o *Manager) fillTokenURI(ctx context.Context, asset *thirdparty.FullCollectibleData) error { id := asset.CollectibleData.ID tokenURI := asset.CollectibleData.TokenURI // Only need to fetch it from contract if it was empty if tokenURI == "" { - tokenURI, err := o.fetchTokenURI(id) + tokenURI, err := o.fetchTokenURI(ctx, id) if err != nil { return err @@ -642,9 +639,9 @@ func (o *Manager) fillCommunityInfo(communityID string, communityAssets []*third return nil } -func (o *Manager) fillAnimationMediatype(asset *thirdparty.FullCollectibleData) error { +func (o *Manager) fillAnimationMediatype(ctx context.Context, asset *thirdparty.FullCollectibleData) error { if len(asset.CollectibleData.AnimationURL) > 0 { - contentType, err := o.doContentTypeRequest(asset.CollectibleData.AnimationURL) + contentType, err := o.doContentTypeRequest(ctx, asset.CollectibleData.AnimationURL) if err != nil { asset.CollectibleData.AnimationURL = "" } @@ -653,7 +650,7 @@ func (o *Manager) fillAnimationMediatype(asset *thirdparty.FullCollectibleData) return nil } -func (o *Manager) processCollectionData(collections []thirdparty.CollectionData) error { +func (o *Manager) processCollectionData(ctx context.Context, collections []thirdparty.CollectionData) error { return o.collectionsDataDB.SetData(collections) } diff --git a/services/wallet/collectibles/service.go b/services/wallet/collectibles/service.go index be247e007..60d2fa51b 100644 --- a/services/wallet/collectibles/service.go +++ b/services/wallet/collectibles/service.go @@ -119,7 +119,7 @@ func (s *Service) FilterOwnedCollectiblesAsync(requestID int32, chainIDs []walle if err != nil { return nil, err } - data, err := s.manager.FetchAssetsByCollectibleUniqueID(collectibles) + data, err := s.manager.FetchAssetsByCollectibleUniqueID(ctx, collectibles) if err != nil { return nil, err } @@ -156,7 +156,7 @@ func (s *Service) FilterOwnedCollectiblesAsync(requestID int32, chainIDs []walle func (s *Service) GetCollectiblesDetailsAsync(requestID int32, uniqueIDs []thirdparty.CollectibleUniqueID) { s.scheduler.Enqueue(requestID, getCollectiblesDataTask, func(ctx context.Context) (interface{}, error) { - collectibles, err := s.manager.FetchAssetsByCollectibleUniqueID(uniqueIDs) + collectibles, err := s.manager.FetchAssetsByCollectibleUniqueID(ctx, uniqueIDs) if err != nil { return nil, err } diff --git a/services/wallet/common/utils.go b/services/wallet/common/utils.go new file mode 100644 index 000000000..34d8c633a --- /dev/null +++ b/services/wallet/common/utils.go @@ -0,0 +1,13 @@ +package common + +import "context" + +// ShouldCancel returns true if the context has been cancelled and task should be aborted +func ShouldCancel(ctx context.Context) bool { + select { + case <-ctx.Done(): + return true + default: + } + return false +} diff --git a/services/wallet/thirdparty/alchemy/client.go b/services/wallet/thirdparty/alchemy/client.go index 0534c4897..e2bb422ef 100644 --- a/services/wallet/thirdparty/alchemy/client.go +++ b/services/wallet/thirdparty/alchemy/client.go @@ -1,6 +1,7 @@ package alchemy import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -94,8 +95,8 @@ func NewClient(apiKeys map[uint64]string) *Client { } } -func (o *Client) doQuery(url string) (*http.Response, error) { - req, err := http.NewRequest(http.MethodGet, url, nil) +func (o *Client) doQuery(ctx context.Context, url string) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, err } @@ -103,7 +104,7 @@ func (o *Client) doQuery(url string) (*http.Response, error) { return o.doWithRetries(req) } -func (o *Client) doPostWithJSON(url string, payload any) (*http.Response, error) { +func (o *Client) doPostWithJSON(ctx context.Context, url string, payload any) (*http.Response, error) { payloadJSON, err := json.Marshal(payload) if err != nil { return nil, err @@ -112,7 +113,7 @@ func (o *Client) doPostWithJSON(url string, payload any) (*http.Response, error) payloadString := string(payloadJSON) payloadReader := strings.NewReader(payloadString) - req, err := http.NewRequest("POST", url, payloadReader) + req, err := http.NewRequestWithContext(ctx, "POST", url, payloadReader) if err != nil { return nil, err } @@ -154,7 +155,7 @@ func (o *Client) doWithRetries(req *http.Request) (*http.Response, error) { return backoff.RetryWithData(op, &b) } -func (o *Client) FetchCollectibleOwnersByContractAddress(chainID walletCommon.ChainID, contractAddress common.Address) (*thirdparty.CollectibleContractOwnership, error) { +func (o *Client) FetchCollectibleOwnersByContractAddress(ctx context.Context, chainID walletCommon.ChainID, contractAddress common.Address) (*thirdparty.CollectibleContractOwnership, error) { ownership := thirdparty.CollectibleContractOwnership{ ContractAddress: contractAddress, Owners: make([]thirdparty.CollectibleOwner, 0), @@ -174,9 +175,11 @@ func (o *Client) FetchCollectibleOwnersByContractAddress(chainID walletCommon.Ch for { url := fmt.Sprintf("%s/getOwnersForContract?%s", baseURL, queryParams.Encode()) - resp, err := o.doQuery(url) + resp, err := o.doQuery(ctx, url) if err != nil { - o.connectionStatus.SetIsConnected(false) + if ctx.Err() == nil { + o.connectionStatus.SetIsConnected(false) + } return nil, err } o.connectionStatus.SetIsConnected(true) @@ -206,23 +209,23 @@ func (o *Client) FetchCollectibleOwnersByContractAddress(chainID walletCommon.Ch return &ownership, nil } -func (o *Client) FetchAllAssetsByOwner(chainID walletCommon.ChainID, owner common.Address, cursor string, limit int) (*thirdparty.FullCollectibleDataContainer, error) { +func (o *Client) FetchAllAssetsByOwner(ctx context.Context, chainID walletCommon.ChainID, owner common.Address, cursor string, limit int) (*thirdparty.FullCollectibleDataContainer, error) { queryParams := url.Values{} - return o.fetchOwnedAssets(chainID, owner, queryParams, cursor, limit) + return o.fetchOwnedAssets(ctx, chainID, owner, queryParams, cursor, limit) } -func (o *Client) FetchAllAssetsByOwnerAndContractAddress(chainID walletCommon.ChainID, owner common.Address, contractAddresses []common.Address, cursor string, limit int) (*thirdparty.FullCollectibleDataContainer, error) { +func (o *Client) FetchAllAssetsByOwnerAndContractAddress(ctx context.Context, chainID walletCommon.ChainID, owner common.Address, contractAddresses []common.Address, cursor string, limit int) (*thirdparty.FullCollectibleDataContainer, error) { queryParams := url.Values{} for _, contractAddress := range contractAddresses { queryParams.Add("contractAddresses", contractAddress.String()) } - return o.fetchOwnedAssets(chainID, owner, queryParams, cursor, limit) + return o.fetchOwnedAssets(ctx, chainID, owner, queryParams, cursor, limit) } -func (o *Client) fetchOwnedAssets(chainID walletCommon.ChainID, owner common.Address, queryParams url.Values, cursor string, limit int) (*thirdparty.FullCollectibleDataContainer, error) { +func (o *Client) fetchOwnedAssets(ctx context.Context, chainID walletCommon.ChainID, owner common.Address, queryParams url.Values, cursor string, limit int) (*thirdparty.FullCollectibleDataContainer, error) { assets := new(thirdparty.FullCollectibleDataContainer) queryParams["owner"] = []string{owner.String()} @@ -243,9 +246,11 @@ func (o *Client) fetchOwnedAssets(chainID walletCommon.ChainID, owner common.Add for { url := fmt.Sprintf("%s/getNFTsForOwner?%s", baseURL, queryParams.Encode()) - resp, err := o.doQuery(url) + resp, err := o.doQuery(ctx, url) if err != nil { - o.connectionStatus.SetIsConnected(false) + if ctx.Err() == nil { + o.connectionStatus.SetIsConnected(false) + } return nil, err } o.connectionStatus.SetIsConnected(true) @@ -313,7 +318,7 @@ func getCollectibleUniqueIDBatches(ids []thirdparty.CollectibleUniqueID) []Batch return batches } -func (o *Client) fetchAssetsByBatchTokenIDs(chainID walletCommon.ChainID, batchIDs BatchTokenIDs) ([]thirdparty.FullCollectibleData, error) { +func (o *Client) fetchAssetsByBatchTokenIDs(ctx context.Context, chainID walletCommon.ChainID, batchIDs BatchTokenIDs) ([]thirdparty.FullCollectibleData, error) { baseURL, err := getNFTBaseURL(chainID, o.apiKeys[uint64(chainID)]) if err != nil { return nil, err @@ -321,7 +326,7 @@ func (o *Client) fetchAssetsByBatchTokenIDs(chainID walletCommon.ChainID, batchI url := fmt.Sprintf("%s/getNFTMetadataBatch", baseURL) - resp, err := o.doPostWithJSON(url, batchIDs) + resp, err := o.doPostWithJSON(ctx, url, batchIDs) if err != nil { return nil, err } @@ -349,7 +354,7 @@ func (o *Client) fetchAssetsByBatchTokenIDs(chainID walletCommon.ChainID, batchI return ret, nil } -func (o *Client) FetchAssetsByCollectibleUniqueID(uniqueIDs []thirdparty.CollectibleUniqueID) ([]thirdparty.FullCollectibleData, error) { +func (o *Client) FetchAssetsByCollectibleUniqueID(ctx context.Context, uniqueIDs []thirdparty.CollectibleUniqueID) ([]thirdparty.FullCollectibleData, error) { ret := make([]thirdparty.FullCollectibleData, 0, len(uniqueIDs)) idsPerChainID := thirdparty.GroupCollectibleUIDsByChainID(uniqueIDs) @@ -357,7 +362,7 @@ func (o *Client) FetchAssetsByCollectibleUniqueID(uniqueIDs []thirdparty.Collect for chainID, ids := range idsPerChainID { batches := getCollectibleUniqueIDBatches(ids) for _, batch := range batches { - assets, err := o.fetchAssetsByBatchTokenIDs(chainID, batch) + assets, err := o.fetchAssetsByBatchTokenIDs(ctx, chainID, batch) if err != nil { return nil, err } @@ -393,7 +398,7 @@ func getContractAddressBatches(ids []thirdparty.ContractID) []BatchContractAddre return batches } -func (o *Client) fetchCollectionsDataByBatchContractAddresses(chainID walletCommon.ChainID, batchAddresses BatchContractAddresses) ([]thirdparty.CollectionData, error) { +func (o *Client) fetchCollectionsDataByBatchContractAddresses(ctx context.Context, chainID walletCommon.ChainID, batchAddresses BatchContractAddresses) ([]thirdparty.CollectionData, error) { baseURL, err := getNFTBaseURL(chainID, o.apiKeys[uint64(chainID)]) if err != nil { return nil, err @@ -401,7 +406,7 @@ func (o *Client) fetchCollectionsDataByBatchContractAddresses(chainID walletComm url := fmt.Sprintf("%s/getContractMetadataBatch", baseURL) - resp, err := o.doPostWithJSON(url, batchAddresses) + resp, err := o.doPostWithJSON(ctx, url, batchAddresses) if err != nil { return nil, err } @@ -429,7 +434,7 @@ func (o *Client) fetchCollectionsDataByBatchContractAddresses(chainID walletComm return ret, nil } -func (o *Client) FetchCollectionsDataByContractID(contractIDs []thirdparty.ContractID) ([]thirdparty.CollectionData, error) { +func (o *Client) FetchCollectionsDataByContractID(ctx context.Context, contractIDs []thirdparty.ContractID) ([]thirdparty.CollectionData, error) { ret := make([]thirdparty.CollectionData, 0, len(contractIDs)) idsPerChainID := thirdparty.GroupContractIDsByChainID(contractIDs) @@ -437,7 +442,7 @@ func (o *Client) FetchCollectionsDataByContractID(contractIDs []thirdparty.Contr for chainID, ids := range idsPerChainID { batches := getContractAddressBatches(ids) for _, batch := range batches { - contractsData, err := o.fetchCollectionsDataByBatchContractAddresses(chainID, batch) + contractsData, err := o.fetchCollectionsDataByBatchContractAddresses(ctx, chainID, batch) if err != nil { return nil, err } diff --git a/services/wallet/thirdparty/collectible_types.go b/services/wallet/thirdparty/collectible_types.go index 5f30e0d71..9ae86f3d8 100644 --- a/services/wallet/thirdparty/collectible_types.go +++ b/services/wallet/thirdparty/collectible_types.go @@ -1,6 +1,7 @@ package thirdparty import ( + "context" "database/sql" "errors" "fmt" @@ -196,23 +197,23 @@ type CollectibleContractOwnership struct { type CollectibleContractOwnershipProvider interface { CollectibleProvider - FetchCollectibleOwnersByContractAddress(chainID w_common.ChainID, contractAddress common.Address) (*CollectibleContractOwnership, error) + FetchCollectibleOwnersByContractAddress(ctx context.Context, chainID w_common.ChainID, contractAddress common.Address) (*CollectibleContractOwnership, error) } type CollectibleAccountOwnershipProvider interface { CollectibleProvider - FetchAllAssetsByOwner(chainID w_common.ChainID, owner common.Address, cursor string, limit int) (*FullCollectibleDataContainer, error) - FetchAllAssetsByOwnerAndContractAddress(chainID w_common.ChainID, owner common.Address, contractAddresses []common.Address, cursor string, limit int) (*FullCollectibleDataContainer, error) + FetchAllAssetsByOwner(ctx context.Context, chainID w_common.ChainID, owner common.Address, cursor string, limit int) (*FullCollectibleDataContainer, error) + FetchAllAssetsByOwnerAndContractAddress(ctx context.Context, chainID w_common.ChainID, owner common.Address, contractAddresses []common.Address, cursor string, limit int) (*FullCollectibleDataContainer, error) } type CollectibleDataProvider interface { CollectibleProvider - FetchAssetsByCollectibleUniqueID(uniqueIDs []CollectibleUniqueID) ([]FullCollectibleData, error) + FetchAssetsByCollectibleUniqueID(ctx context.Context, uniqueIDs []CollectibleUniqueID) ([]FullCollectibleData, error) } type CollectionDataProvider interface { CollectibleProvider - FetchCollectionsDataByContractID(ids []ContractID) ([]CollectionData, error) + FetchCollectionsDataByContractID(ctx context.Context, ids []ContractID) ([]CollectionData, error) } type CollectibleCommunityInfoProvider interface { diff --git a/services/wallet/thirdparty/opensea/client_v2.go b/services/wallet/thirdparty/opensea/client_v2.go index 0e2e2905b..2d6510e03 100644 --- a/services/wallet/thirdparty/opensea/client_v2.go +++ b/services/wallet/thirdparty/opensea/client_v2.go @@ -1,6 +1,7 @@ package opensea import ( + "context" "encoding/json" "fmt" "net/url" @@ -71,7 +72,7 @@ func NewClientV2(apiKey string, httpClient *HTTPClient) *ClientV2 { } } -func (o *ClientV2) FetchAllAssetsByOwnerAndContractAddress(chainID walletCommon.ChainID, owner common.Address, contractAddresses []common.Address, cursor string, limit int) (*thirdparty.FullCollectibleDataContainer, error) { +func (o *ClientV2) FetchAllAssetsByOwnerAndContractAddress(ctx context.Context, chainID walletCommon.ChainID, owner common.Address, contractAddresses []common.Address, cursor string, limit int) (*thirdparty.FullCollectibleDataContainer, error) { // No dedicated endpoint to filter owned assets by contract address. // Will probably be available at some point, for now do the filtering ourselves. assets := new(thirdparty.FullCollectibleDataContainer) @@ -91,7 +92,7 @@ func (o *ClientV2) FetchAllAssetsByOwnerAndContractAddress(chainID walletCommon. assets.Provider = o.ID() for { - assetsPage, err := o.FetchAllAssetsByOwner(chainID, owner, assets.NextCursor, assetLimitV2) + assetsPage, err := o.FetchAllAssetsByOwner(ctx, chainID, owner, assets.NextCursor, assetLimitV2) if err != nil { return nil, err } @@ -116,7 +117,7 @@ func (o *ClientV2) FetchAllAssetsByOwnerAndContractAddress(chainID walletCommon. return assets, nil } -func (o *ClientV2) FetchAllAssetsByOwner(chainID walletCommon.ChainID, owner common.Address, cursor string, limit int) (*thirdparty.FullCollectibleDataContainer, error) { +func (o *ClientV2) FetchAllAssetsByOwner(ctx context.Context, chainID walletCommon.ChainID, owner common.Address, cursor string, limit int) (*thirdparty.FullCollectibleDataContainer, error) { pathParams := []string{ "chain", chainIDToChainString(chainID), "account", owner.String(), @@ -125,14 +126,14 @@ func (o *ClientV2) FetchAllAssetsByOwner(chainID walletCommon.ChainID, owner com queryParams := url.Values{} - return o.fetchAssets(chainID, pathParams, queryParams, limit, cursor) + return o.fetchAssets(ctx, chainID, pathParams, queryParams, limit, cursor) } -func (o *ClientV2) FetchAssetsByCollectibleUniqueID(uniqueIDs []thirdparty.CollectibleUniqueID) ([]thirdparty.FullCollectibleData, error) { - return o.fetchDetailedAssets(uniqueIDs) +func (o *ClientV2) FetchAssetsByCollectibleUniqueID(ctx context.Context, uniqueIDs []thirdparty.CollectibleUniqueID) ([]thirdparty.FullCollectibleData, error) { + return o.fetchDetailedAssets(ctx, uniqueIDs) } -func (o *ClientV2) fetchAssets(chainID walletCommon.ChainID, pathParams []string, queryParams url.Values, limit int, cursor string) (*thirdparty.FullCollectibleDataContainer, error) { +func (o *ClientV2) fetchAssets(ctx context.Context, chainID walletCommon.ChainID, pathParams []string, queryParams url.Values, limit int, cursor string) (*thirdparty.FullCollectibleDataContainer, error) { assets := new(thirdparty.FullCollectibleDataContainer) tmpLimit := assetLimitV2 @@ -154,7 +155,7 @@ func (o *ClientV2) fetchAssets(chainID walletCommon.ChainID, pathParams []string return nil, err } - body, err := o.client.doGetRequest(url, o.apiKey) + body, err := o.client.doGetRequest(ctx, url, o.apiKey) if err != nil { o.connectionStatus.SetIsConnected(false) return nil, err @@ -198,7 +199,7 @@ func (o *ClientV2) fetchAssets(chainID walletCommon.ChainID, pathParams []string return assets, nil } -func (o *ClientV2) fetchDetailedAssets(uniqueIDs []thirdparty.CollectibleUniqueID) ([]thirdparty.FullCollectibleData, error) { +func (o *ClientV2) fetchDetailedAssets(ctx context.Context, uniqueIDs []thirdparty.CollectibleUniqueID) ([]thirdparty.FullCollectibleData, error) { assets := make([]thirdparty.FullCollectibleData, 0, len(uniqueIDs)) for _, id := range uniqueIDs { @@ -208,9 +209,11 @@ func (o *ClientV2) fetchDetailedAssets(uniqueIDs []thirdparty.CollectibleUniqueI return nil, err } - body, err := o.client.doGetRequest(url, o.apiKey) + body, err := o.client.doGetRequest(ctx, url, o.apiKey) if err != nil { - o.connectionStatus.SetIsConnected(false) + if ctx.Err() == nil { + o.connectionStatus.SetIsConnected(false) + } return nil, err } o.connectionStatus.SetIsConnected(true) @@ -232,16 +235,18 @@ func (o *ClientV2) fetchDetailedAssets(uniqueIDs []thirdparty.CollectibleUniqueI return assets, nil } -func (o *ClientV2) fetchContractDataByContractID(id thirdparty.ContractID) (*ContractData, error) { +func (o *ClientV2) fetchContractDataByContractID(ctx context.Context, id thirdparty.ContractID) (*ContractData, error) { path := fmt.Sprintf("chain/%s/contract/%s", chainIDToChainString(id.ChainID), id.Address.String()) url, err := o.urlGetter(id.ChainID, path) if err != nil { return nil, err } - body, err := o.client.doGetRequest(url, o.apiKey) + body, err := o.client.doGetRequest(ctx, url, o.apiKey) if err != nil { - o.connectionStatus.SetIsConnected(false) + if ctx.Err() == nil { + o.connectionStatus.SetIsConnected(false) + } return nil, err } o.connectionStatus.SetIsConnected(true) @@ -260,14 +265,14 @@ func (o *ClientV2) fetchContractDataByContractID(id thirdparty.ContractID) (*Con return &contract, nil } -func (o *ClientV2) fetchCollectionDataBySlug(chainID walletCommon.ChainID, slug string) (*CollectionData, error) { +func (o *ClientV2) fetchCollectionDataBySlug(ctx context.Context, chainID walletCommon.ChainID, slug string) (*CollectionData, error) { path := fmt.Sprintf("collections/%s", slug) url, err := o.urlGetter(chainID, path) if err != nil { return nil, err } - body, err := o.client.doGetRequest(url, o.apiKey) + body, err := o.client.doGetRequest(ctx, url, o.apiKey) if err != nil { o.connectionStatus.SetIsConnected(false) return nil, err @@ -288,11 +293,11 @@ func (o *ClientV2) fetchCollectionDataBySlug(chainID walletCommon.ChainID, slug return &collection, nil } -func (o *ClientV2) FetchCollectionsDataByContractID(contractIDs []thirdparty.ContractID) ([]thirdparty.CollectionData, error) { +func (o *ClientV2) FetchCollectionsDataByContractID(ctx context.Context, contractIDs []thirdparty.ContractID) ([]thirdparty.CollectionData, error) { ret := make([]thirdparty.CollectionData, 0, len(contractIDs)) for _, id := range contractIDs { - contractData, err := o.fetchContractDataByContractID(id) + contractData, err := o.fetchContractDataByContractID(ctx, id) if err != nil { return nil, err } @@ -301,7 +306,7 @@ func (o *ClientV2) FetchCollectionsDataByContractID(contractIDs []thirdparty.Con continue } - collectionData, err := o.fetchCollectionDataBySlug(id.ChainID, contractData.Collection) + collectionData, err := o.fetchCollectionDataBySlug(ctx, id.ChainID, contractData.Collection) if err != nil { return nil, err } diff --git a/services/wallet/thirdparty/opensea/http_client.go b/services/wallet/thirdparty/opensea/http_client.go index 51cfd764d..c038c7206 100644 --- a/services/wallet/thirdparty/opensea/http_client.go +++ b/services/wallet/thirdparty/opensea/http_client.go @@ -1,6 +1,7 @@ package opensea import ( + "context" "fmt" "io/ioutil" "net/http" @@ -27,7 +28,7 @@ func NewHTTPClient() *HTTPClient { } } -func (o *HTTPClient) doGetRequest(url string, apiKey string) ([]byte, error) { +func (o *HTTPClient) doGetRequest(ctx context.Context, url string, apiKey string) ([]byte, error) { // Ensure only one thread makes a request at a time o.getRequestLock.Lock() defer o.getRequestLock.Unlock() @@ -39,7 +40,7 @@ func (o *HTTPClient) doGetRequest(url string, apiKey string) ([]byte, error) { tmpAPIKey := "" for { - req, err := http.NewRequest(http.MethodGet, url, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, err }