Refactor common.nim and add more resolve procedures. (#177)

* Refactor common.nim to remove `result` usage.
Fix comparison of TransportAddress issue.
Add resolveTAddress procedures for both IPv4 and IPv6 addresses.
Fix tests.

* Bump version to 3.0.2.
This commit is contained in:
Eugene Kabanov 2021-04-10 00:39:54 +03:00 committed by GitHub
parent 895fc53193
commit aab1e30a72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 203 additions and 165 deletions

View File

@ -1,5 +1,5 @@
packageName = "chronos"
version = "3.0.1"
version = "3.0.2"
author = "Status Research & Development GmbH"
description = "Chronos"
license = "Apache License 2.0 or MIT"

View File

@ -10,6 +10,7 @@
{.push raises: [Defect].}
import std/[os, strutils, nativesockets, net]
import stew/base10
import ../asyncloop
export net
@ -130,51 +131,54 @@ var
proc `==`*(lhs, rhs: TransportAddress): bool =
## Compare two transport addresses ``lhs`` and ``rhs``. Return ``true`` if
## addresses are equal.
if lhs.family != lhs.family:
if lhs.family != rhs.family:
return false
if lhs.family == AddressFamily.IPv4:
result = equalMem(unsafeAddr lhs.address_v4[0],
unsafeAddr rhs.address_v4[0], sizeof(lhs.address_v4)) and
(lhs.port == rhs.port)
elif lhs.family == AddressFamily.IPv6:
result = equalMem(unsafeAddr lhs.address_v6[0],
unsafeAddr rhs.address_v6[0], sizeof(lhs.address_v6)) and
(lhs.port == rhs.port)
elif lhs.family == AddressFamily.Unix:
result = equalMem(unsafeAddr lhs.address_un[0],
unsafeAddr rhs.address_un[0], sizeof(lhs.address_un))
case lhs.family
of AddressFamily.None:
true
of AddressFamily.IPv4:
equalMem(unsafeAddr lhs.address_v4[0],
unsafeAddr rhs.address_v4[0], sizeof(lhs.address_v4)) and
(lhs.port == rhs.port)
of AddressFamily.IPv6:
equalMem(unsafeAddr lhs.address_v6[0],
unsafeAddr rhs.address_v6[0], sizeof(lhs.address_v6)) and
(lhs.port == rhs.port)
of AddressFamily.Unix:
equalMem(unsafeAddr lhs.address_un[0],
unsafeAddr rhs.address_un[0], sizeof(lhs.address_un))
proc getDomain*(address: TransportAddress): Domain =
## Returns OS specific Domain from TransportAddress.
case address.family
of AddressFamily.IPv4:
result = Domain.AF_INET
Domain.AF_INET
of AddressFamily.IPv6:
result = Domain.AF_INET6
Domain.AF_INET6
of AddressFamily.Unix:
when defined(windows):
result = cast[Domain](1)
cast[Domain](1)
else:
result = Domain.AF_UNIX
Domain.AF_UNIX
else:
result = cast[Domain](0)
cast[Domain](0)
proc `$`*(address: TransportAddress): string =
## Returns string representation of ``address``.
case address.family
of AddressFamily.IPv4:
var a = IpAddress(
family: IpAddressFamily.IPv4,
address_v4: address.address_v4
)
result = $a
result.add(":")
result.add($int(address.port))
var a = IpAddress(family: IpAddressFamily.IPv4,
address_v4: address.address_v4)
var res = $a
res.add(":")
res.add(Base10.toString(uint16(address.port)))
res
of AddressFamily.IPv6:
var a = IpAddress(family: IpAddressFamily.IPv6,
address_v6: address.address_v6)
result = "[" & $a & "]:"
result.add($(int(address.port)))
var res = "[" & $a & "]:"
res.add(Base10.toString(uint16(address.port)))
res
of AddressFamily.Unix:
const length = sizeof(address.address_un) + 1
var buffer: array[length, char]
@ -182,11 +186,11 @@ proc `$`*(address: TransportAddress): string =
sizeof(address.address_un)):
copyMem(addr buffer[0], unsafeAddr address.address_un[0],
sizeof(address.address_un))
result = $cast[cstring](addr buffer)
$cast[cstring](addr buffer)
else:
result = ""
"/"
else:
result = "Unknown address family: " & $address.family
"Unknown address family: " & $address.family
proc initTAddress*(address: string): TransportAddress {.
raises: [Defect, TransportAddressError].} =
@ -198,78 +202,82 @@ proc initTAddress*(address: string): TransportAddress {.
## Unix transport address format is ``/address``.
if len(address) > 0:
if address[0] == '/':
result = TransportAddress(family: AddressFamily.Unix, port: Port(1))
let size = if len(address) < (sizeof(result.address_un) - 1): len(address)
else: (sizeof(result.address_un) - 1)
copyMem(addr result.address_un[0], unsafeAddr address[0], size)
var res = TransportAddress(family: AddressFamily.Unix, port: Port(1))
let size = if len(address) < (sizeof(res.address_un) - 1): len(address)
else: (sizeof(res.address_un) - 1)
copyMem(addr res.address_un[0], unsafeAddr address[0], size)
res
else:
var port: int
var parts = address.rsplit(":", maxsplit = 1)
if len(parts) != 2:
raise newException(TransportAddressError,
"Format is <address>:<port> or </address>!")
try:
port = parseInt(parts[1])
except:
raise newException(TransportAddressError, "Illegal port number!")
if port < 0 or port >= 65536:
raise newException(TransportAddressError, "Illegal port number!")
try:
var ipaddr: IpAddress
if parts[0][0] == '[' and parts[0][^1] == ']':
ipaddr = parseIpAddress(parts[0][1..^2])
else:
ipaddr = parseIpAddress(parts[0])
if ipaddr.family == IpAddressFamily.IPv4:
result = TransportAddress(family: AddressFamily.IPv4)
result.address_v4 = ipaddr.address_v4
elif ipaddr.family == IpAddressFamily.IPv6:
result = TransportAddress(family: AddressFamily.IPv6)
result.address_v6 = ipaddr.address_v6
else:
raise newException(TransportAddressError, "Incorrect address family!")
result.port = Port(port)
except CatchableError as exc:
raise newException(TransportAddressError, exc.msg)
let parts =
block:
let res = address.rsplit(":", maxsplit = 1)
if len(res) != 2:
raise newException(TransportAddressError,
"Format is <address>:<port>!")
res
let port =
block:
let res = Base10.decode(uint16, parts[1])
if res.isErr():
raise newException(TransportAddressError,
"Invalid port number!")
res.get()
let ipaddr =
try:
if parts[0][0] == '[' and parts[0][^1] == ']':
parseIpAddress(parts[0][1..^2])
else:
parseIpAddress(parts[0])
except CatchableError as exc:
raise newException(TransportAddressError, exc.msg)
case ipaddr.family
of IpAddressFamily.IPv4:
TransportAddress(family: AddressFamily.IPv4,
address_v4: ipaddr.address_v4, port: Port(port))
of IpAddressFamily.IPv6:
TransportAddress(family: AddressFamily.IPv6,
address_v6: ipaddr.address_v6, port: Port(port))
else:
result = TransportAddress(family: AddressFamily.Unix)
TransportAddress(family: AddressFamily.Unix)
proc initTAddress*(address: string, port: Port): TransportAddress {.
raises: [Defect, TransportAddressError].} =
## Initialize ``TransportAddress`` with IP (IPv4 or IPv6) address ``address``
## and port number ``port``.
try:
var ipaddr = parseIpAddress(address)
if ipaddr.family == IpAddressFamily.IPv4:
result = TransportAddress(family: AddressFamily.IPv4, port: port)
result.address_v4 = ipaddr.address_v4
elif ipaddr.family == IpAddressFamily.IPv6:
result = TransportAddress(family: AddressFamily.IPv6, port: port)
result.address_v6 = ipaddr.address_v6
else:
raise newException(TransportAddressError, "Incorrect address family!")
except CatchableError as exc:
raise newException(TransportAddressError, exc.msg)
let ipaddr =
try:
parseIpAddress(address)
except CatchableError as exc:
raise newException(TransportAddressError, exc.msg)
case ipaddr.family
of IpAddressFamily.IPv4:
TransportAddress(family: AddressFamily.IPv4,
address_v4: ipaddr.address_v4, port: port)
of IpAddressFamily.IPv6:
TransportAddress(family: AddressFamily.IPv6,
address_v6: ipaddr.address_v6, port: port)
proc initTAddress*(address: string, port: int): TransportAddress {.
raises: [Defect, TransportAddressError].} =
## Initialize ``TransportAddress`` with IP (IPv4 or IPv6) address ``address``
## and port number ``port``.
if port < 0 or port >= 65536:
if port < 0 or port > 65535:
raise newException(TransportAddressError, "Illegal port number!")
else:
result = initTAddress(address, Port(port))
initTAddress(address, Port(port))
proc initTAddress*(address: IpAddress, port: Port): TransportAddress =
## Initialize ``TransportAddress`` with net.nim ``IpAddress`` and
## port number ``port``.
case address.family
of IpAddressFamily.IPv4:
result = TransportAddress(family: AddressFamily.IPv4, port: port)
result.address_v4 = address.address_v4
TransportAddress(family: AddressFamily.IPv4,
address_v4: address.address_v4, port: port)
of IpAddressFamily.IPv6:
result = TransportAddress(family: AddressFamily.IPv6, port: port)
result.address_v6 = address.address_v6
TransportAddress(family: AddressFamily.IPv6,
address_v6: address.address_v6, port: port)
proc getAddrInfo(address: string, port: Port, domain: Domain,
sockType: SockType = SockType.SOCK_STREAM,
@ -278,16 +286,18 @@ proc getAddrInfo(address: string, port: Port, domain: Domain,
## We have this one copy of ``getAddrInfo()`` because of AI_V4MAPPED in
## ``net.nim:getAddrInfo()``, which is not cross-platform.
var hints: AddrInfo
result = nil
var res: ptr AddrInfo = nil
hints.ai_family = toInt(domain)
hints.ai_socktype = toInt(sockType)
hints.ai_protocol = toInt(protocol)
var gaiResult = getaddrinfo(address, $port, addr(hints), result)
if gaiResult != 0'i32:
var gaiRes = getaddrinfo(address, Base10.toString(uint16(port)),
addr(hints), res)
if gaiRes != 0'i32:
when defined(windows):
raise newException(TransportAddressError, osErrorMsg(osLastError()))
else:
raise newException(TransportAddressError, $gai_strerror(gaiResult))
raise newException(TransportAddressError, $gai_strerror(gaiRes))
res
proc fromSAddr*(sa: ptr Sockaddr_storage, sl: Socklen,
address: var TransportAddress) =
@ -358,17 +368,81 @@ proc address*(ta: TransportAddress): IpAddress {.raises: [Defect, ValueError].}
##
## Note its impossible to convert ``TransportAddress`` of ``Unix`` family,
## because ``IpAddress`` supports only IPv4, IPv6 addresses.
if ta.family == AddressFamily.IPv4:
result = IpAddress(family: IpAddressFamily.IPv4)
result.address_v4 = ta.address_v4
elif ta.family == AddressFamily.IPv6:
result = IpAddress(family: IpAddressFamily.IPv6)
result.address_v6 = ta.address_v6
case ta.family
of AddressFamily.IPv4:
IpAddress(family: IpAddressFamily.IPv4, address_v4: ta.address_v4)
of AddressFamily.IPv6:
IpAddress(family: IpAddressFamily.IPv6, address_v6: ta.address_v6)
else:
raise newException(ValueError, "IpAddress supports only IPv4/IPv6!")
proc resolveTAddress*(address: string, port: Port,
domain: Domain): seq[TransportAddress] {.
raises: [Defect, TransportAddressError].} =
var res: seq[TransportAddress]
let aiList = getAddrInfo(address, Port(port), domain)
var it = aiList
while not(isNil(it)):
var ta: TransportAddress
fromSAddr(cast[ptr Sockaddr_storage](it.ai_addr),
SockLen(it.ai_addrlen), ta)
# For some reason getAddrInfo() sometimes returns duplicate addresses,
# for example getAddrInfo(`localhost`) returns `127.0.0.1` twice.
if ta notin res:
res.add(ta)
it = it.ai_next
res
proc resolveTAddress*(address: string, domain: Domain): seq[TransportAddress] {.
raises: [Defect, TransportAddressError].} =
let parts =
block:
let res = address.rsplit(":", maxsplit = 1)
if len(res) != 2:
raise newException(TransportAddressError, "Format is <address>:<port>!")
res
let port =
block:
let res = Base10.decode(uint16, parts[1])
if res.isErr():
raise newException(TransportAddressError, "Invalid port number!")
res.get()
let hostname =
if parts[0][0] == '[' and parts[0][^1] == ']':
# IPv6 numeric addresses must be enclosed with `[]`.
parts[0][1..^2]
else:
parts[0]
resolveTAddress(hostname, Port(port), domain)
proc resolveTAddress*(address: string): seq[TransportAddress] {.
raises: [Defect, TransportAddressError].} =
## Resolve string representation of ``address``.
##
## Supported formats are:
## IPv4 numeric address ``a.b.c.d:port``
## IPv6 numeric address ``[::]:port``
## Hostname address ``hostname:port``
##
## If hostname address is detected, then network address translation via DNS
## will be performed.
resolveTAddress(address, Domain.AF_UNSPEC)
proc resolveTAddress*(address: string, port: Port): seq[TransportAddress] {.
raises: [Defect, TransportAddressError].} =
## Resolve string representation of ``address``.
##
## Supported formats are:
## IPv4 numeric address ``a.b.c.d:port``
## IPv6 numeric address ``[::]:port``
## Hostname address ``hostname:port``
##
## If hostname address is detected, then network address translation via DNS
## will be performed.
resolveTAddress(address, port, Domain.AF_UNSPEC)
proc resolveTAddress*(address: string,
family = AddressFamily.IPv4): seq[TransportAddress] {.
family: AddressFamily): seq[TransportAddress] {.
raises: [Defect, TransportAddressError].} =
## Resolve string representation of ``address``.
##
@ -379,48 +453,16 @@ proc resolveTAddress*(address: string,
##
## If hostname address is detected, then network address translation via DNS
## will be performed.
var
hostname: string
port: int
doAssert(family in {AddressFamily.IPv4, AddressFamily.IPv6})
result = newSeq[TransportAddress]()
var parts = address.rsplit(":", maxsplit = 1)
if len(parts) != 2:
raise newException(TransportAddressError, "Format is <address>:<port>!")
try:
port = parseInt(parts[1])
except:
raise newException(TransportAddressError, "Illegal port number!")
if port < 0 or port >= 65536:
raise newException(TransportAddressError, "Illegal port number!")
if parts[0][0] == '[' and parts[0][^1] == ']':
# IPv6 numeric addresses must be enclosed with `[]`.
hostname = parts[0][1..^2]
case family
of AddressFamily.IPv4:
resolveTAddress(address, Domain.AF_INET)
of AddressFamily.IPv6:
resolveTAddress(address, Domain.AF_INET6)
else:
hostname = parts[0]
var domain = if family == AddressFamily.IPv4: Domain.AF_INET else:
Domain.AF_INET6
var aiList = getAddrInfo(hostname, Port(port), domain)
var it = aiList
while it != nil:
var ta: TransportAddress
fromSAddr(cast[ptr Sockaddr_storage](it.ai_addr),
SockLen(it.ai_addrlen), ta)
# For some reason getAddrInfo() sometimes returns duplicate addresses,
# for example getAddrInfo(`localhost`) returns `127.0.0.1` twice.
if ta notin result:
result.add(ta)
it = it.ai_next
freeAddrInfo(aiList)
raiseAssert("Unable to resolve non-internet address")
proc resolveTAddress*(address: string, port: Port,
family = AddressFamily.IPv4): seq[TransportAddress] {.
family: AddressFamily): seq[TransportAddress] {.
raises: [Defect, TransportAddressError].} =
## Resolve string representation of ``address``.
##
@ -428,39 +470,31 @@ proc resolveTAddress*(address: string, port: Port,
##
## If hostname address is detected, then network address translation via DNS
## will be performed.
doAssert(family in {AddressFamily.IPv4, AddressFamily.IPv6})
result = newSeq[TransportAddress]()
var domain = if family == AddressFamily.IPv4: Domain.AF_INET else:
Domain.AF_INET6
var aiList = getAddrInfo(address, port, domain)
var it = aiList
while it != nil:
var ta: TransportAddress
fromSAddr(cast[ptr Sockaddr_storage](it.ai_addr),
SockLen(it.ai_addrlen), ta)
# For some reason getAddrInfo() sometimes returns duplicate addresses,
# for example getAddrInfo(`localhost`) returns `127.0.0.1` twice.
if ta notin result:
result.add(ta)
it = it.ai_next
freeAddrInfo(aiList)
case family
of AddressFamily.IPv4:
resolveTAddress(address, port, Domain.AF_INET)
of AddressFamily.IPv6:
resolveTAddress(address, port, Domain.AF_INET6)
else:
raiseAssert("Unable to resolve non-internet address")
proc resolveTAddress*(address: string,
family: IpAddressFamily): seq[TransportAddress] {.
deprecated, raises: [Defect, TransportAddressError].} =
if family == IpAddressFamily.IPv4:
result = resolveTAddress(address, AddressFamily.IPv4)
elif family == IpAddressFamily.IPv6:
result = resolveTAddress(address, AddressFamily.IPv6)
case family
of IpAddressFamily.IPv4:
resolveTAddress(address, AddressFamily.IPv4)
of IpAddressFamily.IPv6:
resolveTAddress(address, AddressFamily.IPv6)
proc resolveTAddress*(address: string, port: Port,
family: IpAddressFamily): seq[TransportAddress] {.
deprecated, raises: [Defect, TransportAddressError].} =
if family == IpAddressFamily.IPv4:
result = resolveTAddress(address, port, AddressFamily.IPv4)
elif family == IpAddressFamily.IPv6:
result = resolveTAddress(address, port, AddressFamily.IPv6)
case family
of IpAddressFamily.IPv4:
resolveTAddress(address, port, AddressFamily.IPv4)
of IpAddressFamily.IPv6:
resolveTAddress(address, port, AddressFamily.IPv6)
proc windowsAnyAddressFix*(a: TransportAddress): TransportAddress =
## BSD Sockets on *nix systems are able to perform connections to
@ -468,16 +502,20 @@ proc windowsAnyAddressFix*(a: TransportAddress): TransportAddress =
when defined(windows):
if (a.family == AddressFamily.IPv4 and
a.address_v4 == AnyAddress.address_v4):
result = try: initTAddress("127.0.0.1", a.port)
except TransportAddressError as exc: raiseAssert exc.msg
try:
initTAddress("127.0.0.1", a.port)
except TransportAddressError as exc:
raiseAssert exc.msg
elif (a.family == AddressFamily.IPv6 and
a.address_v6 == AnyAddress6.address_v6):
result = try: initTAddress("::1", a.port)
except TransportAddressError as exc: raiseAssert exc.msg
try:
initTAddress("::1", a.port)
except TransportAddressError as exc:
raiseAssert exc.msg
else:
result = a
a
else:
result = a
a
template checkClosed*(t: untyped) =
if (ReadClosed in (t).state) or (WriteClosed in (t).state):

View File

@ -151,7 +151,7 @@ suite "TransportAddress test suite":
var errcounter = 0
for item in numeric:
try:
discard resolveTAddress(item)
discard resolveTAddress(item, AddressFamily.IPv4)
except TransportAddressError:
inc(errcounter)
check errcounter == len(numeric)
@ -168,7 +168,7 @@ suite "TransportAddress test suite":
var errcounter = 0
for item in numeric:
try:
discard resolveTAddress(item, Port(443))
discard resolveTAddress(item, Port(443), AddressFamily.IPv4)
except TransportAddressError:
inc(errcounter)
check errcounter == len(numeric)