diff --git a/services/wallet/router/filter.go b/services/wallet/router/filter.go new file mode 100644 index 000000000..365c85e7b --- /dev/null +++ b/services/wallet/router/filter.go @@ -0,0 +1,104 @@ +package router + +import ( + "math/big" + + "github.com/ethereum/go-ethereum/common/hexutil" +) + +func filterRoutesV2(routes [][]*PathV2, amountIn *big.Int, fromLockedAmount map[uint64]*hexutil.Big) [][]*PathV2 { + if len(fromLockedAmount) == 0 { + return routes + } + + routesAfterNetworkCompliance := filterNetworkComplianceV2(routes, fromLockedAmount) + return filterCapacityValidationV2(routesAfterNetworkCompliance, amountIn, fromLockedAmount) +} + +// filterNetworkComplianceV2 performs the first level of filtering based on network inclusion/exclusion criteria. +func filterNetworkComplianceV2(routes [][]*PathV2, fromLockedAmount map[uint64]*hexutil.Big) [][]*PathV2 { + var filteredRoutes [][]*PathV2 + + for _, route := range routes { + if isValidForNetworkComplianceV2(route, fromLockedAmount) { + filteredRoutes = append(filteredRoutes, route) + } + } + return filteredRoutes +} + +// isValidForNetworkComplianceV2 checks if a route complies with network inclusion/exclusion criteria. +func isValidForNetworkComplianceV2(route []*PathV2, fromLockedAmount map[uint64]*hexutil.Big) bool { + fromIncluded, fromExcluded := setupRouteValidationMapsV2(fromLockedAmount) + + for _, path := range route { + if fromExcluded[path.From.ChainID] { + return false + } + if _, ok := fromIncluded[path.From.ChainID]; ok { + fromIncluded[path.From.ChainID] = true + } + } + + for _, included := range fromIncluded { + if !included { + return false + } + } + + return true +} + +// setupRouteValidationMapsV2 initializes maps for network inclusion and exclusion based on locked amounts. +func setupRouteValidationMapsV2(fromLockedAmount map[uint64]*hexutil.Big) (map[uint64]bool, map[uint64]bool) { + fromIncluded := make(map[uint64]bool) + fromExcluded := make(map[uint64]bool) + + for chainID, amount := range fromLockedAmount { + if amount.ToInt().Cmp(zero) == 0 { + fromExcluded[chainID] = true + } else { + fromIncluded[chainID] = false + } + } + return fromIncluded, fromExcluded +} + +// filterCapacityValidationV2 performs the second level of filtering based on amount and capacity validation. +func filterCapacityValidationV2(routes [][]*PathV2, amountIn *big.Int, fromLockedAmount map[uint64]*hexutil.Big) [][]*PathV2 { + var filteredRoutes [][]*PathV2 + + for _, route := range routes { + if hasSufficientCapacityV2(route, amountIn, fromLockedAmount) { + filteredRoutes = append(filteredRoutes, route) + } + } + return filteredRoutes +} + +// hasSufficientCapacityV2 checks if a route has sufficient capacity to handle the required amount. +func hasSufficientCapacityV2(route []*PathV2, amountIn *big.Int, fromLockedAmount map[uint64]*hexutil.Big) bool { + totalRestAmount := calculateTotalRestAmountV2(route) + + for _, path := range route { + if amount, ok := fromLockedAmount[path.From.ChainID]; ok { + requiredAmountIn := new(big.Int).Sub(amountIn, amount.ToInt()) + if totalRestAmount.Cmp(requiredAmountIn) < 0 { + return false + } + path.AmountIn = amount + path.AmountInLocked = true + totalRestAmount.Sub(totalRestAmount, amount.ToInt()) + } + } + return true +} + +// calculateTotalRestAmountV2 calculates the total maximum amount that can be used from all paths in the route. +func calculateTotalRestAmountV2(route []*PathV2) *big.Int { + total := big.NewInt(0) + for _, path := range route { + total.Add(total, path.AmountIn.ToInt()) + } + return total +} diff --git a/services/wallet/router/router_v2.go b/services/wallet/router/router_v2.go index 0410d711f..a5bbef099 100644 --- a/services/wallet/router/router_v2.go +++ b/services/wallet/router/router_v2.go @@ -190,76 +190,6 @@ func (n NodeV2) buildAllRoutesV2() [][]*PathV2 { return res } -func filterRoutesV2(routes [][]*PathV2, amountIn *big.Int, fromLockedAmount map[uint64]*hexutil.Big) [][]*PathV2 { - if len(fromLockedAmount) == 0 { - return routes - } - - filteredRoutesLevel1 := make([][]*PathV2, 0) - for _, route := range routes { - routeOk := true - fromIncluded := make(map[uint64]bool) - fromExcluded := make(map[uint64]bool) - for chainID, amount := range fromLockedAmount { - if amount.ToInt().Cmp(zero) == 0 { - fromExcluded[chainID] = false - } else { - fromIncluded[chainID] = false - } - - } - for _, path := range route { - if _, ok := fromExcluded[path.From.ChainID]; ok { - routeOk = false - break - } - if _, ok := fromIncluded[path.From.ChainID]; ok { - fromIncluded[path.From.ChainID] = true - } - } - for _, value := range fromIncluded { - if !value { - routeOk = false - break - } - } - - if routeOk { - filteredRoutesLevel1 = append(filteredRoutesLevel1, route) - } - } - - filteredRoutesLevel2 := make([][]*PathV2, 0) - for _, route := range filteredRoutesLevel1 { - routeOk := true - for _, path := range route { - if amount, ok := fromLockedAmount[path.From.ChainID]; ok { - requiredAmountIn := new(big.Int).Sub(amountIn, amount.ToInt()) - restAmountIn := big.NewInt(0) - - for _, otherPath := range route { - if path.Equal(otherPath) { - continue - } - restAmountIn = new(big.Int).Add(otherPath.AmountIn.ToInt(), restAmountIn) - } - if restAmountIn.Cmp(requiredAmountIn) >= 0 { - path.AmountIn = amount - path.AmountInLocked = true - } else { - routeOk = false - break - } - } - } - if routeOk { - filteredRoutesLevel2 = append(filteredRoutesLevel2, route) - } - } - - return filteredRoutesLevel2 -} - func findBestV2(routes [][]*PathV2, tokenPrice float64, nativeChainTokenPrice float64) []*PathV2 { var best []*PathV2 bestCost := big.NewFloat(math.Inf(1))