refactor(nat): enhance NAT manager with locking and atomic operations

This commit is contained in:
Dmitriy Ryajov 2025-05-13 18:22:46 -06:00
parent ffda691ea7
commit 2381225b70
No known key found for this signature in database
GPG Key ID: DA8C680CE7C657A4

View File

@ -9,7 +9,7 @@
{.push raises: [].}
import
std/[options, os, strutils, times, net],
std/[options, os, strutils, times, net, locks, atomics],
stew/shims/net as stewNet,
stew/[objects, results],
nat_traversal/[miniupnpc, natpmp],
@ -44,11 +44,12 @@ type NatManager* = ref object
portMappings: seq[PortMapping]
thread: Thread[ptr NatManager]
config: NatConfig
upnp: Miniupnp
npmp: NatPmp
lock: Lock
upnp {.guard: lock.}: Miniupnp
npmp {.guard: lock.}: NatPmp
strategy: NatStrategy
threadStarted: bool
natCloseChan: Channel[bool]
natClosed: Atomic[bool]
export natutils
@ -69,67 +70,68 @@ proc getExternalIP*(
): Option[IpAddress] =
var externalIP: IpAddress
if natStrategy == NatStrategy.NatAny or natStrategy == NatStrategy.NatUpnp:
if self.upnp == nil:
self.upnp = newMiniupnp()
withLock(self.lock):
if natStrategy == NatStrategy.NatAny or natStrategy == NatStrategy.NatUpnp:
if self.upnp == nil:
self.upnp = newMiniupnp()
self.upnp.discoverDelay = UPNP_TIMEOUT
let dres = self.upnp.discover()
if dres.isErr:
debug "UPnP", msg = dres.error
else:
var
msg: cstring
canContinue = true
case self.upnp.selectIGD()
of IGDNotFound:
msg = "Internet Gateway Device not found. Giving up."
canContinue = false
of IGDFound:
msg = "Internet Gateway Device found."
of IGDNotConnected:
msg = "Internet Gateway Device found but it's not connected. Trying anyway."
of NotAnIGD:
msg =
"Some device found, but it's not recognised as an Internet Gateway Device. Trying anyway."
of IGDIpNotRoutable:
msg =
"Internet Gateway Device found and is connected, but with a reserved or non-routable IP. Trying anyway."
if not quiet:
debug "UPnP", msg
if canContinue:
let ires = self.upnp.externalIPAddress()
if ires.isErr:
debug "UPnP", msg = ires.error
self.upnp.discoverDelay = UPNP_TIMEOUT
let dres = self.upnp.discover()
if dres.isErr:
debug "UPnP", msg = dres.error
else:
var
msg: cstring
canContinue = true
case self.upnp.selectIGD()
of IGDNotFound:
msg = "Internet Gateway Device not found. Giving up."
canContinue = false
of IGDFound:
msg = "Internet Gateway Device found."
of IGDNotConnected:
msg = "Internet Gateway Device found but it's not connected. Trying anyway."
of NotAnIGD:
msg =
"Some device found, but it's not recognised as an Internet Gateway Device. Trying anyway."
of IGDIpNotRoutable:
msg =
"Internet Gateway Device found and is connected, but with a reserved or non-routable IP. Trying anyway."
if not quiet:
debug "UPnP", msg
if canContinue:
let ires = self.upnp.externalIPAddress()
if ires.isErr:
debug "UPnP", msg = ires.error
else:
# if we got this far, UPnP is working and we don't need to try NAT-PMP
try:
externalIP = parseIpAddress(ires.value)
self.strategy = NatStrategy.NatUpnp
return some(externalIP)
except ValueError as e:
error "parseIpAddress() exception", err = e.msg
return
if natStrategy == NatStrategy.NatAny or natStrategy == NatStrategy.NatPmp:
if self.npmp == nil:
self.npmp = newNatPmp()
let nres = self.npmp.init()
if nres.isErr:
debug "NAT-PMP", msg = nres.error
else:
let nires = self.npmp.externalIPAddress()
if nires.isErr:
debug "NAT-PMP", msg = nires.error
else:
# if we got this far, UPnP is working and we don't need to try NAT-PMP
try:
externalIP = parseIpAddress(ires.value)
self.strategy = NatStrategy.NatUpnp
externalIP = parseIpAddress($(nires.value))
self.strategy = NatPmp
return some(externalIP)
except ValueError as e:
error "parseIpAddress() exception", err = e.msg
return
if natStrategy == NatStrategy.NatAny or natStrategy == NatStrategy.NatPmp:
if self.npmp == nil:
self.npmp = newNatPmp()
let nres = self.npmp.init()
if nres.isErr:
debug "NAT-PMP", msg = nres.error
else:
let nires = self.npmp.externalIPAddress()
if nires.isErr:
debug "NAT-PMP", msg = nires.error
else:
try:
externalIP = parseIpAddress($(nires.value))
self.strategy = NatPmp
return some(externalIP)
except ValueError as e:
error "parseIpAddress() exception", err = e.msg
return
# This queries the routing table to get the "preferred source" attribute and
# checks if it's a public IP. If so, then it's our public IP.
#
@ -176,62 +178,64 @@ proc doPortMapping(
extTcpPort: Port
extUdpPort: Port
if strategy == NatStrategy.NatUpnp:
for t in [(tcpPort, UPNPProtocol.TCP), (udpPort, UPNPProtocol.UDP)]:
let
(port, protocol) = t
pmres = self.upnp.addPortMapping(
externalPort = $port,
protocol = protocol,
internalHost = self.upnp.lanAddr,
internalPort = $port,
desc = description,
leaseDuration = 0,
)
if pmres.isErr:
error "UPnP port mapping", msg = pmres.error, port
return
else:
# let's check it
let cres =
self.upnp.getSpecificPortMapping(externalPort = $port, protocol = protocol)
if cres.isErr:
warn "UPnP port mapping check failed. Assuming the check itself is broken and the port mapping was done.",
msg = cres.error
withLock(self.lock):
if strategy == NatStrategy.NatUpnp:
for t in [(tcpPort, UPNPProtocol.TCP), (udpPort, UPNPProtocol.UDP)]:
let
(port, protocol) = t
pmres = self.upnp.addPortMapping(
externalPort = $port,
protocol = protocol,
internalHost = self.upnp.lanAddr,
internalPort = $port,
desc = description,
leaseDuration = 0,
)
if pmres.isErr:
error "UPnP port mapping", msg = pmres.error, port
return
else:
# let's check it
let cres =
self.upnp.getSpecificPortMapping(externalPort = $port, protocol = protocol)
if cres.isErr:
warn "UPnP port mapping check failed. Assuming the check itself is broken and the port mapping was done.",
msg = cres.error
info "UPnP: added port mapping",
externalPort = port, internalPort = port, protocol = protocol
case protocol
of UPNPProtocol.TCP:
extTcpPort = port
of UPNPProtocol.UDP:
extUdpPort = port
elif strategy == NatStrategy.NatPmp:
for t in [(tcpPort, NatPmpProtocol.TCP), (udpPort, NatPmpProtocol.UDP)]:
let
(port, protocol) = t
pmres = self.npmp.addPortMapping(
eport = port.cushort,
iport = port.cushort,
protocol = protocol,
lifetime = NATPMP_LIFETIME,
)
if pmres.isErr:
error "NAT-PMP port mapping", msg = pmres.error, port
return
else:
let extPort = Port(pmres.value)
info "NAT-PMP: added port mapping",
externalPort = extPort, internalPort = port, protocol = protocol
case protocol
of NatPmpProtocol.TCP:
extTcpPort = extPort
of NatPmpProtocol.UDP:
extUdpPort = extPort
info "UPnP: added port mapping",
externalPort = port, internalPort = port, protocol = protocol
case protocol
of UPNPProtocol.TCP:
extTcpPort = port
of UPNPProtocol.UDP:
extUdpPort = port
elif strategy == NatStrategy.NatPmp:
for t in [(tcpPort, NatPmpProtocol.TCP), (udpPort, NatPmpProtocol.UDP)]:
let
(port, protocol) = t
pmres = self.npmp.addPortMapping(
eport = port.cushort,
iport = port.cushort,
protocol = protocol,
lifetime = NATPMP_LIFETIME,
)
if pmres.isErr:
error "NAT-PMP port mapping", msg = pmres.error, port
return
else:
let extPort = Port(pmres.value)
info "NAT-PMP: added port mapping",
externalPort = extPort, internalPort = port, protocol = protocol
case protocol
of NatPmpProtocol.TCP:
extTcpPort = extPort
of NatPmpProtocol.UDP:
extUdpPort = extPort
return some((extTcpPort, extUdpPort))
proc repeatPortMapping(self: ptr NatManager) {.thread, raises: [ValueError].} =
ignoreSignalsInThread()
let
interval = initDuration(seconds = PORT_MAPPING_INTERVAL)
sleepDuration = 1_000 # in ms, also the maximum delay after pressing Ctrl-C
@ -242,31 +246,22 @@ proc repeatPortMapping(self: ptr NatManager) {.thread, raises: [ValueError].} =
# C pointers with other instances that have already been garbage collected, so
# we use threadvars instead and initialise them again with getExternalIP(),
# even though we don't need the external IP's value.
let ipres = getExternalIP(self[], self[].strategy, quiet = true)
if ipres.isSome:
while true:
# we're being silly here with this channel polling because we can't
# select on Nim channels like on Go ones
let (dataAvailable, _) =
try:
self[].natCloseChan.tryRecv()
except Exception:
(false, false)
if dataAvailable:
return
else:
let currTime = now()
if currTime >= (lastUpdate + interval):
for entry in self[].portMappings:
discard doPortMapping(
self[],
self[].strategy,
entry.externalTcpPort,
entry.externalUdpPort,
entry.description,
)
lastUpdate = currTime
sleep(sleepDuration)
while self.natClosed.load() == false:
# we're being silly here with this channel polling because we can't
# select on Nim channels like on Go ones
let currTime = now()
if currTime >= (lastUpdate + interval):
for entry in self[].portMappings:
discard doPortMapping(
self[],
self[].strategy,
entry.externalTcpPort,
entry.externalUdpPort,
entry.description,
)
lastUpdate = currTime
sleep(sleepDuration)
proc stop*(self: NatManager) {.async.} =
# stop the thread
@ -276,9 +271,8 @@ proc stop*(self: NatManager) {.async.} =
return
try:
self.natCloseChan.send(true)
self.natClosed.store(true)
self.thread.joinThread()
self.natCloseChan.close()
except Exception as exc:
warn "Failed to stop NAT port mapping renewal thread", exc = exc.msg
@ -290,8 +284,7 @@ proc stop*(self: NatManager) {.async.} =
# In Windows, a new thread is created for the signal handler, so we need to
# initialise our threadvars again.
let ipres = getExternalIP(self, self.strategy, quiet = true)
if ipres.isSome:
withLock(self.lock):
if self.strategy == NatStrategy.NatUpnp:
for entry in self.portMappings:
for t in [
@ -406,7 +399,7 @@ proc setupAddress*(
proc startPortMappingThread*(self: NatManager) =
if self.portMappings.len > 0:
self.natCloseChan.open()
self.natClosed.store(false)
try:
self.thread.createThread(repeatPortMapping, (self.addr))
self.threadStarted = true
@ -417,10 +410,10 @@ proc nattedAddress*(
self: NatManager, addrs: seq[MultiAddress], udpPort: Port
): tuple[libp2p, discovery: seq[MultiAddress]] =
## Takes a NAT configuration, sequence of multiaddresses and UDP port and returns:
## - Modified multiaddresses with NAT-mapped addresses for libp2p
## - Modified multiaddresses with NAT-mapped addresses for libp2p
## - Discovery addresses with NAT-mapped UDP ports
var discoveryAddrs = newSeq[MultiAddress](0)
var discoveryAddrs: seq[MultiAddress]
let newAddrs = addrs.mapIt:
block:
# Extract IP address and port from the multiaddress
@ -445,7 +438,7 @@ proc nattedAddress*(
startPortMappingThread(self)
(newAddrs, discoveryAddrs)
proc new*(_: type NatManager, config: NatConfig): NatManager =
func new*(_: type NatManager, config: NatConfig): NatManager =
result = NatManager(
portMappings: @[],
config: config,