Inbound Path (#38)

* refactor invite

* Update tests

* Cleanups

* Remove ContentFrame references
This commit is contained in:
Jazz Turner-Baggs 2025-12-16 08:20:53 -08:00 committed by GitHub
parent 082f63f6c7
commit 7ee12eb250
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 198 additions and 133 deletions

View File

@ -82,7 +82,7 @@ proc main() {.async.} =
# Perform OOB Introduction: Raya -> Saro # Perform OOB Introduction: Raya -> Saro
let raya_bundle = raya.createIntroBundle() let raya_bundle = raya.createIntroBundle()
discard await saro.newPrivateConversation(raya_bundle) discard await saro.newPrivateConversation(raya_bundle, initTextFrame("Init").toContentFrame().toBytes())
await sleepAsync(20.seconds) # Run for some time await sleepAsync(20.seconds) # Run for some time

View File

@ -203,7 +203,7 @@ proc drawStatusBar(app: ChatApp, layout: Pane , fg: ForegroundColor, bg: Backgro
var i = layout.yStart + 1 var i = layout.yStart + 1
var chunk = layout.width - 9 var chunk = layout.width - 9
tb.write(1, i, "Name: " & app.client.getId()) tb.write(1, i, "Name: " & app.client.getName())
inc i inc i
tb.write(1, i, fmt"PeerCount: {app.peerCount}") tb.write(1, i, fmt"PeerCount: {app.peerCount}")
inc i inc i

View File

@ -2,10 +2,13 @@ syntax = "proto3";
package wap.invite; package wap.invite;
import "encryption.proto";
message InvitePrivateV1 { message InvitePrivateV1 {
bytes initiator = 1; bytes initiator = 1;
bytes initiator_ephemeral = 2; bytes initiator_ephemeral = 2;
bytes participant = 3; bytes participant = 3;
int32 participant_ephemeral_id= 4; int32 participant_ephemeral_id= 4;
string discriminator = 5; string discriminator = 5;
encryption.EncryptedPayload initial_message = 6;
} }

View File

@ -54,6 +54,7 @@ type Client* = ref object
conversations: Table[string, Conversation] # Keyed by conversation ID conversations: Table[string, Conversation] # Keyed by conversation ID
inboundQueue: QueueRef inboundQueue: QueueRef
isRunning: bool isRunning: bool
inbox: Inbox
newMessageCallbacks: seq[MessageCallback] newMessageCallbacks: seq[MessageCallback]
newConvoCallbacks: seq[NewConvoCallback] newConvoCallbacks: seq[NewConvoCallback]
@ -71,6 +72,8 @@ proc newClient*(cfg: WakuConfig, ident: Identity): Client {.raises: [IOError,
let rm = newReliabilityManager().valueOr: let rm = newReliabilityManager().valueOr:
raise newException(ValueError, fmt"SDS InitializationError") raise newException(ValueError, fmt"SDS InitializationError")
let defaultInbox = initInbox(ident)
var q = QueueRef(queue: newAsyncQueue[ChatPayload](10)) var q = QueueRef(queue: newAsyncQueue[ChatPayload](10))
var c = Client(ident: ident, var c = Client(ident: ident,
ds: waku, ds: waku,
@ -78,14 +81,14 @@ proc newClient*(cfg: WakuConfig, ident: Identity): Client {.raises: [IOError,
conversations: initTable[string, Conversation](), conversations: initTable[string, Conversation](),
inboundQueue: q, inboundQueue: q,
isRunning: false, isRunning: false,
inbox: defaultInbox,
newMessageCallbacks: @[], newMessageCallbacks: @[],
newConvoCallbacks: @[]) newConvoCallbacks: @[])
let defaultInbox = initInbox(c.ident.getPubkey())
c.conversations[defaultInbox.id()] = defaultInbox c.conversations[defaultInbox.id()] = defaultInbox
notice "Client started", client = c.ident.getId(), notice "Client started", client = c.ident.getName(),
defaultInbox = defaultInbox defaultInbox = defaultInbox, inTopic= topic_inbox(c.ident.get_addr())
result = c result = c
except Exception as e: except Exception as e:
error "newCLient", err = e.msg error "newCLient", err = e.msg
@ -95,7 +98,7 @@ proc newClient*(cfg: WakuConfig, ident: Identity): Client {.raises: [IOError,
################################################# #################################################
proc getId*(client: Client): string = proc getId*(client: Client): string =
result = client.ident.getId() result = client.ident.getName()
proc identity*(client: Client): Identity = proc identity*(client: Client): Identity =
result = client.ident result = client.ident
@ -174,7 +177,7 @@ proc createIntroBundle*(self: var Client): IntroBundle =
################################################# #################################################
proc addConversation*(client: Client, convo: Conversation) = proc addConversation*(client: Client, convo: Conversation) =
notice "Creating conversation", client = client.getId(), topic = convo.id() notice "Creating conversation", client = client.getId(), convoId = convo.id()
client.conversations[convo.id()] = convo client.conversations[convo.id()] = convo
client.notifyNewConversation(convo) client.notifyNewConversation(convo)
@ -183,47 +186,15 @@ proc getConversation*(client: Client, convoId: string): Conversation =
result = client.conversations[convoId] result = client.conversations[convoId]
proc newPrivateConversation*(client: Client, proc newPrivateConversation*(client: Client,
introBundle: IntroBundle): Future[Option[ChatError]] {.async.} = introBundle: IntroBundle, content: Content): Future[Option[ChatError]] {.async.} =
## Creates a private conversation with the given `IntroBundle`. ## Creates a private conversation with the given `IntroBundle`.
## `IntroBundles` are provided out-of-band. ## `IntroBundles` are provided out-of-band.
let remote_pubkey = loadPublicKeyFromBytes(introBundle.ident).get()
let remote_ephemeralkey = loadPublicKeyFromBytes(introBundle.ephemeral).get()
notice "New PRIVATE Convo ", client = client.getId(), let convo = await client.inbox.inviteToPrivateConversation(client.ds,remote_pubkey, remote_ephemeralkey, content )
fromm = introBundle.ident.mapIt(it.toHex(2)).join("") client.addConversation(convo) # TODO: Fix re-entrantancy bug. Convo needs to be saved before payload is sent.
let destPubkey = loadPublicKeyFromBytes(introBundle.ident).valueOr:
raise newException(ValueError, "Invalid public key in intro bundle.")
let convoId = conversationIdFor(destPubkey)
let destConvoTopic = topicInbox(destPubkey.getAddr())
let invite = InvitePrivateV1(
initiator: @(client.ident.getPubkey().bytes()),
initiatorEphemeral: @[0, 0], # TODO: Add ephemeral
participant: @(destPubkey.bytes()),
participantEphemeralId: introBundle.ephemeralId,
discriminator: "test"
)
let env = wrapEnv(encrypt(InboxV1Frame(invitePrivateV1: invite,
recipient: "")), convoId)
let deliveryAckCb = proc(
conversation: Conversation,
msgId: string): Future[void] {.async.} =
client.notifyDeliveryAck(conversation, msgId)
# TODO: remove placeholder key
var key : array[32, byte]
key[2]=2
var convo = initPrivateV1Sender(client.identity(), client.ds, destPubkey, key, deliveryAckCb)
client.addConversation(convo)
# TODO: Subscribe to new content topic
await client.ds.sendPayload(destConvoTopic, env)
return none(ChatError) return none(ChatError)

View File

@ -23,7 +23,8 @@ import convo_type
import message import message
import ../../naxolotl as nax import ../../naxolotl as nax
const TopicPrefixPrivateV1 = "/convo/private/"
type type
ReceivedPrivateV1Message* = ref object of ReceivedMessage ReceivedPrivateV1Message* = ref object of ReceivedMessage
@ -37,17 +38,32 @@ type
ds: WakuClient ds: WakuClient
sdsClient: ReliabilityManager sdsClient: ReliabilityManager
owner: Identity owner: Identity
topic: string
participant: PublicKey participant: PublicKey
discriminator: string discriminator: string
doubleratchet: naxolotl.Doubleratchet doubleratchet: naxolotl.Doubleratchet
const proc derive_topic(participant: PublicKey): string =
TopicPrefixPrivateV1 = "/convo/private/" ## Derives a topic from the participants' public keys.
return TopicPrefixPrivateV1 & participant.get_addr()
proc getTopic*(self: PrivateV1): string = proc getTopicInbound*(self: PrivateV1): string =
## Returns the topic for the PrivateV1 conversation. ## Returns the topic where the local client is listening for messages
return self.topic return derive_topic(self.owner.getPubkey())
proc getTopicOutbound*(self: PrivateV1): string =
## Returns the topic where the remote recipient is listening for messages
return derive_topic(self.participant)
## Parses the topic to extract the conversation ID.
proc parseTopic*(topic: string): Result[string, ChatError] =
if not topic.startsWith(TopicPrefixPrivateV1):
return err(ChatError(code: errTopic, context: "Invalid topic prefix"))
let id = topic.split('/')[^1]
if id == "":
return err(ChatError(code: errTopic, context: "Empty conversation ID"))
return ok(id)
proc allParticipants(self: PrivateV1): seq[PublicKey] = proc allParticipants(self: PrivateV1): seq[PublicKey] =
return @[self.owner.getPubkey(), self.participant] return @[self.owner.getPubkey(), self.participant]
@ -64,20 +80,6 @@ proc getConvoIdRaw(participants: seq[PublicKey],
proc getConvoId*(self: PrivateV1): string = proc getConvoId*(self: PrivateV1): string =
return getConvoIdRaw(@[self.owner.getPubkey(), self.participant], self.discriminator) return getConvoIdRaw(@[self.owner.getPubkey(), self.participant], self.discriminator)
proc derive_topic(participants: seq[PublicKey], discriminator: string): string =
## Derives a topic from the participants' public keys.
return TopicPrefixPrivateV1 & getConvoIdRaw(participants, discriminator)
## Parses the topic to extract the conversation ID.
proc parseTopic*(topic: string): Result[string, ChatError] =
if not topic.startsWith(TopicPrefixPrivateV1):
return err(ChatError(code: errTopic, context: "Invalid topic prefix"))
let id = topic.split('/')[^1]
if id == "":
return err(ChatError(code: errTopic, context: "Empty conversation ID"))
return ok(id)
proc calcMsgId(self: PrivateV1, msgBytes: seq[byte]): string = proc calcMsgId(self: PrivateV1, msgBytes: seq[byte]): string =
let s = fmt"{self.getConvoId()}|{msgBytes}" let s = fmt"{self.getConvoId()}|{msgBytes}"
@ -124,7 +126,7 @@ proc wireCallbacks(convo: PrivateV1, deliveryAckCb: proc(
let funcDeliveryAck = proc(messageId: SdsMessageID, let funcDeliveryAck = proc(messageId: SdsMessageID,
channelId: SdsChannelID) {.gcsafe.} = channelId: SdsChannelID) {.gcsafe.} =
debug "sds message ack", messageId = messageId, debug "sds message ack", messageId = messageId,
channelId = channelId, cb = repr(deliveryAckCb) channelId = channelId
if deliveryAckCb != nil: if deliveryAckCb != nil:
asyncSpawn deliveryAckCb(convo, messageId) asyncSpawn deliveryAckCb(convo, messageId)
@ -146,19 +148,21 @@ proc initPrivateV1*(owner: Identity, ds:WakuClient, participant: PublicKey, seed
msgId: string): Future[void] {.async.} = nil): msgId: string): Future[void] {.async.} = nil):
PrivateV1 = PrivateV1 =
var participants = @[owner.getPubkey(), participant];
var rm = newReliabilityManager().valueOr: var rm = newReliabilityManager().valueOr:
raise newException(ValueError, fmt"sds initialization: {repr(error)}") raise newException(ValueError, fmt"sds initialization: {repr(error)}")
let dr = if isSender:
initDoubleratchetSender(seedKey, participant.bytes)
else:
initDoubleratchetRecipient(seedKey, owner.privateKey.bytes)
result = PrivateV1( result = PrivateV1(
ds: ds, ds: ds,
sdsClient: rm, sdsClient: rm,
owner: owner, owner: owner,
topic: derive_topic(participants, discriminator),
participant: participant, participant: participant,
discriminator: discriminator, discriminator: discriminator,
doubleratchet: initDoubleratchet(seedKey, owner.privateKey.bytes, participant.bytes, isSender) doubleratchet: dr
) )
result.wireCallbacks(deliveryAckCb) result.wireCallbacks(deliveryAckCb)
@ -166,18 +170,7 @@ proc initPrivateV1*(owner: Identity, ds:WakuClient, participant: PublicKey, seed
result.sdsClient.ensureChannel(result.getConvoId()).isOkOr: result.sdsClient.ensureChannel(result.getConvoId()).isOkOr:
raise newException(ValueError, "bad sds channel") raise newException(ValueError, "bad sds channel")
proc encodeFrame*(self: PrivateV1, msg: PrivateV1Frame): (MessageId, EncryptedPayload) =
proc initPrivateV1Sender*(owner:Identity, ds: WakuClient, participant: PublicKey, seedKey: array[32, byte], deliveryAckCb: proc(
conversation: Conversation, msgId: string): Future[void] {.async.} = nil): PrivateV1 =
initPrivateV1(owner, ds, participant, seedKey, "default", true, deliveryAckCb)
proc initPrivateV1Recipient*(owner:Identity,ds: WakuClient, participant: PublicKey, seedKey: array[32, byte], deliveryAckCb: proc(
conversation: Conversation, msgId: string): Future[void] {.async.} = nil): PrivateV1 =
initPrivateV1(owner,ds, participant, seedKey, "default", false, deliveryAckCb)
proc sendFrame(self: PrivateV1, ds: WakuClient,
msg: PrivateV1Frame): Future[MessageId]{.async.} =
let frameBytes = encode(msg) let frameBytes = encode(msg)
let msgId = self.calcMsgId(frameBytes) let msgId = self.calcMsgId(frameBytes)
@ -185,9 +178,12 @@ proc sendFrame(self: PrivateV1, ds: WakuClient,
self.getConvoId()).valueOr: self.getConvoId()).valueOr:
raise newException(ValueError, fmt"sds wrapOutgoingMessage failed: {repr(error)}") raise newException(ValueError, fmt"sds wrapOutgoingMessage failed: {repr(error)}")
let encryptedPayload = self.encrypt(sdsPayload) result = (msgId, self.encrypt(sdsPayload))
discard ds.sendPayload(self.getTopic(), encryptedPayload.toEnvelope( proc sendFrame(self: PrivateV1, ds: WakuClient,
msg: PrivateV1Frame): Future[MessageId]{.async.} =
let (msgId, encryptedPayload) = self.encodeFrame(msg)
discard ds.sendPayload(self.getTopicOutbound(), encryptedPayload.toEnvelope(
self.getConvoId())) self.getConvoId()))
result = msgId result = msgId
@ -197,18 +193,15 @@ method id*(self: PrivateV1): string =
return getConvoIdRaw(self.allParticipants(), self.discriminator) return getConvoIdRaw(self.allParticipants(), self.discriminator)
proc handleFrame*[T: ConversationStore](convo: PrivateV1, client: T, proc handleFrame*[T: ConversationStore](convo: PrivateV1, client: T,
bytes: seq[byte]) = encPayload: EncryptedPayload) =
## Dispatcher for Incoming `PrivateV1Frames`. ## Dispatcher for Incoming `PrivateV1Frames`.
## Calls further processing depending on the kind of frame. ## Calls further processing depending on the kind of frame.
let enc = decode(bytes, EncryptedPayload).valueOr: if convo.doubleratchet.dhSelfPublic() == encPayload.doubleratchet.dh:
raise newException(ValueError, fmt"Failed to decode EncryptedPayload: {repr(error)}")
if convo.doubleratchet.dhSelfPublic() == enc.doubleratchet.dh:
info "outgoing message, no need to handle", convo = convo.id() info "outgoing message, no need to handle", convo = convo.id()
return return
let plaintext = convo.decrypt(enc).valueOr: let plaintext = convo.decrypt(encPayload).valueOr:
error "decryption failed", error = error error "decryption failed", error = error
return return
@ -234,6 +227,15 @@ proc handleFrame*[T: ConversationStore](convo: PrivateV1, client: T,
of typePlaceholder: of typePlaceholder:
notice "Got Placeholder", text = frame.placeholder.counter notice "Got Placeholder", text = frame.placeholder.counter
proc handleFrame*[T: ConversationStore](convo: PrivateV1, client: T,
bytes: seq[byte]) =
## Dispatcher for Incoming `PrivateV1Frames`.
## Calls further processing depending on the kind of frame.
let encPayload = decode(bytes, EncryptedPayload).valueOr:
raise newException(ValueError, fmt"Failed to decode EncryptedPayload: {repr(error)}")
convo.handleFrame(client,encPayload)
method sendMessage*(convo: PrivateV1, content_frame: Content) : Future[MessageId] {.async.} = method sendMessage*(convo: PrivateV1, content_frame: Content) : Future[MessageId] {.async.} =
@ -245,3 +247,36 @@ method sendMessage*(convo: PrivateV1, content_frame: Content) : Future[MessageId
except Exception as e: except Exception as e:
error "Unknown error in PrivateV1:SendMessage" error "Unknown error in PrivateV1:SendMessage"
## Encrypts content without sending it.
proc encryptMessage*(self: PrivateV1, content_frame: Content) : (MessageId, EncryptedPayload) =
try:
let frame = PrivateV1Frame(
sender: @(self.owner.getPubkey().bytes()),
timestamp: getCurrentTimestamp(),
content: content_frame
)
result = self.encodeFrame(frame)
except Exception as e:
error "Unknown error in PrivateV1:EncryptMessage"
proc initPrivateV1Sender*(sender:Identity,
ds: WakuClient,
participant: PublicKey,
seedKey: array[32, byte],
content: Content,
deliveryAckCb: proc(conversation: Conversation, msgId: string): Future[void] {.async.} = nil): (PrivateV1, EncryptedPayload) =
let convo = initPrivateV1(sender, ds, participant, seedKey, "default", true, deliveryAckCb)
# Encrypt Content with Convo
let contentFrame = PrivateV1Frame(sender: @(sender.getPubkey().bytes()), timestamp: getCurrentTimestamp(), content: content)
let (msg_id, encPayload) = convo.encryptMessage(content)
result = (convo, encPayload)
proc initPrivateV1Recipient*(owner:Identity,ds: WakuClient, participant: PublicKey, seedKey: array[32, byte], deliveryAckCb: proc(
conversation: Conversation, msgId: string): Future[void] {.async.} = nil): PrivateV1 =
initPrivateV1(owner,ds, participant, seedKey, "default", false, deliveryAckCb)

View File

@ -29,5 +29,5 @@ proc getAddr*(self: Identity): string =
result = get_addr(self.getPubKey()) result = get_addr(self.getPubKey())
proc getId*(self: Identity): string = proc getName*(self: Identity): string =
result = self.name result = self.name

View File

@ -11,16 +11,19 @@ import
conversations, conversations,
conversation_store, conversation_store,
crypto, crypto,
delivery/waku_client,
errors, errors,
identity,
proto_types, proto_types,
types types,
utils
logScope: logScope:
topics = "chat inbox" topics = "chat inbox"
type type
Inbox* = ref object of Conversation Inbox* = ref object of Conversation
pubkey: PublicKey identity: Identity
inbox_addr: string inbox_addr: string
const const
@ -30,9 +33,9 @@ proc `$`*(conv: Inbox): string =
fmt"Inbox: addr->{conv.inbox_addr}" fmt"Inbox: addr->{conv.inbox_addr}"
proc initInbox*(pubkey: PublicKey): Inbox = proc initInbox*(ident: Identity): Inbox =
## Initializes an Inbox object with the given address and invite callback. ## Initializes an Inbox object with the given address and invite callback.
return Inbox(pubkey: pubkey) return Inbox(identity: ident)
proc encrypt*(frame: InboxV1Frame): EncryptedPayload = proc encrypt*(frame: InboxV1Frame): EncryptedPayload =
return encrypt_plain(frame) return encrypt_plain(frame)
@ -73,13 +76,52 @@ proc parseTopic*(topic: string): Result[string, ChatError] =
return ok(id) return ok(id)
method id*(convo: Inbox): string = method id*(convo: Inbox): string =
return conversation_id_for(convo.pubkey) return conversation_id_for(convo.identity.getPubkey())
## Encrypt and Send a frame to the remote account
proc sendFrame(ds: WakuClient, remote: PublicKey, frame: InboxV1Frame ): Future[void] {.async.} =
let env = wrapEnv(encrypt(frame),conversation_id_for(remote) )
await ds.sendPayload(topic_inbox(remote.get_addr()), env)
proc newPrivateInvite(initator_static: PublicKey,
initator_ephemeral: PublicKey,
recipient_static: PublicKey,
recipient_ephemeral: uint32,
payload: EncryptedPayload) : InboxV1Frame =
let invite = InvitePrivateV1(
initiator: @(initator_static.bytes()),
initiatorEphemeral: @(initator_ephemeral.bytes()),
participant: @(recipient_static.bytes()),
participantEphemeralId: 0,
discriminator: "",
initial_message: payload
)
result = InboxV1Frame(invitePrivateV1: invite, recipient: "")
################################################# #################################################
# Conversation Creation # Conversation Creation
################################################# #################################################
## Establish a PrivateConversation with a remote client
proc inviteToPrivateConversation*(self: Inbox, ds: Wakuclient, remote_static: PublicKey, remote_ephemeral: PublicKey, content: Content ) : Future[PrivateV1] {.async.} =
# Create SeedKey
# TODO: Update key derivations when noise is integrated
var local_ephemeral = generateKey()
var sk{.noInit.} : array[32, byte] = default(array[32, byte])
# Initialize PrivateConversation
let (convo, encPayload) = initPrivateV1Sender(self.identity, ds, remote_static, sk, content, nil)
result = convo
# # Build Invite
let frame = newPrivateInvite(self.identity.getPubkey(), local_ephemeral.getPublicKey(), remote_static, 0, encPayload)
# Send
await sendFrame(ds, remote_static, frame)
## Receive am Invitation to create a new private conversation
proc createPrivateV1FromInvite*[T: ConversationStore](client: T, proc createPrivateV1FromInvite*[T: ConversationStore](client: T,
invite: InvitePrivateV1) = invite: InvitePrivateV1) =
@ -92,13 +134,17 @@ proc createPrivateV1FromInvite*[T: ConversationStore](client: T,
client.notifyDeliveryAck(conversation, msgId) client.notifyDeliveryAck(conversation, msgId)
# TODO: remove placeholder key # TODO: remove placeholder key
var key : array[32, byte] var key : array[32, byte] = default(array[32,byte])
key[2]=2
let convo = initPrivateV1Recipient(client.identity(), client.ds, destPubkey, key, deliveryAckCb) let convo = initPrivateV1Recipient(client.identity(), client.ds, destPubkey, key, deliveryAckCb)
notice "Creating PrivateV1 conversation", client = client.getId(), notice "Creating PrivateV1 conversation", client = client.getId(),
topic = convo.getConvoId() convoId = convo.getConvoId()
client.addConversation(convo)
convo.handleFrame(client, invite.initial_message)
# Calling `addConversation` must only occur after the conversation is completely configured.
# The client calls the OnNewConversation callback, which returns execution to the application.
client.addConversation(convo)
proc handleFrame*[T: ConversationStore](convo: Inbox, client: T, bytes: seq[ proc handleFrame*[T: ConversationStore](convo: Inbox, client: T, bytes: seq[
byte]) = byte]) =

View File

@ -10,7 +10,6 @@ export protobuf_serialization
import_proto3 "../../protos/inbox.proto" import_proto3 "../../protos/inbox.proto"
# import_proto3 "../protos/invite.proto" // Import3 follows protobuf includes so this will result in a redefinition error # import_proto3 "../protos/invite.proto" // Import3 follows protobuf includes so this will result in a redefinition error
import_proto3 "../../protos/encryption.proto"
import_proto3 "../../protos/envelope.proto" import_proto3 "../../protos/envelope.proto"
import_proto3 "../../protos/private_v1.proto" import_proto3 "../../protos/private_v1.proto"

View File

@ -50,12 +50,20 @@ proc `$`*(x: DrHeader): string =
"DrHeader(pubKey=" & hex(x.dhPublic) & ", msgNum=" & $x.msgNumber & ", msgNum=" & $x.prevChainLen & ")" "DrHeader(pubKey=" & hex(x.dhPublic) & ", msgNum=" & $x.msgNumber & ", msgNum=" & $x.prevChainLen & ")"
proc `$`*(key: array[32, byte]): string =
let byteStr = hex(key)
fmt"{byteStr[0..5]}..{byteStr[^6 .. ^1]}"
proc generateDhKey() : PrivateKey =
result = generateKeypair().get()[0]
################################################# #################################################
# Kdf # Kdf
################################################# #################################################
func kdfRoot(self: var Doubleratchet, rootKey: RootKey, dhOutput:DhDerivedKey): (RootKey, ChainKey) = func kdfRoot(rootKey: RootKey, dhOutput:DhDerivedKey): (RootKey, ChainKey) =
var salt = rootKey var salt = rootKey
var ikm = dhOutput var ikm = dhOutput
@ -63,7 +71,7 @@ func kdfRoot(self: var Doubleratchet, rootKey: RootKey, dhOutput:DhDerivedKey):
hkdfSplit(salt, ikm, info) hkdfSplit(salt, ikm, info)
func kdfChain(self: Doubleratchet, chainKey: ChainKey): (MessageKey, ChainKey) = func kdfChain(chainKey: ChainKey): (MessageKey, ChainKey) =
let msgKey = hkdfExtract(chainKey, [0x01u8]) let msgKey = hkdfExtract(chainKey, [0x01u8])
let chainKey = hkdfExtract(chainKey, [0x02u8]) let chainKey = hkdfExtract(chainKey, [0x02u8])
@ -73,7 +81,7 @@ func kdfChain(self: Doubleratchet, chainKey: ChainKey): (MessageKey, ChainKey) =
func dhRatchetSend(self: var Doubleratchet) = func dhRatchetSend(self: var Doubleratchet) =
# Perform DH Ratchet step when receiving a new peer key. # Perform DH Ratchet step when receiving a new peer key.
let dhOutput : DhDerivedKey = dhExchange(self.dhSelf, self.dhRemote).get() let dhOutput : DhDerivedKey = dhExchange(self.dhSelf, self.dhRemote).get()
let (newRootKey, newChainKeySend) = kdfRoot(self, self.rootKey, dhOutput) let (newRootKey, newChainKeySend) = kdfRoot(self.rootKey, dhOutput)
self.rootKey = newRootKey self.rootKey = newRootKey
self.chainKeySend = newChainKeySend self.chainKeySend = newChainKeySend
self.msgCountSend = 0 self.msgCountSend = 0
@ -86,14 +94,14 @@ proc dhRatchetRecv(self: var Doubleratchet, remotePublickey: PublicKey ) =
self.dhRemote = remotePublickey self.dhRemote = remotePublickey
let dhOutputPre = self.dhSelf.dhExchange(self.dhRemote).get() let dhOutputPre = self.dhSelf.dhExchange(self.dhRemote).get()
let (newRootKey, newChainKeyRecv) = kdfRoot(self, self.rootKey, dhOutputPre) let (newRootKey, newChainKeyRecv) = kdfRoot(self.rootKey, dhOutputPre)
self.rootKey = newRootKey self.rootKey = newRootKey
self.chainKeyRecv = newChainKeyRecv self.chainKeyRecv = newChainKeyRecv
self.dhSelf = generateKeypair().get()[0] self.dhSelf = generateDhKey()
let dhOutputPost = self.dhSelf.dhExchange(self.dhRemote).get() let dhOutputPost = self.dhSelf.dhExchange(self.dhRemote).get()
(self.rootKey, self.chainKeySend) = kdfRoot(self, self.rootKey, dhOutputPost) (self.rootKey, self.chainKeySend) = kdfRoot(self.rootKey, dhOutputPost)
proc skipMessageKeys(self: var Doubleratchet, until: MsgCount): Result[(), string] = proc skipMessageKeys(self: var Doubleratchet, until: MsgCount): Result[(), string] =
@ -102,7 +110,7 @@ proc skipMessageKeys(self: var Doubleratchet, until: MsgCount): Result[(), strin
return err("Too many skipped messages") return err("Too many skipped messages")
while self.msgCountRecv < until: while self.msgCountRecv < until:
let (msgKey, chainKey) = self.kdfChain(self.chainKeyRecv) let (msgKey, chainKey) = kdfChain(self.chainKeyRecv)
self.chainKeyRecv = chainKey self.chainKeyRecv = chainKey
let keyId = keyId(self.dhRemote, self.msgCountRecv) let keyId = keyId(self.dhRemote, self.msgCountRecv)
@ -113,7 +121,7 @@ proc skipMessageKeys(self: var Doubleratchet, until: MsgCount): Result[(), strin
proc encrypt(self: var Doubleratchet, plaintext: var seq[byte], associatedData: openArray[byte]): (DrHeader, CipherText) = proc encrypt(self: var Doubleratchet, plaintext: var seq[byte], associatedData: openArray[byte]): (DrHeader, CipherText) =
let (msgKey, chainKey) = self.kdfChain(self.chainKeySend) let (msgKey, chainKey) = kdfChain(self.chainKeySend)
self.chainKeySend = chainKey self.chainKeySend = chainKey
let header = DrHeader( let header = DrHeader(
dhPublic: self.dhSelf.public, #TODO Serialize dhPublic: self.dhSelf.public, #TODO Serialize
@ -130,14 +138,12 @@ proc encrypt(self: var Doubleratchet, plaintext: var seq[byte], associatedData:
output.add(nonce) output.add(nonce)
output.add(ciphertext) output.add(ciphertext)
(header, output) (header, output)
proc decrypt*(self: var Doubleratchet, header: DrHeader, ciphertext: CipherText, associatedData: openArray[byte] ) : Result[seq[byte], NaxolotlError] = proc decrypt*(self: var Doubleratchet, header: DrHeader, ciphertext: CipherText, associatedData: openArray[byte] ) : Result[seq[byte], NaxolotlError] =
let peerPublic = header.dhPublic let peerPublic = header.dhPublic
var msgKey : MessageKey var msgKey : MessageKey
# Check Skipped Keys # Check Skipped Keys
@ -155,7 +161,7 @@ proc decrypt*(self: var Doubleratchet, header: DrHeader, ciphertext: CipherText,
if r.isErr: if r.isErr:
error "skipMessages", error = r.error() error "skipMessages", error = r.error()
(msgKey, self.chainKeyRecv) = self.kdfChain(self.chainKeyRecv) (msgKey, self.chainKeyRecv) = kdfChain(self.chainKeyRecv)
inc self.msgCountRecv inc self.msgCountRecv
var nonce : Nonce var nonce : Nonce
@ -173,10 +179,10 @@ proc encrypt*(self: var Doubleratchet, plaintext: var seq[byte]) : (DrHeader, Ci
encrypt(self, plaintext,@[]) encrypt(self, plaintext,@[])
func initDoubleratchet*(sharedSecret: array[32, byte], dhSelf: PrivateKey, dhRemote: PublicKey, isSending: bool = true): Doubleratchet = proc initDoubleratchetSender*(sharedSecret: array[32, byte], dhRemote: PublicKey): Doubleratchet =
result = Doubleratchet( result = Doubleratchet(
dhSelf: dhSelf, dhSelf: generateDhKey(),
dhRemote: dhRemote, dhRemote: dhRemote,
rootKey: RootKey(sharedSecret), rootKey: RootKey(sharedSecret),
msgCountSend: 0, msgCountSend: 0,
@ -185,8 +191,20 @@ func initDoubleratchet*(sharedSecret: array[32, byte], dhSelf: PrivateKey, dhRem
skippedMessageKeys: initTable[(PublicKey, MsgCount), MessageKey]() skippedMessageKeys: initTable[(PublicKey, MsgCount), MessageKey]()
) )
if isSending: # Update RK, CKs
result.dhRatchetSend() result.dhRatchetSend()
proc initDoubleratchetRecipient*(sharedSecret: array[32, byte], dhSelf: PrivateKey): Doubleratchet =
result = Doubleratchet(
dhSelf: dhSelf,
#dhRemote: None,
rootKey: RootKey(sharedSecret),
msgCountSend: 0,
msgCountRecv: 0,
prevChainLen: 0,
skippedMessageKeys: initTable[(PublicKey, MsgCount), MessageKey]()
)
func dhSelfPublic*(self: Doubleratchet): PublicKey = func dhSelfPublic*(self: Doubleratchet): PublicKey =
self.dhSelf.public self.dhSelf.public

View File

@ -31,7 +31,6 @@ proc hexToArray*[N: static[int]](hexStr: string): array[N, byte] =
"Hex string length (" & $hexStr.len & ") doesn't match array size (" & $( "Hex string length (" & $hexStr.len & ") doesn't match array size (" & $(
N*2) & ")") N*2) & ")")
var result: array[N, byte]
for i in 0..<N: for i in 0..<N:
result[i] = byte(parseHexInt(hexStr[2*i .. 2*i+1])) result[i] = byte(parseHexInt(hexStr[2*i .. 2*i+1]))
@ -46,9 +45,9 @@ func loadTestKeys() : (array[32,byte],array[32,byte],array[32,byte],array[32,byt
(a_priv, a_pub, b_priv, b_pub) (a_priv, a_pub, b_priv, b_pub)
func createTestInstances(a: array[32, byte], apub: array[32, byte], b: array[32, byte], bpub: array[32, byte],sk: array[32, byte]) : (Doubleratchet, Doubleratchet) = proc createTestInstances(b: array[32, byte], bpub: array[32, byte],sk: array[32, byte]) : (Doubleratchet, Doubleratchet) =
let adr = initDoubleratchet(sk, a, bpub, true) let adr = initDoubleratchetSender(sk, bpub)
let bdr = initDoubleratchet(sk, b, apub, false) let bdr = initDoubleratchetRecipient(sk, b)
(adr,bdr) (adr,bdr)
@ -60,9 +59,8 @@ suite "Doubleratchet":
let sk = hexToArray[32](ks7748_shared_key) let sk = hexToArray[32](ks7748_shared_key)
var adr = initDoubleratchet(sk, a_priv, b_pub, true) var (adr, bdr) = createTestInstances(b_priv, b_pub, sk)
var bdr = initDoubleratchet(sk, b_priv, a_pub, true)
var msg :seq[byte] = @[1,2,3,4,5,6,7,8,9,10] var msg :seq[byte] = @[1,2,3,4,5,6,7,8,9,10]
let (header, ciphertext) = adr.encrypt(msg) let (header, ciphertext) = adr.encrypt(msg)
@ -77,8 +75,7 @@ suite "Doubleratchet":
let sk = hexToArray[32](ks7748_shared_key) let sk = hexToArray[32](ks7748_shared_key)
var adr = initDoubleratchet(sk, a_priv, b_pub, true) var (adr, bdr) = createTestInstances(b_priv, b_pub, sk)
var bdr = initDoubleratchet(sk, b_priv, a_pub, true)
var msg0 :seq[byte] = @[1,2,3,4,5,6,7,8,9,10] var msg0 :seq[byte] = @[1,2,3,4,5,6,7,8,9,10]
var msg1 :seq[byte] = @[6,7,8,9,10,1,2,3,4,5] var msg1 :seq[byte] = @[6,7,8,9,10,1,2,3,4,5]
@ -98,8 +95,7 @@ suite "Doubleratchet":
let sk = hexToArray[32](ks7748_shared_key) let sk = hexToArray[32](ks7748_shared_key)
var adr = initDoubleratchet(sk, a_priv, b_pub, true) var (adr, bdr) = createTestInstances(b_priv, b_pub, sk)
var bdr = initDoubleratchet(sk, b_priv, a_pub, true)
var msg : seq[ seq[byte]]= @[ var msg : seq[ seq[byte]]= @[
@[1,2,3,4,5,6,7,8,9,10], @[1,2,3,4,5,6,7,8,9,10],
@ -132,8 +128,7 @@ suite "Doubleratchet":
let (a_priv, a_pub, b_priv, b_pub) = loadTestKeys() let (a_priv, a_pub, b_priv, b_pub) = loadTestKeys()
let sk = hexToArray[32](ks7748_shared_key) let sk = hexToArray[32](ks7748_shared_key)
var adr = initDoubleratchet(sk, a_priv, b_pub, true) var (adr, bdr) = createTestInstances(b_priv, b_pub, sk)
var bdr = initDoubleratchet(sk, b_priv, a_pub, true)
var msg :seq[byte] = @[1,2,3,4,5,6,7,8,9,10] var msg :seq[byte] = @[1,2,3,4,5,6,7,8,9,10]
@ -150,8 +145,7 @@ suite "Doubleratchet":
let (a_priv, a_pub, b_priv, b_pub) = loadTestKeys() let (a_priv, a_pub, b_priv, b_pub) = loadTestKeys()
let sk = hexToArray[32](ks7748_shared_key) let sk = hexToArray[32](ks7748_shared_key)
var adr = initDoubleratchet(sk, a_priv, b_pub, true) var (adr, bdr) = createTestInstances(b_priv, b_pub, sk)
var bdr = initDoubleratchet(sk, b_priv, a_pub, true)
var msg :seq[byte] = @[1,2,3,4,5,6,7,8,9,10] var msg :seq[byte] = @[1,2,3,4,5,6,7,8,9,10]
@ -167,8 +161,7 @@ suite "Doubleratchet":
let sk = hexToArray[32](ks7748_shared_key) let sk = hexToArray[32](ks7748_shared_key)
var adr = initDoubleratchet(sk, a_priv, b_pub, true) var (adr, bdr) = createTestInstances(b_priv, b_pub, sk)
var bdr = initDoubleratchet(sk, b_priv, a_pub, true)
var last_dh_a : PublicKey var last_dh_a : PublicKey
var last_dh_b : PublicKey var last_dh_b : PublicKey