diff --git a/chronos.nimble b/chronos.nimble index 1dd823b..9a21e19 100644 --- a/chronos.nimble +++ b/chronos.nimble @@ -1,5 +1,5 @@ packageName = "chronos" -version = "3.0.1" +version = "3.0.2" author = "Status Research & Development GmbH" description = "Chronos" license = "Apache License 2.0 or MIT" diff --git a/chronos/transports/common.nim b/chronos/transports/common.nim index f79ca44..4d8e359 100644 --- a/chronos/transports/common.nim +++ b/chronos/transports/common.nim @@ -10,6 +10,7 @@ {.push raises: [Defect].} import std/[os, strutils, nativesockets, net] +import stew/base10 import ../asyncloop export net @@ -130,51 +131,54 @@ var proc `==`*(lhs, rhs: TransportAddress): bool = ## Compare two transport addresses ``lhs`` and ``rhs``. Return ``true`` if ## addresses are equal. - if lhs.family != lhs.family: + if lhs.family != rhs.family: return false - if lhs.family == AddressFamily.IPv4: - result = equalMem(unsafeAddr lhs.address_v4[0], - unsafeAddr rhs.address_v4[0], sizeof(lhs.address_v4)) and - (lhs.port == rhs.port) - elif lhs.family == AddressFamily.IPv6: - result = equalMem(unsafeAddr lhs.address_v6[0], - unsafeAddr rhs.address_v6[0], sizeof(lhs.address_v6)) and - (lhs.port == rhs.port) - elif lhs.family == AddressFamily.Unix: - result = equalMem(unsafeAddr lhs.address_un[0], - unsafeAddr rhs.address_un[0], sizeof(lhs.address_un)) + case lhs.family + of AddressFamily.None: + true + of AddressFamily.IPv4: + equalMem(unsafeAddr lhs.address_v4[0], + unsafeAddr rhs.address_v4[0], sizeof(lhs.address_v4)) and + (lhs.port == rhs.port) + of AddressFamily.IPv6: + equalMem(unsafeAddr lhs.address_v6[0], + unsafeAddr rhs.address_v6[0], sizeof(lhs.address_v6)) and + (lhs.port == rhs.port) + of AddressFamily.Unix: + equalMem(unsafeAddr lhs.address_un[0], + unsafeAddr rhs.address_un[0], sizeof(lhs.address_un)) proc getDomain*(address: TransportAddress): Domain = ## Returns OS specific Domain from TransportAddress. case address.family of AddressFamily.IPv4: - result = Domain.AF_INET + Domain.AF_INET of AddressFamily.IPv6: - result = Domain.AF_INET6 + Domain.AF_INET6 of AddressFamily.Unix: when defined(windows): - result = cast[Domain](1) + cast[Domain](1) else: - result = Domain.AF_UNIX + Domain.AF_UNIX else: - result = cast[Domain](0) + cast[Domain](0) proc `$`*(address: TransportAddress): string = ## Returns string representation of ``address``. case address.family of AddressFamily.IPv4: - var a = IpAddress( - family: IpAddressFamily.IPv4, - address_v4: address.address_v4 - ) - result = $a - result.add(":") - result.add($int(address.port)) + var a = IpAddress(family: IpAddressFamily.IPv4, + address_v4: address.address_v4) + var res = $a + res.add(":") + res.add(Base10.toString(uint16(address.port))) + res of AddressFamily.IPv6: var a = IpAddress(family: IpAddressFamily.IPv6, address_v6: address.address_v6) - result = "[" & $a & "]:" - result.add($(int(address.port))) + var res = "[" & $a & "]:" + res.add(Base10.toString(uint16(address.port))) + res of AddressFamily.Unix: const length = sizeof(address.address_un) + 1 var buffer: array[length, char] @@ -182,11 +186,11 @@ proc `$`*(address: TransportAddress): string = sizeof(address.address_un)): copyMem(addr buffer[0], unsafeAddr address.address_un[0], sizeof(address.address_un)) - result = $cast[cstring](addr buffer) + $cast[cstring](addr buffer) else: - result = "" + "/" else: - result = "Unknown address family: " & $address.family + "Unknown address family: " & $address.family proc initTAddress*(address: string): TransportAddress {. raises: [Defect, TransportAddressError].} = @@ -198,78 +202,82 @@ proc initTAddress*(address: string): TransportAddress {. ## Unix transport address format is ``/address``. if len(address) > 0: if address[0] == '/': - result = TransportAddress(family: AddressFamily.Unix, port: Port(1)) - let size = if len(address) < (sizeof(result.address_un) - 1): len(address) - else: (sizeof(result.address_un) - 1) - copyMem(addr result.address_un[0], unsafeAddr address[0], size) + var res = TransportAddress(family: AddressFamily.Unix, port: Port(1)) + let size = if len(address) < (sizeof(res.address_un) - 1): len(address) + else: (sizeof(res.address_un) - 1) + copyMem(addr res.address_un[0], unsafeAddr address[0], size) + res else: - var port: int - var parts = address.rsplit(":", maxsplit = 1) - if len(parts) != 2: - raise newException(TransportAddressError, - "Format is
: or
!") - try: - port = parseInt(parts[1]) - except: - raise newException(TransportAddressError, "Illegal port number!") - if port < 0 or port >= 65536: - raise newException(TransportAddressError, "Illegal port number!") - try: - var ipaddr: IpAddress - if parts[0][0] == '[' and parts[0][^1] == ']': - ipaddr = parseIpAddress(parts[0][1..^2]) - else: - ipaddr = parseIpAddress(parts[0]) - if ipaddr.family == IpAddressFamily.IPv4: - result = TransportAddress(family: AddressFamily.IPv4) - result.address_v4 = ipaddr.address_v4 - elif ipaddr.family == IpAddressFamily.IPv6: - result = TransportAddress(family: AddressFamily.IPv6) - result.address_v6 = ipaddr.address_v6 - else: - raise newException(TransportAddressError, "Incorrect address family!") - result.port = Port(port) - except CatchableError as exc: - raise newException(TransportAddressError, exc.msg) + let parts = + block: + let res = address.rsplit(":", maxsplit = 1) + if len(res) != 2: + raise newException(TransportAddressError, + "Format is
:!") + res + let port = + block: + let res = Base10.decode(uint16, parts[1]) + if res.isErr(): + raise newException(TransportAddressError, + "Invalid port number!") + res.get() + + let ipaddr = + try: + if parts[0][0] == '[' and parts[0][^1] == ']': + parseIpAddress(parts[0][1..^2]) + else: + parseIpAddress(parts[0]) + except CatchableError as exc: + raise newException(TransportAddressError, exc.msg) + + case ipaddr.family + of IpAddressFamily.IPv4: + TransportAddress(family: AddressFamily.IPv4, + address_v4: ipaddr.address_v4, port: Port(port)) + of IpAddressFamily.IPv6: + TransportAddress(family: AddressFamily.IPv6, + address_v6: ipaddr.address_v6, port: Port(port)) else: - result = TransportAddress(family: AddressFamily.Unix) + TransportAddress(family: AddressFamily.Unix) proc initTAddress*(address: string, port: Port): TransportAddress {. raises: [Defect, TransportAddressError].} = ## Initialize ``TransportAddress`` with IP (IPv4 or IPv6) address ``address`` ## and port number ``port``. - try: - var ipaddr = parseIpAddress(address) - if ipaddr.family == IpAddressFamily.IPv4: - result = TransportAddress(family: AddressFamily.IPv4, port: port) - result.address_v4 = ipaddr.address_v4 - elif ipaddr.family == IpAddressFamily.IPv6: - result = TransportAddress(family: AddressFamily.IPv6, port: port) - result.address_v6 = ipaddr.address_v6 - else: - raise newException(TransportAddressError, "Incorrect address family!") - except CatchableError as exc: - raise newException(TransportAddressError, exc.msg) + let ipaddr = + try: + parseIpAddress(address) + except CatchableError as exc: + raise newException(TransportAddressError, exc.msg) + + case ipaddr.family + of IpAddressFamily.IPv4: + TransportAddress(family: AddressFamily.IPv4, + address_v4: ipaddr.address_v4, port: port) + of IpAddressFamily.IPv6: + TransportAddress(family: AddressFamily.IPv6, + address_v6: ipaddr.address_v6, port: port) proc initTAddress*(address: string, port: int): TransportAddress {. raises: [Defect, TransportAddressError].} = ## Initialize ``TransportAddress`` with IP (IPv4 or IPv6) address ``address`` ## and port number ``port``. - if port < 0 or port >= 65536: + if port < 0 or port > 65535: raise newException(TransportAddressError, "Illegal port number!") - else: - result = initTAddress(address, Port(port)) + initTAddress(address, Port(port)) proc initTAddress*(address: IpAddress, port: Port): TransportAddress = ## Initialize ``TransportAddress`` with net.nim ``IpAddress`` and ## port number ``port``. case address.family of IpAddressFamily.IPv4: - result = TransportAddress(family: AddressFamily.IPv4, port: port) - result.address_v4 = address.address_v4 + TransportAddress(family: AddressFamily.IPv4, + address_v4: address.address_v4, port: port) of IpAddressFamily.IPv6: - result = TransportAddress(family: AddressFamily.IPv6, port: port) - result.address_v6 = address.address_v6 + TransportAddress(family: AddressFamily.IPv6, + address_v6: address.address_v6, port: port) proc getAddrInfo(address: string, port: Port, domain: Domain, sockType: SockType = SockType.SOCK_STREAM, @@ -278,16 +286,18 @@ proc getAddrInfo(address: string, port: Port, domain: Domain, ## We have this one copy of ``getAddrInfo()`` because of AI_V4MAPPED in ## ``net.nim:getAddrInfo()``, which is not cross-platform. var hints: AddrInfo - result = nil + var res: ptr AddrInfo = nil hints.ai_family = toInt(domain) hints.ai_socktype = toInt(sockType) hints.ai_protocol = toInt(protocol) - var gaiResult = getaddrinfo(address, $port, addr(hints), result) - if gaiResult != 0'i32: + var gaiRes = getaddrinfo(address, Base10.toString(uint16(port)), + addr(hints), res) + if gaiRes != 0'i32: when defined(windows): raise newException(TransportAddressError, osErrorMsg(osLastError())) else: - raise newException(TransportAddressError, $gai_strerror(gaiResult)) + raise newException(TransportAddressError, $gai_strerror(gaiRes)) + res proc fromSAddr*(sa: ptr Sockaddr_storage, sl: Socklen, address: var TransportAddress) = @@ -358,17 +368,81 @@ proc address*(ta: TransportAddress): IpAddress {.raises: [Defect, ValueError].} ## ## Note its impossible to convert ``TransportAddress`` of ``Unix`` family, ## because ``IpAddress`` supports only IPv4, IPv6 addresses. - if ta.family == AddressFamily.IPv4: - result = IpAddress(family: IpAddressFamily.IPv4) - result.address_v4 = ta.address_v4 - elif ta.family == AddressFamily.IPv6: - result = IpAddress(family: IpAddressFamily.IPv6) - result.address_v6 = ta.address_v6 + case ta.family + of AddressFamily.IPv4: + IpAddress(family: IpAddressFamily.IPv4, address_v4: ta.address_v4) + of AddressFamily.IPv6: + IpAddress(family: IpAddressFamily.IPv6, address_v6: ta.address_v6) else: raise newException(ValueError, "IpAddress supports only IPv4/IPv6!") +proc resolveTAddress*(address: string, port: Port, + domain: Domain): seq[TransportAddress] {. + raises: [Defect, TransportAddressError].} = + var res: seq[TransportAddress] + let aiList = getAddrInfo(address, Port(port), domain) + var it = aiList + while not(isNil(it)): + var ta: TransportAddress + fromSAddr(cast[ptr Sockaddr_storage](it.ai_addr), + SockLen(it.ai_addrlen), ta) + # For some reason getAddrInfo() sometimes returns duplicate addresses, + # for example getAddrInfo(`localhost`) returns `127.0.0.1` twice. + if ta notin res: + res.add(ta) + it = it.ai_next + res + +proc resolveTAddress*(address: string, domain: Domain): seq[TransportAddress] {. + raises: [Defect, TransportAddressError].} = + let parts = + block: + let res = address.rsplit(":", maxsplit = 1) + if len(res) != 2: + raise newException(TransportAddressError, "Format is
:!") + res + let port = + block: + let res = Base10.decode(uint16, parts[1]) + if res.isErr(): + raise newException(TransportAddressError, "Invalid port number!") + res.get() + let hostname = + if parts[0][0] == '[' and parts[0][^1] == ']': + # IPv6 numeric addresses must be enclosed with `[]`. + parts[0][1..^2] + else: + parts[0] + resolveTAddress(hostname, Port(port), domain) + +proc resolveTAddress*(address: string): seq[TransportAddress] {. + raises: [Defect, TransportAddressError].} = + ## Resolve string representation of ``address``. + ## + ## Supported formats are: + ## IPv4 numeric address ``a.b.c.d:port`` + ## IPv6 numeric address ``[::]:port`` + ## Hostname address ``hostname:port`` + ## + ## If hostname address is detected, then network address translation via DNS + ## will be performed. + resolveTAddress(address, Domain.AF_UNSPEC) + +proc resolveTAddress*(address: string, port: Port): seq[TransportAddress] {. + raises: [Defect, TransportAddressError].} = + ## Resolve string representation of ``address``. + ## + ## Supported formats are: + ## IPv4 numeric address ``a.b.c.d:port`` + ## IPv6 numeric address ``[::]:port`` + ## Hostname address ``hostname:port`` + ## + ## If hostname address is detected, then network address translation via DNS + ## will be performed. + resolveTAddress(address, port, Domain.AF_UNSPEC) + proc resolveTAddress*(address: string, - family = AddressFamily.IPv4): seq[TransportAddress] {. + family: AddressFamily): seq[TransportAddress] {. raises: [Defect, TransportAddressError].} = ## Resolve string representation of ``address``. ## @@ -379,48 +453,16 @@ proc resolveTAddress*(address: string, ## ## If hostname address is detected, then network address translation via DNS ## will be performed. - var - hostname: string - port: int - - doAssert(family in {AddressFamily.IPv4, AddressFamily.IPv6}) - - result = newSeq[TransportAddress]() - var parts = address.rsplit(":", maxsplit = 1) - if len(parts) != 2: - raise newException(TransportAddressError, "Format is
:!") - - try: - port = parseInt(parts[1]) - except: - raise newException(TransportAddressError, "Illegal port number!") - - if port < 0 or port >= 65536: - raise newException(TransportAddressError, "Illegal port number!") - - if parts[0][0] == '[' and parts[0][^1] == ']': - # IPv6 numeric addresses must be enclosed with `[]`. - hostname = parts[0][1..^2] + case family + of AddressFamily.IPv4: + resolveTAddress(address, Domain.AF_INET) + of AddressFamily.IPv6: + resolveTAddress(address, Domain.AF_INET6) else: - hostname = parts[0] - - var domain = if family == AddressFamily.IPv4: Domain.AF_INET else: - Domain.AF_INET6 - var aiList = getAddrInfo(hostname, Port(port), domain) - var it = aiList - while it != nil: - var ta: TransportAddress - fromSAddr(cast[ptr Sockaddr_storage](it.ai_addr), - SockLen(it.ai_addrlen), ta) - # For some reason getAddrInfo() sometimes returns duplicate addresses, - # for example getAddrInfo(`localhost`) returns `127.0.0.1` twice. - if ta notin result: - result.add(ta) - it = it.ai_next - freeAddrInfo(aiList) + raiseAssert("Unable to resolve non-internet address") proc resolveTAddress*(address: string, port: Port, - family = AddressFamily.IPv4): seq[TransportAddress] {. + family: AddressFamily): seq[TransportAddress] {. raises: [Defect, TransportAddressError].} = ## Resolve string representation of ``address``. ## @@ -428,39 +470,31 @@ proc resolveTAddress*(address: string, port: Port, ## ## If hostname address is detected, then network address translation via DNS ## will be performed. - doAssert(family in {AddressFamily.IPv4, AddressFamily.IPv6}) - - result = newSeq[TransportAddress]() - var domain = if family == AddressFamily.IPv4: Domain.AF_INET else: - Domain.AF_INET6 - var aiList = getAddrInfo(address, port, domain) - var it = aiList - while it != nil: - var ta: TransportAddress - fromSAddr(cast[ptr Sockaddr_storage](it.ai_addr), - SockLen(it.ai_addrlen), ta) - # For some reason getAddrInfo() sometimes returns duplicate addresses, - # for example getAddrInfo(`localhost`) returns `127.0.0.1` twice. - if ta notin result: - result.add(ta) - it = it.ai_next - freeAddrInfo(aiList) + case family + of AddressFamily.IPv4: + resolveTAddress(address, port, Domain.AF_INET) + of AddressFamily.IPv6: + resolveTAddress(address, port, Domain.AF_INET6) + else: + raiseAssert("Unable to resolve non-internet address") proc resolveTAddress*(address: string, family: IpAddressFamily): seq[TransportAddress] {. deprecated, raises: [Defect, TransportAddressError].} = - if family == IpAddressFamily.IPv4: - result = resolveTAddress(address, AddressFamily.IPv4) - elif family == IpAddressFamily.IPv6: - result = resolveTAddress(address, AddressFamily.IPv6) + case family + of IpAddressFamily.IPv4: + resolveTAddress(address, AddressFamily.IPv4) + of IpAddressFamily.IPv6: + resolveTAddress(address, AddressFamily.IPv6) proc resolveTAddress*(address: string, port: Port, family: IpAddressFamily): seq[TransportAddress] {. deprecated, raises: [Defect, TransportAddressError].} = - if family == IpAddressFamily.IPv4: - result = resolveTAddress(address, port, AddressFamily.IPv4) - elif family == IpAddressFamily.IPv6: - result = resolveTAddress(address, port, AddressFamily.IPv6) + case family + of IpAddressFamily.IPv4: + resolveTAddress(address, port, AddressFamily.IPv4) + of IpAddressFamily.IPv6: + resolveTAddress(address, port, AddressFamily.IPv6) proc windowsAnyAddressFix*(a: TransportAddress): TransportAddress = ## BSD Sockets on *nix systems are able to perform connections to @@ -468,16 +502,20 @@ proc windowsAnyAddressFix*(a: TransportAddress): TransportAddress = when defined(windows): if (a.family == AddressFamily.IPv4 and a.address_v4 == AnyAddress.address_v4): - result = try: initTAddress("127.0.0.1", a.port) - except TransportAddressError as exc: raiseAssert exc.msg + try: + initTAddress("127.0.0.1", a.port) + except TransportAddressError as exc: + raiseAssert exc.msg elif (a.family == AddressFamily.IPv6 and a.address_v6 == AnyAddress6.address_v6): - result = try: initTAddress("::1", a.port) - except TransportAddressError as exc: raiseAssert exc.msg + try: + initTAddress("::1", a.port) + except TransportAddressError as exc: + raiseAssert exc.msg else: - result = a + a else: - result = a + a template checkClosed*(t: untyped) = if (ReadClosed in (t).state) or (WriteClosed in (t).state): diff --git a/tests/testaddress.nim b/tests/testaddress.nim index a1a588e..040fc28 100644 --- a/tests/testaddress.nim +++ b/tests/testaddress.nim @@ -151,7 +151,7 @@ suite "TransportAddress test suite": var errcounter = 0 for item in numeric: try: - discard resolveTAddress(item) + discard resolveTAddress(item, AddressFamily.IPv4) except TransportAddressError: inc(errcounter) check errcounter == len(numeric) @@ -168,7 +168,7 @@ suite "TransportAddress test suite": var errcounter = 0 for item in numeric: try: - discard resolveTAddress(item, Port(443)) + discard resolveTAddress(item, Port(443), AddressFamily.IPv4) except TransportAddressError: inc(errcounter) check errcounter == len(numeric)