mirror of
https://github.com/status-im/nim-chronos.git
synced 2025-02-22 16:08:23 +00:00
Fix Windows/Unix behavior on datagram CONNRESET.
Add test for CONNRESET
This commit is contained in:
parent
708e581c62
commit
a0c724e9d8
@ -227,7 +227,6 @@ template processCallbacks(loop: untyped) =
|
|||||||
|
|
||||||
when defined(windows) or defined(nimdoc):
|
when defined(windows) or defined(nimdoc):
|
||||||
import winlean, sets, hashes
|
import winlean, sets, hashes
|
||||||
|
|
||||||
type
|
type
|
||||||
WSAPROC_TRANSMITFILE = proc(hSocket: SocketHandle, hFile: Handle,
|
WSAPROC_TRANSMITFILE = proc(hSocket: SocketHandle, hFile: Handle,
|
||||||
nNumberOfBytesToWrite: DWORD,
|
nNumberOfBytesToWrite: DWORD,
|
||||||
@ -341,10 +340,10 @@ when defined(windows) or defined(nimdoc):
|
|||||||
loop.processCallbacks()
|
loop.processCallbacks()
|
||||||
|
|
||||||
proc getFunc(s: SocketHandle, fun: var pointer, guid: var GUID): bool =
|
proc getFunc(s: SocketHandle, fun: var pointer, guid: var GUID): bool =
|
||||||
var bytesRet: Dword
|
var bytesRet: DWORD
|
||||||
fun = nil
|
fun = nil
|
||||||
result = WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, addr guid,
|
result = WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, addr guid,
|
||||||
sizeof(GUID).Dword, addr fun, sizeof(pointer).Dword,
|
sizeof(GUID).DWORD, addr fun, sizeof(pointer).DWORD,
|
||||||
addr bytesRet, nil, nil) == 0
|
addr bytesRet, nil, nil) == 0
|
||||||
|
|
||||||
proc initAPI() =
|
proc initAPI() =
|
||||||
@ -366,20 +365,24 @@ when defined(windows) or defined(nimdoc):
|
|||||||
|
|
||||||
var funcPointer: pointer = nil
|
var funcPointer: pointer = nil
|
||||||
if not getFunc(sock, funcPointer, WSAID_CONNECTEX):
|
if not getFunc(sock, funcPointer, WSAID_CONNECTEX):
|
||||||
|
let err = osLastError()
|
||||||
close(sock)
|
close(sock)
|
||||||
raiseOSError(osLastError())
|
raiseOSError(err)
|
||||||
loop.connectEx = cast[WSAPROC_CONNECTEX](funcPointer)
|
loop.connectEx = cast[WSAPROC_CONNECTEX](funcPointer)
|
||||||
if not getFunc(sock, funcPointer, WSAID_ACCEPTEX):
|
if not getFunc(sock, funcPointer, WSAID_ACCEPTEX):
|
||||||
|
let err = osLastError()
|
||||||
close(sock)
|
close(sock)
|
||||||
raiseOSError(osLastError())
|
raiseOSError(err)
|
||||||
loop.acceptEx = cast[WSAPROC_ACCEPTEX](funcPointer)
|
loop.acceptEx = cast[WSAPROC_ACCEPTEX](funcPointer)
|
||||||
if not getFunc(sock, funcPointer, WSAID_GETACCEPTEXSOCKADDRS):
|
if not getFunc(sock, funcPointer, WSAID_GETACCEPTEXSOCKADDRS):
|
||||||
|
let err = osLastError()
|
||||||
close(sock)
|
close(sock)
|
||||||
raiseOSError(osLastError())
|
raiseOSError(err)
|
||||||
loop.getAcceptExSockAddrs = cast[WSAPROC_GETACCEPTEXSOCKADDRS](funcPointer)
|
loop.getAcceptExSockAddrs = cast[WSAPROC_GETACCEPTEXSOCKADDRS](funcPointer)
|
||||||
if not getFunc(sock, funcPointer, WSAID_TRANSMITFILE):
|
if not getFunc(sock, funcPointer, WSAID_TRANSMITFILE):
|
||||||
|
let err = osLastError()
|
||||||
close(sock)
|
close(sock)
|
||||||
raiseOSError(osLastError())
|
raiseOSError(err)
|
||||||
loop.transmitFile = cast[WSAPROC_TRANSMITFILE](funcPointer)
|
loop.transmitFile = cast[WSAPROC_TRANSMITFILE](funcPointer)
|
||||||
close(sock)
|
close(sock)
|
||||||
|
|
||||||
|
@ -63,6 +63,9 @@ template setWriterWSABuffer(t, v: untyped) =
|
|||||||
(t).wwsabuf.len = cast[int32](v.buflen)
|
(t).wwsabuf.len = cast[int32](v.buflen)
|
||||||
|
|
||||||
when defined(windows):
|
when defined(windows):
|
||||||
|
const
|
||||||
|
IOC_VENDOR = DWORD(0x18000000)
|
||||||
|
SIO_UDP_CONNRESET = DWORD(winlean.IOC_IN) or IOC_VENDOR or DWORD(12)
|
||||||
|
|
||||||
proc writeDatagramLoop(udata: pointer) =
|
proc writeDatagramLoop(udata: pointer) =
|
||||||
var bytesCount: int32
|
var bytesCount: int32
|
||||||
@ -213,10 +216,10 @@ when defined(windows):
|
|||||||
localSock = createAsyncSocket(Domain.AF_INET6, SockType.SOCK_DGRAM,
|
localSock = createAsyncSocket(Domain.AF_INET6, SockType.SOCK_DGRAM,
|
||||||
Protocol.IPPROTO_UDP)
|
Protocol.IPPROTO_UDP)
|
||||||
if localSock == asyncInvalidSocket:
|
if localSock == asyncInvalidSocket:
|
||||||
raiseOsError(osLastError())
|
raiseTransportOsError(osLastError())
|
||||||
else:
|
else:
|
||||||
if not setSocketBlocking(SocketHandle(sock), false):
|
if not setSocketBlocking(SocketHandle(sock), false):
|
||||||
raiseOsError(osLastError())
|
raiseTransportOsError(osLastError())
|
||||||
localSock = sock
|
localSock = sock
|
||||||
register(localSock)
|
register(localSock)
|
||||||
|
|
||||||
@ -226,7 +229,15 @@ when defined(windows):
|
|||||||
let err = osLastError()
|
let err = osLastError()
|
||||||
if sock == asyncInvalidSocket:
|
if sock == asyncInvalidSocket:
|
||||||
closeAsyncSocket(localSock)
|
closeAsyncSocket(localSock)
|
||||||
raiseOsError(err)
|
raiseTransportOsError(err)
|
||||||
|
|
||||||
|
## Fix for Q263823.
|
||||||
|
var bytesRet: DWORD
|
||||||
|
var bval = WINBOOL(0)
|
||||||
|
if WSAIoctl(SocketHandle(localSock), SIO_UDP_CONNRESET, addr bval,
|
||||||
|
sizeof(WINBOOL).DWORD, nil, DWORD(0),
|
||||||
|
addr bytesRet, nil, nil) != 0:
|
||||||
|
raiseTransportOsError(osLastError())
|
||||||
|
|
||||||
if local.port != Port(0):
|
if local.port != Port(0):
|
||||||
var saddr: Sockaddr_storage
|
var saddr: Sockaddr_storage
|
||||||
@ -237,7 +248,7 @@ when defined(windows):
|
|||||||
let err = osLastError()
|
let err = osLastError()
|
||||||
if sock == asyncInvalidSocket:
|
if sock == asyncInvalidSocket:
|
||||||
closeAsyncSocket(localSock)
|
closeAsyncSocket(localSock)
|
||||||
raiseOsError(err)
|
raiseTransportOsError(err)
|
||||||
result.local = local
|
result.local = local
|
||||||
else:
|
else:
|
||||||
var saddr: Sockaddr_storage
|
var saddr: Sockaddr_storage
|
||||||
@ -253,7 +264,7 @@ when defined(windows):
|
|||||||
let err = osLastError()
|
let err = osLastError()
|
||||||
if sock == asyncInvalidSocket:
|
if sock == asyncInvalidSocket:
|
||||||
closeAsyncSocket(localSock)
|
closeAsyncSocket(localSock)
|
||||||
raiseOsError(err)
|
raiseTransportOsError(err)
|
||||||
|
|
||||||
if remote.port != Port(0):
|
if remote.port != Port(0):
|
||||||
var saddr: Sockaddr_storage
|
var saddr: Sockaddr_storage
|
||||||
@ -264,7 +275,7 @@ when defined(windows):
|
|||||||
let err = osLastError()
|
let err = osLastError()
|
||||||
if sock == asyncInvalidSocket:
|
if sock == asyncInvalidSocket:
|
||||||
closeAsyncSocket(localSock)
|
closeAsyncSocket(localSock)
|
||||||
raiseOsError(err)
|
raiseTransportOsError(err)
|
||||||
result.remote = remote
|
result.remote = remote
|
||||||
|
|
||||||
result.fd = localSock
|
result.fd = localSock
|
||||||
@ -299,6 +310,7 @@ when defined(windows):
|
|||||||
GC_unref(transp)
|
GC_unref(transp)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
# Linux/BSD/MacOS part
|
||||||
|
|
||||||
proc readDatagramLoop(udata: pointer) =
|
proc readDatagramLoop(udata: pointer) =
|
||||||
var
|
var
|
||||||
@ -403,10 +415,10 @@ else:
|
|||||||
localSock = createAsyncSocket(Domain.AF_INET6, SockType.SOCK_DGRAM,
|
localSock = createAsyncSocket(Domain.AF_INET6, SockType.SOCK_DGRAM,
|
||||||
Protocol.IPPROTO_UDP)
|
Protocol.IPPROTO_UDP)
|
||||||
if localSock == asyncInvalidSocket:
|
if localSock == asyncInvalidSocket:
|
||||||
raiseOsError(osLastError())
|
raiseTransportOsError(osLastError())
|
||||||
else:
|
else:
|
||||||
if not setSocketBlocking(SocketHandle(sock), false):
|
if not setSocketBlocking(SocketHandle(sock), false):
|
||||||
raiseOsError(osLastError())
|
raiseTransportOsError(osLastError())
|
||||||
localSock = sock
|
localSock = sock
|
||||||
register(localSock)
|
register(localSock)
|
||||||
|
|
||||||
@ -416,7 +428,7 @@ else:
|
|||||||
let err = osLastError()
|
let err = osLastError()
|
||||||
if sock == asyncInvalidSocket:
|
if sock == asyncInvalidSocket:
|
||||||
closeAsyncSocket(localSock)
|
closeAsyncSocket(localSock)
|
||||||
raiseOsError(err)
|
raiseTransportOsError(err)
|
||||||
|
|
||||||
if local.port != Port(0):
|
if local.port != Port(0):
|
||||||
var saddr: Sockaddr_storage
|
var saddr: Sockaddr_storage
|
||||||
@ -427,7 +439,7 @@ else:
|
|||||||
let err = osLastError()
|
let err = osLastError()
|
||||||
if sock == asyncInvalidSocket:
|
if sock == asyncInvalidSocket:
|
||||||
closeAsyncSocket(localSock)
|
closeAsyncSocket(localSock)
|
||||||
raiseOsError(err)
|
raiseTransportOsError(err)
|
||||||
result.local = local
|
result.local = local
|
||||||
|
|
||||||
if remote.port != Port(0):
|
if remote.port != Port(0):
|
||||||
@ -439,7 +451,7 @@ else:
|
|||||||
let err = osLastError()
|
let err = osLastError()
|
||||||
if sock == asyncInvalidSocket:
|
if sock == asyncInvalidSocket:
|
||||||
closeAsyncSocket(localSock)
|
closeAsyncSocket(localSock)
|
||||||
raiseOsError(err)
|
raiseTransportOsError(err)
|
||||||
result.remote = remote
|
result.remote = remote
|
||||||
|
|
||||||
result.fd = localSock
|
result.fd = localSock
|
||||||
|
@ -111,7 +111,7 @@ proc remoteAddress*(transp: StreamTransport): TransportAddress =
|
|||||||
var slen = SockLen(sizeof(saddr))
|
var slen = SockLen(sizeof(saddr))
|
||||||
if getpeername(SocketHandle(transp.fd), cast[ptr SockAddr](addr saddr),
|
if getpeername(SocketHandle(transp.fd), cast[ptr SockAddr](addr saddr),
|
||||||
addr slen) != 0:
|
addr slen) != 0:
|
||||||
raiseOsError(osLastError())
|
raiseTransportOsError(osLastError())
|
||||||
fromSockAddr(saddr, slen, transp.remote.address, transp.remote.port)
|
fromSockAddr(saddr, slen, transp.remote.address, transp.remote.port)
|
||||||
result = transp.remote
|
result = transp.remote
|
||||||
|
|
||||||
@ -124,7 +124,7 @@ proc localAddress*(transp: StreamTransport): TransportAddress =
|
|||||||
var slen = SockLen(sizeof(saddr))
|
var slen = SockLen(sizeof(saddr))
|
||||||
if getsockname(SocketHandle(transp.fd), cast[ptr SockAddr](addr saddr),
|
if getsockname(SocketHandle(transp.fd), cast[ptr SockAddr](addr saddr),
|
||||||
addr slen) != 0:
|
addr slen) != 0:
|
||||||
raiseOsError(osLastError())
|
raiseTransportOsError(osLastError())
|
||||||
fromSockAddr(saddr, slen, transp.local.address, transp.local.port)
|
fromSockAddr(saddr, slen, transp.local.address, transp.local.port)
|
||||||
result = transp.local
|
result = transp.local
|
||||||
|
|
||||||
|
@ -429,6 +429,23 @@ proc test3(bounded: bool): Future[int] {.async.} =
|
|||||||
for i in 0..<ClientsCount:
|
for i in 0..<ClientsCount:
|
||||||
result += counters[i]
|
result += counters[i]
|
||||||
|
|
||||||
|
proc client20(transp: DatagramTransport,
|
||||||
|
raddr: TransportAddress): Future[void] {.async.} =
|
||||||
|
var counterPtr = cast[ptr int](transp.udata)
|
||||||
|
counterPtr[] = 1
|
||||||
|
transp.close()
|
||||||
|
|
||||||
|
proc testConnReset(): Future[bool] {.async.} =
|
||||||
|
var ta = initTAddress("127.0.0.1:65000")
|
||||||
|
var counter = 0
|
||||||
|
var dgram1 = newDatagramTransport(client1, local = ta)
|
||||||
|
dgram1.close()
|
||||||
|
var dgram2 = newDatagramTransport(client20, udata = addr counter)
|
||||||
|
var data = "MESSAGE"
|
||||||
|
discard dgram2.sendTo(data, ta)
|
||||||
|
await sleepAsync(1000)
|
||||||
|
result = (counter == 0)
|
||||||
|
|
||||||
when isMainModule:
|
when isMainModule:
|
||||||
const
|
const
|
||||||
m1 = "sendTo(pointer) test (" & $TestsCount & " messages)"
|
m1 = "sendTo(pointer) test (" & $TestsCount & " messages)"
|
||||||
@ -458,3 +475,5 @@ when isMainModule:
|
|||||||
check waitFor(test3(false)) == ClientsCount * MessagesCount
|
check waitFor(test3(false)) == ClientsCount * MessagesCount
|
||||||
test m8:
|
test m8:
|
||||||
check waitFor(test3(true)) == ClientsCount * MessagesCount
|
check waitFor(test3(true)) == ClientsCount * MessagesCount
|
||||||
|
test "Datagram connection reset test":
|
||||||
|
check waitFor(testConnReset()) == true
|
||||||
|
Loading…
x
Reference in New Issue
Block a user