From 1211ffbb5c555d4b3abfac14373c0db4d83f0c77 Mon Sep 17 00:00:00 2001 From: cheatfate Date: Mon, 19 Nov 2018 04:52:11 +0200 Subject: [PATCH] Add daemon api sources. --- .travis.yml | 3 +- libp2p.nim | 11 + libp2p.nimble | 12 + libp2p/daemon/daemonapi.nim | 866 ++++++++++++++++++++++++++++++++ libp2p/daemon/transpool.nim | 141 ++++++ libp2p/protobuf/minprotobuf.nim | 276 ++++++++++ libp2p/protobuf/varint.nim | 278 ++++++++++ 7 files changed, 1585 insertions(+), 2 deletions(-) create mode 100644 libp2p.nim create mode 100644 libp2p.nimble create mode 100644 libp2p/daemon/daemonapi.nim create mode 100644 libp2p/daemon/transpool.nim create mode 100644 libp2p/protobuf/minprotobuf.nim create mode 100644 libp2p/protobuf/varint.nim diff --git a/.travis.yml b/.travis.yml index 92f155036..c4dd98423 100644 --- a/.travis.yml +++ b/.travis.yml @@ -39,11 +39,10 @@ install: cd ../.. ; }" - "export PATH=$PWD/nim/$NIMVER/bin:$PATH" - - echo $TRAVIS_GO_VERSION - go get -v github.com/libp2p/go-libp2p-daemon - cd $GOPATH/src/github.com/libp2p/go-libp2p-daemon - make script: - # - nimble install -y + - nimble install -y # - nimble test diff --git a/libp2p.nim b/libp2p.nim new file mode 100644 index 000000000..01d183410 --- /dev/null +++ b/libp2p.nim @@ -0,0 +1,11 @@ +## Nim-LibP2P +## Copyright (c) 2018 Status Research & Development GmbH +## Licensed under either of +## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +## * MIT license ([LICENSE-MIT](LICENSE-MIT)) +## at your option. +## This file may not be copied, modified, or distributed except according to +## those terms. +import libp2p/daemon/[daemonapi, transpool] +import libp2p/protobuf/[minprotobuf, varint] +export daemonapi, minprotobuf, varint, transpool \ No newline at end of file diff --git a/libp2p.nimble b/libp2p.nimble new file mode 100644 index 000000000..78c887a8a --- /dev/null +++ b/libp2p.nimble @@ -0,0 +1,12 @@ +mode = ScriptMode.Verbose + +packageName = "nim-libp2p" +version = "0.0.1" +author = "Status Research & Development GmbH" +description = "LibP2P implementation" +license = "MIT" +skipDirs = @["tests", "Nim"] + +requires "nim > 0.18.0" + +# task tests, "Runs the test suite": diff --git a/libp2p/daemon/daemonapi.nim b/libp2p/daemon/daemonapi.nim new file mode 100644 index 000000000..8b541cf92 --- /dev/null +++ b/libp2p/daemon/daemonapi.nim @@ -0,0 +1,866 @@ +## Nim-LibP2P +## Copyright (c) 2018 Status Research & Development GmbH +## Licensed under either of +## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +## * MIT license ([LICENSE-MIT](LICENSE-MIT)) +## at your option. +## This file may not be copied, modified, or distributed except according to +## those terms. + +## This module implementes API for `go-libp2p-daemon`. +import os, osproc, strutils, tables, streams +import asyncdispatch2 +import ../protobuf/varint, ../protobuf/minprotobuf, transpool + +when not defined(windows): + import posix + +const + DefaultSocketPath* = "/tmp/p2pd.sock" + DefaultDaemonFile* = "p2pd" + +type + RequestType* {.pure.} = enum + IDENTITY = 0, + CONNECT = 1, + STREAM_OPEN = 2, + STREAM_HANDLER = 3, + DHT = 4, + LIST_PEERS = 5, + CONNMANAGER = 6, + DISCONNECT = 7 + + DHTRequestType* {.pure.} = enum + FIND_PEER = 0, + FIND_PEERS_CONNECTED_TO_PEER = 1, + FIND_PROVIDERS = 2, + GET_CLOSEST_PEERS = 3, + GET_PUBLIC_KEY = 4, + GET_VALUE = 5, + SEARCH_VALUE = 6, + PUT_VALUE = 7, + PROVIDE = 8 + + ConnManagerRequestType* {.pure.} = enum + TAG_PEER = 0, + UNTAG_PEER = 1, + TRIM = 2 + + ResponseKind* = enum + Malformed, + Error, + Success + + ResponseType* {.pure.} = enum + ERROR = 2, + STREAMINFO = 3, + IDENTITY = 4, + DHT = 5, + PEERINFO = 6 + + DHTResponseType* {.pure.} = enum + BEGIN = 0, + VALUE = 1, + END = 2 + + PeerID* = seq[byte] + MultiProtocol* = string + MultiAddress* = seq[byte] + CID* = seq[byte] + LibP2PPublicKey* = seq[byte] + DHTValue* = seq[byte] + + P2PStreamFlags* {.pure.} = enum + None, Closed, Inbound, Outbound + + P2PDaemonFlags* {.pure.} = enum + DHTClient, DHTFull, Bootstrap + + P2PStream* = ref object + flags*: set[P2PStreamFlags] + peer*: PeerID + raddress*: MultiAddress + protocol*: string + transp*: StreamTransport + + DaemonAPI* = ref object + pool*: TransportPool + flags*: set[P2PDaemonFlags] + address*: TransportAddress + sockname*: string + pattern*: string + ucounter*: int + process*: Process + handlers*: Table[string, P2PStreamCallback] + servers*: seq[StreamServer] + + PeerInfo* = object + peer*: PeerID + addresses: seq[MultiAddress] + + P2PStreamCallback* = proc(api: DaemonAPI, + stream: P2PStream): Future[void] {.gcsafe.} + + DaemonRemoteError* = object of Exception + DaemonLocalError* = object of Exception + +proc requestIdentity(): ProtoBuffer = + ## https://github.com/libp2p/go-libp2p-daemon/blob/master/conn.go + ## Processing function `doIdentify(req *pb.Request)`. + result = initProtoBuffer({WithVarintLength}) + result.write(initProtoField(1, cast[uint](RequestType.IDENTITY))) + result.finish() + +proc requestConnect(peerid: PeerID, + addresses: openarray[MultiAddress]): ProtoBuffer = + ## https://github.com/libp2p/go-libp2p-daemon/blob/master/conn.go + ## Processing function `doConnect(req *pb.Request)`. + result = initProtoBuffer({WithVarintLength}) + var msg = initProtoBuffer() + msg.write(initProtoField(1, peerid)) + for item in addresses: + msg.write(initProtoField(2, item)) + result.write(initProtoField(1, cast[uint](RequestType.CONNECT))) + result.write(initProtoField(2, msg)) + result.finish() + +proc requestDisconnect(peerid: PeerID): ProtoBuffer = + ## https://github.com/libp2p/go-libp2p-daemon/blob/master/conn.go + ## Processing function `doDisconnect(req *pb.Request)`. + result = initProtoBuffer({WithVarintLength}) + var msg = initProtoBuffer() + msg.write(initProtoField(1, peerid)) + result.write(initProtoField(1, cast[uint](RequestType.DISCONNECT))) + result.write(initProtoField(7, msg)) + result.finish() + +proc requestStreamOpen(peerid: PeerID, + protocols: openarray[string]): ProtoBuffer = + ## https://github.com/libp2p/go-libp2p-daemon/blob/master/conn.go + ## Processing function `doStreamOpen(req *pb.Request)`. + result = initProtoBuffer({WithVarintLength}) + var msg = initProtoBuffer() + msg.write(initProtoField(1, peerid)) + for item in protocols: + msg.write(initProtoField(2, item)) + result.write(initProtoField(1, cast[uint](RequestType.STREAM_OPEN))) + result.write(initProtoField(3, msg)) + result.finish() + +proc requestStreamHandler(path: string, + protocols: openarray[MultiProtocol]): ProtoBuffer = + ## https://github.com/libp2p/go-libp2p-daemon/blob/master/conn.go + ## Processing function `doStreamHandler(req *pb.Request)`. + result = initProtoBuffer({WithVarintLength}) + var msg = initProtoBuffer() + msg.write(initProtoField(1, path)) + for item in protocols: + msg.write(initProtoField(2, item)) + result.write(initProtoField(1, cast[uint](RequestType.STREAM_HANDLER))) + result.write(initProtoField(4, msg)) + result.finish() + +proc requestListPeers(): ProtoBuffer = + ## https://github.com/libp2p/go-libp2p-daemon/blob/master/conn.go + ## Processing function `doListPeers(req *pb.Request)` + result = initProtoBuffer({WithVarintLength}) + result.write(initProtoField(1, cast[uint](RequestType.LIST_PEERS))) + result.finish() + +proc requestDHTFindPeer(peer: PeerID, timeout = 0): ProtoBuffer = + ## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go + ## Processing function `doDHTFindPeer(req *pb.DHTRequest)`. + let msgid = cast[uint](DHTRequestType.FIND_PEER) + result = initProtoBuffer({WithVarintLength}) + var msg = initProtoBuffer() + msg.write(initProtoField(1, msgid)) + msg.write(initProtoField(2, peer)) + if timeout > 0: + msg.write(initProtoField(7, uint(timeout))) + msg.finish() + result.write(initProtoField(1, cast[uint](RequestType.DHT))) + result.write(initProtoField(5, msg)) + result.finish() + +proc requestDHTFindPeersConnectedToPeer(peer: PeerID, + timeout = 0): ProtoBuffer = + ## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go + ## Processing function `doDHTFindPeersConnectedToPeer(req *pb.DHTRequest)`. + let msgid = cast[uint](DHTRequestType.FIND_PEERS_CONNECTED_TO_PEER) + result = initProtoBuffer({WithVarintLength}) + var msg = initProtoBuffer() + msg.write(initProtoField(1, msgid)) + msg.write(initProtoField(2, peer)) + if timeout > 0: + msg.write(initProtoField(7, uint(timeout))) + msg.finish() + result.write(initProtoField(1, cast[uint](RequestType.DHT))) + result.write(initProtoField(5, msg)) + result.finish() + +proc requestDHTFindProviders(cid: CID, + count: uint32, timeout = 0): ProtoBuffer = + ## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go + ## Processing function `doDHTFindProviders(req *pb.DHTRequest)`. + let msgid = cast[uint](DHTRequestType.FIND_PROVIDERS) + result = initProtoBuffer({WithVarintLength}) + var msg = initProtoBuffer() + msg.write(initProtoField(1, msgid)) + msg.write(initProtoField(3, cid)) + msg.write(initProtoField(6, count)) + if timeout > 0: + msg.write(initProtoField(7, uint(timeout))) + msg.finish() + result.write(initProtoField(1, cast[uint](RequestType.DHT))) + result.write(initProtoField(5, msg)) + result.finish() + +proc requestDHTGetClosestPeers(key: string, timeout = 0): ProtoBuffer = + ## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go + ## Processing function `doDHTGetClosestPeers(req *pb.DHTRequest)`. + let msgid = cast[uint](DHTRequestType.GET_CLOSEST_PEERS) + result = initProtoBuffer({WithVarintLength}) + var msg = initProtoBuffer() + msg.write(initProtoField(1, msgid)) + msg.write(initProtoField(4, key)) + if timeout > 0: + msg.write(initProtoField(7, uint(timeout))) + msg.finish() + result.write(initProtoField(1, cast[uint](RequestType.DHT))) + result.write(initProtoField(5, msg)) + result.finish() + +proc requestDHTGetPublicKey(peer: PeerID, timeout = 0): ProtoBuffer = + ## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go + ## Processing function `doDHTGetPublicKey(req *pb.DHTRequest)`. + let msgid = cast[uint](DHTRequestType.GET_PUBLIC_KEY) + result = initProtoBuffer({WithVarintLength}) + var msg = initProtoBuffer() + msg.write(initProtoField(1, msgid)) + msg.write(initProtoField(2, peer)) + if timeout > 0: + msg.write(initProtoField(7, uint(timeout))) + msg.finish() + result.write(initProtoField(1, cast[uint](RequestType.DHT))) + result.write(initProtoField(5, msg)) + result.finish() + +proc requestDHTGetValue(key: string, timeout = 0): ProtoBuffer = + ## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go + ## Processing function `doDHTGetValue(req *pb.DHTRequest)`. + let msgid = cast[uint](DHTRequestType.GET_VALUE) + result = initProtoBuffer({WithVarintLength}) + var msg = initProtoBuffer() + msg.write(initProtoField(1, msgid)) + msg.write(initProtoField(4, key)) + if timeout > 0: + msg.write(initProtoField(7, uint(timeout))) + msg.finish() + result.write(initProtoField(1, cast[uint](RequestType.DHT))) + result.write(initProtoField(5, msg)) + result.finish() + +proc requestDHTSearchValue(key: string, timeout = 0): ProtoBuffer = + ## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go + ## Processing function `doDHTSearchValue(req *pb.DHTRequest)`. + let msgid = cast[uint](DHTRequestType.SEARCH_VALUE) + result = initProtoBuffer({WithVarintLength}) + var msg = initProtoBuffer() + msg.write(initProtoField(1, msgid)) + msg.write(initProtoField(4, key)) + if timeout > 0: + msg.write(initProtoField(7, uint(timeout))) + msg.finish() + result.write(initProtoField(1, cast[uint](RequestType.DHT))) + result.write(initProtoField(5, msg)) + result.finish() + +proc requestDHTPutValue(key: string, value: openarray[byte], + timeout = 0): ProtoBuffer = + ## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go + ## Processing function `doDHTPutValue(req *pb.DHTRequest)`. + let msgid = cast[uint](DHTRequestType.PUT_VALUE) + result = initProtoBuffer({WithVarintLength}) + var msg = initProtoBuffer() + msg.write(initProtoField(1, msgid)) + msg.write(initProtoField(4, key)) + msg.write(initProtoField(5, value)) + if timeout > 0: + msg.write(initProtoField(7, uint(timeout))) + msg.finish() + result.write(initProtoField(1, cast[uint](RequestType.DHT))) + result.write(initProtoField(5, msg)) + result.finish() + +proc requestDHTProvide(cid: CID, timeout = 0): ProtoBuffer = + ## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go + ## Processing function `doDHTProvide(req *pb.DHTRequest)`. + let msgid = cast[uint](DHTRequestType.PROVIDE) + result = initProtoBuffer({WithVarintLength}) + var msg = initProtoBuffer() + msg.write(initProtoField(1, msgid)) + msg.write(initProtoField(3, cid)) + if timeout > 0: + msg.write(initProtoField(7, uint(timeout))) + msg.finish() + result.write(initProtoField(1, cast[uint](RequestType.DHT))) + result.write(initProtoField(5, msg)) + result.finish() + +proc requestCMTagPeer(peer: PeerID, tag: string, weight: int): ProtoBuffer = + ## https://github.com/libp2p/go-libp2p-daemon/blob/master/connmgr.go#L18 + let msgid = cast[uint](ConnManagerRequestType.TAG_PEER) + result = initProtoBuffer({WithVarintLength}) + var msg = initProtoBuffer() + msg.write(initProtoField(1, msgid)) + msg.write(initProtoField(2, peer)) + msg.write(initProtoField(3, tag)) + msg.write(initProtoField(4, weight)) + msg.finish() + result.write(initProtoField(1, cast[uint](RequestType.CONNMANAGER))) + result.write(initProtoField(6, msg)) + result.finish() + +proc requestCMUntagPeer(peer: PeerID, tag: string): ProtoBuffer = + ## https://github.com/libp2p/go-libp2p-daemon/blob/master/connmgr.go#L33 + let msgid = cast[uint](ConnManagerRequestType.UNTAG_PEER) + result = initProtoBuffer({WithVarintLength}) + var msg = initProtoBuffer() + msg.write(initProtoField(1, msgid)) + msg.write(initProtoField(2, peer)) + msg.write(initProtoField(3, tag)) + msg.finish() + result.write(initProtoField(1, cast[uint](RequestType.CONNMANAGER))) + result.write(initProtoField(6, msg)) + result.finish() + +proc requestCMTrim(): ProtoBuffer = + ## https://github.com/libp2p/go-libp2p-daemon/blob/master/connmgr.go#L47 + let msgid = cast[uint](ConnManagerRequestType.TRIM) + result = initProtoBuffer({WithVarintLength}) + var msg = initProtoBuffer() + msg.write(initProtoField(1, msgid)) + msg.finish() + result.write(initProtoField(1, cast[uint](RequestType.CONNMANAGER))) + result.write(initProtoField(6, msg)) + result.finish() + +proc checkResponse(pb: var ProtoBuffer): ResponseKind {.inline.} = + result = ResponseKind.Malformed + var value: uint64 + if getVarintValue(pb, 1, value) > 0: + if value == 0: + result = ResponseKind.Success + else: + result = ResponseKind.Error + +proc getErrorMessage(pb: var ProtoBuffer): string {.inline.} = + if pb.enterSubmessage() == cast[int](ResponseType.ERROR): + if pb.getString(1, result) == -1: + raise newException(DaemonLocalError, "Error message is missing!") + +proc recvMessage(conn: StreamTransport): Future[seq[byte]] {.async.} = + var + size: uint + length: int + res: VarintStatus + var buffer = newSeq[byte](10) + for i in 0.. MaxMessageSize: + raise newException(ValueError, "Invalid message size") + buffer.setLen(size) + await conn.readExactly(addr buffer[0], int(size)) + result = buffer + +proc socketExists(filename: string): bool = + var res: Stat + result = stat(filename, res) >= 0'i32 + +proc newDaemonApi*(flags: set[P2PDaemonFlags] = {}, + bootstrapNodes: seq[string] = @[], + id: string = "", + daemon = DefaultDaemonFile, + sockpath = DefaultSocketPath, + pattern = "/tmp/nim-p2pd-$1.sock", + poolSize = 10): Future[DaemonAPI] {.async.} = + ## Initialize connections to `go-libp2p-daemon` control socket. + result = new DaemonAPI + result.flags = flags + result.servers = newSeq[StreamServer]() + result.address = initTAddress(sockpath) + result.pattern = pattern + result.ucounter = 1 + result.handlers = initTable[string, P2PStreamCallback]() + # We will start daemon process only when control socket path is not default or + # options are specified. + if flags == {} and sockpath == DefaultSocketPath: + result.pool = await newPool(initTAddress(sockpath), poolsize = poolSize) + else: + var args = newSeq[string]() + # DHTFull and DHTClient could not be present at the same time + if P2PDaemonFlags.DHTFull in flags and P2PDaemonFlags.DHTClient in flags: + result.flags.excl(DHTClient) + if P2PDaemonFlags.DHTFull in result.flags: + args.add("-dht") + if P2PDaemonFlags.DHTClient in result.flags: + args.add("-dhtClient") + if P2PDaemonFlags.Bootstrap in result.flags: + args.add("-b") + if len(bootstrapNodes) > 0: + args.add("-bootstrapPeers=" & bootstrapNodes.join(",")) + if len(id) != 0: + args.add("-id=" & id) + if sockpath != DefaultSocketPath: + args.add("-sock=" & sockpath) + # We are trying to get absolute daemon path. + let cmd = findExe(daemon) + if len(cmd) == 0: + raise newException(DaemonLocalError, "Could not find daemon executable!") + # We will try to remove control socket file, because daemon will fail + # if its not able to create new socket control file. + # We can't use `existsFile()` because it do not support unix-domain socket + # endpoints. + if socketExists(sockpath): + discard tryRemoveFile(sockpath) + # Starting daemon process + result.process = startProcess(cmd, "", args, options = {poStdErrToStdOut}) + # Waiting until daemon will not be bound to control socket. + while true: + if not result.process.running(): + echo result.process.errorStream.readAll() + raise newException(DaemonLocalError, + "Daemon executable could not be started!") + if socketExists(sockpath): + break + await sleepAsync(100) + result.sockname = sockpath + result.pool = await newPool(initTAddress(sockpath), poolsize = poolSize) + +proc close*(api: DaemonAPI, stream: P2PStream) {.async.} = + ## Close ``stream``. + if P2PStreamFlags.Closed notin stream.flags: + stream.transp.close() + await stream.transp.join() + stream.transp = nil + stream.flags.incl(P2PStreamFlags.Closed) + else: + raise newException(DaemonLocalError, "Stream is already closed!") + +proc close*(api: DaemonAPI) {.async.} = + ## Shutdown connections to `go-libp2p-daemon` control socket. + await api.pool.close() + # Closing all pending servers. + if len(api.servers) > 0: + var pending = newSeq[Future[void]]() + for server in api.servers: + server.stop() + server.close() + pending.add(server.join()) + await all(pending) + # Closing daemon's process. + if api.flags != {}: + api.process.terminate() + # Attempt to delete control socket endpoint. + # if socketExists(api.sockname): + # discard tryRemoveFile(api.sockname) + +template withMessage(m, body: untyped): untyped = + let kind = m.checkResponse() + if kind == ResponseKind.Error: + raise newException(DaemonRemoteError, m.getErrorMessage()) + elif kind == ResponseKind.Malformed: + raise newException(DaemonLocalError, "Malformed message received!") + else: + body + +proc transactMessage(transp: StreamTransport, + pb: ProtoBuffer): Future[ProtoBuffer] {.async.} = + let length = pb.getLen() + let res = await transp.write(pb.getPtr(), length) + if res != length: + raise newException(DaemonLocalError, "Could not send message to daemon!") + var message = await transp.recvMessage() + result = initProtoBuffer(message) + +proc getPeerInfo(pb: var ProtoBuffer): PeerInfo = + ## Get PeerInfo object from ``pb``. + result.addresses = newSeq[MultiAddress]() + result.peer = newSeq[byte]() + if pb.getBytes(1, result.peer) == -1: + raise newException(DaemonLocalError, "Missing required field `peer`!") + var address = newSeq[byte]() + while pb.getBytes(2, address) != -1: + if len(address) != 0: + result.addresses.add(address) + address.setLen(0) + +proc identity*(api: DaemonAPI): Future[PeerInfo] {.async.} = + ## Get Node identity information + var transp = await api.pool.acquire() + try: + var pb = await transactMessage(transp, requestIdentity()) + pb.withMessage() do: + let res = pb.enterSubmessage() + if res == cast[int](ResponseType.IDENTITY): + result = pb.getPeerInfo() + finally: + api.pool.release(transp) + +proc connect*(api: DaemonAPI, peer: PeerID, + addresses: seq[MultiAddress]) {.async.} = + ## Connect to remote peer with id ``peer`` and addresses ``addresses``. + var transp = await api.pool.acquire() + try: + var pb = await transp.transactMessage(requestConnect(peer, addresses)) + pb.withMessage() do: + discard + finally: + api.pool.release(transp) + +proc disconnect*(api: DaemonAPI, peer: PeerID) {.async.} = + ## Disconnect from remote peer with id ``peer``. + var transp = await api.pool.acquire() + try: + var pb = await transp.transactMessage(requestDisconnect(peer)) + pb.withMessage() do: + discard + finally: + api.pool.release(transp) + +proc openStream*(api: DaemonAPI, peer: PeerID, + protocols: seq[string]): Future[P2PStream] {.async.} = + ## Open new stream to peer ``peer`` using one of the protocols in + ## ``protocols``. Returns ``StreamTransport`` for the stream. + var transp = await connect(api.address) + var stream = new P2PStream + try: + var pb = await transp.transactMessage(requestStreamOpen(peer, protocols)) + pb.withMessage() do: + var res = pb.enterSubmessage() + if res == cast[int](ResponseType.STREAMINFO): + stream.peer = newSeq[byte]() + stream.raddress = newSeq[byte]() + stream.protocol = "" + if pb.getLengthValue(1, stream.peer) == -1: + raise newException(DaemonLocalError, "Missing `peer` field!") + if pb.getLengthValue(2, stream.raddress) == -1: + raise newException(DaemonLocalError, "Missing `address` field!") + if pb.getLengthValue(3, stream.protocol) == -1: + raise newException(DaemonLocalError, "Missing `proto` field!") + stream.flags.incl(Outbound) + stream.transp = transp + result = stream + except: + transp.close() + await transp.join() + raise getCurrentException() + +proc streamHandler(server: StreamServer, transp: StreamTransport) {.async.} = + var api = getUserData[DaemonAPI](server) + var message = await transp.recvMessage() + var pb = initProtoBuffer(message) + var stream = new P2PStream + stream.peer = newSeq[byte]() + stream.raddress = newSeq[byte]() + stream.protocol = "" + if pb.getLengthValue(1, stream.peer) == -1: + raise newException(DaemonLocalError, "Missing `peer` field!") + if pb.getLengthValue(2, stream.raddress) == -1: + raise newException(DaemonLocalError, "Missing `address` field!") + if pb.getLengthValue(3, stream.protocol) == -1: + raise newException(DaemonLocalError, "Missing `proto` field!") + stream.flags.incl(Inbound) + stream.transp = transp + if len(stream.protocol) > 0: + var handler = api.handlers.getOrDefault(stream.protocol) + if not isNil(handler): + asyncCheck handler(api, stream) + +proc addHandler*(api: DaemonAPI, protocols: seq[string], + handler: P2PStreamCallback) {.async.} = + ## Add stream handler ``handler`` for set of protocols ``protocols``. + var transp = await api.pool.acquire() + var sockname = api.pattern % [$api.ucounter] + var localaddr = initTAddress(sockname) + inc(api.ucounter) + var server = createStreamServer(localaddr, streamHandler, udata = api) + try: + for item in protocols: + api.handlers[item] = handler + server.start() + var pb = await transp.transactMessage(requestStreamHandler(sockname, + protocols)) + pb.withMessage() do: + api.servers.add(server) + except: + for item in protocols: + api.handlers.del(item) + server.stop() + server.close() + await server.join() + raise getCurrentException() + finally: + api.pool.release(transp) + +proc listPeers*(api: DaemonAPI): Future[seq[PeerInfo]] {.async.} = + ## Get list of remote peers to which we are currently connected. + var transp = await api.pool.acquire() + try: + var pb = await transp.transactMessage(requestListPeers()) + pb.withMessage() do: + var address = newSeq[byte]() + result = newSeq[PeerInfo]() + var res = pb.enterSubmessage() + while res != 0: + if res == cast[int](ResponseType.PEERINFO): + var peer = pb.getPeerInfo() + result.add(peer) + else: + pb.skipSubmessage() + res = pb.enterSubmessage() + finally: + api.pool.release(transp) + +proc cmTagPeer*(api: DaemonAPI, peer: PeerID, tag: string, + weight: int) {.async.} = + ## Tag peer with id ``peer`` using ``tag`` and ``weight``. + var transp = await api.pool.acquire() + try: + var pb = await transp.transactMessage(requestCMTagPeer(peer, tag, weight)) + withMessage(pb) do: + discard + finally: + api.pool.release(transp) + +proc cmUntagPeer*(api: DaemonAPI, peer: PeerID, tag: string) {.async.} = + ## Remove tag ``tag`` from peer with id ``peer``. + var transp = await api.pool.acquire() + try: + var pb = await transp.transactMessage(requestCMUntagPeer(peer, tag)) + withMessage(pb) do: + discard + finally: + api.pool.release(transp) + +proc cmTrimPeers*(api: DaemonAPI) {.async.} = + ## Trim all connections. + var transp = await api.pool.acquire() + try: + var pb = await transp.transactMessage(requestCMTrim()) + withMessage(pb) do: + discard + finally: + api.pool.release(transp) + +proc dhtGetSinglePeerInfo(pb: var ProtoBuffer): PeerInfo = + if pb.enterSubmessage() == 2: + result = pb.getPeerInfo() + else: + raise newException(DaemonLocalError, "Missing required field `peer`!") + +proc dhtGetSingleValue(pb: var ProtoBuffer): seq[byte] = + result = newSeq[byte]() + if pb.getLengthValue(3, result) == -1: + raise newException(DaemonLocalError, "Missing field `value`!") + +proc enterDhtMessage(pb: var ProtoBuffer, rt: DHTResponseType) {.inline.} = + var dtype: uint + var res = pb.enterSubmessage() + if res == cast[int](ResponseType.DHT): + if pb.getVarintValue(1, dtype) == 0: + raise newException(DaemonLocalError, "Missing required DHT field `type`!") + if dtype != cast[uint](rt): + raise newException(DaemonLocalError, "Wrong DHT answer type! ") + else: + raise newException(DaemonLocalError, "Wrong message type!") + +proc getDhtMessageType(pb: var ProtoBuffer): DHTResponseType {.inline.} = + var dtype: uint + if pb.getVarintValue(1, dtype) == 0: + raise newException(DaemonLocalError, "Missing required DHT field `type`!") + if dtype == cast[uint](DHTResponseType.VALUE): + result = DHTResponseType.VALUE + elif dtype == cast[uint](DHTResponseType.END): + result = DHTResponseType.END + else: + raise newException(DaemonLocalError, "Wrong DHT answer type!") + +proc dhtFindPeer*(api: DaemonAPI, peer: PeerID, + timeout = 0): Future[PeerInfo] {.async.} = + ## Find peer with id ``peer`` and return peer information ``PeerInfo``. + ## + ## You can specify timeout for DHT request with ``timeout`` value. ``0`` value + ## means no timeout. + var transp = await api.pool.acquire() + try: + var pb = await transp.transactMessage(requestDHTFindPeer(peer, timeout)) + withMessage(pb) do: + pb.enterDhtMessage(DHTResponseType.VALUE) + result = pb.dhtGetSinglePeerInfo() + finally: + api.pool.release(transp) + +proc dhtGetPublicKey*(api: DaemonAPI, peer: PeerID, + timeout = 0): Future[LibP2PPublicKey] {.async.} = + ## Get peer's public key from peer with id ``peer``. + ## + ## You can specify timeout for DHT request with ``timeout`` value. ``0`` value + ## means no timeout. + var transp = await api.pool.acquire() + try: + var pb = await transp.transactMessage(requestDHTGetPublicKey(peer, timeout)) + withMessage(pb) do: + pb.enterDhtMessage(DHTResponseType.VALUE) + result = pb.dhtGetSingleValue() + finally: + api.pool.release(transp) + +proc dhtGetValue*(api: DaemonAPI, key: string, + timeout = 0): Future[seq[byte]] {.async.} = + ## Get value associated with ``key``. + ## + ## You can specify timeout for DHT request with ``timeout`` value. ``0`` value + ## means no timeout. + var transp = await api.pool.acquire() + try: + var pb = await transp.transactMessage(requestDHTGetValue(key, timeout)) + withMessage(pb) do: + pb.enterDhtMessage(DHTResponseType.VALUE) + result = pb.dhtGetSingleValue() + finally: + api.pool.release(transp) + +proc dhtPutValue*(api: DaemonAPI, key: string, value: seq[byte], + timeout = 0) {.async.} = + ## Associate ``value`` with ``key``. + ## + ## You can specify timeout for DHT request with ``timeout`` value. ``0`` value + ## means no timeout. + var transp = await api.pool.acquire() + try: + var pb = await transp.transactMessage(requestDHTPutValue(key, value, + timeout)) + withMessage(pb) do: + discard + finally: + api.pool.release(transp) + +proc dhtProvide*(api: DaemonAPI, cid: CID, timeout = 0) {.async.} = + ## Provide content with id ``cid``. + ## + ## You can specify timeout for DHT request with ``timeout`` value. ``0`` value + ## means no timeout. + var transp = await api.pool.acquire() + try: + var pb = await transp.transactMessage(requestDHTProvide(cid, timeout)) + withMessage(pb) do: + discard + finally: + api.pool.release(transp) + +proc dhtFindPeersConnectedToPeer*(api: DaemonAPI, peer: PeerID, + timeout = 0): Future[seq[PeerInfo]] {.async.} = + ## Find peers which are connected to peer with id ``peer``. + ## + ## You can specify timeout for DHT request with ``timeout`` value. ``0`` value + ## means no timeout. + var transp = await api.pool.acquire() + var list = newSeq[PeerInfo]() + try: + let spb = requestDHTFindPeersConnectedToPeer(peer, timeout) + var pb = await transp.transactMessage(spb) + withMessage(pb) do: + pb.enterDhtMessage(DHTResponseType.BEGIN) + while true: + var message = await transp.recvMessage() + var cpb = initProtoBuffer(message) + if cpb.getDhtMessageType() == DHTResponseType.END: + break + list.add(cpb.dhtGetSinglePeerInfo()) + result = list + finally: + api.pool.release(transp) + +proc dhtGetClosestPeers*(api: DaemonAPI, key: string, + timeout = 0): Future[seq[PeerID]] {.async.} = + ## Get closest peers for ``key``. + ## + ## You can specify timeout for DHT request with ``timeout`` value. ``0`` value + ## means no timeout. + var transp = await api.pool.acquire() + var list = newSeq[PeerID]() + try: + let spb = requestDHTGetClosestPeers(key, timeout) + var pb = await transp.transactMessage(spb) + withMessage(pb) do: + pb.enterDhtMessage(DHTResponseType.BEGIN) + while true: + var message = await transp.recvMessage() + var cpb = initProtoBuffer(message) + if cpb.getDhtMessageType() == DHTResponseType.END: + break + list.add(cpb.dhtGetSingleValue()) + result = list + finally: + api.pool.release(transp) + +proc dhtFindProviders*(api: DaemonAPI, cid: CID, count: uint32, + timeout = 0): Future[seq[PeerInfo]] {.async.} = + ## Get ``count`` providers for content with id ``cid``. + ## + ## You can specify timeout for DHT request with ``timeout`` value. ``0`` value + ## means no timeout. + var transp = await api.pool.acquire() + var list = newSeq[PeerInfo]() + try: + let spb = requestDHTFindProviders(cid, count, timeout) + var pb = await transp.transactMessage(spb) + withMessage(pb) do: + pb.enterDhtMessage(DHTResponseType.BEGIN) + while true: + var message = await transp.recvMessage() + var cpb = initProtoBuffer(message) + if cpb.getDhtMessageType() == DHTResponseType.END: + break + list.add(cpb.dhtGetSinglePeerInfo()) + result = list + finally: + api.pool.release(transp) + +proc dhtSearchValue*(api: DaemonAPI, key: string, + timeout = 0): Future[seq[seq[byte]]] {.async.} = + ## Search for value with ``key``, return list of values found. + ## + ## You can specify timeout for DHT request with ``timeout`` value. ``0`` value + ## means no timeout. + var transp = await api.pool.acquire() + var list = newSeq[seq[byte]]() + try: + var pb = await transp.transactMessage(requestDHTSearchValue(key, timeout)) + withMessage(pb) do: + pb.enterDhtMessage(DHTResponseType.BEGIN) + while true: + var message = await transp.recvMessage() + var cpb = initProtoBuffer(message) + if cpb.getDhtMessageType() == DHTResponseType.END: + break + list.add(cpb.dhtGetSingleValue()) + result = list + finally: + api.pool.release(transp) + +when isMainModule: + proc test() {.async.} = + var api1 = await newDaemonApi(sockpath = "/tmp/p2pd-1.sock") + var api2 = await newDaemonApi(sockpath = "/tmp/p2pd-2.sock") + echo await api1.identity() + echo await api2.identity() + await sleepAsync(1000) + await api1.close() + await api2.close() + + waitFor test() diff --git a/libp2p/daemon/transpool.nim b/libp2p/daemon/transpool.nim new file mode 100644 index 000000000..cae8d2cb8 --- /dev/null +++ b/libp2p/daemon/transpool.nim @@ -0,0 +1,141 @@ +## Nim-Libp2p +## Copyright (c) 2018 Status Research & Development GmbH +## Licensed under either of +## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +## * MIT license ([LICENSE-MIT](LICENSE-MIT)) +## at your option. +## This file may not be copied, modified, or distributed except according to +## those terms. + +## This module implements Pool of StreamTransport. +import asyncdispatch2 + +const + DefaultPoolSize* = 8 + ## Default pool size + +type + ConnectionFlags = enum + None, Busy + + PoolItem = object + transp*: StreamTransport + flags*: set[ConnectionFlags] + + PoolState = enum + Connecting, Connected, Closing, Closed + + TransportPool* = ref object + ## Transports pool object + transports: seq[PoolItem] + busyCount: int + state: PoolState + bufferSize: int + event: AsyncEvent + + TransportPoolError* = object of AsyncError + +proc waitAll[T](futs: seq[Future[T]]): Future[void] = + ## Performs waiting for all Future[T]. + var counter = len(futs) + var retFuture = newFuture[void]("connpool.waitAllConnections") + proc cb(udata: pointer) = + dec(counter) + if counter == 0: + retFuture.complete() + for fut in futs: + fut.addCallback(cb) + return retFuture + +proc newPool*(address: TransportAddress, poolsize: int = DefaultPoolSize, + bufferSize = DefaultStreamBufferSize, + ): Future[TransportPool] {.async.} = + ## Establish pool of connections to address ``address`` with size + ## ``poolsize``. + result = new TransportPool + result.bufferSize = bufferSize + result.transports = newSeq[PoolItem](poolsize) + var conns = newSeq[Future[StreamTransport]](poolsize) + result.state = Connecting + for i in 0..= 0 + +template getPtr*(pb: ProtoBuffer): pointer = + cast[pointer](unsafeAddr pb.buffer[pb.offset]) + +template getLen*(pb: ProtoBuffer): int = + len(pb.buffer) - pb.offset + +proc vsizeof*(field: ProtoField): int {.inline.} = + ## Returns number of bytes required to store protobuf's field ``field``. + result = vsizeof(protoHeader(field)) + case field.kind + of ProtoFieldKind.Varint: + result += vsizeof(field.vint) + of ProtoFieldKind.Fixed64: + result += sizeof(field.vfloat64) + of ProtoFieldKind.Fixed32: + result += sizeof(field.vfloat32) + of ProtoFieldKind.Length: + result += vsizeof(uint(len(field.vbuffer))) + len(field.vbuffer) + else: + discard + +proc initProtoField*(index: int, value: SomeVarint): ProtoField = + ## Initialize ProtoField with integer value. + result = ProtoField(kind: Varint, index: index) + when type(value) is uint64: + result.vint = value + else: + result.vint = cast[uint64](value) + +proc initProtoField*(index: int, value: openarray[byte]): ProtoField = + ## Initialize ProtoField with bytes array. + result = ProtoField(kind: Length, index: index) + if len(value) > 0: + result.vbuffer = newSeq[byte](len(value)) + copyMem(addr result.vbuffer[0], unsafeAddr value[0], len(value)) + +proc initProtoField*(index: int, value: string): ProtoField = + ## Initialize ProtoField with string. + result = ProtoField(kind: Length, index: index) + if len(value) > 0: + result.vbuffer = newSeq[byte](len(value)) + copyMem(addr result.vbuffer[0], unsafeAddr value[0], len(value)) + +proc initProtoField*(index: int, value: ProtoBuffer): ProtoField {.inline.} = + ## Initialize ProtoField with nested message stored in ``value``. + ## + ## Note: This procedure performs shallow copy of ``value`` sequence. + result = ProtoField(kind: Length, index: index) + if len(value.buffer) > 0: + shallowCopy(result.vbuffer, value.buffer) + +proc initProtoBuffer*(data: seq[byte], offset = 0, + options: set[ProtoFlags] = {}): ProtoBuffer = + ## Initialize ProtoBuffer with shallow copy of ``data``. + shallowCopy(result.buffer, data) + result.offset = offset + result.options = options + +proc initProtoBuffer*(options: set[ProtoFlags] = {}): ProtoBuffer = + ## Initialize ProtoBuffer with new sequence of capacity ``cap``. + result.buffer = newSeqOfCap[byte](128) + result.options = options + if WithVarintLength in options: + # Our buffer will start from position 10, so we can store length of buffer + # in [0, 9]. + result.buffer.setLen(10) + result.offset = 10 + +proc write*(pb: var ProtoBuffer, field: ProtoField) = + ## Encode protobuf's field ``field`` and store it to protobuf's buffer ``pb``. + var length = 0 + var res: VarintStatus + pb.buffer.setLen(len(pb.buffer) + vsizeof(field)) + res = putUVarint(pb.toOpenArray(), length, protoHeader(field)) + assert(res == VarintStatus.Success) + pb.offset += length + case field.kind + of ProtoFieldKind.Varint: + res = putUVarint(pb.toOpenArray(), length, field.vint) + assert(res == VarintStatus.Success) + pb.offset += length + of ProtoFieldKind.Fixed64: + assert(pb.isEnough(8)) + var value = cast[uint64](field.vfloat64) + pb.buffer[pb.offset] = byte(value and 0xFF'u32) + pb.buffer[pb.offset + 1] = byte((value shr 8) and 0xFF'u32) + pb.buffer[pb.offset + 2] = byte((value shr 16) and 0xFF'u32) + pb.buffer[pb.offset + 3] = byte((value shr 24) and 0xFF'u32) + pb.buffer[pb.offset + 4] = byte((value shr 32) and 0xFF'u32) + pb.buffer[pb.offset + 5] = byte((value shr 40) and 0xFF'u32) + pb.buffer[pb.offset + 6] = byte((value shr 48) and 0xFF'u32) + pb.buffer[pb.offset + 7] = byte((value shr 56) and 0xFF'u32) + pb.offset += 8 + of ProtoFieldKind.Fixed32: + assert(pb.isEnough(4)) + var value = cast[uint32](field.vfloat32) + pb.buffer[pb.offset] = byte(value and 0xFF'u32) + pb.buffer[pb.offset + 1] = byte((value shr 8) and 0xFF'u32) + pb.buffer[pb.offset + 2] = byte((value shr 16) and 0xFF'u32) + pb.buffer[pb.offset + 3] = byte((value shr 24) and 0xFF'u32) + pb.offset += 4 + of ProtoFieldKind.Length: + res = putUVarint(pb.toOpenArray(), length, uint(len(field.vbuffer))) + assert(res == VarintStatus.Success) + pb.offset += length + assert(pb.isEnough(len(field.vbuffer))) + copyMem(addr pb.buffer[pb.offset], unsafeAddr field.vbuffer[0], + len(field.vbuffer)) + pb.offset += len(field.vbuffer) + else: + discard + +proc finish*(pb: var ProtoBuffer) = + ## Prepare protobuf's buffer ``pb`` for writing to stream. + var length = 0 + assert(len(pb.buffer) > 0) + if WithVarintLength in pb.options: + let size = uint(len(pb.buffer) - 10) + let pos = 10 - vsizeof(length) + let res = putUVarint(pb.buffer.toOpenArray(pos, 9), length, size) + assert(res == VarintStatus.Success) + pb.offset = pos + else: + pb.offset = 0 + +proc getVarintValue*(data: var ProtoBuffer, field: int, + value: var SomeVarint): int = + ## Get value of `Varint` type. + var length = 0 + var header = 0'u64 + var soffset = data.offset + + if not data.isEmpty() and + getUVarint(data.toOpenArray(), length, header) == VarintStatus.Success: + data.offset += length + if header == protoHeader(field, Varint): + if not data.isEmpty(): + when type(value) is int32 or type(value) is int64 or type(value) is int: + let res = getSVarint(data.toOpenArray(), length, value) + else: + let res = getUVarint(data.toOpenArray(), length, value) + data.offset += length + result = length + return + # Restore offset on error + data.offset = soffset + +proc getLengthValue*[T: string|seq[byte]](data: var ProtoBuffer, field: int, + buffer: var T): int = + ## Get value of `Length` type. + var length = 0 + var header = 0'u64 + var ssize = 0'u64 + var soffset = data.offset + result = -1 + buffer.setLen(0) + if not data.isEmpty() and + getUVarint(data.toOpenArray(), length, header) == VarintStatus.Success: + data.offset += length + if header == protoHeader(field, Length): + if not data.isEmpty() and + getUVarint(data.toOpenArray(), length, ssize) == VarintStatus.Success: + data.offset += length + if ssize <= MaxMessageSize and data.isEnough(int(ssize)): + buffer.setLen(ssize) + # Protobuf allow zero-length values. + if ssize > 0'u64: + copyMem(addr buffer[0], addr data.buffer[data.offset], ssize) + result = int(ssize) + data.offset += int(ssize) + return + # Restore offset on error + data.offset = soffset + +proc getBytes*(data: var ProtoBuffer, field: int, + buffer: var seq[byte]): int {.inline.} = + ## Get value of `Length` type as bytes. + result = getLengthValue(data, field, buffer) + +proc getString*(data: var ProtoBuffer, field: int, + buffer: var string): int {.inline.} = + ## Get value of `Length` type as string. + result = getLengthValue(data, field, buffer) + +proc enterSubmessage*(pb: var ProtoBuffer): int = + ## Processes protobuf's sub-message and adjust internal offset to enter + ## inside of sub-message. Returns field index of sub-message field or + ## ``0`` on error. + var length = 0 + var header = 0'u64 + var msize = 0'u64 + var soffset = pb.offset + + if not pb.isEmpty() and + getUVarint(pb.toOpenArray(), length, header) == VarintStatus.Success: + pb.offset += length + if (header and 0x07'u64) == cast[uint64](ProtoFieldKind.Length): + if not pb.isEmpty() and + getUVarint(pb.toOpenArray(), length, msize) == VarintStatus.Success: + pb.offset += length + if msize <= MaxMessageSize and pb.isEnough(int(msize)): + pb.length = int(msize) + result = int(header shr 3) + return + # Restore offset on error + pb.offset = soffset + +proc skipSubmessage*(pb: var ProtoBuffer) = + ## Skip current protobuf's sub-message and adjust internal offset to the + ## end of sub-message. + assert(pb.length != 0) + pb.offset += pb.length + pb.length = 0 diff --git a/libp2p/protobuf/varint.nim b/libp2p/protobuf/varint.nim new file mode 100644 index 000000000..d354310eb --- /dev/null +++ b/libp2p/protobuf/varint.nim @@ -0,0 +1,278 @@ +## Nim-Libp2p +## Copyright (c) 2018 Status Research & Development GmbH +## Licensed under either of +## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +## * MIT license ([LICENSE-MIT](LICENSE-MIT)) +## at your option. +## This file may not be copied, modified, or distributed except according to +## those terms. + +## This module implements Google ProtoBuf's variable integer `VARINT`. +import bitops + +type + VarintStatus* = enum + Error, + Success, + Overflow, + Incomplete, + Overrun + + SomeUVarint* = uint | uint64 | uint32 + SomeSVarint* = int | int64 | int32 + SomeVarint* = SomeUVarint | SomeSVarint + VarintError* = object of Exception + +proc vsizeof*(x: SomeUVarint|SomeSVarint): int {.inline.} = + ## Returns number of bytes required to encode integer ``x`` as varint. + if x == cast[type(x)](0): + result = 1 + else: + result = (fastLog2(x) + 1 + 7 - 1) div 7 + +proc getUVarint*(pbytes: openarray[byte], outlen: var int, + outval: var SomeUVarint): VarintStatus = + ## Decode `unsigned varint` from buffer ``pbytes`` and store it to ``outval``. + ## On success ``outlen`` will be set to number of bytes processed while + ## decoding `unsigned varint`. + ## + ## If array ``pbytes`` is empty, ``Incomplete`` error will be returned. + ## + ## If there not enough bytes available in array ``pbytes`` to decode `unsigned + ## varint`, ``Incomplete`` error will be returned. + ## + ## If encoded value can produce integer overflow, ``Overflow`` error will be + ## returned. + ## + ## Note, when decoding 10th byte of 64bit integer only 1 bit from byte will be + ## decoded, all other bits will be ignored. When decoding 5th byte of 32bit + ## integer only 4 bits from byte will be decoded, all other bits will be + ## ignored. + const MaxBits = byte(sizeof(outval) * 8) + var shift = 0'u8 + result = VarintStatus.Incomplete + outlen = 0 + outval = cast[type(outval)](0) + for i in 0..= MaxBits: + result = VarintStatus.Overflow + outlen = 0 + outval = cast[type(outval)](0) + break + else: + outval = outval or (cast[type(outval)](b and 0x7F'u8) shl shift) + shift += 7 + inc(outlen) + if (b and 0x80'u8) == 0'u8: + result = VarintStatus.Success + break + + if result == VarintStatus.Incomplete: + outlen = 0 + outval = cast[type(outval)](0) + +proc putUVarint*(pbytes: var openarray[byte], outlen: var int, + outval: SomeUVarint): VarintStatus = + ## Encode `unsigned varint` ``outval`` and store it to array ``pbytes``. + ## + ## On success ``outlen`` will hold number of bytes (octets) used to encode + ## unsigned integer ``v``. + ## + ## If there not enough bytes available in buffer ``pbytes``, ``Incomplete`` + ## error will be returned and ``outlen`` will be set to number of bytes + ## required. + ## + ## Maximum encoded length of 64bit integer is 10 octets. + ## Maximum encoded length of 32bit integer is 5 octets. + var buffer: array[10, byte] + var value = outval + var k = 0 + + if value <= cast[type(outval)](0x7F): + buffer[0] = cast[byte](outval and 0xFF) + inc(k) + else: + while value != cast[type(outval)](0): + buffer[k] = cast[byte]((value and 0x7F) or 0x80) + value = value shr 7 + inc(k) + buffer[k - 1] = buffer[k - 1] and 0x7F'u8 + + outlen = k + if len(pbytes) >= k: + copyMem(addr pbytes[0], addr buffer[0], k) + result = VarintStatus.Success + else: + result = VarintStatus.Overrun + +proc getSVarint*(pbytes: openarray[byte], outsize: var int, + outval: var SomeSVarint): VarintStatus {.inline.} = + ## Decode `signed varint` from buffer ``pbytes`` and store it to ``outval``. + ## On success ``outlen`` will be set to number of bytes processed while + ## decoding `signed varint`. + ## + ## If array ``pbytes`` is empty, ``Incomplete`` error will be returned. + ## + ## If there not enough bytes available in array ``pbytes`` to decode `signed + ## varint`, ``Incomplete`` error will be returned. + ## + ## If encoded value can produce integer overflow, ``Overflow`` error will be + ## returned. + ## + ## Note, when decoding 10th byte of 64bit integer only 1 bit from byte will be + ## decoded, all other bits will be ignored. When decoding 5th byte of 32bit + ## integer only 4 bits from byte will be decoded, all other bits will be + ## ignored. + when sizeof(outval) == 8: + var value: uint64 + else: + var value: uint32 + + result = getUVarint(pbytes, outsize, value) + if result == VarintStatus.Success: + if (value and cast[type(value)](1)) != cast[type(value)](0): + outval = cast[type(outval)](not(value shr 1)) + else: + outval = cast[type(outval)](value shr 1) + +proc putSVarint*(pbytes: var openarray[byte], outsize: var int, + outval: SomeSVarint): VarintStatus {.inline.} = + ## Encode `signed varint` ``outval`` and store it to array ``pbytes``. + ## + ## On success ``outlen`` will hold number of bytes (octets) used to encode + ## unsigned integer ``v``. + ## + ## If there not enough bytes available in buffer ``pbytes``, ``Incomplete`` + ## error will be returned and ``outlen`` will be set to number of bytes + ## required. + ## + ## Maximum encoded length of 64bit integer is 10 octets. + ## Maximum encoded length of 32bit integer is 5 octets. + when sizeof(outval) == 8: + var value: uint64 = + if outval < 0: + not(cast[uint64](outval) shl 1) + else: + cast[uint64](outval) shl 1 + else: + var value: uint32 = + if outval < 0: + not(cast[uint32](outval) shl 1) + else: + cast[uint32](outval) shl 1 + result = putUVarint(pbytes, outsize, value) + +proc encodeVarint*(value: SomeUVarint|SomeSVarint): seq[byte] {.inline.} = + ## Encode integer to `signed/unsigned varint` and returns sequence of bytes + ## as result. + var outsize = 0 + result = newSeqOfCap[byte](10) + when sizeof(value) == 4: + result.setLen(5) + else: + result.setLen(10) + when type(value) is SomeSVarint: + let res = putSVarint(result, outsize, value) + else: + let res = putUVarint(result, outsize, value) + if res == VarintStatus.Success: + result.setLen(outsize) + else: + raise newException(VarintError, "Error '" & $res & "'") + +proc decodeSVarint*(data: openarray[byte]): int {.inline.} = + ## Decode signed integer from array ``data`` and return it as result. + var outsize = 0 + let res = getSVarint(data, outsize, result) + if res != VarintStatus.Success: + raise newException(VarintError, "Error '" & $res & "'") + +proc decodeUVarint*(data: openarray[byte]): uint {.inline.} = + ## Decode unsigned integer from array ``data`` and return it as result. + var outsize = 0 + let res = getUVarint(data, outsize, result) + if res != VarintStatus.Success: + raise newException(VarintError, "Error '" & $res & "'") + +when isMainModule: + import unittest + + const edgeValues = [ + 0'u64, (1'u64 shl 7) - 1'u64, + (1'u64 shl 7), (1'u64 shl 14) - 1'u64, + (1'u64 shl 14), (1'u64 shl 21) - 1'u64, + (1'u64 shl 21), (1'u64 shl 28) - 1'u64, + (1'u64 shl 28), (1'u64 shl 35) - 1'u64, + (1'u64 shl 35), (1'u64 shl 42) - 1'u64, + (1'u64 shl 42), (1'u64 shl 49) - 1'u64, + (1'u64 shl 49), (1'u64 shl 56) - 1'u64, + (1'u64 shl 56), (1'u64 shl 63) - 1'u64, + (1'u64 shl 63), 0xFFFF_FFFF_FFFF_FFFF'u64 + ] + const edgeSizes = [ + 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10 + ] + + suite "Variable integer test suite": + + test "vsizeof() edge cases test": + for i in 0.. 5: + var value = 0'u32 + buffer.setLen(edgeSizes[i]) + check: + putUVarint(buffer, length, edgeValues[i]) == VarintStatus.Success + getUVarint(buffer, length, value) == VarintStatus.Overflow + + test "Integer Overflow 64bit test": + var buffer = newSeq[byte]() + var length = 0 + for i in 0.. 9: + var value = 0'u64 + buffer.setLen(edgeSizes[i] + 1) + check: + putUVarint(buffer, length, edgeValues[i]) == VarintStatus.Success + buffer[9] = buffer[9] or 0x80'u8 + buffer[10] = 0x01'u8 + check: + getUVarint(buffer, length, value) == VarintStatus.Overflow