noise updates (#255)
* clear secrets explicitly * simplify keygen * avoid some trivial memory allocations * fix little endian encoding of nonce
This commit is contained in:
parent
4e12d0d97a
commit
45c089ff0d
|
@ -105,10 +105,13 @@ proc mulgen*(_: type[Curve25519], dst: var Curve25519Key, point: Curve25519Key)
|
||||||
proc public*(private: Curve25519Key): Curve25519Key =
|
proc public*(private: Curve25519Key): Curve25519Key =
|
||||||
Curve25519.mulgen(result, private)
|
Curve25519.mulgen(result, private)
|
||||||
|
|
||||||
proc random*(_: type[Curve25519Key], rng: var BrHmacDrbgContext): Result[Curve25519Key, Curve25519Error] =
|
proc random*(_: type[Curve25519Key], rng: var BrHmacDrbgContext): Curve25519Key =
|
||||||
var res: Curve25519Key
|
var res: Curve25519Key
|
||||||
let defaultBrEc = brEcGetDefault()
|
let defaultBrEc = brEcGetDefault()
|
||||||
if brEcKeygen(addr rng.vtable, defaultBrEc, nil, addr res[0], EC_curve25519) != Curve25519KeySize:
|
let len = brEcKeygen(
|
||||||
err(Curver25519GenError)
|
addr rng.vtable, defaultBrEc, nil, addr res[0], EC_curve25519)
|
||||||
else:
|
# Per bearssl documentation, the keygen only fails if the curve is
|
||||||
ok(res)
|
# unrecognised -
|
||||||
|
doAssert len == Curve25519KeySize, "Could not generate curve"
|
||||||
|
|
||||||
|
res
|
||||||
|
|
|
@ -39,11 +39,11 @@ const
|
||||||
|
|
||||||
type
|
type
|
||||||
EcPrivateKey* = ref object
|
EcPrivateKey* = ref object
|
||||||
buffer*: seq[byte]
|
buffer*: array[BR_EC_KBUF_PRIV_MAX_SIZE, byte]
|
||||||
key*: BrEcPrivateKey
|
key*: BrEcPrivateKey
|
||||||
|
|
||||||
EcPublicKey* = ref object
|
EcPublicKey* = ref object
|
||||||
buffer*: seq[byte]
|
buffer*: array[BR_EC_KBUF_PUB_MAX_SIZE, byte]
|
||||||
key*: BrEcPublicKey
|
key*: BrEcPublicKey
|
||||||
|
|
||||||
EcKeyPair* = object
|
EcKeyPair* = object
|
||||||
|
@ -237,7 +237,6 @@ proc random*(
|
||||||
## secp521r1).
|
## secp521r1).
|
||||||
var ecimp = brEcGetDefault()
|
var ecimp = brEcGetDefault()
|
||||||
var res = new EcPrivateKey
|
var res = new EcPrivateKey
|
||||||
res.buffer = newSeq[byte](BR_EC_KBUF_PRIV_MAX_SIZE)
|
|
||||||
if brEcKeygen(addr rng.vtable, ecimp,
|
if brEcKeygen(addr rng.vtable, ecimp,
|
||||||
addr res.key, addr res.buffer[0],
|
addr res.key, addr res.buffer[0],
|
||||||
cast[cint](kind)) == 0:
|
cast[cint](kind)) == 0:
|
||||||
|
@ -254,7 +253,6 @@ proc getKey*(seckey: EcPrivateKey): EcResult[EcPublicKey] =
|
||||||
if seckey.key.curve in EcSupportedCurvesCint:
|
if seckey.key.curve in EcSupportedCurvesCint:
|
||||||
var length = getPublicKeyLength(cast[EcCurveKind](seckey.key.curve))
|
var length = getPublicKeyLength(cast[EcCurveKind](seckey.key.curve))
|
||||||
var res = new EcPublicKey
|
var res = new EcPublicKey
|
||||||
res.buffer = newSeq[byte](length)
|
|
||||||
if brEcComputePublicKey(ecimp, addr res.key,
|
if brEcComputePublicKey(ecimp, addr res.key,
|
||||||
addr res.buffer[0], unsafeAddr seckey.key) == 0:
|
addr res.buffer[0], unsafeAddr seckey.key) == 0:
|
||||||
err(EcKeyIncorrectError)
|
err(EcKeyIncorrectError)
|
||||||
|
@ -621,7 +619,6 @@ proc init*(key: var EcPrivateKey, data: openarray[byte]): Result[void, Asn1Error
|
||||||
|
|
||||||
if checkScalar(raw.toOpenArray(), curve) == 1'u32:
|
if checkScalar(raw.toOpenArray(), curve) == 1'u32:
|
||||||
key = new EcPrivateKey
|
key = new EcPrivateKey
|
||||||
key.buffer = newSeq[byte](raw.length)
|
|
||||||
copyMem(addr key.buffer[0], addr raw.buffer[raw.offset], raw.length)
|
copyMem(addr key.buffer[0], addr raw.buffer[raw.offset], raw.length)
|
||||||
key.key.x = cast[ptr cuchar](addr key.buffer[0])
|
key.key.x = cast[ptr cuchar](addr key.buffer[0])
|
||||||
key.key.xlen = raw.length
|
key.key.xlen = raw.length
|
||||||
|
@ -681,7 +678,6 @@ proc init*(pubkey: var EcPublicKey, data: openarray[byte]): Result[void, Asn1Err
|
||||||
|
|
||||||
if checkPublic(raw.toOpenArray(), curve) != 0:
|
if checkPublic(raw.toOpenArray(), curve) != 0:
|
||||||
pubkey = new EcPublicKey
|
pubkey = new EcPublicKey
|
||||||
pubkey.buffer = newSeq[byte](raw.length)
|
|
||||||
copyMem(addr pubkey.buffer[0], addr raw.buffer[raw.offset], raw.length)
|
copyMem(addr pubkey.buffer[0], addr raw.buffer[raw.offset], raw.length)
|
||||||
pubkey.key.q = cast[ptr cuchar](addr pubkey.buffer[0])
|
pubkey.key.q = cast[ptr cuchar](addr pubkey.buffer[0])
|
||||||
pubkey.key.qlen = raw.length
|
pubkey.key.qlen = raw.length
|
||||||
|
@ -769,7 +765,6 @@ proc initRaw*(key: var EcPrivateKey, data: openarray[byte]): bool =
|
||||||
if checkScalar(data, curve) == 1'u32:
|
if checkScalar(data, curve) == 1'u32:
|
||||||
let length = len(data)
|
let length = len(data)
|
||||||
key = new EcPrivateKey
|
key = new EcPrivateKey
|
||||||
key.buffer = newSeq[byte](length)
|
|
||||||
copyMem(addr key.buffer[0], unsafeAddr data[0], length)
|
copyMem(addr key.buffer[0], unsafeAddr data[0], length)
|
||||||
key.key.x = cast[ptr cuchar](addr key.buffer[0])
|
key.key.x = cast[ptr cuchar](addr key.buffer[0])
|
||||||
key.key.xlen = length
|
key.key.xlen = length
|
||||||
|
@ -801,7 +796,6 @@ proc initRaw*(pubkey: var EcPublicKey, data: openarray[byte]): bool =
|
||||||
if checkPublic(data, curve) != 0:
|
if checkPublic(data, curve) != 0:
|
||||||
let length = len(data)
|
let length = len(data)
|
||||||
pubkey = new EcPublicKey
|
pubkey = new EcPublicKey
|
||||||
pubkey.buffer = newSeq[byte](length)
|
|
||||||
copyMem(addr pubkey.buffer[0], unsafeAddr data[0], length)
|
copyMem(addr pubkey.buffer[0], unsafeAddr data[0], length)
|
||||||
pubkey.key.q = cast[ptr cuchar](addr pubkey.buffer[0])
|
pubkey.key.q = cast[ptr cuchar](addr pubkey.buffer[0])
|
||||||
pubkey.key.qlen = length
|
pubkey.key.qlen = length
|
||||||
|
|
|
@ -12,15 +12,13 @@ import chronicles
|
||||||
import bearssl
|
import bearssl
|
||||||
import stew/[endians2, byteutils]
|
import stew/[endians2, byteutils]
|
||||||
import nimcrypto/[utils, sha2, hmac]
|
import nimcrypto/[utils, sha2, hmac]
|
||||||
import ../../stream/lpstream
|
import ../../stream/[connection, streamseq]
|
||||||
import ../../peerid
|
import ../../peerid
|
||||||
import ../../peerinfo
|
import ../../peerinfo
|
||||||
import ../../protobuf/minprotobuf
|
import ../../protobuf/minprotobuf
|
||||||
import ../../utility
|
import ../../utility
|
||||||
import ../../stream/lpstream
|
|
||||||
import secure,
|
import secure,
|
||||||
../../crypto/[crypto, chacha20poly1305, curve25519, hkdf],
|
../../crypto/[crypto, chacha20poly1305, curve25519, hkdf]
|
||||||
../../stream/bufferstream
|
|
||||||
|
|
||||||
logScope:
|
logScope:
|
||||||
topics = "noise"
|
topics = "noise"
|
||||||
|
@ -34,7 +32,7 @@ const
|
||||||
ProtocolXXName = "Noise_XX_25519_ChaChaPoly_SHA256"
|
ProtocolXXName = "Noise_XX_25519_ChaChaPoly_SHA256"
|
||||||
|
|
||||||
# Empty is a special value which indicates k has not yet been initialized.
|
# Empty is a special value which indicates k has not yet been initialized.
|
||||||
EmptyKey: ChaChaPolyKey = [0.byte, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
EmptyKey = default(ChaChaPolyKey)
|
||||||
NonceMax = uint64.high - 1 # max is reserved
|
NonceMax = uint64.high - 1 # max is reserved
|
||||||
NoiseSize = 32
|
NoiseSize = 32
|
||||||
MaxPlainSize = int(uint16.high - NoiseSize - ChaChaPolyTag.len)
|
MaxPlainSize = int(uint16.high - NoiseSize - ChaChaPolyTag.len)
|
||||||
|
@ -72,7 +70,7 @@ type
|
||||||
Noise* = ref object of Secure
|
Noise* = ref object of Secure
|
||||||
rng: ref BrHmacDrbgContext
|
rng: ref BrHmacDrbgContext
|
||||||
localPrivateKey: PrivateKey
|
localPrivateKey: PrivateKey
|
||||||
localPublicKey: PublicKey
|
localPublicKey: seq[byte]
|
||||||
noiseKeys: KeyPair
|
noiseKeys: KeyPair
|
||||||
commonPrologue: seq[byte]
|
commonPrologue: seq[byte]
|
||||||
outgoing: bool
|
outgoing: bool
|
||||||
|
@ -89,7 +87,7 @@ type
|
||||||
# Utility
|
# Utility
|
||||||
|
|
||||||
proc genKeyPair(rng: var BrHmacDrbgContext): KeyPair =
|
proc genKeyPair(rng: var BrHmacDrbgContext): KeyPair =
|
||||||
result.privateKey = Curve25519Key.random(rng).tryGet()
|
result.privateKey = Curve25519Key.random(rng)
|
||||||
result.publicKey = result.privateKey.public()
|
result.publicKey = result.privateKey.public()
|
||||||
|
|
||||||
proc hashProtocol(name: string): MDigest[256] =
|
proc hashProtocol(name: string): MDigest[256] =
|
||||||
|
@ -110,12 +108,11 @@ proc dh(priv: Curve25519Key, pub: Curve25519Key): Curve25519Key =
|
||||||
proc hasKey(cs: CipherState): bool =
|
proc hasKey(cs: CipherState): bool =
|
||||||
cs.k != EmptyKey
|
cs.k != EmptyKey
|
||||||
|
|
||||||
proc encryptWithAd(state: var CipherState, ad, data: openarray[byte]): seq[byte] =
|
proc encryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte] =
|
||||||
var
|
var
|
||||||
tag: ChaChaPolyTag
|
tag: ChaChaPolyTag
|
||||||
nonce: ChaChaPolyNonce
|
nonce: ChaChaPolyNonce
|
||||||
np = cast[ptr uint64](addr nonce[4])
|
nonce[4..<12] = toBytesLE(state.n)
|
||||||
np[] = state.n
|
|
||||||
result = @data
|
result = @data
|
||||||
ChaChaPoly.encrypt(state.k, nonce, tag, result, ad)
|
ChaChaPoly.encrypt(state.k, nonce, tag, result, ad)
|
||||||
inc state.n
|
inc state.n
|
||||||
|
@ -124,13 +121,12 @@ proc encryptWithAd(state: var CipherState, ad, data: openarray[byte]): seq[byte]
|
||||||
result &= tag
|
result &= tag
|
||||||
trace "encryptWithAd", tag = byteutils.toHex(tag), data = result.shortLog, nonce = state.n - 1
|
trace "encryptWithAd", tag = byteutils.toHex(tag), data = result.shortLog, nonce = state.n - 1
|
||||||
|
|
||||||
proc decryptWithAd(state: var CipherState, ad, data: openarray[byte]): seq[byte] =
|
proc decryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte] =
|
||||||
var
|
var
|
||||||
tagIn = data[^ChaChaPolyTag.len..data.high].intoChaChaPolyTag
|
tagIn = data.toOpenArray(data.len - ChaChaPolyTag.len, data.high).intoChaChaPolyTag
|
||||||
tagOut = tagIn
|
tagOut: ChaChaPolyTag
|
||||||
nonce: ChaChaPolyNonce
|
nonce: ChaChaPolyNonce
|
||||||
np = cast[ptr uint64](addr nonce[4])
|
nonce[4..<12] = toBytesLE(state.n)
|
||||||
np[] = state.n
|
|
||||||
result = data[0..(data.high - ChaChaPolyTag.len)]
|
result = data[0..(data.high - ChaChaPolyTag.len)]
|
||||||
ChaChaPoly.decrypt(state.k, nonce, tagOut, result, ad)
|
ChaChaPoly.decrypt(state.k, nonce, tagOut, result, ad)
|
||||||
trace "decryptWithAd", tagIn = tagIn.shortLog, tagOut = tagOut.shortLog, nonce = state.n
|
trace "decryptWithAd", tagIn = tagIn.shortLog, tagOut = tagOut.shortLog, nonce = state.n
|
||||||
|
@ -156,7 +152,7 @@ proc mixKey(ss: var SymmetricState, ikm: ChaChaPolyKey) =
|
||||||
ss.cs = CipherState(k: temp_keys[1])
|
ss.cs = CipherState(k: temp_keys[1])
|
||||||
trace "mixKey", key = ss.cs.k.shortLog
|
trace "mixKey", key = ss.cs.k.shortLog
|
||||||
|
|
||||||
proc mixHash(ss: var SymmetricState; data: openarray[byte]) =
|
proc mixHash(ss: var SymmetricState; data: openArray[byte]) =
|
||||||
var ctx: sha256
|
var ctx: sha256
|
||||||
ctx.init()
|
ctx.init()
|
||||||
ctx.update(ss.h.data)
|
ctx.update(ss.h.data)
|
||||||
|
@ -165,7 +161,7 @@ proc mixHash(ss: var SymmetricState; data: openarray[byte]) =
|
||||||
trace "mixHash", hash = ss.h.data.shortLog
|
trace "mixHash", hash = ss.h.data.shortLog
|
||||||
|
|
||||||
# We might use this for other handshake patterns/tokens
|
# We might use this for other handshake patterns/tokens
|
||||||
proc mixKeyAndHash(ss: var SymmetricState; ikm: openarray[byte]) {.used.} =
|
proc mixKeyAndHash(ss: var SymmetricState; ikm: openArray[byte]) {.used.} =
|
||||||
var
|
var
|
||||||
temp_keys: array[3, ChaChaPolyKey]
|
temp_keys: array[3, ChaChaPolyKey]
|
||||||
sha256.hkdf(ss.ck, ikm, [], temp_keys)
|
sha256.hkdf(ss.ck, ikm, [], temp_keys)
|
||||||
|
@ -173,7 +169,7 @@ proc mixKeyAndHash(ss: var SymmetricState; ikm: openarray[byte]) {.used.} =
|
||||||
ss.mixHash(temp_keys[1])
|
ss.mixHash(temp_keys[1])
|
||||||
ss.cs = CipherState(k: temp_keys[2])
|
ss.cs = CipherState(k: temp_keys[2])
|
||||||
|
|
||||||
proc encryptAndHash(ss: var SymmetricState, data: openarray[byte]): seq[byte] =
|
proc encryptAndHash(ss: var SymmetricState, data: openArray[byte]): seq[byte] =
|
||||||
# according to spec if key is empty leave plaintext
|
# according to spec if key is empty leave plaintext
|
||||||
if ss.cs.hasKey:
|
if ss.cs.hasKey:
|
||||||
result = ss.cs.encryptWithAd(ss.h.data, data)
|
result = ss.cs.encryptWithAd(ss.h.data, data)
|
||||||
|
@ -181,7 +177,7 @@ proc encryptAndHash(ss: var SymmetricState, data: openarray[byte]): seq[byte] =
|
||||||
result = @data
|
result = @data
|
||||||
ss.mixHash(result)
|
ss.mixHash(result)
|
||||||
|
|
||||||
proc decryptAndHash(ss: var SymmetricState, data: openarray[byte]): seq[byte] =
|
proc decryptAndHash(ss: var SymmetricState, data: openArray[byte]): seq[byte] =
|
||||||
# according to spec if key is empty leave plaintext
|
# according to spec if key is empty leave plaintext
|
||||||
if ss.cs.hasKey:
|
if ss.cs.hasKey:
|
||||||
result = ss.cs.decryptWithAd(ss.h.data, data)
|
result = ss.cs.decryptWithAd(ss.h.data, data)
|
||||||
|
@ -202,13 +198,13 @@ template write_e: untyped =
|
||||||
trace "noise write e"
|
trace "noise write e"
|
||||||
# Sets e (which must be empty) to GENERATE_KEYPAIR(). Appends e.public_key to the buffer. Calls MixHash(e.public_key).
|
# Sets e (which must be empty) to GENERATE_KEYPAIR(). Appends e.public_key to the buffer. Calls MixHash(e.public_key).
|
||||||
hs.e = genKeyPair(p.rng[])
|
hs.e = genKeyPair(p.rng[])
|
||||||
msg &= hs.e.publicKey
|
msg.add hs.e.publicKey
|
||||||
hs.ss.mixHash(hs.e.publicKey)
|
hs.ss.mixHash(hs.e.publicKey)
|
||||||
|
|
||||||
template write_s: untyped =
|
template write_s: untyped =
|
||||||
trace "noise write s"
|
trace "noise write s"
|
||||||
# Appends EncryptAndHash(s.public_key) to the buffer.
|
# Appends EncryptAndHash(s.public_key) to the buffer.
|
||||||
msg &= hs.ss.encryptAndHash(hs.s.publicKey)
|
msg.add hs.ss.encryptAndHash(hs.s.publicKey)
|
||||||
|
|
||||||
template dh_ee: untyped =
|
template dh_ee: untyped =
|
||||||
trace "noise dh ee"
|
trace "noise dh ee"
|
||||||
|
@ -244,8 +240,8 @@ template read_e: untyped =
|
||||||
raise newException(NoiseHandshakeError, "Noise E, expected more data")
|
raise newException(NoiseHandshakeError, "Noise E, expected more data")
|
||||||
|
|
||||||
# Sets re (which must be empty) to the next DHLEN bytes from the message. Calls MixHash(re.public_key).
|
# Sets re (which must be empty) to the next DHLEN bytes from the message. Calls MixHash(re.public_key).
|
||||||
hs.re[0..Curve25519Key.high] = msg[0..Curve25519Key.high]
|
hs.re[0..Curve25519Key.high] = msg.toOpenArray(0, Curve25519Key.high)
|
||||||
msg = msg[Curve25519Key.len..msg.high]
|
msg.consume(Curve25519Key.len)
|
||||||
hs.ss.mixHash(hs.re)
|
hs.ss.mixHash(hs.re)
|
||||||
|
|
||||||
template read_s: untyped =
|
template read_s: untyped =
|
||||||
|
@ -253,30 +249,33 @@ template read_s: untyped =
|
||||||
# Sets temp to the next DHLEN + 16 bytes of the message if HasKey() == True, or to the next DHLEN bytes otherwise.
|
# Sets temp to the next DHLEN + 16 bytes of the message if HasKey() == True, or to the next DHLEN bytes otherwise.
|
||||||
# Sets rs (which must be empty) to DecryptAndHash(temp).
|
# Sets rs (which must be empty) to DecryptAndHash(temp).
|
||||||
let
|
let
|
||||||
temp =
|
rsLen =
|
||||||
if hs.ss.cs.hasKey:
|
if hs.ss.cs.hasKey:
|
||||||
if msg.len < Curve25519Key.len + ChaChaPolyTag.len:
|
if msg.len < Curve25519Key.len + ChaChaPolyTag.len:
|
||||||
raise newException(NoiseHandshakeError, "Noise S, expected more data")
|
raise newException(NoiseHandshakeError, "Noise S, expected more data")
|
||||||
msg[0..Curve25519Key.high + ChaChaPolyTag.len]
|
Curve25519Key.len + ChaChaPolyTag.len
|
||||||
else:
|
else:
|
||||||
if msg.len < Curve25519Key.len:
|
if msg.len < Curve25519Key.len:
|
||||||
raise newException(NoiseHandshakeError, "Noise S, expected more data")
|
raise newException(NoiseHandshakeError, "Noise S, expected more data")
|
||||||
msg[0..Curve25519Key.high]
|
Curve25519Key.len
|
||||||
msg = msg[temp.len..msg.high]
|
hs.rs[0..Curve25519Key.high] =
|
||||||
let plain = hs.ss.decryptAndHash(temp)
|
hs.ss.decryptAndHash(msg.toOpenArray(0, rsLen - 1))
|
||||||
hs.rs[0..Curve25519Key.high] = plain
|
|
||||||
|
msg.consume(rsLen)
|
||||||
|
|
||||||
proc receiveHSMessage(sconn: Connection): Future[seq[byte]] {.async.} =
|
proc receiveHSMessage(sconn: Connection): Future[seq[byte]] {.async.} =
|
||||||
var besize: array[2, byte]
|
var besize: array[2, byte]
|
||||||
await sconn.readExactly(addr besize[0], besize.len)
|
await sconn.readExactly(addr besize[0], besize.len)
|
||||||
let size = uint16.fromBytesBE(besize).int
|
let size = uint16.fromBytesBE(besize).int
|
||||||
trace "receiveHSMessage", size
|
trace "receiveHSMessage", size
|
||||||
|
if size == 0:
|
||||||
|
return
|
||||||
|
|
||||||
var buffer = newSeq[byte](size)
|
var buffer = newSeq[byte](size)
|
||||||
if buffer.len > 0:
|
await sconn.readExactly(addr buffer[0], buffer.len)
|
||||||
await sconn.readExactly(addr buffer[0], buffer.len)
|
|
||||||
return buffer
|
return buffer
|
||||||
|
|
||||||
proc sendHSMessage(sconn: Connection; buf: seq[byte]) {.async.} =
|
proc sendHSMessage(sconn: Connection; buf: openArray[byte]): Future[void] =
|
||||||
var
|
var
|
||||||
lesize = buf.len.uint16
|
lesize = buf.len.uint16
|
||||||
besize = lesize.toBytesBE
|
besize = lesize.toBytesBE
|
||||||
|
@ -284,97 +283,106 @@ proc sendHSMessage(sconn: Connection; buf: seq[byte]) {.async.} =
|
||||||
trace "sendHSMessage", size = lesize
|
trace "sendHSMessage", size = lesize
|
||||||
outbuf &= besize
|
outbuf &= besize
|
||||||
outbuf &= buf
|
outbuf &= buf
|
||||||
await sconn.write(outbuf)
|
sconn.write(outbuf)
|
||||||
|
|
||||||
proc handshakeXXOutbound(p: Noise, conn: Connection, p2pProof: ProtoBuffer): Future[HandshakeResult] {.async.} =
|
proc handshakeXXOutbound(
|
||||||
|
p: Noise, conn: Connection,
|
||||||
|
p2pSecret: seq[byte]): Future[HandshakeResult] {.async.} =
|
||||||
const initiator = true
|
const initiator = true
|
||||||
|
|
||||||
var
|
var
|
||||||
hs = HandshakeState.init()
|
hs = HandshakeState.init()
|
||||||
p2psecret = p2pProof.buffer
|
|
||||||
|
|
||||||
hs.ss.mixHash(p.commonPrologue)
|
try:
|
||||||
hs.s = p.noiseKeys
|
|
||||||
|
|
||||||
# -> e
|
hs.ss.mixHash(p.commonPrologue)
|
||||||
var msg: seq[byte]
|
hs.s = p.noiseKeys
|
||||||
|
|
||||||
write_e()
|
# -> e
|
||||||
|
var msg: StreamSeq
|
||||||
|
|
||||||
# IK might use this btw!
|
write_e()
|
||||||
msg &= hs.ss.encryptAndHash(@[])
|
|
||||||
|
|
||||||
await conn.sendHSMessage(msg)
|
# IK might use this btw!
|
||||||
|
msg.add hs.ss.encryptAndHash([])
|
||||||
|
|
||||||
# <- e, ee, s, es
|
await conn.sendHSMessage(msg.data)
|
||||||
|
|
||||||
msg = await conn.receiveHSMessage()
|
# <- e, ee, s, es
|
||||||
|
|
||||||
read_e()
|
msg.assign(await conn.receiveHSMessage())
|
||||||
dh_ee()
|
|
||||||
read_s()
|
|
||||||
dh_es()
|
|
||||||
|
|
||||||
let remoteP2psecret = hs.ss.decryptAndHash(msg)
|
read_e()
|
||||||
|
dh_ee()
|
||||||
|
read_s()
|
||||||
|
dh_es()
|
||||||
|
|
||||||
# -> s, se
|
let remoteP2psecret = hs.ss.decryptAndHash(msg.data)
|
||||||
|
msg.clear()
|
||||||
|
|
||||||
msg.setLen(0)
|
# -> s, se
|
||||||
|
|
||||||
write_s()
|
write_s()
|
||||||
dh_se()
|
dh_se()
|
||||||
|
|
||||||
# last payload must follow the ecrypted way of sending
|
# last payload must follow the encrypted way of sending
|
||||||
msg &= hs.ss.encryptAndHash(p2psecret)
|
msg.add hs.ss.encryptAndHash(p2psecret)
|
||||||
|
|
||||||
await conn.sendHSMessage(msg)
|
await conn.sendHSMessage(msg.data)
|
||||||
|
|
||||||
let (cs1, cs2) = hs.ss.split()
|
let (cs1, cs2) = hs.ss.split()
|
||||||
return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs)
|
return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs)
|
||||||
|
finally:
|
||||||
|
burnMem(hs)
|
||||||
|
|
||||||
proc handshakeXXInbound(p: Noise, conn: Connection, p2pProof: ProtoBuffer): Future[HandshakeResult] {.async.} =
|
proc handshakeXXInbound(
|
||||||
|
p: Noise, conn: Connection,
|
||||||
|
p2pSecret: seq[byte]): Future[HandshakeResult] {.async.} =
|
||||||
const initiator = false
|
const initiator = false
|
||||||
|
|
||||||
var
|
var
|
||||||
hs = HandshakeState.init()
|
hs = HandshakeState.init()
|
||||||
p2psecret = p2pProof.buffer
|
|
||||||
|
|
||||||
hs.ss.mixHash(p.commonPrologue)
|
try:
|
||||||
hs.s = p.noiseKeys
|
hs.ss.mixHash(p.commonPrologue)
|
||||||
|
hs.s = p.noiseKeys
|
||||||
|
|
||||||
# -> e
|
# -> e
|
||||||
|
|
||||||
var msg = await conn.receiveHSMessage()
|
var msg: StreamSeq
|
||||||
|
msg.add(await conn.receiveHSMessage())
|
||||||
|
|
||||||
read_e()
|
read_e()
|
||||||
|
|
||||||
# we might use this early data one day, keeping it here for clarity
|
# we might use this early data one day, keeping it here for clarity
|
||||||
let earlyData {.used.} = hs.ss.decryptAndHash(msg)
|
let earlyData {.used.} = hs.ss.decryptAndHash(msg.data)
|
||||||
|
|
||||||
# <- e, ee, s, es
|
# <- e, ee, s, es
|
||||||
|
|
||||||
msg.setLen(0)
|
msg.consume(msg.len)
|
||||||
|
|
||||||
write_e()
|
write_e()
|
||||||
dh_ee()
|
dh_ee()
|
||||||
write_s()
|
write_s()
|
||||||
dh_es()
|
dh_es()
|
||||||
|
|
||||||
msg &= hs.ss.encryptAndHash(p2psecret)
|
msg.add hs.ss.encryptAndHash(p2psecret)
|
||||||
|
|
||||||
await conn.sendHSMessage(msg)
|
await conn.sendHSMessage(msg.data)
|
||||||
|
msg.clear()
|
||||||
|
|
||||||
# -> s, se
|
# -> s, se
|
||||||
|
|
||||||
msg = await conn.receiveHSMessage()
|
msg.add(await conn.receiveHSMessage())
|
||||||
|
|
||||||
read_s()
|
read_s()
|
||||||
dh_se()
|
dh_se()
|
||||||
|
|
||||||
let remoteP2psecret = hs.ss.decryptAndHash(msg)
|
let
|
||||||
|
remoteP2psecret = hs.ss.decryptAndHash(msg.data)
|
||||||
let (cs1, cs2) = hs.ss.split()
|
(cs1, cs2) = hs.ss.split()
|
||||||
return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs)
|
return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs)
|
||||||
|
finally:
|
||||||
|
burnMem(hs)
|
||||||
|
|
||||||
method readMessage*(sconn: NoiseConnection): Future[seq[byte]] {.async.} =
|
method readMessage*(sconn: NoiseConnection): Future[seq[byte]] {.async.} =
|
||||||
while true: # Discard 0-length payloads
|
while true: # Discard 0-length payloads
|
||||||
|
@ -399,7 +407,8 @@ method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async.
|
||||||
while left > 0:
|
while left > 0:
|
||||||
let
|
let
|
||||||
chunkSize = if left > MaxPlainSize: MaxPlainSize else: left
|
chunkSize = if left > MaxPlainSize: MaxPlainSize else: left
|
||||||
cipher = sconn.writeCs.encryptWithAd([], message.toOpenArray(offset, offset + chunkSize - 1))
|
cipher = sconn.writeCs.encryptWithAd(
|
||||||
|
[], message.toOpenArray(offset, offset + chunkSize - 1))
|
||||||
left = left - chunkSize
|
left = left - chunkSize
|
||||||
offset = offset + chunkSize
|
offset = offset + chunkSize
|
||||||
var
|
var
|
||||||
|
@ -421,65 +430,75 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
|
||||||
|
|
||||||
var
|
var
|
||||||
libp2pProof = initProtoBuffer()
|
libp2pProof = initProtoBuffer()
|
||||||
libp2pProof.write(initProtoField(1, p.localPublicKey.getBytes.tryGet()))
|
libp2pProof.write(initProtoField(1, p.localPublicKey))
|
||||||
libp2pProof.write(initProtoField(2, signedPayload.getBytes()))
|
libp2pProof.write(initProtoField(2, signedPayload.getBytes()))
|
||||||
# data field also there but not used!
|
# data field also there but not used!
|
||||||
libp2pProof.finish()
|
libp2pProof.finish()
|
||||||
|
|
||||||
let handshakeRes =
|
var handshakeRes =
|
||||||
if initiator:
|
if initiator:
|
||||||
await handshakeXXOutbound(p, conn, libp2pProof)
|
await handshakeXXOutbound(p, conn, libp2pProof.buffer)
|
||||||
else:
|
else:
|
||||||
await handshakeXXInbound(p, conn, libp2pProof)
|
await handshakeXXInbound(p, conn, libp2pProof.buffer)
|
||||||
|
|
||||||
var
|
var secure = try:
|
||||||
remoteProof = initProtoBuffer(handshakeRes.remoteP2psecret)
|
var
|
||||||
remotePubKey: PublicKey
|
remoteProof = initProtoBuffer(handshakeRes.remoteP2psecret)
|
||||||
remotePubKeyBytes: seq[byte]
|
remotePubKey: PublicKey
|
||||||
remoteSig: Signature
|
remotePubKeyBytes: seq[byte]
|
||||||
remoteSigBytes: seq[byte]
|
remoteSig: Signature
|
||||||
|
remoteSigBytes: seq[byte]
|
||||||
|
|
||||||
if remoteProof.getLengthValue(1, remotePubKeyBytes) <= 0:
|
if remoteProof.getLengthValue(1, remotePubKeyBytes) <= 0:
|
||||||
raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
|
raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
|
||||||
if remoteProof.getLengthValue(2, remoteSigBytes) <= 0:
|
if remoteProof.getLengthValue(2, remoteSigBytes) <= 0:
|
||||||
raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
|
raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
|
||||||
|
|
||||||
if not remotePubKey.init(remotePubKeyBytes):
|
if not remotePubKey.init(remotePubKeyBytes):
|
||||||
raise newException(NoiseHandshakeError, "Failed to decode remote public key. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
|
raise newException(NoiseHandshakeError, "Failed to decode remote public key. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
|
||||||
if not remoteSig.init(remoteSigBytes):
|
if not remoteSig.init(remoteSigBytes):
|
||||||
raise newException(NoiseHandshakeError, "Failed to decode remote signature. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
|
raise newException(NoiseHandshakeError, "Failed to decode remote signature. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
|
||||||
|
|
||||||
let verifyPayload = PayloadString.toBytes & handshakeRes.rs.getBytes
|
let verifyPayload = PayloadString.toBytes & handshakeRes.rs.getBytes
|
||||||
if not remoteSig.verify(verifyPayload, remotePubKey):
|
if not remoteSig.verify(verifyPayload, remotePubKey):
|
||||||
raise newException(NoiseHandshakeError, "Noise handshake signature verify failed.")
|
raise newException(NoiseHandshakeError, "Noise handshake signature verify failed.")
|
||||||
else:
|
else:
|
||||||
trace "Remote signature verified", peer = $conn
|
trace "Remote signature verified", peer = $conn
|
||||||
|
|
||||||
if initiator and not isNil(conn.peerInfo):
|
if initiator and not isNil(conn.peerInfo):
|
||||||
let pid = PeerID.init(remotePubKey)
|
let pid = PeerID.init(remotePubKey)
|
||||||
if not conn.peerInfo.peerId.validate():
|
if not conn.peerInfo.peerId.validate():
|
||||||
raise newException(NoiseHandshakeError, "Failed to validate peerId.")
|
raise newException(NoiseHandshakeError, "Failed to validate peerId.")
|
||||||
if pid.isErr or pid.get() != conn.peerInfo.peerId:
|
if pid.isErr or pid.get() != conn.peerInfo.peerId:
|
||||||
var
|
var
|
||||||
failedKey: PublicKey
|
failedKey: PublicKey
|
||||||
discard extractPublicKey(conn.peerInfo.peerId, failedKey)
|
discard extractPublicKey(conn.peerInfo.peerId, failedKey)
|
||||||
debug "Noise handshake, peer infos don't match!", initiator, dealt_peer = $conn.peerInfo.id, dealt_key = $failedKey, received_peer = $pid, received_key = $remotePubKey
|
debug "Noise handshake, peer infos don't match!", initiator, dealt_peer = $conn.peerInfo.id, dealt_key = $failedKey, received_peer = $pid, received_key = $remotePubKey
|
||||||
raise newException(NoiseHandshakeError, "Noise handshake, peer infos don't match! " & $pid & " != " & $conn.peerInfo.peerId)
|
raise newException(NoiseHandshakeError, "Noise handshake, peer infos don't match! " & $pid & " != " & $conn.peerInfo.peerId)
|
||||||
|
|
||||||
var secure = NoiseConnection.init(conn,
|
var tmp = NoiseConnection.init(
|
||||||
PeerInfo.init(remotePubKey),
|
conn, PeerInfo.init(remotePubKey), conn.observedAddr)
|
||||||
conn.observedAddr)
|
|
||||||
if initiator:
|
if initiator:
|
||||||
secure.readCs = handshakeRes.cs2
|
tmp.readCs = handshakeRes.cs2
|
||||||
secure.writeCs = handshakeRes.cs1
|
tmp.writeCs = handshakeRes.cs1
|
||||||
else:
|
else:
|
||||||
secure.readCs = handshakeRes.cs1
|
tmp.readCs = handshakeRes.cs1
|
||||||
secure.writeCs = handshakeRes.cs2
|
tmp.writeCs = handshakeRes.cs2
|
||||||
|
tmp
|
||||||
|
finally:
|
||||||
|
burnMem(handshakeRes)
|
||||||
|
|
||||||
trace "Noise handshake completed!", initiator, peer = $secure.peerInfo
|
trace "Noise handshake completed!", initiator, peer = $secure.peerInfo
|
||||||
|
|
||||||
return secure
|
return secure
|
||||||
|
|
||||||
|
method close*(s: NoiseConnection) {.async.} =
|
||||||
|
await procCall SecureConn(s).close()
|
||||||
|
|
||||||
|
burnMem(s.readCs)
|
||||||
|
burnMem(s.writeCs)
|
||||||
|
|
||||||
method init*(p: Noise) {.gcsafe.} =
|
method init*(p: Noise) {.gcsafe.} =
|
||||||
procCall Secure(p).init()
|
procCall Secure(p).init()
|
||||||
p.codec = NoiseCodec
|
p.codec = NoiseCodec
|
||||||
|
@ -491,7 +510,7 @@ proc newNoise*(
|
||||||
rng: rng,
|
rng: rng,
|
||||||
outgoing: outgoing,
|
outgoing: outgoing,
|
||||||
localPrivateKey: privateKey,
|
localPrivateKey: privateKey,
|
||||||
localPublicKey: privateKey.getKey().tryGet(),
|
localPublicKey: privateKey.getKey().tryGet().getBytes().tryGet(),
|
||||||
noiseKeys: genKeyPair(rng[]),
|
noiseKeys: genKeyPair(rng[]),
|
||||||
commonPrologue: commonPrologue,
|
commonPrologue: commonPrologue,
|
||||||
)
|
)
|
||||||
|
|
|
@ -59,10 +59,9 @@ proc handleConn*(s: Secure,
|
||||||
conn: Connection,
|
conn: Connection,
|
||||||
initiator: bool): Future[Connection] {.async, gcsafe.} =
|
initiator: bool): Future[Connection] {.async, gcsafe.} =
|
||||||
var sconn = await s.handshake(conn, initiator)
|
var sconn = await s.handshake(conn, initiator)
|
||||||
|
if not isNil(sconn):
|
||||||
conn.closeEvent.wait()
|
conn.closeEvent.wait()
|
||||||
.addCallback do(udata: pointer = nil):
|
.addCallback do(udata: pointer = nil):
|
||||||
if not(isNil(sconn)):
|
|
||||||
asyncCheck sconn.close()
|
asyncCheck sconn.close()
|
||||||
|
|
||||||
return sconn
|
return sconn
|
||||||
|
|
|
@ -61,6 +61,11 @@ template data*(v: StreamSeq): openArray[byte] =
|
||||||
# TODO a double-hash comment here breaks compile (!)
|
# TODO a double-hash comment here breaks compile (!)
|
||||||
v.buf.toOpenArray(v.rpos, v.wpos - 1)
|
v.buf.toOpenArray(v.rpos, v.wpos - 1)
|
||||||
|
|
||||||
|
template toOpenArray*(v: StreamSeq, b, e: int): openArray[byte] =
|
||||||
|
# Data that is ready to be consumed
|
||||||
|
# TODO a double-hash comment here breaks compile (!)
|
||||||
|
v.buf.toOpenArray(v.rpos + b, v.rpos + e - b)
|
||||||
|
|
||||||
func consume*(v: var StreamSeq, n: int) =
|
func consume*(v: var StreamSeq, n: int) =
|
||||||
## Mark `n` bytes that were returned via `data` as consumed
|
## Mark `n` bytes that were returned via `data` as consumed
|
||||||
v.rpos += n
|
v.rpos += n
|
||||||
|
@ -71,3 +76,10 @@ func consumeTo*(v: var StreamSeq, buf: var openArray[byte]): int =
|
||||||
copyMem(addr buf[0], addr v.buf[v.rpos], bytes)
|
copyMem(addr buf[0], addr v.buf[v.rpos], bytes)
|
||||||
v.consume(bytes)
|
v.consume(bytes)
|
||||||
bytes
|
bytes
|
||||||
|
|
||||||
|
func clear*(v: var StreamSeq) =
|
||||||
|
v.consume(v.len)
|
||||||
|
|
||||||
|
func assign*(v: var StreamSeq, buf: openArray[byte]) =
|
||||||
|
v.clear()
|
||||||
|
v.add(buf)
|
||||||
|
|
Loading…
Reference in New Issue