mirror of
https://github.com/logos-messaging/nim-chat-poc.git
synced 2026-01-02 14:13:10 +00:00
Initial impl
This commit is contained in:
parent
1c03bd8c7a
commit
520ce1cb80
6
src/naxolotl.nim
Normal file
6
src/naxolotl.nim
Normal file
@ -0,0 +1,6 @@
|
||||
import naxolotl/[
|
||||
naxolotl,
|
||||
curve25519
|
||||
]
|
||||
|
||||
export naxolotl, curve25519
|
||||
64
src/naxolotl/chacha.nim
Normal file
64
src/naxolotl/chacha.nim
Normal file
@ -0,0 +1,64 @@
|
||||
import nim_chacha20_poly1305/[common, chacha20_poly1305, streaming, helpers, poly1305]
|
||||
import std/[sysrand]
|
||||
import results
|
||||
import strformat
|
||||
import chronicles
|
||||
|
||||
import types
|
||||
import errors
|
||||
|
||||
|
||||
proc encryptWithChaCha20Poly1305*(msgKey: MessageKey, plaintext: var openArray[byte], associatedData: openArray[byte]) : (Nonce, CipherText) =
|
||||
|
||||
var nonce : Nonce
|
||||
discard urandom(nonce)
|
||||
|
||||
var tag: Tag
|
||||
var ciphertext = newSeq[byte](plaintext.len + tag.len)
|
||||
|
||||
var counter : Counter = 0
|
||||
|
||||
# TODO: check plaintext mutability requirement
|
||||
chacha20_aead_poly1305_encrypt(
|
||||
Key(msgKey),
|
||||
nonce,
|
||||
counter,
|
||||
associatedData,
|
||||
plaintext,
|
||||
ciphertext.toOpenArray(0, plaintext.high),
|
||||
tag
|
||||
)
|
||||
|
||||
# Combine tag with cipherkey for ease of transport and consistency with other implementations
|
||||
copyMem(addr ciphertext[plaintext.len], unsafeAddr tag[0], tag.len)
|
||||
(nonce, ciphertext)
|
||||
|
||||
|
||||
proc decryptWithChaCha20Poly1305*(msgKey: MessageKey, nonce: Nonce, ciphertext: var openArray[byte], associatedData: openArray[byte]) : Result[seq[byte], NaxolotlError] =
|
||||
var tag : Tag
|
||||
if ciphertext.len <= tag.len:
|
||||
return err(NaxolotlError(code: errInvalidInput, context: fmt"ciphertext is less than {tag.len} bytes. Expected `ciphertext || tag`" ))
|
||||
|
||||
copyMem(addr tag[0], unsafeAddr ciphertext[^tag.len], tag.len)
|
||||
|
||||
var plaintext = newSeq[byte](ciphertext.len - tag.len)
|
||||
|
||||
var computedTag: Tag
|
||||
var counter : Counter = 0
|
||||
|
||||
chacha20_aead_poly1305_decrypt(
|
||||
Key(msgKey),
|
||||
nonce,
|
||||
counter,
|
||||
associatedData,
|
||||
plaintext,
|
||||
ciphertext.toOpenArray(0,ciphertext.high - tag.len),
|
||||
computedTag
|
||||
)
|
||||
|
||||
if not poly1305_verify(tag, computedTag):
|
||||
return err(NaxolotlError(code: errMessageAuthentication, context: fmt"Got Tag: {tag} expected: {computedTag}"))
|
||||
|
||||
|
||||
ok(plaintext)
|
||||
|
||||
123
src/naxolotl/curve25519.nim
Normal file
123
src/naxolotl/curve25519.nim
Normal file
@ -0,0 +1,123 @@
|
||||
# See https://github.com/vacp2p/nim-libp2p/blob/master/libp2p/crypto/curve25519.nim
|
||||
|
||||
import bearssl/[ec, rand]
|
||||
import results
|
||||
from stew/assign2 import assign
|
||||
export results
|
||||
|
||||
const Curve25519KeySize* = 32
|
||||
|
||||
type
|
||||
Curve25519* = object
|
||||
Curve25519Key* = array[Curve25519KeySize, byte]
|
||||
Curve25519Error* = enum
|
||||
Curver25519GenError
|
||||
|
||||
proc intoCurve25519Key*(s: openArray[byte]): Curve25519Key =
|
||||
assert s.len == Curve25519KeySize
|
||||
assign(result, s)
|
||||
|
||||
proc getBytes*(key: Curve25519Key): seq[byte] =
|
||||
@key
|
||||
|
||||
proc byteswap(buf: var Curve25519Key) {.inline.} =
|
||||
for i in 0 ..< 16:
|
||||
let x = buf[i]
|
||||
buf[i] = buf[31 - i]
|
||||
buf[31 - i] = x
|
||||
|
||||
proc mul*(_: type[Curve25519], point: var Curve25519Key, multiplier: Curve25519Key) =
|
||||
let defaultBrEc = ecGetDefault()
|
||||
|
||||
# multiplier needs to be big-endian
|
||||
var multiplierBs = multiplier
|
||||
multiplierBs.byteswap()
|
||||
let res = defaultBrEc.mul(
|
||||
addr point[0],
|
||||
Curve25519KeySize,
|
||||
addr multiplierBs[0],
|
||||
Curve25519KeySize,
|
||||
EC_curve25519,
|
||||
)
|
||||
assert res == 1
|
||||
|
||||
proc mulgen(_: type[Curve25519], dst: var Curve25519Key, point: Curve25519Key) =
|
||||
let defaultBrEc = ecGetDefault()
|
||||
|
||||
var rpoint = point
|
||||
rpoint.byteswap()
|
||||
|
||||
let size =
|
||||
defaultBrEc.mulgen(addr dst[0], addr rpoint[0], Curve25519KeySize, EC_curve25519)
|
||||
|
||||
assert size == Curve25519KeySize
|
||||
|
||||
proc public*(private: Curve25519Key): Curve25519Key =
|
||||
Curve25519.mulgen(result, private)
|
||||
|
||||
proc random*(_: type[Curve25519Key], rng: var HmacDrbgContext): Curve25519Key =
|
||||
var res: Curve25519Key
|
||||
let defaultBrEc = ecGetDefault()
|
||||
let len = ecKeygen(
|
||||
PrngClassPointerConst(addr rng.vtable), defaultBrEc, nil, addr res[0], EC_curve25519
|
||||
)
|
||||
# Per bearssl documentation, the keygen only fails if the curve is
|
||||
# unrecognised -
|
||||
doAssert len == Curve25519KeySize, "Could not generate curve"
|
||||
|
||||
res
|
||||
|
||||
const FieldElementSize* = Curve25519KeySize
|
||||
|
||||
type FieldElement* = Curve25519Key
|
||||
|
||||
# Convert bytes to FieldElement
|
||||
proc bytesToFieldElement*(bytes: openArray[byte]): Result[FieldElement, string] =
|
||||
if bytes.len != FieldElementSize:
|
||||
return err("Field element size must be 32 bytes")
|
||||
ok(intoCurve25519Key(bytes))
|
||||
|
||||
# Convert FieldElement to bytes
|
||||
proc fieldElementToBytes*(fe: FieldElement): seq[byte] =
|
||||
fe.getBytes()
|
||||
|
||||
# Generate a random FieldElement
|
||||
proc generateRandomFieldElement*(): Result[FieldElement, string] =
|
||||
let rng = HmacDrbgContext.new()
|
||||
if rng.isNil:
|
||||
return err("Failed to creat HmacDrbgContext with system randomness")
|
||||
ok(Curve25519Key.random(rng[]))
|
||||
|
||||
# Generate a key pair (private key and public key are both FieldElements)
|
||||
proc generateKeyPair*(): Result[tuple[privateKey, publicKey: FieldElement], string] =
|
||||
let privateKeyRes = generateRandomFieldElement()
|
||||
if privateKeyRes.isErr:
|
||||
return err(privateKeyRes.error)
|
||||
let privateKey = privateKeyRes.get()
|
||||
|
||||
let publicKey = public(privateKey)
|
||||
ok((privateKey, publicKey))
|
||||
|
||||
# # Multiply a given Curve25519 point with a set of scalars
|
||||
# proc multiplyPointWithScalars*(
|
||||
# point: FieldElement, scalars: openArray[FieldElement]
|
||||
# ): FieldElement =
|
||||
# var res = point
|
||||
# for scalar in scalars:
|
||||
# Curve25519.mul(res, scalar)
|
||||
# res
|
||||
|
||||
# # Multiply the Curve25519 base point with a set of scalars
|
||||
# proc multiplyBasePointWithScalars*(
|
||||
# scalars: openArray[FieldElement]
|
||||
# ): Result[FieldElement, string] =
|
||||
# if scalars.len <= 0:
|
||||
# return err("Atleast one scalar must be provided")
|
||||
# var res: FieldElement = public(scalars[0]) # Use the predefined base point
|
||||
# for i in 1 ..< scalars.len:
|
||||
# Curve25519.mul(res, scalars[i]) # Multiply with each scalar
|
||||
# ok(res)
|
||||
|
||||
# # Compare two FieldElements
|
||||
# proc compareFieldElements*(a, b: FieldElement): bool =
|
||||
# a == b
|
||||
11
src/naxolotl/errors.nim
Normal file
11
src/naxolotl/errors.nim
Normal file
@ -0,0 +1,11 @@
|
||||
|
||||
type
|
||||
NaxolotlError* = object of CatchableError
|
||||
code*: ErrorCode
|
||||
context*: string
|
||||
|
||||
ErrorCode* = enum
|
||||
errDecryption
|
||||
errMessageAuthentication
|
||||
errInvalidInput
|
||||
errProgram
|
||||
197
src/naxolotl/naxolotl.nim
Normal file
197
src/naxolotl/naxolotl.nim
Normal file
@ -0,0 +1,197 @@
|
||||
import curve25519
|
||||
import results
|
||||
import chronicles
|
||||
import nim_chacha20_poly1305/[common,helpers]
|
||||
|
||||
import tables
|
||||
|
||||
import chacha
|
||||
import types
|
||||
import utils
|
||||
import errors
|
||||
|
||||
# converter toArray*(x: RootKey): array[32, byte] =
|
||||
# array[32,byte](x)
|
||||
|
||||
# converter toArray*(x: DhDerivedKey): array[32, byte] =
|
||||
# array[32,byte](x)
|
||||
|
||||
|
||||
|
||||
const maxSkip = 10
|
||||
|
||||
|
||||
|
||||
type Doubleratchet* = object
|
||||
dhSelf: PrivateKey
|
||||
dhRemote: PublicKey
|
||||
|
||||
rootKey: RootKey
|
||||
chainKeySend: ChainKey
|
||||
chainKeyRecv: ChainKey
|
||||
|
||||
msgCountSend: MsgCount
|
||||
msgCountRecv: MsgCount
|
||||
prevChainLen: MsgCount
|
||||
|
||||
# TODO: SkippedKeys
|
||||
skippedMessageKeys: Table[(PublicKey,MsgCount), MessageKey]
|
||||
|
||||
const DomainSepKdfRoot = "DoubleRatchet"
|
||||
const DomainSepKdfMsg = "MessageKey"
|
||||
const DomainSepKdfChain = "ChainKey"
|
||||
|
||||
|
||||
|
||||
type DrHeader* = object
|
||||
dhPublic: PublicKey
|
||||
msgNumber: uint32
|
||||
prevChainLength: uint32
|
||||
|
||||
|
||||
|
||||
func keyId(dh:PublicKey, recvCount: MsgCount ): KeyId =
|
||||
(dh, recvCount)
|
||||
|
||||
#################################################
|
||||
# Kdf
|
||||
#################################################
|
||||
|
||||
|
||||
func kdfRoot(self: var Doubleratchet, rootKey: RootKey, dhOutput:DhDerivedKey): (RootKey, ChainKey) =
|
||||
|
||||
var salt = rootKey
|
||||
var ikm = dhOutput
|
||||
let info = cast[seq[byte]](DomainSepKdfRoot)
|
||||
|
||||
hkdfSplit(salt, ikm, info)
|
||||
|
||||
func kdfChain(self: Doubleratchet, chainKey: ChainKey): (MessageKey, ChainKey) =
|
||||
|
||||
let msgKey = hkdfExtract(chainKey, [0x01u8], cast[seq[byte]](DomainSepKdfMsg))
|
||||
let chainKey = hkdfExtract(chainKey, [0x02u8], cast[seq[byte]](DomainSepKdfChain))
|
||||
|
||||
return(msgKey, chainKey)
|
||||
|
||||
func dhRatchetSend(self: var Doubleratchet) =
|
||||
# Perform DH Ratchet step when receiving a new peer key.
|
||||
let dhOutput : DhDerivedKey = dhExchange(self.dhSelf, self.dhRemote).get()
|
||||
let (newRootKey, newChainKeySend) = kdfRoot(self, self.rootKey, dhOutput)
|
||||
self.rootKey = newRootKey
|
||||
self.chainKeySend = newChainKeySend
|
||||
self.msgCountSend = 0
|
||||
|
||||
proc dhRatchetRecv(self: var Doubleratchet, remotePublickey: PublicKey ) =
|
||||
self.prevChainLen = self.msgCountSend
|
||||
self.msgCountSend = 0
|
||||
self.msgCountRecv = 0
|
||||
|
||||
self.dhRemote = remotePublickey
|
||||
|
||||
let dhOutputPre = self.dhSelf.dhExchange(self.dhRemote).get()
|
||||
let (newRootKey, newChainKeyRecv) = kdfRoot(self, self.rootKey, dhOutputPre)
|
||||
self.rootKey = newRootKey
|
||||
self.chainKeyRecv = newChainKeyRecv
|
||||
|
||||
self.dhSelf = generateKeypair().get()[0]
|
||||
|
||||
let dhOutputPost = self.dhSelf.dhExchange(self.dhRemote).get()
|
||||
(self.rootKey, self.chainKeyRecv) = kdfRoot(self, self.rootKey, dhOutputPost)
|
||||
# let (newRootKey, newChainKeySend) = kdfRoot(self, self.rootKey, dhOutputPost)
|
||||
# self.rootKey = newRootKey
|
||||
# self.chainKeyRecv = newChainKeySend
|
||||
|
||||
|
||||
proc skipMessageKeys(self: var Doubleratchet, until: MsgCount): Result[(), string] =
|
||||
|
||||
if self.msgCountRecv + maxSkip < until:
|
||||
return err("Too many skipped messages")
|
||||
|
||||
while self.msgCountRecv < until:
|
||||
let (msgKey, chainKey) = self.kdfChain(self.chainKeyRecv)
|
||||
self.chainKeyRecv = chainKey
|
||||
|
||||
let keyId = keyId(self.dhRemote, self.msgCountRecv)
|
||||
self.skippedMessageKeys[keyId] = msgKey
|
||||
inc self.msgCountRecv
|
||||
|
||||
ok(())
|
||||
|
||||
proc encrypt(self: var Doubleratchet, plaintext: var seq[byte], associatedData: openArray[byte]): (DrHeader, CipherText) =
|
||||
|
||||
let (msgKey, chainKey) = self.kdfChain(self.chainKeySend)
|
||||
|
||||
let header = DrHeader(
|
||||
dhPublic: self.dhSelf.public, #TODO Serialize
|
||||
msgNumber: self.msgCountSend,
|
||||
prevChainLength: self.prevChainLen)
|
||||
|
||||
self.msgCountSend = self.msgCountSend + 1
|
||||
|
||||
|
||||
var (nonce, ciphertext) = encryptWithChaCha20Poly1305(msgKey, plaintext, associatedData)
|
||||
|
||||
# TODO: optimize copies
|
||||
var output : seq[byte]
|
||||
output.add(nonce)
|
||||
output.add(ciphertext)
|
||||
|
||||
|
||||
(header, output)
|
||||
|
||||
|
||||
proc decrypt*(self: var Doubleratchet, header: DrHeader, ciphertext: CipherText, associatedData: openArray[byte] ) : Result[seq[byte], NaxolotlError] =
|
||||
|
||||
let peerPublic = header.dhPublic
|
||||
|
||||
var msgKey : MessageKey
|
||||
|
||||
# Check Skipped Keys
|
||||
let keyId = keyId(header.dhPublic, header.msgNumber)
|
||||
if self.skippedMessageKeys.hasKey(keyId):
|
||||
debug "detected skipped message", keyId = keyId
|
||||
msgKey = self.skippedMessageKeys[keyId]
|
||||
else:
|
||||
if (peerPublic != self.dhRemote):
|
||||
let r = self.skipMessageKeys(header.prevChainLength)
|
||||
if r.isErr:
|
||||
error "skipMessages", error = r.error()
|
||||
self.dhRatchetRecv(peerPublic)
|
||||
let r = self.skipMessageKeys(header.msgNumber)
|
||||
if r.isErr:
|
||||
error "skipMessages", error = r.error()
|
||||
|
||||
(msgKey, self.chainKeyRecv) = self.kdfChain(self.chainKeyRecv)
|
||||
inc self.msgCountRecv
|
||||
|
||||
var nonce : Nonce
|
||||
copyMem(addr nonce[0], unsafeAddr ciphertext[0], Nonce.len)
|
||||
var cipherTag = ciphertext[Nonce.len .. ^1]
|
||||
|
||||
result = decryptWithChaCha20Poly1305(msgKey,nonce, cipherTag, associatedData )
|
||||
|
||||
if result.isOk:
|
||||
# TODO: persist chainKey state changes
|
||||
self.skippedMessageKeys.del(keyId)
|
||||
|
||||
|
||||
proc encrypt*(self: var Doubleratchet, plaintext: var seq[byte]) : (DrHeader, CipherText) =
|
||||
encrypt(self, plaintext,@[])
|
||||
|
||||
|
||||
func initDoubleratchet*(sharedSecret: array[32, byte], dhSelf: PrivateKey, dhRemote: PublicKey, isSending: bool = true): Doubleratchet =
|
||||
|
||||
result = Doubleratchet(
|
||||
dhSelf: dhSelf,
|
||||
dhRemote: dhRemote,
|
||||
rootKey: RootKey(sharedSecret),
|
||||
msgCountSend: 0,
|
||||
msgCountRecv: 0,
|
||||
prevChainLen: 0,
|
||||
skippedMessageKeys: initTable[(PublicKey, MsgCount), MessageKey]()
|
||||
# chainKeySend: none
|
||||
# chainKeyRecv: none
|
||||
)
|
||||
|
||||
if isSending:
|
||||
result.dhRatchetSend()
|
||||
18
src/naxolotl/types.nim
Normal file
18
src/naxolotl/types.nim
Normal file
@ -0,0 +1,18 @@
|
||||
import curve25519
|
||||
|
||||
type PrivateKey* = Curve25519Key
|
||||
type PublicKey* = Curve25519Key
|
||||
|
||||
type RootKey* = array[32, byte]
|
||||
type ChainKey* = array[32, byte]
|
||||
type MessageKey* = array[32, byte]
|
||||
type DhDerivedKey* = array[32, byte]
|
||||
|
||||
type GenericArray* = array[32, byte]
|
||||
|
||||
type CipherText* = seq[byte]
|
||||
|
||||
type MsgCount* = uint32
|
||||
type KeyId* = (PublicKey, MsgCount)
|
||||
|
||||
const KeyLen* = 32
|
||||
55
src/naxolotl/utils.nim
Normal file
55
src/naxolotl/utils.nim
Normal file
@ -0,0 +1,55 @@
|
||||
import constantine/hashes
|
||||
import constantine/kdf/kdf_hkdf
|
||||
import curve25519
|
||||
import results
|
||||
|
||||
import errors
|
||||
import types
|
||||
|
||||
|
||||
func hkdfExtract*(salt: openArray[byte], ikm: openArray[byte], info: openArray[byte] ) : GenericArray =
|
||||
|
||||
assert GenericArray.len == sha256.digestSize()
|
||||
|
||||
var ctx{.noInit.}: HKDF[sha256]
|
||||
var prk{.noInit.}: array[sha256.digestSize(), byte]
|
||||
ctx.hkdfExtract(prk, salt, ikm)
|
||||
|
||||
return prk
|
||||
|
||||
|
||||
|
||||
func hkdfExtractExpand*(output: var openArray[byte], salt: openArray[byte], ikm: openArray[byte], info: openArray[byte] ) =
|
||||
var ctx{.noInit.}: HKDF[sha256]
|
||||
var prk{.noInit.}: array[sha256.digestSize(), byte]
|
||||
ctx.hkdfExtract(prk, salt, ikm)
|
||||
ctx.hkdfExpand(output, prk, info, true)
|
||||
|
||||
|
||||
func hkdfSplit*(salt: GenericArray, ikm: GenericArray, info: openArray[byte] ) : (RootKey, ChainKey) =
|
||||
|
||||
var output : array[KeyLen*2 , byte]
|
||||
|
||||
hkdfExtractExpand(output, salt, ikm, info)
|
||||
|
||||
var out1 : array[KeyLen, byte]
|
||||
var out2 : array[KeyLen, byte]
|
||||
|
||||
# Unsafe memcopy
|
||||
copyMem(addr output[0], unsafeAddr out1[0], KeyLen)
|
||||
copyMem(addr output[32], unsafeAddr out2[0], KeyLen)
|
||||
|
||||
result = (out1,out2)
|
||||
|
||||
|
||||
|
||||
func dhExchange*(a: PrivateKey, b: PublicKey): Result[DhDerivedKey, NaxolotlError] =
|
||||
var dhOuput = b
|
||||
|
||||
try:
|
||||
Curve25519.mul(dhOuput, a)
|
||||
except CatchableError as e:
|
||||
return err(NaxolotlError( code: errProgram, context: e.msg))
|
||||
ok(DhDerivedKey(dhOuput))
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user