From 40193f4d1df62f89bc137bd511f7f531ccba09f4 Mon Sep 17 00:00:00 2001 From: Samuel Hawksby-Robinson Date: Wed, 22 May 2024 10:16:17 +0100 Subject: [PATCH] test_: Implemented some fixes to the filter logic --- services/wallet/router/filter.go | 42 ++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/services/wallet/router/filter.go b/services/wallet/router/filter.go index 066643633..9ac98eb9c 100644 --- a/services/wallet/router/filter.go +++ b/services/wallet/router/filter.go @@ -17,10 +17,17 @@ func filterRoutesV2(routes [][]*PathV2, amountIn *big.Int, fromLockedAmount map[ // filterNetworkComplianceV2 performs the first level of filtering based on network inclusion/exclusion criteria. func filterNetworkComplianceV2(routes [][]*PathV2, fromLockedAmount map[uint64]*hexutil.Big) [][]*PathV2 { - var filteredRoutes [][]*PathV2 + filteredRoutes := make([][]*PathV2, 0) + if routes == nil || fromLockedAmount == nil { + return filteredRoutes + } + fromIncluded, fromExcluded := setupRouteValidationMapsV2(fromLockedAmount) for _, route := range routes { + if route == nil { + continue + } if isValidForNetworkComplianceV2(route, fromIncluded, fromExcluded) { filteredRoutes = append(filteredRoutes, route) } @@ -31,7 +38,10 @@ func filterNetworkComplianceV2(routes [][]*PathV2, fromLockedAmount map[uint64]* // isValidForNetworkComplianceV2 checks if a route complies with network inclusion/exclusion criteria. func isValidForNetworkComplianceV2(route []*PathV2, fromIncluded, fromExcluded map[uint64]bool) bool { for _, path := range route { - if fromExcluded[path.From.ChainID] { + if path == nil || path.From == nil || path.To == nil { + return false + } + if _, ok := fromExcluded[path.From.ChainID]; ok { return false } if _, ok := fromIncluded[path.From.ChainID]; ok { @@ -55,7 +65,7 @@ func setupRouteValidationMapsV2(fromLockedAmount map[uint64]*hexutil.Big) (map[u for chainID, amount := range fromLockedAmount { if amount.ToInt().Cmp(zero) == 0 { - fromExcluded[chainID] = true + fromExcluded[chainID] = false } else { fromIncluded[chainID] = false } @@ -65,7 +75,7 @@ func setupRouteValidationMapsV2(fromLockedAmount map[uint64]*hexutil.Big) (map[u // filterCapacityValidationV2 performs the second level of filtering based on amount and capacity validation. func filterCapacityValidationV2(routes [][]*PathV2, amountIn *big.Int, fromLockedAmount map[uint64]*hexutil.Big) [][]*PathV2 { - var filteredRoutes [][]*PathV2 + filteredRoutes := make([][]*PathV2, 0) for _, route := range routes { if hasSufficientCapacityV2(route, amountIn, fromLockedAmount) { @@ -77,27 +87,29 @@ func filterCapacityValidationV2(routes [][]*PathV2, amountIn *big.Int, fromLocke // hasSufficientCapacityV2 checks if a route has sufficient capacity to handle the required amount. func hasSufficientCapacityV2(route []*PathV2, amountIn *big.Int, fromLockedAmount map[uint64]*hexutil.Big) bool { - totalRestAmount := calculateTotalRestAmountV2(route) - for _, path := range route { if amount, ok := fromLockedAmount[path.From.ChainID]; ok { requiredAmountIn := new(big.Int).Sub(amountIn, amount.ToInt()) - if totalRestAmount.Cmp(requiredAmountIn) < 0 { + restAmountIn := calculateRestAmountInV2(route, path) + + if restAmountIn.Cmp(requiredAmountIn) >= 0 { + path.AmountIn = amount + path.AmountInLocked = true + } else { return false } - path.AmountIn = amount - path.AmountInLocked = true - totalRestAmount.Sub(totalRestAmount, amount.ToInt()) } } return true } -// calculateTotalRestAmountV2 calculates the total maximum amount that can be used from all paths in the route. -func calculateTotalRestAmountV2(route []*PathV2) *big.Int { - total := big.NewInt(0) +// calculateRestAmountIn calculates the remaining amount in for the route excluding the specified path +func calculateRestAmountInV2(route []*PathV2, excludePath *PathV2) *big.Int { + restAmountIn := big.NewInt(0) for _, path := range route { - total.Add(total, path.AmountIn.ToInt()) + if path != excludePath { + restAmountIn.Add(restAmountIn, path.AmountIn.ToInt()) + } } - return total + return restAmountIn }