2025-10-07 18:41:45 -07:00
|
|
|
import curve25519
|
|
|
|
|
import results
|
|
|
|
|
import chronicles
|
2025-11-20 16:02:57 -08:00
|
|
|
import nim_chacha20_poly1305/common
|
2025-10-09 17:07:50 -07:00
|
|
|
import strformat
|
|
|
|
|
import strutils
|
|
|
|
|
import sequtils
|
2025-10-07 18:41:45 -07:00
|
|
|
import tables
|
|
|
|
|
|
|
|
|
|
import chacha
|
|
|
|
|
import types
|
|
|
|
|
import utils
|
|
|
|
|
import errors
|
|
|
|
|
|
|
|
|
|
const maxSkip = 10
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
type Doubleratchet* = object
|
|
|
|
|
dhSelf: PrivateKey
|
|
|
|
|
dhRemote: PublicKey
|
|
|
|
|
|
|
|
|
|
rootKey: RootKey
|
|
|
|
|
chainKeySend: ChainKey
|
|
|
|
|
chainKeyRecv: ChainKey
|
|
|
|
|
|
|
|
|
|
msgCountSend: MsgCount
|
|
|
|
|
msgCountRecv: MsgCount
|
|
|
|
|
prevChainLen: MsgCount
|
|
|
|
|
|
|
|
|
|
# TODO: SkippedKeys
|
|
|
|
|
skippedMessageKeys: Table[(PublicKey,MsgCount), MessageKey]
|
|
|
|
|
|
2025-11-20 09:49:44 -08:00
|
|
|
const DomainSepKdfRoot = "DoubleRatchetRootKey"
|
2025-10-07 18:41:45 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
type DrHeader* = object
|
2025-10-09 17:07:50 -07:00
|
|
|
dhPublic*: PublicKey
|
|
|
|
|
msgNumber*: uint32
|
|
|
|
|
prevChainLen*: uint32
|
2025-10-07 18:41:45 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func keyId(dh:PublicKey, recvCount: MsgCount ): KeyId =
|
|
|
|
|
(dh, recvCount)
|
|
|
|
|
|
2025-10-09 17:07:50 -07:00
|
|
|
func hex(a: openArray[byte]) : string =
|
|
|
|
|
a.mapIt(&"{it:02X}").join("")
|
|
|
|
|
|
|
|
|
|
proc `$`*(x: DrHeader): string =
|
|
|
|
|
"DrHeader(pubKey=" & hex(x.dhPublic) & ", msgNum=" & $x.msgNumber & ", msgNum=" & $x.prevChainLen & ")"
|
|
|
|
|
|
|
|
|
|
|
2025-10-07 18:41:45 -07:00
|
|
|
#################################################
|
|
|
|
|
# Kdf
|
|
|
|
|
#################################################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func kdfRoot(self: var Doubleratchet, rootKey: RootKey, dhOutput:DhDerivedKey): (RootKey, ChainKey) =
|
|
|
|
|
|
|
|
|
|
var salt = rootKey
|
|
|
|
|
var ikm = dhOutput
|
|
|
|
|
let info = cast[seq[byte]](DomainSepKdfRoot)
|
|
|
|
|
|
|
|
|
|
hkdfSplit(salt, ikm, info)
|
|
|
|
|
|
|
|
|
|
func kdfChain(self: Doubleratchet, chainKey: ChainKey): (MessageKey, ChainKey) =
|
|
|
|
|
|
2025-11-20 09:44:46 -08:00
|
|
|
let msgKey = hkdfExtract(chainKey, [0x01u8])
|
|
|
|
|
let chainKey = hkdfExtract(chainKey, [0x02u8])
|
2025-10-07 18:41:45 -07:00
|
|
|
|
|
|
|
|
return(msgKey, chainKey)
|
|
|
|
|
|
|
|
|
|
func dhRatchetSend(self: var Doubleratchet) =
|
|
|
|
|
# Perform DH Ratchet step when receiving a new peer key.
|
|
|
|
|
let dhOutput : DhDerivedKey = dhExchange(self.dhSelf, self.dhRemote).get()
|
|
|
|
|
let (newRootKey, newChainKeySend) = kdfRoot(self, self.rootKey, dhOutput)
|
|
|
|
|
self.rootKey = newRootKey
|
|
|
|
|
self.chainKeySend = newChainKeySend
|
|
|
|
|
self.msgCountSend = 0
|
|
|
|
|
|
|
|
|
|
proc dhRatchetRecv(self: var Doubleratchet, remotePublickey: PublicKey ) =
|
|
|
|
|
self.prevChainLen = self.msgCountSend
|
|
|
|
|
self.msgCountSend = 0
|
|
|
|
|
self.msgCountRecv = 0
|
|
|
|
|
|
|
|
|
|
self.dhRemote = remotePublickey
|
|
|
|
|
|
|
|
|
|
let dhOutputPre = self.dhSelf.dhExchange(self.dhRemote).get()
|
|
|
|
|
let (newRootKey, newChainKeyRecv) = kdfRoot(self, self.rootKey, dhOutputPre)
|
|
|
|
|
self.rootKey = newRootKey
|
|
|
|
|
self.chainKeyRecv = newChainKeyRecv
|
|
|
|
|
|
|
|
|
|
self.dhSelf = generateKeypair().get()[0]
|
|
|
|
|
|
|
|
|
|
let dhOutputPost = self.dhSelf.dhExchange(self.dhRemote).get()
|
|
|
|
|
(self.rootKey, self.chainKeyRecv) = kdfRoot(self, self.rootKey, dhOutputPost)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
proc skipMessageKeys(self: var Doubleratchet, until: MsgCount): Result[(), string] =
|
|
|
|
|
|
|
|
|
|
if self.msgCountRecv + maxSkip < until:
|
|
|
|
|
return err("Too many skipped messages")
|
|
|
|
|
|
|
|
|
|
while self.msgCountRecv < until:
|
|
|
|
|
let (msgKey, chainKey) = self.kdfChain(self.chainKeyRecv)
|
|
|
|
|
self.chainKeyRecv = chainKey
|
|
|
|
|
|
|
|
|
|
let keyId = keyId(self.dhRemote, self.msgCountRecv)
|
|
|
|
|
self.skippedMessageKeys[keyId] = msgKey
|
|
|
|
|
inc self.msgCountRecv
|
|
|
|
|
|
|
|
|
|
ok(())
|
|
|
|
|
|
|
|
|
|
proc encrypt(self: var Doubleratchet, plaintext: var seq[byte], associatedData: openArray[byte]): (DrHeader, CipherText) =
|
|
|
|
|
|
|
|
|
|
let (msgKey, chainKey) = self.kdfChain(self.chainKeySend)
|
2025-11-25 08:14:59 -08:00
|
|
|
self.chainKeySend = chainKey
|
2025-10-07 18:41:45 -07:00
|
|
|
let header = DrHeader(
|
|
|
|
|
dhPublic: self.dhSelf.public, #TODO Serialize
|
|
|
|
|
msgNumber: self.msgCountSend,
|
2025-10-09 17:07:50 -07:00
|
|
|
prevChainLen: self.prevChainLen)
|
2025-10-07 18:41:45 -07:00
|
|
|
|
|
|
|
|
self.msgCountSend = self.msgCountSend + 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
var (nonce, ciphertext) = encryptWithChaCha20Poly1305(msgKey, plaintext, associatedData)
|
|
|
|
|
|
|
|
|
|
# TODO: optimize copies
|
|
|
|
|
var output : seq[byte]
|
|
|
|
|
output.add(nonce)
|
|
|
|
|
output.add(ciphertext)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
(header, output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
proc decrypt*(self: var Doubleratchet, header: DrHeader, ciphertext: CipherText, associatedData: openArray[byte] ) : Result[seq[byte], NaxolotlError] =
|
|
|
|
|
|
|
|
|
|
let peerPublic = header.dhPublic
|
|
|
|
|
|
|
|
|
|
var msgKey : MessageKey
|
|
|
|
|
|
|
|
|
|
# Check Skipped Keys
|
|
|
|
|
let keyId = keyId(header.dhPublic, header.msgNumber)
|
|
|
|
|
if self.skippedMessageKeys.hasKey(keyId):
|
|
|
|
|
debug "detected skipped message", keyId = keyId
|
|
|
|
|
msgKey = self.skippedMessageKeys[keyId]
|
|
|
|
|
else:
|
|
|
|
|
if (peerPublic != self.dhRemote):
|
2025-10-09 17:07:50 -07:00
|
|
|
let r = self.skipMessageKeys(header.prevChainLen)
|
2025-10-07 18:41:45 -07:00
|
|
|
if r.isErr:
|
|
|
|
|
error "skipMessages", error = r.error()
|
|
|
|
|
self.dhRatchetRecv(peerPublic)
|
|
|
|
|
let r = self.skipMessageKeys(header.msgNumber)
|
|
|
|
|
if r.isErr:
|
|
|
|
|
error "skipMessages", error = r.error()
|
|
|
|
|
|
|
|
|
|
(msgKey, self.chainKeyRecv) = self.kdfChain(self.chainKeyRecv)
|
|
|
|
|
inc self.msgCountRecv
|
|
|
|
|
|
|
|
|
|
var nonce : Nonce
|
|
|
|
|
copyMem(addr nonce[0], unsafeAddr ciphertext[0], Nonce.len)
|
|
|
|
|
var cipherTag = ciphertext[Nonce.len .. ^1]
|
|
|
|
|
|
|
|
|
|
result = decryptWithChaCha20Poly1305(msgKey,nonce, cipherTag, associatedData )
|
|
|
|
|
|
|
|
|
|
if result.isOk:
|
|
|
|
|
# TODO: persist chainKey state changes
|
|
|
|
|
self.skippedMessageKeys.del(keyId)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
proc encrypt*(self: var Doubleratchet, plaintext: var seq[byte]) : (DrHeader, CipherText) =
|
|
|
|
|
encrypt(self, plaintext,@[])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func initDoubleratchet*(sharedSecret: array[32, byte], dhSelf: PrivateKey, dhRemote: PublicKey, isSending: bool = true): Doubleratchet =
|
|
|
|
|
|
|
|
|
|
result = Doubleratchet(
|
|
|
|
|
dhSelf: dhSelf,
|
|
|
|
|
dhRemote: dhRemote,
|
|
|
|
|
rootKey: RootKey(sharedSecret),
|
|
|
|
|
msgCountSend: 0,
|
|
|
|
|
msgCountRecv: 0,
|
|
|
|
|
prevChainLen: 0,
|
|
|
|
|
skippedMessageKeys: initTable[(PublicKey, MsgCount), MessageKey]()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if isSending:
|
|
|
|
|
result.dhRatchetSend()
|