From f6f1b56cf736327a37d7d6acb75dfb01e824d3a5 Mon Sep 17 00:00:00 2001 From: Samuel Hawksby-Robinson Date: Mon, 10 Jun 2024 16:15:41 +0100 Subject: [PATCH] test(router_validation)_: Added validation to prevent all excluded networks --- services/wallet/router/router_v2.go | 64 ++++++++++++++++++++++++----- 1 file changed, 53 insertions(+), 11 deletions(-) diff --git a/services/wallet/router/router_v2.go b/services/wallet/router/router_v2.go index a5c9fc425..690af2df1 100644 --- a/services/wallet/router/router_v2.go +++ b/services/wallet/router/router_v2.go @@ -2,6 +2,7 @@ package router import ( "context" + "errors" "fmt" "math" "math/big" @@ -10,7 +11,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" - "github.com/status-im/status-go/errors" + sErrors "github.com/status-im/status-go/errors" "github.com/status-im/status-go/params" "github.com/status-im/status-go/services/ens" "github.com/status-im/status-go/services/wallet/async" @@ -429,6 +430,47 @@ func validateInputData(input *RouteInputParams) error { } } + return validateFromLockedAmount(input.FromLockedAmount, input.testnetMode) +} + +func validateFromLockedAmount(fromLockedAmount map[uint64]*hexutil.Big, isTestnetMode bool) error { + if fromLockedAmount == nil || len(fromLockedAmount) == 0 { + return nil + } + + chainIDSet := make(map[uint64]bool) + excludedChainCount := 0 + + for chainID, amount := range fromLockedAmount { + if isTestnetMode { + if !supportedTestNetworks[chainID] { + return errors.New("locked amount is not supported for the selected network") + } + } else { + if !supportedNetworks[chainID] { + return errors.New("locked amount is not supported for the selected network") + } + } + + // Check locked amount is not negative + if amount == nil || amount.ToInt().Sign() < 0 { + return errors.New("locked amount must not be negative") + } + + // Check if locked chain ID is a duplicate + if _, exists := chainIDSet[chainID]; exists { + // Handle duplicate chain ID + return fmt.Errorf("a chain ID may only appear once, duplicate chain ID found '%d'", chainID) + } + chainIDSet[chainID] = amount.ToInt().Sign() > 0 + if !chainIDSet[chainID] { + excludedChainCount++ + } + } + if (!isTestnetMode && excludedChainCount == len(supportedNetworks)) || + (isTestnetMode && excludedChainCount == len(supportedTestNetworks)) { + return errors.New("all supported chains are excluded, routing impossible") + } return nil } @@ -439,7 +481,7 @@ func (r *Router) SuggestedRoutesV2Async(input *RouteInputParams) { if err != nil { errResponse := &ErrorResponseWithUUID{ Uuid: input.Uuid, - ErrorResponse: errors.CreateErrorResponseFromError(err), + ErrorResponse: sErrors.CreateErrorResponseFromError(err), } signal.SendWalletEvent(signal.SuggestedRoutes, errResponse) return @@ -455,7 +497,7 @@ func (r *Router) StopSuggestedRoutesV2AsyncCalcualtion() { func (r *Router) SuggestedRoutesV2(ctx context.Context, input *RouteInputParams) (*SuggestedRoutesV2, error) { testnetMode, err := r.rpcClient.NetworkManager.GetTestNetworksEnabled() if err != nil { - return nil, errors.CreateErrorResponseFromError(err) + return nil, sErrors.CreateErrorResponseFromError(err) } input.testnetMode = testnetMode @@ -469,12 +511,12 @@ func (r *Router) SuggestedRoutesV2(ctx context.Context, input *RouteInputParams) err = validateInputData(input) if err != nil { - return nil, errors.CreateErrorResponseFromError(err) + return nil, sErrors.CreateErrorResponseFromError(err) } candidates, err := r.resolveCandidates(ctx, input) if err != nil { - return nil, errors.CreateErrorResponseFromError(err) + return nil, sErrors.CreateErrorResponseFromError(err) } return r.resolveRoutes(ctx, input, candidates) @@ -488,7 +530,7 @@ func (r *Router) resolveCandidates(ctx context.Context, input *RouteInputParams) networks, err = r.rpcClient.NetworkManager.Get(false) if err != nil { - return nil, errors.CreateErrorResponseFromError(err) + return nil, sErrors.CreateErrorResponseFromError(err) } var ( @@ -711,12 +753,12 @@ func (r *Router) checkBalancesForTheBestRoute(ctx context.Context, bestRoute []* if input.SendType == ERC1155Transfer { tokenBalance, err = r.getERC1155Balance(ctx, path.FromChain, path.FromToken, input.AddrFrom) if err != nil { - return errors.CreateErrorResponseFromError(err) + return sErrors.CreateErrorResponseFromError(err) } } else if input.SendType != ERC721Transfer { tokenBalance, err = r.getBalance(ctx, path.FromChain, path.FromToken, input.AddrFrom) if err != nil { - return errors.CreateErrorResponseFromError(err) + return sErrors.CreateErrorResponseFromError(err) } } } @@ -737,7 +779,7 @@ func (r *Router) checkBalancesForTheBestRoute(ctx context.Context, bestRoute []* nativeBalance, err = r.getBalance(ctx, path.FromChain, nativeToken, input.AddrFrom) if err != nil { - return errors.CreateErrorResponseFromError(err) + return sErrors.CreateErrorResponseFromError(err) } } @@ -783,7 +825,7 @@ func (r *Router) resolveRoutes(ctx context.Context, input *RouteInputParams, can } else { prices, err = input.SendType.FetchPrices(r.marketManager, input.TokenID) if err != nil { - return nil, errors.CreateErrorResponseFromError(err) + return nil, sErrors.CreateErrorResponseFromError(err) } } @@ -806,7 +848,7 @@ func (r *Router) resolveRoutes(ctx context.Context, input *RouteInputParams, can allRoutes = removeBestRouteFromAllRouters(allRoutes, best) continue } else { - return suggestedRoutes, errors.CreateErrorResponseFromError(err) + return suggestedRoutes, sErrors.CreateErrorResponseFromError(err) } }