mirror of
https://github.com/vacp2p/nim-libp2p.git
synced 2025-01-13 10:16:08 +00:00
Add daemon api sources.
This commit is contained in:
parent
e7e87763c3
commit
1211ffbb5c
@ -39,11 +39,10 @@ install:
|
||||
cd ../.. ;
|
||||
}"
|
||||
- "export PATH=$PWD/nim/$NIMVER/bin:$PATH"
|
||||
- echo $TRAVIS_GO_VERSION
|
||||
- go get -v github.com/libp2p/go-libp2p-daemon
|
||||
- cd $GOPATH/src/github.com/libp2p/go-libp2p-daemon
|
||||
- make
|
||||
|
||||
script:
|
||||
# - nimble install -y
|
||||
- nimble install -y
|
||||
# - nimble test
|
||||
|
11
libp2p.nim
Normal file
11
libp2p.nim
Normal file
@ -0,0 +1,11 @@
|
||||
## Nim-LibP2P
|
||||
## Copyright (c) 2018 Status Research & Development GmbH
|
||||
## Licensed under either of
|
||||
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
## at your option.
|
||||
## This file may not be copied, modified, or distributed except according to
|
||||
## those terms.
|
||||
import libp2p/daemon/[daemonapi, transpool]
|
||||
import libp2p/protobuf/[minprotobuf, varint]
|
||||
export daemonapi, minprotobuf, varint, transpool
|
12
libp2p.nimble
Normal file
12
libp2p.nimble
Normal file
@ -0,0 +1,12 @@
|
||||
mode = ScriptMode.Verbose
|
||||
|
||||
packageName = "nim-libp2p"
|
||||
version = "0.0.1"
|
||||
author = "Status Research & Development GmbH"
|
||||
description = "LibP2P implementation"
|
||||
license = "MIT"
|
||||
skipDirs = @["tests", "Nim"]
|
||||
|
||||
requires "nim > 0.18.0"
|
||||
|
||||
# task tests, "Runs the test suite":
|
866
libp2p/daemon/daemonapi.nim
Normal file
866
libp2p/daemon/daemonapi.nim
Normal file
@ -0,0 +1,866 @@
|
||||
## Nim-LibP2P
|
||||
## Copyright (c) 2018 Status Research & Development GmbH
|
||||
## Licensed under either of
|
||||
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
## at your option.
|
||||
## This file may not be copied, modified, or distributed except according to
|
||||
## those terms.
|
||||
|
||||
## This module implementes API for `go-libp2p-daemon`.
|
||||
import os, osproc, strutils, tables, streams
|
||||
import asyncdispatch2
|
||||
import ../protobuf/varint, ../protobuf/minprotobuf, transpool
|
||||
|
||||
when not defined(windows):
|
||||
import posix
|
||||
|
||||
const
|
||||
DefaultSocketPath* = "/tmp/p2pd.sock"
|
||||
DefaultDaemonFile* = "p2pd"
|
||||
|
||||
type
|
||||
RequestType* {.pure.} = enum
|
||||
IDENTITY = 0,
|
||||
CONNECT = 1,
|
||||
STREAM_OPEN = 2,
|
||||
STREAM_HANDLER = 3,
|
||||
DHT = 4,
|
||||
LIST_PEERS = 5,
|
||||
CONNMANAGER = 6,
|
||||
DISCONNECT = 7
|
||||
|
||||
DHTRequestType* {.pure.} = enum
|
||||
FIND_PEER = 0,
|
||||
FIND_PEERS_CONNECTED_TO_PEER = 1,
|
||||
FIND_PROVIDERS = 2,
|
||||
GET_CLOSEST_PEERS = 3,
|
||||
GET_PUBLIC_KEY = 4,
|
||||
GET_VALUE = 5,
|
||||
SEARCH_VALUE = 6,
|
||||
PUT_VALUE = 7,
|
||||
PROVIDE = 8
|
||||
|
||||
ConnManagerRequestType* {.pure.} = enum
|
||||
TAG_PEER = 0,
|
||||
UNTAG_PEER = 1,
|
||||
TRIM = 2
|
||||
|
||||
ResponseKind* = enum
|
||||
Malformed,
|
||||
Error,
|
||||
Success
|
||||
|
||||
ResponseType* {.pure.} = enum
|
||||
ERROR = 2,
|
||||
STREAMINFO = 3,
|
||||
IDENTITY = 4,
|
||||
DHT = 5,
|
||||
PEERINFO = 6
|
||||
|
||||
DHTResponseType* {.pure.} = enum
|
||||
BEGIN = 0,
|
||||
VALUE = 1,
|
||||
END = 2
|
||||
|
||||
PeerID* = seq[byte]
|
||||
MultiProtocol* = string
|
||||
MultiAddress* = seq[byte]
|
||||
CID* = seq[byte]
|
||||
LibP2PPublicKey* = seq[byte]
|
||||
DHTValue* = seq[byte]
|
||||
|
||||
P2PStreamFlags* {.pure.} = enum
|
||||
None, Closed, Inbound, Outbound
|
||||
|
||||
P2PDaemonFlags* {.pure.} = enum
|
||||
DHTClient, DHTFull, Bootstrap
|
||||
|
||||
P2PStream* = ref object
|
||||
flags*: set[P2PStreamFlags]
|
||||
peer*: PeerID
|
||||
raddress*: MultiAddress
|
||||
protocol*: string
|
||||
transp*: StreamTransport
|
||||
|
||||
DaemonAPI* = ref object
|
||||
pool*: TransportPool
|
||||
flags*: set[P2PDaemonFlags]
|
||||
address*: TransportAddress
|
||||
sockname*: string
|
||||
pattern*: string
|
||||
ucounter*: int
|
||||
process*: Process
|
||||
handlers*: Table[string, P2PStreamCallback]
|
||||
servers*: seq[StreamServer]
|
||||
|
||||
PeerInfo* = object
|
||||
peer*: PeerID
|
||||
addresses: seq[MultiAddress]
|
||||
|
||||
P2PStreamCallback* = proc(api: DaemonAPI,
|
||||
stream: P2PStream): Future[void] {.gcsafe.}
|
||||
|
||||
DaemonRemoteError* = object of Exception
|
||||
DaemonLocalError* = object of Exception
|
||||
|
||||
proc requestIdentity(): ProtoBuffer =
|
||||
## https://github.com/libp2p/go-libp2p-daemon/blob/master/conn.go
|
||||
## Processing function `doIdentify(req *pb.Request)`.
|
||||
result = initProtoBuffer({WithVarintLength})
|
||||
result.write(initProtoField(1, cast[uint](RequestType.IDENTITY)))
|
||||
result.finish()
|
||||
|
||||
proc requestConnect(peerid: PeerID,
|
||||
addresses: openarray[MultiAddress]): ProtoBuffer =
|
||||
## https://github.com/libp2p/go-libp2p-daemon/blob/master/conn.go
|
||||
## Processing function `doConnect(req *pb.Request)`.
|
||||
result = initProtoBuffer({WithVarintLength})
|
||||
var msg = initProtoBuffer()
|
||||
msg.write(initProtoField(1, peerid))
|
||||
for item in addresses:
|
||||
msg.write(initProtoField(2, item))
|
||||
result.write(initProtoField(1, cast[uint](RequestType.CONNECT)))
|
||||
result.write(initProtoField(2, msg))
|
||||
result.finish()
|
||||
|
||||
proc requestDisconnect(peerid: PeerID): ProtoBuffer =
|
||||
## https://github.com/libp2p/go-libp2p-daemon/blob/master/conn.go
|
||||
## Processing function `doDisconnect(req *pb.Request)`.
|
||||
result = initProtoBuffer({WithVarintLength})
|
||||
var msg = initProtoBuffer()
|
||||
msg.write(initProtoField(1, peerid))
|
||||
result.write(initProtoField(1, cast[uint](RequestType.DISCONNECT)))
|
||||
result.write(initProtoField(7, msg))
|
||||
result.finish()
|
||||
|
||||
proc requestStreamOpen(peerid: PeerID,
|
||||
protocols: openarray[string]): ProtoBuffer =
|
||||
## https://github.com/libp2p/go-libp2p-daemon/blob/master/conn.go
|
||||
## Processing function `doStreamOpen(req *pb.Request)`.
|
||||
result = initProtoBuffer({WithVarintLength})
|
||||
var msg = initProtoBuffer()
|
||||
msg.write(initProtoField(1, peerid))
|
||||
for item in protocols:
|
||||
msg.write(initProtoField(2, item))
|
||||
result.write(initProtoField(1, cast[uint](RequestType.STREAM_OPEN)))
|
||||
result.write(initProtoField(3, msg))
|
||||
result.finish()
|
||||
|
||||
proc requestStreamHandler(path: string,
|
||||
protocols: openarray[MultiProtocol]): ProtoBuffer =
|
||||
## https://github.com/libp2p/go-libp2p-daemon/blob/master/conn.go
|
||||
## Processing function `doStreamHandler(req *pb.Request)`.
|
||||
result = initProtoBuffer({WithVarintLength})
|
||||
var msg = initProtoBuffer()
|
||||
msg.write(initProtoField(1, path))
|
||||
for item in protocols:
|
||||
msg.write(initProtoField(2, item))
|
||||
result.write(initProtoField(1, cast[uint](RequestType.STREAM_HANDLER)))
|
||||
result.write(initProtoField(4, msg))
|
||||
result.finish()
|
||||
|
||||
proc requestListPeers(): ProtoBuffer =
|
||||
## https://github.com/libp2p/go-libp2p-daemon/blob/master/conn.go
|
||||
## Processing function `doListPeers(req *pb.Request)`
|
||||
result = initProtoBuffer({WithVarintLength})
|
||||
result.write(initProtoField(1, cast[uint](RequestType.LIST_PEERS)))
|
||||
result.finish()
|
||||
|
||||
proc requestDHTFindPeer(peer: PeerID, timeout = 0): ProtoBuffer =
|
||||
## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go
|
||||
## Processing function `doDHTFindPeer(req *pb.DHTRequest)`.
|
||||
let msgid = cast[uint](DHTRequestType.FIND_PEER)
|
||||
result = initProtoBuffer({WithVarintLength})
|
||||
var msg = initProtoBuffer()
|
||||
msg.write(initProtoField(1, msgid))
|
||||
msg.write(initProtoField(2, peer))
|
||||
if timeout > 0:
|
||||
msg.write(initProtoField(7, uint(timeout)))
|
||||
msg.finish()
|
||||
result.write(initProtoField(1, cast[uint](RequestType.DHT)))
|
||||
result.write(initProtoField(5, msg))
|
||||
result.finish()
|
||||
|
||||
proc requestDHTFindPeersConnectedToPeer(peer: PeerID,
|
||||
timeout = 0): ProtoBuffer =
|
||||
## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go
|
||||
## Processing function `doDHTFindPeersConnectedToPeer(req *pb.DHTRequest)`.
|
||||
let msgid = cast[uint](DHTRequestType.FIND_PEERS_CONNECTED_TO_PEER)
|
||||
result = initProtoBuffer({WithVarintLength})
|
||||
var msg = initProtoBuffer()
|
||||
msg.write(initProtoField(1, msgid))
|
||||
msg.write(initProtoField(2, peer))
|
||||
if timeout > 0:
|
||||
msg.write(initProtoField(7, uint(timeout)))
|
||||
msg.finish()
|
||||
result.write(initProtoField(1, cast[uint](RequestType.DHT)))
|
||||
result.write(initProtoField(5, msg))
|
||||
result.finish()
|
||||
|
||||
proc requestDHTFindProviders(cid: CID,
|
||||
count: uint32, timeout = 0): ProtoBuffer =
|
||||
## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go
|
||||
## Processing function `doDHTFindProviders(req *pb.DHTRequest)`.
|
||||
let msgid = cast[uint](DHTRequestType.FIND_PROVIDERS)
|
||||
result = initProtoBuffer({WithVarintLength})
|
||||
var msg = initProtoBuffer()
|
||||
msg.write(initProtoField(1, msgid))
|
||||
msg.write(initProtoField(3, cid))
|
||||
msg.write(initProtoField(6, count))
|
||||
if timeout > 0:
|
||||
msg.write(initProtoField(7, uint(timeout)))
|
||||
msg.finish()
|
||||
result.write(initProtoField(1, cast[uint](RequestType.DHT)))
|
||||
result.write(initProtoField(5, msg))
|
||||
result.finish()
|
||||
|
||||
proc requestDHTGetClosestPeers(key: string, timeout = 0): ProtoBuffer =
|
||||
## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go
|
||||
## Processing function `doDHTGetClosestPeers(req *pb.DHTRequest)`.
|
||||
let msgid = cast[uint](DHTRequestType.GET_CLOSEST_PEERS)
|
||||
result = initProtoBuffer({WithVarintLength})
|
||||
var msg = initProtoBuffer()
|
||||
msg.write(initProtoField(1, msgid))
|
||||
msg.write(initProtoField(4, key))
|
||||
if timeout > 0:
|
||||
msg.write(initProtoField(7, uint(timeout)))
|
||||
msg.finish()
|
||||
result.write(initProtoField(1, cast[uint](RequestType.DHT)))
|
||||
result.write(initProtoField(5, msg))
|
||||
result.finish()
|
||||
|
||||
proc requestDHTGetPublicKey(peer: PeerID, timeout = 0): ProtoBuffer =
|
||||
## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go
|
||||
## Processing function `doDHTGetPublicKey(req *pb.DHTRequest)`.
|
||||
let msgid = cast[uint](DHTRequestType.GET_PUBLIC_KEY)
|
||||
result = initProtoBuffer({WithVarintLength})
|
||||
var msg = initProtoBuffer()
|
||||
msg.write(initProtoField(1, msgid))
|
||||
msg.write(initProtoField(2, peer))
|
||||
if timeout > 0:
|
||||
msg.write(initProtoField(7, uint(timeout)))
|
||||
msg.finish()
|
||||
result.write(initProtoField(1, cast[uint](RequestType.DHT)))
|
||||
result.write(initProtoField(5, msg))
|
||||
result.finish()
|
||||
|
||||
proc requestDHTGetValue(key: string, timeout = 0): ProtoBuffer =
|
||||
## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go
|
||||
## Processing function `doDHTGetValue(req *pb.DHTRequest)`.
|
||||
let msgid = cast[uint](DHTRequestType.GET_VALUE)
|
||||
result = initProtoBuffer({WithVarintLength})
|
||||
var msg = initProtoBuffer()
|
||||
msg.write(initProtoField(1, msgid))
|
||||
msg.write(initProtoField(4, key))
|
||||
if timeout > 0:
|
||||
msg.write(initProtoField(7, uint(timeout)))
|
||||
msg.finish()
|
||||
result.write(initProtoField(1, cast[uint](RequestType.DHT)))
|
||||
result.write(initProtoField(5, msg))
|
||||
result.finish()
|
||||
|
||||
proc requestDHTSearchValue(key: string, timeout = 0): ProtoBuffer =
|
||||
## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go
|
||||
## Processing function `doDHTSearchValue(req *pb.DHTRequest)`.
|
||||
let msgid = cast[uint](DHTRequestType.SEARCH_VALUE)
|
||||
result = initProtoBuffer({WithVarintLength})
|
||||
var msg = initProtoBuffer()
|
||||
msg.write(initProtoField(1, msgid))
|
||||
msg.write(initProtoField(4, key))
|
||||
if timeout > 0:
|
||||
msg.write(initProtoField(7, uint(timeout)))
|
||||
msg.finish()
|
||||
result.write(initProtoField(1, cast[uint](RequestType.DHT)))
|
||||
result.write(initProtoField(5, msg))
|
||||
result.finish()
|
||||
|
||||
proc requestDHTPutValue(key: string, value: openarray[byte],
|
||||
timeout = 0): ProtoBuffer =
|
||||
## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go
|
||||
## Processing function `doDHTPutValue(req *pb.DHTRequest)`.
|
||||
let msgid = cast[uint](DHTRequestType.PUT_VALUE)
|
||||
result = initProtoBuffer({WithVarintLength})
|
||||
var msg = initProtoBuffer()
|
||||
msg.write(initProtoField(1, msgid))
|
||||
msg.write(initProtoField(4, key))
|
||||
msg.write(initProtoField(5, value))
|
||||
if timeout > 0:
|
||||
msg.write(initProtoField(7, uint(timeout)))
|
||||
msg.finish()
|
||||
result.write(initProtoField(1, cast[uint](RequestType.DHT)))
|
||||
result.write(initProtoField(5, msg))
|
||||
result.finish()
|
||||
|
||||
proc requestDHTProvide(cid: CID, timeout = 0): ProtoBuffer =
|
||||
## https://github.com/libp2p/go-libp2p-daemon/blob/master/dht.go
|
||||
## Processing function `doDHTProvide(req *pb.DHTRequest)`.
|
||||
let msgid = cast[uint](DHTRequestType.PROVIDE)
|
||||
result = initProtoBuffer({WithVarintLength})
|
||||
var msg = initProtoBuffer()
|
||||
msg.write(initProtoField(1, msgid))
|
||||
msg.write(initProtoField(3, cid))
|
||||
if timeout > 0:
|
||||
msg.write(initProtoField(7, uint(timeout)))
|
||||
msg.finish()
|
||||
result.write(initProtoField(1, cast[uint](RequestType.DHT)))
|
||||
result.write(initProtoField(5, msg))
|
||||
result.finish()
|
||||
|
||||
proc requestCMTagPeer(peer: PeerID, tag: string, weight: int): ProtoBuffer =
|
||||
## https://github.com/libp2p/go-libp2p-daemon/blob/master/connmgr.go#L18
|
||||
let msgid = cast[uint](ConnManagerRequestType.TAG_PEER)
|
||||
result = initProtoBuffer({WithVarintLength})
|
||||
var msg = initProtoBuffer()
|
||||
msg.write(initProtoField(1, msgid))
|
||||
msg.write(initProtoField(2, peer))
|
||||
msg.write(initProtoField(3, tag))
|
||||
msg.write(initProtoField(4, weight))
|
||||
msg.finish()
|
||||
result.write(initProtoField(1, cast[uint](RequestType.CONNMANAGER)))
|
||||
result.write(initProtoField(6, msg))
|
||||
result.finish()
|
||||
|
||||
proc requestCMUntagPeer(peer: PeerID, tag: string): ProtoBuffer =
|
||||
## https://github.com/libp2p/go-libp2p-daemon/blob/master/connmgr.go#L33
|
||||
let msgid = cast[uint](ConnManagerRequestType.UNTAG_PEER)
|
||||
result = initProtoBuffer({WithVarintLength})
|
||||
var msg = initProtoBuffer()
|
||||
msg.write(initProtoField(1, msgid))
|
||||
msg.write(initProtoField(2, peer))
|
||||
msg.write(initProtoField(3, tag))
|
||||
msg.finish()
|
||||
result.write(initProtoField(1, cast[uint](RequestType.CONNMANAGER)))
|
||||
result.write(initProtoField(6, msg))
|
||||
result.finish()
|
||||
|
||||
proc requestCMTrim(): ProtoBuffer =
|
||||
## https://github.com/libp2p/go-libp2p-daemon/blob/master/connmgr.go#L47
|
||||
let msgid = cast[uint](ConnManagerRequestType.TRIM)
|
||||
result = initProtoBuffer({WithVarintLength})
|
||||
var msg = initProtoBuffer()
|
||||
msg.write(initProtoField(1, msgid))
|
||||
msg.finish()
|
||||
result.write(initProtoField(1, cast[uint](RequestType.CONNMANAGER)))
|
||||
result.write(initProtoField(6, msg))
|
||||
result.finish()
|
||||
|
||||
proc checkResponse(pb: var ProtoBuffer): ResponseKind {.inline.} =
|
||||
result = ResponseKind.Malformed
|
||||
var value: uint64
|
||||
if getVarintValue(pb, 1, value) > 0:
|
||||
if value == 0:
|
||||
result = ResponseKind.Success
|
||||
else:
|
||||
result = ResponseKind.Error
|
||||
|
||||
proc getErrorMessage(pb: var ProtoBuffer): string {.inline.} =
|
||||
if pb.enterSubmessage() == cast[int](ResponseType.ERROR):
|
||||
if pb.getString(1, result) == -1:
|
||||
raise newException(DaemonLocalError, "Error message is missing!")
|
||||
|
||||
proc recvMessage(conn: StreamTransport): Future[seq[byte]] {.async.} =
|
||||
var
|
||||
size: uint
|
||||
length: int
|
||||
res: VarintStatus
|
||||
var buffer = newSeq[byte](10)
|
||||
for i in 0..<len(buffer):
|
||||
await conn.readExactly(addr buffer[i], 1)
|
||||
res = getUVarint(buffer.toOpenArray(0, i), length, size)
|
||||
if res == VarintStatus.Success:
|
||||
break
|
||||
if res != VarintStatus.Success or size > MaxMessageSize:
|
||||
raise newException(ValueError, "Invalid message size")
|
||||
buffer.setLen(size)
|
||||
await conn.readExactly(addr buffer[0], int(size))
|
||||
result = buffer
|
||||
|
||||
proc socketExists(filename: string): bool =
|
||||
var res: Stat
|
||||
result = stat(filename, res) >= 0'i32
|
||||
|
||||
proc newDaemonApi*(flags: set[P2PDaemonFlags] = {},
|
||||
bootstrapNodes: seq[string] = @[],
|
||||
id: string = "",
|
||||
daemon = DefaultDaemonFile,
|
||||
sockpath = DefaultSocketPath,
|
||||
pattern = "/tmp/nim-p2pd-$1.sock",
|
||||
poolSize = 10): Future[DaemonAPI] {.async.} =
|
||||
## Initialize connections to `go-libp2p-daemon` control socket.
|
||||
result = new DaemonAPI
|
||||
result.flags = flags
|
||||
result.servers = newSeq[StreamServer]()
|
||||
result.address = initTAddress(sockpath)
|
||||
result.pattern = pattern
|
||||
result.ucounter = 1
|
||||
result.handlers = initTable[string, P2PStreamCallback]()
|
||||
# We will start daemon process only when control socket path is not default or
|
||||
# options are specified.
|
||||
if flags == {} and sockpath == DefaultSocketPath:
|
||||
result.pool = await newPool(initTAddress(sockpath), poolsize = poolSize)
|
||||
else:
|
||||
var args = newSeq[string]()
|
||||
# DHTFull and DHTClient could not be present at the same time
|
||||
if P2PDaemonFlags.DHTFull in flags and P2PDaemonFlags.DHTClient in flags:
|
||||
result.flags.excl(DHTClient)
|
||||
if P2PDaemonFlags.DHTFull in result.flags:
|
||||
args.add("-dht")
|
||||
if P2PDaemonFlags.DHTClient in result.flags:
|
||||
args.add("-dhtClient")
|
||||
if P2PDaemonFlags.Bootstrap in result.flags:
|
||||
args.add("-b")
|
||||
if len(bootstrapNodes) > 0:
|
||||
args.add("-bootstrapPeers=" & bootstrapNodes.join(","))
|
||||
if len(id) != 0:
|
||||
args.add("-id=" & id)
|
||||
if sockpath != DefaultSocketPath:
|
||||
args.add("-sock=" & sockpath)
|
||||
# We are trying to get absolute daemon path.
|
||||
let cmd = findExe(daemon)
|
||||
if len(cmd) == 0:
|
||||
raise newException(DaemonLocalError, "Could not find daemon executable!")
|
||||
# We will try to remove control socket file, because daemon will fail
|
||||
# if its not able to create new socket control file.
|
||||
# We can't use `existsFile()` because it do not support unix-domain socket
|
||||
# endpoints.
|
||||
if socketExists(sockpath):
|
||||
discard tryRemoveFile(sockpath)
|
||||
# Starting daemon process
|
||||
result.process = startProcess(cmd, "", args, options = {poStdErrToStdOut})
|
||||
# Waiting until daemon will not be bound to control socket.
|
||||
while true:
|
||||
if not result.process.running():
|
||||
echo result.process.errorStream.readAll()
|
||||
raise newException(DaemonLocalError,
|
||||
"Daemon executable could not be started!")
|
||||
if socketExists(sockpath):
|
||||
break
|
||||
await sleepAsync(100)
|
||||
result.sockname = sockpath
|
||||
result.pool = await newPool(initTAddress(sockpath), poolsize = poolSize)
|
||||
|
||||
proc close*(api: DaemonAPI, stream: P2PStream) {.async.} =
|
||||
## Close ``stream``.
|
||||
if P2PStreamFlags.Closed notin stream.flags:
|
||||
stream.transp.close()
|
||||
await stream.transp.join()
|
||||
stream.transp = nil
|
||||
stream.flags.incl(P2PStreamFlags.Closed)
|
||||
else:
|
||||
raise newException(DaemonLocalError, "Stream is already closed!")
|
||||
|
||||
proc close*(api: DaemonAPI) {.async.} =
|
||||
## Shutdown connections to `go-libp2p-daemon` control socket.
|
||||
await api.pool.close()
|
||||
# Closing all pending servers.
|
||||
if len(api.servers) > 0:
|
||||
var pending = newSeq[Future[void]]()
|
||||
for server in api.servers:
|
||||
server.stop()
|
||||
server.close()
|
||||
pending.add(server.join())
|
||||
await all(pending)
|
||||
# Closing daemon's process.
|
||||
if api.flags != {}:
|
||||
api.process.terminate()
|
||||
# Attempt to delete control socket endpoint.
|
||||
# if socketExists(api.sockname):
|
||||
# discard tryRemoveFile(api.sockname)
|
||||
|
||||
template withMessage(m, body: untyped): untyped =
|
||||
let kind = m.checkResponse()
|
||||
if kind == ResponseKind.Error:
|
||||
raise newException(DaemonRemoteError, m.getErrorMessage())
|
||||
elif kind == ResponseKind.Malformed:
|
||||
raise newException(DaemonLocalError, "Malformed message received!")
|
||||
else:
|
||||
body
|
||||
|
||||
proc transactMessage(transp: StreamTransport,
|
||||
pb: ProtoBuffer): Future[ProtoBuffer] {.async.} =
|
||||
let length = pb.getLen()
|
||||
let res = await transp.write(pb.getPtr(), length)
|
||||
if res != length:
|
||||
raise newException(DaemonLocalError, "Could not send message to daemon!")
|
||||
var message = await transp.recvMessage()
|
||||
result = initProtoBuffer(message)
|
||||
|
||||
proc getPeerInfo(pb: var ProtoBuffer): PeerInfo =
|
||||
## Get PeerInfo object from ``pb``.
|
||||
result.addresses = newSeq[MultiAddress]()
|
||||
result.peer = newSeq[byte]()
|
||||
if pb.getBytes(1, result.peer) == -1:
|
||||
raise newException(DaemonLocalError, "Missing required field `peer`!")
|
||||
var address = newSeq[byte]()
|
||||
while pb.getBytes(2, address) != -1:
|
||||
if len(address) != 0:
|
||||
result.addresses.add(address)
|
||||
address.setLen(0)
|
||||
|
||||
proc identity*(api: DaemonAPI): Future[PeerInfo] {.async.} =
|
||||
## Get Node identity information
|
||||
var transp = await api.pool.acquire()
|
||||
try:
|
||||
var pb = await transactMessage(transp, requestIdentity())
|
||||
pb.withMessage() do:
|
||||
let res = pb.enterSubmessage()
|
||||
if res == cast[int](ResponseType.IDENTITY):
|
||||
result = pb.getPeerInfo()
|
||||
finally:
|
||||
api.pool.release(transp)
|
||||
|
||||
proc connect*(api: DaemonAPI, peer: PeerID,
|
||||
addresses: seq[MultiAddress]) {.async.} =
|
||||
## Connect to remote peer with id ``peer`` and addresses ``addresses``.
|
||||
var transp = await api.pool.acquire()
|
||||
try:
|
||||
var pb = await transp.transactMessage(requestConnect(peer, addresses))
|
||||
pb.withMessage() do:
|
||||
discard
|
||||
finally:
|
||||
api.pool.release(transp)
|
||||
|
||||
proc disconnect*(api: DaemonAPI, peer: PeerID) {.async.} =
|
||||
## Disconnect from remote peer with id ``peer``.
|
||||
var transp = await api.pool.acquire()
|
||||
try:
|
||||
var pb = await transp.transactMessage(requestDisconnect(peer))
|
||||
pb.withMessage() do:
|
||||
discard
|
||||
finally:
|
||||
api.pool.release(transp)
|
||||
|
||||
proc openStream*(api: DaemonAPI, peer: PeerID,
|
||||
protocols: seq[string]): Future[P2PStream] {.async.} =
|
||||
## Open new stream to peer ``peer`` using one of the protocols in
|
||||
## ``protocols``. Returns ``StreamTransport`` for the stream.
|
||||
var transp = await connect(api.address)
|
||||
var stream = new P2PStream
|
||||
try:
|
||||
var pb = await transp.transactMessage(requestStreamOpen(peer, protocols))
|
||||
pb.withMessage() do:
|
||||
var res = pb.enterSubmessage()
|
||||
if res == cast[int](ResponseType.STREAMINFO):
|
||||
stream.peer = newSeq[byte]()
|
||||
stream.raddress = newSeq[byte]()
|
||||
stream.protocol = ""
|
||||
if pb.getLengthValue(1, stream.peer) == -1:
|
||||
raise newException(DaemonLocalError, "Missing `peer` field!")
|
||||
if pb.getLengthValue(2, stream.raddress) == -1:
|
||||
raise newException(DaemonLocalError, "Missing `address` field!")
|
||||
if pb.getLengthValue(3, stream.protocol) == -1:
|
||||
raise newException(DaemonLocalError, "Missing `proto` field!")
|
||||
stream.flags.incl(Outbound)
|
||||
stream.transp = transp
|
||||
result = stream
|
||||
except:
|
||||
transp.close()
|
||||
await transp.join()
|
||||
raise getCurrentException()
|
||||
|
||||
proc streamHandler(server: StreamServer, transp: StreamTransport) {.async.} =
|
||||
var api = getUserData[DaemonAPI](server)
|
||||
var message = await transp.recvMessage()
|
||||
var pb = initProtoBuffer(message)
|
||||
var stream = new P2PStream
|
||||
stream.peer = newSeq[byte]()
|
||||
stream.raddress = newSeq[byte]()
|
||||
stream.protocol = ""
|
||||
if pb.getLengthValue(1, stream.peer) == -1:
|
||||
raise newException(DaemonLocalError, "Missing `peer` field!")
|
||||
if pb.getLengthValue(2, stream.raddress) == -1:
|
||||
raise newException(DaemonLocalError, "Missing `address` field!")
|
||||
if pb.getLengthValue(3, stream.protocol) == -1:
|
||||
raise newException(DaemonLocalError, "Missing `proto` field!")
|
||||
stream.flags.incl(Inbound)
|
||||
stream.transp = transp
|
||||
if len(stream.protocol) > 0:
|
||||
var handler = api.handlers.getOrDefault(stream.protocol)
|
||||
if not isNil(handler):
|
||||
asyncCheck handler(api, stream)
|
||||
|
||||
proc addHandler*(api: DaemonAPI, protocols: seq[string],
|
||||
handler: P2PStreamCallback) {.async.} =
|
||||
## Add stream handler ``handler`` for set of protocols ``protocols``.
|
||||
var transp = await api.pool.acquire()
|
||||
var sockname = api.pattern % [$api.ucounter]
|
||||
var localaddr = initTAddress(sockname)
|
||||
inc(api.ucounter)
|
||||
var server = createStreamServer(localaddr, streamHandler, udata = api)
|
||||
try:
|
||||
for item in protocols:
|
||||
api.handlers[item] = handler
|
||||
server.start()
|
||||
var pb = await transp.transactMessage(requestStreamHandler(sockname,
|
||||
protocols))
|
||||
pb.withMessage() do:
|
||||
api.servers.add(server)
|
||||
except:
|
||||
for item in protocols:
|
||||
api.handlers.del(item)
|
||||
server.stop()
|
||||
server.close()
|
||||
await server.join()
|
||||
raise getCurrentException()
|
||||
finally:
|
||||
api.pool.release(transp)
|
||||
|
||||
proc listPeers*(api: DaemonAPI): Future[seq[PeerInfo]] {.async.} =
|
||||
## Get list of remote peers to which we are currently connected.
|
||||
var transp = await api.pool.acquire()
|
||||
try:
|
||||
var pb = await transp.transactMessage(requestListPeers())
|
||||
pb.withMessage() do:
|
||||
var address = newSeq[byte]()
|
||||
result = newSeq[PeerInfo]()
|
||||
var res = pb.enterSubmessage()
|
||||
while res != 0:
|
||||
if res == cast[int](ResponseType.PEERINFO):
|
||||
var peer = pb.getPeerInfo()
|
||||
result.add(peer)
|
||||
else:
|
||||
pb.skipSubmessage()
|
||||
res = pb.enterSubmessage()
|
||||
finally:
|
||||
api.pool.release(transp)
|
||||
|
||||
proc cmTagPeer*(api: DaemonAPI, peer: PeerID, tag: string,
|
||||
weight: int) {.async.} =
|
||||
## Tag peer with id ``peer`` using ``tag`` and ``weight``.
|
||||
var transp = await api.pool.acquire()
|
||||
try:
|
||||
var pb = await transp.transactMessage(requestCMTagPeer(peer, tag, weight))
|
||||
withMessage(pb) do:
|
||||
discard
|
||||
finally:
|
||||
api.pool.release(transp)
|
||||
|
||||
proc cmUntagPeer*(api: DaemonAPI, peer: PeerID, tag: string) {.async.} =
|
||||
## Remove tag ``tag`` from peer with id ``peer``.
|
||||
var transp = await api.pool.acquire()
|
||||
try:
|
||||
var pb = await transp.transactMessage(requestCMUntagPeer(peer, tag))
|
||||
withMessage(pb) do:
|
||||
discard
|
||||
finally:
|
||||
api.pool.release(transp)
|
||||
|
||||
proc cmTrimPeers*(api: DaemonAPI) {.async.} =
|
||||
## Trim all connections.
|
||||
var transp = await api.pool.acquire()
|
||||
try:
|
||||
var pb = await transp.transactMessage(requestCMTrim())
|
||||
withMessage(pb) do:
|
||||
discard
|
||||
finally:
|
||||
api.pool.release(transp)
|
||||
|
||||
proc dhtGetSinglePeerInfo(pb: var ProtoBuffer): PeerInfo =
|
||||
if pb.enterSubmessage() == 2:
|
||||
result = pb.getPeerInfo()
|
||||
else:
|
||||
raise newException(DaemonLocalError, "Missing required field `peer`!")
|
||||
|
||||
proc dhtGetSingleValue(pb: var ProtoBuffer): seq[byte] =
|
||||
result = newSeq[byte]()
|
||||
if pb.getLengthValue(3, result) == -1:
|
||||
raise newException(DaemonLocalError, "Missing field `value`!")
|
||||
|
||||
proc enterDhtMessage(pb: var ProtoBuffer, rt: DHTResponseType) {.inline.} =
|
||||
var dtype: uint
|
||||
var res = pb.enterSubmessage()
|
||||
if res == cast[int](ResponseType.DHT):
|
||||
if pb.getVarintValue(1, dtype) == 0:
|
||||
raise newException(DaemonLocalError, "Missing required DHT field `type`!")
|
||||
if dtype != cast[uint](rt):
|
||||
raise newException(DaemonLocalError, "Wrong DHT answer type! ")
|
||||
else:
|
||||
raise newException(DaemonLocalError, "Wrong message type!")
|
||||
|
||||
proc getDhtMessageType(pb: var ProtoBuffer): DHTResponseType {.inline.} =
|
||||
var dtype: uint
|
||||
if pb.getVarintValue(1, dtype) == 0:
|
||||
raise newException(DaemonLocalError, "Missing required DHT field `type`!")
|
||||
if dtype == cast[uint](DHTResponseType.VALUE):
|
||||
result = DHTResponseType.VALUE
|
||||
elif dtype == cast[uint](DHTResponseType.END):
|
||||
result = DHTResponseType.END
|
||||
else:
|
||||
raise newException(DaemonLocalError, "Wrong DHT answer type!")
|
||||
|
||||
proc dhtFindPeer*(api: DaemonAPI, peer: PeerID,
|
||||
timeout = 0): Future[PeerInfo] {.async.} =
|
||||
## Find peer with id ``peer`` and return peer information ``PeerInfo``.
|
||||
##
|
||||
## You can specify timeout for DHT request with ``timeout`` value. ``0`` value
|
||||
## means no timeout.
|
||||
var transp = await api.pool.acquire()
|
||||
try:
|
||||
var pb = await transp.transactMessage(requestDHTFindPeer(peer, timeout))
|
||||
withMessage(pb) do:
|
||||
pb.enterDhtMessage(DHTResponseType.VALUE)
|
||||
result = pb.dhtGetSinglePeerInfo()
|
||||
finally:
|
||||
api.pool.release(transp)
|
||||
|
||||
proc dhtGetPublicKey*(api: DaemonAPI, peer: PeerID,
|
||||
timeout = 0): Future[LibP2PPublicKey] {.async.} =
|
||||
## Get peer's public key from peer with id ``peer``.
|
||||
##
|
||||
## You can specify timeout for DHT request with ``timeout`` value. ``0`` value
|
||||
## means no timeout.
|
||||
var transp = await api.pool.acquire()
|
||||
try:
|
||||
var pb = await transp.transactMessage(requestDHTGetPublicKey(peer, timeout))
|
||||
withMessage(pb) do:
|
||||
pb.enterDhtMessage(DHTResponseType.VALUE)
|
||||
result = pb.dhtGetSingleValue()
|
||||
finally:
|
||||
api.pool.release(transp)
|
||||
|
||||
proc dhtGetValue*(api: DaemonAPI, key: string,
|
||||
timeout = 0): Future[seq[byte]] {.async.} =
|
||||
## Get value associated with ``key``.
|
||||
##
|
||||
## You can specify timeout for DHT request with ``timeout`` value. ``0`` value
|
||||
## means no timeout.
|
||||
var transp = await api.pool.acquire()
|
||||
try:
|
||||
var pb = await transp.transactMessage(requestDHTGetValue(key, timeout))
|
||||
withMessage(pb) do:
|
||||
pb.enterDhtMessage(DHTResponseType.VALUE)
|
||||
result = pb.dhtGetSingleValue()
|
||||
finally:
|
||||
api.pool.release(transp)
|
||||
|
||||
proc dhtPutValue*(api: DaemonAPI, key: string, value: seq[byte],
|
||||
timeout = 0) {.async.} =
|
||||
## Associate ``value`` with ``key``.
|
||||
##
|
||||
## You can specify timeout for DHT request with ``timeout`` value. ``0`` value
|
||||
## means no timeout.
|
||||
var transp = await api.pool.acquire()
|
||||
try:
|
||||
var pb = await transp.transactMessage(requestDHTPutValue(key, value,
|
||||
timeout))
|
||||
withMessage(pb) do:
|
||||
discard
|
||||
finally:
|
||||
api.pool.release(transp)
|
||||
|
||||
proc dhtProvide*(api: DaemonAPI, cid: CID, timeout = 0) {.async.} =
|
||||
## Provide content with id ``cid``.
|
||||
##
|
||||
## You can specify timeout for DHT request with ``timeout`` value. ``0`` value
|
||||
## means no timeout.
|
||||
var transp = await api.pool.acquire()
|
||||
try:
|
||||
var pb = await transp.transactMessage(requestDHTProvide(cid, timeout))
|
||||
withMessage(pb) do:
|
||||
discard
|
||||
finally:
|
||||
api.pool.release(transp)
|
||||
|
||||
proc dhtFindPeersConnectedToPeer*(api: DaemonAPI, peer: PeerID,
|
||||
timeout = 0): Future[seq[PeerInfo]] {.async.} =
|
||||
## Find peers which are connected to peer with id ``peer``.
|
||||
##
|
||||
## You can specify timeout for DHT request with ``timeout`` value. ``0`` value
|
||||
## means no timeout.
|
||||
var transp = await api.pool.acquire()
|
||||
var list = newSeq[PeerInfo]()
|
||||
try:
|
||||
let spb = requestDHTFindPeersConnectedToPeer(peer, timeout)
|
||||
var pb = await transp.transactMessage(spb)
|
||||
withMessage(pb) do:
|
||||
pb.enterDhtMessage(DHTResponseType.BEGIN)
|
||||
while true:
|
||||
var message = await transp.recvMessage()
|
||||
var cpb = initProtoBuffer(message)
|
||||
if cpb.getDhtMessageType() == DHTResponseType.END:
|
||||
break
|
||||
list.add(cpb.dhtGetSinglePeerInfo())
|
||||
result = list
|
||||
finally:
|
||||
api.pool.release(transp)
|
||||
|
||||
proc dhtGetClosestPeers*(api: DaemonAPI, key: string,
|
||||
timeout = 0): Future[seq[PeerID]] {.async.} =
|
||||
## Get closest peers for ``key``.
|
||||
##
|
||||
## You can specify timeout for DHT request with ``timeout`` value. ``0`` value
|
||||
## means no timeout.
|
||||
var transp = await api.pool.acquire()
|
||||
var list = newSeq[PeerID]()
|
||||
try:
|
||||
let spb = requestDHTGetClosestPeers(key, timeout)
|
||||
var pb = await transp.transactMessage(spb)
|
||||
withMessage(pb) do:
|
||||
pb.enterDhtMessage(DHTResponseType.BEGIN)
|
||||
while true:
|
||||
var message = await transp.recvMessage()
|
||||
var cpb = initProtoBuffer(message)
|
||||
if cpb.getDhtMessageType() == DHTResponseType.END:
|
||||
break
|
||||
list.add(cpb.dhtGetSingleValue())
|
||||
result = list
|
||||
finally:
|
||||
api.pool.release(transp)
|
||||
|
||||
proc dhtFindProviders*(api: DaemonAPI, cid: CID, count: uint32,
|
||||
timeout = 0): Future[seq[PeerInfo]] {.async.} =
|
||||
## Get ``count`` providers for content with id ``cid``.
|
||||
##
|
||||
## You can specify timeout for DHT request with ``timeout`` value. ``0`` value
|
||||
## means no timeout.
|
||||
var transp = await api.pool.acquire()
|
||||
var list = newSeq[PeerInfo]()
|
||||
try:
|
||||
let spb = requestDHTFindProviders(cid, count, timeout)
|
||||
var pb = await transp.transactMessage(spb)
|
||||
withMessage(pb) do:
|
||||
pb.enterDhtMessage(DHTResponseType.BEGIN)
|
||||
while true:
|
||||
var message = await transp.recvMessage()
|
||||
var cpb = initProtoBuffer(message)
|
||||
if cpb.getDhtMessageType() == DHTResponseType.END:
|
||||
break
|
||||
list.add(cpb.dhtGetSinglePeerInfo())
|
||||
result = list
|
||||
finally:
|
||||
api.pool.release(transp)
|
||||
|
||||
proc dhtSearchValue*(api: DaemonAPI, key: string,
|
||||
timeout = 0): Future[seq[seq[byte]]] {.async.} =
|
||||
## Search for value with ``key``, return list of values found.
|
||||
##
|
||||
## You can specify timeout for DHT request with ``timeout`` value. ``0`` value
|
||||
## means no timeout.
|
||||
var transp = await api.pool.acquire()
|
||||
var list = newSeq[seq[byte]]()
|
||||
try:
|
||||
var pb = await transp.transactMessage(requestDHTSearchValue(key, timeout))
|
||||
withMessage(pb) do:
|
||||
pb.enterDhtMessage(DHTResponseType.BEGIN)
|
||||
while true:
|
||||
var message = await transp.recvMessage()
|
||||
var cpb = initProtoBuffer(message)
|
||||
if cpb.getDhtMessageType() == DHTResponseType.END:
|
||||
break
|
||||
list.add(cpb.dhtGetSingleValue())
|
||||
result = list
|
||||
finally:
|
||||
api.pool.release(transp)
|
||||
|
||||
when isMainModule:
|
||||
proc test() {.async.} =
|
||||
var api1 = await newDaemonApi(sockpath = "/tmp/p2pd-1.sock")
|
||||
var api2 = await newDaemonApi(sockpath = "/tmp/p2pd-2.sock")
|
||||
echo await api1.identity()
|
||||
echo await api2.identity()
|
||||
await sleepAsync(1000)
|
||||
await api1.close()
|
||||
await api2.close()
|
||||
|
||||
waitFor test()
|
141
libp2p/daemon/transpool.nim
Normal file
141
libp2p/daemon/transpool.nim
Normal file
@ -0,0 +1,141 @@
|
||||
## Nim-Libp2p
|
||||
## Copyright (c) 2018 Status Research & Development GmbH
|
||||
## Licensed under either of
|
||||
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
## at your option.
|
||||
## This file may not be copied, modified, or distributed except according to
|
||||
## those terms.
|
||||
|
||||
## This module implements Pool of StreamTransport.
|
||||
import asyncdispatch2
|
||||
|
||||
const
|
||||
DefaultPoolSize* = 8
|
||||
## Default pool size
|
||||
|
||||
type
|
||||
ConnectionFlags = enum
|
||||
None, Busy
|
||||
|
||||
PoolItem = object
|
||||
transp*: StreamTransport
|
||||
flags*: set[ConnectionFlags]
|
||||
|
||||
PoolState = enum
|
||||
Connecting, Connected, Closing, Closed
|
||||
|
||||
TransportPool* = ref object
|
||||
## Transports pool object
|
||||
transports: seq[PoolItem]
|
||||
busyCount: int
|
||||
state: PoolState
|
||||
bufferSize: int
|
||||
event: AsyncEvent
|
||||
|
||||
TransportPoolError* = object of AsyncError
|
||||
|
||||
proc waitAll[T](futs: seq[Future[T]]): Future[void] =
|
||||
## Performs waiting for all Future[T].
|
||||
var counter = len(futs)
|
||||
var retFuture = newFuture[void]("connpool.waitAllConnections")
|
||||
proc cb(udata: pointer) =
|
||||
dec(counter)
|
||||
if counter == 0:
|
||||
retFuture.complete()
|
||||
for fut in futs:
|
||||
fut.addCallback(cb)
|
||||
return retFuture
|
||||
|
||||
proc newPool*(address: TransportAddress, poolsize: int = DefaultPoolSize,
|
||||
bufferSize = DefaultStreamBufferSize,
|
||||
): Future[TransportPool] {.async.} =
|
||||
## Establish pool of connections to address ``address`` with size
|
||||
## ``poolsize``.
|
||||
result = new TransportPool
|
||||
result.bufferSize = bufferSize
|
||||
result.transports = newSeq[PoolItem](poolsize)
|
||||
var conns = newSeq[Future[StreamTransport]](poolsize)
|
||||
result.state = Connecting
|
||||
for i in 0..<poolsize:
|
||||
conns[i] = connect(address, bufferSize)
|
||||
# Waiting for all connections to be established.
|
||||
await waitAll(conns)
|
||||
# Checking connections and preparing pool.
|
||||
for i in 0..<poolsize:
|
||||
if conns[i].failed:
|
||||
raise conns[i].error
|
||||
else:
|
||||
let transp = conns[i].read()
|
||||
let item = PoolItem(transp: transp)
|
||||
result.transports[i] = item
|
||||
# Setup available connections event
|
||||
result.event = newAsyncEvent()
|
||||
result.state = Connected
|
||||
|
||||
proc acquire*(pool: TransportPool): Future[StreamTransport] {.async.} =
|
||||
## Acquire non-busy connection from pool ``pool``.
|
||||
var transp: StreamTransport
|
||||
if pool.state in {Connected}:
|
||||
while true:
|
||||
if pool.busyCount < len(pool.transports):
|
||||
for conn in pool.transports.mitems():
|
||||
if Busy notin conn.flags:
|
||||
conn.flags.incl(Busy)
|
||||
inc(pool.busyCount)
|
||||
transp = conn.transp
|
||||
break
|
||||
else:
|
||||
await pool.event.wait()
|
||||
pool.event.clear()
|
||||
|
||||
if not isNil(transp):
|
||||
break
|
||||
else:
|
||||
raise newException(TransportPoolError, "Pool is not ready!")
|
||||
result = transp
|
||||
|
||||
proc release*(pool: TransportPool, transp: StreamTransport) =
|
||||
## Release connection ``transp`` back to pool ``pool``.
|
||||
if pool.state in {Connected, Closing}:
|
||||
var found = false
|
||||
for conn in pool.transports.mitems():
|
||||
if conn.transp == transp:
|
||||
conn.flags.excl(Busy)
|
||||
dec(pool.busyCount)
|
||||
pool.event.fire()
|
||||
found = true
|
||||
break
|
||||
if not found:
|
||||
raise newException(TransportPoolError, "Transport not bound to pool!")
|
||||
else:
|
||||
raise newException(TransportPoolError, "Pool is not ready!")
|
||||
|
||||
proc join*(pool: TransportPool) {.async.} =
|
||||
## Waiting for all connection to become available.
|
||||
if pool.state in {Connected, Closing}:
|
||||
while true:
|
||||
if pool.busyCount == 0:
|
||||
break
|
||||
else:
|
||||
await pool.event.wait()
|
||||
pool.event.clear()
|
||||
elif pool.state == Connecting:
|
||||
raise newException(TransportPoolError, "Pool is not ready!")
|
||||
|
||||
proc close*(pool: TransportPool) {.async.} =
|
||||
## Closes transports pool ``pool`` and release all resources.
|
||||
if pool.state == Connected:
|
||||
pool.state = Closing
|
||||
# Waiting for all transports to become available.
|
||||
await pool.join()
|
||||
# Closing all transports
|
||||
var pending = newSeq[Future[void]](len(pool.transports))
|
||||
for i in 0..<len(pool.transports):
|
||||
let transp = pool.transports[i].transp
|
||||
transp.close()
|
||||
pending[i] = transp.join()
|
||||
# Waiting for all transports to be closed
|
||||
await waitAll(pending)
|
||||
# Mark pool as `Closed`.
|
||||
pool.state = Closed
|
276
libp2p/protobuf/minprotobuf.nim
Normal file
276
libp2p/protobuf/minprotobuf.nim
Normal file
@ -0,0 +1,276 @@
|
||||
## Nim-Libp2p
|
||||
## Copyright (c) 2018 Status Research & Development GmbH
|
||||
## Licensed under either of
|
||||
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
## at your option.
|
||||
## This file may not be copied, modified, or distributed except according to
|
||||
## those terms.
|
||||
|
||||
## This module implements minimal Google's ProtoBuf primitives.
|
||||
import varint
|
||||
|
||||
const
|
||||
MaxMessageSize* = 1'u shl 22
|
||||
|
||||
type
|
||||
ProtoFieldKind* = enum
|
||||
## Protobuf's field types enum
|
||||
Varint, Fixed64, Length, StartGroup, EndGroup, Fixed32
|
||||
|
||||
ProtoFlags* = enum
|
||||
## Protobuf's encoding types
|
||||
WithVarintLength
|
||||
|
||||
ProtoBuffer* = object
|
||||
## Protobuf's message representation object
|
||||
options: set[ProtoFlags]
|
||||
buffer*: seq[byte]
|
||||
offset*: int
|
||||
length*: int
|
||||
|
||||
ProtoField* = object
|
||||
## Protobuf's message field representation object
|
||||
index: int
|
||||
case kind: ProtoFieldKind
|
||||
of Varint:
|
||||
vint*: uint64
|
||||
of Fixed64:
|
||||
vfloat64*: float64
|
||||
of Length:
|
||||
vbuffer*: seq[byte]
|
||||
of Fixed32:
|
||||
vfloat32*: float32
|
||||
of StartGroup, EndGroup:
|
||||
discard
|
||||
|
||||
template protoHeader*(index: int, wire: ProtoFieldKind): uint =
|
||||
## Get protobuf's field header integer for ``index`` and ``wire``.
|
||||
((uint(index) shl 3) or cast[uint](wire))
|
||||
|
||||
template protoHeader*(field: ProtoField): uint =
|
||||
## Get protobuf's field header integer for ``field``.
|
||||
((uint(field.index) shl 3) or cast[uint](field.kind))
|
||||
|
||||
template toOpenArray*(pb: ProtoBuffer): untyped =
|
||||
toOpenArray(pb.buffer, pb.offset, len(pb.buffer) - 1)
|
||||
|
||||
template isEmpty*(pb: ProtoBuffer): bool =
|
||||
len(pb.buffer) - pb.offset <= 0
|
||||
|
||||
template isEnough*(pb: ProtoBuffer, length: int): bool =
|
||||
len(pb.buffer) - pb.offset - length >= 0
|
||||
|
||||
template getPtr*(pb: ProtoBuffer): pointer =
|
||||
cast[pointer](unsafeAddr pb.buffer[pb.offset])
|
||||
|
||||
template getLen*(pb: ProtoBuffer): int =
|
||||
len(pb.buffer) - pb.offset
|
||||
|
||||
proc vsizeof*(field: ProtoField): int {.inline.} =
|
||||
## Returns number of bytes required to store protobuf's field ``field``.
|
||||
result = vsizeof(protoHeader(field))
|
||||
case field.kind
|
||||
of ProtoFieldKind.Varint:
|
||||
result += vsizeof(field.vint)
|
||||
of ProtoFieldKind.Fixed64:
|
||||
result += sizeof(field.vfloat64)
|
||||
of ProtoFieldKind.Fixed32:
|
||||
result += sizeof(field.vfloat32)
|
||||
of ProtoFieldKind.Length:
|
||||
result += vsizeof(uint(len(field.vbuffer))) + len(field.vbuffer)
|
||||
else:
|
||||
discard
|
||||
|
||||
proc initProtoField*(index: int, value: SomeVarint): ProtoField =
|
||||
## Initialize ProtoField with integer value.
|
||||
result = ProtoField(kind: Varint, index: index)
|
||||
when type(value) is uint64:
|
||||
result.vint = value
|
||||
else:
|
||||
result.vint = cast[uint64](value)
|
||||
|
||||
proc initProtoField*(index: int, value: openarray[byte]): ProtoField =
|
||||
## Initialize ProtoField with bytes array.
|
||||
result = ProtoField(kind: Length, index: index)
|
||||
if len(value) > 0:
|
||||
result.vbuffer = newSeq[byte](len(value))
|
||||
copyMem(addr result.vbuffer[0], unsafeAddr value[0], len(value))
|
||||
|
||||
proc initProtoField*(index: int, value: string): ProtoField =
|
||||
## Initialize ProtoField with string.
|
||||
result = ProtoField(kind: Length, index: index)
|
||||
if len(value) > 0:
|
||||
result.vbuffer = newSeq[byte](len(value))
|
||||
copyMem(addr result.vbuffer[0], unsafeAddr value[0], len(value))
|
||||
|
||||
proc initProtoField*(index: int, value: ProtoBuffer): ProtoField {.inline.} =
|
||||
## Initialize ProtoField with nested message stored in ``value``.
|
||||
##
|
||||
## Note: This procedure performs shallow copy of ``value`` sequence.
|
||||
result = ProtoField(kind: Length, index: index)
|
||||
if len(value.buffer) > 0:
|
||||
shallowCopy(result.vbuffer, value.buffer)
|
||||
|
||||
proc initProtoBuffer*(data: seq[byte], offset = 0,
|
||||
options: set[ProtoFlags] = {}): ProtoBuffer =
|
||||
## Initialize ProtoBuffer with shallow copy of ``data``.
|
||||
shallowCopy(result.buffer, data)
|
||||
result.offset = offset
|
||||
result.options = options
|
||||
|
||||
proc initProtoBuffer*(options: set[ProtoFlags] = {}): ProtoBuffer =
|
||||
## Initialize ProtoBuffer with new sequence of capacity ``cap``.
|
||||
result.buffer = newSeqOfCap[byte](128)
|
||||
result.options = options
|
||||
if WithVarintLength in options:
|
||||
# Our buffer will start from position 10, so we can store length of buffer
|
||||
# in [0, 9].
|
||||
result.buffer.setLen(10)
|
||||
result.offset = 10
|
||||
|
||||
proc write*(pb: var ProtoBuffer, field: ProtoField) =
|
||||
## Encode protobuf's field ``field`` and store it to protobuf's buffer ``pb``.
|
||||
var length = 0
|
||||
var res: VarintStatus
|
||||
pb.buffer.setLen(len(pb.buffer) + vsizeof(field))
|
||||
res = putUVarint(pb.toOpenArray(), length, protoHeader(field))
|
||||
assert(res == VarintStatus.Success)
|
||||
pb.offset += length
|
||||
case field.kind
|
||||
of ProtoFieldKind.Varint:
|
||||
res = putUVarint(pb.toOpenArray(), length, field.vint)
|
||||
assert(res == VarintStatus.Success)
|
||||
pb.offset += length
|
||||
of ProtoFieldKind.Fixed64:
|
||||
assert(pb.isEnough(8))
|
||||
var value = cast[uint64](field.vfloat64)
|
||||
pb.buffer[pb.offset] = byte(value and 0xFF'u32)
|
||||
pb.buffer[pb.offset + 1] = byte((value shr 8) and 0xFF'u32)
|
||||
pb.buffer[pb.offset + 2] = byte((value shr 16) and 0xFF'u32)
|
||||
pb.buffer[pb.offset + 3] = byte((value shr 24) and 0xFF'u32)
|
||||
pb.buffer[pb.offset + 4] = byte((value shr 32) and 0xFF'u32)
|
||||
pb.buffer[pb.offset + 5] = byte((value shr 40) and 0xFF'u32)
|
||||
pb.buffer[pb.offset + 6] = byte((value shr 48) and 0xFF'u32)
|
||||
pb.buffer[pb.offset + 7] = byte((value shr 56) and 0xFF'u32)
|
||||
pb.offset += 8
|
||||
of ProtoFieldKind.Fixed32:
|
||||
assert(pb.isEnough(4))
|
||||
var value = cast[uint32](field.vfloat32)
|
||||
pb.buffer[pb.offset] = byte(value and 0xFF'u32)
|
||||
pb.buffer[pb.offset + 1] = byte((value shr 8) and 0xFF'u32)
|
||||
pb.buffer[pb.offset + 2] = byte((value shr 16) and 0xFF'u32)
|
||||
pb.buffer[pb.offset + 3] = byte((value shr 24) and 0xFF'u32)
|
||||
pb.offset += 4
|
||||
of ProtoFieldKind.Length:
|
||||
res = putUVarint(pb.toOpenArray(), length, uint(len(field.vbuffer)))
|
||||
assert(res == VarintStatus.Success)
|
||||
pb.offset += length
|
||||
assert(pb.isEnough(len(field.vbuffer)))
|
||||
copyMem(addr pb.buffer[pb.offset], unsafeAddr field.vbuffer[0],
|
||||
len(field.vbuffer))
|
||||
pb.offset += len(field.vbuffer)
|
||||
else:
|
||||
discard
|
||||
|
||||
proc finish*(pb: var ProtoBuffer) =
|
||||
## Prepare protobuf's buffer ``pb`` for writing to stream.
|
||||
var length = 0
|
||||
assert(len(pb.buffer) > 0)
|
||||
if WithVarintLength in pb.options:
|
||||
let size = uint(len(pb.buffer) - 10)
|
||||
let pos = 10 - vsizeof(length)
|
||||
let res = putUVarint(pb.buffer.toOpenArray(pos, 9), length, size)
|
||||
assert(res == VarintStatus.Success)
|
||||
pb.offset = pos
|
||||
else:
|
||||
pb.offset = 0
|
||||
|
||||
proc getVarintValue*(data: var ProtoBuffer, field: int,
|
||||
value: var SomeVarint): int =
|
||||
## Get value of `Varint` type.
|
||||
var length = 0
|
||||
var header = 0'u64
|
||||
var soffset = data.offset
|
||||
|
||||
if not data.isEmpty() and
|
||||
getUVarint(data.toOpenArray(), length, header) == VarintStatus.Success:
|
||||
data.offset += length
|
||||
if header == protoHeader(field, Varint):
|
||||
if not data.isEmpty():
|
||||
when type(value) is int32 or type(value) is int64 or type(value) is int:
|
||||
let res = getSVarint(data.toOpenArray(), length, value)
|
||||
else:
|
||||
let res = getUVarint(data.toOpenArray(), length, value)
|
||||
data.offset += length
|
||||
result = length
|
||||
return
|
||||
# Restore offset on error
|
||||
data.offset = soffset
|
||||
|
||||
proc getLengthValue*[T: string|seq[byte]](data: var ProtoBuffer, field: int,
|
||||
buffer: var T): int =
|
||||
## Get value of `Length` type.
|
||||
var length = 0
|
||||
var header = 0'u64
|
||||
var ssize = 0'u64
|
||||
var soffset = data.offset
|
||||
result = -1
|
||||
buffer.setLen(0)
|
||||
if not data.isEmpty() and
|
||||
getUVarint(data.toOpenArray(), length, header) == VarintStatus.Success:
|
||||
data.offset += length
|
||||
if header == protoHeader(field, Length):
|
||||
if not data.isEmpty() and
|
||||
getUVarint(data.toOpenArray(), length, ssize) == VarintStatus.Success:
|
||||
data.offset += length
|
||||
if ssize <= MaxMessageSize and data.isEnough(int(ssize)):
|
||||
buffer.setLen(ssize)
|
||||
# Protobuf allow zero-length values.
|
||||
if ssize > 0'u64:
|
||||
copyMem(addr buffer[0], addr data.buffer[data.offset], ssize)
|
||||
result = int(ssize)
|
||||
data.offset += int(ssize)
|
||||
return
|
||||
# Restore offset on error
|
||||
data.offset = soffset
|
||||
|
||||
proc getBytes*(data: var ProtoBuffer, field: int,
|
||||
buffer: var seq[byte]): int {.inline.} =
|
||||
## Get value of `Length` type as bytes.
|
||||
result = getLengthValue(data, field, buffer)
|
||||
|
||||
proc getString*(data: var ProtoBuffer, field: int,
|
||||
buffer: var string): int {.inline.} =
|
||||
## Get value of `Length` type as string.
|
||||
result = getLengthValue(data, field, buffer)
|
||||
|
||||
proc enterSubmessage*(pb: var ProtoBuffer): int =
|
||||
## Processes protobuf's sub-message and adjust internal offset to enter
|
||||
## inside of sub-message. Returns field index of sub-message field or
|
||||
## ``0`` on error.
|
||||
var length = 0
|
||||
var header = 0'u64
|
||||
var msize = 0'u64
|
||||
var soffset = pb.offset
|
||||
|
||||
if not pb.isEmpty() and
|
||||
getUVarint(pb.toOpenArray(), length, header) == VarintStatus.Success:
|
||||
pb.offset += length
|
||||
if (header and 0x07'u64) == cast[uint64](ProtoFieldKind.Length):
|
||||
if not pb.isEmpty() and
|
||||
getUVarint(pb.toOpenArray(), length, msize) == VarintStatus.Success:
|
||||
pb.offset += length
|
||||
if msize <= MaxMessageSize and pb.isEnough(int(msize)):
|
||||
pb.length = int(msize)
|
||||
result = int(header shr 3)
|
||||
return
|
||||
# Restore offset on error
|
||||
pb.offset = soffset
|
||||
|
||||
proc skipSubmessage*(pb: var ProtoBuffer) =
|
||||
## Skip current protobuf's sub-message and adjust internal offset to the
|
||||
## end of sub-message.
|
||||
assert(pb.length != 0)
|
||||
pb.offset += pb.length
|
||||
pb.length = 0
|
278
libp2p/protobuf/varint.nim
Normal file
278
libp2p/protobuf/varint.nim
Normal file
@ -0,0 +1,278 @@
|
||||
## Nim-Libp2p
|
||||
## Copyright (c) 2018 Status Research & Development GmbH
|
||||
## Licensed under either of
|
||||
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
## at your option.
|
||||
## This file may not be copied, modified, or distributed except according to
|
||||
## those terms.
|
||||
|
||||
## This module implements Google ProtoBuf's variable integer `VARINT`.
|
||||
import bitops
|
||||
|
||||
type
|
||||
VarintStatus* = enum
|
||||
Error,
|
||||
Success,
|
||||
Overflow,
|
||||
Incomplete,
|
||||
Overrun
|
||||
|
||||
SomeUVarint* = uint | uint64 | uint32
|
||||
SomeSVarint* = int | int64 | int32
|
||||
SomeVarint* = SomeUVarint | SomeSVarint
|
||||
VarintError* = object of Exception
|
||||
|
||||
proc vsizeof*(x: SomeUVarint|SomeSVarint): int {.inline.} =
|
||||
## Returns number of bytes required to encode integer ``x`` as varint.
|
||||
if x == cast[type(x)](0):
|
||||
result = 1
|
||||
else:
|
||||
result = (fastLog2(x) + 1 + 7 - 1) div 7
|
||||
|
||||
proc getUVarint*(pbytes: openarray[byte], outlen: var int,
|
||||
outval: var SomeUVarint): VarintStatus =
|
||||
## Decode `unsigned varint` from buffer ``pbytes`` and store it to ``outval``.
|
||||
## On success ``outlen`` will be set to number of bytes processed while
|
||||
## decoding `unsigned varint`.
|
||||
##
|
||||
## If array ``pbytes`` is empty, ``Incomplete`` error will be returned.
|
||||
##
|
||||
## If there not enough bytes available in array ``pbytes`` to decode `unsigned
|
||||
## varint`, ``Incomplete`` error will be returned.
|
||||
##
|
||||
## If encoded value can produce integer overflow, ``Overflow`` error will be
|
||||
## returned.
|
||||
##
|
||||
## Note, when decoding 10th byte of 64bit integer only 1 bit from byte will be
|
||||
## decoded, all other bits will be ignored. When decoding 5th byte of 32bit
|
||||
## integer only 4 bits from byte will be decoded, all other bits will be
|
||||
## ignored.
|
||||
const MaxBits = byte(sizeof(outval) * 8)
|
||||
var shift = 0'u8
|
||||
result = VarintStatus.Incomplete
|
||||
outlen = 0
|
||||
outval = cast[type(outval)](0)
|
||||
for i in 0..<len(pbytes):
|
||||
let b = pbytes[i]
|
||||
if shift >= MaxBits:
|
||||
result = VarintStatus.Overflow
|
||||
outlen = 0
|
||||
outval = cast[type(outval)](0)
|
||||
break
|
||||
else:
|
||||
outval = outval or (cast[type(outval)](b and 0x7F'u8) shl shift)
|
||||
shift += 7
|
||||
inc(outlen)
|
||||
if (b and 0x80'u8) == 0'u8:
|
||||
result = VarintStatus.Success
|
||||
break
|
||||
|
||||
if result == VarintStatus.Incomplete:
|
||||
outlen = 0
|
||||
outval = cast[type(outval)](0)
|
||||
|
||||
proc putUVarint*(pbytes: var openarray[byte], outlen: var int,
|
||||
outval: SomeUVarint): VarintStatus =
|
||||
## Encode `unsigned varint` ``outval`` and store it to array ``pbytes``.
|
||||
##
|
||||
## On success ``outlen`` will hold number of bytes (octets) used to encode
|
||||
## unsigned integer ``v``.
|
||||
##
|
||||
## If there not enough bytes available in buffer ``pbytes``, ``Incomplete``
|
||||
## error will be returned and ``outlen`` will be set to number of bytes
|
||||
## required.
|
||||
##
|
||||
## Maximum encoded length of 64bit integer is 10 octets.
|
||||
## Maximum encoded length of 32bit integer is 5 octets.
|
||||
var buffer: array[10, byte]
|
||||
var value = outval
|
||||
var k = 0
|
||||
|
||||
if value <= cast[type(outval)](0x7F):
|
||||
buffer[0] = cast[byte](outval and 0xFF)
|
||||
inc(k)
|
||||
else:
|
||||
while value != cast[type(outval)](0):
|
||||
buffer[k] = cast[byte]((value and 0x7F) or 0x80)
|
||||
value = value shr 7
|
||||
inc(k)
|
||||
buffer[k - 1] = buffer[k - 1] and 0x7F'u8
|
||||
|
||||
outlen = k
|
||||
if len(pbytes) >= k:
|
||||
copyMem(addr pbytes[0], addr buffer[0], k)
|
||||
result = VarintStatus.Success
|
||||
else:
|
||||
result = VarintStatus.Overrun
|
||||
|
||||
proc getSVarint*(pbytes: openarray[byte], outsize: var int,
|
||||
outval: var SomeSVarint): VarintStatus {.inline.} =
|
||||
## Decode `signed varint` from buffer ``pbytes`` and store it to ``outval``.
|
||||
## On success ``outlen`` will be set to number of bytes processed while
|
||||
## decoding `signed varint`.
|
||||
##
|
||||
## If array ``pbytes`` is empty, ``Incomplete`` error will be returned.
|
||||
##
|
||||
## If there not enough bytes available in array ``pbytes`` to decode `signed
|
||||
## varint`, ``Incomplete`` error will be returned.
|
||||
##
|
||||
## If encoded value can produce integer overflow, ``Overflow`` error will be
|
||||
## returned.
|
||||
##
|
||||
## Note, when decoding 10th byte of 64bit integer only 1 bit from byte will be
|
||||
## decoded, all other bits will be ignored. When decoding 5th byte of 32bit
|
||||
## integer only 4 bits from byte will be decoded, all other bits will be
|
||||
## ignored.
|
||||
when sizeof(outval) == 8:
|
||||
var value: uint64
|
||||
else:
|
||||
var value: uint32
|
||||
|
||||
result = getUVarint(pbytes, outsize, value)
|
||||
if result == VarintStatus.Success:
|
||||
if (value and cast[type(value)](1)) != cast[type(value)](0):
|
||||
outval = cast[type(outval)](not(value shr 1))
|
||||
else:
|
||||
outval = cast[type(outval)](value shr 1)
|
||||
|
||||
proc putSVarint*(pbytes: var openarray[byte], outsize: var int,
|
||||
outval: SomeSVarint): VarintStatus {.inline.} =
|
||||
## Encode `signed varint` ``outval`` and store it to array ``pbytes``.
|
||||
##
|
||||
## On success ``outlen`` will hold number of bytes (octets) used to encode
|
||||
## unsigned integer ``v``.
|
||||
##
|
||||
## If there not enough bytes available in buffer ``pbytes``, ``Incomplete``
|
||||
## error will be returned and ``outlen`` will be set to number of bytes
|
||||
## required.
|
||||
##
|
||||
## Maximum encoded length of 64bit integer is 10 octets.
|
||||
## Maximum encoded length of 32bit integer is 5 octets.
|
||||
when sizeof(outval) == 8:
|
||||
var value: uint64 =
|
||||
if outval < 0:
|
||||
not(cast[uint64](outval) shl 1)
|
||||
else:
|
||||
cast[uint64](outval) shl 1
|
||||
else:
|
||||
var value: uint32 =
|
||||
if outval < 0:
|
||||
not(cast[uint32](outval) shl 1)
|
||||
else:
|
||||
cast[uint32](outval) shl 1
|
||||
result = putUVarint(pbytes, outsize, value)
|
||||
|
||||
proc encodeVarint*(value: SomeUVarint|SomeSVarint): seq[byte] {.inline.} =
|
||||
## Encode integer to `signed/unsigned varint` and returns sequence of bytes
|
||||
## as result.
|
||||
var outsize = 0
|
||||
result = newSeqOfCap[byte](10)
|
||||
when sizeof(value) == 4:
|
||||
result.setLen(5)
|
||||
else:
|
||||
result.setLen(10)
|
||||
when type(value) is SomeSVarint:
|
||||
let res = putSVarint(result, outsize, value)
|
||||
else:
|
||||
let res = putUVarint(result, outsize, value)
|
||||
if res == VarintStatus.Success:
|
||||
result.setLen(outsize)
|
||||
else:
|
||||
raise newException(VarintError, "Error '" & $res & "'")
|
||||
|
||||
proc decodeSVarint*(data: openarray[byte]): int {.inline.} =
|
||||
## Decode signed integer from array ``data`` and return it as result.
|
||||
var outsize = 0
|
||||
let res = getSVarint(data, outsize, result)
|
||||
if res != VarintStatus.Success:
|
||||
raise newException(VarintError, "Error '" & $res & "'")
|
||||
|
||||
proc decodeUVarint*(data: openarray[byte]): uint {.inline.} =
|
||||
## Decode unsigned integer from array ``data`` and return it as result.
|
||||
var outsize = 0
|
||||
let res = getUVarint(data, outsize, result)
|
||||
if res != VarintStatus.Success:
|
||||
raise newException(VarintError, "Error '" & $res & "'")
|
||||
|
||||
when isMainModule:
|
||||
import unittest
|
||||
|
||||
const edgeValues = [
|
||||
0'u64, (1'u64 shl 7) - 1'u64,
|
||||
(1'u64 shl 7), (1'u64 shl 14) - 1'u64,
|
||||
(1'u64 shl 14), (1'u64 shl 21) - 1'u64,
|
||||
(1'u64 shl 21), (1'u64 shl 28) - 1'u64,
|
||||
(1'u64 shl 28), (1'u64 shl 35) - 1'u64,
|
||||
(1'u64 shl 35), (1'u64 shl 42) - 1'u64,
|
||||
(1'u64 shl 42), (1'u64 shl 49) - 1'u64,
|
||||
(1'u64 shl 49), (1'u64 shl 56) - 1'u64,
|
||||
(1'u64 shl 56), (1'u64 shl 63) - 1'u64,
|
||||
(1'u64 shl 63), 0xFFFF_FFFF_FFFF_FFFF'u64
|
||||
]
|
||||
const edgeSizes = [
|
||||
1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10
|
||||
]
|
||||
|
||||
suite "Variable integer test suite":
|
||||
|
||||
test "vsizeof() edge cases test":
|
||||
for i in 0..<len(edgeValues):
|
||||
check vsizeof(edgeValues[i]) == edgeSizes[i]
|
||||
|
||||
test "Success edge cases test":
|
||||
var buffer = newSeq[byte]()
|
||||
var length = 0
|
||||
var value = 0'u64
|
||||
for i in 0..<len(edgeValues):
|
||||
buffer.setLen(edgeSizes[i])
|
||||
check:
|
||||
putUVarint(buffer, length, edgeValues[i]) == VarintStatus.Success
|
||||
getUVarint(buffer, length, value) == VarintStatus.Success
|
||||
value == edgeValues[i]
|
||||
|
||||
test "Buffer Overrun edge cases test":
|
||||
var buffer = newSeq[byte]()
|
||||
var length = 0
|
||||
for i in 0..<len(edgeValues):
|
||||
buffer.setLen(edgeSizes[i] - 1)
|
||||
let res = putUVarint(buffer, length, edgeValues[i])
|
||||
check:
|
||||
res == VarintStatus.Overrun
|
||||
length == edgeSizes[i]
|
||||
|
||||
test "Buffer Incomplete edge cases test":
|
||||
var buffer = newSeq[byte]()
|
||||
var length = 0
|
||||
var value = 0'u64
|
||||
for i in 0..<len(edgeValues):
|
||||
buffer.setLen(edgeSizes[i])
|
||||
check putUVarint(buffer, length, edgeValues[i]) == VarintStatus.Success
|
||||
buffer.setLen(len(buffer) - 1)
|
||||
check:
|
||||
getUVarint(buffer, length, value) == VarintStatus.Incomplete
|
||||
|
||||
test "Integer Overflow 32bit test":
|
||||
var buffer = newSeq[byte]()
|
||||
var length = 0
|
||||
for i in 0..<len(edgeValues):
|
||||
if edgeSizes[i] > 5:
|
||||
var value = 0'u32
|
||||
buffer.setLen(edgeSizes[i])
|
||||
check:
|
||||
putUVarint(buffer, length, edgeValues[i]) == VarintStatus.Success
|
||||
getUVarint(buffer, length, value) == VarintStatus.Overflow
|
||||
|
||||
test "Integer Overflow 64bit test":
|
||||
var buffer = newSeq[byte]()
|
||||
var length = 0
|
||||
for i in 0..<len(edgeValues):
|
||||
if edgeSizes[i] > 9:
|
||||
var value = 0'u64
|
||||
buffer.setLen(edgeSizes[i] + 1)
|
||||
check:
|
||||
putUVarint(buffer, length, edgeValues[i]) == VarintStatus.Success
|
||||
buffer[9] = buffer[9] or 0x80'u8
|
||||
buffer[10] = 0x01'u8
|
||||
check:
|
||||
getUVarint(buffer, length, value) == VarintStatus.Overflow
|
Loading…
x
Reference in New Issue
Block a user