Initial impl

This commit is contained in:
Jazz Turner-Baggs 2025-10-07 18:41:45 -07:00
parent 1c03bd8c7a
commit 520ce1cb80
7 changed files with 474 additions and 0 deletions

6
src/naxolotl.nim Normal file
View File

@ -0,0 +1,6 @@
import naxolotl/[
naxolotl,
curve25519
]
export naxolotl, curve25519

64
src/naxolotl/chacha.nim Normal file
View File

@ -0,0 +1,64 @@
import nim_chacha20_poly1305/[common, chacha20_poly1305, streaming, helpers, poly1305]
import std/[sysrand]
import results
import strformat
import chronicles
import types
import errors
proc encryptWithChaCha20Poly1305*(msgKey: MessageKey, plaintext: var openArray[byte], associatedData: openArray[byte]) : (Nonce, CipherText) =
var nonce : Nonce
discard urandom(nonce)
var tag: Tag
var ciphertext = newSeq[byte](plaintext.len + tag.len)
var counter : Counter = 0
# TODO: check plaintext mutability requirement
chacha20_aead_poly1305_encrypt(
Key(msgKey),
nonce,
counter,
associatedData,
plaintext,
ciphertext.toOpenArray(0, plaintext.high),
tag
)
# Combine tag with cipherkey for ease of transport and consistency with other implementations
copyMem(addr ciphertext[plaintext.len], unsafeAddr tag[0], tag.len)
(nonce, ciphertext)
proc decryptWithChaCha20Poly1305*(msgKey: MessageKey, nonce: Nonce, ciphertext: var openArray[byte], associatedData: openArray[byte]) : Result[seq[byte], NaxolotlError] =
var tag : Tag
if ciphertext.len <= tag.len:
return err(NaxolotlError(code: errInvalidInput, context: fmt"ciphertext is less than {tag.len} bytes. Expected `ciphertext || tag`" ))
copyMem(addr tag[0], unsafeAddr ciphertext[^tag.len], tag.len)
var plaintext = newSeq[byte](ciphertext.len - tag.len)
var computedTag: Tag
var counter : Counter = 0
chacha20_aead_poly1305_decrypt(
Key(msgKey),
nonce,
counter,
associatedData,
plaintext,
ciphertext.toOpenArray(0,ciphertext.high - tag.len),
computedTag
)
if not poly1305_verify(tag, computedTag):
return err(NaxolotlError(code: errMessageAuthentication, context: fmt"Got Tag: {tag} expected: {computedTag}"))
ok(plaintext)

123
src/naxolotl/curve25519.nim Normal file
View File

