diff --git a/codex/nat.nim b/codex/nat.nim index 0f295be8..f947b13d 100644 --- a/codex/nat.nim +++ b/codex/nat.nim @@ -9,7 +9,7 @@ {.push raises: [].} import - std/[options, os, strutils, times, net], + std/[options, os, strutils, times, net, locks, atomics], stew/shims/net as stewNet, stew/[objects, results], nat_traversal/[miniupnpc, natpmp], @@ -44,11 +44,12 @@ type NatManager* = ref object portMappings: seq[PortMapping] thread: Thread[ptr NatManager] config: NatConfig - upnp: Miniupnp - npmp: NatPmp + lock: Lock + upnp {.guard: lock.}: Miniupnp + npmp {.guard: lock.}: NatPmp strategy: NatStrategy threadStarted: bool - natCloseChan: Channel[bool] + natClosed: Atomic[bool] export natutils @@ -69,67 +70,68 @@ proc getExternalIP*( ): Option[IpAddress] = var externalIP: IpAddress - if natStrategy == NatStrategy.NatAny or natStrategy == NatStrategy.NatUpnp: - if self.upnp == nil: - self.upnp = newMiniupnp() + withLock(self.lock): + if natStrategy == NatStrategy.NatAny or natStrategy == NatStrategy.NatUpnp: + if self.upnp == nil: + self.upnp = newMiniupnp() - self.upnp.discoverDelay = UPNP_TIMEOUT - let dres = self.upnp.discover() - if dres.isErr: - debug "UPnP", msg = dres.error - else: - var - msg: cstring - canContinue = true - case self.upnp.selectIGD() - of IGDNotFound: - msg = "Internet Gateway Device not found. Giving up." - canContinue = false - of IGDFound: - msg = "Internet Gateway Device found." - of IGDNotConnected: - msg = "Internet Gateway Device found but it's not connected. Trying anyway." - of NotAnIGD: - msg = - "Some device found, but it's not recognised as an Internet Gateway Device. Trying anyway." - of IGDIpNotRoutable: - msg = - "Internet Gateway Device found and is connected, but with a reserved or non-routable IP. Trying anyway." - if not quiet: - debug "UPnP", msg - if canContinue: - let ires = self.upnp.externalIPAddress() - if ires.isErr: - debug "UPnP", msg = ires.error + self.upnp.discoverDelay = UPNP_TIMEOUT + let dres = self.upnp.discover() + if dres.isErr: + debug "UPnP", msg = dres.error + else: + var + msg: cstring + canContinue = true + case self.upnp.selectIGD() + of IGDNotFound: + msg = "Internet Gateway Device not found. Giving up." + canContinue = false + of IGDFound: + msg = "Internet Gateway Device found." + of IGDNotConnected: + msg = "Internet Gateway Device found but it's not connected. Trying anyway." + of NotAnIGD: + msg = + "Some device found, but it's not recognised as an Internet Gateway Device. Trying anyway." + of IGDIpNotRoutable: + msg = + "Internet Gateway Device found and is connected, but with a reserved or non-routable IP. Trying anyway." + if not quiet: + debug "UPnP", msg + if canContinue: + let ires = self.upnp.externalIPAddress() + if ires.isErr: + debug "UPnP", msg = ires.error + else: + # if we got this far, UPnP is working and we don't need to try NAT-PMP + try: + externalIP = parseIpAddress(ires.value) + self.strategy = NatStrategy.NatUpnp + return some(externalIP) + except ValueError as e: + error "parseIpAddress() exception", err = e.msg + return + + if natStrategy == NatStrategy.NatAny or natStrategy == NatStrategy.NatPmp: + if self.npmp == nil: + self.npmp = newNatPmp() + let nres = self.npmp.init() + if nres.isErr: + debug "NAT-PMP", msg = nres.error + else: + let nires = self.npmp.externalIPAddress() + if nires.isErr: + debug "NAT-PMP", msg = nires.error else: - # if we got this far, UPnP is working and we don't need to try NAT-PMP try: - externalIP = parseIpAddress(ires.value) - self.strategy = NatStrategy.NatUpnp + externalIP = parseIpAddress($(nires.value)) + self.strategy = NatPmp return some(externalIP) except ValueError as e: error "parseIpAddress() exception", err = e.msg return - if natStrategy == NatStrategy.NatAny or natStrategy == NatStrategy.NatPmp: - if self.npmp == nil: - self.npmp = newNatPmp() - let nres = self.npmp.init() - if nres.isErr: - debug "NAT-PMP", msg = nres.error - else: - let nires = self.npmp.externalIPAddress() - if nires.isErr: - debug "NAT-PMP", msg = nires.error - else: - try: - externalIP = parseIpAddress($(nires.value)) - self.strategy = NatPmp - return some(externalIP) - except ValueError as e: - error "parseIpAddress() exception", err = e.msg - return - # This queries the routing table to get the "preferred source" attribute and # checks if it's a public IP. If so, then it's our public IP. # @@ -176,62 +178,64 @@ proc doPortMapping( extTcpPort: Port extUdpPort: Port - if strategy == NatStrategy.NatUpnp: - for t in [(tcpPort, UPNPProtocol.TCP), (udpPort, UPNPProtocol.UDP)]: - let - (port, protocol) = t - pmres = self.upnp.addPortMapping( - externalPort = $port, - protocol = protocol, - internalHost = self.upnp.lanAddr, - internalPort = $port, - desc = description, - leaseDuration = 0, - ) - if pmres.isErr: - error "UPnP port mapping", msg = pmres.error, port - return - else: - # let's check it - let cres = - self.upnp.getSpecificPortMapping(externalPort = $port, protocol = protocol) - if cres.isErr: - warn "UPnP port mapping check failed. Assuming the check itself is broken and the port mapping was done.", - msg = cres.error + withLock(self.lock): + if strategy == NatStrategy.NatUpnp: + for t in [(tcpPort, UPNPProtocol.TCP), (udpPort, UPNPProtocol.UDP)]: + let + (port, protocol) = t + pmres = self.upnp.addPortMapping( + externalPort = $port, + protocol = protocol, + internalHost = self.upnp.lanAddr, + internalPort = $port, + desc = description, + leaseDuration = 0, + ) + if pmres.isErr: + error "UPnP port mapping", msg = pmres.error, port + return + else: + # let's check it + let cres = + self.upnp.getSpecificPortMapping(externalPort = $port, protocol = protocol) + if cres.isErr: + warn "UPnP port mapping check failed. Assuming the check itself is broken and the port mapping was done.", + msg = cres.error - info "UPnP: added port mapping", - externalPort = port, internalPort = port, protocol = protocol - case protocol - of UPNPProtocol.TCP: - extTcpPort = port - of UPNPProtocol.UDP: - extUdpPort = port - elif strategy == NatStrategy.NatPmp: - for t in [(tcpPort, NatPmpProtocol.TCP), (udpPort, NatPmpProtocol.UDP)]: - let - (port, protocol) = t - pmres = self.npmp.addPortMapping( - eport = port.cushort, - iport = port.cushort, - protocol = protocol, - lifetime = NATPMP_LIFETIME, - ) - if pmres.isErr: - error "NAT-PMP port mapping", msg = pmres.error, port - return - else: - let extPort = Port(pmres.value) - info "NAT-PMP: added port mapping", - externalPort = extPort, internalPort = port, protocol = protocol - case protocol - of NatPmpProtocol.TCP: - extTcpPort = extPort - of NatPmpProtocol.UDP: - extUdpPort = extPort + info "UPnP: added port mapping", + externalPort = port, internalPort = port, protocol = protocol + case protocol + of UPNPProtocol.TCP: + extTcpPort = port + of UPNPProtocol.UDP: + extUdpPort = port + elif strategy == NatStrategy.NatPmp: + for t in [(tcpPort, NatPmpProtocol.TCP), (udpPort, NatPmpProtocol.UDP)]: + let + (port, protocol) = t + pmres = self.npmp.addPortMapping( + eport = port.cushort, + iport = port.cushort, + protocol = protocol, + lifetime = NATPMP_LIFETIME, + ) + if pmres.isErr: + error "NAT-PMP port mapping", msg = pmres.error, port + return + else: + let extPort = Port(pmres.value) + info "NAT-PMP: added port mapping", + externalPort = extPort, internalPort = port, protocol = protocol + case protocol + of NatPmpProtocol.TCP: + extTcpPort = extPort + of NatPmpProtocol.UDP: + extUdpPort = extPort return some((extTcpPort, extUdpPort)) proc repeatPortMapping(self: ptr NatManager) {.thread, raises: [ValueError].} = ignoreSignalsInThread() + let interval = initDuration(seconds = PORT_MAPPING_INTERVAL) sleepDuration = 1_000 # in ms, also the maximum delay after pressing Ctrl-C @@ -242,31 +246,22 @@ proc repeatPortMapping(self: ptr NatManager) {.thread, raises: [ValueError].} = # C pointers with other instances that have already been garbage collected, so # we use threadvars instead and initialise them again with getExternalIP(), # even though we don't need the external IP's value. - let ipres = getExternalIP(self[], self[].strategy, quiet = true) - if ipres.isSome: - while true: - # we're being silly here with this channel polling because we can't - # select on Nim channels like on Go ones - let (dataAvailable, _) = - try: - self[].natCloseChan.tryRecv() - except Exception: - (false, false) - if dataAvailable: - return - else: - let currTime = now() - if currTime >= (lastUpdate + interval): - for entry in self[].portMappings: - discard doPortMapping( - self[], - self[].strategy, - entry.externalTcpPort, - entry.externalUdpPort, - entry.description, - ) - lastUpdate = currTime - sleep(sleepDuration) + while self.natClosed.load() == false: + # we're being silly here with this channel polling because we can't + # select on Nim channels like on Go ones + let currTime = now() + if currTime >= (lastUpdate + interval): + for entry in self[].portMappings: + discard doPortMapping( + self[], + self[].strategy, + entry.externalTcpPort, + entry.externalUdpPort, + entry.description, + ) + lastUpdate = currTime + + sleep(sleepDuration) proc stop*(self: NatManager) {.async.} = # stop the thread @@ -276,9 +271,8 @@ proc stop*(self: NatManager) {.async.} = return try: - self.natCloseChan.send(true) + self.natClosed.store(true) self.thread.joinThread() - self.natCloseChan.close() except Exception as exc: warn "Failed to stop NAT port mapping renewal thread", exc = exc.msg @@ -290,8 +284,7 @@ proc stop*(self: NatManager) {.async.} = # In Windows, a new thread is created for the signal handler, so we need to # initialise our threadvars again. - let ipres = getExternalIP(self, self.strategy, quiet = true) - if ipres.isSome: + withLock(self.lock): if self.strategy == NatStrategy.NatUpnp: for entry in self.portMappings: for t in [ @@ -406,7 +399,7 @@ proc setupAddress*( proc startPortMappingThread*(self: NatManager) = if self.portMappings.len > 0: - self.natCloseChan.open() + self.natClosed.store(false) try: self.thread.createThread(repeatPortMapping, (self.addr)) self.threadStarted = true @@ -417,10 +410,10 @@ proc nattedAddress*( self: NatManager, addrs: seq[MultiAddress], udpPort: Port ): tuple[libp2p, discovery: seq[MultiAddress]] = ## Takes a NAT configuration, sequence of multiaddresses and UDP port and returns: - ## - Modified multiaddresses with NAT-mapped addresses for libp2p + ## - Modified multiaddresses with NAT-mapped addresses for libp2p ## - Discovery addresses with NAT-mapped UDP ports - var discoveryAddrs = newSeq[MultiAddress](0) + var discoveryAddrs: seq[MultiAddress] let newAddrs = addrs.mapIt: block: # Extract IP address and port from the multiaddress @@ -445,7 +438,7 @@ proc nattedAddress*( startPortMappingThread(self) (newAddrs, discoveryAddrs) -proc new*(_: type NatManager, config: NatConfig): NatManager = +func new*(_: type NatManager, config: NatConfig): NatManager = result = NatManager( portMappings: @[], config: config,