mirror of https://github.com/status-im/nim-eth.git
Add implementation of Portal wire protocol
This commit is contained in:
parent
d18ebaa570
commit
e2e30247bf
|
@ -77,6 +77,7 @@ build_script:
|
|||
test_script:
|
||||
- nimble test
|
||||
- nimble build_dcli
|
||||
- nimble build_portalcli
|
||||
|
||||
deploy: off
|
||||
|
||||
|
|
|
@ -237,4 +237,4 @@ jobs:
|
|||
nimble install -y --depsOnly
|
||||
nimble test
|
||||
nimble build_dcli
|
||||
|
||||
nimble build_portalcli
|
||||
|
|
|
@ -47,3 +47,4 @@ script:
|
|||
- nimble install -y --depsOnly
|
||||
- nimble test
|
||||
- nimble build_dcli
|
||||
- nimble build_portalcli
|
||||
|
|
|
@ -46,6 +46,9 @@ task test_discv5, "Run discovery v5 tests":
|
|||
task test_discv4, "Run discovery v4 tests":
|
||||
runTest("tests/p2p/test_discovery")
|
||||
|
||||
task test_portal, "Run Portal network tests":
|
||||
runTest("tests/p2p/all_portal_tests")
|
||||
|
||||
task test_p2p, "Run p2p tests":
|
||||
runTest("tests/p2p/all_tests")
|
||||
|
||||
|
@ -86,3 +89,6 @@ task test_discv5_full, "Run discovery v5 and its dependencies tests":
|
|||
|
||||
task build_dcli, "Build dcli":
|
||||
buildBinary("eth/p2p/discoveryv5/dcli")
|
||||
|
||||
task build_portalcli, "Build portalcli":
|
||||
buildBinary("eth/p2p/portal/portalcli")
|
||||
|
|
|
@ -132,16 +132,20 @@ type
|
|||
bootstrapRecords*: seq[Record]
|
||||
ipVote: IpVote
|
||||
enrAutoUpdate: bool
|
||||
talkProtocols: Table[seq[byte], TalkProtocolHandler]
|
||||
talkProtocols*: Table[seq[byte], TalkProtocol] # TODO: Table is a bit of
|
||||
# overkill here, use sequence
|
||||
rng*: ref BrHmacDrbgContext
|
||||
|
||||
PendingRequest = object
|
||||
node: Node
|
||||
message: seq[byte]
|
||||
|
||||
TalkProtocolHandler* = proc(request: seq[byte]): seq[byte]
|
||||
TalkProtocolHandler* = proc(p: TalkProtocol, request: seq[byte]): seq[byte]
|
||||
{.gcsafe, raises: [Defect].}
|
||||
|
||||
TalkProtocol* = ref object of RootObj
|
||||
protocolHandler*: TalkProtocolHandler
|
||||
|
||||
DiscResult*[T] = Result[T, cstring]
|
||||
|
||||
proc addNode*(d: Protocol, node: Node): bool =
|
||||
|
@ -299,15 +303,16 @@ proc handleFindNode(d: Protocol, fromId: NodeId, fromAddr: Address,
|
|||
|
||||
proc handleTalkReq(d: Protocol, fromId: NodeId, fromAddr: Address,
|
||||
talkreq: TalkReqMessage, reqId: RequestId) =
|
||||
let protocolHandler = d.talkProtocols.getOrDefault(talkreq.protocol)
|
||||
let talkProtocol = d.talkProtocols.getOrDefault(talkreq.protocol)
|
||||
|
||||
let talkresp =
|
||||
if protocolHandler.isNil():
|
||||
if talkProtocol.isNil() or talkProtocol.protocolHandler.isNil():
|
||||
# Protocol identifier that is not registered and thus not supported. An
|
||||
# empty response is send as per specification.
|
||||
TalkRespMessage(response: @[])
|
||||
else:
|
||||
TalkRespMessage(response: protocolHandler(talkreq.request))
|
||||
TalkRespMessage(response: talkProtocol.protocolHandler(talkProtocol,
|
||||
talkreq.request))
|
||||
let (data, _) = encodeMessagePacket(d.rng[], d.codec, fromId, fromAddr,
|
||||
encodeMessage(talkresp, reqId))
|
||||
|
||||
|
@ -341,10 +346,10 @@ proc handleMessage(d: Protocol, srcId: NodeId, fromAddr: Address,
|
|||
trace "Timed out or unrequested message", kind = message.kind,
|
||||
origin = fromAddr
|
||||
|
||||
proc registerTalkProtocol*(d: Protocol, protocol: seq[byte],
|
||||
handler: TalkProtocolHandler): DiscResult[void] =
|
||||
proc registerTalkProtocol*(d: Protocol, protocolId: seq[byte],
|
||||
protocol: TalkProtocol): DiscResult[void] =
|
||||
# Currently allow only for one handler per talk protocol.
|
||||
if d.talkProtocols.hasKeyOrPut(protocol, handler):
|
||||
if d.talkProtocols.hasKeyOrPut(protocolId, protocol):
|
||||
err("Protocol identifier already registered")
|
||||
else:
|
||||
ok()
|
||||
|
|
|
@ -0,0 +1,153 @@
|
|||
# nim-eth - Portal Network- Message types
|
||||
# Copyright (c) 2021 Status Research & Development GmbH
|
||||
# Licensed and distributed under either of
|
||||
# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT).
|
||||
# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0).
|
||||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||
|
||||
# As per spec:
|
||||
# https://github.com/ethereum/stateless-ethereum-specs/blob/master/state-network.md#wire-protocol
|
||||
|
||||
{.push raises: [Defect].}
|
||||
|
||||
import
|
||||
stint, stew/[results, objects],
|
||||
../../ssz/ssz_serialization
|
||||
|
||||
export ssz_serialization, stint
|
||||
|
||||
type
|
||||
ByteList* = List[byte, 2048]
|
||||
|
||||
MessageKind* = enum
|
||||
unused = 0x00
|
||||
|
||||
ping = 0x01
|
||||
pong = 0x02
|
||||
findnode = 0x03
|
||||
nodes = 0x04
|
||||
findcontent = 0x05
|
||||
foundcontent = 0x06
|
||||
advertise = 0x07
|
||||
requestproofs = 0x08
|
||||
|
||||
PingMessage* = object
|
||||
enrSeq*: uint64
|
||||
dataRadius*: UInt256
|
||||
|
||||
PongMessage* = object
|
||||
enrSeq*: uint64
|
||||
dataRadius*: UInt256
|
||||
|
||||
FindNodeMessage* = object
|
||||
distances*: List[uint16, 256]
|
||||
|
||||
NodesMessage* = object
|
||||
total*: uint8
|
||||
enrs*: List[ByteList, 32] # ByteList here is the rlp encoded ENR. This could
|
||||
# also be limited to 300 bytes instead of 2048
|
||||
|
||||
FindContentMessage* = object
|
||||
contentKey*: ByteList
|
||||
|
||||
FoundContentMessage* = object
|
||||
enrs*: List[ByteList, 32]
|
||||
payload*: ByteList
|
||||
|
||||
AdvertiseMessage* = List[ByteList, 32] # No container, heh...
|
||||
|
||||
# This would be more consistent with the other messages
|
||||
# AdvertiseMessage* = object
|
||||
# contentKeys*: List[ByteList, 32]
|
||||
|
||||
RequestProofsMessage* = object
|
||||
connectionId*: List[byte, 4]
|
||||
contentKeys*: List[ByteList, 32]
|
||||
|
||||
Message* = object
|
||||
case kind*: MessageKind
|
||||
of ping:
|
||||
ping*: PingMessage
|
||||
of pong:
|
||||
pong*: PongMessage
|
||||
of findnode:
|
||||
findNode*: FindNodeMessage
|
||||
of nodes:
|
||||
nodes*: NodesMessage
|
||||
of findcontent:
|
||||
findcontent*: FindContentMessage
|
||||
of foundcontent:
|
||||
foundcontent*: FoundContentMessage
|
||||
of advertise:
|
||||
advertise*: AdvertiseMessage
|
||||
of requestproofs:
|
||||
requestproofs*: RequestProofsMessage
|
||||
else:
|
||||
discard
|
||||
|
||||
SomeMessage* =
|
||||
PingMessage or PongMessage or
|
||||
FindNodeMessage or NodesMessage or
|
||||
FindContentMessage or FoundContentMessage or
|
||||
AdvertiseMessage or RequestProofsMessage
|
||||
|
||||
template messageKind*(T: typedesc[SomeMessage]): MessageKind =
|
||||
when T is PingMessage: ping
|
||||
elif T is PongMessage: pong
|
||||
elif T is FindNodeMessage: findNode
|
||||
elif T is NodesMessage: nodes
|
||||
elif T is FindContentMessage: findcontent
|
||||
elif T is FoundContentMessage: foundcontent
|
||||
elif T is AdvertiseMessage: advertise
|
||||
elif T is RequestProofsMessage: requestproofs
|
||||
|
||||
template toSszType*(x: auto): auto =
|
||||
mixin toSszType
|
||||
|
||||
when x is UInt256: toBytesLE(x)
|
||||
else: x
|
||||
|
||||
func fromSszBytes*(T: type UInt256, data: openArray[byte]):
|
||||
T {.raises: [MalformedSszError, Defect].} =
|
||||
if data.len != sizeof(result):
|
||||
raiseIncorrectSize T
|
||||
|
||||
T.fromBytesLE(data)
|
||||
|
||||
proc encodeMessage*[T: SomeMessage](m: T): seq[byte] =
|
||||
ord(messageKind(T)).byte & SSZ.encode(m)
|
||||
|
||||
proc decodeMessage*(body: openarray[byte]): Result[Message, cstring] =
|
||||
# Decodes to the specific `Message` type.
|
||||
if body.len < 1:
|
||||
return err("No message data")
|
||||
|
||||
var kind: MessageKind
|
||||
if not checkedEnumAssign(kind, body[0]):
|
||||
return err("Invalid message type")
|
||||
|
||||
var message = Message(kind: kind)
|
||||
|
||||
try:
|
||||
case kind
|
||||
of unused: return err("Invalid message type")
|
||||
of ping:
|
||||
message.ping = SSZ.decode(body.toOpenArray(1, body.high), PingMessage)
|
||||
of pong:
|
||||
message.pong = SSZ.decode(body.toOpenArray(1, body.high), PongMessage)
|
||||
of findNode:
|
||||
message.findNode = SSZ.decode(body.toOpenArray(1, body.high), FindNodeMessage)
|
||||
of nodes:
|
||||
message.nodes = SSZ.decode(body.toOpenArray(1, body.high), NodesMessage)
|
||||
of findcontent:
|
||||
message.findcontent = SSZ.decode(body.toOpenArray(1, body.high), FindContentMessage)
|
||||
of foundcontent:
|
||||
message.foundcontent = SSZ.decode(body.toOpenArray(1, body.high), FoundContentMessage)
|
||||
of advertise:
|
||||
message.advertise = SSZ.decode(body.toOpenArray(1, body.high), AdvertiseMessage)
|
||||
of requestproofs:
|
||||
message.requestproofs = SSZ.decode(body.toOpenArray(1, body.high), RequestProofsMessage)
|
||||
except SszError:
|
||||
return err("Invalid message encoding")
|
||||
|
||||
ok(message)
|
|
@ -0,0 +1,221 @@
|
|||
# nim-eth - Portal Network
|
||||
# Copyright (c) 2021 Status Research & Development GmbH
|
||||
# Licensed and distributed under either of
|
||||
# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT).
|
||||
# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0).
|
||||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||
|
||||
import
|
||||
std/[options, strutils, tables],
|
||||
confutils, confutils/std/net, chronicles, chronicles/topics_registry,
|
||||
chronos, metrics, metrics/chronos_httpserver, stew/byteutils,
|
||||
../../keys, ../../net/nat,
|
||||
".."/discoveryv5/[enr, node], ".."/discoveryv5/protocol as discv5_protocol,
|
||||
./messages, ./protocol as portal_protocol
|
||||
|
||||
type
|
||||
PortalCmd* = enum
|
||||
noCommand
|
||||
ping
|
||||
findnode
|
||||
findcontent
|
||||
|
||||
DiscoveryConf* = object
|
||||
logLevel* {.
|
||||
defaultValue: LogLevel.DEBUG
|
||||
desc: "Sets the log level"
|
||||
name: "log-level" .}: LogLevel
|
||||
|
||||
udpPort* {.
|
||||
defaultValue: 9009
|
||||
desc: "UDP listening port"
|
||||
name: "udp-port" .}: uint16
|
||||
|
||||
listenAddress* {.
|
||||
defaultValue: defaultListenAddress(config)
|
||||
desc: "Listening address for the Discovery v5 traffic"
|
||||
name: "listen-address" }: ValidIpAddress
|
||||
|
||||
bootnodes* {.
|
||||
desc: "ENR URI of node to bootstrap discovery with. Argument may be repeated"
|
||||
name: "bootnode" .}: seq[enr.Record]
|
||||
|
||||
nat* {.
|
||||
desc: "Specify method to use for determining public address. " &
|
||||
"Must be one of: any, none, upnp, pmp, extip:<IP>"
|
||||
defaultValue: NatConfig(hasExtIp: false, nat: NatAny)
|
||||
name: "nat" .}: NatConfig
|
||||
|
||||
enrAutoUpdate* {.
|
||||
defaultValue: false
|
||||
desc: "Discovery can automatically update its ENR with the IP address " &
|
||||
"and UDP port as seen by other nodes it communicates with. " &
|
||||
"This option allows to enable/disable this functionality"
|
||||
name: "enr-auto-update" .}: bool
|
||||
|
||||
nodeKey* {.
|
||||
desc: "P2P node private key as hex",
|
||||
defaultValue: PrivateKey.random(keys.newRng()[])
|
||||
name: "nodekey" .}: PrivateKey
|
||||
|
||||
metricsEnabled* {.
|
||||
defaultValue: false
|
||||
desc: "Enable the metrics server"
|
||||
name: "metrics" .}: bool
|
||||
|
||||
metricsAddress* {.
|
||||
defaultValue: defaultAdminListenAddress(config)
|
||||
desc: "Listening address of the metrics server"
|
||||
name: "metrics-address" .}: ValidIpAddress
|
||||
|
||||
metricsPort* {.
|
||||
defaultValue: 8008
|
||||
desc: "Listening HTTP port of the metrics server"
|
||||
name: "metrics-port" .}: Port
|
||||
|
||||
case cmd* {.
|
||||
command
|
||||
defaultValue: noCommand }: PortalCmd
|
||||
of noCommand:
|
||||
discard
|
||||
of ping:
|
||||
pingTarget* {.
|
||||
argument
|
||||
desc: "ENR URI of the node to a send ping message"
|
||||
name: "node" .}: Node
|
||||
of findnode:
|
||||
distance* {.
|
||||
defaultValue: 255
|
||||
desc: "Distance parameter for the findNode message"
|
||||
name: "distance" .}: uint16
|
||||
# TODO: Order here matters as else the help message does not show all the
|
||||
# information, see: https://github.com/status-im/nim-confutils/issues/15
|
||||
findNodeTarget* {.
|
||||
argument
|
||||
desc: "ENR URI of the node to send a findNode message"
|
||||
name: "node" .}: Node
|
||||
of findcontent:
|
||||
findContentTarget* {.
|
||||
argument
|
||||
desc: "ENR URI of the node to send a findContent message"
|
||||
name: "node" .}: Node
|
||||
|
||||
func defaultListenAddress*(conf: DiscoveryConf): ValidIpAddress =
|
||||
(static ValidIpAddress.init("0.0.0.0"))
|
||||
|
||||
func defaultAdminListenAddress*(conf: DiscoveryConf): ValidIpAddress =
|
||||
(static ValidIpAddress.init("127.0.0.1"))
|
||||
|
||||
proc parseCmdArg*(T: type enr.Record, p: TaintedString): T =
|
||||
if not fromURI(result, p):
|
||||
raise newException(ConfigurationError, "Invalid ENR")
|
||||
|
||||
proc completeCmdArg*(T: type enr.Record, val: TaintedString): seq[string] =
|
||||
return @[]
|
||||
|
||||
proc parseCmdArg*(T: type Node, p: TaintedString): T =
|
||||
var record: enr.Record
|
||||
if not fromURI(record, p):
|
||||
raise newException(ConfigurationError, "Invalid ENR")
|
||||
|
||||
let n = newNode(record)
|
||||
if n.isErr:
|
||||
raise newException(ConfigurationError, $n.error)
|
||||
|
||||
if n[].address.isNone():
|
||||
raise newException(ConfigurationError, "ENR without address")
|
||||
|
||||
n[]
|
||||
|
||||
proc completeCmdArg*(T: type Node, val: TaintedString): seq[string] =
|
||||
return @[]
|
||||
|
||||
proc parseCmdArg*(T: type PrivateKey, p: TaintedString): T =
|
||||
try:
|
||||
result = PrivateKey.fromHex(string(p)).tryGet()
|
||||
except CatchableError:
|
||||
raise newException(ConfigurationError, "Invalid private key")
|
||||
|
||||
proc completeCmdArg*(T: type PrivateKey, val: TaintedString): seq[string] =
|
||||
return @[]
|
||||
|
||||
proc discover(d: discv5_protocol.Protocol) {.async.} =
|
||||
while true:
|
||||
let discovered = await d.queryRandom()
|
||||
info "Lookup finished", nodes = discovered.len
|
||||
await sleepAsync(30.seconds)
|
||||
|
||||
proc run(config: DiscoveryConf) =
|
||||
let
|
||||
rng = newRng()
|
||||
bindIp = config.listenAddress
|
||||
udpPort = Port(config.udpPort)
|
||||
# TODO: allow for no TCP port mapping!
|
||||
(extIp, _, extUdpPort) = setupAddress(config.nat,
|
||||
config.listenAddress, udpPort, udpPort, "dcli")
|
||||
|
||||
let d = newProtocol(config.nodeKey,
|
||||
extIp, none(Port), extUdpPort,
|
||||
bootstrapRecords = config.bootnodes,
|
||||
bindIp = bindIp, bindPort = udpPort,
|
||||
enrAutoUpdate = config.enrAutoUpdate,
|
||||
rng = rng)
|
||||
|
||||
d.open()
|
||||
|
||||
let portal = PortalProtocol.new(d)
|
||||
|
||||
if config.metricsEnabled:
|
||||
let
|
||||
address = config.metricsAddress
|
||||
port = config.metricsPort
|
||||
notice "Starting metrics HTTP server",
|
||||
url = "http://" & $address & ":" & $port & "/metrics"
|
||||
try:
|
||||
chronos_httpserver.startMetricsHttpServer($address, port)
|
||||
except CatchableError as exc: raise exc
|
||||
except Exception as exc: raiseAssert exc.msg # TODO fix metrics
|
||||
|
||||
case config.cmd
|
||||
of ping:
|
||||
let pong = waitFor portal.ping(config.pingTarget)
|
||||
|
||||
if pong.isOk():
|
||||
echo pong.get()
|
||||
else:
|
||||
echo pong.error
|
||||
of findnode:
|
||||
let distances = List[uint16, 256](@[config.distance])
|
||||
let nodes = waitFor portal.findNode(config.findNodeTarget, distances)
|
||||
|
||||
if nodes.isOk():
|
||||
echo nodes.get()
|
||||
else:
|
||||
echo nodes.error
|
||||
of findcontent:
|
||||
proc random(T: type UInt256, rng: var BrHmacDrbgContext): T =
|
||||
var key: UInt256
|
||||
brHmacDrbgGenerate(addr rng, addr key, csize_t(sizeof(key)))
|
||||
|
||||
key
|
||||
|
||||
# For now just random content keys
|
||||
let contentKey = ByteList(@(UInt256.random(rng[]).toBytes()))
|
||||
let foundContent = waitFor portal.findContent(config.findContentTarget,
|
||||
contentKey)
|
||||
|
||||
if foundContent.isOk():
|
||||
echo foundContent.get()
|
||||
else:
|
||||
echo foundContent.error
|
||||
|
||||
of noCommand:
|
||||
d.start()
|
||||
waitfor(discover(d))
|
||||
|
||||
when isMainModule:
|
||||
let config = DiscoveryConf.load()
|
||||
|
||||
setLogLevel(config.logLevel)
|
||||
|
||||
run(config)
|
|
@ -0,0 +1,143 @@
|
|||
# nim-eth - Portal Network
|
||||
# Copyright (c) 2021 Status Research & Development GmbH
|
||||
# Licensed and distributed under either of
|
||||
# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT).
|
||||
# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0).
|
||||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||
|
||||
{.push raises: [Defect].}
|
||||
|
||||
import
|
||||
stew/[results, byteutils],
|
||||
../../rlp,
|
||||
../discoveryv5/[protocol, node],
|
||||
./messages
|
||||
|
||||
export messages
|
||||
|
||||
const
|
||||
PortalProtocolId* = "portal".toBytes()
|
||||
|
||||
type
|
||||
PortalProtocol* = ref object of TalkProtocol
|
||||
baseProtocol*: protocol.Protocol
|
||||
dataRadius*: UInt256
|
||||
|
||||
proc handlePing(p: PortalProtocol, ping: PingMessage):
|
||||
seq[byte] =
|
||||
let p = PongMessage(enrSeq: p.baseProtocol.localNode.record.seqNum,
|
||||
dataRadius: p.dataRadius)
|
||||
|
||||
encodeMessage(p)
|
||||
|
||||
proc handleFindNode(p: PortalProtocol, fn: FindNodeMessage): seq[byte] =
|
||||
if fn.distances.len == 0:
|
||||
let enrs = List[ByteList, 32](@[])
|
||||
encodeMessage(NodesMessage(total: 1, enrs: enrs))
|
||||
elif fn.distances.contains(0):
|
||||
# A request for our own record.
|
||||
let enr = ByteList(rlp.encode(p.baseProtocol.localNode.record))
|
||||
encodeMessage(NodesMessage(total: 1, enrs: List[ByteList, 32](@[enr])))
|
||||
else:
|
||||
# TODO: Not implemented for now, sending empty back.
|
||||
let enrs = List[ByteList, 32](@[])
|
||||
encodeMessage(NodesMessage(total: 1, enrs: enrs))
|
||||
|
||||
proc handleFindContent(p: PortalProtocol, ping: FindContentMessage): seq[byte] =
|
||||
# TODO: Neither payload nor enrs implemented, sending empty back.
|
||||
let
|
||||
enrs = List[ByteList, 32](@[])
|
||||
payload = ByteList(@[])
|
||||
encodeMessage(FoundContentMessage(enrs: enrs, payload: payload))
|
||||
|
||||
proc handleAdvertise(p: PortalProtocol, ping: AdvertiseMessage): seq[byte] =
|
||||
# TODO: Not implemented
|
||||
let
|
||||
connectionId = List[byte, 4](@[])
|
||||
contentKeys = List[ByteList, 32](@[])
|
||||
encodeMessage(RequestProofsMessage(connectionId: connectionId,
|
||||
contentKeys: contentKeys))
|
||||
|
||||
proc messageHandler*(protocol: TalkProtocol, request: seq[byte]): seq[byte] =
|
||||
doAssert(protocol of PortalProtocol)
|
||||
|
||||
let p = PortalProtocol(protocol)
|
||||
|
||||
let decoded = decodeMessage(request)
|
||||
if decoded.isOk():
|
||||
let message = decoded.get()
|
||||
case message.kind
|
||||
of MessageKind.ping:
|
||||
p.handlePing(message.ping)
|
||||
of MessageKind.findnode:
|
||||
p.handleFindNode(message.findNode)
|
||||
of MessageKind.findcontent:
|
||||
p.handleFindContent(message.findcontent)
|
||||
of MessageKind.advertise:
|
||||
p.handleAdvertise(message.advertise)
|
||||
else:
|
||||
@[]
|
||||
else:
|
||||
@[]
|
||||
|
||||
proc new*(T: type PortalProtocol, baseProtocol: protocol.Protocol,
|
||||
dataRadius = UInt256.high()): T =
|
||||
let proto = PortalProtocol(
|
||||
protocolHandler: messageHandler,
|
||||
baseProtocol: baseProtocol,
|
||||
dataRadius: dataRadius)
|
||||
|
||||
proto.baseProtocol.registerTalkProtocol(PortalProtocolId, proto).expect(
|
||||
"Only one protocol should have this id")
|
||||
|
||||
return proto
|
||||
|
||||
proc ping*(p: PortalProtocol, dst: Node):
|
||||
Future[DiscResult[PongMessage]] {.async.} =
|
||||
let ping = PingMessage(enrSeq: p.baseProtocol.localNode.record.seqNum,
|
||||
dataRadius: p.dataRadius)
|
||||
|
||||
let talkresp = await talkreq(p.baseProtocol, dst, PortalProtocolId,
|
||||
encodeMessage(ping))
|
||||
|
||||
if talkresp.isOk():
|
||||
let decoded = decodeMessage(talkresp.get().response)
|
||||
if decoded.isOk() and decoded.get().kind == pong:
|
||||
return ok(decoded.get().pong)
|
||||
else:
|
||||
return err("Invalid message received")
|
||||
else:
|
||||
return err(talkresp.error)
|
||||
|
||||
proc findNode*(p: PortalProtocol, dst: Node, distances: List[uint16, 256]):
|
||||
Future[DiscResult[NodesMessage]] {.async.} =
|
||||
let fn = FindNodeMessage(distances: distances)
|
||||
|
||||
let talkresp = await talkreq(p.baseProtocol, dst, PortalProtocolId,
|
||||
encodeMessage(fn))
|
||||
|
||||
if talkresp.isOk():
|
||||
let decoded = decodeMessage(talkresp.get().response)
|
||||
if decoded.isOk() and decoded.get().kind == nodes:
|
||||
# TODO: Verify nodes here
|
||||
return ok(decoded.get().nodes)
|
||||
else:
|
||||
return err("Invalid message received")
|
||||
else:
|
||||
return err(talkresp.error)
|
||||
|
||||
proc findContent*(p: PortalProtocol, dst: Node, contentKey: ByteList):
|
||||
Future[DiscResult[FoundContentMessage]] {.async.} =
|
||||
let fc = FindContentMessage(contentKey: contentKey)
|
||||
|
||||
let talkresp = await talkreq(p.baseProtocol, dst, PortalProtocolId,
|
||||
encodeMessage(fc))
|
||||
|
||||
if talkresp.isOk():
|
||||
let decoded = decodeMessage(talkresp.get().response)
|
||||
if decoded.isOk() and decoded.get().kind == foundcontent:
|
||||
return ok(decoded.get().foundcontent)
|
||||
else:
|
||||
return err("Invalid message received")
|
||||
else:
|
||||
return err(talkresp.error)
|
|
@ -0,0 +1,313 @@
|
|||
# nim-eth
|
||||
# Copyright (c) 2018-2021 Status Research & Development GmbH
|
||||
# Licensed and distributed under either of
|
||||
# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT).
|
||||
# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0).
|
||||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||
|
||||
{.push raises: [Defect].}
|
||||
|
||||
import
|
||||
stew/[bitops2, endians2, ptrops]
|
||||
|
||||
type
|
||||
Bytes = seq[byte]
|
||||
|
||||
BitSeq* = distinct Bytes
|
||||
## The current design of BitSeq tries to follow precisely
|
||||
## the bitwise representation of the SSZ bitlists.
|
||||
## This is a relatively compact representation, but as
|
||||
## evident from the code below, many of the operations
|
||||
## are not trivial.
|
||||
|
||||
BitArray*[bits: static int] = object
|
||||
bytes*: array[(bits + 7) div 8, byte]
|
||||
|
||||
func bitsLen*(bytes: openArray[byte]): int =
|
||||
let
|
||||
bytesCount = bytes.len
|
||||
lastByte = bytes[bytesCount - 1]
|
||||
markerPos = log2trunc(lastByte)
|
||||
|
||||
bytesCount * 8 - (8 - markerPos)
|
||||
|
||||
template len*(s: BitSeq): int =
|
||||
bitsLen(Bytes s)
|
||||
|
||||
template len*(a: BitArray): int =
|
||||
a.bits
|
||||
|
||||
func add*(s: var BitSeq, value: bool) =
|
||||
let
|
||||
lastBytePos = s.Bytes.len - 1
|
||||
lastByte = s.Bytes[lastBytePos]
|
||||
|
||||
if (lastByte and byte(128)) == 0:
|
||||
# There is at least one leading zero, so we have enough
|
||||
# room to store the new bit
|
||||
let markerPos = log2trunc(lastByte)
|
||||
s.Bytes[lastBytePos].changeBit markerPos, value
|
||||
s.Bytes[lastBytePos].setBit markerPos + 1
|
||||
else:
|
||||
s.Bytes[lastBytePos].changeBit 7, value
|
||||
s.Bytes.add byte(1)
|
||||
|
||||
func toBytesLE(x: uint): array[sizeof(x), byte] =
|
||||
# stew/endians2 supports explicitly sized uints only
|
||||
when sizeof(uint) == 4:
|
||||
static: doAssert sizeof(uint) == sizeof(uint32)
|
||||
toBytesLE(x.uint32)
|
||||
elif sizeof(uint) == 8:
|
||||
static: doAssert sizeof(uint) == sizeof(uint64)
|
||||
toBytesLE(x.uint64)
|
||||
else:
|
||||
static: doAssert false, "requires a 32-bit or 64-bit platform"
|
||||
|
||||
func loadLEBytes(WordType: type, bytes: openArray[byte]): WordType =
|
||||
# TODO: this is a temporary proc until the endians API is improved
|
||||
var shift = 0
|
||||
for b in bytes:
|
||||
result = result or (WordType(b) shl shift)
|
||||
shift += 8
|
||||
|
||||
func storeLEBytes(value: SomeUnsignedInt, dst: var openArray[byte]) =
|
||||
doAssert dst.len <= sizeof(value)
|
||||
let bytesLE = toBytesLE(value)
|
||||
copyMem(addr dst[0], unsafeAddr bytesLE[0], dst.len)
|
||||
|
||||
template loopOverWords(lhs, rhs: BitSeq,
|
||||
lhsIsVar, rhsIsVar: static bool,
|
||||
WordType: type,
|
||||
lhsBits, rhsBits, body: untyped) =
|
||||
const hasRhs = astToStr(lhs) != astToStr(rhs)
|
||||
|
||||
let bytesCount = len Bytes(lhs)
|
||||
when hasRhs: doAssert len(Bytes(rhs)) == bytesCount
|
||||
|
||||
var fullWordsCount = bytesCount div sizeof(WordType)
|
||||
let lastWordSize = bytesCount mod sizeof(WordType)
|
||||
|
||||
block:
|
||||
var lhsWord: WordType
|
||||
when hasRhs:
|
||||
var rhsWord: WordType
|
||||
var firstByteOfLastWord, lastByteOfLastWord: int
|
||||
|
||||
# TODO: Returning a `var` value from an iterator is always safe due to
|
||||
# the way inlining works, but currently the compiler reports an error
|
||||
# when a local variable escapes. We have to cheat it with this location
|
||||
# obfuscation through pointers:
|
||||
template lhsBits: auto = (addr(lhsWord))[]
|
||||
|
||||
when hasRhs:
|
||||
template rhsBits: auto = (addr(rhsWord))[]
|
||||
|
||||
template lastWordBytes(bitseq): auto =
|
||||
Bytes(bitseq).toOpenArray(firstByteOfLastWord, lastByteOfLastWord)
|
||||
|
||||
template initLastWords =
|
||||
lhsWord = loadLEBytes(WordType, lastWordBytes(lhs))
|
||||
when hasRhs: rhsWord = loadLEBytes(WordType, lastWordBytes(rhs))
|
||||
|
||||
if lastWordSize == 0:
|
||||
firstByteOfLastWord = bytesCount - sizeof(WordType)
|
||||
lastByteOfLastWord = bytesCount - 1
|
||||
dec fullWordsCount
|
||||
else:
|
||||
firstByteOfLastWord = bytesCount - lastWordSize
|
||||
lastByteOfLastWord = bytesCount - 1
|
||||
|
||||
initLastWords()
|
||||
let markerPos = log2trunc(lhsWord)
|
||||
when hasRhs: doAssert log2trunc(rhsWord) == markerPos
|
||||
|
||||
lhsWord.clearBit markerPos
|
||||
when hasRhs: rhsWord.clearBit markerPos
|
||||
|
||||
body
|
||||
|
||||
when lhsIsVar or rhsIsVar:
|
||||
let
|
||||
markerBit = uint(1 shl markerPos)
|
||||
mask = markerBit - 1'u
|
||||
|
||||
when lhsIsVar:
|
||||
let lhsEndResult = (lhsWord and mask) or markerBit
|
||||
storeLEBytes(lhsEndResult, lastWordBytes(lhs))
|
||||
|
||||
when rhsIsVar:
|
||||
let rhsEndResult = (rhsWord and mask) or markerBit
|
||||
storeLEBytes(rhsEndResult, lastWordBytes(rhs))
|
||||
|
||||
var lhsCurrAddr = cast[ptr WordType](unsafeAddr Bytes(lhs)[0])
|
||||
let lhsEndAddr = offset(lhsCurrAddr, fullWordsCount)
|
||||
when hasRhs:
|
||||
var rhsCurrAddr = cast[ptr WordType](unsafeAddr Bytes(rhs)[0])
|
||||
|
||||
while lhsCurrAddr < lhsEndAddr:
|
||||
template lhsBits: auto = lhsCurrAddr[]
|
||||
when hasRhs:
|
||||
template rhsBits: auto = rhsCurrAddr[]
|
||||
|
||||
body
|
||||
|
||||
lhsCurrAddr = offset(lhsCurrAddr, 1)
|
||||
when hasRhs: rhsCurrAddr = offset(rhsCurrAddr, 1)
|
||||
|
||||
iterator words*(x: var BitSeq): var uint =
|
||||
loopOverWords(x, x, true, false, uint, word, wordB):
|
||||
yield word
|
||||
|
||||
iterator words*(x: BitSeq): uint =
|
||||
loopOverWords(x, x, false, false, uint, word, word):
|
||||
yield word
|
||||
|
||||
iterator words*(a, b: BitSeq): (uint, uint) =
|
||||
loopOverWords(a, b, false, false, uint, wordA, wordB):
|
||||
yield (wordA, wordB)
|
||||
|
||||
iterator words*(a: var BitSeq, b: BitSeq): (var uint, uint) =
|
||||
loopOverWords(a, b, true, false, uint, wordA, wordB):
|
||||
yield (wordA, wordB)
|
||||
|
||||
iterator words*(a, b: var BitSeq): (var uint, var uint) =
|
||||
loopOverWords(a, b, true, true, uint, wordA, wordB):
|
||||
yield (wordA, wordB)
|
||||
|
||||
func `[]`*(s: BitSeq, pos: Natural): bool {.inline.} =
|
||||
doAssert pos < s.len
|
||||
s.Bytes.getBit pos
|
||||
|
||||
func `[]=`*(s: var BitSeq, pos: Natural, value: bool) {.inline.} =
|
||||
doAssert pos < s.len
|
||||
s.Bytes.changeBit pos, value
|
||||
|
||||
func setBit*(s: var BitSeq, pos: Natural) {.inline.} =
|
||||
doAssert pos < s.len
|
||||
setBit s.Bytes, pos
|
||||
|
||||
func clearBit*(s: var BitSeq, pos: Natural) {.inline.} =
|
||||
doAssert pos < s.len
|
||||
clearBit s.Bytes, pos
|
||||
|
||||
func init*(T: type BitSeq, len: int): T =
|
||||
result = BitSeq newSeq[byte](1 + len div 8)
|
||||
Bytes(result).setBit len
|
||||
|
||||
func init*(T: type BitArray): T =
|
||||
# The default zero-initializatio is fine
|
||||
discard
|
||||
|
||||
template `[]`*(a: BitArray, pos: Natural): bool =
|
||||
getBit a.bytes, pos
|
||||
|
||||
template `[]=`*(a: var BitArray, pos: Natural, value: bool) =
|
||||
changeBit a.bytes, pos, value
|
||||
|
||||
template setBit*(a: var BitArray, pos: Natural) =
|
||||
setBit a.bytes, pos
|
||||
|
||||
template clearBit*(a: var BitArray, pos: Natural) =
|
||||
clearBit a.bytes, pos
|
||||
|
||||
# TODO: Submit this to the standard library as `cmp`
|
||||
# At the moment, it doesn't work quite well because Nim selects
|
||||
# the generic cmp[T] from the system module instead of choosing
|
||||
# the openArray overload
|
||||
func compareArrays[T](a, b: openArray[T]): int =
|
||||
result = cmp(a.len, b.len)
|
||||
if result != 0: return
|
||||
|
||||
for i in 0 ..< a.len:
|
||||
result = cmp(a[i], b[i])
|
||||
if result != 0: return
|
||||
|
||||
template cmp*(a, b: BitSeq): int =
|
||||
compareArrays(Bytes a, Bytes b)
|
||||
|
||||
template `==`*(a, b: BitSeq): bool =
|
||||
cmp(a, b) == 0
|
||||
|
||||
func `$`*(a: BitSeq | BitArray): string =
|
||||
let length = a.len
|
||||
result = newStringOfCap(2 + length)
|
||||
result.add "0b"
|
||||
for i in countdown(length - 1, 0):
|
||||
result.add if a[i]: '1' else: '0'
|
||||
|
||||
func incl*(tgt: var BitSeq, src: BitSeq) =
|
||||
# Update `tgt` to include the bits of `src`, as if applying `or` to each bit
|
||||
doAssert tgt.len == src.len
|
||||
for tgtWord, srcWord in words(tgt, src):
|
||||
tgtWord = tgtWord or srcWord
|
||||
|
||||
func overlaps*(a, b: BitSeq): bool =
|
||||
for wa, wb in words(a, b):
|
||||
if (wa and wb) != 0:
|
||||
return true
|
||||
|
||||
func countOverlap*(a, b: BitSeq): int =
|
||||
var res = 0
|
||||
for wa, wb in words(a, b):
|
||||
res += countOnes(wa and wb)
|
||||
res
|
||||
|
||||
func isSubsetOf*(a, b: BitSeq): bool =
|
||||
let alen = a.len
|
||||
doAssert b.len == alen
|
||||
for i in 0 ..< alen:
|
||||
if a[i] and not b[i]:
|
||||
return false
|
||||
true
|
||||
|
||||
func isZeros*(x: BitSeq): bool =
|
||||
for w in words(x):
|
||||
if w != 0: return false
|
||||
return true
|
||||
|
||||
func countOnes*(x: BitSeq): int =
|
||||
# Count the number of set bits
|
||||
var res = 0
|
||||
for w in words(x):
|
||||
res += w.countOnes()
|
||||
res
|
||||
|
||||
func clear*(x: var BitSeq) =
|
||||
for w in words(x):
|
||||
w = 0
|
||||
|
||||
func countZeros*(x: BitSeq): int =
|
||||
x.len() - x.countOnes()
|
||||
|
||||
template bytes*(x: BitSeq): untyped =
|
||||
seq[byte](x)
|
||||
|
||||
iterator items*(x: BitArray): bool =
|
||||
for i in 0..<x.bits:
|
||||
yield x[i]
|
||||
|
||||
iterator pairs*(x: BitArray): (int, bool) =
|
||||
for i in 0..<x.bits:
|
||||
yield (i, x[i])
|
||||
|
||||
func incl*(a: var BitArray, b: BitArray) =
|
||||
# Update `a` to include the bits of `b`, as if applying `or` to each bit
|
||||
for i in 0..<a.bytes.len:
|
||||
a[i] = a[i] or b[i]
|
||||
|
||||
func clear*(a: var BitArray) =
|
||||
for b in a.bytes.mitems(): b = 0
|
||||
|
||||
# Set operations
|
||||
func `+`*(a, b: BitArray): BitArray =
|
||||
for i in 0..<a.bytes.len:
|
||||
result.bytes[i] = a.bytes[i] or b.bytes[i]
|
||||
|
||||
func `-`*(a, b: BitArray): BitArray =
|
||||
for i in 0..<a.bytes.len:
|
||||
result.bytes[i] = a.bytes[i] and (not b.bytes[i])
|
||||
|
||||
iterator oneIndices*(a: BitArray): int =
|
||||
for i in 0..<a.len:
|
||||
if a[i]: yield i
|
||||
|
|
@ -0,0 +1,218 @@
|
|||
# nim-eth - Limited SSZ implementation
|
||||
# Copyright (c) 2018-2021 Status Research & Development GmbH
|
||||
# Licensed and distributed under either of
|
||||
# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT).
|
||||
# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0).
|
||||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||
|
||||
{.push raises: [Defect].}
|
||||
|
||||
import
|
||||
std/[typetraits, options],
|
||||
stew/[endians2, objects],
|
||||
./types
|
||||
|
||||
template raiseIncorrectSize*(T: type) =
|
||||
const typeName = name(T)
|
||||
raise newException(MalformedSszError,
|
||||
"SSZ " & typeName & " input of incorrect size")
|
||||
|
||||
template setOutputSize[R, T](a: var array[R, T], length: int) =
|
||||
if length != a.len:
|
||||
raiseIncorrectSize a.type
|
||||
|
||||
proc setOutputSize(list: var List, length: int) {.raises: [SszError, Defect].} =
|
||||
if not list.setLen length:
|
||||
raise newException(MalformedSszError, "SSZ list maximum size exceeded")
|
||||
|
||||
# fromSszBytes copies the wire representation to a Nim variable,
|
||||
# assuming there's enough data in the buffer
|
||||
func fromSszBytes*(T: type UintN, data: openArray[byte]):
|
||||
T {.raises: [MalformedSszError, Defect].} =
|
||||
## Convert directly to bytes the size of the int. (e.g. ``uint16 = 2 bytes``)
|
||||
## All integers are serialized as **little endian**.
|
||||
if data.len != sizeof(result):
|
||||
raiseIncorrectSize T
|
||||
|
||||
T.fromBytesLE(data)
|
||||
|
||||
func fromSszBytes*(T: type bool, data: openArray[byte]):
|
||||
T {.raises: [MalformedSszError, Defect].} =
|
||||
# Strict: only allow 0 or 1
|
||||
if data.len != 1 or byte(data[0]) > byte(1):
|
||||
raise newException(MalformedSszError, "invalid boolean value")
|
||||
data[0] == 1
|
||||
|
||||
template fromSszBytes*(T: type BitSeq, bytes: openArray[byte]): auto =
|
||||
BitSeq @bytes
|
||||
|
||||
proc `[]`[T, U, V](s: openArray[T], x: HSlice[U, V]) {.error:
|
||||
"Please don't use openArray's [] as it allocates a result sequence".}
|
||||
|
||||
template checkForForbiddenBits(ResulType: type,
|
||||
input: openArray[byte],
|
||||
expectedBits: static int64) =
|
||||
## This checks if the input contains any bits set above the maximum
|
||||
## sized allowed. We only need to check the last byte to verify this:
|
||||
const bitsInLastByte = (expectedBits mod 8)
|
||||
when bitsInLastByte != 0:
|
||||
# As an example, if there are 3 bits expected in the last byte,
|
||||
# we calculate a bitmask equal to 11111000. If the input has any
|
||||
# raised bits in range of the bitmask, this would be a violation
|
||||
# of the size of the BitArray:
|
||||
const forbiddenBitsMask = byte(byte(0xff) shl bitsInLastByte)
|
||||
|
||||
if (input[^1] and forbiddenBitsMask) != 0:
|
||||
raiseIncorrectSize ResulType
|
||||
|
||||
func readSszValue*[T](input: openArray[byte], val: var T)
|
||||
{.raises: [SszError, Defect].} =
|
||||
mixin fromSszBytes, toSszType
|
||||
|
||||
template readOffsetUnchecked(n: int): uint32 {.used.}=
|
||||
fromSszBytes(uint32, input.toOpenArray(n, n + offsetSize - 1))
|
||||
|
||||
template readOffset(n: int): int {.used.} =
|
||||
let offset = readOffsetUnchecked(n)
|
||||
if offset > input.len.uint32:
|
||||
raise newException(MalformedSszError, "SSZ list element offset points past the end of the input")
|
||||
int(offset)
|
||||
|
||||
when val is BitList:
|
||||
if input.len == 0:
|
||||
raise newException(MalformedSszError, "Invalid empty SSZ BitList value")
|
||||
|
||||
# Since our BitLists have an in-memory representation that precisely
|
||||
# matches their SSZ encoding, we can deserialize them as regular Lists:
|
||||
const maxExpectedSize = (val.maxLen div 8) + 1
|
||||
type MatchingListType = List[byte, maxExpectedSize]
|
||||
|
||||
when false:
|
||||
# TODO: Nim doesn't like this simple type coercion,
|
||||
# we'll rely on `cast` for now (see below)
|
||||
readSszValue(input, MatchingListType val)
|
||||
else:
|
||||
static:
|
||||
# As a sanity check, we verify that the coercion is accepted by the compiler:
|
||||
doAssert MatchingListType(val) is MatchingListType
|
||||
readSszValue(input, cast[ptr MatchingListType](addr val)[])
|
||||
|
||||
let resultBytesCount = len bytes(val)
|
||||
|
||||
if bytes(val)[resultBytesCount - 1] == 0:
|
||||
raise newException(MalformedSszError, "SSZ BitList is not properly terminated")
|
||||
|
||||
if resultBytesCount == maxExpectedSize:
|
||||
checkForForbiddenBits(T, input, val.maxLen + 1)
|
||||
|
||||
elif val is List|array:
|
||||
type E = type val[0]
|
||||
|
||||
when E is byte:
|
||||
val.setOutputSize input.len
|
||||
if input.len > 0:
|
||||
copyMem(addr val[0], unsafeAddr input[0], input.len)
|
||||
|
||||
elif isFixedSize(E):
|
||||
const elemSize = fixedPortionSize(E)
|
||||
if input.len mod elemSize != 0:
|
||||
var ex = new SszSizeMismatchError
|
||||
ex.deserializedType = cstring typetraits.name(T)
|
||||
ex.actualSszSize = input.len
|
||||
ex.elementSize = elemSize
|
||||
raise ex
|
||||
val.setOutputSize input.len div elemSize
|
||||
for i in 0 ..< val.len:
|
||||
let offset = i * elemSize
|
||||
readSszValue(input.toOpenArray(offset, offset + elemSize - 1), val[i])
|
||||
|
||||
else:
|
||||
if input.len == 0:
|
||||
# This is an empty list.
|
||||
# The default initialization of the return value is fine.
|
||||
val.setOutputSize 0
|
||||
return
|
||||
elif input.len < offsetSize:
|
||||
raise newException(MalformedSszError, "SSZ input of insufficient size")
|
||||
|
||||
var offset = readOffset 0
|
||||
let resultLen = offset div offsetSize
|
||||
|
||||
if resultLen == 0:
|
||||
# If there are too many elements, other constraints detect problems
|
||||
# (not monotonically increasing, past end of input, or last element
|
||||
# not matching up with its nextOffset properly)
|
||||
raise newException(MalformedSszError, "SSZ list incorrectly encoded of zero length")
|
||||
|
||||
val.setOutputSize resultLen
|
||||
for i in 1 ..< resultLen:
|
||||
let nextOffset = readOffset(i * offsetSize)
|
||||
if nextOffset <= offset:
|
||||
raise newException(MalformedSszError, "SSZ list element offsets are not monotonically increasing")
|
||||
else:
|
||||
readSszValue(input.toOpenArray(offset, nextOffset - 1), val[i - 1])
|
||||
offset = nextOffset
|
||||
|
||||
readSszValue(input.toOpenArray(offset, input.len - 1), val[resultLen - 1])
|
||||
|
||||
elif val is UintN|bool:
|
||||
val = fromSszBytes(T, input)
|
||||
|
||||
elif val is BitArray:
|
||||
if sizeof(val) != input.len:
|
||||
raiseIncorrectSize(T)
|
||||
checkForForbiddenBits(T, input, val.bits)
|
||||
copyMem(addr val.bytes[0], unsafeAddr input[0], input.len)
|
||||
|
||||
elif val is object|tuple:
|
||||
let inputLen = uint32 input.len
|
||||
const minimallyExpectedSize = uint32 fixedPortionSize(T)
|
||||
|
||||
if inputLen < minimallyExpectedSize:
|
||||
raise newException(MalformedSszError, "SSZ input of insufficient size")
|
||||
|
||||
enumInstanceSerializedFields(val, fieldName, field):
|
||||
const boundingOffsets = getFieldBoundingOffsets(T, fieldName)
|
||||
|
||||
# type FieldType = type field # buggy
|
||||
# For some reason, Nim gets confused about the alias here. This could be a
|
||||
# generics caching issue caused by the use of distinct types. Such an
|
||||
# issue is very scary in general.
|
||||
# The bug can be seen with the two List[uint64, N] types that exist in
|
||||
# the spec, with different N.
|
||||
|
||||
type SszType = type toSszType(declval type(field))
|
||||
|
||||
when isFixedSize(SszType):
|
||||
const
|
||||
startOffset = boundingOffsets[0]
|
||||
endOffset = boundingOffsets[1]
|
||||
else:
|
||||
let
|
||||
startOffset = readOffsetUnchecked(boundingOffsets[0])
|
||||
endOffset = if boundingOffsets[1] == -1: inputLen
|
||||
else: readOffsetUnchecked(boundingOffsets[1])
|
||||
|
||||
when boundingOffsets.isFirstOffset:
|
||||
if startOffset != minimallyExpectedSize:
|
||||
raise newException(MalformedSszError, "SSZ object dynamic portion starts at invalid offset")
|
||||
|
||||
if startOffset > endOffset:
|
||||
raise newException(MalformedSszError, "SSZ field offsets are not monotonically increasing")
|
||||
elif endOffset > inputLen:
|
||||
raise newException(MalformedSszError, "SSZ field offset points past the end of the input")
|
||||
elif startOffset < minimallyExpectedSize:
|
||||
raise newException(MalformedSszError, "SSZ field offset points outside bounding offsets")
|
||||
|
||||
# TODO The extra type escaping here is a work-around for a Nim issue:
|
||||
when type(field) is type(SszType):
|
||||
readSszValue(
|
||||
input.toOpenArray(int(startOffset), int(endOffset - 1)),
|
||||
field)
|
||||
else:
|
||||
field = fromSszBytes(
|
||||
type(field),
|
||||
input.toOpenArray(int(startOffset), int(endOffset - 1)))
|
||||
|
||||
else:
|
||||
unsupported T
|
|
@ -0,0 +1,247 @@
|
|||
# nim-eth - Limited SSZ implementation
|
||||
# Copyright (c) 2018-2021 Status Research & Development GmbH
|
||||
# Licensed and distributed under either of
|
||||
# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT).
|
||||
# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0).
|
||||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||
|
||||
{.push raises: [Defect].}
|
||||
|
||||
## SSZ serialization for core SSZ types, as specified in:
|
||||
# https://github.com/ethereum/eth2.0-specs/blob/v1.0.1/ssz/simple-serialize.md#serialization
|
||||
|
||||
import
|
||||
std/[typetraits, options],
|
||||
stew/[endians2, leb128, objects],
|
||||
serialization, serialization/testing/tracing,
|
||||
./bytes_reader, ./types
|
||||
|
||||
export
|
||||
serialization, types, bytes_reader
|
||||
|
||||
type
|
||||
SszReader* = object
|
||||
stream: InputStream
|
||||
|
||||
SszWriter* = object
|
||||
stream: OutputStream
|
||||
|
||||
SizePrefixed*[T] = distinct T
|
||||
SszMaxSizeExceeded* = object of SerializationError
|
||||
|
||||
VarSizedWriterCtx = object
|
||||
fixedParts: WriteCursor
|
||||
offset: int
|
||||
|
||||
FixedSizedWriterCtx = object
|
||||
|
||||
serializationFormat SSZ
|
||||
|
||||
SSZ.setReader SszReader
|
||||
SSZ.setWriter SszWriter, PreferredOutput = seq[byte]
|
||||
|
||||
template sizePrefixed*[TT](x: TT): untyped =
|
||||
type T = TT
|
||||
SizePrefixed[T](x)
|
||||
|
||||
proc init*(T: type SszReader, stream: InputStream): T {.raises: [Defect].} =
|
||||
T(stream: stream)
|
||||
|
||||
proc writeFixedSized(s: var (OutputStream|WriteCursor), x: auto)
|
||||
{.raises: [Defect, IOError].} =
|
||||
mixin toSszType
|
||||
|
||||
when x is byte:
|
||||
s.write x
|
||||
elif x is bool:
|
||||
s.write byte(ord(x))
|
||||
elif x is UintN:
|
||||
when cpuEndian == bigEndian:
|
||||
s.write toBytesLE(x)
|
||||
else:
|
||||
s.writeMemCopy x
|
||||
elif x is array:
|
||||
when x[0] is byte:
|
||||
trs "APPENDING FIXED SIZE BYTES", x
|
||||
s.write x
|
||||
else:
|
||||
for elem in x:
|
||||
trs "WRITING FIXED SIZE ARRAY ELEMENT"
|
||||
s.writeFixedSized toSszType(elem)
|
||||
elif x is tuple|object:
|
||||
enumInstanceSerializedFields(x, fieldName, field):
|
||||
trs "WRITING FIXED SIZE FIELD", fieldName
|
||||
s.writeFixedSized toSszType(field)
|
||||
else:
|
||||
unsupported x.type
|
||||
|
||||
template writeOffset(cursor: var WriteCursor, offset: int) =
|
||||
write cursor, toBytesLE(uint32 offset)
|
||||
|
||||
template supports*(_: type SSZ, T: type): bool =
|
||||
mixin toSszType
|
||||
anonConst compiles(fixedPortionSize toSszType(declval T))
|
||||
|
||||
func init*(T: type SszWriter, stream: OutputStream): T {.raises: [Defect].} =
|
||||
result.stream = stream
|
||||
|
||||
proc writeVarSizeType(w: var SszWriter, value: auto)
|
||||
{.gcsafe, raises: [Defect, IOError].}
|
||||
|
||||
proc beginRecord*(w: var SszWriter, TT: type): auto {.raises: [Defect].} =
|
||||
type T = TT
|
||||
when isFixedSize(T):
|
||||
FixedSizedWriterCtx()
|
||||
else:
|
||||
const offset = when T is array: len(T) * offsetSize
|
||||
else: fixedPortionSize(T)
|
||||
VarSizedWriterCtx(offset: offset,
|
||||
fixedParts: w.stream.delayFixedSizeWrite(offset))
|
||||
|
||||
template writeField*(w: var SszWriter,
|
||||
ctx: var auto,
|
||||
fieldName: string,
|
||||
field: auto) =
|
||||
mixin toSszType
|
||||
when ctx is FixedSizedWriterCtx:
|
||||
writeFixedSized(w.stream, toSszType(field))
|
||||
else:
|
||||
type FieldType = type toSszType(field)
|
||||
|
||||
when isFixedSize(FieldType):
|
||||
writeFixedSized(ctx.fixedParts, toSszType(field))
|
||||
else:
|
||||
trs "WRITING OFFSET ", ctx.offset, " FOR ", fieldName
|
||||
writeOffset(ctx.fixedParts, ctx.offset)
|
||||
let initPos = w.stream.pos
|
||||
trs "WRITING VAR SIZE VALUE OF TYPE ", name(FieldType)
|
||||
when FieldType is BitList:
|
||||
trs "BIT SEQ ", bytes(field)
|
||||
writeVarSizeType(w, toSszType(field))
|
||||
ctx.offset += w.stream.pos - initPos
|
||||
|
||||
template endRecord*(w: var SszWriter, ctx: var auto) =
|
||||
when ctx is VarSizedWriterCtx:
|
||||
finalize ctx.fixedParts
|
||||
|
||||
proc writeSeq[T](w: var SszWriter, value: seq[T])
|
||||
{.raises: [Defect, IOError].} =
|
||||
# Please note that `writeSeq` exists in order to reduce the code bloat
|
||||
# produced from generic instantiations of the unique `List[N, T]` types.
|
||||
when isFixedSize(T):
|
||||
trs "WRITING LIST WITH FIXED SIZE ELEMENTS"
|
||||
for elem in value:
|
||||
w.stream.writeFixedSized toSszType(elem)
|
||||
trs "DONE"
|
||||
else:
|
||||
trs "WRITING LIST WITH VAR SIZE ELEMENTS"
|
||||
var offset = value.len * offsetSize
|
||||
var cursor = w.stream.delayFixedSizeWrite offset
|
||||
for elem in value:
|
||||
cursor.writeFixedSized uint32(offset)
|
||||
let initPos = w.stream.pos
|
||||
w.writeVarSizeType toSszType(elem)
|
||||
offset += w.stream.pos - initPos
|
||||
finalize cursor
|
||||
trs "DONE"
|
||||
|
||||
proc writeVarSizeType(w: var SszWriter, value: auto)
|
||||
{.raises: [Defect, IOError].} =
|
||||
trs "STARTING VAR SIZE TYPE"
|
||||
|
||||
when value is List:
|
||||
# We reduce code bloat by forwarding all `List` types to a general `seq[T]`
|
||||
# proc.
|
||||
writeSeq(w, asSeq value)
|
||||
elif value is BitList:
|
||||
# ATTENTION! We can reuse `writeSeq` only as long as our BitList type is
|
||||
# implemented to internally match the binary representation of SSZ BitLists
|
||||
# in memory.
|
||||
writeSeq(w, bytes value)
|
||||
elif value is object|tuple|array:
|
||||
trs "WRITING OBJECT OR ARRAY"
|
||||
var ctx = beginRecord(w, type value)
|
||||
enumerateSubFields(value, field):
|
||||
writeField w, ctx, astToStr(field), field
|
||||
endRecord w, ctx
|
||||
else:
|
||||
unsupported type(value)
|
||||
|
||||
proc writeValue*(w: var SszWriter, x: auto)
|
||||
{.gcsafe, raises: [Defect, IOError].} =
|
||||
mixin toSszType
|
||||
type T = type toSszType(x)
|
||||
|
||||
when isFixedSize(T):
|
||||
w.stream.writeFixedSized toSszType(x)
|
||||
else:
|
||||
w.writeVarSizeType toSszType(x)
|
||||
|
||||
func sszSize*(value: auto): int {.gcsafe, raises: [Defect].}
|
||||
|
||||
func sszSizeForVarSizeList[T](value: openArray[T]): int =
|
||||
mixin toSszType
|
||||
result = len(value) * offsetSize
|
||||
for elem in value:
|
||||
result += sszSize(toSszType elem)
|
||||
|
||||
func sszSize*(value: auto): int {.gcsafe, raises: [Defect].} =
|
||||
mixin toSszType
|
||||
type T = type toSszType(value)
|
||||
|
||||
when isFixedSize(T):
|
||||
anonConst fixedPortionSize(T)
|
||||
|
||||
elif T is array|List:
|
||||
type E = ElemType(T)
|
||||
when isFixedSize(E):
|
||||
len(value) * anonConst(fixedPortionSize(E))
|
||||
elif T is HashArray:
|
||||
sszSizeForVarSizeList(value.data)
|
||||
elif T is array:
|
||||
sszSizeForVarSizeList(value)
|
||||
else:
|
||||
sszSizeForVarSizeList(asSeq value)
|
||||
|
||||
elif T is BitList:
|
||||
return len(bytes(value))
|
||||
|
||||
elif T is object|tuple:
|
||||
result = anonConst fixedPortionSize(T)
|
||||
enumInstanceSerializedFields(value, _{.used.}, field):
|
||||
type FieldType = type toSszType(field)
|
||||
when not isFixedSize(FieldType):
|
||||
result += sszSize(toSszType field)
|
||||
|
||||
else:
|
||||
unsupported T
|
||||
|
||||
proc writeValue*[T](w: var SszWriter, x: SizePrefixed[T])
|
||||
{.raises: [Defect, IOError].} =
|
||||
var cursor = w.stream.delayVarSizeWrite(Leb128.maxLen(uint64))
|
||||
let initPos = w.stream.pos
|
||||
w.writeValue T(x)
|
||||
let length = toBytes(uint64(w.stream.pos - initPos), Leb128)
|
||||
cursor.finalWrite length.toOpenArray()
|
||||
|
||||
proc readValue*[T](r: var SszReader, val: var T)
|
||||
{.raises: [Defect, SszError, IOError].} =
|
||||
when isFixedSize(T):
|
||||
const minimalSize = fixedPortionSize(T)
|
||||
if r.stream.readable(minimalSize):
|
||||
readSszValue(r.stream.read(minimalSize), val)
|
||||
else:
|
||||
raise newException(MalformedSszError, "SSZ input of insufficient size")
|
||||
else:
|
||||
# TODO(zah) Read the fixed portion first and precisely measure the
|
||||
# size of the dynamic portion to consume the right number of bytes.
|
||||
readSszValue(r.stream.read(r.stream.len.get), val)
|
||||
|
||||
proc readSszBytes*[T](data: openArray[byte], val: var T) {.
|
||||
raises: [Defect, MalformedSszError, SszSizeMismatchError].} =
|
||||
when isFixedSize(T):
|
||||
const minimalSize = fixedPortionSize(T)
|
||||
if data.len < minimalSize:
|
||||
raise newException(MalformedSszError, "SSZ input of insufficient size")
|
||||
|
||||
readSszValue(data, val)
|
|
@ -0,0 +1,258 @@
|
|||
# nim-eth - Limited SSZ implementation
|
||||
# Copyright (c) 2018-2021 Status Research & Development GmbH
|
||||
# Licensed and distributed under either of
|
||||
# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT).
|
||||
# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0).
|
||||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||
|
||||
{.push raises: [Defect].}
|
||||
|
||||
import
|
||||
std/[tables, options, typetraits, strformat],
|
||||
stew/shims/macros, stew/[byteutils, bitops2, objects],
|
||||
serialization/[object_serialization, errors],
|
||||
./bitseqs
|
||||
|
||||
export bitseqs
|
||||
|
||||
const
|
||||
offsetSize* = 4
|
||||
bytesPerChunk* = 32
|
||||
|
||||
type
|
||||
UintN* = SomeUnsignedInt
|
||||
BasicType* = bool|UintN
|
||||
|
||||
Limit* = int64
|
||||
|
||||
List*[T; maxLen: static Limit] = distinct seq[T]
|
||||
BitList*[maxLen: static Limit] = distinct BitSeq
|
||||
|
||||
# Note for readers:
|
||||
# We use `array` for `Vector` and
|
||||
# `BitArray` for `BitVector`
|
||||
|
||||
SszError* = object of SerializationError
|
||||
|
||||
MalformedSszError* = object of SszError
|
||||
|
||||
SszSizeMismatchError* = object of SszError
|
||||
deserializedType*: cstring
|
||||
actualSszSize*: int
|
||||
elementSize*: int
|
||||
|
||||
template asSeq*(x: List): auto = distinctBase(x)
|
||||
|
||||
template init*[T](L: type List, x: seq[T], N: static Limit): auto =
|
||||
List[T, N](x)
|
||||
|
||||
template init*[T, N](L: type List[T, N], x: seq[T]): auto =
|
||||
List[T, N](x)
|
||||
|
||||
template `$`*(x: List): auto = $(distinctBase x)
|
||||
template len*(x: List): auto = len(distinctBase x)
|
||||
template low*(x: List): auto = low(distinctBase x)
|
||||
template high*(x: List): auto = high(distinctBase x)
|
||||
template `[]`*(x: List, idx: auto): untyped = distinctBase(x)[idx]
|
||||
template `[]=`*(x: var List, idx: auto, val: auto) = distinctBase(x)[idx] = val
|
||||
template `==`*(a, b: List): bool = distinctBase(a) == distinctBase(b)
|
||||
|
||||
template `&`*(a, b: List): auto = (type(a)(distinctBase(a) & distinctBase(b)))
|
||||
|
||||
template items* (x: List): untyped = items(distinctBase x)
|
||||
template pairs* (x: List): untyped = pairs(distinctBase x)
|
||||
template mitems*(x: var List): untyped = mitems(distinctBase x)
|
||||
template mpairs*(x: var List): untyped = mpairs(distinctBase x)
|
||||
|
||||
template contains* (x: List, val: auto): untyped = contains(distinctBase x, val)
|
||||
|
||||
proc add*(x: var List, val: auto): bool =
|
||||
if x.len < x.maxLen:
|
||||
add(distinctBase x, val)
|
||||
true
|
||||
else:
|
||||
false
|
||||
|
||||
proc setLen*(x: var List, newLen: int): bool =
|
||||
if newLen <= x.maxLen:
|
||||
setLen(distinctBase x, newLen)
|
||||
true
|
||||
else:
|
||||
false
|
||||
|
||||
template init*(L: type BitList, x: seq[byte], N: static Limit): auto =
|
||||
BitList[N](data: x)
|
||||
|
||||
template init*[N](L: type BitList[N], x: seq[byte]): auto =
|
||||
L(data: x)
|
||||
|
||||
template init*(T: type BitList, len: int): auto = T init(BitSeq, len)
|
||||
template len*(x: BitList): auto = len(BitSeq(x))
|
||||
template bytes*(x: BitList): auto = seq[byte](x)
|
||||
template `[]`*(x: BitList, idx: auto): auto = BitSeq(x)[idx]
|
||||
template `[]=`*(x: var BitList, idx: auto, val: bool) = BitSeq(x)[idx] = val
|
||||
template `==`*(a, b: BitList): bool = BitSeq(a) == BitSeq(b)
|
||||
template setBit*(x: var BitList, idx: Natural) = setBit(BitSeq(x), idx)
|
||||
template clearBit*(x: var BitList, idx: Natural) = clearBit(BitSeq(x), idx)
|
||||
template overlaps*(a, b: BitList): bool = overlaps(BitSeq(a), BitSeq(b))
|
||||
template incl*(a: var BitList, b: BitList) = incl(BitSeq(a), BitSeq(b))
|
||||
template isSubsetOf*(a, b: BitList): bool = isSubsetOf(BitSeq(a), BitSeq(b))
|
||||
template isZeros*(x: BitList): bool = isZeros(BitSeq(x))
|
||||
template countOnes*(x: BitList): int = countOnes(BitSeq(x))
|
||||
template countZeros*(x: BitList): int = countZeros(BitSeq(x))
|
||||
template countOverlap*(x, y: BitList): int = countOverlap(BitSeq(x), BitSeq(y))
|
||||
template `$`*(a: BitList): string = $(BitSeq(a))
|
||||
|
||||
iterator items*(x: BitList): bool =
|
||||
for i in 0 ..< x.len:
|
||||
yield x[i]
|
||||
|
||||
macro unsupported*(T: typed): untyped =
|
||||
# TODO: {.fatal.} breaks compilation even in `compiles()` context,
|
||||
# so we use this macro instead. It's also much better at figuring
|
||||
# out the actual type that was used in the instantiation.
|
||||
# File both problems as issues.
|
||||
error "SSZ serialization of the type " & humaneTypeName(T) & " is not supported"
|
||||
|
||||
template ElemType*(T: type array): untyped =
|
||||
type(default(T)[low(T)])
|
||||
|
||||
template ElemType*(T: type seq): untyped =
|
||||
type(default(T)[0])
|
||||
|
||||
template ElemType*(T: type List): untyped =
|
||||
T.T
|
||||
|
||||
func isFixedSize*(T0: type): bool {.compileTime.} =
|
||||
mixin toSszType, enumAllSerializedFields
|
||||
|
||||
type T = type toSszType(declval T0)
|
||||
|
||||
when T is BasicType:
|
||||
return true
|
||||
elif T is array:
|
||||
return isFixedSize(ElemType(T))
|
||||
elif T is object|tuple:
|
||||
enumAllSerializedFields(T):
|
||||
when not isFixedSize(FieldType):
|
||||
return false
|
||||
return true
|
||||
|
||||
func fixedPortionSize*(T0: type): int {.compileTime.} =
|
||||
mixin enumAllSerializedFields, toSszType
|
||||
|
||||
type T = type toSszType(declval T0)
|
||||
|
||||
when T is BasicType: sizeof(T)
|
||||
elif T is array:
|
||||
type E = ElemType(T)
|
||||
when isFixedSize(E): int(len(T)) * fixedPortionSize(E)
|
||||
else: int(len(T)) * offsetSize
|
||||
elif T is object|tuple:
|
||||
enumAllSerializedFields(T):
|
||||
when isFixedSize(FieldType):
|
||||
result += fixedPortionSize(FieldType)
|
||||
else:
|
||||
result += offsetSize
|
||||
else:
|
||||
unsupported T0
|
||||
|
||||
# TODO This should have been an iterator, but the VM can't compile the
|
||||
# code due to "too many registers required".
|
||||
proc fieldInfos*(RecordType: type): seq[tuple[name: string,
|
||||
offset: int,
|
||||
fixedSize: int,
|
||||
branchKey: string]] =
|
||||
mixin enumAllSerializedFields
|
||||
|
||||
var
|
||||
offsetInBranch = {"": 0}.toTable
|
||||
nestedUnder = initTable[string, string]()
|
||||
|
||||
enumAllSerializedFields(RecordType):
|
||||
const
|
||||
isFixed = isFixedSize(FieldType)
|
||||
fixedSize = when isFixed: fixedPortionSize(FieldType)
|
||||
else: 0
|
||||
branchKey = when fieldCaseDiscriminator.len == 0: ""
|
||||
else: fieldCaseDiscriminator & ":" & $fieldCaseBranches
|
||||
fieldSize = when isFixed: fixedSize
|
||||
else: offsetSize
|
||||
|
||||
nestedUnder[fieldName] = branchKey
|
||||
|
||||
var fieldOffset: int
|
||||
offsetInBranch.withValue(branchKey, val):
|
||||
fieldOffset = val[]
|
||||
val[] += fieldSize
|
||||
do:
|
||||
try:
|
||||
let parentBranch = nestedUnder.getOrDefault(fieldCaseDiscriminator, "")
|
||||
fieldOffset = offsetInBranch[parentBranch]
|
||||
offsetInBranch[branchKey] = fieldOffset + fieldSize
|
||||
except KeyError as e:
|
||||
raiseAssert e.msg
|
||||
|
||||
result.add((fieldName, fieldOffset, fixedSize, branchKey))
|
||||
|
||||
func getFieldBoundingOffsetsImpl(RecordType: type, fieldName: static string):
|
||||
tuple[fieldOffset, nextFieldOffset: int, isFirstOffset: bool]
|
||||
{.compileTime.} =
|
||||
result = (-1, -1, false)
|
||||
var fieldBranchKey: string
|
||||
var isFirstOffset = true
|
||||
|
||||
for f in fieldInfos(RecordType):
|
||||
if fieldName == f.name:
|
||||
result[0] = f.offset
|
||||
if f.fixedSize > 0:
|
||||
result[1] = result[0] + f.fixedSize
|
||||
return
|
||||
else:
|
||||
fieldBranchKey = f.branchKey
|
||||
result.isFirstOffset = isFirstOffset
|
||||
|
||||
elif result[0] != -1 and
|
||||
f.fixedSize == 0 and
|
||||
f.branchKey == fieldBranchKey:
|
||||
# We have found the next variable sized field
|
||||
result[1] = f.offset
|
||||
return
|
||||
|
||||
if f.fixedSize == 0:
|
||||
isFirstOffset = false
|
||||
|
||||
func getFieldBoundingOffsets*(RecordType: type, fieldName: static string):
|
||||
tuple[fieldOffset, nextFieldOffset: int, isFirstOffset: bool]
|
||||
{.compileTime.} =
|
||||
## Returns the start and end offsets of a field.
|
||||
##
|
||||
## For fixed-size fields, the start offset points to the first
|
||||
## byte of the field and the end offset points to 1 byte past the
|
||||
## end of the field.
|
||||
##
|
||||
## For variable-size fields, the returned offsets point to the
|
||||
## statically known positions of the 32-bit offset values written
|
||||
## within the SSZ object. You must read the 32-bit values stored
|
||||
## at the these locations in order to obtain the actual offsets.
|
||||
##
|
||||
## For variable-size fields, the end offset may be -1 when the
|
||||
## designated field is the last variable sized field within the
|
||||
## object. Then the SSZ object boundary known at run-time marks
|
||||
## the end of the variable-size field.
|
||||
type T = RecordType
|
||||
anonConst getFieldBoundingOffsetsImpl(T, fieldName)
|
||||
|
||||
template enumerateSubFields*(holder, fieldVar, body: untyped) =
|
||||
when holder is array:
|
||||
for fieldVar in holder: body
|
||||
else:
|
||||
enumInstanceSerializedFields(holder, _{.used.}, fieldVar): body
|
||||
|
||||
method formatMsg*(
|
||||
err: ref SszSizeMismatchError,
|
||||
filename: string): string {.gcsafe, raises: [Defect].} =
|
||||
try:
|
||||
&"SSZ size mismatch, element {err.elementSize}, actual {err.actualSszSize}, type {err.deserializedType}, file {filename}"
|
||||
except CatchableError:
|
||||
"SSZ size mismatch"
|
|
@ -0,0 +1,5 @@
|
|||
{.used.}
|
||||
|
||||
import
|
||||
./test_portal_encoding,
|
||||
./test_portal
|
|
@ -1,5 +1,6 @@
|
|||
import
|
||||
./all_discv5_tests,
|
||||
./all_portal_tests,
|
||||
./test_auth,
|
||||
./test_crypt,
|
||||
./test_discovery,
|
||||
|
|
|
@ -645,10 +645,13 @@ procSuite "Discovery v5 Tests":
|
|||
rng, PrivateKey.random(rng[]), localAddress(20303))
|
||||
talkProtocol = "echo".toBytes()
|
||||
|
||||
proc handler(request: seq[byte]): seq[byte] {.gcsafe, raises: [Defect].} =
|
||||
proc handler(protocol: TalkProtocol, request: seq[byte]): seq[byte]
|
||||
{.gcsafe, raises: [Defect].} =
|
||||
request
|
||||
|
||||
check node2.registerTalkProtocol(talkProtocol, handler).isOk()
|
||||
let echoProtocol = TalkProtocol(protocolHandler: handler)
|
||||
|
||||
check node2.registerTalkProtocol(talkProtocol, echoProtocol).isOk()
|
||||
let talkresp = await discv5_protocol.talkreq(node1, node2.localNode,
|
||||
talkProtocol, "hello".toBytes())
|
||||
|
||||
|
@ -667,13 +670,16 @@ procSuite "Discovery v5 Tests":
|
|||
rng, PrivateKey.random(rng[]), localAddress(20303))
|
||||
talkProtocol = "echo".toBytes()
|
||||
|
||||
proc handler(request: seq[byte]): seq[byte] {.gcsafe, raises: [Defect].} =
|
||||
proc handler(protocol: TalkProtocol, request: seq[byte]): seq[byte]
|
||||
{.gcsafe, raises: [Defect].} =
|
||||
request
|
||||
|
||||
let echoProtocol = TalkProtocol(protocolHandler: handler)
|
||||
|
||||
check:
|
||||
node2.registerTalkProtocol(talkProtocol, handler).isOk()
|
||||
node2.registerTalkProtocol(talkProtocol, handler).isErr()
|
||||
node2.registerTalkProtocol("test".toBytes(), handler).isOk()
|
||||
node2.registerTalkProtocol(talkProtocol, echoProtocol).isOk()
|
||||
node2.registerTalkProtocol(talkProtocol, echoProtocol).isErr()
|
||||
node2.registerTalkProtocol("test".toBytes(), echoProtocol).isOk()
|
||||
|
||||
await node1.closeWait()
|
||||
await node2.closeWait()
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
# nim-eth - Portal Network
|
||||
# Copyright (c) 2021 Status Research & Development GmbH
|
||||
# Licensed and distributed under either of
|
||||
# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT).
|
||||
# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0).
|
||||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||
|
||||
{.used.}
|
||||
|
||||
import
|
||||
chronos, testutils/unittests,
|
||||
../../eth/keys, # for rng
|
||||
../../eth/p2p/discoveryv5/protocol as discv5_protocol,
|
||||
../../eth/p2p/portal/protocol as portal_protocol,
|
||||
./discv5_test_helper
|
||||
|
||||
proc random(T: type UInt256, rng: var BrHmacDrbgContext): T =
|
||||
var key: UInt256
|
||||
brHmacDrbgGenerate(addr rng, addr key, csize_t(sizeof(key)))
|
||||
|
||||
key
|
||||
|
||||
procSuite "Portal Tests":
|
||||
let rng = newRng()
|
||||
|
||||
asyncTest "Portal Ping/Pong":
|
||||
let
|
||||
node1 = initDiscoveryNode(
|
||||
rng, PrivateKey.random(rng[]), localAddress(20302))
|
||||
node2 = initDiscoveryNode(
|
||||
rng, PrivateKey.random(rng[]), localAddress(20303))
|
||||
|
||||
proto1 = PortalProtocol.new(node1)
|
||||
proto2 = PortalProtocol.new(node2)
|
||||
|
||||
let pong = await proto1.ping(proto2.baseProtocol.localNode)
|
||||
|
||||
check:
|
||||
pong.isOk()
|
||||
pong.get().enrSeq == 1'u64
|
||||
pong.get().dataRadius == UInt256.high()
|
||||
|
||||
await node1.closeWait()
|
||||
await node2.closeWait()
|
||||
|
||||
asyncTest "Portal FindNode/Nodes":
|
||||
let
|
||||
node1 = initDiscoveryNode(
|
||||
rng, PrivateKey.random(rng[]), localAddress(20302))
|
||||
node2 = initDiscoveryNode(
|
||||
rng, PrivateKey.random(rng[]), localAddress(20303))
|
||||
|
||||
proto1 = PortalProtocol.new(node1)
|
||||
proto2 = PortalProtocol.new(node2)
|
||||
|
||||
block: # Find itself
|
||||
let nodes = await proto1.findNode(proto2.baseProtocol.localNode,
|
||||
List[uint16, 256](@[0'u16]))
|
||||
|
||||
check:
|
||||
nodes.isOk()
|
||||
nodes.get().total == 1'u8
|
||||
nodes.get().enrs.len() == 1
|
||||
|
||||
block: # Find nothing
|
||||
let nodes = await proto1.findNode(proto2.baseProtocol.localNode,
|
||||
List[uint16, 256](@[]))
|
||||
|
||||
check:
|
||||
nodes.isOk()
|
||||
nodes.get().total == 1'u8
|
||||
nodes.get().enrs.len() == 0
|
||||
|
||||
block: # Find for distance
|
||||
# TODO: Add test when implemented
|
||||
discard
|
||||
|
||||
await node1.closeWait()
|
||||
await node2.closeWait()
|
||||
|
||||
asyncTest "Portal FindContent/FoundContent":
|
||||
let
|
||||
node1 = initDiscoveryNode(
|
||||
rng, PrivateKey.random(rng[]), localAddress(20302))
|
||||
node2 = initDiscoveryNode(
|
||||
rng, PrivateKey.random(rng[]), localAddress(20303))
|
||||
|
||||
proto1 = PortalProtocol.new(node1)
|
||||
proto2 = PortalProtocol.new(node2)
|
||||
|
||||
let contentKey = ByteList(@(UInt256.random(rng[]).toBytes()))
|
||||
|
||||
let foundContent = await proto1.findContent(proto2.baseProtocol.localNode,
|
||||
contentKey)
|
||||
|
||||
check:
|
||||
foundContent.isOk()
|
||||
# TODO: adjust when implemented
|
||||
foundContent.get().enrs.len() == 0
|
||||
foundContent.get().payload.len() == 0
|
||||
|
||||
await node1.closeWait()
|
||||
await node2.closeWait()
|
|
@ -0,0 +1,156 @@
|
|||
# nim-eth - Portal Network
|
||||
# Copyright (c) 2021 Status Research & Development GmbH
|
||||
# Licensed and distributed under either of
|
||||
# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT).
|
||||
# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0).
|
||||
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
||||
|
||||
{.used.}
|
||||
|
||||
import
|
||||
std/unittest,
|
||||
stint, stew/[byteutils, results],
|
||||
../../eth/p2p/portal/messages
|
||||
|
||||
suite "Portal Protocol Message Encodings":
|
||||
test "Ping Request":
|
||||
var dataRadius: UInt256
|
||||
let
|
||||
enrSeq = 1'u64
|
||||
p = PingMessage(enrSeq: enrSeq, dataRadius: dataRadius)
|
||||
|
||||
let encoded = encodeMessage(p)
|
||||
check encoded.toHex ==
|
||||
"0101000000000000000000000000000000000000000000000000000000000000000000000000000000"
|
||||
let decoded = decodeMessage(encoded)
|
||||
check decoded.isOk()
|
||||
|
||||
let message = decoded.get()
|
||||
check:
|
||||
message.kind == ping
|
||||
message.ping.enrSeq == enrSeq
|
||||
message.ping.dataRadius == dataRadius
|
||||
|
||||
test "Pong Response":
|
||||
var dataRadius: UInt256
|
||||
let
|
||||
enrSeq = 1'u64
|
||||
p = PongMessage(enrSeq: enrSeq, dataRadius: dataRadius)
|
||||
|
||||
let encoded = encodeMessage(p)
|
||||
check encoded.toHex ==
|
||||
"0201000000000000000000000000000000000000000000000000000000000000000000000000000000"
|
||||
let decoded = decodeMessage(encoded)
|
||||
check decoded.isOk()
|
||||
|
||||
let message = decoded.get()
|
||||
check:
|
||||
message.kind == pong
|
||||
message.pong.enrSeq == enrSeq
|
||||
message.pong.dataRadius == dataRadius
|
||||
|
||||
test "FindNode Request":
|
||||
let
|
||||
distances = List[uint16, 256](@[0x0100'u16])
|
||||
fn = FindNodeMessage(distances: distances)
|
||||
|
||||
let encoded = encodeMessage(fn)
|
||||
check encoded.toHex == "03040000000001"
|
||||
|
||||
let decoded = decodeMessage(encoded)
|
||||
check decoded.isOk()
|
||||
|
||||
let message = decoded.get()
|
||||
check:
|
||||
message.kind == findnode
|
||||
message.findnode.distances == distances
|
||||
|
||||
test "Nodes Response (empty)":
|
||||
let
|
||||
total = 0x1'u8
|
||||
n = NodesMessage(total: total)
|
||||
|
||||
let encoded = encodeMessage(n)
|
||||
check encoded.toHex == "040105000000"
|
||||
|
||||
let decoded = decodeMessage(encoded)
|
||||
check decoded.isOk()
|
||||
|
||||
let message = decoded.get()
|
||||
check:
|
||||
message.kind == nodes
|
||||
message.nodes.total == total
|
||||
message.nodes.enrs.len() == 0
|
||||
|
||||
test "FindContent Request":
|
||||
let
|
||||
contentKey = ByteList(@[byte 0x01, 0x02, 0x03])
|
||||
fn = FindContentMessage(contentKey: contentKey)
|
||||
|
||||
let encoded = encodeMessage(fn)
|
||||
check encoded.toHex == "0504000000010203"
|
||||
|
||||
let decoded = decodeMessage(encoded)
|
||||
check decoded.isOk()
|
||||
|
||||
let message = decoded.get()
|
||||
check:
|
||||
message.kind == findcontent
|
||||
message.findcontent.contentKey == contentKey
|
||||
|
||||
test "FoundContent Response (empty enrs)":
|
||||
let
|
||||
enrs = List[ByteList, 32](@[])
|
||||
payload = ByteList(@[byte 0x01, 0x02, 0x03])
|
||||
n = FoundContentMessage(enrs: enrs, payload: payload)
|
||||
|
||||
let encoded = encodeMessage(n)
|
||||
check encoded.toHex == "060800000008000000010203"
|
||||
|
||||
let decoded = decodeMessage(encoded)
|
||||
check decoded.isOk()
|
||||
|
||||
let message = decoded.get()
|
||||
check:
|
||||
message.kind == foundcontent
|
||||
message.foundcontent.enrs.len() == 0
|
||||
message.foundcontent.payload == payload
|
||||
|
||||
test "Advertise Request":
|
||||
let
|
||||
contentKeys = List[ByteList, 32](List(@[ByteList(@[byte 0x01, 0x02, 0x03])]))
|
||||
am = AdvertiseMessage(contentKeys)
|
||||
# am = AdvertiseMessage(contentKeys: contentKeys)
|
||||
|
||||
let encoded = encodeMessage(am)
|
||||
check encoded.toHex == "0704000000010203"
|
||||
# "070400000004000000010203"
|
||||
|
||||
let decoded = decodeMessage(encoded)
|
||||
check decoded.isOk()
|
||||
|
||||
let message = decoded.get()
|
||||
check:
|
||||
message.kind == advertise
|
||||
message.advertise == contentKeys
|
||||
# message.advertise.contentKeys == contentKeys
|
||||
|
||||
test "RequestProofs Response": # That sounds weird
|
||||
let
|
||||
connectionId = List[byte, 4](@[byte 0x01, 0x02, 0x03, 0x04])
|
||||
contentKeys =
|
||||
List[ByteList, 32](List(@[ByteList(@[byte 0x01, 0x02, 0x03])]))
|
||||
n = RequestProofsMessage(connectionId: connectionId,
|
||||
contentKeys: contentKeys)
|
||||
|
||||
let encoded = encodeMessage(n)
|
||||
check encoded.toHex == "08080000000c0000000102030404000000010203"
|
||||
|
||||
let decoded = decodeMessage(encoded)
|
||||
check decoded.isOk()
|
||||
|
||||
let message = decoded.get()
|
||||
check:
|
||||
message.kind == requestproofs
|
||||
message.requestproofs.connectionId == connectionId
|
||||
message.requestproofs.contentKeys == contentKeys
|
Loading…
Reference in New Issue