nim-libp2p/libp2p/transports/tcptransport.nim

214 lines
6.1 KiB
Nim
Raw Normal View History

## Nim-LibP2P
2019-09-24 17:48:23 +00:00
## Copyright (c) 2019 Status Research & Development GmbH
## Licensed under either of
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
## at your option.
## This file may not be copied, modified, or distributed except according to
## those terms.
2020-08-02 10:22:49 +00:00
import oids
import chronos, chronicles, sequtils
import transport,
../errors,
../wire,
../multiaddress,
../multicodec,
../stream/connection,
../stream/chronosstream
2019-09-12 00:15:04 +00:00
logScope:
topics = "tcptransport"
2019-09-12 00:15:04 +00:00
const
TcpTransportTrackerName* = "libp2p.tcptransport"
type
TcpTransport* = ref object of Transport
server*: StreamServer
clients: seq[StreamTransport]
flags: set[ServerFlags]
cleanups*: seq[Future[void]]
handlers*: seq[Future[void]]
TcpTransportTracker* = ref object of TrackerBase
opened*: uint64
closed*: uint64
proc setupTcpTransportTracker(): TcpTransportTracker {.gcsafe.}
proc getTcpTransportTracker(): TcpTransportTracker {.gcsafe.} =
result = cast[TcpTransportTracker](getTracker(TcpTransportTrackerName))
if isNil(result):
result = setupTcpTransportTracker()
proc dumpTracking(): string {.gcsafe.} =
var tracker = getTcpTransportTracker()
result = "Opened tcp transports: " & $tracker.opened & "\n" &
"Closed tcp transports: " & $tracker.closed
proc leakTransport(): bool {.gcsafe.} =
var tracker = getTcpTransportTracker()
result = (tracker.opened != tracker.closed)
proc setupTcpTransportTracker(): TcpTransportTracker =
result = new TcpTransportTracker
result.opened = 0
result.closed = 0
result.dump = dumpTracking
result.isLeaked = leakTransport
addTracker(TcpTransportTrackerName, result)
proc connHandler*(t: TcpTransport,
client: StreamTransport,
initiator: bool): Connection =
debug "Handling tcp connection", address = $client.remoteAddress,
initiator = initiator,
clients = t.clients.len
let conn = Connection(
ChronosStream.init(
client,
dir = if initiator:
Direction.Out
else:
Direction.In))
2019-09-05 15:20:05 +00:00
if not initiator:
if not isNil(t.handler):
t.handlers &= t.handler(conn)
2020-04-06 21:33:44 +00:00
2020-05-23 17:10:01 +00:00
proc cleanup() {.async.} =
try:
await client.join() or conn.join()
trace "Cleaning up client", addrs = $client.remoteAddress,
conn = $conn.oid
t.clients.keepItIf( it != client )
if not(isNil(conn) and not conn.closed()):
2020-05-23 17:10:01 +00:00
await conn.close()
if not(isNil(client) and client.closed()):
await client.closeWait()
trace "Cleaned up client", addrs = $client.remoteAddress,
conn = $conn.oid
2020-05-23 17:10:01 +00:00
except CatchableError as exc:
let useExc {.used.} = exc
debug "Error cleaning up client", errMsg = exc.msg, s = conn
2020-05-23 17:10:01 +00:00
t.clients.add(client)
# All the errors are handled inside `cleanup()` procedure.
asyncSpawn cleanup()
try:
conn.observedAddr = MultiAddress.init(client.remoteAddress).tryGet()
except CatchableError as exc:
trace "Unable to get remote address", exc = exc.msg
if not isNil(client):
client.close()
raise exc
return conn
proc connCb(server: StreamServer,
2019-08-30 15:28:07 +00:00
client: StreamTransport) {.async, gcsafe.} =
trace "incoming connection", address = $client.remoteAddress
try:
let t = cast[TcpTransport](server.udata)
# we don't need result connection in this case
# as it's added inside connHandler
discard t.connHandler(client, false)
except CancelledError as exc:
debug "Connection setup cancelled", exc = exc.msg
await client.closeWait()
raise exc
except CatchableError as exc:
debug "Connection setup failed", exc = exc.msg
await client.closeWait()
proc init*(T: type TcpTransport, flags: set[ServerFlags] = {}): T =
result = T(flags: flags)
result.initTransport()
method initTransport*(t: TcpTransport) =
t.multicodec = multiCodec("tcp")
inc getTcpTransportTracker().opened
method close*(t: TcpTransport) {.async, gcsafe.} =
2020-05-27 20:46:25 +00:00
try:
## start the transport
trace "stopping transport"
await procCall Transport(t).close() # call base
checkFutures(await allFinished(
t.clients.mapIt(it.closeWait())))
2019-08-21 22:53:16 +00:00
2020-05-27 20:46:25 +00:00
# server can be nil
if not isNil(t.server):
t.server.stop()
await t.server.closeWait()
2020-05-27 20:46:25 +00:00
t.server = nil
2020-05-27 20:46:25 +00:00
for fut in t.handlers:
if not fut.finished:
fut.cancel()
2020-05-27 20:46:25 +00:00
checkFutures(
await allFinished(t.handlers))
t.handlers = @[]
2020-05-27 20:46:25 +00:00
for fut in t.cleanups:
if not fut.finished:
fut.cancel()
2020-05-27 20:46:25 +00:00
checkFutures(
await allFinished(t.cleanups))
t.cleanups = @[]
2020-05-27 20:46:25 +00:00
trace "transport stopped"
inc getTcpTransportTracker().closed
except CancelledError as exc:
raise exc
2020-05-27 20:46:25 +00:00
except CatchableError as exc:
trace "error shutting down tcp transport", exc = exc.msg
2019-08-26 15:37:15 +00:00
method listen*(t: TcpTransport,
ma: MultiAddress,
2020-05-05 15:55:02 +00:00
handler: ConnHandler):
2019-09-12 00:15:04 +00:00
Future[Future[void]] {.async, gcsafe.} =
discard await procCall Transport(t).listen(ma, handler) # call base
## listen on the transport
t.server = createStreamServer(t.ma, connCb, t.flags, t)
2019-09-12 00:15:04 +00:00
t.server.start()
2019-09-25 17:36:39 +00:00
# always get the resolved address in case we're bound to 0.0.0.0:0
t.ma = MultiAddress.init(t.server.sock.getLocalAddress()).tryGet()
2019-09-12 00:15:04 +00:00
result = t.server.join()
trace "started node on", address = t.ma
2019-08-26 15:37:15 +00:00
method dial*(t: TcpTransport,
address: MultiAddress):
2019-08-30 22:16:37 +00:00
Future[Connection] {.async, gcsafe.} =
trace "dialing remote peer", address = $address
## dial a peer
var client: StreamTransport
try:
client = await connect(address)
except CatchableError as exc:
trace "Exception dialing peer", exc = exc.msg
if not(isNil(client)):
await client.closeWait()
raise exc
return t.connHandler(client, true)
method handles*(t: TcpTransport, address: MultiAddress): bool {.gcsafe.} =
if procCall Transport(t).handles(address):
result = address.protocols.tryGet().filterIt( it == multiCodec("tcp") ).len > 0