From 520ce1cb8094dd1414616053d28ce212bf46fc2a Mon Sep 17 00:00:00 2001 From: Jazz Turner-Baggs <473256+jazzz@users.noreply.github.com> Date: Tue, 7 Oct 2025 18:41:45 -0700 Subject: [PATCH] Initial impl --- src/naxolotl.nim | 6 ++ src/naxolotl/chacha.nim | 64 ++++++++++++ src/naxolotl/curve25519.nim | 123 ++++++++++++++++++++++ src/naxolotl/errors.nim | 11 ++ src/naxolotl/naxolotl.nim | 197 ++++++++++++++++++++++++++++++++++++ src/naxolotl/types.nim | 18 ++++ src/naxolotl/utils.nim | 55 ++++++++++ 7 files changed, 474 insertions(+) create mode 100644 src/naxolotl.nim create mode 100644 src/naxolotl/chacha.nim create mode 100644 src/naxolotl/curve25519.nim create mode 100644 src/naxolotl/errors.nim create mode 100644 src/naxolotl/naxolotl.nim create mode 100644 src/naxolotl/types.nim create mode 100644 src/naxolotl/utils.nim diff --git a/src/naxolotl.nim b/src/naxolotl.nim new file mode 100644 index 0000000..7880e6e --- /dev/null +++ b/src/naxolotl.nim @@ -0,0 +1,6 @@ +import naxolotl/[ + naxolotl, + curve25519 +] + +export naxolotl, curve25519 \ No newline at end of file diff --git a/src/naxolotl/chacha.nim b/src/naxolotl/chacha.nim new file mode 100644 index 0000000..d0d8add --- /dev/null +++ b/src/naxolotl/chacha.nim @@ -0,0 +1,64 @@ +import nim_chacha20_poly1305/[common, chacha20_poly1305, streaming, helpers, poly1305] +import std/[sysrand] +import results +import strformat +import chronicles + +import types +import errors + + +proc encryptWithChaCha20Poly1305*(msgKey: MessageKey, plaintext: var openArray[byte], associatedData: openArray[byte]) : (Nonce, CipherText) = + + var nonce : Nonce + discard urandom(nonce) + + var tag: Tag + var ciphertext = newSeq[byte](plaintext.len + tag.len) + + var counter : Counter = 0 + + # TODO: check plaintext mutability requirement + chacha20_aead_poly1305_encrypt( + Key(msgKey), + nonce, + counter, + associatedData, + plaintext, + ciphertext.toOpenArray(0, plaintext.high), + tag + ) + + # Combine tag with cipherkey for ease of transport and consistency with other implementations + copyMem(addr ciphertext[plaintext.len], unsafeAddr tag[0], tag.len) + (nonce, ciphertext) + + +proc decryptWithChaCha20Poly1305*(msgKey: MessageKey, nonce: Nonce, ciphertext: var openArray[byte], associatedData: openArray[byte]) : Result[seq[byte], NaxolotlError] = + var tag : Tag + if ciphertext.len <= tag.len: + return err(NaxolotlError(code: errInvalidInput, context: fmt"ciphertext is less than {tag.len} bytes. Expected `ciphertext || tag`" )) + + copyMem(addr tag[0], unsafeAddr ciphertext[^tag.len], tag.len) + + var plaintext = newSeq[byte](ciphertext.len - tag.len) + + var computedTag: Tag + var counter : Counter = 0 + + chacha20_aead_poly1305_decrypt( + Key(msgKey), + nonce, + counter, + associatedData, + plaintext, + ciphertext.toOpenArray(0,ciphertext.high - tag.len), + computedTag + ) + + if not poly1305_verify(tag, computedTag): + return err(NaxolotlError(code: errMessageAuthentication, context: fmt"Got Tag: {tag} expected: {computedTag}")) + + + ok(plaintext) + \ No newline at end of file diff --git a/src/naxolotl/curve25519.nim b/src/naxolotl/curve25519.nim new file mode 100644 index 0000000..140b928 --- /dev/null +++ b/src/naxolotl/curve25519.nim @@ -0,0 +1,123 @@ +# See https://github.com/vacp2p/nim-libp2p/blob/master/libp2p/crypto/curve25519.nim + +import bearssl/[ec, rand] +import results +from stew/assign2 import assign +export results + +const Curve25519KeySize* = 32 + +type + Curve25519* = object + Curve25519Key* = array[Curve25519KeySize, byte] + Curve25519Error* = enum + Curver25519GenError + +proc intoCurve25519Key*(s: openArray[byte]): Curve25519Key = + assert s.len == Curve25519KeySize + assign(result, s) + +proc getBytes*(key: Curve25519Key): seq[byte] = + @key + +proc byteswap(buf: var Curve25519Key) {.inline.} = + for i in 0 ..< 16: + let x = buf[i] + buf[i] = buf[31 - i] + buf[31 - i] = x + +proc mul*(_: type[Curve25519], point: var Curve25519Key, multiplier: Curve25519Key) = + let defaultBrEc = ecGetDefault() + + # multiplier needs to be big-endian + var multiplierBs = multiplier + multiplierBs.byteswap() + let res = defaultBrEc.mul( + addr point[0], + Curve25519KeySize, + addr multiplierBs[0], + Curve25519KeySize, + EC_curve25519, + ) + assert res == 1 + +proc mulgen(_: type[Curve25519], dst: var Curve25519Key, point: Curve25519Key) = + let defaultBrEc = ecGetDefault() + + var rpoint = point + rpoint.byteswap() + + let size = + defaultBrEc.mulgen(addr dst[0], addr rpoint[0], Curve25519KeySize, EC_curve25519) + + assert size == Curve25519KeySize + +proc public*(private: Curve25519Key): Curve25519Key = + Curve25519.mulgen(result, private) + +proc random*(_: type[Curve25519Key], rng: var HmacDrbgContext): Curve25519Key = + var res: Curve25519Key + let defaultBrEc = ecGetDefault() + let len = ecKeygen( + PrngClassPointerConst(addr rng.vtable), defaultBrEc, nil, addr res[0], EC_curve25519 + ) + # Per bearssl documentation, the keygen only fails if the curve is + # unrecognised - + doAssert len == Curve25519KeySize, "Could not generate curve" + + res + +const FieldElementSize* = Curve25519KeySize + +type FieldElement* = Curve25519Key + +# Convert bytes to FieldElement +proc bytesToFieldElement*(bytes: openArray[byte]): Result[FieldElement, string] = + if bytes.len != FieldElementSize: + return err("Field element size must be 32 bytes") + ok(intoCurve25519Key(bytes)) + +# Convert FieldElement to bytes +proc fieldElementToBytes*(fe: FieldElement): seq[byte] = + fe.getBytes() + +# Generate a random FieldElement +proc generateRandomFieldElement*(): Result[FieldElement, string] = + let rng = HmacDrbgContext.new() + if rng.isNil: + return err("Failed to creat HmacDrbgContext with system randomness") + ok(Curve25519Key.random(rng[])) + +# Generate a key pair (private key and public key are both FieldElements) +proc generateKeyPair*(): Result[tuple[privateKey, publicKey: FieldElement], string] = + let privateKeyRes = generateRandomFieldElement() + if privateKeyRes.isErr: + return err(privateKeyRes.error) + let privateKey = privateKeyRes.get() + + let publicKey = public(privateKey) + ok((privateKey, publicKey)) + +# # Multiply a given Curve25519 point with a set of scalars +# proc multiplyPointWithScalars*( +# point: FieldElement, scalars: openArray[FieldElement] +# ): FieldElement = +# var res = point +# for scalar in scalars: +# Curve25519.mul(res, scalar) +# res + +# # Multiply the Curve25519 base point with a set of scalars +# proc multiplyBasePointWithScalars*( +# scalars: openArray[FieldElement] +# ): Result[FieldElement, string] = +# if scalars.len <= 0: +# return err("Atleast one scalar must be provided") +# var res: FieldElement = public(scalars[0]) # Use the predefined base point +# for i in 1 ..< scalars.len: +# Curve25519.mul(res, scalars[i]) # Multiply with each scalar +# ok(res) + +# # Compare two FieldElements +# proc compareFieldElements*(a, b: FieldElement): bool = +# a == b diff --git a/src/naxolotl/errors.nim b/src/naxolotl/errors.nim new file mode 100644 index 0000000..57ac784 --- /dev/null +++ b/src/naxolotl/errors.nim @@ -0,0 +1,11 @@ + +type + NaxolotlError* = object of CatchableError + code*: ErrorCode + context*: string + + ErrorCode* = enum + errDecryption + errMessageAuthentication + errInvalidInput + errProgram \ No newline at end of file diff --git a/src/naxolotl/naxolotl.nim b/src/naxolotl/naxolotl.nim new file mode 100644 index 0000000..d0f9ff6 --- /dev/null +++ b/src/naxolotl/naxolotl.nim @@ -0,0 +1,197 @@ +import curve25519 +import results +import chronicles +import nim_chacha20_poly1305/[common,helpers] + +import tables + +import chacha +import types +import utils +import errors + +# converter toArray*(x: RootKey): array[32, byte] = +# array[32,byte](x) + +# converter toArray*(x: DhDerivedKey): array[32, byte] = +# array[32,byte](x) + + + +const maxSkip = 10 + + + +type Doubleratchet* = object + dhSelf: PrivateKey + dhRemote: PublicKey + + rootKey: RootKey + chainKeySend: ChainKey + chainKeyRecv: ChainKey + + msgCountSend: MsgCount + msgCountRecv: MsgCount + prevChainLen: MsgCount + + # TODO: SkippedKeys + skippedMessageKeys: Table[(PublicKey,MsgCount), MessageKey] + +const DomainSepKdfRoot = "DoubleRatchet" +const DomainSepKdfMsg = "MessageKey" +const DomainSepKdfChain = "ChainKey" + + + +type DrHeader* = object + dhPublic: PublicKey + msgNumber: uint32 + prevChainLength: uint32 + + + +func keyId(dh:PublicKey, recvCount: MsgCount ): KeyId = + (dh, recvCount) + +################################################# +# Kdf +################################################# + + +func kdfRoot(self: var Doubleratchet, rootKey: RootKey, dhOutput:DhDerivedKey): (RootKey, ChainKey) = + + var salt = rootKey + var ikm = dhOutput + let info = cast[seq[byte]](DomainSepKdfRoot) + + hkdfSplit(salt, ikm, info) + +func kdfChain(self: Doubleratchet, chainKey: ChainKey): (MessageKey, ChainKey) = + + let msgKey = hkdfExtract(chainKey, [0x01u8], cast[seq[byte]](DomainSepKdfMsg)) + let chainKey = hkdfExtract(chainKey, [0x02u8], cast[seq[byte]](DomainSepKdfChain)) + + return(msgKey, chainKey) + +func dhRatchetSend(self: var Doubleratchet) = + # Perform DH Ratchet step when receiving a new peer key. + let dhOutput : DhDerivedKey = dhExchange(self.dhSelf, self.dhRemote).get() + let (newRootKey, newChainKeySend) = kdfRoot(self, self.rootKey, dhOutput) + self.rootKey = newRootKey + self.chainKeySend = newChainKeySend + self.msgCountSend = 0 + +proc dhRatchetRecv(self: var Doubleratchet, remotePublickey: PublicKey ) = + self.prevChainLen = self.msgCountSend + self.msgCountSend = 0 + self.msgCountRecv = 0 + + self.dhRemote = remotePublickey + + let dhOutputPre = self.dhSelf.dhExchange(self.dhRemote).get() + let (newRootKey, newChainKeyRecv) = kdfRoot(self, self.rootKey, dhOutputPre) + self.rootKey = newRootKey + self.chainKeyRecv = newChainKeyRecv + + self.dhSelf = generateKeypair().get()[0] + + let dhOutputPost = self.dhSelf.dhExchange(self.dhRemote).get() + (self.rootKey, self.chainKeyRecv) = kdfRoot(self, self.rootKey, dhOutputPost) + # let (newRootKey, newChainKeySend) = kdfRoot(self, self.rootKey, dhOutputPost) + # self.rootKey = newRootKey + # self.chainKeyRecv = newChainKeySend + + +proc skipMessageKeys(self: var Doubleratchet, until: MsgCount): Result[(), string] = + + if self.msgCountRecv + maxSkip < until: + return err("Too many skipped messages") + + while self.msgCountRecv < until: + let (msgKey, chainKey) = self.kdfChain(self.chainKeyRecv) + self.chainKeyRecv = chainKey + + let keyId = keyId(self.dhRemote, self.msgCountRecv) + self.skippedMessageKeys[keyId] = msgKey + inc self.msgCountRecv + + ok(()) + +proc encrypt(self: var Doubleratchet, plaintext: var seq[byte], associatedData: openArray[byte]): (DrHeader, CipherText) = + + let (msgKey, chainKey) = self.kdfChain(self.chainKeySend) + + let header = DrHeader( + dhPublic: self.dhSelf.public, #TODO Serialize + msgNumber: self.msgCountSend, + prevChainLength: self.prevChainLen) + + self.msgCountSend = self.msgCountSend + 1 + + + var (nonce, ciphertext) = encryptWithChaCha20Poly1305(msgKey, plaintext, associatedData) + + # TODO: optimize copies + var output : seq[byte] + output.add(nonce) + output.add(ciphertext) + + + (header, output) + + +proc decrypt*(self: var Doubleratchet, header: DrHeader, ciphertext: CipherText, associatedData: openArray[byte] ) : Result[seq[byte], NaxolotlError] = + + let peerPublic = header.dhPublic + + var msgKey : MessageKey + + # Check Skipped Keys + let keyId = keyId(header.dhPublic, header.msgNumber) + if self.skippedMessageKeys.hasKey(keyId): + debug "detected skipped message", keyId = keyId + msgKey = self.skippedMessageKeys[keyId] + else: + if (peerPublic != self.dhRemote): + let r = self.skipMessageKeys(header.prevChainLength) + if r.isErr: + error "skipMessages", error = r.error() + self.dhRatchetRecv(peerPublic) + let r = self.skipMessageKeys(header.msgNumber) + if r.isErr: + error "skipMessages", error = r.error() + + (msgKey, self.chainKeyRecv) = self.kdfChain(self.chainKeyRecv) + inc self.msgCountRecv + + var nonce : Nonce + copyMem(addr nonce[0], unsafeAddr ciphertext[0], Nonce.len) + var cipherTag = ciphertext[Nonce.len .. ^1] + + result = decryptWithChaCha20Poly1305(msgKey,nonce, cipherTag, associatedData ) + + if result.isOk: + # TODO: persist chainKey state changes + self.skippedMessageKeys.del(keyId) + + +proc encrypt*(self: var Doubleratchet, plaintext: var seq[byte]) : (DrHeader, CipherText) = + encrypt(self, plaintext,@[]) + + +func initDoubleratchet*(sharedSecret: array[32, byte], dhSelf: PrivateKey, dhRemote: PublicKey, isSending: bool = true): Doubleratchet = + + result = Doubleratchet( + dhSelf: dhSelf, + dhRemote: dhRemote, + rootKey: RootKey(sharedSecret), + msgCountSend: 0, + msgCountRecv: 0, + prevChainLen: 0, + skippedMessageKeys: initTable[(PublicKey, MsgCount), MessageKey]() + # chainKeySend: none + # chainKeyRecv: none + ) + + if isSending: + result.dhRatchetSend() diff --git a/src/naxolotl/types.nim b/src/naxolotl/types.nim new file mode 100644 index 0000000..a4504f4 --- /dev/null +++ b/src/naxolotl/types.nim @@ -0,0 +1,18 @@ +import curve25519 + +type PrivateKey* = Curve25519Key +type PublicKey* = Curve25519Key + +type RootKey* = array[32, byte] +type ChainKey* = array[32, byte] +type MessageKey* = array[32, byte] +type DhDerivedKey* = array[32, byte] + +type GenericArray* = array[32, byte] + +type CipherText* = seq[byte] + +type MsgCount* = uint32 +type KeyId* = (PublicKey, MsgCount) + +const KeyLen* = 32 \ No newline at end of file diff --git a/src/naxolotl/utils.nim b/src/naxolotl/utils.nim new file mode 100644 index 0000000..9b6e834 --- /dev/null +++ b/src/naxolotl/utils.nim @@ -0,0 +1,55 @@ +import constantine/hashes +import constantine/kdf/kdf_hkdf +import curve25519 +import results + +import errors +import types + + +func hkdfExtract*(salt: openArray[byte], ikm: openArray[byte], info: openArray[byte] ) : GenericArray = + + assert GenericArray.len == sha256.digestSize() + + var ctx{.noInit.}: HKDF[sha256] + var prk{.noInit.}: array[sha256.digestSize(), byte] + ctx.hkdfExtract(prk, salt, ikm) + + return prk + + + +func hkdfExtractExpand*(output: var openArray[byte], salt: openArray[byte], ikm: openArray[byte], info: openArray[byte] ) = + var ctx{.noInit.}: HKDF[sha256] + var prk{.noInit.}: array[sha256.digestSize(), byte] + ctx.hkdfExtract(prk, salt, ikm) + ctx.hkdfExpand(output, prk, info, true) + + +func hkdfSplit*(salt: GenericArray, ikm: GenericArray, info: openArray[byte] ) : (RootKey, ChainKey) = + + var output : array[KeyLen*2 , byte] + + hkdfExtractExpand(output, salt, ikm, info) + + var out1 : array[KeyLen, byte] + var out2 : array[KeyLen, byte] + + # Unsafe memcopy + copyMem(addr output[0], unsafeAddr out1[0], KeyLen) + copyMem(addr output[32], unsafeAddr out2[0], KeyLen) + + result = (out1,out2) + + + +func dhExchange*(a: PrivateKey, b: PublicKey): Result[DhDerivedKey, NaxolotlError] = + var dhOuput = b + + try: + Curve25519.mul(dhOuput, a) + except CatchableError as e: + return err(NaxolotlError( code: errProgram, context: e.msg)) + ok(DhDerivedKey(dhOuput)) + +