diff --git a/libp2p/multistream.nim b/libp2p/multistream.nim index a1f3fbef1..991e56bbf 100644 --- a/libp2p/multistream.nim +++ b/libp2p/multistream.nim @@ -7,8 +7,8 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import strutils -import chronos, chronicles +import strutils, sequtils +import chronos, chronicles, stew/byteutils import vbuffer, protocols/protocol, streams/[connection, @@ -37,9 +37,10 @@ type MultistreamSelect* = object handlers*: seq[HandlerHolder] - codec*: string - na: string - ls: string + codec*: seq[byte] + na*: seq[byte] + ls*: seq[byte] + lp: LenPrefixed MultistreamHandshakeException* = object of CatchableError @@ -47,27 +48,30 @@ proc newMultistreamHandshakeException*(): ref Exception {.inline.} = result = newException(MultistreamHandshakeException, "could not perform multistream handshake") -proc append(item: Future[seq[byte]]): Future[seq[byte]] {.async.} = - result = await item - result.add(byte('\n')) - var appendNl: Through[seq[byte]] = proc (i: Source[seq[byte]]): Source[seq[byte]] {.gcsafe.} = + proc append(item: Future[seq[byte]]): Future[seq[byte]] {.async.} = + result = await item + result.add(byte('\n')) + return iterator(): Future[seq[byte]] {.closure.} = for item in i: yield append(item) -proc strip(item: Future[seq[byte]]): Future[seq[byte]] {.async.} = - result = await item - if result[^1] == byte('\n'): - result.setLen(result.high) - var stripNl: Through[seq[byte]] = proc (i: Source[seq[byte]]): Source[seq[byte]] {.gcsafe.} = + proc strip(item: Future[seq[byte]]): Future[seq[byte]] {.async.} = + result = await item + if result.len > 0 and result[^1] == byte('\n'): + result.setLen(result.high) + return iterator(): Future[seq[byte]] {.closure.} = for item in i: yield strip(item) proc init*(M: type[MultistreamSelect]): MultistreamSelect = - M(codec: Codec, ls: Ls, na: Na) + M(codec: toSeq(Codec).mapIt( it.byte ), + ls: Ls.toBytes(), + na: Na.toBytes(), + lp: LenPrefixed.init()) proc select*(m: MultistreamSelect, conn: Connection, @@ -75,29 +79,26 @@ proc select*(m: MultistreamSelect, Future[string] {.async.} = trace "initiating handshake", codec = m.codec var pushable = Pushable[seq[byte]].init() # pushable source - var lp = LenPrefixed.init() - var sink = pipe(pushable, - appendNl, - lp.encoder, - conn) - - let source = pipe(conn, - lp.decoder, - stripNl) + var source = pipe(pushable, + appendNl, + m.lp.encoder, + conn.toThrough, + m.lp.decoder, + stripNl) # handshake first - await pushable.push(cast[seq[byte]](m.codec)) + await pushable.push(m.codec) # (common optimization) if we've got # protos send the first one out immediately # without waiting for the handshake response if protos.len > 0: - await pushable.push(cast[seq[byte]](protos[0])) + await pushable.push(protos[0].toBytes()) # check for handshake result - result = cast[string](await source()) - if result != m.codec: + var res = await source() + if res != m.codec: error "handshake failed", codec = result.toHex() raise newMultistreamHandshakeException() @@ -106,18 +107,17 @@ proc select*(m: MultistreamSelect, while i < protos.len: # first read because we've the outstanding requirest above trace "reading requested proto" - for chunk in source: - result = cast[string](await chunk) + res = await source() - if result == protos[i]: - trace "succesfully selected ", proto = proto - break + var protoBytes = protos[i].toBytes() + if res == protoBytes: + trace "succesfully selected ", proto = protos[i] + return protos[i] if i > 0: - trace "selecting proto", proto = proto - await pushable.push(cast[seq[byte]](protos[i])) # select proto + trace "selecting proto", proto = protos[i] + await pushable.push(protoBytes) # select proto i.inc() - await sink proc select*(m: MultistreamSelect, conn: Connection, @@ -130,92 +130,98 @@ proc select*(m: MultistreamSelect, proc select*(m: MultistreamSelect, conn: Connection): Future[bool] = m.select(conn, "") -# proc list*(m: MultistreamSelect, -# conn: Connection): Future[seq[string]] {.async.} = -# ## list remote protos requests on connection -# if not await m.select(conn): -# return +proc list*(m: MultistreamSelect, + conn: Connection): Future[seq[string]] {.async.} = + ## list remote protos requests on connection + if not await m.select(conn): + return -# await conn.write(m.ls) # send ls + var pushable = Pushable[seq[byte]].init() + var source = pipe(pushable, + appendNl, + m.lp.encoder, + conn.toThrough, + m.lp.decoder, + stripNl) -# var list = newSeq[string]() -# let ms = cast[string]((await conn.readLp())) -# for s in ms.split("\n"): -# if s.len() > 0: -# list.add(s) + await pushable.push(m.ls) # send ls -# result = list + var list = newSeq[string]() + for chunk in source: + var msg = string.fromBytes((await chunk)) + for s in msg.split("\n"): + if s.len() > 0: + list.add(s) -# proc handle*(m: MultistreamSelect, conn: Connection) {.async, gcsafe.} = -# trace "handle: starting multistream handling" -# try: -# while not conn.closed: -# var ms = cast[string]((await conn.readLp())) -# ms.removeSuffix("\n") + result = list -# trace "handle: got request for ", ms -# if ms.len() <= 0: -# trace "handle: invalid proto" -# await conn.write(m.na) +proc handle*(m: MultistreamSelect, conn: Connection) {.async, gcsafe.} = + trace "handle: starting multistream handling" + try: + var pushable = Pushable[seq[byte]].init() + var source = pipe(pushable, + appendNl, + m.lp.encoder, + conn.toThrough, + m.lp.decoder, + stripNl) -# if m.handlers.len() == 0: -# trace "handle: sending `na` for protocol ", protocol = ms -# await conn.write(m.na) -# continue + for chunk in source: + var msg = string.fromBytes((await chunk)) + trace "got request for ", msg + if msg.len <= 0: + trace "invalid proto" + await pushable.push(m.na) -# case ms: -# of "ls": -# trace "handle: listing protos" -# var protos = "" -# for h in m.handlers: -# protos &= (h.proto & "\n") -# await conn.writeLp(protos) -# of Codec: -# await conn.write(m.codec) -# else: -# for h in m.handlers: -# if (not isNil(h.match) and h.match(ms)) or ms == h.proto: -# trace "found handler for", protocol = ms -# await conn.writeLp((h.proto & "\n")) -# try: -# await h.protocol.handler(conn, ms) -# return -# except CatchableError as exc: -# warn "exception while handling", msg = exc.msg -# return -# warn "no handlers for ", protocol = ms -# await conn.write(m.na) -# except CatchableError as exc: -# trace "Exception occurred", exc = exc.msg -# finally: -# trace "leaving multistream loop" + if m.handlers.len() == 0: + trace "sending `na` for protocol ", protocol = msg + await pushable.push(m.na) + continue -# proc addHandler*[T: LPProtocol](m: MultistreamSelect, -# codec: string, -# protocol: T, -# matcher: Matcher = nil) = -# ## register a protocol -# # TODO: This is a bug in chronicles, -# # it break if I uncoment this line. -# # Which is almost the same as the -# # one on the next override of addHandler -# # -# # trace "registering protocol", codec = codec -# m.handlers.add(HandlerHolder(proto: codec, -# protocol: protocol, -# match: matcher)) + case msg: + of Ls: + trace "listing protos" + for h in m.handlers: + await pushable.push(h.proto.toBytes()) + of Codec: + trace "handling handshake" + await pushable.push(m.codec) + else: + for h in m.handlers: + if (not isNil(h.match) and h.match(msg)) or msg == h.proto: + trace "found handler for", protocol = msg + await pushable.push(h.proto.toBytes()) + try: + await h.protocol.handler(conn, msg) + return + except CatchableError as exc: + warn "exception while handling", msg = exc.msg + return + warn "no handlers for ", protocol = msg + await pushable.push(m.na) + except CatchableError as exc: + trace "Exception occurred", exc = exc.msg + finally: + trace "leaving multistream loop" -# proc addHandler*[T: LPProtoHandler](m: MultistreamSelect, -# codec: string, -# handler: T, -# matcher: Matcher = nil) = -# ## helper to allow registering pure handlers +proc addHandler*(m: var MultistreamSelect, + codec: string, + protocol: LPProtocol, + matcher: Matcher = nil) = + ## register a protocol + trace "registering protocol", codec = codec + m.handlers.add(HandlerHolder(proto: codec, + protocol: protocol, + match: matcher)) -# trace "registering proto handler", codec = codec -# let protocol = new LPProtocol -# protocol.codec = codec -# protocol.handler = handler +proc addHandler*(m: var MultistreamSelect, + codec: string, + handler: LPProtoHandler, + matcher: Matcher = nil) = + ## helper to allow registering pure handlers -# m.handlers.add(HandlerHolder(proto: codec, -# protocol: protocol, -# match: matcher)) + trace "registering proto handler", codec = codec + let protocol = LPProtocol(codec: codec, handler: handler) + m.handlers.add(HandlerHolder(proto: codec, + protocol: protocol, + match: matcher))