use di to inject network interface provider

This commit is contained in:
Diego 2024-05-26 22:24:14 +02:00
parent d34575675c
commit fdff1ebec5
No known key found for this signature in database
GPG Key ID: C9DAC9BF68D1F806
3 changed files with 45 additions and 29 deletions

View File

@ -30,6 +30,7 @@ import
nameresolving/nameresolver,
errors, utility
import services/wildcardresolverservice
import ../di/di
export
switch, peerid, peerinfo, connection, multiaddress, crypto, errors
@ -62,6 +63,7 @@ type
services: seq[Service]
observedAddrManager: ObservedAddrManager
enableWildcardResolver: bool
container*: Container
proc new*(T: type[SwitchBuilder]): T {.public.} =
## Creates a SwitchBuilder
@ -70,7 +72,7 @@ proc new*(T: type[SwitchBuilder]): T {.public.} =
.init("/ip4/127.0.0.1/tcp/0")
.expect("Should initialize to default")
SwitchBuilder(
let sb = SwitchBuilder(
privKey: none(PrivateKey),
addresses: @[address],
secureManagers: @[],
@ -79,7 +81,12 @@ proc new*(T: type[SwitchBuilder]): T {.public.} =
maxOut: -1,
maxConnsPerPeer: MaxConnectionsPerPeer,
protoVersion: ProtoVersion,
agentVersion: AgentVersion)
agentVersion: AgentVersion,
container: Container())
register[NetworkInterfaceProvider](sb.container, networkInterfaceProvider)
sb
proc withPrivateKey*(b: SwitchBuilder, privateKey: PrivateKey): SwitchBuilder {.public.} =
## Set the private key of the switch. Will be used to
@ -211,6 +218,10 @@ proc withObservedAddrManager*(b: SwitchBuilder, observedAddrManager: ObservedAdd
b.observedAddrManager = observedAddrManager
b
proc withBinding*[T](b: SwitchBuilder, binding: proc(): T {.gcsafe, raises: [].}): SwitchBuilder =
register[T](b.container, binding)
b
proc build*(b: SwitchBuilder): Switch
{.raises: [LPError], public.} =
@ -263,8 +274,11 @@ proc build*(b: SwitchBuilder): Switch
else:
PeerStore.new(identify)
if b.enableWildcardResolver:
b.services.add(WildcardAddressResolverService.new())
try:
let networkInterfaceProvider = resolve[NetworkInterfaceProvider](b.container)
b.services.add(WildcardAddressResolverService.new(networkInterfaceProvider))
except BindingNotFoundError as e:
raise newException(LPError, "Cannot resolve NetworkInterfaceProvider", e)
let switch = newSwitch(
peerInfo = peerInfo,

View File

@ -32,14 +32,19 @@ type
## and the machine has 2 interfaces with IPs 172.217.11.174 and 64.233.177.113, the address mapper will
## expand the wildcard address to 172.217.11.174:4001 and 64.233.177.113:4001.
NetworkInterfaceProvider* =
proc(addrFamily: AddressFamily): seq[InterfaceAddress] {.gcsafe, raises: [].}
NetworkInterfaceProvider* = ref object of RootObj
proc isLoopbackOrUp(networkInterface: NetworkInterface): bool =
if (networkInterface.ifType == IfSoftwareLoopback) or
(networkInterface.state == StatusUp): true else: false
proc getAddresses(addrFamily: AddressFamily): seq[InterfaceAddress] =
proc networkInterfaceProvider*(): NetworkInterfaceProvider =
## Returns a new instance of `NetworkInterfaceProvider`.
return NetworkInterfaceProvider()
method getAddresses*(
networkInterfaceProvider: NetworkInterfaceProvider, addrFamily: AddressFamily
): seq[InterfaceAddress] {.base.} =
## This method retrieves the addresses of network interfaces based on the specified address family.
##
## The `getAddresses` method filters the available network interfaces to include only
@ -47,10 +52,12 @@ proc getAddresses(addrFamily: AddressFamily): seq[InterfaceAddress] =
## interfaces and filters them to match the provided address family.
##
## Parameters:
## - `networkInterfaceProvider`: A provider that offers access to network interfaces.
## - `addrFamily`: The address family to filter the network addresses (e.g., `AddressFamily.IPv4` or `AddressFamily.IPv6`).
##
## Returns:
## - A sequence of `InterfaceAddress` objects that match the specified address family.
echo "Getting addresses for address family: ", addrFamily
let
interfaces = getInterfaces().filterIt(it.isLoopbackOrUp())
flatInterfaceAddresses = concat(interfaces.mapIt(it.addresses))
@ -60,7 +67,7 @@ proc getAddresses(addrFamily: AddressFamily): seq[InterfaceAddress] =
proc new*(
T: typedesc[WildcardAddressResolverService],
networkInterfaceProvider: NetworkInterfaceProvider = getAddresses,
networkInterfaceProvider: NetworkInterfaceProvider = new(NetworkInterfaceProvider),
): T =
## This procedure initializes a new `WildcardAddressResolverService` with the provided network interface provider.
##
@ -106,7 +113,7 @@ proc getWildcardAddress(
var addresses: seq[MultiAddress]
maddress.getProtocolArgument(multiCodec).withValue(address):
if address == anyAddr:
let filteredInterfaceAddresses = networkInterfaceProvider(addrFamily)
let filteredInterfaceAddresses = networkInterfaceProvider.getAddresses(addrFamily)
addresses.add(
getWildcardMultiAddresses(filteredInterfaceAddresses, IPPROTO_TCP, port)
)
@ -171,6 +178,7 @@ method setup*(
let hasBeenSetup = await procCall Service(self).setup(switch)
if hasBeenSetup:
switch.peerInfo.addressMappers.add(self.addressMapper)
await self.run(switch)
return hasBeenSetup
method run*(self: WildcardAddressResolverService, switch: Switch) {.async, public.} =

View File

@ -17,10 +17,14 @@ import ../libp2p/[builders, switch]
import ../libp2p/services/wildcardresolverservice
import ../libp2p/[multiaddress, multicodec]
import ./helpers
import ../di/di
proc getAddressesMock(
addrFamily: AddressFamily
type NetworkInterfaceProviderMock* = ref object of NetworkInterfaceProvider
method getAddresses(
networkInterfaceProvider: NetworkInterfaceProviderMock, addrFamily: AddressFamily
): seq[InterfaceAddress] {.gcsafe, raises: [].} =
echo "getAddressesMock"
try:
if addrFamily == AddressFamily.IPv4:
return
@ -38,7 +42,10 @@ proc getAddressesMock(
echo "Error: " & $e.msg
fail()
proc createSwitch(svc: Service): Switch =
proc networkInterfaceProviderMock(): NetworkInterfaceProvider =
NetworkInterfaceProviderMock.new()
proc createSwitch(): Switch =
SwitchBuilder
.new()
.withRng(newRng())
@ -51,32 +58,19 @@ proc createSwitch(svc: Service): Switch =
.withTcpTransport()
.withMplex()
.withNoise()
.withServices(@[svc])
.withBinding(networkInterfaceProviderMock)
.build()
suite "WildcardAddressResolverService":
teardown:
checkTrackers()
proc setupWildcardService(): Future[
tuple[svc: Service, switch: Switch, tcpIp4: MultiAddress, tcpIp6: MultiAddress]
] {.async.} =
let svc: Service =
WildcardAddressResolverService.new(networkInterfaceProvider = getAddressesMock)
let switch = createSwitch(svc)
asyncTest "WildcardAddressResolverService must resolve wildcard addresses and stop doing so when stopped":
let switch = createSwitch()
await switch.start()
let tcpIp4 = switch.peerInfo.addrs[0][multiCodec("tcp")].get # tcp port for ip4
let tcpIp6 = switch.peerInfo.addrs[1][multiCodec("tcp")].get # tcp port for ip6
return (svc, switch, tcpIp4, tcpIp6)
let tcpIp6 = switch.peerInfo.addrs[2][multiCodec("tcp")].get # tcp port for ip6
asyncTest "WildcardAddressResolverService must resolve wildcard addresses and stop doing so when stopped":
let (svc, switch, tcpIp4, tcpIp6) = await setupWildcardService()
check switch.peerInfo.addrs ==
@[
MultiAddress.init("/ip4/0.0.0.0" & $tcpIp4).get,
MultiAddress.init("/ip6/::" & $tcpIp6).get,
]
await svc.run(switch)
check switch.peerInfo.addrs ==
@[
MultiAddress.init("/ip4/127.0.0.1" & $tcpIp4).get,