@ -0,0 +1,123 @@
# See https://github.com/vacp2p/nim-libp2p/blob/master/libp2p/crypto/curve25519.nim
import bearssl/[ec, rand]
import results
from stew/assign2 import assign
export results
const Curve25519KeySize* = 32
type
Curve25519* = object
Curve25519Key* = array[Curve25519KeySize, byte]
Curve25519Error* = enum
Curver25519GenError
proc intoCurve25519Key*(s: openArray[byte]): Curve25519Key =
assert s.len == Curve25519KeySize
assign(result, s)
proc getBytes*(key: Curve25519Key): seq[byte] =
@key
proc byteswap(buf: var Curve25519Key) {.inline.} =
for i in 0 ..< 16:
let x = buf[i]
buf[i] = buf[31 - i]
buf[31 - i] = x
proc mul*(_: type[Curve25519], point: var Curve25519Key, multiplier: Curve25519Key) =
let defaultBrEc = ecGetDefault()
# multiplier needs to be big-endian
var multiplierBs = multiplier
multiplierBs.byteswap()
let res = defaultBrEc.mul(
addr point[0],
Curve25519KeySize,
addr multiplierBs[0],
Curve25519KeySize,
EC_curve25519,
)
assert res == 1
proc mulgen(_: type[Curve25519], dst: var Curve25519Key, point: Curve25519Key) =
let defaultBrEc = ecGetDefault()
var rpoint = point
rpoint.byteswap()
let size =
defaultBrEc.mulgen(addr dst[0], addr rpoint[0], Curve25519KeySize, EC_curve25519)
assert size == Curve25519KeySize
proc public*(private: Curve25519Key): Curve25519Key =
Curve25519.mulgen(result, private)
proc random*(_: type[Curve25519Key], rng: var HmacDrbgContext): Curve25519Key =
var res: Curve25519Key
let defaultBrEc = ecGetDefault()
let len = ecKeygen(
PrngClassPointerConst(addr rng.vtable), defaultBrEc, nil, addr res[0], EC_curve25519
)
# Per bearssl documentation, the keygen only fails if the curve is
# unrecognised -
doAssert len == Curve25519KeySize, "Could not generate curve"
res
const FieldElementSize* = Curve25519KeySize
type FieldElement* = Curve25519Key
# Convert bytes to FieldElement
proc bytesToFieldElement*(bytes: openArray[byte]): Result[FieldElement, string] =
if bytes.len != FieldElementSize:
return err("Field element size must be 32 bytes")
ok(intoCurve25519Key(bytes))
# Convert FieldElement to bytes
proc fieldElementToBytes*(fe: FieldElement): seq[byte] =
fe.getBytes()
# Generate a random FieldElement
proc generateRandomFieldElement*(): Result[FieldElement, string] =
let rng = HmacDrbgContext.new()
if rng.isNil:
return err("Failed to creat HmacDrbgContext with system randomness")
ok(Curve25519Key.random(rng[]))
# Generate a key pair (private key and public key are both FieldElements)
proc generateKeyPair*(): Result[tuple[privateKey, publicKey: FieldElement], string] =
let privateKeyRes = generateRandomFieldElement()
if privateKeyRes.isErr:
return err(privateKeyRes.error)
let privateKey = privateKeyRes.get()
let publicKey = public(privateKey)
ok((privateKey, publicKey))
# # Multiply a given Curve25519 point with a set of scalars
# proc multiplyPointWithScalars*(
# point: FieldElement, scalars: openArray[FieldElement]
# ): FieldElement =
# var res = point
# for scalar in scalars:
# Curve25519.mul(res, scalar)
# res
# # Multiply the Curve25519 base point with a set of scalars
# proc multiplyBasePointWithScalars*(
# scalars: openArray[FieldElement]
# ): Result[FieldElement, string] =
# if scalars.len <= 0:
# return err("Atleast one scalar must be provided")
# var res: FieldElement = public(scalars[0]) # Use the predefined base point
# for i in 1 ..< scalars.len:
# Curve25519.mul(res, scalars[i]) # Multiply with each scalar
# ok(res)
# # Compare two FieldElements
# proc compareFieldElements*(a, b: FieldElement): bool =
# a == b

11
src/naxolotl/errors.nim Normal file
View File

@ -0,0 +1,11 @@
type
NaxolotlError* = object of CatchableError
code*: ErrorCode
context*: string
ErrorCode* = enum
errDecryption
errMessageAuthentication
errInvalidInput
errProgram

197
src/naxolotl/naxolotl.nim Normal file
View File

