Fix Windows/Unix behavior on datagram CONNRESET.

Add test for CONNRESET
This commit is contained in:
cheatfate 2018-06-15 03:28:02 +03:00
parent 708e581c62
commit a0c724e9d8
4 changed files with 55 additions and 21 deletions

View File

@ -227,7 +227,6 @@ template processCallbacks(loop: untyped) =
when defined(windows) or defined(nimdoc):
import winlean, sets, hashes
type
WSAPROC_TRANSMITFILE = proc(hSocket: SocketHandle, hFile: Handle,
nNumberOfBytesToWrite: DWORD,
@ -341,10 +340,10 @@ when defined(windows) or defined(nimdoc):
loop.processCallbacks()
proc getFunc(s: SocketHandle, fun: var pointer, guid: var GUID): bool =
var bytesRet: Dword
var bytesRet: DWORD
fun = nil
result = WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, addr guid,
sizeof(GUID).Dword, addr fun, sizeof(pointer).Dword,
sizeof(GUID).DWORD, addr fun, sizeof(pointer).DWORD,
addr bytesRet, nil, nil) == 0
proc initAPI() =
@ -366,20 +365,24 @@ when defined(windows) or defined(nimdoc):
var funcPointer: pointer = nil
if not getFunc(sock, funcPointer, WSAID_CONNECTEX):
let err = osLastError()
close(sock)
raiseOSError(osLastError())
raiseOSError(err)
loop.connectEx = cast[WSAPROC_CONNECTEX](funcPointer)
if not getFunc(sock, funcPointer, WSAID_ACCEPTEX):
let err = osLastError()
close(sock)
raiseOSError(osLastError())
raiseOSError(err)
loop.acceptEx = cast[WSAPROC_ACCEPTEX](funcPointer)
if not getFunc(sock, funcPointer, WSAID_GETACCEPTEXSOCKADDRS):
let err = osLastError()
close(sock)
raiseOSError(osLastError())
raiseOSError(err)
loop.getAcceptExSockAddrs = cast[WSAPROC_GETACCEPTEXSOCKADDRS](funcPointer)
if not getFunc(sock, funcPointer, WSAID_TRANSMITFILE):
let err = osLastError()
close(sock)
raiseOSError(osLastError())
raiseOSError(err)
loop.transmitFile = cast[WSAPROC_TRANSMITFILE](funcPointer)
close(sock)

View File

