Remove all `Result.get()`s & `Option` -> `Opt` (#902)

Co-authored-by: Ludovic Chenut <ludovic@status.im>
Co-authored-by: Diego <diego@status.im>
This commit is contained in:
Tanguy 2023-06-28 16:44:58 +02:00 committed by GitHub
parent 1c4d0832ce
commit 66f9dc9167
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
48 changed files with 820 additions and 973 deletions

View File

@ -54,7 +54,7 @@ type
protoVersion: string protoVersion: string
agentVersion: string agentVersion: string
nameResolver: NameResolver nameResolver: NameResolver
peerStoreCapacity: Option[int] peerStoreCapacity: Opt[int]
autonat: bool autonat: bool
circuitRelay: Relay circuitRelay: Relay
rdv: RendezVous rdv: RendezVous
@ -170,7 +170,7 @@ proc withMaxConnsPerPeer*(b: SwitchBuilder, maxConnsPerPeer: int): SwitchBuilder
b b
proc withPeerStore*(b: SwitchBuilder, capacity: int): SwitchBuilder {.public.} = proc withPeerStore*(b: SwitchBuilder, capacity: int): SwitchBuilder {.public.} =
b.peerStoreCapacity = some(capacity) b.peerStoreCapacity = Opt.some(capacity)
b b
proc withProtoVersion*(b: SwitchBuilder, protoVersion: string): SwitchBuilder {.public.} = proc withProtoVersion*(b: SwitchBuilder, protoVersion: string): SwitchBuilder {.public.} =
@ -242,9 +242,9 @@ proc build*(b: SwitchBuilder): Switch
if isNil(b.rng): if isNil(b.rng):
b.rng = newRng() b.rng = newRng()
let peerStore = let peerStore = block:
if isSome(b.peerStoreCapacity): b.peerStoreCapacity.withValue(capacity):
PeerStore.new(identify, b.peerStoreCapacity.get()) PeerStore.new(identify, capacity)
else: else:
PeerStore.new(identify) PeerStore.new(identify)
@ -316,7 +316,7 @@ proc newStandardSwitch*(
.withNameResolver(nameResolver) .withNameResolver(nameResolver)
.withNoise() .withNoise()
if privKey.isSome(): privKey.withValue(pkey):
b = b.withPrivateKey(privKey.get()) b = b.withPrivateKey(pkey)
b.build() b.build()

View File

@ -276,9 +276,6 @@ proc `$`*(cid: Cid): string =
BTCBase58.encode(cid.data.buffer) BTCBase58.encode(cid.data.buffer)
elif cid.cidver == CIDv1: elif cid.cidver == CIDv1:
let res = MultiBase.encode("base58btc", cid.data.buffer) let res = MultiBase.encode("base58btc", cid.data.buffer)
if res.isOk(): res.get("")
res.get()
else:
""
else: else:
"" ""

View File

@ -9,7 +9,7 @@
{.push raises: [].} {.push raises: [].}
import std/[options, tables, sequtils, sets] import std/[tables, sequtils, sets]
import pkg/[chronos, chronicles, metrics] import pkg/[chronos, chronicles, metrics]
import peerinfo, import peerinfo,
peerstore, peerstore,

View File

@ -468,7 +468,7 @@ proc init*[T: PrivateKey|PublicKey](key: var T, data: openArray[byte]): bool =
var pb = initProtoBuffer(@data) var pb = initProtoBuffer(@data)
let r1 = pb.getField(1, id) let r1 = pb.getField(1, id)
let r2 = pb.getField(2, buffer) let r2 = pb.getField(2, buffer)
if not(r1.isOk() and r1.get() and r2.isOk() and r2.get()): if not(r1.get(false) and r2.get(false)):
false false
else: else:
if cast[int8](id) notin SupportedSchemesInt or len(buffer) <= 0: if cast[int8](id) notin SupportedSchemesInt or len(buffer) <= 0:
@ -973,9 +973,8 @@ proc decodeProposal*(message: seq[byte], nonce, pubkey: var seq[byte],
let r4 = pb.getField(4, ciphers) let r4 = pb.getField(4, ciphers)
let r5 = pb.getField(5, hashes) let r5 = pb.getField(5, hashes)
r1.isOk() and r1.get() and r2.isOk() and r2.get() and r1.get(false) and r2.get(false) and r3.get(false) and
r3.isOk() and r3.get() and r4.isOk() and r4.get() and r4.get(false) and r5.get(false)
r5.isOk() and r5.get()
proc createExchange*(epubkey, signature: openArray[byte]): seq[byte] = proc createExchange*(epubkey, signature: openArray[byte]): seq[byte] =
## Create SecIO exchange message using ephemeral public key ``epubkey`` and ## Create SecIO exchange message using ephemeral public key ``epubkey`` and
@ -995,7 +994,7 @@ proc decodeExchange*(message: seq[byte],
var pb = initProtoBuffer(message) var pb = initProtoBuffer(message)
let r1 = pb.getField(1, pubkey) let r1 = pb.getField(1, pubkey)
let r2 = pb.getField(2, signature) let r2 = pb.getField(2, signature)
r1.isOk() and r1.get() and r2.isOk() and r2.get() r1.get(false) and r2.get(false)
## Serialization/Deserialization helpers ## Serialization/Deserialization helpers

View File

@ -25,7 +25,7 @@
## 5. LocalAddress: optional bytes ## 5. LocalAddress: optional bytes
## 6. RemoteAddress: optional bytes ## 6. RemoteAddress: optional bytes
## 7. Message: required bytes ## 7. Message: required bytes
import os, options import os
import nimcrypto/utils, stew/endians2 import nimcrypto/utils, stew/endians2
import protobuf/minprotobuf, stream/connection, protocols/secure/secure, import protobuf/minprotobuf, stream/connection, protocols/secure/secure,
multiaddress, peerid, varint, muxers/mplex/coder multiaddress, peerid, varint, muxers/mplex/coder
@ -33,7 +33,7 @@ import protobuf/minprotobuf, stream/connection, protocols/secure/secure,
from times import getTime, toUnix, fromUnix, nanosecond, format, Time, from times import getTime, toUnix, fromUnix, nanosecond, format, Time,
NanosecondRange, initTime NanosecondRange, initTime
from strutils import toHex, repeat from strutils import toHex, repeat
export peerid, options, multiaddress export peerid, multiaddress
type type
FlowDirection* = enum FlowDirection* = enum
@ -43,10 +43,10 @@ type
timestamp*: uint64 timestamp*: uint64
direction*: FlowDirection direction*: FlowDirection
message*: seq[byte] message*: seq[byte]
seqID*: Option[uint64] seqID*: Opt[uint64]
mtype*: Option[uint64] mtype*: Opt[uint64]
local*: Option[MultiAddress] local*: Opt[MultiAddress]
remote*: Option[MultiAddress] remote*: Opt[MultiAddress]
const const
libp2p_dump_dir* {.strdefine.} = "nim-libp2p" libp2p_dump_dir* {.strdefine.} = "nim-libp2p"
@ -72,7 +72,8 @@ proc dumpMessage*(conn: SecureConn, direction: FlowDirection,
var pb = initProtoBuffer(options = {WithVarintLength}) var pb = initProtoBuffer(options = {WithVarintLength})
pb.write(2, getTimestamp()) pb.write(2, getTimestamp())
pb.write(4, uint64(direction)) pb.write(4, uint64(direction))
pb.write(6, conn.observedAddr) conn.observedAddr.withValue(oaddr):
pb.write(6, oaddr)
pb.write(7, data) pb.write(7, data)
pb.finish() pb.finish()
@ -100,7 +101,7 @@ proc dumpMessage*(conn: SecureConn, direction: FlowDirection,
finally: finally:
close(handle) close(handle)
proc decodeDumpMessage*(data: openArray[byte]): Option[ProtoMessage] = proc decodeDumpMessage*(data: openArray[byte]): Opt[ProtoMessage] =
## Decode protobuf's message ProtoMessage from array of bytes ``data``. ## Decode protobuf's message ProtoMessage from array of bytes ``data``.
var var
pb = initProtoBuffer(data) pb = initProtoBuffer(data)
@ -108,13 +109,12 @@ proc decodeDumpMessage*(data: openArray[byte]): Option[ProtoMessage] =
ma1, ma2: MultiAddress ma1, ma2: MultiAddress
pmsg: ProtoMessage pmsg: ProtoMessage
let res2 = pb.getField(2, pmsg.timestamp) let
if res2.isErr() or not(res2.get()): r2 = pb.getField(2, pmsg.timestamp)
return none[ProtoMessage]() r4 = pb.getField(4, value)
r7 = pb.getField(7, pmsg.message)
let res4 = pb.getField(4, value) if not r2.get(false) or not r4.get(false) or not r7.get(false):
if res4.isErr() or not(res4.get()): return Opt.none(ProtoMessage)
return none[ProtoMessage]()
# `case` statement could not work here with an error "selector must be of an # `case` statement could not work here with an error "selector must be of an
# ordinal type, float or string" # ordinal type, float or string"
@ -124,30 +124,27 @@ proc decodeDumpMessage*(data: openArray[byte]): Option[ProtoMessage] =
elif value == uint64(Incoming): elif value == uint64(Incoming):
Incoming Incoming
else: else:
return none[ProtoMessage]() return Opt.none(ProtoMessage)
let res7 = pb.getField(7, pmsg.message) let r1 = pb.getField(1, value)
if res7.isErr() or not(res7.get()): if r1.get(false):
return none[ProtoMessage]() pmsg.seqID = Opt.some(value)
value = 0'u64 let r3 = pb.getField(3, value)
let res1 = pb.getField(1, value) if r3.get(false):
if res1.isOk() and res1.get(): pmsg.mtype = Opt.some(value)
pmsg.seqID = some(value)
value = 0'u64
let res3 = pb.getField(3, value)
if res3.isOk() and res3.get():
pmsg.mtype = some(value)
let res5 = pb.getField(5, ma1)
if res5.isOk() and res5.get():
pmsg.local = some(ma1)
let res6 = pb.getField(6, ma2)
if res6.isOk() and res6.get():
pmsg.remote = some(ma2)
some(pmsg) let
r5 = pb.getField(5, ma1)
r6 = pb.getField(6, ma2)
if r5.get(false):
pmsg.local = Opt.some(ma1)
if r6.get(false):
pmsg.remote = Opt.some(ma2)
iterator messages*(data: seq[byte]): Option[ProtoMessage] = Opt.some(pmsg)
iterator messages*(data: seq[byte]): Opt[ProtoMessage] =
## Iterate over sequence of bytes and decode all the ``ProtoMessage`` ## Iterate over sequence of bytes and decode all the ``ProtoMessage``
## messages we found. ## messages we found.
var value: uint64 var value: uint64
@ -242,27 +239,19 @@ proc toString*(msg: ProtoMessage, dump = true): string =
" >> " " >> "
let address = let address =
block: block:
let local = let local = block:
if msg.local.isSome(): msg.local.withValue(loc): "[" & $loc & "]"
"[" & $(msg.local.get()) & "]" else: "[LOCAL]"
else: let remote = block:
"[LOCAL]" msg.remote.withValue(rem): "[" & $rem & "]"
let remote = else: "[REMOTE]"
if msg.remote.isSome():
"[" & $(msg.remote.get()) & "]"
else:
"[REMOTE]"
local & direction & remote local & direction & remote
let seqid = let seqid = block:
if msg.seqID.isSome(): msg.seqID.wihValue(seqid): "seqID = " & $seqid & " "
"seqID = " & $(msg.seqID.get()) & " " else: ""
else: let mtype = block:
"" msg.mtype.withValue(typ): "type = " & $typ & " "
let mtype = else: ""
if msg.mtype.isSome():
"type = " & $(msg.mtype.get()) & " "
else:
""
res.add(" ") res.add(" ")
res.add(address) res.add(address)
res.add(" ") res.add(" ")

View File

@ -150,7 +150,7 @@ proc dialAndUpgrade(
if not isNil(result): if not isNil(result):
return result return result
proc tryReusingConnection(self: Dialer, peerId: PeerId): Future[Opt[Muxer]] {.async.} = proc tryReusingConnection(self: Dialer, peerId: PeerId): Opt[Muxer] =
let muxer = self.connManager.selectMuxer(peerId) let muxer = self.connManager.selectMuxer(peerId)
if muxer == nil: if muxer == nil:
return Opt.none(Muxer) return Opt.none(Muxer)
@ -174,10 +174,10 @@ proc internalConnect(
try: try:
await lock.acquire() await lock.acquire()
if peerId.isSome and reuseConnection: if reuseConnection:
let muxOpt = await self.tryReusingConnection(peerId.get()) peerId.withValue(peerId):
if muxOpt.isSome: self.tryReusingConnection(peerId).withValue(mux):
return muxOpt.get() return mux
let slot = self.connManager.getOutgoingSlot(forceDial) let slot = self.connManager.getOutgoingSlot(forceDial)
let muxed = let muxed =
@ -225,20 +225,20 @@ method connect*(
allowUnknownPeerId = false): Future[PeerId] {.async.} = allowUnknownPeerId = false): Future[PeerId] {.async.} =
## Connects to a peer and retrieve its PeerId ## Connects to a peer and retrieve its PeerId
let fullAddress = parseFullAddress(address) parseFullAddress(address).toOpt().withValue(fullAddress):
if fullAddress.isOk:
return (await self.internalConnect( return (await self.internalConnect(
Opt.some(fullAddress.get()[0]), Opt.some(fullAddress[0]),
@[fullAddress.get()[1]], @[fullAddress[1]],
false)).connection.peerId
else:
if allowUnknownPeerId == false:
raise newException(DialFailedError, "Address without PeerID and unknown peer id disabled!")
return (await self.internalConnect(
Opt.none(PeerId),
@[address],
false)).connection.peerId false)).connection.peerId
if allowUnknownPeerId == false:
raise newException(DialFailedError, "Address without PeerID and unknown peer id disabled!")
return (await self.internalConnect(
Opt.none(PeerId),
@[address],
false)).connection.peerId
proc negotiateStream( proc negotiateStream(
self: Dialer, self: Dialer,
conn: Connection, conn: Connection,

View File

@ -1080,19 +1080,15 @@ proc matchPart(pat: MaPattern, protos: seq[MultiCodec]): MaPatResult =
proc match*(pat: MaPattern, address: MultiAddress): bool = proc match*(pat: MaPattern, address: MultiAddress): bool =
## Match full ``address`` using pattern ``pat`` and return ``true`` if ## Match full ``address`` using pattern ``pat`` and return ``true`` if
## ``address`` satisfies pattern. ## ``address`` satisfies pattern.
let protos = address.protocols() let protos = address.protocols().valueOr: return false
if protos.isErr(): let res = matchPart(pat, protos)
return false
let res = matchPart(pat, protos.get())
res.flag and (len(res.rem) == 0) res.flag and (len(res.rem) == 0)
proc matchPartial*(pat: MaPattern, address: MultiAddress): bool = proc matchPartial*(pat: MaPattern, address: MultiAddress): bool =
## Match prefix part of ``address`` using pattern ``pat`` and return ## Match prefix part of ``address`` using pattern ``pat`` and return
## ``true`` if ``address`` starts with pattern. ## ``true`` if ``address`` starts with pattern.
let protos = address.protocols() let protos = address.protocols().valueOr: return false
if protos.isErr(): let res = matchPart(pat, protos)
return false
let res = matchPart(pat, protos.get())
res.flag res.flag
proc `$`*(pat: MaPattern): string = proc `$`*(pat: MaPattern): string =
@ -1121,12 +1117,8 @@ proc getField*(pb: ProtoBuffer, field: int,
if not(res): if not(res):
ok(false) ok(false)
else: else:
let ma = MultiAddress.init(buffer) value = MultiAddress.init(buffer).valueOr: return err(ProtoError.IncorrectBlob)
if ma.isOk(): ok(true)
value = ma.get()
ok(true)
else:
err(ProtoError.IncorrectBlob)
proc getRepeatedField*(pb: ProtoBuffer, field: int, proc getRepeatedField*(pb: ProtoBuffer, field: int,
value: var seq[MultiAddress]): ProtoResult[bool] {. value: var seq[MultiAddress]): ProtoResult[bool] {.
@ -1142,11 +1134,11 @@ proc getRepeatedField*(pb: ProtoBuffer, field: int,
ok(false) ok(false)
else: else:
for item in items: for item in items:
let ma = MultiAddress.init(item) let ma = MultiAddress.init(item).valueOr:
if ma.isOk(): debug "Unsupported MultiAddress in blob", ma = item
value.add(ma.get()) continue
else:
debug "Not supported MultiAddress in blob", ma = item value.add(ma)
if value.len == 0: if value.len == 0:
err(ProtoError.IncorrectBlob) err(ProtoError.IncorrectBlob)
else: else:

View File

@ -118,7 +118,7 @@ proc resolveMAddress*(
if not DNS.matchPartial(address): if not DNS.matchPartial(address):
res.incl(address) res.incl(address)
else: else:
let code = address[0].get().protoCode().get() let code = address[0].tryGet().protoCode().tryGet()
let seq = case code: let seq = case code:
of multiCodec("dns"): of multiCodec("dns"):
await self.resolveOneAddress(address) await self.resolveOneAddress(address)
@ -129,7 +129,7 @@ proc resolveMAddress*(
of multiCodec("dnsaddr"): of multiCodec("dnsaddr"):
await self.resolveDnsAddr(address) await self.resolveDnsAddr(address)
else: else:
doAssert false assert false
@[address] @[address]
for ad in seq: for ad in seq:
res.incl(ad) res.incl(ad)

View File

@ -9,10 +9,9 @@
{.push raises: [].} {.push raises: [].}
import import std/[sequtils, tables, sugar]
std/[sequtils, tables], import chronos
chronos, chronicles, import multiaddress, multicodec
multiaddress, multicodec
type type
## Manages observed MultiAddresses by reomte peers. It keeps track of the most observed IP and IP/Port. ## Manages observed MultiAddresses by reomte peers. It keeps track of the most observed IP and IP/Port.
@ -33,14 +32,16 @@ proc getProtocol(self: ObservedAddrManager, observations: seq[MultiAddress], mul
countTable.sort() countTable.sort()
var orderedPairs = toSeq(countTable.pairs) var orderedPairs = toSeq(countTable.pairs)
for (ma, count) in orderedPairs: for (ma, count) in orderedPairs:
let maFirst = ma[0].get() let protoCode = (ma[0].flatMap(protoCode)).valueOr: continue
if maFirst.protoCode.get() == multiCodec and count >= self.minCount: if protoCode == multiCodec and count >= self.minCount:
return Opt.some(ma) return Opt.some(ma)
return Opt.none(MultiAddress) return Opt.none(MultiAddress)
proc getMostObservedProtocol(self: ObservedAddrManager, multiCodec: MultiCodec): Opt[MultiAddress] = proc getMostObservedProtocol(self: ObservedAddrManager, multiCodec: MultiCodec): Opt[MultiAddress] =
## Returns the most observed IP address or none if the number of observations are less than minCount. ## Returns the most observed IP address or none if the number of observations are less than minCount.
let observedIPs = self.observedIPsAndPorts.mapIt(it[0].get()) let observedIPs = collect:
for observedIp in self.observedIPsAndPorts:
observedIp[0].valueOr: continue
return self.getProtocol(observedIPs, multiCodec) return self.getProtocol(observedIPs, multiCodec)
proc getMostObservedProtoAndPort(self: ObservedAddrManager, multiCodec: MultiCodec): Opt[MultiAddress] = proc getMostObservedProtoAndPort(self: ObservedAddrManager, multiCodec: MultiCodec): Opt[MultiAddress] =
@ -51,34 +52,24 @@ proc getMostObservedProtosAndPorts*(self: ObservedAddrManager): seq[MultiAddress
## Returns the most observed IP4/Port and IP6/Port address or an empty seq if the number of observations ## Returns the most observed IP4/Port and IP6/Port address or an empty seq if the number of observations
## are less than minCount. ## are less than minCount.
var res: seq[MultiAddress] var res: seq[MultiAddress]
let ip4 = self.getMostObservedProtoAndPort(multiCodec("ip4")) self.getMostObservedProtoAndPort(multiCodec("ip4")).withValue(ip4):
if ip4.isSome(): res.add(ip4)
res.add(ip4.get()) self.getMostObservedProtoAndPort(multiCodec("ip6")).withValue(ip6):
let ip6 = self.getMostObservedProtoAndPort(multiCodec("ip6")) res.add(ip6)
if ip6.isSome():
res.add(ip6.get())
return res return res
proc guessDialableAddr*( proc guessDialableAddr*(
self: ObservedAddrManager, self: ObservedAddrManager,
ma: MultiAddress): MultiAddress = ma: MultiAddress): MultiAddress =
## Replaces the first proto valeu of each listen address by the corresponding (matching the proto code) most observed value. ## Replaces the first proto value of each listen address by the corresponding (matching the proto code) most observed value.
## If the most observed value is not available, the original MultiAddress is returned. ## If the most observed value is not available, the original MultiAddress is returned.
try: let
let maFirst = ma[0] maFirst = ma[0].valueOr: return ma
let maRest = ma[1..^1] maRest = ma[1..^1].valueOr: return ma
if maRest.isErr(): maFirstProto = maFirst.protoCode().valueOr: return ma
return ma
let observedIP = self.getMostObservedProtocol(maFirst.get().protoCode().get()) let observedIP = self.getMostObservedProtocol(maFirstProto).valueOr: return ma
return return concat(observedIP, maRest).valueOr: ma
if observedIP.isNone() or maFirst.get() == observedIP.get():
ma
else:
observedIP.get() & maRest.get()
except CatchableError as error:
debug "Error while handling manual port forwarding", msg = error.msg
return ma
proc `$`*(self: ObservedAddrManager): string = proc `$`*(self: ObservedAddrManager): string =
## Returns a string representation of the ObservedAddrManager. ## Returns a string representation of the ObservedAddrManager.

View File

@ -185,19 +185,11 @@ proc random*(t: typedesc[PeerId], rng = newRng()): Result[PeerId, cstring] =
func match*(pid: PeerId, pubkey: PublicKey): bool = func match*(pid: PeerId, pubkey: PublicKey): bool =
## Returns ``true`` if ``pid`` matches public key ``pubkey``. ## Returns ``true`` if ``pid`` matches public key ``pubkey``.
let p = PeerId.init(pubkey) PeerId.init(pubkey) == Result[PeerId, cstring].ok(pid)
if p.isErr:
false
else:
pid == p.get()
func match*(pid: PeerId, seckey: PrivateKey): bool = func match*(pid: PeerId, seckey: PrivateKey): bool =
## Returns ``true`` if ``pid`` matches private key ``seckey``. ## Returns ``true`` if ``pid`` matches private key ``seckey``.
let p = PeerId.init(seckey) PeerId.init(seckey) == Result[PeerId, cstring].ok(pid)
if p.isErr:
false
else:
pid == p.get()
## Serialization/Deserialization helpers ## Serialization/Deserialization helpers

View File

@ -10,7 +10,7 @@
{.push raises: [].} {.push raises: [].}
{.push public.} {.push public.}
import std/[options, sequtils] import std/sequtils
import pkg/[chronos, chronicles, stew/results] import pkg/[chronos, chronicles, stew/results]
import peerid, multiaddress, multicodec, crypto/crypto, routing_record, errors, utility import peerid, multiaddress, multicodec, crypto/crypto, routing_record, errors, utility
@ -53,15 +53,12 @@ proc update*(p: PeerInfo) {.async.} =
for mapper in p.addressMappers: for mapper in p.addressMappers:
p.addrs = await mapper(p.addrs) p.addrs = await mapper(p.addrs)
let sprRes = SignedPeerRecord.init( p.signedPeerRecord = SignedPeerRecord.init(
p.privateKey, p.privateKey,
PeerRecord.init(p.peerId, p.addrs) PeerRecord.init(p.peerId, p.addrs)
) ).valueOr():
if sprRes.isOk: info "Can't update the signed peer record"
p.signedPeerRecord = sprRes.get() return
else:
discard
#info "Can't update the signed peer record"
proc addrs*(p: PeerInfo): seq[MultiAddress] = proc addrs*(p: PeerInfo): seq[MultiAddress] =
p.addrs p.addrs

View File

@ -16,7 +16,7 @@ runnableExamples:
# Create a custom book type # Create a custom book type
type MoodBook = ref object of PeerBook[string] type MoodBook = ref object of PeerBook[string]
var somePeerId = PeerId.random().get() var somePeerId = PeerId.random().expect("get random key")
peerStore[MoodBook][somePeerId] = "Happy" peerStore[MoodBook][somePeerId] = "Happy"
doAssert peerStore[MoodBook][somePeerId] == "Happy" doAssert peerStore[MoodBook][somePeerId] == "Happy"
@ -158,20 +158,20 @@ proc updatePeerInfo*(
if info.addrs.len > 0: if info.addrs.len > 0:
peerStore[AddressBook][info.peerId] = info.addrs peerStore[AddressBook][info.peerId] = info.addrs
if info.pubkey.isSome: info.pubkey.withValue(pubkey):
peerStore[KeyBook][info.peerId] = info.pubkey.get() peerStore[KeyBook][info.peerId] = pubkey
if info.agentVersion.isSome: info.agentVersion.withValue(agentVersion):
peerStore[AgentBook][info.peerId] = info.agentVersion.get().string peerStore[AgentBook][info.peerId] = agentVersion.string
if info.protoVersion.isSome: info.protoVersion.withValue(protoVersion):
peerStore[ProtoVersionBook][info.peerId] = info.protoVersion.get().string peerStore[ProtoVersionBook][info.peerId] = protoVersion.string
if info.protos.len > 0: if info.protos.len > 0:
peerStore[ProtoBook][info.peerId] = info.protos peerStore[ProtoBook][info.peerId] = info.protos
if info.signedPeerRecord.isSome: info.signedPeerRecord.withValue(signedPeerRecord):
peerStore[SPRBook][info.peerId] = info.signedPeerRecord.get() peerStore[SPRBook][info.peerId] = signedPeerRecord
let cleanupPos = peerStore.toClean.find(info.peerId) let cleanupPos = peerStore.toClean.find(info.peerId)
if cleanupPos >= 0: if cleanupPos >= 0:
@ -207,11 +207,11 @@ proc identify*(
let info = await peerStore.identify.identify(stream, stream.peerId) let info = await peerStore.identify.identify(stream, stream.peerId)
when defined(libp2p_agents_metrics): when defined(libp2p_agents_metrics):
var knownAgent = "unknown" var
if info.agentVersion.isSome and info.agentVersion.get().len > 0: knownAgent = "unknown"
let shortAgent = info.agentVersion.get().split("/")[0].safeToLowerAscii() shortAgent = info.agentVersion.get("").split("/")[0].safeToLowerAscii().get("")
if shortAgent.isOk() and KnownLibP2PAgentsSeq.contains(shortAgent.get()): if KnownLibP2PAgentsSeq.contains(shortAgent):
knownAgent = shortAgent.get() knownAgent = shortAgent
muxer.connection.setShortAgent(knownAgent) muxer.connection.setShortAgent(knownAgent)
peerStore.updatePeerInfo(info) peerStore.updatePeerInfo(info)

View File

@ -576,26 +576,18 @@ proc getField*[T: seq[byte]|string](data: ProtoBuffer, field: int,
proc getField*(pb: ProtoBuffer, field: int, proc getField*(pb: ProtoBuffer, field: int,
output: var ProtoBuffer): ProtoResult[bool] {.inline.} = output: var ProtoBuffer): ProtoResult[bool] {.inline.} =
var buffer: seq[byte] var buffer: seq[byte]
let res = pb.getField(field, buffer) if ? pb.getField(field, buffer):
if res.isOk(): output = initProtoBuffer(buffer)
if res.get(): ok(true)
output = initProtoBuffer(buffer)
ok(true)
else:
ok(false)
else: else:
err(res.error) ok(false)
proc getRequiredField*[T](pb: ProtoBuffer, field: int, proc getRequiredField*[T](pb: ProtoBuffer, field: int,
output: var T): ProtoResult[void] {.inline.} = output: var T): ProtoResult[void] {.inline.} =
let res = pb.getField(field, output) if ? pb.getField(field, output):
if res.isOk(): ok()
if res.get():
ok()
else:
err(RequiredFieldMissing)
else: else:
err(res.error) err(RequiredFieldMissing)
proc getRepeatedField*[T: seq[byte]|string](data: ProtoBuffer, field: int, proc getRepeatedField*[T: seq[byte]|string](data: ProtoBuffer, field: int,
output: var seq[T]): ProtoResult[bool] = output: var seq[T]): ProtoResult[bool] =
@ -675,14 +667,10 @@ proc getRepeatedField*[T: ProtoScalar](data: ProtoBuffer, field: int,
proc getRequiredRepeatedField*[T](pb: ProtoBuffer, field: int, proc getRequiredRepeatedField*[T](pb: ProtoBuffer, field: int,
output: var seq[T]): ProtoResult[void] {.inline.} = output: var seq[T]): ProtoResult[void] {.inline.} =
let res = pb.getRepeatedField(field, output) if ? pb.getRepeatedField(field, output):
if res.isOk(): ok()
if res.get():
ok()
else:
err(RequiredFieldMissing)
else: else:
err(res.error) err(RequiredFieldMissing)
proc getPackedRepeatedField*[T: ProtoScalar](data: ProtoBuffer, field: int, proc getPackedRepeatedField*[T: ProtoScalar](data: ProtoBuffer, field: int,
output: var seq[T]): ProtoResult[bool] = output: var seq[T]): ProtoResult[bool] =

View File

@ -9,7 +9,6 @@
{.push raises: [].} {.push raises: [].}
import std/options
import stew/results import stew/results
import chronos, chronicles import chronos, chronicles
import ../../../switch, import ../../../switch,
@ -24,8 +23,8 @@ type
AutonatClient* = ref object of RootObj AutonatClient* = ref object of RootObj
proc sendDial(conn: Connection, pid: PeerId, addrs: seq[MultiAddress]) {.async.} = proc sendDial(conn: Connection, pid: PeerId, addrs: seq[MultiAddress]) {.async.} =
let pb = AutonatDial(peerInfo: some(AutonatPeerInfo( let pb = AutonatDial(peerInfo: Opt.some(AutonatPeerInfo(
id: some(pid), id: Opt.some(pid),
addrs: addrs addrs: addrs
))).encode() ))).encode()
await conn.writeLp(pb.buffer) await conn.writeLp(pb.buffer)
@ -33,15 +32,13 @@ proc sendDial(conn: Connection, pid: PeerId, addrs: seq[MultiAddress]) {.async.}
method dialMe*(self: AutonatClient, switch: Switch, pid: PeerId, addrs: seq[MultiAddress] = newSeq[MultiAddress]()): method dialMe*(self: AutonatClient, switch: Switch, pid: PeerId, addrs: seq[MultiAddress] = newSeq[MultiAddress]()):
Future[MultiAddress] {.base, async.} = Future[MultiAddress] {.base, async.} =
proc getResponseOrRaise(autonatMsg: Option[AutonatMsg]): AutonatDialResponse {.raises: [AutonatError].} = proc getResponseOrRaise(autonatMsg: Opt[AutonatMsg]): AutonatDialResponse {.raises: [AutonatError].} =
if autonatMsg.isNone() or autonatMsg.withValue(msg):
autonatMsg.get().msgType != DialResponse or if msg.msgType == DialResponse:
autonatMsg.get().response.isNone() or msg.response.withValue(res):
(autonatMsg.get().response.get().status == Ok and if not (res.status == Ok and res.ma.isNone()):
autonatMsg.get().response.get().ma.isNone()): return res
raise newException(AutonatError, "Unexpected response") raise newException(AutonatError, "Unexpected response")
else:
autonatMsg.get().response.get()
let conn = let conn =
try: try:
@ -66,7 +63,7 @@ method dialMe*(self: AutonatClient, switch: Switch, pid: PeerId, addrs: seq[Mult
let response = getResponseOrRaise(AutonatMsg.decode(await conn.readLp(1024))) let response = getResponseOrRaise(AutonatMsg.decode(await conn.readLp(1024)))
return case response.status: return case response.status:
of ResponseStatus.Ok: of ResponseStatus.Ok:
response.ma.get() response.ma.tryGet()
of ResponseStatus.DialError: of ResponseStatus.DialError:
raise newException(AutonatUnreachableError, "Peer could not dial us back: " & response.text.get("")) raise newException(AutonatUnreachableError, "Peer could not dial us back: " & response.text.get(""))
else: else:

View File

@ -9,7 +9,6 @@
{.push raises: [].} {.push raises: [].}
import std/[options]
import stew/[results, objects] import stew/[results, objects]
import chronos, chronicles import chronos, chronicles
import ../../../multiaddress, import ../../../multiaddress,
@ -39,29 +38,29 @@ type
InternalError = 300 InternalError = 300
AutonatPeerInfo* = object AutonatPeerInfo* = object
id*: Option[PeerId] id*: Opt[PeerId]
addrs*: seq[MultiAddress] addrs*: seq[MultiAddress]
AutonatDial* = object AutonatDial* = object
peerInfo*: Option[AutonatPeerInfo] peerInfo*: Opt[AutonatPeerInfo]
AutonatDialResponse* = object AutonatDialResponse* = object
status*: ResponseStatus status*: ResponseStatus
text*: Option[string] text*: Opt[string]
ma*: Option[MultiAddress] ma*: Opt[MultiAddress]
AutonatMsg* = object AutonatMsg* = object
msgType*: MsgType msgType*: MsgType
dial*: Option[AutonatDial] dial*: Opt[AutonatDial]
response*: Option[AutonatDialResponse] response*: Opt[AutonatDialResponse]
NetworkReachability* {.pure.} = enum NetworkReachability* {.pure.} = enum
Unknown, NotReachable, Reachable Unknown, NotReachable, Reachable
proc encode(p: AutonatPeerInfo): ProtoBuffer = proc encode(p: AutonatPeerInfo): ProtoBuffer =
result = initProtoBuffer() result = initProtoBuffer()
if p.id.isSome(): p.id.withValue(id):
result.write(1, p.id.get()) result.write(1, id)
for ma in p.addrs: for ma in p.addrs:
result.write(2, ma.data.buffer) result.write(2, ma.data.buffer)
result.finish() result.finish()
@ -70,8 +69,8 @@ proc encode*(d: AutonatDial): ProtoBuffer =
result = initProtoBuffer() result = initProtoBuffer()
result.write(1, MsgType.Dial.uint) result.write(1, MsgType.Dial.uint)
var dial = initProtoBuffer() var dial = initProtoBuffer()
if d.peerInfo.isSome(): d.peerInfo.withValue(pinfo):
dial.write(1, encode(d.peerInfo.get())) dial.write(1, encode(pinfo))
dial.finish() dial.finish()
result.write(2, dial.buffer) result.write(2, dial.buffer)
result.finish() result.finish()
@ -81,72 +80,60 @@ proc encode*(r: AutonatDialResponse): ProtoBuffer =
result.write(1, MsgType.DialResponse.uint) result.write(1, MsgType.DialResponse.uint)
var bufferResponse = initProtoBuffer() var bufferResponse = initProtoBuffer()
bufferResponse.write(1, r.status.uint) bufferResponse.write(1, r.status.uint)
if r.text.isSome(): r.text.withValue(text):
bufferResponse.write(2, r.text.get()) bufferResponse.write(2, text)
if r.ma.isSome(): r.ma.withValue(ma):
bufferResponse.write(3, r.ma.get()) bufferResponse.write(3, ma)
bufferResponse.finish() bufferResponse.finish()
result.write(3, bufferResponse.buffer) result.write(3, bufferResponse.buffer)
result.finish() result.finish()
proc encode*(msg: AutonatMsg): ProtoBuffer = proc encode*(msg: AutonatMsg): ProtoBuffer =
if msg.dial.isSome(): msg.dial.withValue(dial):
return encode(msg.dial.get()) return encode(dial)
if msg.response.isSome(): msg.response.withValue(res):
return encode(msg.response.get()) return encode(res)
proc decode*(_: typedesc[AutonatMsg], buf: seq[byte]): Option[AutonatMsg] = proc decode*(_: typedesc[AutonatMsg], buf: seq[byte]): Opt[AutonatMsg] =
var var
msgTypeOrd: uint32 msgTypeOrd: uint32
pbDial: ProtoBuffer pbDial: ProtoBuffer
pbResponse: ProtoBuffer pbResponse: ProtoBuffer
msg: AutonatMsg msg: AutonatMsg
let let pb = initProtoBuffer(buf)
pb = initProtoBuffer(buf)
r1 = pb.getField(1, msgTypeOrd)
r2 = pb.getField(2, pbDial)
r3 = pb.getField(3, pbResponse)
if r1.isErr() or r2.isErr() or r3.isErr(): return none(AutonatMsg)
if r1.get() and not checkedEnumAssign(msg.msgType, msgTypeOrd): if ? pb.getField(1, msgTypeOrd).toOpt() and not checkedEnumAssign(msg.msgType, msgTypeOrd):
return none(AutonatMsg) return Opt.none(AutonatMsg)
if r2.get(): if ? pb.getField(2, pbDial).toOpt():
var var
pbPeerInfo: ProtoBuffer pbPeerInfo: ProtoBuffer
dial: AutonatDial dial: AutonatDial
let let r4 = ? pbDial.getField(1, pbPeerInfo).toOpt()
r4 = pbDial.getField(1, pbPeerInfo)
if r4.isErr(): return none(AutonatMsg)
var peerInfo: AutonatPeerInfo var peerInfo: AutonatPeerInfo
if r4.get(): if r4:
var pid: PeerId var pid: PeerId
let let
r5 = pbPeerInfo.getField(1, pid) r5 = ? pbPeerInfo.getField(1, pid).toOpt()
r6 = pbPeerInfo.getRepeatedField(2, peerInfo.addrs) r6 = ? pbPeerInfo.getRepeatedField(2, peerInfo.addrs).toOpt()
if r5.isErr() or r6.isErr(): return none(AutonatMsg) if r5: peerInfo.id = Opt.some(pid)
if r5.get(): peerInfo.id = some(pid) dial.peerInfo = Opt.some(peerInfo)
dial.peerInfo = some(peerInfo) msg.dial = Opt.some(dial)
msg.dial = some(dial)
if r3.get(): if ? pb.getField(3, pbResponse).toOpt():
var var
statusOrd: uint statusOrd: uint
text: string text: string
ma: MultiAddress ma: MultiAddress
response: AutonatDialResponse response: AutonatDialResponse
let if ? pbResponse.getField(1, statusOrd).optValue():
r4 = pbResponse.getField(1, statusOrd) if not checkedEnumAssign(response.status, statusOrd):
r5 = pbResponse.getField(2, text) return Opt.none(AutonatMsg)
r6 = pbResponse.getField(3, ma) if ? pbResponse.getField(2, text).optValue():
response.text = Opt.some(text)
if r4.isErr() or r5.isErr() or r6.isErr() or if ? pbResponse.getField(3, ma).optValue():
(r4.get() and not checkedEnumAssign(response.status, statusOrd)): response.ma = Opt.some(ma)
return none(AutonatMsg) msg.response = Opt.some(response)
if r5.get(): response.text = some(text) return Opt.some(msg)
if r6.get(): response.ma = some(ma)
msg.response = some(response)
return some(msg)

View File

@ -9,7 +9,7 @@
{.push raises: [].} {.push raises: [].}
import std/[options, sets, sequtils] import std/[sets, sequtils]
import stew/results import stew/results
import chronos, chronicles import chronos, chronicles
import ../../protocol, import ../../protocol,
@ -33,8 +33,8 @@ type
dialTimeout: Duration dialTimeout: Duration
proc sendDial(conn: Connection, pid: PeerId, addrs: seq[MultiAddress]) {.async.} = proc sendDial(conn: Connection, pid: PeerId, addrs: seq[MultiAddress]) {.async.} =
let pb = AutonatDial(peerInfo: some(AutonatPeerInfo( let pb = AutonatDial(peerInfo: Opt.some(AutonatPeerInfo(
id: some(pid), id: Opt.some(pid),
addrs: addrs addrs: addrs
))).encode() ))).encode()
await conn.writeLp(pb.buffer) await conn.writeLp(pb.buffer)
@ -42,16 +42,16 @@ proc sendDial(conn: Connection, pid: PeerId, addrs: seq[MultiAddress]) {.async.}
proc sendResponseError(conn: Connection, status: ResponseStatus, text: string = "") {.async.} = proc sendResponseError(conn: Connection, status: ResponseStatus, text: string = "") {.async.} =
let pb = AutonatDialResponse( let pb = AutonatDialResponse(
status: status, status: status,
text: if text == "": none(string) else: some(text), text: if text == "": Opt.none(string) else: Opt.some(text),
ma: none(MultiAddress) ma: Opt.none(MultiAddress)
).encode() ).encode()
await conn.writeLp(pb.buffer) await conn.writeLp(pb.buffer)
proc sendResponseOk(conn: Connection, ma: MultiAddress) {.async.} = proc sendResponseOk(conn: Connection, ma: MultiAddress) {.async.} =
let pb = AutonatDialResponse( let pb = AutonatDialResponse(
status: ResponseStatus.Ok, status: ResponseStatus.Ok,
text: some("Ok"), text: Opt.some("Ok"),
ma: some(ma) ma: Opt.some(ma)
).encode() ).encode()
await conn.writeLp(pb.buffer) await conn.writeLp(pb.buffer)
@ -70,8 +70,8 @@ proc tryDial(autonat: Autonat, conn: Connection, addrs: seq[MultiAddress]) {.asy
futs = addrs.mapIt(autonat.switch.dialer.tryDial(conn.peerId, @[it])) futs = addrs.mapIt(autonat.switch.dialer.tryDial(conn.peerId, @[it]))
let fut = await anyCompleted(futs).wait(autonat.dialTimeout) let fut = await anyCompleted(futs).wait(autonat.dialTimeout)
let ma = await fut let ma = await fut
if ma.isSome: ma.withValue(maddr):
await conn.sendResponseOk(ma.get()) await conn.sendResponseOk(maddr)
else: else:
await conn.sendResponseError(DialError, "Missing observed address") await conn.sendResponseError(DialError, "Missing observed address")
except CancelledError as exc: except CancelledError as exc:
@ -92,42 +92,40 @@ proc tryDial(autonat: Autonat, conn: Connection, addrs: seq[MultiAddress]) {.asy
f.cancel() f.cancel()
proc handleDial(autonat: Autonat, conn: Connection, msg: AutonatMsg): Future[void] = proc handleDial(autonat: Autonat, conn: Connection, msg: AutonatMsg): Future[void] =
if msg.dial.isNone() or msg.dial.get().peerInfo.isNone(): let dial = msg.dial.valueOr:
return conn.sendResponseError(BadRequest, "Missing Dial")
let peerInfo = dial.peerInfo.valueOr:
return conn.sendResponseError(BadRequest, "Missing Peer Info") return conn.sendResponseError(BadRequest, "Missing Peer Info")
let peerInfo = msg.dial.get().peerInfo.get() peerInfo.id.withValue(id):
if peerInfo.id.isSome() and peerInfo.id.get() != conn.peerId: if id != conn.peerId:
return conn.sendResponseError(BadRequest, "PeerId mismatch") return conn.sendResponseError(BadRequest, "PeerId mismatch")
if conn.observedAddr.isNone: let observedAddr = conn.observedAddr.valueOr:
return conn.sendResponseError(BadRequest, "Missing observed address") return conn.sendResponseError(BadRequest, "Missing observed address")
let observedAddr = conn.observedAddr.get()
var isRelayed = observedAddr.contains(multiCodec("p2p-circuit")) var isRelayed = observedAddr.contains(multiCodec("p2p-circuit")).valueOr:
if isRelayed.isErr() or isRelayed.get(): return conn.sendResponseError(DialRefused, "Invalid observed address")
if isRelayed:
return conn.sendResponseError(DialRefused, "Refused to dial a relayed observed address") return conn.sendResponseError(DialRefused, "Refused to dial a relayed observed address")
let hostIp = observedAddr[0] let hostIp = observedAddr[0].valueOr:
if hostIp.isErr() or not IP.match(hostIp.get()): return conn.sendResponseError(InternalError, "Wrong observed address")
trace "wrong observed address", address=observedAddr if not IP.match(hostIp):
return conn.sendResponseError(InternalError, "Expected an IP address") return conn.sendResponseError(InternalError, "Expected an IP address")
var addrs = initHashSet[MultiAddress]() var addrs = initHashSet[MultiAddress]()
addrs.incl(observedAddr) addrs.incl(observedAddr)
trace "addrs received", addrs = peerInfo.addrs trace "addrs received", addrs = peerInfo.addrs
for ma in peerInfo.addrs: for ma in peerInfo.addrs:
isRelayed = ma.contains(multiCodec("p2p-circuit")) isRelayed = ma.contains(multiCodec("p2p-circuit")).valueOr: continue
if isRelayed.isErr() or isRelayed.get(): let maFirst = ma[0].valueOr: continue
continue if not DNS_OR_IP.match(maFirst): continue
let maFirst = ma[0]
if maFirst.isErr() or not DNS_OR_IP.match(maFirst.get()):
continue
try: try:
addrs.incl( addrs.incl(
if maFirst.get() == hostIp.get(): if maFirst == hostIp:
ma ma
else: else:
let maEnd = ma[1..^1] let maEnd = ma[1..^1].valueOr: continue
if maEnd.isErr(): continue hostIp & maEnd
hostIp.get() & maEnd.get()
) )
except LPError as exc: except LPError as exc:
continue continue
@ -144,10 +142,10 @@ proc new*(T: typedesc[Autonat], switch: Switch, semSize: int = 1, dialTimeout =
let autonat = T(switch: switch, sem: newAsyncSemaphore(semSize), dialTimeout: dialTimeout) let autonat = T(switch: switch, sem: newAsyncSemaphore(semSize), dialTimeout: dialTimeout)
proc handleStream(conn: Connection, proto: string) {.async, gcsafe.} = proc handleStream(conn: Connection, proto: string) {.async, gcsafe.} =
try: try:
let msgOpt = AutonatMsg.decode(await conn.readLp(1024)) let msg = AutonatMsg.decode(await conn.readLp(1024)).valueOr:
if msgOpt.isNone() or msgOpt.get().msgType != MsgType.Dial:
raise newException(AutonatError, "Received malformed message") raise newException(AutonatError, "Received malformed message")
let msg = msgOpt.get() if msg.msgType != MsgType.Dial:
raise newException(AutonatError, "Message type should be dial")
await autonat.handleDial(conn, msg) await autonat.handleDial(conn, msg)
except CancelledError as exc: except CancelledError as exc:
raise exc raise exc

View File

@ -9,7 +9,7 @@
{.push raises: [].} {.push raises: [].}
import std/[options, deques, sequtils] import std/[deques, sequtils]
import chronos, metrics import chronos, metrics
import ../../../switch import ../../../switch
import ../../../wire import ../../../wire
@ -18,7 +18,7 @@ from core import NetworkReachability, AutonatUnreachableError
import ../../../utils/heartbeat import ../../../utils/heartbeat
import ../../../crypto/crypto import ../../../crypto/crypto
export options, core.NetworkReachability export core.NetworkReachability
logScope: logScope:
topics = "libp2p autonatservice" topics = "libp2p autonatservice"
@ -31,12 +31,12 @@ type
addressMapper: AddressMapper addressMapper: AddressMapper
scheduleHandle: Future[void] scheduleHandle: Future[void]
networkReachability*: NetworkReachability networkReachability*: NetworkReachability
confidence: Option[float] confidence: Opt[float]
answers: Deque[NetworkReachability] answers: Deque[NetworkReachability]
autonatClient: AutonatClient autonatClient: AutonatClient
statusAndConfidenceHandler: StatusAndConfidenceHandler statusAndConfidenceHandler: StatusAndConfidenceHandler
rng: ref HmacDrbgContext rng: ref HmacDrbgContext
scheduleInterval: Option[Duration] scheduleInterval: Opt[Duration]
askNewConnectedPeers: bool askNewConnectedPeers: bool
numPeersToAsk: int numPeersToAsk: int
maxQueueSize: int maxQueueSize: int
@ -44,13 +44,13 @@ type
dialTimeout: Duration dialTimeout: Duration
enableAddressMapper: bool enableAddressMapper: bool
StatusAndConfidenceHandler* = proc (networkReachability: NetworkReachability, confidence: Option[float]): Future[void] {.gcsafe, raises: [].} StatusAndConfidenceHandler* = proc (networkReachability: NetworkReachability, confidence: Opt[float]): Future[void] {.gcsafe, raises: [].}
proc new*( proc new*(
T: typedesc[AutonatService], T: typedesc[AutonatService],
autonatClient: AutonatClient, autonatClient: AutonatClient,
rng: ref HmacDrbgContext, rng: ref HmacDrbgContext,
scheduleInterval: Option[Duration] = none(Duration), scheduleInterval: Opt[Duration] = Opt.none(Duration),
askNewConnectedPeers = true, askNewConnectedPeers = true,
numPeersToAsk: int = 5, numPeersToAsk: int = 5,
maxQueueSize: int = 10, maxQueueSize: int = 10,
@ -60,7 +60,7 @@ proc new*(
return T( return T(
scheduleInterval: scheduleInterval, scheduleInterval: scheduleInterval,
networkReachability: Unknown, networkReachability: Unknown,
confidence: none(float), confidence: Opt.none(float),
answers: initDeque[NetworkReachability](), answers: initDeque[NetworkReachability](),
autonatClient: autonatClient, autonatClient: autonatClient,
rng: rng, rng: rng,
@ -95,14 +95,14 @@ proc handleAnswer(self: AutonatService, ans: NetworkReachability): Future[bool]
self.answers.addLast(ans) self.answers.addLast(ans)
self.networkReachability = Unknown self.networkReachability = Unknown
self.confidence = none(float) self.confidence = Opt.none(float)
const reachabilityPriority = [Reachable, NotReachable] const reachabilityPriority = [Reachable, NotReachable]
for reachability in reachabilityPriority: for reachability in reachabilityPriority:
let confidence = self.answers.countIt(it == reachability) / self.maxQueueSize let confidence = self.answers.countIt(it == reachability) / self.maxQueueSize
libp2p_autonat_reachability_confidence.set(value = confidence, labelValues = [$reachability]) libp2p_autonat_reachability_confidence.set(value = confidence, labelValues = [$reachability])
if self.confidence.isNone and confidence >= self.minConfidence: if self.confidence.isNone and confidence >= self.minConfidence:
self.networkReachability = reachability self.networkReachability = reachability
self.confidence = some(confidence) self.confidence = Opt.some(confidence)
debug "Current status", currentStats = $self.networkReachability, confidence = $self.confidence, answers = self.answers debug "Current status", currentStats = $self.networkReachability, confidence = $self.confidence, answers = self.answers
@ -189,8 +189,8 @@ method setup*(self: AutonatService, switch: Switch): Future[bool] {.async.} =
self.newConnectedPeerHandler = proc (peerId: PeerId, event: PeerEvent): Future[void] {.async.} = self.newConnectedPeerHandler = proc (peerId: PeerId, event: PeerEvent): Future[void] {.async.} =
discard askPeer(self, switch, peerId) discard askPeer(self, switch, peerId)
switch.connManager.addPeerEventHandler(self.newConnectedPeerHandler, PeerEventKind.Joined) switch.connManager.addPeerEventHandler(self.newConnectedPeerHandler, PeerEventKind.Joined)
if self.scheduleInterval.isSome(): self.scheduleInterval.withValue(interval):
self.scheduleHandle = schedule(self, switch, self.scheduleInterval.get()) self.scheduleHandle = schedule(self, switch, interval)
if self.enableAddressMapper: if self.enableAddressMapper:
switch.peerInfo.addressMappers.add(self.addressMapper) switch.peerInfo.addressMappers.add(self.addressMapper)
return hasBeenSetup return hasBeenSetup

View File

@ -9,8 +9,7 @@
{.push raises: [].} {.push raises: [].}
import std/[options, sets, sequtils] import std/[sets, sequtils]
import stew/[results, objects] import stew/[results, objects]
import chronos, chronicles import chronos, chronicles

View File

@ -9,10 +9,8 @@
{.push raises: [].} {.push raises: [].}
import times, options import times
import chronos, chronicles import chronos, chronicles
import ./relay, import ./relay,
./messages, ./messages,
./rconn, ./rconn,
@ -22,8 +20,6 @@ import ./relay,
../../../multiaddress, ../../../multiaddress,
../../../stream/connection ../../../stream/connection
export options
logScope: logScope:
topics = "libp2p relay relay-client" topics = "libp2p relay relay-client"
@ -44,28 +40,27 @@ type
Rsvp* = object Rsvp* = object
expire*: uint64 # required, Unix expiration time (UTC) expire*: uint64 # required, Unix expiration time (UTC)
addrs*: seq[MultiAddress] # relay address for reserving peer addrs*: seq[MultiAddress] # relay address for reserving peer
voucher*: Option[Voucher] # optional, reservation voucher voucher*: Opt[Voucher] # optional, reservation voucher
limitDuration*: uint32 # seconds limitDuration*: uint32 # seconds
limitData*: uint64 # bytes limitData*: uint64 # bytes
proc sendStopError(conn: Connection, code: StatusV2) {.async.} = proc sendStopError(conn: Connection, code: StatusV2) {.async.} =
trace "send stop status", status = $code & " (" & $ord(code) & ")" trace "send stop status", status = $code & " (" & $ord(code) & ")"
let msg = StopMessage(msgType: StopMessageType.Status, status: some(code)) let msg = StopMessage(msgType: StopMessageType.Status, status: Opt.some(code))
await conn.writeLp(encode(msg).buffer) await conn.writeLp(encode(msg).buffer)
proc handleRelayedConnect(cl: RelayClient, conn: Connection, msg: StopMessage) {.async.} = proc handleRelayedConnect(cl: RelayClient, conn: Connection, msg: StopMessage) {.async.} =
if msg.peer.isNone():
await sendStopError(conn, MalformedMessage)
return
let let
# TODO: check the go version to see in which way this could fail # TODO: check the go version to see in which way this could fail
# it's unclear in the spec # it's unclear in the spec
src = msg.peer.get() src = msg.peer.valueOr:
await sendStopError(conn, MalformedMessage)
return
limitDuration = msg.limit.duration limitDuration = msg.limit.duration
limitData = msg.limit.data limitData = msg.limit.data
msg = StopMessage( msg = StopMessage(
msgType: StopMessageType.Status, msgType: StopMessageType.Status,
status: some(Ok)) status: Opt.some(Ok))
pb = encode(msg) pb = encode(msg)
trace "incoming relay connection", src trace "incoming relay connection", src
@ -89,7 +84,7 @@ proc reserve*(cl: RelayClient,
pb = encode(HopMessage(msgType: HopMessageType.Reserve)) pb = encode(HopMessage(msgType: HopMessageType.Reserve))
msg = try: msg = try:
await conn.writeLp(pb.buffer) await conn.writeLp(pb.buffer)
HopMessage.decode(await conn.readLp(RelayClientMsgSize)).get() HopMessage.decode(await conn.readLp(RelayClientMsgSize)).tryGet()
except CancelledError as exc: except CancelledError as exc:
raise exc raise exc
except CatchableError as exc: except CatchableError as exc:
@ -100,21 +95,21 @@ proc reserve*(cl: RelayClient,
raise newException(ReservationError, "Unexpected relay response type") raise newException(ReservationError, "Unexpected relay response type")
if msg.status.get(UnexpectedMessage) != Ok: if msg.status.get(UnexpectedMessage) != Ok:
raise newException(ReservationError, "Reservation failed") raise newException(ReservationError, "Reservation failed")
if msg.reservation.isNone():
raise newException(ReservationError, "Missing reservation information")
let reservation = msg.reservation.get() let reservation = msg.reservation.valueOr:
raise newException(ReservationError, "Missing reservation information")
if reservation.expire > int64.high().uint64 or if reservation.expire > int64.high().uint64 or
now().utc > reservation.expire.int64.fromUnix.utc: now().utc > reservation.expire.int64.fromUnix.utc:
raise newException(ReservationError, "Bad expiration date") raise newException(ReservationError, "Bad expiration date")
result.expire = reservation.expire result.expire = reservation.expire
result.addrs = reservation.addrs result.addrs = reservation.addrs
if reservation.svoucher.isSome(): reservation.svoucher.withValue(sv):
let svoucher = SignedVoucher.decode(reservation.svoucher.get()) let svoucher = SignedVoucher.decode(sv).valueOr:
if svoucher.isErr() or svoucher.get().data.relayPeerId != peerId:
raise newException(ReservationError, "Invalid voucher") raise newException(ReservationError, "Invalid voucher")
result.voucher = some(svoucher.get().data) if svoucher.data.relayPeerId != peerId:
raise newException(ReservationError, "Invalid voucher PeerId")
result.voucher = Opt.some(svoucher.data)
result.limitDuration = msg.limit.duration result.limitDuration = msg.limit.duration
result.limitData = msg.limit.data result.limitData = msg.limit.data
@ -126,9 +121,9 @@ proc dialPeerV1*(
dstAddrs: seq[MultiAddress]): Future[Connection] {.async.} = dstAddrs: seq[MultiAddress]): Future[Connection] {.async.} =
var var
msg = RelayMessage( msg = RelayMessage(
msgType: some(RelayType.Hop), msgType: Opt.some(RelayType.Hop),
srcPeer: some(RelayPeer(peerId: cl.switch.peerInfo.peerId, addrs: cl.switch.peerInfo.addrs)), srcPeer: Opt.some(RelayPeer(peerId: cl.switch.peerInfo.peerId, addrs: cl.switch.peerInfo.addrs)),
dstPeer: some(RelayPeer(peerId: dstPeerId, addrs: dstAddrs))) dstPeer: Opt.some(RelayPeer(peerId: dstPeerId, addrs: dstAddrs)))
pb = encode(msg) pb = encode(msg)
trace "Dial peer", msgSend=msg trace "Dial peer", msgSend=msg
@ -151,16 +146,18 @@ proc dialPeerV1*(
raise exc raise exc
try: try:
if msgRcvFromRelayOpt.isNone: let msgRcvFromRelay = msgRcvFromRelayOpt.valueOr:
raise newException(RelayV1DialError, "Hop can't open destination stream") raise newException(RelayV1DialError, "Hop can't open destination stream")
let msgRcvFromRelay = msgRcvFromRelayOpt.get() if msgRcvFromRelay.msgType.tryGet() != RelayType.Status:
if msgRcvFromRelay.msgType.isNone or msgRcvFromRelay.msgType.get() != RelayType.Status:
raise newException(RelayV1DialError, "Hop can't open destination stream: wrong message type") raise newException(RelayV1DialError, "Hop can't open destination stream: wrong message type")
if msgRcvFromRelay.status.isNone or msgRcvFromRelay.status.get() != StatusV1.Success: if msgRcvFromRelay.status.tryGet() != StatusV1.Success:
raise newException(RelayV1DialError, "Hop can't open destination stream: status failed") raise newException(RelayV1DialError, "Hop can't open destination stream: status failed")
except RelayV1DialError as exc: except RelayV1DialError as exc:
await sendStatus(conn, StatusV1.HopCantOpenDstStream) await sendStatus(conn, StatusV1.HopCantOpenDstStream)
raise exc raise exc
except ValueError as exc:
await sendStatus(conn, StatusV1.HopCantOpenDstStream)
raise newException(RelayV1DialError, exc.msg)
result = conn result = conn
proc dialPeerV2*( proc dialPeerV2*(
@ -170,13 +167,13 @@ proc dialPeerV2*(
dstAddrs: seq[MultiAddress]): Future[Connection] {.async.} = dstAddrs: seq[MultiAddress]): Future[Connection] {.async.} =
let let
p = Peer(peerId: dstPeerId, addrs: dstAddrs) p = Peer(peerId: dstPeerId, addrs: dstAddrs)
pb = encode(HopMessage(msgType: HopMessageType.Connect, peer: some(p))) pb = encode(HopMessage(msgType: HopMessageType.Connect, peer: Opt.some(p)))
trace "Dial peer", p trace "Dial peer", p
let msgRcvFromRelay = try: let msgRcvFromRelay = try:
await conn.writeLp(pb.buffer) await conn.writeLp(pb.buffer)
HopMessage.decode(await conn.readLp(RelayClientMsgSize)).get() HopMessage.decode(await conn.readLp(RelayClientMsgSize)).tryGet()
except CancelledError as exc: except CancelledError as exc:
raise exc raise exc
except CatchableError as exc: except CatchableError as exc:
@ -186,19 +183,17 @@ proc dialPeerV2*(
if msgRcvFromRelay.msgType != HopMessageType.Status: if msgRcvFromRelay.msgType != HopMessageType.Status:
raise newException(RelayV2DialError, "Unexpected stop response") raise newException(RelayV2DialError, "Unexpected stop response")
if msgRcvFromRelay.status.get(UnexpectedMessage) != Ok: if msgRcvFromRelay.status.get(UnexpectedMessage) != Ok:
trace "Relay stop failed", msg = msgRcvFromRelay.status.get() trace "Relay stop failed", msg = msgRcvFromRelay.status
raise newException(RelayV2DialError, "Relay stop failure") raise newException(RelayV2DialError, "Relay stop failure")
conn.limitDuration = msgRcvFromRelay.limit.duration conn.limitDuration = msgRcvFromRelay.limit.duration
conn.limitData = msgRcvFromRelay.limit.data conn.limitData = msgRcvFromRelay.limit.data
return conn return conn
proc handleStopStreamV2(cl: RelayClient, conn: Connection) {.async, gcsafe.} = proc handleStopStreamV2(cl: RelayClient, conn: Connection) {.async, gcsafe.} =
let msgOpt = StopMessage.decode(await conn.readLp(RelayClientMsgSize)) let msg = StopMessage.decode(await conn.readLp(RelayClientMsgSize)).valueOr:
if msgOpt.isNone():
await sendHopStatus(conn, MalformedMessage) await sendHopStatus(conn, MalformedMessage)
return return
trace "client circuit relay v2 handle stream", msg = msgOpt.get() trace "client circuit relay v2 handle stream", msg
let msg = msgOpt.get()
if msg.msgType == StopMessageType.Connect: if msg.msgType == StopMessageType.Connect:
await cl.handleRelayedConnect(conn, msg) await cl.handleRelayedConnect(conn, msg)
@ -207,16 +202,14 @@ proc handleStopStreamV2(cl: RelayClient, conn: Connection) {.async, gcsafe.} =
await sendStopError(conn, MalformedMessage) await sendStopError(conn, MalformedMessage)
proc handleStop(cl: RelayClient, conn: Connection, msg: RelayMessage) {.async, gcsafe.} = proc handleStop(cl: RelayClient, conn: Connection, msg: RelayMessage) {.async, gcsafe.} =
if msg.srcPeer.isNone: let src = msg.srcPeer.valueOr:
await sendStatus(conn, StatusV1.StopSrcMultiaddrInvalid) await sendStatus(conn, StatusV1.StopSrcMultiaddrInvalid)
return return
let src = msg.srcPeer.get()
if msg.dstPeer.isNone: let dst = msg.dstPeer.valueOr:
await sendStatus(conn, StatusV1.StopDstMultiaddrInvalid) await sendStatus(conn, StatusV1.StopDstMultiaddrInvalid)
return return
let dst = msg.dstPeer.get()
if dst.peerId != cl.switch.peerInfo.peerId: if dst.peerId != cl.switch.peerInfo.peerId:
await sendStatus(conn, StatusV1.StopDstMultiaddrInvalid) await sendStatus(conn, StatusV1.StopDstMultiaddrInvalid)
return return
@ -234,13 +227,16 @@ proc handleStop(cl: RelayClient, conn: Connection, msg: RelayMessage) {.async, g
else: await conn.close() else: await conn.close()
proc handleStreamV1(cl: RelayClient, conn: Connection) {.async, gcsafe.} = proc handleStreamV1(cl: RelayClient, conn: Connection) {.async, gcsafe.} =
let msgOpt = RelayMessage.decode(await conn.readLp(RelayClientMsgSize)) let msg = RelayMessage.decode(await conn.readLp(RelayClientMsgSize)).valueOr:
if msgOpt.isNone:
await sendStatus(conn, StatusV1.MalformedMessage) await sendStatus(conn, StatusV1.MalformedMessage)
return return
trace "client circuit relay v1 handle stream", msg = msgOpt.get() trace "client circuit relay v1 handle stream", msg
let msg = msgOpt.get()
case msg.msgType.get: let typ = msg.msgType.valueOr:
trace "Message type not set"
await sendStatus(conn, StatusV1.MalformedMessage)
return
case typ:
of RelayType.Hop: of RelayType.Hop:
if cl.canHop: await cl.handleHop(conn, msg) if cl.canHop: await cl.handleHop(conn, msg)
else: await sendStatus(conn, StatusV1.HopCantSpeakRelay) else: await sendStatus(conn, StatusV1.HopCantSpeakRelay)

View File

@ -9,8 +9,8 @@
{.push raises: [].} {.push raises: [].}
import options, macros import macros
import stew/objects import stew/[objects, results]
import ../../../peerinfo, import ../../../peerinfo,
../../../signed_envelope ../../../signed_envelope
@ -46,36 +46,36 @@ type
addrs*: seq[MultiAddress] addrs*: seq[MultiAddress]
RelayMessage* = object RelayMessage* = object
msgType*: Option[RelayType] msgType*: Opt[RelayType]
srcPeer*: Option[RelayPeer] srcPeer*: Opt[RelayPeer]
dstPeer*: Option[RelayPeer] dstPeer*: Opt[RelayPeer]
status*: Option[StatusV1] status*: Opt[StatusV1]
proc encode*(msg: RelayMessage): ProtoBuffer = proc encode*(msg: RelayMessage): ProtoBuffer =
result = initProtoBuffer() result = initProtoBuffer()
if isSome(msg.msgType): msg.msgType.withValue(typ):
result.write(1, msg.msgType.get().ord.uint) result.write(1, typ.ord.uint)
if isSome(msg.srcPeer): msg.srcPeer.withValue(srcPeer):
var peer = initProtoBuffer() var peer = initProtoBuffer()
peer.write(1, msg.srcPeer.get().peerId) peer.write(1, srcPeer.peerId)
for ma in msg.srcPeer.get().addrs: for ma in srcPeer.addrs:
peer.write(2, ma.data.buffer) peer.write(2, ma.data.buffer)
peer.finish() peer.finish()
result.write(2, peer.buffer) result.write(2, peer.buffer)
if isSome(msg.dstPeer): msg.dstPeer.withValue(dstPeer):
var peer = initProtoBuffer() var peer = initProtoBuffer()
peer.write(1, msg.dstPeer.get().peerId) peer.write(1, dstPeer.peerId)
for ma in msg.dstPeer.get().addrs: for ma in dstPeer.addrs:
peer.write(2, ma.data.buffer) peer.write(2, ma.data.buffer)
peer.finish() peer.finish()
result.write(3, peer.buffer) result.write(3, peer.buffer)
if isSome(msg.status): msg.status.withValue(status):
result.write(4, msg.status.get().ord.uint) result.write(4, status.ord.uint)
result.finish() result.finish()
proc decode*(_: typedesc[RelayMessage], buf: seq[byte]): Option[RelayMessage] = proc decode*(_: typedesc[RelayMessage], buf: seq[byte]): Opt[RelayMessage] =
var var
rMsg: RelayMessage rMsg: RelayMessage
msgTypeOrd: uint32 msgTypeOrd: uint32
@ -85,38 +85,29 @@ proc decode*(_: typedesc[RelayMessage], buf: seq[byte]): Option[RelayMessage] =
pbSrc: ProtoBuffer pbSrc: ProtoBuffer
pbDst: ProtoBuffer pbDst: ProtoBuffer
let let pb = initProtoBuffer(buf)
pb = initProtoBuffer(buf)
r1 = pb.getField(1, msgTypeOrd)
r2 = pb.getField(2, pbSrc)
r3 = pb.getField(3, pbDst)
r4 = pb.getField(4, statusOrd)
if r1.isErr() or r2.isErr() or r3.isErr() or r4.isErr(): if ? pb.getField(1, msgTypeOrd).toOpt():
return none(RelayMessage)
if r2.get() and
(pbSrc.getField(1, src.peerId).isErr() or
pbSrc.getRepeatedField(2, src.addrs).isErr()):
return none(RelayMessage)
if r3.get() and
(pbDst.getField(1, dst.peerId).isErr() or
pbDst.getRepeatedField(2, dst.addrs).isErr()):
return none(RelayMessage)
if r1.get():
if msgTypeOrd.int notin RelayType: if msgTypeOrd.int notin RelayType:
return none(RelayMessage) return Opt.none(RelayMessage)
rMsg.msgType = some(RelayType(msgTypeOrd)) rMsg.msgType = Opt.some(RelayType(msgTypeOrd))
if r2.get(): rMsg.srcPeer = some(src)
if r3.get(): rMsg.dstPeer = some(dst) if ? pb.getField(2, pbSrc).toOpt():
if r4.get(): discard ? pbSrc.getField(1, src.peerId).toOpt()
discard ? pbSrc.getRepeatedField(2, src.addrs).toOpt()
rMsg.srcPeer = Opt.some(src)
if ? pb.getField(3, pbDst).toOpt():
discard ? pbDst.getField(1, dst.peerId).toOpt()
discard ? pbDst.getRepeatedField(2, dst.addrs).toOpt()
rMsg.dstPeer = Opt.some(dst)
if ? pb.getField(4, statusOrd).toOpt():
var status: StatusV1 var status: StatusV1
if not checkedEnumAssign(status, statusOrd): if not checkedEnumAssign(status, statusOrd):
return none(RelayMessage) return Opt.none(RelayMessage)
rMsg.status = some(status) rMsg.status = Opt.some(status)
some(rMsg) Opt.some(rMsg)
# Voucher # Voucher
@ -176,7 +167,7 @@ type
Reservation* = object Reservation* = object
expire*: uint64 # required, Unix expiration time (UTC) expire*: uint64 # required, Unix expiration time (UTC)
addrs*: seq[MultiAddress] # relay address for reserving peer addrs*: seq[MultiAddress] # relay address for reserving peer
svoucher*: Option[seq[byte]] # optional, reservation voucher svoucher*: Opt[seq[byte]] # optional, reservation voucher
Limit* = object Limit* = object
duration*: uint32 # seconds duration*: uint32 # seconds
data*: uint64 # bytes data*: uint64 # bytes
@ -196,30 +187,29 @@ type
Status = 2 Status = 2
HopMessage* = object HopMessage* = object
msgType*: HopMessageType msgType*: HopMessageType
peer*: Option[Peer] peer*: Opt[Peer]
reservation*: Option[Reservation] reservation*: Opt[Reservation]
limit*: Limit limit*: Limit
status*: Option[StatusV2] status*: Opt[StatusV2]
proc encode*(msg: HopMessage): ProtoBuffer = proc encode*(msg: HopMessage): ProtoBuffer =
var pb = initProtoBuffer() var pb = initProtoBuffer()
pb.write(1, msg.msgType.ord.uint) pb.write(1, msg.msgType.ord.uint)
if msg.peer.isSome(): msg.peer.withValue(peer):
var ppb = initProtoBuffer() var ppb = initProtoBuffer()
ppb.write(1, msg.peer.get().peerId) ppb.write(1, peer.peerId)
for ma in msg.peer.get().addrs: for ma in peer.addrs:
ppb.write(2, ma.data.buffer) ppb.write(2, ma.data.buffer)
ppb.finish() ppb.finish()
pb.write(2, ppb.buffer) pb.write(2, ppb.buffer)
if msg.reservation.isSome(): msg.reservation.withValue(rsrv):
let rsrv = msg.reservation.get()
var rpb = initProtoBuffer() var rpb = initProtoBuffer()
rpb.write(1, rsrv.expire) rpb.write(1, rsrv.expire)
for ma in rsrv.addrs: for ma in rsrv.addrs:
rpb.write(2, ma.data.buffer) rpb.write(2, ma.data.buffer)
if rsrv.svoucher.isSome(): rsrv.svoucher.withValue(vouch):
rpb.write(3, rsrv.svoucher.get()) rpb.write(3, vouch)
rpb.finish() rpb.finish()
pb.write(3, rpb.buffer) pb.write(3, rpb.buffer)
if msg.limit.duration > 0 or msg.limit.data > 0: if msg.limit.duration > 0 or msg.limit.data > 0:
@ -228,66 +218,51 @@ proc encode*(msg: HopMessage): ProtoBuffer =
if msg.limit.data > 0: lpb.write(2, msg.limit.data) if msg.limit.data > 0: lpb.write(2, msg.limit.data)
lpb.finish() lpb.finish()
pb.write(4, lpb.buffer) pb.write(4, lpb.buffer)
if msg.status.isSome(): msg.status.withValue(status):
pb.write(5, msg.status.get().ord.uint) pb.write(5, status.ord.uint)
pb.finish() pb.finish()
pb pb
proc decode*(_: typedesc[HopMessage], buf: seq[byte]): Option[HopMessage] = proc decode*(_: typedesc[HopMessage], buf: seq[byte]): Opt[HopMessage] =
var var msg: HopMessage
msg: HopMessage let pb = initProtoBuffer(buf)
msgTypeOrd: uint32
pbPeer: ProtoBuffer
pbReservation: ProtoBuffer
pbLimit: ProtoBuffer
statusOrd: uint32
peer: Peer
reservation: Reservation
limit: Limit
res: bool
let
pb = initProtoBuffer(buf)
r1 = pb.getRequiredField(1, msgTypeOrd)
r2 = pb.getField(2, pbPeer)
r3 = pb.getField(3, pbReservation)
r4 = pb.getField(4, pbLimit)
r5 = pb.getField(5, statusOrd)
if r1.isErr() or r2.isErr() or r3.isErr() or r4.isErr() or r5.isErr():
return none(HopMessage)
if r2.get() and
(pbPeer.getRequiredField(1, peer.peerId).isErr() or
pbPeer.getRepeatedField(2, peer.addrs).isErr()):
return none(HopMessage)
if r3.get():
var svoucher: seq[byte]
let rSVoucher = pbReservation.getField(3, svoucher)
if pbReservation.getRequiredField(1, reservation.expire).isErr() or
pbReservation.getRepeatedField(2, reservation.addrs).isErr() or
rSVoucher.isErr():
return none(HopMessage)
if rSVoucher.get(): reservation.svoucher = some(svoucher)
if r4.get() and
(pbLimit.getField(1, limit.duration).isErr() or
pbLimit.getField(2, limit.data).isErr()):
return none(HopMessage)
var msgTypeOrd: uint32
? pb.getRequiredField(1, msgTypeOrd).toOpt()
if not checkedEnumAssign(msg.msgType, msgTypeOrd): if not checkedEnumAssign(msg.msgType, msgTypeOrd):
return none(HopMessage) return Opt.none(HopMessage)
if r2.get(): msg.peer = some(peer)
if r3.get(): msg.reservation = some(reservation) var pbPeer: ProtoBuffer
if r4.get(): msg.limit = limit if ? pb.getField(2, pbPeer).toOpt():
if r5.get(): var peer: Peer
? pbPeer.getRequiredField(1, peer.peerId).toOpt()
discard ? pbPeer.getRepeatedField(2, peer.addrs).toOpt()
msg.peer = Opt.some(peer)
var pbReservation: ProtoBuffer
if ? pb.getField(3, pbReservation).toOpt():
var
svoucher: seq[byte]
reservation: Reservation
if ? pbReservation.getField(3, svoucher).toOpt():
reservation.svoucher = Opt.some(svoucher)
? pbReservation.getRequiredField(1, reservation.expire).toOpt()
discard ? pbReservation.getRepeatedField(2, reservation.addrs).toOpt()
msg.reservation = Opt.some(reservation)
var pbLimit: ProtoBuffer
if ? pb.getField(4, pbLimit).toOpt():
discard ? pbLimit.getField(1, msg.limit.duration).toOpt()
discard ? pbLimit.getField(2, msg.limit.data).toOpt()
var statusOrd: uint32
if ? pb.getField(5, statusOrd).toOpt():
var status: StatusV2 var status: StatusV2
if not checkedEnumAssign(status, statusOrd): if not checkedEnumAssign(status, statusOrd):
return none(HopMessage) return Opt.none(HopMessage)
msg.status = some(status) msg.status = Opt.some(status)
some(msg) Opt.some(msg)
# Circuit Relay V2 Stop Message # Circuit Relay V2 Stop Message
@ -297,19 +272,19 @@ type
Status = 1 Status = 1
StopMessage* = object StopMessage* = object
msgType*: StopMessageType msgType*: StopMessageType
peer*: Option[Peer] peer*: Opt[Peer]
limit*: Limit limit*: Limit
status*: Option[StatusV2] status*: Opt[StatusV2]
proc encode*(msg: StopMessage): ProtoBuffer = proc encode*(msg: StopMessage): ProtoBuffer =
var pb = initProtoBuffer() var pb = initProtoBuffer()
pb.write(1, msg.msgType.ord.uint) pb.write(1, msg.msgType.ord.uint)
if msg.peer.isSome(): msg.peer.withValue(peer):
var ppb = initProtoBuffer() var ppb = initProtoBuffer()
ppb.write(1, msg.peer.get().peerId) ppb.write(1, peer.peerId)
for ma in msg.peer.get().addrs: for ma in peer.addrs:
ppb.write(2, ma.data.buffer) ppb.write(2, ma.data.buffer)
ppb.finish() ppb.finish()
pb.write(2, ppb.buffer) pb.write(2, ppb.buffer)
@ -319,52 +294,40 @@ proc encode*(msg: StopMessage): ProtoBuffer =
if msg.limit.data > 0: lpb.write(2, msg.limit.data) if msg.limit.data > 0: lpb.write(2, msg.limit.data)
lpb.finish() lpb.finish()
pb.write(3, lpb.buffer) pb.write(3, lpb.buffer)
if msg.status.isSome(): msg.status.withValue(status):
pb.write(4, msg.status.get().ord.uint) pb.write(4, status.ord.uint)
pb.finish() pb.finish()
pb pb
proc decode*(_: typedesc[StopMessage], buf: seq[byte]): Option[StopMessage] = proc decode*(_: typedesc[StopMessage], buf: seq[byte]): Opt[StopMessage] =
var var msg: StopMessage
msg: StopMessage
msgTypeOrd: uint32
pbPeer: ProtoBuffer
pbLimit: ProtoBuffer
statusOrd: uint32
peer: Peer
limit: Limit
rVoucher: ProtoResult[bool]
res: bool
let let pb = initProtoBuffer(buf)
pb = initProtoBuffer(buf)
r1 = pb.getRequiredField(1, msgTypeOrd)
r2 = pb.getField(2, pbPeer)
r3 = pb.getField(3, pbLimit)
r4 = pb.getField(4, statusOrd)
if r1.isErr() or r2.isErr() or r3.isErr() or r4.isErr(): var msgTypeOrd: uint32
return none(StopMessage) ? pb.getRequiredField(1, msgTypeOrd).toOpt()
if msgTypeOrd.int notin StopMessageType:
if r2.get() and return Opt.none(StopMessage)
(pbPeer.getRequiredField(1, peer.peerId).isErr() or
pbPeer.getRepeatedField(2, peer.addrs).isErr()):
return none(StopMessage)
if r3.get() and
(pbLimit.getField(1, limit.duration).isErr() or
pbLimit.getField(2, limit.data).isErr()):
return none(StopMessage)
if msgTypeOrd.int notin StopMessageType.low.ord .. StopMessageType.high.ord:
return none(StopMessage)
msg.msgType = StopMessageType(msgTypeOrd) msg.msgType = StopMessageType(msgTypeOrd)
if r2.get(): msg.peer = some(peer)
if r3.get(): msg.limit = limit
if r4.get(): var pbPeer: ProtoBuffer
if ? pb.getField(2, pbPeer).toOpt():
var peer: Peer
? pbPeer.getRequiredField(1, peer.peerId).toOpt()
discard ? pbPeer.getRepeatedField(2, peer.addrs).toOpt()
msg.peer = Opt.some(peer)
var pbLimit: ProtoBuffer
if ? pb.getField(3, pbLimit).toOpt():
discard ? pbLimit.getField(1, msg.limit.duration).toOpt()
discard ? pbLimit.getField(2, msg.limit.data).toOpt()
var statusOrd: uint32
if ? pb.getField(4, statusOrd).toOpt():
var status: StatusV2 var status: StatusV2
if not checkedEnumAssign(status, statusOrd): if not checkedEnumAssign(status, statusOrd):
return none(StopMessage) return Opt.none(StopMessage)
msg.status = some(status) msg.status = Opt.some(status)
some(msg) Opt.some(msg)

View File

@ -9,7 +9,7 @@
{.push raises: [].} {.push raises: [].}
import options, sequtils, tables import sequtils, tables
import chronos, chronicles import chronos, chronicles
@ -90,11 +90,11 @@ proc createReserveResponse(
rsrv = Reservation(expire: expireUnix, rsrv = Reservation(expire: expireUnix,
addrs: r.switch.peerInfo.addrs.mapIt( addrs: r.switch.peerInfo.addrs.mapIt(
? it.concat(ma).orErr(CryptoError.KeyError)), ? it.concat(ma).orErr(CryptoError.KeyError)),
svoucher: some(? sv.encode)) svoucher: Opt.some(? sv.encode))
msg = HopMessage(msgType: HopMessageType.Status, msg = HopMessage(msgType: HopMessageType.Status,
reservation: some(rsrv), reservation: Opt.some(rsrv),
limit: r.limit, limit: r.limit,
status: some(Ok)) status: Opt.some(Ok))
return ok(msg) return ok(msg)
proc isRelayed*(conn: Connection): bool = proc isRelayed*(conn: Connection): bool =
@ -115,17 +115,16 @@ proc handleReserve(r: Relay, conn: Connection) {.async, gcsafe.} =
trace "Too many reservations", pid = conn.peerId trace "Too many reservations", pid = conn.peerId
await sendHopStatus(conn, ReservationRefused) await sendHopStatus(conn, ReservationRefused)
return return
trace "reserving relay slot for", pid = conn.peerId
let let
pid = conn.peerId pid = conn.peerId
expire = now().utc + r.reservationTTL expire = now().utc + r.reservationTTL
msg = r.createReserveResponse(pid, expire) msg = r.createReserveResponse(pid, expire).valueOr:
trace "error signing the voucher", pid
return
trace "reserving relay slot for", pid
if msg.isErr():
trace "error signing the voucher", error = error(msg), pid
return
r.rsvp[pid] = expire r.rsvp[pid] = expire
await conn.writeLp(encode(msg.get()).buffer) await conn.writeLp(encode(msg).buffer)
proc handleConnect(r: Relay, proc handleConnect(r: Relay,
connSrc: Connection, connSrc: Connection,
@ -134,13 +133,12 @@ proc handleConnect(r: Relay,
trace "connection attempt over relay connection" trace "connection attempt over relay connection"
await sendHopStatus(connSrc, PermissionDenied) await sendHopStatus(connSrc, PermissionDenied)
return return
if msg.peer.isNone():
await sendHopStatus(connSrc, MalformedMessage)
return
let let
msgPeer = msg.peer.valueOr:
await sendHopStatus(connSrc, MalformedMessage)
return
src = connSrc.peerId src = connSrc.peerId
dst = msg.peer.get().peerId dst = msgPeer.peerId
if dst notin r.rsvp: if dst notin r.rsvp:
trace "refusing connection, no reservation", src, dst trace "refusing connection, no reservation", src, dst
await sendHopStatus(connSrc, NoReservation) await sendHopStatus(connSrc, NoReservation)
@ -173,16 +171,17 @@ proc handleConnect(r: Relay,
proc sendStopMsg() {.async.} = proc sendStopMsg() {.async.} =
let stopMsg = StopMessage(msgType: StopMessageType.Connect, let stopMsg = StopMessage(msgType: StopMessageType.Connect,
peer: some(Peer(peerId: src, addrs: @[])), peer: Opt.some(Peer(peerId: src, addrs: @[])),
limit: r.limit) limit: r.limit)
await connDst.writeLp(encode(stopMsg).buffer) await connDst.writeLp(encode(stopMsg).buffer)
let msg = StopMessage.decode(await connDst.readLp(r.msgSize)).get() let msg = StopMessage.decode(await connDst.readLp(r.msgSize)).valueOr:
raise newException(SendStopError, "Malformed message")
if msg.msgType != StopMessageType.Status: if msg.msgType != StopMessageType.Status:
raise newException(SendStopError, "Unexpected stop response, not a status message") raise newException(SendStopError, "Unexpected stop response, not a status message")
if msg.status.get(UnexpectedMessage) != Ok: if msg.status.get(UnexpectedMessage) != Ok:
raise newException(SendStopError, "Relay stop failure") raise newException(SendStopError, "Relay stop failure")
await connSrc.writeLp(encode(HopMessage(msgType: HopMessageType.Status, await connSrc.writeLp(encode(HopMessage(msgType: HopMessageType.Status,
status: some(Ok))).buffer) status: Opt.some(Ok))).buffer)
try: try:
await sendStopMsg() await sendStopMsg()
except CancelledError as exc: except CancelledError as exc:
@ -202,12 +201,10 @@ proc handleConnect(r: Relay,
await bridge(rconnSrc, rconnDst) await bridge(rconnSrc, rconnDst)
proc handleHopStreamV2*(r: Relay, conn: Connection) {.async, gcsafe.} = proc handleHopStreamV2*(r: Relay, conn: Connection) {.async, gcsafe.} =
let msgOpt = HopMessage.decode(await conn.readLp(r.msgSize)) let msg = HopMessage.decode(await conn.readLp(r.msgSize)).valueOr:
if msgOpt.isNone():
await sendHopStatus(conn, MalformedMessage) await sendHopStatus(conn, MalformedMessage)
return return
trace "relayv2 handle stream", msg = msgOpt.get() trace "relayv2 handle stream", msg = msg
let msg = msgOpt.get()
case msg.msgType: case msg.msgType:
of HopMessageType.Reserve: await r.handleReserve(conn) of HopMessageType.Reserve: await r.handleReserve(conn)
of HopMessageType.Connect: await r.handleConnect(conn, msg) of HopMessageType.Connect: await r.handleConnect(conn, msg)
@ -225,15 +222,14 @@ proc handleHop*(r: Relay, connSrc: Connection, msg: RelayMessage) {.async, gcsaf
await sendStatus(connSrc, StatusV1.HopCantSpeakRelay) await sendStatus(connSrc, StatusV1.HopCantSpeakRelay)
return return
var src, dst: RelayPeer
proc checkMsg(): Result[RelayMessage, StatusV1] = proc checkMsg(): Result[RelayMessage, StatusV1] =
if msg.srcPeer.isNone: src = msg.srcPeer.valueOr:
return err(StatusV1.HopSrcMultiaddrInvalid) return err(StatusV1.HopSrcMultiaddrInvalid)
let src = msg.srcPeer.get()
if src.peerId != connSrc.peerId: if src.peerId != connSrc.peerId:
return err(StatusV1.HopSrcMultiaddrInvalid) return err(StatusV1.HopSrcMultiaddrInvalid)
if msg.dstPeer.isNone: dst = msg.dstPeer.valueOr:
return err(StatusV1.HopDstMultiaddrInvalid) return err(StatusV1.HopDstMultiaddrInvalid)
let dst = msg.dstPeer.get()
if dst.peerId == r.switch.peerInfo.peerId: if dst.peerId == r.switch.peerInfo.peerId:
return err(StatusV1.HopCantRelayToSelf) return err(StatusV1.HopCantRelayToSelf)
if not r.switch.isConnected(dst.peerId): if not r.switch.isConnected(dst.peerId):
@ -245,9 +241,6 @@ proc handleHop*(r: Relay, connSrc: Connection, msg: RelayMessage) {.async, gcsaf
await sendStatus(connSrc, check.error()) await sendStatus(connSrc, check.error())
return return
let
src = msg.srcPeer.get()
dst = msg.dstPeer.get()
if r.peerCount[src.peerId] >= r.maxCircuitPerPeer or if r.peerCount[src.peerId] >= r.maxCircuitPerPeer or
r.peerCount[dst.peerId] >= r.maxCircuitPerPeer: r.peerCount[dst.peerId] >= r.maxCircuitPerPeer:
trace "refusing connection; too many connection from src or to dst", src, dst trace "refusing connection; too many connection from src or to dst", src, dst
@ -271,9 +264,9 @@ proc handleHop*(r: Relay, connSrc: Connection, msg: RelayMessage) {.async, gcsaf
await connDst.close() await connDst.close()
let msgToSend = RelayMessage( let msgToSend = RelayMessage(
msgType: some(RelayType.Stop), msgType: Opt.some(RelayType.Stop),
srcPeer: some(src), srcPeer: Opt.some(src),
dstPeer: some(dst)) dstPeer: Opt.some(dst))
let msgRcvFromDstOpt = try: let msgRcvFromDstOpt = try:
await connDst.writeLp(encode(msgToSend).buffer) await connDst.writeLp(encode(msgToSend).buffer)
@ -285,12 +278,11 @@ proc handleHop*(r: Relay, connSrc: Connection, msg: RelayMessage) {.async, gcsaf
await sendStatus(connSrc, StatusV1.HopCantOpenDstStream) await sendStatus(connSrc, StatusV1.HopCantOpenDstStream)
return return
if msgRcvFromDstOpt.isNone: let msgRcvFromDst = msgRcvFromDstOpt.valueOr:
trace "error reading stop response", msg = msgRcvFromDstOpt trace "error reading stop response", msg = msgRcvFromDstOpt
await sendStatus(connSrc, StatusV1.HopCantOpenDstStream) await sendStatus(connSrc, StatusV1.HopCantOpenDstStream)
return return
let msgRcvFromDst = msgRcvFromDstOpt.get()
if msgRcvFromDst.msgType.get(RelayType.Stop) != RelayType.Status or if msgRcvFromDst.msgType.get(RelayType.Stop) != RelayType.Status or
msgRcvFromDst.status.get(StatusV1.StopRelayRefused) != StatusV1.Success: msgRcvFromDst.status.get(StatusV1.StopRelayRefused) != StatusV1.Success:
trace "unexcepted relay stop response", msgRcvFromDst trace "unexcepted relay stop response", msgRcvFromDst
@ -302,13 +294,16 @@ proc handleHop*(r: Relay, connSrc: Connection, msg: RelayMessage) {.async, gcsaf
await bridge(connSrc, connDst) await bridge(connSrc, connDst)
proc handleStreamV1(r: Relay, conn: Connection) {.async, gcsafe.} = proc handleStreamV1(r: Relay, conn: Connection) {.async, gcsafe.} =
let msgOpt = RelayMessage.decode(await conn.readLp(r.msgSize)) let msg = RelayMessage.decode(await conn.readLp(r.msgSize)).valueOr:
if msgOpt.isNone:
await sendStatus(conn, StatusV1.MalformedMessage) await sendStatus(conn, StatusV1.MalformedMessage)
return return
trace "relay handle stream", msg = msgOpt.get() trace "relay handle stream", msg
let msg = msgOpt.get()
case msg.msgType.get: let typ = msg.msgType.valueOr:
trace "Message type not set"
await sendStatus(conn, StatusV1.MalformedMessage)
return
case typ:
of RelayType.Hop: await r.handleHop(conn, msg) of RelayType.Hop: await r.handleHop(conn, msg)
of RelayType.Stop: await sendStatus(conn, StatusV1.StopRelayRefused) of RelayType.Stop: await sendStatus(conn, StatusV1.StopRelayRefused)
of RelayType.CanHop: await sendStatus(conn, StatusV1.Success) of RelayType.CanHop: await sendStatus(conn, StatusV1.Success)

View File

@ -61,9 +61,9 @@ proc dial*(self: RelayTransport, ma: MultiAddress): Future[Connection] {.async,
var var
relayPeerId: PeerId relayPeerId: PeerId
dstPeerId: PeerId dstPeerId: PeerId
if not relayPeerId.init(($(sma[^3].get())).split('/')[2]): if not relayPeerId.init(($(sma[^3].tryGet())).split('/')[2]):
raise newException(RelayV2DialError, "Relay doesn't exist") raise newException(RelayV2DialError, "Relay doesn't exist")
if not dstPeerId.init(($(sma[^1].get())).split('/')[2]): if not dstPeerId.init(($(sma[^1].tryGet())).split('/')[2]):
raise newException(RelayV2DialError, "Destination doesn't exist") raise newException(RelayV2DialError, "Destination doesn't exist")
trace "Dial", relayPeerId, dstPeerId trace "Dial", relayPeerId, dstPeerId
@ -91,13 +91,17 @@ method dial*(
hostname: string, hostname: string,
ma: MultiAddress, ma: MultiAddress,
peerId: Opt[PeerId] = Opt.none(PeerId)): Future[Connection] {.async, gcsafe.} = peerId: Opt[PeerId] = Opt.none(PeerId)): Future[Connection] {.async, gcsafe.} =
let address = MultiAddress.init($ma & "/p2p/" & $peerId.get()).tryGet() peerId.withValue(pid):
result = await self.dial(address) let address = MultiAddress.init($ma & "/p2p/" & $pid).tryGet()
result = await self.dial(address)
method handles*(self: RelayTransport, ma: MultiAddress): bool {.gcsafe} = method handles*(self: RelayTransport, ma: MultiAddress): bool {.gcsafe.} =
if ma.protocols.isOk(): try:
let sma = toSeq(ma.items()) if ma.protocols.isOk():
result = sma.len >= 2 and CircuitRelay.match(sma[^1].get()) let sma = toSeq(ma.items())
result = sma.len >= 2 and CircuitRelay.match(sma[^1].tryGet())
except CatchableError as exc:
result = false
trace "Handles return", ma, result trace "Handles return", ma, result
proc new*(T: typedesc[RelayTransport], cl: RelayClient, upgrader: Upgrade): T = proc new*(T: typedesc[RelayTransport], cl: RelayClient, upgrader: Upgrade): T =

View File

@ -9,10 +9,7 @@
{.push raises: [].} {.push raises: [].}
import options
import chronos, chronicles import chronos, chronicles
import ./messages, import ./messages,
../../../stream/connection ../../../stream/connection
@ -27,21 +24,21 @@ const
proc sendStatus*(conn: Connection, code: StatusV1) {.async, gcsafe.} = proc sendStatus*(conn: Connection, code: StatusV1) {.async, gcsafe.} =
trace "send relay/v1 status", status = $code & "(" & $ord(code) & ")" trace "send relay/v1 status", status = $code & "(" & $ord(code) & ")"
let let
msg = RelayMessage(msgType: some(RelayType.Status), status: some(code)) msg = RelayMessage(msgType: Opt.some(RelayType.Status), status: Opt.some(code))
pb = encode(msg) pb = encode(msg)
await conn.writeLp(pb.buffer) await conn.writeLp(pb.buffer)
proc sendHopStatus*(conn: Connection, code: StatusV2) {.async, gcsafe.} = proc sendHopStatus*(conn: Connection, code: StatusV2) {.async, gcsafe.} =
trace "send hop relay/v2 status", status = $code & "(" & $ord(code) & ")" trace "send hop relay/v2 status", status = $code & "(" & $ord(code) & ")"
let let
msg = HopMessage(msgType: HopMessageType.Status, status: some(code)) msg = HopMessage(msgType: HopMessageType.Status, status: Opt.some(code))
pb = encode(msg) pb = encode(msg)
await conn.writeLp(pb.buffer) await conn.writeLp(pb.buffer)
proc sendStopStatus*(conn: Connection, code: StatusV2) {.async.} = proc sendStopStatus*(conn: Connection, code: StatusV2) {.async.} =
trace "send stop relay/v2 status", status = $code & " (" & $ord(code) & ")" trace "send stop relay/v2 status", status = $code & " (" & $ord(code) & ")"
let let
msg = StopMessage(msgType: StopMessageType.Status, status: some(code)) msg = StopMessage(msgType: StopMessageType.Status, status: Opt.some(code))
pb = encode(msg) pb = encode(msg)
await conn.writeLp(pb.buffer) await conn.writeLp(pb.buffer)

View File

@ -71,9 +71,7 @@ chronicles.expandIt(IdentifyInfo):
pubkey = ($it.pubkey).shortLog pubkey = ($it.pubkey).shortLog
addresses = it.addrs.map(x => $x).join(",") addresses = it.addrs.map(x => $x).join(",")
protocols = it.protos.map(x => $x).join(",") protocols = it.protos.map(x => $x).join(",")
observable_address = observable_address = $it.observedAddr
if it.observedAddr.isSome(): $it.observedAddr.get()
else: "None"
proto_version = it.protoVersion.get("None") proto_version = it.protoVersion.get("None")
agent_version = it.agentVersion.get("None") agent_version = it.agentVersion.get("None")
signedPeerRecord = signedPeerRecord =
@ -88,13 +86,13 @@ proc encodeMsg(peerInfo: PeerInfo, observedAddr: Opt[MultiAddress], sendSpr: boo
let pkey = peerInfo.publicKey let pkey = peerInfo.publicKey
result.write(1, pkey.getBytes().get()) result.write(1, pkey.getBytes().expect("valid key"))
for ma in peerInfo.addrs: for ma in peerInfo.addrs:
result.write(2, ma.data.buffer) result.write(2, ma.data.buffer)
for proto in peerInfo.protocols: for proto in peerInfo.protocols:
result.write(3, proto) result.write(3, proto)
if observedAddr.isSome: observedAddr.withValue(observed):
result.write(4, observedAddr.get().data.buffer) result.write(4, observed.data.buffer)
let protoVersion = ProtoVersion let protoVersion = ProtoVersion
result.write(5, protoVersion) result.write(5, protoVersion)
let agentVersion = if peerInfo.agentVersion.len <= 0: let agentVersion = if peerInfo.agentVersion.len <= 0:
@ -106,13 +104,12 @@ proc encodeMsg(peerInfo: PeerInfo, observedAddr: Opt[MultiAddress], sendSpr: boo
## Optionally populate signedPeerRecord field. ## Optionally populate signedPeerRecord field.
## See https://github.com/libp2p/go-libp2p/blob/ddf96ce1cfa9e19564feb9bd3e8269958bbc0aba/p2p/protocol/identify/pb/identify.proto for reference. ## See https://github.com/libp2p/go-libp2p/blob/ddf96ce1cfa9e19564feb9bd3e8269958bbc0aba/p2p/protocol/identify/pb/identify.proto for reference.
if sendSpr: if sendSpr:
let sprBuff = peerInfo.signedPeerRecord.envelope.encode() peerInfo.signedPeerRecord.envelope.encode().toOpt().withValue(sprBuff):
if sprBuff.isOk(): result.write(8, sprBuff)
result.write(8, sprBuff.get())
result.finish() result.finish()
proc decodeMsg*(buf: seq[byte]): Option[IdentifyInfo] = proc decodeMsg*(buf: seq[byte]): Opt[IdentifyInfo] =
var var
iinfo: IdentifyInfo iinfo: IdentifyInfo
pubkey: PublicKey pubkey: PublicKey
@ -122,37 +119,22 @@ proc decodeMsg*(buf: seq[byte]): Option[IdentifyInfo] =
signedPeerRecord: SignedPeerRecord signedPeerRecord: SignedPeerRecord
var pb = initProtoBuffer(buf) var pb = initProtoBuffer(buf)
if ? pb.getField(1, pubkey).toOpt():
iinfo.pubkey = some(pubkey)
if ? pb.getField(8, signedPeerRecord).toOpt() and
pubkey == signedPeerRecord.envelope.publicKey:
iinfo.signedPeerRecord = some(signedPeerRecord.envelope)
discard ? pb.getRepeatedField(2, iinfo.addrs).toOpt()
discard ? pb.getRepeatedField(3, iinfo.protos).toOpt()
if ? pb.getField(4, oaddr).toOpt():
iinfo.observedAddr = some(oaddr)
if ? pb.getField(5, protoVersion).toOpt():
iinfo.protoVersion = some(protoVersion)
if ? pb.getField(6, agentVersion).toOpt():
iinfo.agentVersion = some(agentVersion)
let r1 = pb.getField(1, pubkey) debug "decodeMsg: decoded identify", iinfo
let r2 = pb.getRepeatedField(2, iinfo.addrs) Opt.some(iinfo)
let r3 = pb.getRepeatedField(3, iinfo.protos)
let r4 = pb.getField(4, oaddr)
let r5 = pb.getField(5, protoVersion)
let r6 = pb.getField(6, agentVersion)
let r8 = pb.getField(8, signedPeerRecord)
let res = r1.isOk() and r2.isOk() and r3.isOk() and
r4.isOk() and r5.isOk() and r6.isOk() and
r8.isOk()
if res:
if r1.get():
iinfo.pubkey = some(pubkey)
if r4.get():
iinfo.observedAddr = some(oaddr)
if r5.get():
iinfo.protoVersion = some(protoVersion)
if r6.get():
iinfo.agentVersion = some(agentVersion)
if r8.get() and r1.get():
if iinfo.pubkey.get() == signedPeerRecord.envelope.publicKey:
iinfo.signedPeerRecord = some(signedPeerRecord.envelope)
debug "decodeMsg: decoded identify", iinfo
some(iinfo)
else:
trace "decodeMsg: failed to decode received message"
none[IdentifyInfo]()
proc new*( proc new*(
T: typedesc[Identify], T: typedesc[Identify],
@ -193,26 +175,19 @@ proc identify*(self: Identify,
trace "identify: Empty message received!", conn trace "identify: Empty message received!", conn
raise newException(IdentityInvalidMsgError, "Empty message received!") raise newException(IdentityInvalidMsgError, "Empty message received!")
let infoOpt = decodeMsg(message) var info = decodeMsg(message).valueOr: raise newException(IdentityInvalidMsgError, "Incorrect message received!")
if infoOpt.isNone(): let
raise newException(IdentityInvalidMsgError, "Incorrect message received!") pubkey = info.pubkey.valueOr: raise newException(IdentityInvalidMsgError, "No pubkey in identify")
peer = PeerId.init(pubkey).valueOr: raise newException(IdentityInvalidMsgError, $error)
var info = infoOpt.get() if peer != remotePeerId:
if info.pubkey.isNone():
raise newException(IdentityInvalidMsgError, "No pubkey in identify")
let peer = PeerId.init(info.pubkey.get())
if peer.isErr:
raise newException(IdentityInvalidMsgError, $peer.error)
if peer.get() != remotePeerId:
trace "Peer ids don't match", remote = peer, local = remotePeerId trace "Peer ids don't match", remote = peer, local = remotePeerId
raise newException(IdentityNoMatchError, "Peer ids don't match") raise newException(IdentityNoMatchError, "Peer ids don't match")
info.peerId = peer.get() info.peerId = peer
if info.observedAddr.isSome: info.observedAddr.withValue(observed):
if not self.observedAddrManager.addObservation(info.observedAddr.get()): if not self.observedAddrManager.addObservation(observed):
debug "Observed address is not valid", observedAddr = info.observedAddr.get() debug "Observed address is not valid", observedAddr = observed
return info return info
proc new*(T: typedesc[IdentifyPush], handler: IdentifyPushHandler = nil): T {.public.} = proc new*(T: typedesc[IdentifyPush], handler: IdentifyPushHandler = nil): T {.public.} =
@ -228,21 +203,18 @@ proc init*(p: IdentifyPush) =
try: try:
var message = await conn.readLp(64*1024) var message = await conn.readLp(64*1024)
let infoOpt = decodeMsg(message) var identInfo = decodeMsg(message).valueOr:
if infoOpt.isNone():
raise newException(IdentityInvalidMsgError, "Incorrect message received!") raise newException(IdentityInvalidMsgError, "Incorrect message received!")
var indentInfo = infoOpt.get() identInfo.pubkey.withValue(pubkey):
let receivedPeerId = PeerId.init(pubkey).tryGet()
if indentInfo.pubkey.isSome:
let receivedPeerId = PeerId.init(indentInfo.pubkey.get()).tryGet()
if receivedPeerId != conn.peerId: if receivedPeerId != conn.peerId:
raise newException(IdentityNoMatchError, "Peer ids don't match") raise newException(IdentityNoMatchError, "Peer ids don't match")
indentInfo.peerId = receivedPeerId identInfo.peerId = receivedPeerId
trace "triggering peer event", peerInfo = conn.peerId trace "triggering peer event", peerInfo = conn.peerId
if not isNil(p.identifyHandler): if not isNil(p.identifyHandler):
await p.identifyHandler(conn.peerId, indentInfo) await p.identifyHandler(conn.peerId, identInfo)
except CancelledError as exc: except CancelledError as exc:
raise exc raise exc
except CatchableError as exc: except CatchableError as exc:

View File

@ -185,11 +185,11 @@ method unsubscribePeer*(g: GossipSub, peer: PeerId) =
return return
# remove from peer IPs collection too # remove from peer IPs collection too
if pubSubPeer.address.isSome(): pubSubPeer.address.withValue(address):
g.peersInIP.withValue(pubSubPeer.address.get(), s): g.peersInIP.withValue(address, s):
s[].excl(pubSubPeer.peerId) s[].excl(pubSubPeer.peerId)
if s[].len == 0: if s[].len == 0:
g.peersInIP.del(pubSubPeer.address.get()) g.peersInIP.del(address)
for t in toSeq(g.mesh.keys): for t in toSeq(g.mesh.keys):
trace "pruning unsubscribing peer", pubSubPeer, score = pubSubPeer.score trace "pruning unsubscribing peer", pubSubPeer, score = pubSubPeer.score

View File

@ -49,9 +49,7 @@ proc pruned*(g: GossipSub,
backoff = none(Duration)) {.raises: [].} = backoff = none(Duration)) {.raises: [].} =
if setBackoff: if setBackoff:
let let
backoffDuration = backoffDuration = backoff.get(g.parameters.pruneBackoff)
if isSome(backoff): backoff.get()
else: g.parameters.pruneBackoff
backoffMoment = Moment.fromNow(backoffDuration) backoffMoment = Moment.fromNow(backoffDuration)
g.backingOff g.backingOff
@ -191,20 +189,15 @@ proc handleGraft*(g: GossipSub,
proc getPeers(prune: ControlPrune, peer: PubSubPeer): seq[(PeerId, Option[PeerRecord])] = proc getPeers(prune: ControlPrune, peer: PubSubPeer): seq[(PeerId, Option[PeerRecord])] =
var routingRecords: seq[(PeerId, Option[PeerRecord])] var routingRecords: seq[(PeerId, Option[PeerRecord])]
for record in prune.peers: for record in prune.peers:
let peerRecord = var peerRecord = none(PeerRecord)
if record.signedPeerRecord.len == 0: if record.signedPeerRecord.len > 0:
none(PeerRecord) SignedPeerRecord.decode(record.signedPeerRecord).toOpt().withValue(spr):
else: if record.peerId != spr.data.peerId:
let signedRecord = SignedPeerRecord.decode(record.signedPeerRecord) trace "peer sent envelope with wrong public key", peer
if signedRecord.isErr:
trace "peer sent invalid SPR", peer, error=signedRecord.error
none(PeerRecord)
else: else:
if record.peerId != signedRecord.get().data.peerId: peerRecord = some(spr.data)
trace "peer sent envelope with wrong public key", peer else:
none(PeerRecord) trace "peer sent invalid SPR", peer
else:
some(signedRecord.get().data)
routingRecords.add((record.peerId, peerRecord)) routingRecords.add((record.peerId, peerRecord))
@ -296,12 +289,11 @@ proc handleIWant*(g: GossipSub,
libp2p_gossipsub_received_iwants.inc(1, labelValues=["skipped"]) libp2p_gossipsub_received_iwants.inc(1, labelValues=["skipped"])
return messages return messages
continue continue
let msg = g.mcache.get(mid) let msg = g.mcache.get(mid).valueOr:
if msg.isSome:
libp2p_gossipsub_received_iwants.inc(1, labelValues=["correct"])
messages.add(msg.get())
else:
libp2p_gossipsub_received_iwants.inc(1, labelValues=["unknown"]) libp2p_gossipsub_received_iwants.inc(1, labelValues=["unknown"])
continue
libp2p_gossipsub_received_iwants.inc(1, labelValues=["correct"])
messages.add(msg)
return messages return messages
proc commitMetrics(metrics: var MeshMetrics) {.raises: [].} = proc commitMetrics(metrics: var MeshMetrics) {.raises: [].} =

View File

@ -9,7 +9,7 @@
{.push raises: [].} {.push raises: [].}
import std/[tables, sets, options] import std/[tables, sets]
import chronos, chronicles, metrics import chronos, chronicles, metrics
import "."/[types] import "."/[types]
import ".."/[pubsubpeer] import ".."/[pubsubpeer]
@ -71,20 +71,17 @@ func `/`(a, b: Duration): float64 =
func byScore*(x,y: PubSubPeer): int = system.cmp(x.score, y.score) func byScore*(x,y: PubSubPeer): int = system.cmp(x.score, y.score)
proc colocationFactor(g: GossipSub, peer: PubSubPeer): float64 = proc colocationFactor(g: GossipSub, peer: PubSubPeer): float64 =
if peer.address.isNone(): let address = peer.address.valueOr: return 0.0
0.0
g.peersInIP.mgetOrPut(address, initHashSet[PeerId]()).incl(peer.peerId)
let
ipPeers = g.peersInIP.getOrDefault(address).len().float64
if ipPeers > g.parameters.ipColocationFactorThreshold:
trace "colocationFactor over threshold", peer, address, ipPeers
let over = ipPeers - g.parameters.ipColocationFactorThreshold
over * over
else: else:
let 0.0
address = peer.address.get()
g.peersInIP.mgetOrPut(address, initHashSet[PeerId]()).incl(peer.peerId)
let
ipPeers = g.peersInIP.getOrDefault(address).len().float64
if ipPeers > g.parameters.ipColocationFactorThreshold:
trace "colocationFactor over threshold", peer, address, ipPeers
let over = ipPeers - g.parameters.ipColocationFactorThreshold
over * over
else:
0.0
{.pop.} {.pop.}

View File

@ -170,10 +170,9 @@ proc broadcast*(
else: else:
libp2p_pubsub_broadcast_messages.inc(npeers, labelValues = ["generic"]) libp2p_pubsub_broadcast_messages.inc(npeers, labelValues = ["generic"])
if msg.control.isSome(): msg.control.withValue(control):
libp2p_pubsub_broadcast_iwant.inc(npeers * msg.control.get().iwant.len.int64) libp2p_pubsub_broadcast_iwant.inc(npeers * control.iwant.len.int64)
let control = msg.control.get()
for ihave in control.ihave: for ihave in control.ihave:
if p.knownTopics.contains(ihave.topicId): if p.knownTopics.contains(ihave.topicId):
libp2p_pubsub_broadcast_ihave.inc(npeers, labelValues = [ihave.topicId]) libp2p_pubsub_broadcast_ihave.inc(npeers, labelValues = [ihave.topicId])
@ -244,9 +243,8 @@ proc updateMetrics*(p: PubSub, rpcMsg: RPCMsg) =
else: else:
libp2p_pubsub_received_messages.inc(labelValues = ["generic"]) libp2p_pubsub_received_messages.inc(labelValues = ["generic"])
if rpcMsg.control.isSome(): rpcMsg.control.withValue(control):
libp2p_pubsub_received_iwant.inc(rpcMsg.control.get().iwant.len.int64) libp2p_pubsub_received_iwant.inc(control.iwant.len.int64)
template control: untyped = rpcMsg.control.unsafeGet()
for ihave in control.ihave: for ihave in control.ihave:
if p.knownTopics.contains(ihave.topicId): if p.knownTopics.contains(ihave.topicId):
libp2p_pubsub_received_ihave.inc(labelValues = [ihave.topicId]) libp2p_pubsub_received_ihave.inc(labelValues = [ihave.topicId])

View File

@ -133,28 +133,26 @@ proc handle*(p: PubSubPeer, conn: Connection) {.async.} =
conn, peer = p, closed = conn.closed, conn, peer = p, closed = conn.closed,
data = data.shortLog data = data.shortLog
var rmsg = decodeRpcMsg(data) var rmsg = decodeRpcMsg(data).valueOr:
data = newSeq[byte]() # Release memory debug "failed to decode msg from peer",
if rmsg.isErr():
notice "failed to decode msg from peer",
conn, peer = p, closed = conn.closed, conn, peer = p, closed = conn.closed,
err = rmsg.error() err = error
break break
data = newSeq[byte]() # Release memory
trace "decoded msg from peer", trace "decoded msg from peer",
conn, peer = p, closed = conn.closed, conn, peer = p, closed = conn.closed,
msg = rmsg.get().shortLog msg = rmsg.shortLog
# trigger hooks # trigger hooks
p.recvObservers(rmsg.get()) p.recvObservers(rmsg)
when defined(libp2p_expensive_metrics): when defined(libp2p_expensive_metrics):
for m in rmsg.get().messages: for m in rmsg.messages:
for t in m.topicIDs: for t in m.topicIDs:
# metrics # metrics
libp2p_pubsub_received_messages.inc(labelValues = [$p.peerId, t]) libp2p_pubsub_received_messages.inc(labelValues = [$p.peerId, t])
await p.handler(p, rmsg.get()) await p.handler(p, rmsg)
finally: finally:
await conn.close() await conn.close()
except CancelledError: except CancelledError:

View File

@ -66,18 +66,17 @@ proc init*(
var msg = Message(data: data, topicIDs: @[topic]) var msg = Message(data: data, topicIDs: @[topic])
# order matters, we want to include seqno in the signature # order matters, we want to include seqno in the signature
if seqno.isSome: seqno.withValue(seqn):
msg.seqno = @(seqno.get().toBytesBE()) msg.seqno = @(seqn.toBytesBE())
if peer.isSome: peer.withValue(peer):
let peer = peer.get()
msg.fromPeer = peer.peerId msg.fromPeer = peer.peerId
if sign: if sign:
msg.signature = sign(msg, peer.privateKey).expect("Couldn't sign message!") msg.signature = sign(msg, peer.privateKey).expect("Couldn't sign message!")
msg.key = peer.privateKey.getPublicKey().expect("Invalid private key!") msg.key = peer.privateKey.getPublicKey().expect("Invalid private key!")
.getBytes().expect("Couldn't get public key bytes!") .getBytes().expect("Couldn't get public key bytes!")
elif sign: else:
raise (ref LPError)(msg: "Cannot sign message without peer info") if sign: raise (ref LPError)(msg: "Cannot sign message without peer info")
msg msg
@ -91,6 +90,6 @@ proc init*(
var msg = Message(data: data, topicIDs: @[topic]) var msg = Message(data: data, topicIDs: @[topic])
msg.fromPeer = peerId msg.fromPeer = peerId
if seqno.isSome: seqno.withValue(seqn):
msg.seqno = @(seqno.get().toBytesBE()) msg.seqno = @(seqn.toBytesBE())
msg msg

View File

@ -110,15 +110,8 @@ func shortLog*(msg: Message): auto =
) )
func shortLog*(m: RPCMsg): auto = func shortLog*(m: RPCMsg): auto =
if m.control.isSome: (
( subscriptions: m.subscriptions,
subscriptions: m.subscriptions, messages: mapIt(m.messages, it.shortLog),
messages: mapIt(m.messages, it.shortLog), control: m.control.get(ControlMessage()).shortLog
control: m.control.get().shortLog )
)
else:
(
subscriptions: m.subscriptions,
messages: mapIt(m.messages, it.shortLog),
control: ControlMessage().shortLog
)

View File

@ -314,8 +314,8 @@ proc encodeRpcMsg*(msg: RPCMsg, anonymize: bool): seq[byte] =
pb.write(1, item) pb.write(1, item)
for item in msg.messages: for item in msg.messages:
pb.write(2, item, anonymize) pb.write(2, item, anonymize)
if msg.control.isSome(): msg.control.withValue(control):
pb.write(3, msg.control.get()) pb.write(3, control)
# nim-libp2p extension, using fields which are unlikely to be used # nim-libp2p extension, using fields which are unlikely to be used
# by other extensions # by other extensions
if msg.ping.len > 0: if msg.ping.len > 0:
@ -329,10 +329,10 @@ proc encodeRpcMsg*(msg: RPCMsg, anonymize: bool): seq[byte] =
proc decodeRpcMsg*(msg: seq[byte]): ProtoResult[RPCMsg] {.inline.} = proc decodeRpcMsg*(msg: seq[byte]): ProtoResult[RPCMsg] {.inline.} =
trace "decodeRpcMsg: decoding message", msg = msg.shortLog() trace "decodeRpcMsg: decoding message", msg = msg.shortLog()
var pb = initProtoBuffer(msg, maxSize = uint.high) var pb = initProtoBuffer(msg, maxSize = uint.high)
var rpcMsg = ok(RPCMsg()) var rpcMsg = RPCMsg()
assign(rpcMsg.get().messages, ? pb.decodeMessages()) assign(rpcMsg.messages, ? pb.decodeMessages())
assign(rpcMsg.get().subscriptions, ? pb.decodeSubscriptions()) assign(rpcMsg.subscriptions, ? pb.decodeSubscriptions())
assign(rpcMsg.get().control, ? pb.decodeControl()) assign(rpcMsg.control, ? pb.decodeControl())
discard ? pb.getField(60, rpcMsg.get().ping) discard ? pb.getField(60, rpcMsg.ping)
discard ? pb.getField(61, rpcMsg.get().pong) discard ? pb.getField(61, rpcMsg.pong)
rpcMsg ok(rpcMsg)

View File

@ -13,6 +13,8 @@ import std/[tables]
import chronos/timer, stew/results import chronos/timer, stew/results
import ../../utility
const Timeout* = 10.seconds # default timeout in ms const Timeout* = 10.seconds # default timeout in ms
type type
@ -55,9 +57,9 @@ func put*[K](t: var TimedCache[K], k: K, now = Moment.now()): bool =
var previous = t.del(k) # Refresh existing item var previous = t.del(k) # Refresh existing item
let addedAt = var addedAt = now
if previous.isSome: previous.get().addedAt previous.withValue(previous):
else: now addedAt = previous.addedAt
let node = TimedEntry[K](key: k, addedAt: addedAt, expiresAt: now + t.timeout) let node = TimedEntry[K](key: k, addedAt: addedAt, expiresAt: now + t.timeout)

View File

@ -9,12 +9,12 @@
{.push raises: [].} {.push raises: [].}
import tables, sequtils, sugar, sets, options import tables, sequtils, sugar, sets
import metrics except collect import metrics except collect
import chronos, import chronos,
chronicles, chronicles,
bearssl/rand, bearssl/rand,
stew/[byteutils, objects] stew/[byteutils, objects, results]
import ./protocol, import ./protocol,
../switch, ../switch,
../routing_record, ../routing_record,
@ -68,34 +68,34 @@ type
Register = object Register = object
ns : string ns : string
signedPeerRecord: seq[byte] signedPeerRecord: seq[byte]
ttl: Option[uint64] # in seconds ttl: Opt[uint64] # in seconds
RegisterResponse = object RegisterResponse = object
status: ResponseStatus status: ResponseStatus
text: Option[string] text: Opt[string]
ttl: Option[uint64] # in seconds ttl: Opt[uint64] # in seconds
Unregister = object Unregister = object
ns: string ns: string
Discover = object Discover = object
ns: string ns: string
limit: Option[uint64] limit: Opt[uint64]
cookie: Option[seq[byte]] cookie: Opt[seq[byte]]
DiscoverResponse = object DiscoverResponse = object
registrations: seq[Register] registrations: seq[Register]
cookie: Option[seq[byte]] cookie: Opt[seq[byte]]
status: ResponseStatus status: ResponseStatus
text: Option[string] text: Opt[string]
Message = object Message = object
msgType: MessageType msgType: MessageType
register: Option[Register] register: Opt[Register]
registerResponse: Option[RegisterResponse] registerResponse: Opt[RegisterResponse]
unregister: Option[Unregister] unregister: Opt[Unregister]
discover: Option[Discover] discover: Opt[Discover]
discoverResponse: Option[DiscoverResponse] discoverResponse: Opt[DiscoverResponse]
proc encode(c: Cookie): ProtoBuffer = proc encode(c: Cookie): ProtoBuffer =
result = initProtoBuffer() result = initProtoBuffer()
@ -107,17 +107,17 @@ proc encode(r: Register): ProtoBuffer =
result = initProtoBuffer() result = initProtoBuffer()
result.write(1, r.ns) result.write(1, r.ns)
result.write(2, r.signedPeerRecord) result.write(2, r.signedPeerRecord)
if r.ttl.isSome(): r.ttl.withValue(ttl):
result.write(3, r.ttl.get()) result.write(3, ttl)
result.finish() result.finish()
proc encode(rr: RegisterResponse): ProtoBuffer = proc encode(rr: RegisterResponse): ProtoBuffer =
result = initProtoBuffer() result = initProtoBuffer()
result.write(1, rr.status.uint) result.write(1, rr.status.uint)
if rr.text.isSome(): rr.text.withValue(text):
result.write(2, rr.text.get()) result.write(2, text)
if rr.ttl.isSome(): rr.ttl.withValue(ttl):
result.write(3, rr.ttl.get()) result.write(3, ttl)
result.finish() result.finish()
proc encode(u: Unregister): ProtoBuffer = proc encode(u: Unregister): ProtoBuffer =
@ -128,48 +128,48 @@ proc encode(u: Unregister): ProtoBuffer =
proc encode(d: Discover): ProtoBuffer = proc encode(d: Discover): ProtoBuffer =
result = initProtoBuffer() result = initProtoBuffer()
result.write(1, d.ns) result.write(1, d.ns)
if d.limit.isSome(): d.limit.withValue(limit):
result.write(2, d.limit.get()) result.write(2, limit)
if d.cookie.isSome(): d.cookie.withValue(cookie):
result.write(3, d.cookie.get()) result.write(3, cookie)
result.finish() result.finish()
proc encode(d: DiscoverResponse): ProtoBuffer = proc encode(dr: DiscoverResponse): ProtoBuffer =
result = initProtoBuffer() result = initProtoBuffer()
for reg in d.registrations: for reg in dr.registrations:
result.write(1, reg.encode()) result.write(1, reg.encode())
if d.cookie.isSome(): dr.cookie.withValue(cookie):
result.write(2, d.cookie.get()) result.write(2, cookie)
result.write(3, d.status.uint) result.write(3, dr.status.uint)
if d.text.isSome(): dr.text.withValue(text):
result.write(4, d.text.get()) result.write(4, text)
result.finish() result.finish()
proc encode(msg: Message): ProtoBuffer = proc encode(msg: Message): ProtoBuffer =
result = initProtoBuffer() result = initProtoBuffer()
result.write(1, msg.msgType.uint) result.write(1, msg.msgType.uint)
if msg.register.isSome(): msg.register.withValue(register):
result.write(2, msg.register.get().encode()) result.write(2, register.encode())
if msg.registerResponse.isSome(): msg.registerResponse.withValue(registerResponse):
result.write(3, msg.registerResponse.get().encode()) result.write(3, registerResponse.encode())
if msg.unregister.isSome(): msg.unregister.withValue(unregister):
result.write(4, msg.unregister.get().encode()) result.write(4, unregister.encode())
if msg.discover.isSome(): msg.discover.withValue(discover):
result.write(5, msg.discover.get().encode()) result.write(5, discover.encode())
if msg.discoverResponse.isSome(): msg.discoverResponse.withValue(discoverResponse):
result.write(6, msg.discoverResponse.get().encode()) result.write(6, discoverResponse.encode())
result.finish() result.finish()
proc decode(_: typedesc[Cookie], buf: seq[byte]): Option[Cookie] = proc decode(_: typedesc[Cookie], buf: seq[byte]): Opt[Cookie] =
var c: Cookie var c: Cookie
let let
pb = initProtoBuffer(buf) pb = initProtoBuffer(buf)
r1 = pb.getRequiredField(1, c.offset) r1 = pb.getRequiredField(1, c.offset)
r2 = pb.getRequiredField(2, c.ns) r2 = pb.getRequiredField(2, c.ns)
if r1.isErr() or r2.isErr(): return none(Cookie) if r1.isErr() or r2.isErr(): return Opt.none(Cookie)
some(c) Opt.some(c)
proc decode(_: typedesc[Register], buf: seq[byte]): Option[Register] = proc decode(_: typedesc[Register], buf: seq[byte]): Opt[Register] =
var var
r: Register r: Register
ttl: uint64 ttl: uint64
@ -178,11 +178,11 @@ proc decode(_: typedesc[Register], buf: seq[byte]): Option[Register] =
r1 = pb.getRequiredField(1, r.ns) r1 = pb.getRequiredField(1, r.ns)
r2 = pb.getRequiredField(2, r.signedPeerRecord) r2 = pb.getRequiredField(2, r.signedPeerRecord)
r3 = pb.getField(3, ttl) r3 = pb.getField(3, ttl)
if r1.isErr() or r2.isErr() or r3.isErr(): return none(Register) if r1.isErr() or r2.isErr() or r3.isErr(): return Opt.none(Register)
if r3.get(): r.ttl = some(ttl) if r3.get(false): r.ttl = Opt.some(ttl)
some(r) Opt.some(r)
proc decode(_: typedesc[RegisterResponse], buf: seq[byte]): Option[RegisterResponse] = proc decode(_: typedesc[RegisterResponse], buf: seq[byte]): Opt[RegisterResponse] =
var var
rr: RegisterResponse rr: RegisterResponse
statusOrd: uint statusOrd: uint
@ -194,20 +194,20 @@ proc decode(_: typedesc[RegisterResponse], buf: seq[byte]): Option[RegisterRespo
r2 = pb.getField(2, text) r2 = pb.getField(2, text)
r3 = pb.getField(3, ttl) r3 = pb.getField(3, ttl)
if r1.isErr() or r2.isErr() or r3.isErr() or if r1.isErr() or r2.isErr() or r3.isErr() or
not checkedEnumAssign(rr.status, statusOrd): return none(RegisterResponse) not checkedEnumAssign(rr.status, statusOrd): return Opt.none(RegisterResponse)
if r2.get(): rr.text = some(text) if r2.get(false): rr.text = Opt.some(text)
if r3.get(): rr.ttl = some(ttl) if r3.get(false): rr.ttl = Opt.some(ttl)
some(rr) Opt.some(rr)
proc decode(_: typedesc[Unregister], buf: seq[byte]): Option[Unregister] = proc decode(_: typedesc[Unregister], buf: seq[byte]): Opt[Unregister] =
var u: Unregister var u: Unregister
let let
pb = initProtoBuffer(buf) pb = initProtoBuffer(buf)
r1 = pb.getRequiredField(1, u.ns) r1 = pb.getRequiredField(1, u.ns)
if r1.isErr(): return none(Unregister) if r1.isErr(): return Opt.none(Unregister)
some(u) Opt.some(u)
proc decode(_: typedesc[Discover], buf: seq[byte]): Option[Discover] = proc decode(_: typedesc[Discover], buf: seq[byte]): Opt[Discover] =
var var
d: Discover d: Discover
limit: uint64 limit: uint64
@ -217,12 +217,12 @@ proc decode(_: typedesc[Discover], buf: seq[byte]): Option[Discover] =
r1 = pb.getRequiredField(1, d.ns) r1 = pb.getRequiredField(1, d.ns)
r2 = pb.getField(2, limit) r2 = pb.getField(2, limit)
r3 = pb.getField(3, cookie) r3 = pb.getField(3, cookie)
if r1.isErr() or r2.isErr() or r3.isErr: return none(Discover) if r1.isErr() or r2.isErr() or r3.isErr: return Opt.none(Discover)
if r2.get(): d.limit = some(limit) if r2.get(false): d.limit = Opt.some(limit)
if r3.get(): d.cookie = some(cookie) if r3.get(false): d.cookie = Opt.some(cookie)
some(d) Opt.some(d)
proc decode(_: typedesc[DiscoverResponse], buf: seq[byte]): Option[DiscoverResponse] = proc decode(_: typedesc[DiscoverResponse], buf: seq[byte]): Opt[DiscoverResponse] =
var var
dr: DiscoverResponse dr: DiscoverResponse
registrations: seq[seq[byte]] registrations: seq[seq[byte]]
@ -236,48 +236,47 @@ proc decode(_: typedesc[DiscoverResponse], buf: seq[byte]): Option[DiscoverRespo
r3 = pb.getRequiredField(3, statusOrd) r3 = pb.getRequiredField(3, statusOrd)
r4 = pb.getField(4, text) r4 = pb.getField(4, text)
if r1.isErr() or r2.isErr() or r3.isErr or r4.isErr() or if r1.isErr() or r2.isErr() or r3.isErr or r4.isErr() or
not checkedEnumAssign(dr.status, statusOrd): return none(DiscoverResponse) not checkedEnumAssign(dr.status, statusOrd): return Opt.none(DiscoverResponse)
for reg in registrations: for reg in registrations:
var r: Register var r: Register
let regOpt = Register.decode(reg) let regOpt = Register.decode(reg).valueOr:
if regOpt.isNone(): return none(DiscoverResponse) return
dr.registrations.add(regOpt.get()) dr.registrations.add(regOpt)
if r2.get(): dr.cookie = some(cookie) if r2.get(false): dr.cookie = Opt.some(cookie)
if r4.get(): dr.text = some(text) if r4.get(false): dr.text = Opt.some(text)
some(dr) Opt.some(dr)
proc decode(_: typedesc[Message], buf: seq[byte]): Option[Message] = proc decode(_: typedesc[Message], buf: seq[byte]): Opt[Message] =
var var
msg: Message msg: Message
statusOrd: uint statusOrd: uint
pbr, pbrr, pbu, pbd, pbdr: ProtoBuffer pbr, pbrr, pbu, pbd, pbdr: ProtoBuffer
let let pb = initProtoBuffer(buf)
pb = initProtoBuffer(buf)
r1 = pb.getRequiredField(1, statusOrd) ? pb.getRequiredField(1, statusOrd).toOpt
r2 = pb.getField(2, pbr) if not checkedEnumAssign(msg.msgType, statusOrd): return Opt.none(Message)
r3 = pb.getField(3, pbrr)
r4 = pb.getField(4, pbu) if ? pb.getField(2, pbr).optValue:
r5 = pb.getField(5, pbd)
r6 = pb.getField(6, pbdr)
if r1.isErr() or r2.isErr() or r3.isErr() or
r4.isErr() or r5.isErr() or r6.isErr() or
not checkedEnumAssign(msg.msgType, statusOrd): return none(Message)
if r2.get():
msg.register = Register.decode(pbr.buffer) msg.register = Register.decode(pbr.buffer)
if msg.register.isNone(): return none(Message) if msg.register.isNone(): return Opt.none(Message)
if r3.get():
if ? pb.getField(3, pbrr).optValue:
msg.registerResponse = RegisterResponse.decode(pbrr.buffer) msg.registerResponse = RegisterResponse.decode(pbrr.buffer)
if msg.registerResponse.isNone(): return none(Message) if msg.registerResponse.isNone(): return Opt.none(Message)
if r4.get():
if ? pb.getField(4, pbu).optValue:
msg.unregister = Unregister.decode(pbu.buffer) msg.unregister = Unregister.decode(pbu.buffer)
if msg.unregister.isNone(): return none(Message) if msg.unregister.isNone(): return Opt.none(Message)
if r5.get():
if ? pb.getField(5, pbd).optValue:
msg.discover = Discover.decode(pbd.buffer) msg.discover = Discover.decode(pbd.buffer)
if msg.discover.isNone(): return none(Message) if msg.discover.isNone(): return Opt.none(Message)
if r6.get():
if ? pb.getField(6, pbdr).optValue:
msg.discoverResponse = DiscoverResponse.decode(pbdr.buffer) msg.discoverResponse = DiscoverResponse.decode(pbdr.buffer)
if msg.discoverResponse.isNone(): return none(Message) if msg.discoverResponse.isNone(): return Opt.none(Message)
some(msg)
Opt.some(msg)
type type
@ -317,7 +316,7 @@ proc sendRegisterResponse(conn: Connection,
ttl: uint64) {.async.} = ttl: uint64) {.async.} =
let msg = encode(Message( let msg = encode(Message(
msgType: MessageType.RegisterResponse, msgType: MessageType.RegisterResponse,
registerResponse: some(RegisterResponse(status: Ok, ttl: some(ttl))))) registerResponse: Opt.some(RegisterResponse(status: Ok, ttl: Opt.some(ttl)))))
await conn.writeLp(msg.buffer) await conn.writeLp(msg.buffer)
proc sendRegisterResponseError(conn: Connection, proc sendRegisterResponseError(conn: Connection,
@ -325,7 +324,7 @@ proc sendRegisterResponseError(conn: Connection,
text: string = "") {.async.} = text: string = "") {.async.} =
let msg = encode(Message( let msg = encode(Message(
msgType: MessageType.RegisterResponse, msgType: MessageType.RegisterResponse,
registerResponse: some(RegisterResponse(status: status, text: some(text))))) registerResponse: Opt.some(RegisterResponse(status: status, text: Opt.some(text)))))
await conn.writeLp(msg.buffer) await conn.writeLp(msg.buffer)
proc sendDiscoverResponse(conn: Connection, proc sendDiscoverResponse(conn: Connection,
@ -333,10 +332,10 @@ proc sendDiscoverResponse(conn: Connection,
cookie: Cookie) {.async.} = cookie: Cookie) {.async.} =
let msg = encode(Message( let msg = encode(Message(
msgType: MessageType.DiscoverResponse, msgType: MessageType.DiscoverResponse,
discoverResponse: some(DiscoverResponse( discoverResponse: Opt.some(DiscoverResponse(
status: Ok, status: Ok,
registrations: s, registrations: s,
cookie: some(cookie.encode().buffer) cookie: Opt.some(cookie.encode().buffer)
)) ))
)) ))
await conn.writeLp(msg.buffer) await conn.writeLp(msg.buffer)
@ -346,7 +345,7 @@ proc sendDiscoverResponseError(conn: Connection,
text: string = "") {.async.} = text: string = "") {.async.} =
let msg = encode(Message( let msg = encode(Message(
msgType: MessageType.DiscoverResponse, msgType: MessageType.DiscoverResponse,
discoverResponse: some(DiscoverResponse(status: status, text: some(text))))) discoverResponse: Opt.some(DiscoverResponse(status: status, text: Opt.some(text)))))
await conn.writeLp(msg.buffer) await conn.writeLp(msg.buffer)
proc countRegister(rdv: RendezVous, peerId: PeerId): int = proc countRegister(rdv: RendezVous, peerId: PeerId): int =
@ -419,7 +418,7 @@ proc discover(rdv: RendezVous, conn: Connection, d: Discover) {.async.} =
cookie = cookie =
if d.cookie.isSome(): if d.cookie.isSome():
try: try:
Cookie.decode(d.cookie.get()).get() Cookie.decode(d.cookie.tryGet()).tryGet()
except CatchableError: except CatchableError:
await conn.sendDiscoverResponseError(InvalidCookie) await conn.sendDiscoverResponseError(InvalidCookie)
return return
@ -450,7 +449,7 @@ proc discover(rdv: RendezVous, conn: Connection, d: Discover) {.async.} =
break break
if reg.expiration < n or index.uint64 <= cookie.offset: continue if reg.expiration < n or index.uint64 <= cookie.offset: continue
limit.dec() limit.dec()
reg.data.ttl = some((reg.expiration - Moment.now()).seconds.uint64) reg.data.ttl = Opt.some((reg.expiration - Moment.now()).seconds.uint64)
reg.data reg.data
rdv.rng.shuffle(s) rdv.rng.shuffle(s)
await conn.sendDiscoverResponse(s, Cookie(offset: offset.uint64, ns: d.ns)) await conn.sendDiscoverResponse(s, Cookie(offset: offset.uint64, ns: d.ns))
@ -465,11 +464,10 @@ proc advertisePeer(rdv: RendezVous,
await conn.writeLp(msg) await conn.writeLp(msg)
let let
buf = await conn.readLp(4096) buf = await conn.readLp(4096)
msgRecv = Message.decode(buf).get() msgRecv = Message.decode(buf).tryGet()
if msgRecv.msgType != MessageType.RegisterResponse: if msgRecv.msgType != MessageType.RegisterResponse:
trace "Unexpected register response", peer, msgType = msgRecv.msgType trace "Unexpected register response", peer, msgType = msgRecv.msgType
elif msgRecv.registerResponse.isNone() or elif msgRecv.registerResponse.tryGet().status != ResponseStatus.Ok:
msgRecv.registerResponse.get().status != ResponseStatus.Ok:
trace "Refuse to register", peer, response = msgRecv.registerResponse trace "Refuse to register", peer, response = msgRecv.registerResponse
except CatchableError as exc: except CatchableError as exc:
trace "exception in the advertise", error = exc.msg trace "exception in the advertise", error = exc.msg
@ -481,16 +479,15 @@ proc advertisePeer(rdv: RendezVous,
proc advertise*(rdv: RendezVous, proc advertise*(rdv: RendezVous,
ns: string, ns: string,
ttl: Duration = MinimumDuration) {.async.} = ttl: Duration = MinimumDuration) {.async.} =
let sprBuff = rdv.switch.peerInfo.signedPeerRecord.encode() let sprBuff = rdv.switch.peerInfo.signedPeerRecord.encode().valueOr:
if sprBuff.isErr():
raise newException(RendezVousError, "Wrong Signed Peer Record") raise newException(RendezVousError, "Wrong Signed Peer Record")
if ns.len notin 1..255: if ns.len notin 1..255:
raise newException(RendezVousError, "Invalid namespace") raise newException(RendezVousError, "Invalid namespace")
if ttl notin MinimumDuration..MaximumDuration: if ttl notin MinimumDuration..MaximumDuration:
raise newException(RendezVousError, "Invalid time to live") raise newException(RendezVousError, "Invalid time to live")
let let
r = Register(ns: ns, signedPeerRecord: sprBuff.get(), ttl: some(ttl.seconds.uint64)) r = Register(ns: ns, signedPeerRecord: sprBuff, ttl: Opt.some(ttl.seconds.uint64))
msg = encode(Message(msgType: MessageType.Register, register: some(r))) msg = encode(Message(msgType: MessageType.Register, register: Opt.some(r)))
rdv.save(ns, rdv.switch.peerInfo.peerId, r) rdv.save(ns, rdv.switch.peerInfo.peerId, r)
let fut = collect(newSeq()): let fut = collect(newSeq()):
for peer in rdv.peers: for peer in rdv.peers:
@ -506,7 +503,9 @@ proc requestLocally*(rdv: RendezVous, ns: string): seq[PeerRecord] =
collect(newSeq()): collect(newSeq()):
for index in rdv.namespaces[nsSalted]: for index in rdv.namespaces[nsSalted]:
if rdv.registered[index].expiration > n: if rdv.registered[index].expiration > n:
SignedPeerRecord.decode(rdv.registered[index].data.signedPeerRecord).get().data let res = SignedPeerRecord.decode(rdv.registered[index].data.signedPeerRecord).valueOr:
continue
res.data
except KeyError as exc: except KeyError as exc:
@[] @[]
@ -527,38 +526,42 @@ proc request*(rdv: RendezVous,
proc requestPeer(peer: PeerId) {.async.} = proc requestPeer(peer: PeerId) {.async.} =
let conn = await rdv.switch.dial(peer, RendezVousCodec) let conn = await rdv.switch.dial(peer, RendezVousCodec)
defer: await conn.close() defer: await conn.close()
d.limit = some(limit) d.limit = Opt.some(limit)
d.cookie = d.cookie =
try: try:
some(rdv.cookiesSaved[peer][ns]) Opt.some(rdv.cookiesSaved[peer][ns])
except KeyError as exc: except KeyError as exc:
none(seq[byte]) Opt.none(seq[byte])
await conn.writeLp(encode(Message( await conn.writeLp(encode(Message(
msgType: MessageType.Discover, msgType: MessageType.Discover,
discover: some(d))).buffer) discover: Opt.some(d))).buffer)
let let
buf = await conn.readLp(65536) buf = await conn.readLp(65536)
msgRcv = Message.decode(buf).get() msgRcv = Message.decode(buf).valueOr:
if msgRcv.msgType != MessageType.DiscoverResponse or debug "Message undecodable"
msgRcv.discoverResponse.isNone(): return
if msgRcv.msgType != MessageType.DiscoverResponse:
debug "Unexpected discover response", msgType = msgRcv.msgType debug "Unexpected discover response", msgType = msgRcv.msgType
return return
let resp = msgRcv.discoverResponse.get() let resp = msgRcv.discoverResponse.valueOr:
debug "Discover response is empty"
return
if resp.status != ResponseStatus.Ok: if resp.status != ResponseStatus.Ok:
trace "Cannot discover", ns, status = resp.status, text = resp.text trace "Cannot discover", ns, status = resp.status, text = resp.text
return return
if resp.cookie.isSome() and resp.cookie.get().len < 1000: resp.cookie.withValue(cookie):
if rdv.cookiesSaved.hasKeyOrPut(peer, {ns: resp.cookie.get()}.toTable): if cookie.len() < 1000 and rdv.cookiesSaved.hasKeyOrPut(peer, {ns: cookie}.toTable()):
rdv.cookiesSaved[peer][ns] = resp.cookie.get() rdv.cookiesSaved[peer][ns] = cookie
for r in resp.registrations: for r in resp.registrations:
if limit == 0: return if limit == 0: return
if r.ttl.isNone() or r.ttl.get() > MaximumTTL: continue let ttl = r.ttl.get(MaximumTTL + 1)
let sprRes = SignedPeerRecord.decode(r.signedPeerRecord) if ttl > MaximumTTL: continue
if sprRes.isErr(): continue let
let pr = sprRes.get().data spr = SignedPeerRecord.decode(r.signedPeerRecord).valueOr: continue
pr = spr.data
if s.hasKey(pr.peerId): if s.hasKey(pr.peerId):
let (prSaved, rSaved) = s[pr.peerId] let (prSaved, rSaved) = s[pr.peerId]
if (prSaved.seqNo == pr.seqNo and rSaved.ttl.get() < r.ttl.get()) or if (prSaved.seqNo == pr.seqNo and rSaved.ttl.get(MaximumTTL) < ttl) or
prSaved.seqNo < pr.seqNo: prSaved.seqNo < pr.seqNo:
s[pr.peerId] = (pr, r) s[pr.peerId] = (pr, r)
else: else:
@ -597,7 +600,7 @@ proc unsubscribe*(rdv: RendezVous, ns: string) {.async.} =
rdv.unsubscribeLocally(ns) rdv.unsubscribeLocally(ns)
let msg = encode(Message( let msg = encode(Message(
msgType: MessageType.Unregister, msgType: MessageType.Unregister,
unregister: some(Unregister(ns: ns)))) unregister: Opt.some(Unregister(ns: ns))))
proc unsubscribePeer(rdv: RendezVous, peerId: PeerId) {.async.} = proc unsubscribePeer(rdv: RendezVous, peerId: PeerId) {.async.} =
try: try:
@ -635,13 +638,13 @@ proc new*(T: typedesc[RendezVous],
try: try:
let let
buf = await conn.readLp(4096) buf = await conn.readLp(4096)
msg = Message.decode(buf).get() msg = Message.decode(buf).tryGet()
case msg.msgType: case msg.msgType:
of MessageType.Register: await rdv.register(conn, msg.register.get()) of MessageType.Register: await rdv.register(conn, msg.register.tryGet())
of MessageType.RegisterResponse: of MessageType.RegisterResponse:
trace "Got an unexpected Register Response", response = msg.registerResponse trace "Got an unexpected Register Response", response = msg.registerResponse
of MessageType.Unregister: rdv.unregister(conn, msg.unregister.get()) of MessageType.Unregister: rdv.unregister(conn, msg.unregister.tryGet())
of MessageType.Discover: await rdv.discover(conn, msg.discover.get()) of MessageType.Discover: await rdv.discover(conn, msg.discover.tryGet())
of MessageType.DiscoverResponse: of MessageType.DiscoverResponse:
trace "Got an unexpected Discover Response", response = msg.discoverResponse trace "Got an unexpected Discover Response", response = msg.discoverResponse
except CancelledError as exc: except CancelledError as exc:

View File

@ -554,8 +554,7 @@ method handshake*(p: Noise, conn: Connection, initiator: bool, peerId: Opt[PeerI
trace "Remote peer id", pid = $pid trace "Remote peer id", pid = $pid
if peerId.isSome(): peerId.withValue(targetPid):
let targetPid = peerId.get()
if not targetPid.validate(): if not targetPid.validate():
raise newException(NoiseHandshakeError, "Failed to validate expected peerId.") raise newException(NoiseHandshakeError, "Failed to validate expected peerId.")

View File

@ -339,8 +339,7 @@ method handshake*(s: Secio, conn: Connection, initiator: bool, peerId: Opt[PeerI
remotePeerId = PeerId.init(remotePubkey).tryGet() remotePeerId = PeerId.init(remotePubkey).tryGet()
if peerId.isSome(): peerId.withValue(targetPid):
let targetPid = peerId.get()
if not targetPid.validate(): if not targetPid.validate():
raise newException(SecioError, "Failed to validate expected peerId.") raise newException(SecioError, "Failed to validate expected peerId.")
@ -436,14 +435,10 @@ proc new*(
T: typedesc[Secio], T: typedesc[Secio],
rng: ref HmacDrbgContext, rng: ref HmacDrbgContext,
localPrivateKey: PrivateKey): T = localPrivateKey: PrivateKey): T =
let pkRes = localPrivateKey.getPublicKey()
if pkRes.isErr:
raise newException(Defect, "Invalid private key")
let secio = Secio( let secio = Secio(
rng: rng, rng: rng,
localPrivateKey: localPrivateKey, localPrivateKey: localPrivateKey,
localPublicKey: pkRes.get(), localPublicKey: localPrivateKey.getPublicKey().expect("Invalid private key"),
) )
secio.init() secio.init()
secio secio

View File

@ -42,14 +42,12 @@ proc decode*(
? pb.getRequiredField(2, record.seqNo) ? pb.getRequiredField(2, record.seqNo)
var addressInfos: seq[seq[byte]] var addressInfos: seq[seq[byte]]
let pb3 = ? pb.getRepeatedField(3, addressInfos) if ? pb.getRepeatedField(3, addressInfos):
if pb3:
for address in addressInfos: for address in addressInfos:
var addressInfo = AddressInfo() var addressInfo = AddressInfo()
let subProto = initProtoBuffer(address) let subProto = initProtoBuffer(address)
let f = subProto.getField(1, addressInfo.address) let f = subProto.getField(1, addressInfo.address)
if f.isOk() and f.get(): if f.get(false):
record.addresses &= addressInfo record.addresses &= addressInfo
if record.addresses.len == 0: if record.addresses.len == 0:

View File

@ -45,9 +45,7 @@ proc tryStartingDirectConn(self: HPService, switch: Switch, peerId: PeerId): Fut
for address in switch.peerStore[AddressBook][peerId]: for address in switch.peerStore[AddressBook][peerId]:
try: try:
let isRelayed = address.contains(multiCodec("p2p-circuit")) let isRelayed = address.contains(multiCodec("p2p-circuit"))
if isRelayed.isErr() or isRelayed.get(): if not isRelayed.get(false) and address.isPublicMA():
continue
if address.isPublicMA():
return await tryConnect(address) return await tryConnect(address)
except CatchableError as err: except CatchableError as err:
debug "Failed to create direct connection.", err = err.msg debug "Failed to create direct connection.", err = err.msg
@ -96,7 +94,7 @@ method setup*(self: HPService, switch: Switch): Future[bool] {.async.} =
switch.connManager.addPeerEventHandler(self.newConnectedPeerHandler, PeerEventKind.Joined) switch.connManager.addPeerEventHandler(self.newConnectedPeerHandler, PeerEventKind.Joined)
self.onNewStatusHandler = proc (networkReachability: NetworkReachability, confidence: Option[float]) {.gcsafe, async.} = self.onNewStatusHandler = proc (networkReachability: NetworkReachability, confidence: Opt[float]) {.gcsafe, async.} =
if networkReachability == NetworkReachability.NotReachable and not self.autoRelayService.isRunning(): if networkReachability == NetworkReachability.NotReachable and not self.autoRelayService.isRunning():
discard await self.autoRelayService.setup(switch) discard await self.autoRelayService.setup(switch)
elif networkReachability == NetworkReachability.Reachable and self.autoRelayService.isRunning(): elif networkReachability == NetworkReachability.Reachable and self.autoRelayService.isRunning():

View File

@ -112,19 +112,12 @@ proc getField*(pb: ProtoBuffer, field: int,
if not(res): if not(res):
ok(false) ok(false)
else: else:
let env = Envelope.decode(buffer, domain) value = Envelope.decode(buffer, domain).valueOr: return err(ProtoError.IncorrectBlob)
if env.isOk(): ok(true)
value = env.get()
ok(true)
else:
err(ProtoError.IncorrectBlob)
proc write*(pb: var ProtoBuffer, field: int, env: Envelope): Result[void, CryptoError] = proc write*(pb: var ProtoBuffer, field: int, env: Envelope): Result[void, CryptoError] =
let e = env.encode() let e = ? env.encode()
pb.write(field, e)
if e.isErr():
return err(e.error)
pb.write(field, e.get())
ok() ok()
type type
@ -142,7 +135,7 @@ proc init*[T](_: typedesc[SignedPayload[T]],
T.payloadType(), T.payloadType(),
data.encode(), data.encode(),
T.payloadDomain) T.payloadDomain)
ok(SignedPayload[T](data: data, envelope: envelope)) ok(SignedPayload[T](data: data, envelope: envelope))
proc getField*[T](pb: ProtoBuffer, field: int, proc getField*[T](pb: ProtoBuffer, field: int,

View File

@ -141,7 +141,9 @@ proc parseOnion3(address: MultiAddress): (byte, seq[byte], seq[byte]) {.raises:
dstPort = address.data.buffer[37..38] dstPort = address.data.buffer[37..38]
return (Socks5AddressType.FQDN.byte, dstAddr, dstPort) return (Socks5AddressType.FQDN.byte, dstAddr, dstPort)
proc parseIpTcp(address: MultiAddress): (byte, seq[byte], seq[byte]) {.raises: [LPError, ValueError].} = proc parseIpTcp(address: MultiAddress):
(byte, seq[byte], seq[byte])
{.raises: [LPError, ValueError].} =
let (codec, atyp) = let (codec, atyp) =
if IPv4Tcp.match(address): if IPv4Tcp.match(address):
(multiCodec("ip4"), Socks5AddressType.IPv4.byte) (multiCodec("ip4"), Socks5AddressType.IPv4.byte)
@ -150,15 +152,17 @@ proc parseIpTcp(address: MultiAddress): (byte, seq[byte], seq[byte]) {.raises: [
else: else:
raise newException(LPError, fmt"IP address not supported {address}") raise newException(LPError, fmt"IP address not supported {address}")
let let
dstAddr = address[codec].get().protoArgument().get() dstAddr = address[codec].tryGet().protoArgument().tryGet()
dstPort = address[multiCodec("tcp")].get().protoArgument().get() dstPort = address[multiCodec("tcp")].tryGet().protoArgument().tryGet()
(atyp, dstAddr, dstPort) (atyp, dstAddr, dstPort)
proc parseDnsTcp(address: MultiAddress): (byte, seq[byte], seq[byte]) = proc parseDnsTcp(address: MultiAddress):
(byte, seq[byte], seq[byte])
{.raises: [LPError, ValueError].} =
let let
dnsAddress = address[multiCodec("dns")].get().protoArgument().get() dnsAddress = address[multiCodec("dns")].tryGet().protoArgument().tryGet()
dstAddr = @(uint8(dnsAddress.len).toBytes()) & dnsAddress dstAddr = @(uint8(dnsAddress.len).toBytes()) & dnsAddress
dstPort = address[multiCodec("tcp")].get().protoArgument().get() dstPort = address[multiCodec("tcp")].tryGet().protoArgument().tryGet()
(Socks5AddressType.FQDN.byte, dstAddr, dstPort) (Socks5AddressType.FQDN.byte, dstAddr, dstPort)
proc dialPeer( proc dialPeer(
@ -214,9 +218,9 @@ method start*(
warn "Invalid address detected, skipping!", address = ma warn "Invalid address detected, skipping!", address = ma
continue continue
let listenAddress = ma[0..1].get() let listenAddress = ma[0..1].tryGet()
listenAddrs.add(listenAddress) listenAddrs.add(listenAddress)
let onion3 = ma[multiCodec("onion3")].get() let onion3 = ma[multiCodec("onion3")].tryGet()
onion3Addrs.add(onion3) onion3Addrs.add(onion3)
if len(listenAddrs) != 0 and len(onion3Addrs) != 0: if len(listenAddrs) != 0 and len(onion3Addrs) != 0:

View File

@ -99,9 +99,8 @@ method handles*(
# by default we skip circuit addresses to avoid # by default we skip circuit addresses to avoid
# having to repeat the check in every transport # having to repeat the check in every transport
if address.protocols.isOk: let protocols = address.protocols.valueOr: return false
return address.protocols return protocols
.get()
.filterIt( .filterIt(
it == multiCodec("p2p-circuit") it == multiCodec("p2p-circuit")
).len == 0 ).len == 0

View File

@ -9,7 +9,10 @@
{.push raises: [].} {.push raises: [].}
import stew/byteutils import std/options, std/macros
import stew/[byteutils, results]
export results
template public* {.pragma.} template public* {.pragma.}
@ -50,9 +53,6 @@ when defined(libp2p_agents_metrics):
import strutils import strutils
export split export split
import stew/results
export results
proc safeToLowerAscii*(s: string): Result[string, cstring] = proc safeToLowerAscii*(s: string): Result[string, cstring] =
try: try:
ok(s.toLowerAscii()) ok(s.toLowerAscii())
@ -83,3 +83,30 @@ template exceptionToAssert*(body: untyped): untyped =
when defined(nimHasWarnBareExcept): when defined(nimHasWarnBareExcept):
{.pop.} {.pop.}
res res
template withValue*[T](self: Opt[T] | Option[T], value, body: untyped): untyped =
if self.isSome:
let value {.inject.} = self.get()
body
macro withValue*[T](self: Opt[T] | Option[T], value, body, body2: untyped): untyped =
let elseBody = body2[0]
quote do:
if `self`.isSome:
let `value` {.inject.} = `self`.get()
`body`
else:
`elseBody`
template valueOr*[T](self: Option[T], body: untyped): untyped =
if self.isSome:
self.get()
else:
body
template toOpt*[T, E](self: Result[T, E]): Opt[T] =
if self.isOk:
when T is void: Result[void, void].ok()
else: Opt.some(self.unsafeGet())
else:
Opt.none(type(T))

View File

@ -89,7 +89,7 @@ proc connect*(
compilesOr: compilesOr:
return connect(transportAddress, bufferSize, child, return connect(transportAddress, bufferSize, child,
if localAddress.isSome(): initTAddress(localAddress.get()).tryGet() else : TransportAddress(), if localAddress.isSome(): initTAddress(localAddress.expect("just checked")).tryGet() else: TransportAddress(),
flags) flags)
do: do:
# support for older chronos versions # support for older chronos versions
@ -152,7 +152,7 @@ proc createStreamServer*[T](ma: MultiAddress,
raise newException(LPError, exc.msg) raise newException(LPError, exc.msg)
proc createAsyncSocket*(ma: MultiAddress): AsyncFD proc createAsyncSocket*(ma: MultiAddress): AsyncFD
{.raises: [LPError].} = {.raises: [ValueError, LPError].} =
## Create new asynchronous socket using MultiAddress' ``ma`` socket type and ## Create new asynchronous socket using MultiAddress' ``ma`` socket type and
## protocol information. ## protocol information.
## ##

View File

@ -43,8 +43,8 @@ proc makeAutonatServicePrivate(): Switch =
discard await conn.readLp(1024) discard await conn.readLp(1024)
await conn.writeLp(AutonatDialResponse( await conn.writeLp(AutonatDialResponse(
status: DialError, status: DialError,
text: some("dial failed"), text: Opt.some("dial failed"),
ma: none(MultiAddress)).encode().buffer) ma: Opt.none(MultiAddress)).encode().buffer)
await conn.close() await conn.close()
autonatProtocol.codec = AutonatCodec autonatProtocol.codec = AutonatCodec
result = newStandardSwitch() result = newStandardSwitch()
@ -93,8 +93,8 @@ suite "Autonat":
await src.connect(dst.peerInfo.peerId, dst.peerInfo.addrs) await src.connect(dst.peerInfo.peerId, dst.peerInfo.addrs)
let conn = await src.dial(dst.peerInfo.peerId, @[AutonatCodec]) let conn = await src.dial(dst.peerInfo.peerId, @[AutonatCodec])
let buffer = AutonatDial(peerInfo: some(AutonatPeerInfo( let buffer = AutonatDial(peerInfo: Opt.some(AutonatPeerInfo(
id: some(src.peerInfo.peerId), id: Opt.some(src.peerInfo.peerId),
# we ask to be dialed in the does nothing listener instead # we ask to be dialed in the does nothing listener instead
addrs: doesNothingListener.addrs addrs: doesNothingListener.addrs
))).encode().buffer ))).encode().buffer

View File

@ -78,7 +78,7 @@ suite "Autonat Service":
asyncTest "Peer must be reachable": asyncTest "Peer must be reachable":
let autonatService = AutonatService.new(AutonatClient.new(), newRng(), some(1.seconds)) let autonatService = AutonatService.new(AutonatClient.new(), newRng(), Opt.some(1.seconds))
let switch1 = createSwitch(autonatService) let switch1 = createSwitch(autonatService)
let switch2 = createSwitch() let switch2 = createSwitch()
@ -87,7 +87,7 @@ suite "Autonat Service":
let awaiter = newFuture[void]() let awaiter = newFuture[void]()
proc statusAndConfidenceHandler(networkReachability: NetworkReachability, confidence: Option[float]) {.gcsafe, async.} = proc statusAndConfidenceHandler(networkReachability: NetworkReachability, confidence: Opt[float]) {.gcsafe, async.} =
if networkReachability == NetworkReachability.Reachable and confidence.isSome() and confidence.get() >= 0.3: if networkReachability == NetworkReachability.Reachable and confidence.isSome() and confidence.get() >= 0.3:
if not awaiter.finished: if not awaiter.finished:
awaiter.complete() awaiter.complete()
@ -122,7 +122,7 @@ suite "Autonat Service":
let autonatClientStub = AutonatClientStub.new(expectedDials = 6) let autonatClientStub = AutonatClientStub.new(expectedDials = 6)
autonatClientStub.answer = NotReachable autonatClientStub.answer = NotReachable
let autonatService = AutonatService.new(autonatClientStub, newRng(), some(1.seconds)) let autonatService = AutonatService.new(autonatClientStub, newRng(), Opt.some(1.seconds))
let switch1 = createSwitch(autonatService) let switch1 = createSwitch(autonatService)
let switch2 = createSwitch() let switch2 = createSwitch()
@ -131,7 +131,7 @@ suite "Autonat Service":
let awaiter = newFuture[void]() let awaiter = newFuture[void]()
proc statusAndConfidenceHandler(networkReachability: NetworkReachability, confidence: Option[float]) {.gcsafe, async.} = proc statusAndConfidenceHandler(networkReachability: NetworkReachability, confidence: Opt[float]) {.gcsafe, async.} =
if networkReachability == NetworkReachability.NotReachable and confidence.isSome() and confidence.get() >= 0.3: if networkReachability == NetworkReachability.NotReachable and confidence.isSome() and confidence.get() >= 0.3:
if not awaiter.finished: if not awaiter.finished:
autonatClientStub.answer = Reachable autonatClientStub.answer = Reachable
@ -164,7 +164,7 @@ suite "Autonat Service":
asyncTest "Peer must be reachable when one connected peer has autonat disabled": asyncTest "Peer must be reachable when one connected peer has autonat disabled":
let autonatService = AutonatService.new(AutonatClient.new(), newRng(), some(1.seconds), maxQueueSize = 2) let autonatService = AutonatService.new(AutonatClient.new(), newRng(), Opt.some(1.seconds), maxQueueSize = 2)
let switch1 = createSwitch(autonatService) let switch1 = createSwitch(autonatService)
let switch2 = createSwitch(withAutonat = false) let switch2 = createSwitch(withAutonat = false)
@ -173,7 +173,7 @@ suite "Autonat Service":
let awaiter = newFuture[void]() let awaiter = newFuture[void]()
proc statusAndConfidenceHandler(networkReachability: NetworkReachability, confidence: Option[float]) {.gcsafe, async.} = proc statusAndConfidenceHandler(networkReachability: NetworkReachability, confidence: Opt[float]) {.gcsafe, async.} =
if networkReachability == NetworkReachability.Reachable and confidence.isSome() and confidence.get() == 1: if networkReachability == NetworkReachability.Reachable and confidence.isSome() and confidence.get() == 1:
if not awaiter.finished: if not awaiter.finished:
awaiter.complete() awaiter.complete()
@ -204,7 +204,7 @@ suite "Autonat Service":
let autonatClientStub = AutonatClientStub.new(expectedDials = 6) let autonatClientStub = AutonatClientStub.new(expectedDials = 6)
autonatClientStub.answer = NotReachable autonatClientStub.answer = NotReachable
let autonatService = AutonatService.new(autonatClientStub, newRng(), some(1.seconds), maxQueueSize = 3) let autonatService = AutonatService.new(autonatClientStub, newRng(), Opt.some(1.seconds), maxQueueSize = 3)
let switch1 = createSwitch(autonatService) let switch1 = createSwitch(autonatService)
let switch2 = createSwitch() let switch2 = createSwitch()
@ -213,7 +213,7 @@ suite "Autonat Service":
let awaiter = newFuture[void]() let awaiter = newFuture[void]()
proc statusAndConfidenceHandler(networkReachability: NetworkReachability, confidence: Option[float]) {.gcsafe, async.} = proc statusAndConfidenceHandler(networkReachability: NetworkReachability, confidence: Opt[float]) {.gcsafe, async.} =
if networkReachability == NetworkReachability.NotReachable and confidence.isSome() and confidence.get() >= 0.3: if networkReachability == NetworkReachability.NotReachable and confidence.isSome() and confidence.get() >= 0.3:
if not awaiter.finished: if not awaiter.finished:
autonatClientStub.answer = Unknown autonatClientStub.answer = Unknown
@ -247,7 +247,7 @@ suite "Autonat Service":
asyncTest "Calling setup and stop twice must work": asyncTest "Calling setup and stop twice must work":
let switch = createSwitch() let switch = createSwitch()
let autonatService = AutonatService.new(AutonatClientStub.new(expectedDials = 0), newRng(), some(1.seconds)) let autonatService = AutonatService.new(AutonatClientStub.new(expectedDials = 0), newRng(), Opt.some(1.seconds))
check (await autonatService.setup(switch)) == true check (await autonatService.setup(switch)) == true
check (await autonatService.setup(switch)) == false check (await autonatService.setup(switch)) == false
@ -258,7 +258,7 @@ suite "Autonat Service":
await allFuturesThrowing(switch.stop()) await allFuturesThrowing(switch.stop())
asyncTest "Must bypass maxConnectionsPerPeer limit": asyncTest "Must bypass maxConnectionsPerPeer limit":
let autonatService = AutonatService.new(AutonatClient.new(), newRng(), some(1.seconds), maxQueueSize = 1) let autonatService = AutonatService.new(AutonatClient.new(), newRng(), Opt.some(1.seconds), maxQueueSize = 1)
let switch1 = createSwitch(autonatService, maxConnsPerPeer = 0) let switch1 = createSwitch(autonatService, maxConnsPerPeer = 0)
await switch1.setDNSAddr() await switch1.setDNSAddr()
@ -267,7 +267,7 @@ suite "Autonat Service":
let awaiter = newFuture[void]() let awaiter = newFuture[void]()
proc statusAndConfidenceHandler(networkReachability: NetworkReachability, confidence: Option[float]) {.gcsafe, async.} = proc statusAndConfidenceHandler(networkReachability: NetworkReachability, confidence: Opt[float]) {.gcsafe, async.} =
if networkReachability == NetworkReachability.Reachable and confidence.isSome() and confidence.get() == 1: if networkReachability == NetworkReachability.Reachable and confidence.isSome() and confidence.get() == 1:
if not awaiter.finished: if not awaiter.finished:
awaiter.complete() awaiter.complete()
@ -290,9 +290,9 @@ suite "Autonat Service":
switch1.stop(), switch2.stop()) switch1.stop(), switch2.stop())
asyncTest "Must work when peers ask each other at the same time with max 1 conn per peer": asyncTest "Must work when peers ask each other at the same time with max 1 conn per peer":
let autonatService1 = AutonatService.new(AutonatClient.new(), newRng(), some(500.millis), maxQueueSize = 3) let autonatService1 = AutonatService.new(AutonatClient.new(), newRng(), Opt.some(500.millis), maxQueueSize = 3)
let autonatService2 = AutonatService.new(AutonatClient.new(), newRng(), some(500.millis), maxQueueSize = 3) let autonatService2 = AutonatService.new(AutonatClient.new(), newRng(), Opt.some(500.millis), maxQueueSize = 3)
let autonatService3 = AutonatService.new(AutonatClient.new(), newRng(), some(500.millis), maxQueueSize = 3) let autonatService3 = AutonatService.new(AutonatClient.new(), newRng(), Opt.some(500.millis), maxQueueSize = 3)
let switch1 = createSwitch(autonatService1, maxConnsPerPeer = 0) let switch1 = createSwitch(autonatService1, maxConnsPerPeer = 0)
let switch2 = createSwitch(autonatService2, maxConnsPerPeer = 0) let switch2 = createSwitch(autonatService2, maxConnsPerPeer = 0)
@ -302,12 +302,12 @@ suite "Autonat Service":
let awaiter2 = newFuture[void]() let awaiter2 = newFuture[void]()
let awaiter3 = newFuture[void]() let awaiter3 = newFuture[void]()
proc statusAndConfidenceHandler1(networkReachability: NetworkReachability, confidence: Option[float]) {.gcsafe, async.} = proc statusAndConfidenceHandler1(networkReachability: NetworkReachability, confidence: Opt[float]) {.gcsafe, async.} =
if networkReachability == NetworkReachability.Reachable and confidence.isSome() and confidence.get() == 1: if networkReachability == NetworkReachability.Reachable and confidence.isSome() and confidence.get() == 1:
if not awaiter1.finished: if not awaiter1.finished:
awaiter1.complete() awaiter1.complete()
proc statusAndConfidenceHandler2(networkReachability: NetworkReachability, confidence: Option[float]) {.gcsafe, async.} = proc statusAndConfidenceHandler2(networkReachability: NetworkReachability, confidence: Opt[float]) {.gcsafe, async.} =
if networkReachability == NetworkReachability.Reachable and confidence.isSome() and confidence.get() == 1: if networkReachability == NetworkReachability.Reachable and confidence.isSome() and confidence.get() == 1:
if not awaiter2.finished: if not awaiter2.finished:
awaiter2.complete() awaiter2.complete()
@ -337,15 +337,15 @@ suite "Autonat Service":
switch1.stop(), switch2.stop(), switch3.stop()) switch1.stop(), switch2.stop(), switch3.stop())
asyncTest "Must work for one peer when two peers ask each other at the same time with max 1 conn per peer": asyncTest "Must work for one peer when two peers ask each other at the same time with max 1 conn per peer":
let autonatService1 = AutonatService.new(AutonatClient.new(), newRng(), some(500.millis), maxQueueSize = 3) let autonatService1 = AutonatService.new(AutonatClient.new(), newRng(), Opt.some(500.millis), maxQueueSize = 3)
let autonatService2 = AutonatService.new(AutonatClient.new(), newRng(), some(500.millis), maxQueueSize = 3) let autonatService2 = AutonatService.new(AutonatClient.new(), newRng(), Opt.some(500.millis), maxQueueSize = 3)
let switch1 = createSwitch(autonatService1, maxConnsPerPeer = 0) let switch1 = createSwitch(autonatService1, maxConnsPerPeer = 0)
let switch2 = createSwitch(autonatService2, maxConnsPerPeer = 0) let switch2 = createSwitch(autonatService2, maxConnsPerPeer = 0)
let awaiter1 = newFuture[void]() let awaiter1 = newFuture[void]()
proc statusAndConfidenceHandler1(networkReachability: NetworkReachability, confidence: Option[float]) {.gcsafe, async.} = proc statusAndConfidenceHandler1(networkReachability: NetworkReachability, confidence: Opt[float]) {.gcsafe, async.} =
if networkReachability == NetworkReachability.Reachable and confidence.isSome() and confidence.get() == 1: if networkReachability == NetworkReachability.Reachable and confidence.isSome() and confidence.get() == 1:
if not awaiter1.finished: if not awaiter1.finished:
awaiter1.complete() awaiter1.complete()
@ -378,7 +378,7 @@ suite "Autonat Service":
switch1.stop(), switch2.stop()) switch1.stop(), switch2.stop())
asyncTest "Must work with low maxConnections": asyncTest "Must work with low maxConnections":
let autonatService = AutonatService.new(AutonatClient.new(), newRng(), some(1.seconds), maxQueueSize = 1) let autonatService = AutonatService.new(AutonatClient.new(), newRng(), Opt.some(1.seconds), maxQueueSize = 1)
let switch1 = createSwitch(autonatService, maxConns = 4) let switch1 = createSwitch(autonatService, maxConns = 4)
let switch2 = createSwitch() let switch2 = createSwitch()
@ -388,7 +388,7 @@ suite "Autonat Service":
var awaiter = newFuture[void]() var awaiter = newFuture[void]()
proc statusAndConfidenceHandler(networkReachability: NetworkReachability, confidence: Option[float]) {.gcsafe, async.} = proc statusAndConfidenceHandler(networkReachability: NetworkReachability, confidence: Opt[float]) {.gcsafe, async.} =
if networkReachability == NetworkReachability.Reachable and confidence.isSome() and confidence.get() == 1: if networkReachability == NetworkReachability.Reachable and confidence.isSome() and confidence.get() == 1:
if not awaiter.finished: if not awaiter.finished:
awaiter.complete() awaiter.complete()
@ -428,7 +428,7 @@ suite "Autonat Service":
let switch1 = createSwitch(autonatService) let switch1 = createSwitch(autonatService)
let switch2 = createSwitch() let switch2 = createSwitch()
proc statusAndConfidenceHandler(networkReachability: NetworkReachability, confidence: Option[float]) {.gcsafe, async.} = proc statusAndConfidenceHandler(networkReachability: NetworkReachability, confidence: Opt[float]) {.gcsafe, async.} =
fail() fail()
check autonatService.networkReachability == NetworkReachability.Unknown check autonatService.networkReachability == NetworkReachability.Unknown

View File

@ -48,20 +48,20 @@ suite "Circuit Relay":
r {.threadvar.}: Relay r {.threadvar.}: Relay
conn {.threadvar.}: Connection conn {.threadvar.}: Connection
msg {.threadvar.}: ProtoBuffer msg {.threadvar.}: ProtoBuffer
rcv {.threadvar.}: Option[RelayMessage] rcv {.threadvar.}: Opt[RelayMessage]
proc createMsg( proc createMsg(
msgType: Option[RelayType] = RelayType.none, msgType: Opt[RelayType] = Opt.none(RelayType),
status: Option[StatusV1] = StatusV1.none, status: Opt[StatusV1] = Opt.none(StatusV1),
src: Option[RelayPeer] = RelayPeer.none, src: Opt[RelayPeer] = Opt.none(RelayPeer),
dst: Option[RelayPeer] = RelayPeer.none): ProtoBuffer = dst: Opt[RelayPeer] = Opt.none(RelayPeer)): ProtoBuffer =
encode(RelayMessage(msgType: msgType, srcPeer: src, dstPeer: dst, status: status)) encode(RelayMessage(msgType: msgType, srcPeer: src, dstPeer: dst, status: status))
proc checkMsg(msg: Option[RelayMessage], proc checkMsg(msg: Opt[RelayMessage],
msgType: Option[RelayType] = none[RelayType](), msgType: Opt[RelayType] = Opt.none(RelayType),
status: Option[StatusV1] = none[StatusV1](), status: Opt[StatusV1] = Opt.none(StatusV1),
src: Option[RelayPeer] = none[RelayPeer](), src: Opt[RelayPeer] = Opt.none(RelayPeer),
dst: Option[RelayPeer] = none[RelayPeer]()) = dst: Opt[RelayPeer] = Opt.none(RelayPeer)) =
check: msg.isSome check: msg.isSome
let m = msg.get() let m = msg.get()
check: m.msgType == msgType check: m.msgType == msgType
@ -119,116 +119,116 @@ suite "Circuit Relay":
await srelay.start() await srelay.start()
asyncTest "Handle CanHop": asyncTest "Handle CanHop":
msg = createMsg(some(CanHop)) msg = createMsg(Opt.some(CanHop))
conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec) conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec)
await conn.writeLp(msg.buffer) await conn.writeLp(msg.buffer)
rcv = RelayMessage.decode(await conn.readLp(1024)) rcv = RelayMessage.decode(await conn.readLp(1024))
rcv.checkMsg(some(RelayType.Status), some(StatusV1.Success)) rcv.checkMsg(Opt.some(RelayType.Status), Opt.some(StatusV1.Success))
conn = await src.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayV1Codec) conn = await src.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayV1Codec)
await conn.writeLp(msg.buffer) await conn.writeLp(msg.buffer)
rcv = RelayMessage.decode(await conn.readLp(1024)) rcv = RelayMessage.decode(await conn.readLp(1024))
rcv.checkMsg(some(RelayType.Status), some(HopCantSpeakRelay)) rcv.checkMsg(Opt.some(RelayType.Status), Opt.some(HopCantSpeakRelay))
await conn.close() await conn.close()
asyncTest "Malformed": asyncTest "Malformed":
conn = await srelay.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayV1Codec) conn = await srelay.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayV1Codec)
msg = createMsg(some(RelayType.Status)) msg = createMsg(Opt.some(RelayType.Status))
await conn.writeLp(msg.buffer) await conn.writeLp(msg.buffer)
rcv = RelayMessage.decode(await conn.readLp(1024)) rcv = RelayMessage.decode(await conn.readLp(1024))
await conn.close() await conn.close()
rcv.checkMsg(some(RelayType.Status), some(StatusV1.MalformedMessage)) rcv.checkMsg(Opt.some(RelayType.Status), Opt.some(StatusV1.MalformedMessage))
asyncTest "Handle Stop Error": asyncTest "Handle Stop Error":
conn = await srelay.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayV1Codec) conn = await srelay.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayV1Codec)
msg = createMsg(some(RelayType.Stop), msg = createMsg(Opt.some(RelayType.Stop),
none(StatusV1), Opt.none(StatusV1),
none(RelayPeer), Opt.none(RelayPeer),
some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs))) Opt.some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs)))
await conn.writeLp(msg.buffer) await conn.writeLp(msg.buffer)
rcv = RelayMessage.decode(await conn.readLp(1024)) rcv = RelayMessage.decode(await conn.readLp(1024))
rcv.checkMsg(some(RelayType.Status), some(StopSrcMultiaddrInvalid)) rcv.checkMsg(Opt.some(RelayType.Status), Opt.some(StopSrcMultiaddrInvalid))
conn = await srelay.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayV1Codec) conn = await srelay.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayV1Codec)
msg = createMsg(some(RelayType.Stop), msg = createMsg(Opt.some(RelayType.Stop),
none(StatusV1), Opt.none(StatusV1),
some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)), Opt.some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)),
none(RelayPeer)) Opt.none(RelayPeer))
await conn.writeLp(msg.buffer) await conn.writeLp(msg.buffer)
rcv = RelayMessage.decode(await conn.readLp(1024)) rcv = RelayMessage.decode(await conn.readLp(1024))
rcv.checkMsg(some(RelayType.Status), some(StopDstMultiaddrInvalid)) rcv.checkMsg(Opt.some(RelayType.Status), Opt.some(StopDstMultiaddrInvalid))
conn = await srelay.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayV1Codec) conn = await srelay.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayV1Codec)
msg = createMsg(some(RelayType.Stop), msg = createMsg(Opt.some(RelayType.Stop),
none(StatusV1), Opt.none(StatusV1),
some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs)), Opt.some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs)),
some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs))) Opt.some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)))
await conn.writeLp(msg.buffer) await conn.writeLp(msg.buffer)
rcv = RelayMessage.decode(await conn.readLp(1024)) rcv = RelayMessage.decode(await conn.readLp(1024))
await conn.close() await conn.close()
rcv.checkMsg(some(RelayType.Status), some(StopDstMultiaddrInvalid)) rcv.checkMsg(Opt.some(RelayType.Status), Opt.some(StopDstMultiaddrInvalid))
asyncTest "Handle Hop Error": asyncTest "Handle Hop Error":
conn = await src.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayV1Codec) conn = await src.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayV1Codec)
msg = createMsg(some(RelayType.Hop)) msg = createMsg(Opt.some(RelayType.Hop))
await conn.writeLp(msg.buffer) await conn.writeLp(msg.buffer)
rcv = RelayMessage.decode(await conn.readLp(1024)) rcv = RelayMessage.decode(await conn.readLp(1024))
rcv.checkMsg(some(RelayType.Status), some(HopCantSpeakRelay)) rcv.checkMsg(Opt.some(RelayType.Status), Opt.some(HopCantSpeakRelay))
conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec) conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec)
msg = createMsg(some(RelayType.Hop), msg = createMsg(Opt.some(RelayType.Hop),
none(StatusV1), Opt.none(StatusV1),
none(RelayPeer), Opt.none(RelayPeer),
some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs))) Opt.some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs)))
await conn.writeLp(msg.buffer) await conn.writeLp(msg.buffer)
rcv = RelayMessage.decode(await conn.readLp(1024)) rcv = RelayMessage.decode(await conn.readLp(1024))
rcv.checkMsg(some(RelayType.Status), some(HopSrcMultiaddrInvalid)) rcv.checkMsg(Opt.some(RelayType.Status), Opt.some(HopSrcMultiaddrInvalid))
conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec) conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec)
msg = createMsg(some(RelayType.Hop), msg = createMsg(Opt.some(RelayType.Hop),
none(StatusV1), Opt.none(StatusV1),
some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs)), Opt.some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs)),
some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs))) Opt.some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs)))
await conn.writeLp(msg.buffer) await conn.writeLp(msg.buffer)
rcv = RelayMessage.decode(await conn.readLp(1024)) rcv = RelayMessage.decode(await conn.readLp(1024))
rcv.checkMsg(some(RelayType.Status), some(HopSrcMultiaddrInvalid)) rcv.checkMsg(Opt.some(RelayType.Status), Opt.some(HopSrcMultiaddrInvalid))
conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec) conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec)
msg = createMsg(some(RelayType.Hop), msg = createMsg(Opt.some(RelayType.Hop),
none(StatusV1), Opt.none(StatusV1),
some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)), Opt.some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)),
none(RelayPeer)) Opt.none(RelayPeer))
await conn.writeLp(msg.buffer) await conn.writeLp(msg.buffer)
rcv = RelayMessage.decode(await conn.readLp(1024)) rcv = RelayMessage.decode(await conn.readLp(1024))
rcv.checkMsg(some(RelayType.Status), some(HopDstMultiaddrInvalid)) rcv.checkMsg(Opt.some(RelayType.Status), Opt.some(HopDstMultiaddrInvalid))
conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec) conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec)
msg = createMsg(some(RelayType.Hop), msg = createMsg(Opt.some(RelayType.Hop),
none(StatusV1), Opt.none(StatusV1),
some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)), Opt.some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)),
some(RelayPeer(peerId: srelay.peerInfo.peerId, addrs: srelay.peerInfo.addrs))) Opt.some(RelayPeer(peerId: srelay.peerInfo.peerId, addrs: srelay.peerInfo.addrs)))
await conn.writeLp(msg.buffer) await conn.writeLp(msg.buffer)
rcv = RelayMessage.decode(await conn.readLp(1024)) rcv = RelayMessage.decode(await conn.readLp(1024))
rcv.checkMsg(some(RelayType.Status), some(HopCantRelayToSelf)) rcv.checkMsg(Opt.some(RelayType.Status), Opt.some(HopCantRelayToSelf))
conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec) conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec)
msg = createMsg(some(RelayType.Hop), msg = createMsg(Opt.some(RelayType.Hop),
none(StatusV1), Opt.none(StatusV1),
some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)), Opt.some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)),
some(RelayPeer(peerId: srelay.peerInfo.peerId, addrs: srelay.peerInfo.addrs))) Opt.some(RelayPeer(peerId: srelay.peerInfo.peerId, addrs: srelay.peerInfo.addrs)))
await conn.writeLp(msg.buffer) await conn.writeLp(msg.buffer)
rcv = RelayMessage.decode(await conn.readLp(1024)) rcv = RelayMessage.decode(await conn.readLp(1024))
rcv.checkMsg(some(RelayType.Status), some(HopCantRelayToSelf)) rcv.checkMsg(Opt.some(RelayType.Status), Opt.some(HopCantRelayToSelf))
conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec) conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec)
msg = createMsg(some(RelayType.Hop), msg = createMsg(Opt.some(RelayType.Hop),
none(StatusV1), Opt.none(StatusV1),
some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)), Opt.some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)),
some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs))) Opt.some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs)))
await conn.writeLp(msg.buffer) await conn.writeLp(msg.buffer)
rcv = RelayMessage.decode(await conn.readLp(1024)) rcv = RelayMessage.decode(await conn.readLp(1024))
rcv.checkMsg(some(RelayType.Status), some(HopNoConnToDst)) rcv.checkMsg(Opt.some(RelayType.Status), Opt.some(HopNoConnToDst))
await srelay.connect(dst.peerInfo.peerId, dst.peerInfo.addrs) await srelay.connect(dst.peerInfo.peerId, dst.peerInfo.addrs)
@ -237,7 +237,7 @@ suite "Circuit Relay":
conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec) conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec)
await conn.writeLp(msg.buffer) await conn.writeLp(msg.buffer)
rcv = RelayMessage.decode(await conn.readLp(1024)) rcv = RelayMessage.decode(await conn.readLp(1024))
rcv.checkMsg(some(RelayType.Status), some(HopCantSpeakRelay)) rcv.checkMsg(Opt.some(RelayType.Status), Opt.some(HopCantSpeakRelay))
r.maxCircuit = tmp r.maxCircuit = tmp
await conn.close() await conn.close()
@ -246,7 +246,7 @@ suite "Circuit Relay":
conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec) conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec)
await conn.writeLp(msg.buffer) await conn.writeLp(msg.buffer)
rcv = RelayMessage.decode(await conn.readLp(1024)) rcv = RelayMessage.decode(await conn.readLp(1024))
rcv.checkMsg(some(RelayType.Status), some(HopCantSpeakRelay)) rcv.checkMsg(Opt.some(RelayType.Status), Opt.some(HopCantSpeakRelay))
r.maxCircuitPerPeer = tmp r.maxCircuitPerPeer = tmp
await conn.close() await conn.close()
@ -255,13 +255,13 @@ suite "Circuit Relay":
await srelay.connect(dst2.peerInfo.peerId, dst2.peerInfo.addrs) await srelay.connect(dst2.peerInfo.peerId, dst2.peerInfo.addrs)
conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec) conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec)
msg = createMsg(some(RelayType.Hop), msg = createMsg(Opt.some(RelayType.Hop),
none(StatusV1), Opt.none(StatusV1),
some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)), Opt.some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)),
some(RelayPeer(peerId: dst2.peerInfo.peerId, addrs: dst2.peerInfo.addrs))) Opt.some(RelayPeer(peerId: dst2.peerInfo.peerId, addrs: dst2.peerInfo.addrs)))
await conn.writeLp(msg.buffer) await conn.writeLp(msg.buffer)
rcv = RelayMessage.decode(await conn.readLp(1024)) rcv = RelayMessage.decode(await conn.readLp(1024))
rcv.checkMsg(some(RelayType.Status), some(HopCantDialDst)) rcv.checkMsg(Opt.some(RelayType.Status), Opt.some(HopCantDialDst))
await allFutures(dst2.stop()) await allFutures(dst2.stop())
asyncTest "Dial Peer": asyncTest "Dial Peer":

View File

@ -81,7 +81,7 @@ suite "Circuit Relay V2":
let msg = HopMessage.decode(await conn.readLp(RelayMsgSize)).get() let msg = HopMessage.decode(await conn.readLp(RelayMsgSize)).get()
check: check:
msg.msgType == HopMessageType.Status msg.msgType == HopMessageType.Status
msg.status == some(StatusV2.ReservationRefused) msg.status == Opt.some(StatusV2.ReservationRefused)
asyncTest "Too many reservations + Reconnect": asyncTest "Too many reservations + Reconnect":
expect(ReservationError): expect(ReservationError):

View File

@ -9,7 +9,6 @@
# This file may not be copied, modified, or distributed except according to # This file may not be copied, modified, or distributed except according to
# those terms. # those terms.
import strformat
import ./helpers import ./helpers
import ../libp2p/utility import ../libp2p/utility