@ -0,0 +1,197 @@
import curve25519
import results
import chronicles
import nim_chacha20_poly1305/[common,helpers]
import tables
import chacha
import types
import utils
import errors
# converter toArray*(x: RootKey): array[32, byte] =
# array[32,byte](x)
# converter toArray*(x: DhDerivedKey): array[32, byte] =
# array[32,byte](x)
const maxSkip = 10
type Doubleratchet* = object
dhSelf: PrivateKey
dhRemote: PublicKey
rootKey: RootKey
chainKeySend: ChainKey
chainKeyRecv: ChainKey
msgCountSend: MsgCount
msgCountRecv: MsgCount
prevChainLen: MsgCount
# TODO: SkippedKeys
skippedMessageKeys: Table[(PublicKey,MsgCount), MessageKey]
const DomainSepKdfRoot = "DoubleRatchet"
const DomainSepKdfMsg = "MessageKey"
const DomainSepKdfChain = "ChainKey"
type DrHeader* = object
dhPublic: PublicKey
msgNumber: uint32
prevChainLength: uint32
func keyId(dh:PublicKey, recvCount: MsgCount ): KeyId =
(dh, recvCount)
#################################################
# Kdf
#################################################
func kdfRoot(self: var Doubleratchet, rootKey: RootKey, dhOutput:DhDerivedKey): (RootKey, ChainKey) =
var salt = rootKey
var ikm = dhOutput
let info = cast[seq[byte]](DomainSepKdfRoot)
hkdfSplit(salt, ikm, info)
func kdfChain(self: Doubleratchet, chainKey: ChainKey): (MessageKey, ChainKey) =
let msgKey = hkdfExtract(chainKey, [0x01u8], cast[seq[byte]](DomainSepKdfMsg))
let chainKey = hkdfExtract(chainKey, [0x02u8], cast[seq[byte]](DomainSepKdfChain))
return(msgKey, chainKey)
func dhRatchetSend(self: var Doubleratchet) =
# Perform DH Ratchet step when receiving a new peer key.
let dhOutput : DhDerivedKey = dhExchange(self.dhSelf, self.dhRemote).get()
let (newRootKey, newChainKeySend) = kdfRoot(self, self.rootKey, dhOutput)
self.rootKey = newRootKey
self.chainKeySend = newChainKeySend
self.msgCountSend = 0
proc dhRatchetRecv(self: var Doubleratchet, remotePublickey: PublicKey ) =
self.prevChainLen = self.msgCountSend
self.msgCountSend = 0
self.msgCountRecv = 0
self.dhRemote = remotePublickey
let dhOutputPre = self.dhSelf.dhExchange(self.dhRemote).get()
let (newRootKey, newChainKeyRecv) = kdfRoot(self, self.rootKey, dhOutputPre)
self.rootKey = newRootKey
self.chainKeyRecv = newChainKeyRecv
self.dhSelf = generateKeypair().get()[0]
let dhOutputPost = self.dhSelf.dhExchange(self.dhRemote).get()
(self.rootKey, self.chainKeyRecv) = kdfRoot(self, self.rootKey, dhOutputPost)
# let (newRootKey, newChainKeySend) = kdfRoot(self, self.rootKey, dhOutputPost)
# self.rootKey = newRootKey
# self.chainKeyRecv = newChainKeySend
proc skipMessageKeys(self: var Doubleratchet, until: MsgCount): Result[(), string] =
if self.msgCountRecv + maxSkip < until:
return err("Too many skipped messages")
while self.msgCountRecv < until:
let (msgKey, chainKey) = self.kdfChain(self.chainKeyRecv)
self.chainKeyRecv = chainKey
let keyId = keyId(self.dhRemote, self.msgCountRecv)
self.skippedMessageKeys[keyId] = msgKey
inc self.msgCountRecv
ok(())
proc encrypt(self: var Doubleratchet, plaintext: var seq[byte], associatedData: openArray[byte]): (DrHeader, CipherText) =
let (msgKey, chainKey) = self.kdfChain(self.chainKeySend)
let header = DrHeader(
dhPublic: self.dhSelf.public, #TODO Serialize
msgNumber: self.msgCountSend,
prevChainLength: self.prevChainLen)
self.msgCountSend = self.msgCountSend + 1
var (nonce, ciphertext) = encryptWithChaCha20Poly1305(msgKey, plaintext, associatedData)
# TODO: optimize copies
var output : seq[byte]
output.add(nonce)
output.add(ciphertext)
(header, output)
proc decrypt*(self: var Doubleratchet, header: DrHeader, ciphertext: CipherText, associatedData: openArray[byte] ) : Result[seq[byte], NaxolotlError] =
let peerPublic = header.dhPublic
var msgKey : MessageKey
# Check Skipped Keys
let keyId = keyId(header.dhPublic, header.msgNumber)
if self.skippedMessageKeys.hasKey(keyId):
debug "detected skipped message", keyId = keyId
msgKey = self.skippedMessageKeys[keyId]
else:
if (peerPublic != self.dhRemote):
let r = self.skipMessageKeys(header.prevChainLength)
if r.isErr:
error "skipMessages", error = r.error()
self.dhRatchetRecv(peerPublic)
let r = self.skipMessageKeys(header.msgNumber)
if r.isErr:
error "skipMessages", error = r.error()
(msgKey, self.chainKeyRecv) = self.kdfChain(self.chainKeyRecv)
inc self.msgCountRecv
var nonce : Nonce
copyMem(addr nonce[0], unsafeAddr ciphertext[0], Nonce.len)
var cipherTag = ciphertext[Nonce.len .. ^1]
result = decryptWithChaCha20Poly1305(msgKey,nonce, cipherTag, associatedData )
if result.isOk:
# TODO: persist chainKey state changes
self.skippedMessageKeys.del(keyId)
proc encrypt*(self: var Doubleratchet, plaintext: var seq[byte]) : (DrHeader, CipherText) =
encrypt(self, plaintext,@[])
func initDoubleratchet*(sharedSecret: array[32, byte], dhSelf: PrivateKey, dhRemote: PublicKey, isSending: bool = true): Doubleratchet =
result = Doubleratchet(
dhSelf: dhSelf,
dhRemote: dhRemote,
rootKey: RootKey(sharedSecret),
msgCountSend: 0,
msgCountRecv: 0,
prevChainLen: 0,
skippedMessageKeys: initTable[(PublicKey, MsgCount), MessageKey]()
# chainKeySend: none
# chainKeyRecv: none
)
if isSending:
result.dhRatchetSend()

