From 0ee9a148c70fec40bb4195c92a2b6c354ce4abc6 Mon Sep 17 00:00:00 2001 From: cheatfate Date: Sun, 10 Jun 2018 03:55:19 +0300 Subject: [PATCH] Fix for TransportAddress resolveTAddress behavior. Added more tests for TransportAddress. --- asyncdispatch2/transports/common.nim | 158 +++++++++++++-------- tests/testaddress.nim | 201 +++++++++++++++++++-------- 2 files changed, 246 insertions(+), 113 deletions(-) diff --git a/asyncdispatch2/transports/common.nim b/asyncdispatch2/transports/common.nim index 1c81549..df2a3ea 100644 --- a/asyncdispatch2/transports/common.nim +++ b/asyncdispatch2/transports/common.nim @@ -7,11 +7,16 @@ # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) -import net, nativesockets, strutils +import os, net, strutils +from nativesockets import toInt import ../asyncloop - export net +when defined(windows): + import winlean +else: + import posix + const DefaultStreamBufferSize* = 4096 ## Default buffer size for stream ## transports @@ -85,6 +90,7 @@ type ## Transport's `incomplete data received` exception TransportLimitError* = object of TransportError ## Transport's `data limit reached` exception + TransportAddressError* = object of TransportError TransportState* = enum ## Transport's state @@ -135,26 +141,60 @@ proc initTAddress*(address: string): TransportAddress = ## IPv4 transport address format is ``a.b.c.d:port``. ## IPv6 transport address format is ``[::]:port``. var parts = address.rsplit(":", maxsplit = 1) - doAssert(len(parts) == 2, "Format is
:!") - let port = parseInt(parts[1]) - doAssert(port >= 0 and port < 65536, "Illegal port number!") - result.port = Port(port) - if parts[0][0] == '[' and parts[0][^1] == ']': - result.address = parseIpAddress(parts[0][1..^2]) - else: - result.address = parseIpAddress(parts[0]) + if len(parts) != 2: + raise newException(TransportAddressError, "Format is
:!") + + try: + let port = parseInt(parts[1]) + doAssert(port > 0 and port < 65536) + result.port = Port(port) + except: + raise newException(TransportAddressError, "Illegal port number!") + + try: + if parts[0][0] == '[' and parts[0][^1] == ']': + result.address = parseIpAddress(parts[0][1..^2]) + else: + result.address = parseIpAddress(parts[0]) + except: + raise newException(TransportAddressError, getCurrentException().msg) proc initTAddress*(address: string, port: Port): TransportAddress = ## Initialize ``TransportAddress`` with IP address ``address`` and ## port number ``port``. - result.address = parseIpAddress(address) - result.port = port + try: + result.address = parseIpAddress(address) + result.port = port + except: + raise newException(TransportAddressError, getCurrentException().msg) proc initTAddress*(address: string, port: int): TransportAddress = ## Initialize ``TransportAddress`` with IP address ``address`` and ## port number ``port``. - result.address = parseIpAddress(address) - result.port = Port(port and 0xFFFF) + if port < 0 or port >= 65536: + raise newException(TransportAddressError, "Illegal port number!") + try: + result.address = parseIpAddress(address) + result.port = Port(port) + except: + raise newException(TransportAddressError, getCurrentException().msg) + +proc getAddrInfo(address: string, port: Port, domain: Domain, + sockType: SockType = SockType.SOCK_STREAM, + protocol: Protocol = Protocol.IPPROTO_TCP): ptr AddrInfo = + ## We have this one copy of ``getAddrInfo()`` because of AI_V4MAPPED in + ## ``net.nim:getAddrInfo()``, which is not cross-platform. + var hints: AddrInfo + result = nil + hints.ai_family = toInt(domain) + hints.ai_socktype = toInt(sockType) + hints.ai_protocol = toInt(protocol) + var gaiResult = getaddrinfo(address, $port, addr(hints), result) + if gaiResult != 0'i32: + when defined(windows): + raise newException(TransportAddressError, osErrorMsg(osLastError())) + else: + raise newException(TransportAddressError, $gai_strerror(gaiResult)) proc resolveTAddress*(address: string, family = IpAddressFamily.IPv4): seq[TransportAddress] = @@ -167,63 +207,65 @@ proc resolveTAddress*(address: string, ## ## If hostname address is detected, then network address translation via DNS ## will be performed. + var + hostname: string + port: int + result = newSeq[TransportAddress]() var parts = address.rsplit(":", maxsplit = 1) - doAssert(len(parts) == 2, "Format is
:!") - let port = parseInt(parts[1]) - doAssert(port >= 0 and port < 65536, "Illegal port number!") + if len(parts) != 2: + raise newException(TransportAddressError, "Format is
:!") + + try: + port = parseInt(parts[1]) + doAssert(port > 0 and port < 65536) + except: + raise newException(TransportAddressError, "Illegal port number!") + if parts[0][0] == '[' and parts[0][^1] == ']': - let ta = TransportAddress(address: parseIpAddress(parts[0][1..^2]), - port: Port(port)) - result.add(ta) + # IPv6 numeric addresses must be enclosed with `[]`. + hostname = parts[0][1..^2] else: - if isIpAddress(parts[0]): - let ta = TransportAddress(address: parseIpAddress(parts[0]), - port: Port(port)) + hostname = parts[0] + + var domain = if family == IpAddressFamily.IPv4: Domain.AF_INET else: + Domain.AF_INET6 + var aiList = getAddrInfo(hostname, Port(port), domain) + var it = aiList + while it != nil: + var ta: TransportAddress + fromSockAddr(cast[ptr Sockaddr_storage](it.ai_addr)[], + SockLen(it.ai_addrlen), ta.address, ta.port) + # For some reason getAddrInfo() sometimes returns duplicate addresses, + # for example getAddrInfo(`localhost`) returns `127.0.0.1` twice. + if ta notin result: result.add(ta) - else: - var domain = if family == IpAddressFamily.IPv4: Domain(AF_INET) else: - Domain(AF_INET6) - var aiList = getAddrInfo(parts[0], Port(port), domain) - var it = aiList - while it != nil: - var ta: TransportAddress - fromSockAddr(cast[ptr Sockaddr_storage](it.ai_addr)[], - SockLen(it.ai_addrlen), ta.address, ta.port) - # For some reason getAddrInfo() sometimes returns duplicate addresses, - # for example getAddrInfo(`localhost`) returns `127.0.0.1` twice. - if ta notin result: - result.add(ta) - it = it.ai_next - freeAddrInfo(aiList) + it = it.ai_next + freeAddrInfo(aiList) proc resolveTAddress*(address: string, port: Port, family = IpAddressFamily.IPv4): seq[TransportAddress] = ## Resolve string representation of ``address``. - ## + ## ## ``address`` could be dot IPv4/IPv6 address or hostname. - ## + ## ## If hostname address is detected, then network address translation via DNS ## will be performed. result = newSeq[TransportAddress]() - if isIpAddress(address): - let ta = TransportAddress(address: parseIpAddress(address), port: port) - result.add(ta) - else: - var domain = if family == IpAddressFamily.IPv4: Domain(AF_INET) else: - Domain(AF_INET6) - var aiList = getAddrInfo(address, port, domain) - var it = aiList - while it != nil: - var ta: TransportAddress - fromSockAddr(cast[ptr Sockaddr_storage](it.ai_addr)[], - SockLen(it.ai_addrlen), ta.address, ta.port) - # For some reason getAddrInfo() sometimes returns duplicate addresses, - # for example getAddrInfo(`localhost`) returns `127.0.0.1` twice. - if ta notin result: - result.add(ta) - it = it.ai_next - freeAddrInfo(aiList) + var domain = if family == IpAddressFamily.IPv4: Domain.AF_INET else: + Domain.AF_INET6 + var aiList = getAddrInfo(address, port, domain) + var it = aiList + while it != nil: + var ta: TransportAddress + fromSockAddr(cast[ptr Sockaddr_storage](it.ai_addr)[], + SockLen(it.ai_addrlen), ta.address, ta.port) + # For some reason getAddrInfo() sometimes returns duplicate addresses, + # for example getAddrInfo(`localhost`) returns `127.0.0.1` twice. + if ta notin result: + result.add(ta) + it = it.ai_next + freeAddrInfo(aiList) template checkClosed*(t: untyped) = if (ReadClosed in (t).state) or (WriteClosed in (t).state): diff --git a/tests/testaddress.nim b/tests/testaddress.nim index 88650b7..b576d1f 100644 --- a/tests/testaddress.nim +++ b/tests/testaddress.nim @@ -12,11 +12,12 @@ import ../asyncdispatch2 when isMainModule: suite "TransportAddress test suite": test "initTAddress(string)": - check $initTAddress("0.0.0.0:0") == "0.0.0.0:0" + check $initTAddress("0.0.0.0:1") == "0.0.0.0:1" check $initTAddress("255.255.255.255:65535") == "255.255.255.255:65535" - check $initTAddress("[::]:0") == "[::]:0" + check $initTAddress("[::]:1") == "[::]:1" check $initTAddress("[FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF]:65535") == "[ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff]:65535" + test "initTAddress(string, Port)": check $initTAddress("0.0.0.0", Port(0)) == "0.0.0.0:0" check $initTAddress("255.255.255.255", Port(65535)) == @@ -25,31 +26,21 @@ when isMainModule: check $initTAddress("FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF", Port(65535)) == "[ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff]:65535" + test "initTAddress(string, int)": - check $initTAddress("0.0.0.0", 0) == "0.0.0.0:0" + check $initTAddress("0.0.0.0", 1) == "0.0.0.0:1" check $initTAddress("255.255.255.255", 65535) == "255.255.255.255:65535" check $initTAddress("::", 0) == "[::]:0" check $initTAddress("FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF", 65535) == "[ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff]:65535" - test "resolveTAddress(string)": - var numeric = [ - "0.0.0.0:0", - "255.0.0.255:54321", - "128.128.128.128:12345", - "255.255.255.255:65535", - "[::]:0", - "[ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff]:65535", - "[aaaa:bbbb:cccc:dddd:eeee:ffff::1111]:12345", - "[aaaa:bbbb:cccc:dddd:eeee:ffff::]:12345", - "[a:b:c:d:e:f::]:12345", - "[2222:3333:4444:5555:6666:7777:8888:9999]:56789" - ] - var hostnames = [ - "www.google.com:443", - "www.github.com:443", - "localhost:443" - ] + + test "resolveTAddress(string, IPv4)": + var numeric = ["0.0.0.0:1", "255.0.0.255:54321", "128.128.128.128:12345", + "255.255.255.255:65535"] + var hostnames = ["www.google.com:443", "www.github.com:443", + "localhost:443"] + for item in numeric: var taseq = resolveTAddress(item) check len(taseq) == 1 @@ -58,15 +49,43 @@ when isMainModule: for item in hostnames: var taseq = resolveTAddress(item) check len(taseq) >= 1 - test "resolveTAddress(string, Port)": - var numeric4 = [ - "0.0.0.0", - "255.0.0.255", - "128.128.128.128", - "255.255.255.255" - ] - var numeric6 = [ + test "resolveTAddress(string, IPv6)": + var numeric = [ + "[::]:1", + "[ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff]:65535", + "[aaaa:bbbb:cccc:dddd:eeee:ffff::1111]:12345", + "[aaaa:bbbb:cccc:dddd:eeee:ffff::]:12345", + "[a:b:c:d:e:f::]:12345", + "[2222:3333:4444:5555:6666:7777:8888:9999]:56789" + ] + var hostnames = ["localhost:443"] + + for item in numeric: + var taseq = resolveTAddress(item, IpAddressFamily.IPv6) + check len(taseq) == 1 + check $taseq[0] == item + + for item in hostnames: + var taseq = resolveTAddress(item, IpAddressFamily.IPv6) + check len(taseq) >= 1 + + test "resolveTAddress(string, Port, IPv4)": + var numeric = ["0.0.0.0", "255.0.0.255", "128.128.128.128", + "255.255.255.255"] + var hostnames = ["www.google.com", "www.github.com", "localhost"] + + for item in numeric: + var taseq = resolveTAddress(item, Port(443)) + check len(taseq) == 1 + check $taseq[0] == item & ":443" + + for item in hostnames: + var taseq = resolveTAddress(item, Port(443)) + check len(taseq) >= 1 + + test "resolveTAddress(string, Port, IPv6)": + var numeric = [ "::", "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", "aaaa:bbbb:cccc:dddd:eeee:ffff::1111", @@ -74,33 +93,105 @@ when isMainModule: "a:b:c:d:e:f::", "2222:3333:4444:5555:6666:7777:8888:9999" ] - var hostnames = [ - "www.google.com", - "www.github.com", - "localhost" - ] - for item in numeric4: - var taseq = resolveTAddress(item, Port(443)) - check len(taseq) == 1 - check $taseq[0] == item & ":443" - - for item in numeric6: - var taseq = resolveTAddress(item, Port(443)) - check len(taseq) == 1 - check $taseq[0] == "[" & item & "]" & ":443" - - for item in hostnames: - var taseq = resolveTAddress(item, Port(443)) - check len(taseq) >= 1 - - test "resolveTAddress(string) (IPv6 only)": - var hostnames = ["localhost:443"] - for item in hostnames: - var taseq = resolveTAddress(item, IpAddressFamily.IPv6) - check len(taseq) >= 1 - - test "resolveTAddress(string, Port) (IPv6 only)": var hostnames = ["localhost"] + for item in numeric: + var taseq = resolveTAddress(item, Port(443), IpAddressFamily.IPv6) + check len(taseq) == 1 + check $taseq[0] == "[" & item & "]:443" + for item in hostnames: var taseq = resolveTAddress(item, Port(443), IpAddressFamily.IPv6) check len(taseq) >= 1 + + test "Faulty initTAddress(string)": + var tests = [ + "z:1", + "256.256.256.256:65534", + "127.0.0.1:65536" + ] + var errcounter = 0 + for item in tests: + try: + var ta = initTAddress(item) + except TransportAddressError: + inc(errcounter) + check errcounter == len(tests) + + test "Faulty initTAddress(string, Port)": + var tests = [ + ":::", + "999.999.999.999", + "gggg:aaaa:bbbb:gggg:aaaa:bbbb:gggg:aaaa", + "hostname" + ] + var errcounter = 0 + for item in tests: + try: + var ta = initTAddress(item, Port(443)) + except TransportAddressError: + inc(errcounter) + check errcounter == len(tests) + + test "Faulty initTAddress(string, Port)": + var errcounter = 0 + try: + var ta = initTAddress("127.0.0.1", 100000) + except TransportAddressError: + inc(errcounter) + check errcounter == 1 + + test "Faulty resolveTAddress(string, IPv4) for IPv6 address": + var numeric = [ + "[::]:1", + "[ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff]:65535", + "[aaaa:bbbb:cccc:dddd:eeee:ffff::1111]:12345", + "[aaaa:bbbb:cccc:dddd:eeee:ffff::]:12345", + "[a:b:c:d:e:f::]:12345", + "[2222:3333:4444:5555:6666:7777:8888:9999]:56789" + ] + var errcounter = 0 + for item in numeric: + try: + var taseq = resolveTAddress(item) + except TransportAddressError: + inc(errcounter) + check errcounter == len(numeric) + + test "Faulty resolveTAddress(string, Port, IPv4) for IPv6 address": + var numeric = [ + "::", + "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", + "aaaa:bbbb:cccc:dddd:eeee:ffff::1111", + "aaaa:bbbb:cccc:dddd:eeee:ffff::", + "a:b:c:d:e:f::", + "2222:3333:4444:5555:6666:7777:8888:9999" + ] + var errcounter = 0 + for item in numeric: + try: + var taseq = resolveTAddress(item, Port(443)) + except TransportAddressError: + inc(errcounter) + check errcounter == len(numeric) + + test "Faulty resolveTAddress(string, IPv6) for IPv4 address": + var numeric = ["0.0.0.0:0", "255.0.0.255:54321", "128.128.128.128:12345", + "255.255.255.255:65535"] + var errcounter = 0 + for item in numeric: + try: + var taseq = resolveTAddress(item, IpAddressFamily.IPv6) + except TransportAddressError: + inc(errcounter) + check errcounter == len(numeric) + + test "Faulty resolveTAddress(string, Port, IPv6) for IPv4 address": + var numeric = ["0.0.0.0", "255.0.0.255", "128.128.128.128", + "255.255.255.255"] + var errcounter = 0 + for item in numeric: + try: + var taseq = resolveTAddress(item, Port(443), IpAddressFamily.IPv6) + except TransportAddressError: + inc(errcounter) + check errcounter == len(numeric)