@ -63,6 +63,9 @@ template setWriterWSABuffer(t, v: untyped) =
(t).wwsabuf.len = cast[int32](v.buflen)
when defined(windows):
const
IOC_VENDOR = DWORD(0x18000000)
SIO_UDP_CONNRESET = DWORD(winlean.IOC_IN) or IOC_VENDOR or DWORD(12)
proc writeDatagramLoop(udata: pointer) =
var bytesCount: int32
@ -213,10 +216,10 @@ when defined(windows):
localSock = createAsyncSocket(Domain.AF_INET6, SockType.SOCK_DGRAM,
Protocol.IPPROTO_UDP)
if localSock == asyncInvalidSocket:
raiseOsError(osLastError())
raiseTransportOsError(osLastError())
else:
if not setSocketBlocking(SocketHandle(sock), false):
raiseOsError(osLastError())
raiseTransportOsError(osLastError())
localSock = sock
register(localSock)
@ -226,7 +229,15 @@ when defined(windows):
let err = osLastError()
if sock == asyncInvalidSocket:
closeAsyncSocket(localSock)
raiseOsError(err)
raiseTransportOsError(err)
## Fix for Q263823.
var bytesRet: DWORD
var bval = WINBOOL(0)
if WSAIoctl(SocketHandle(localSock), SIO_UDP_CONNRESET, addr bval,
sizeof(WINBOOL).DWORD, nil, DWORD(0),
addr bytesRet, nil, nil) != 0:
raiseTransportOsError(osLastError())
if local.port != Port(0):
var saddr: Sockaddr_storage
@ -237,7 +248,7 @@ when defined(windows):
let err = osLastError()
if sock == asyncInvalidSocket:
closeAsyncSocket(localSock)
raiseOsError(err)
raiseTransportOsError(err)
result.local = local
else:
var saddr: Sockaddr_storage
@ -253,7 +264,7 @@ when defined(windows):
let err = osLastError()
if sock == asyncInvalidSocket:
closeAsyncSocket(localSock)
raiseOsError(err)
raiseTransportOsError(err)
if remote.port != Port(0):
var saddr: Sockaddr_storage
@ -264,7 +275,7 @@ when defined(windows):
let err = osLastError()
if sock == asyncInvalidSocket:
closeAsyncSocket(localSock)
raiseOsError(err)
raiseTransportOsError(err)
result.remote = remote
result.fd = localSock
@ -299,6 +310,7 @@ when defined(windows):
GC_unref(transp)
else:
# Linux/BSD/MacOS part
proc readDatagramLoop(udata: pointer) =
var
@ -403,10 +415,10 @@ else:
localSock = createAsyncSocket(Domain.AF_INET6, SockType.SOCK_DGRAM,
Protocol.IPPROTO_UDP)
if localSock == asyncInvalidSocket:
raiseOsError(osLastError())
raiseTransportOsError(osLastError())
else:
if not setSocketBlocking(SocketHandle(sock), false):
raiseOsError(osLastError())
raiseTransportOsError(osLastError())
localSock = sock
register(localSock)
@ -416,7 +428,7 @@ else:
let err = osLastError()
if sock == asyncInvalidSocket:
closeAsyncSocket(localSock)
raiseOsError(err)
raiseTransportOsError(err)
if local.port != Port(0):
var saddr: Sockaddr_storage
@ -427,7 +439,7 @@ else:
let err = osLastError()
if sock == asyncInvalidSocket:
closeAsyncSocket(localSock)
raiseOsError(err)
raiseTransportOsError(err)
result.local = local
if remote.port != Port(0):
@ -439,7 +451,7 @@ else:
let err = osLastError()
if sock == asyncInvalidSocket:
closeAsyncSocket(localSock)
raiseOsError(err)
raiseTransportOsError(err)
result.remote = remote
result.fd = localSock

View File

@ -111,7 +111,7 @@ proc remoteAddress*(transp: StreamTransport): TransportAddress =
var slen = SockLen(sizeof(saddr))
if getpeername(SocketHandle(transp.fd), cast[ptr SockAddr](addr saddr),
addr slen) != 0:
raiseOsError(osLastError())
raiseTransportOsError(osLastError())
fromSockAddr(saddr, slen, transp.remote.address, transp.remote.port)
result = transp.remote
@ -124,7 +124,7 @@ proc localAddress*(transp: StreamTransport): TransportAddress =
var slen = SockLen(sizeof(saddr))
if getsockname(SocketHandle(transp.fd), cast[ptr SockAddr](addr saddr),
addr slen) != 0:
raiseOsError(osLastError())
raiseTransportOsError(osLastError())
fromSockAddr(saddr, slen, transp.local.address, transp.local.port)
result = transp.local

View File

@ -429,6 +429,23 @@ proc test3(bounded: bool): Future[int] {.async.} =
for i in 0..<ClientsCount:
result += counters[i]
proc client20(transp: DatagramTransport,
raddr: TransportAddress): Future[void] {.async.} =
var counterPtr = cast[ptr int](transp.udata)
counterPtr[] = 1
transp.close()
proc testConnReset(): Future[bool] {.async.} =
var ta = initTAddress("127.0.0.1:65000")
var counter = 0
var dgram1 = newDatagramTransport(client1, local = ta)
dgram1.close()
var dgram2 = newDatagramTransport(client20, udata = addr counter)
var data = "MESSAGE"
discard dgram2.sendTo(data, ta)
await sleepAsync(1000)
result = (counter == 0)
when isMainModule:
const
m1 = "sendTo(pointer) test (" & $TestsCount & " messages)"
@ -458,3 +475,5 @@ when isMainModule:
check waitFor(test3(false)) == ClientsCount * MessagesCount
test m8:
check waitFor(test3(true)) == ClientsCount * MessagesCount
test "Datagram connection reset test":
check waitFor(testConnReset()) == true