18
src/naxolotl/types.nim Normal file
View File

@ -0,0 +1,18 @@
import curve25519
type PrivateKey* = Curve25519Key
type PublicKey* = Curve25519Key
type RootKey* = array[32, byte]
type ChainKey* = array[32, byte]
type MessageKey* = array[32, byte]
type DhDerivedKey* = array[32, byte]
type GenericArray* = array[32, byte]
type CipherText* = seq[byte]
type MsgCount* = uint32
type KeyId* = (PublicKey, MsgCount)
const KeyLen* = 32

55
src/naxolotl/utils.nim Normal file
View File

@ -0,0 +1,55 @@
import constantine/hashes
import constantine/kdf/kdf_hkdf
import curve25519
import results
import errors
import types
func hkdfExtract*(salt: openArray[byte], ikm: openArray[byte], info: openArray[byte] ) : GenericArray =
assert GenericArray.len == sha256.digestSize()
var ctx{.noInit.}: HKDF[sha256]
var prk{.noInit.}: array[sha256.digestSize(), byte]
ctx.hkdfExtract(prk, salt, ikm)
return prk
func hkdfExtractExpand*(output: var openArray[byte], salt: openArray[byte], ikm: openArray[byte], info: openArray[byte] ) =
var ctx{.noInit.}: HKDF[sha256]
var prk{.noInit.}: array[sha256.digestSize(), byte]
ctx.hkdfExtract(prk, salt, ikm)
ctx.hkdfExpand(output, prk, info, true)
func hkdfSplit*(salt: GenericArray, ikm: GenericArray, info: openArray[byte] ) : (RootKey, ChainKey) =
var output : array[KeyLen*2 , byte]
hkdfExtractExpand(output, salt, ikm, info)
var out1 : array[KeyLen, byte]
var out2 : array[KeyLen, byte]
# Unsafe memcopy
copyMem(addr output[0], unsafeAddr out1[0], KeyLen)
copyMem(addr output[32], unsafeAddr out2[0], KeyLen)
result = (out1,out2)
func dhExchange*(a: PrivateKey, b: PublicKey): Result[DhDerivedKey, NaxolotlError] =
var dhOuput = b
try:
Curve25519.mul(dhOuput, a)
except CatchableError as e:
return err(NaxolotlError( code: errProgram, context: e.msg))
ok(DhDerivedKey(dhOuput))