megolm/outbound_session.go

120 lines
3.3 KiB
Go

package main
import (
"crypto/aes"
"crypto/cipher"
"crypto/ed25519"
"encoding/binary"
"fmt"
)
const (
ED25519_RANDOM_LENGTH = 32
ED25519_PUBLIC_KEY_LENGTH = 32
ED25519_PRIVATE_KEY_LENGTH = 64
ED25519_SIGNATURE_LENGTH = 64
GROUP_MESSAGE_INDEX_TAG = 010
GROUP_CIPHERTEXT_TAG = 022
MAC_LEN = 8
)
type OutboundSession struct {
Ratchet *Megolm
SigningKey *ed25519.PrivateKey
}
func NewOutboundSession(initialData []byte, initialCounter int) (*OutboundSession, error) {
if len(initialData) < MEGOLM_RATCHET_LENGTH+ED25519_RANDOM_LENGTH {
return nil, fmt.Errorf("initialData must be MEGOLM_RATCHET_LENGTH + ED25519_RANDOM_LENGTH = %d bytes. Got %d.", MEGOLM_RATCHET_LENGTH+ED25519_RANDOM_LENGTH, len(initialData))
}
ratchet, err := NewMegolm(initialData[:MEGOLM_RATCHET_LENGTH], initialCounter)
if err != nil {
return nil, err
}
seed := initialData[MEGOLM_RATCHET_LENGTH : MEGOLM_RATCHET_LENGTH+ED25519_RANDOM_LENGTH]
privKey := ed25519.NewKeyFromSeed(seed)
return &OutboundSession{
Ratchet: ratchet,
SigningKey: &privKey,
}, nil
}
func (s *OutboundSession) GenerateSessionKey() []byte {
key := make([]byte, 1+4+MEGOLM_RATCHET_LENGTH+ED25519_PUBLIC_KEY_LENGTH+ED25519_SIGNATURE_LENGTH)
key[0] = SESSION_KEY_VERSION
counter := uint32(s.Ratchet.counter)
for i := 0; i < 4; i++ {
value := 0xFF & (counter >> 24)
binary.BigEndian.PutUint32(key[i*4+1:], value)
counter <<= 8
}
copy(key[5:], s.Ratchet.Data())
copy(key[5+MEGOLM_RATCHET_LENGTH:], s.SigningKey.Public().(ed25519.PublicKey))
sig := ed25519.Sign(*s.SigningKey, key[0:5+MEGOLM_RATCHET_LENGTH+ED25519_PUBLIC_KEY_LENGTH])
copy(key[5+MEGOLM_RATCHET_LENGTH+ED25519_PUBLIC_KEY_LENGTH:], sig)
return key
}
func (s *OutboundSession) Encrypt(plaintext []byte) ([]byte, error) {
derivedKeys, err := deriveKeys(s.Ratchet.Data(), []byte{}, MEGOLM_KDF_INFO)
if err != nil {
return nil, err
}
ciphertextLen := len(plaintext) + aes.BlockSize - len(plaintext)%aes.BlockSize
ciphertextEncodedLen := varStringLen(uint64(ciphertextLen))
encodedLen :=
uvarintLen(OLM_PROTOCOL_VERSION) +
1 + // MESSAGE_INDEX_TAG
uvarintLen(uint64(s.Ratchet.counter)) +
1 + // MESSAGE_CIPHERTEXT_TAG
ciphertextEncodedLen +
uint64(MAC_LEN) +
ED25519_SIGNATURE_LENGTH
encoded := make([]byte, encodedLen)
pos := 0
pos += binary.PutUvarint(encoded[pos:], OLM_PROTOCOL_VERSION)
pos += binary.PutUvarint(encoded[pos:], MESSAGE_INDEX_TAG)
pos += binary.PutUvarint(encoded[pos:], uint64(s.Ratchet.counter))
pos += binary.PutUvarint(encoded[pos:], MESSAGE_CIPHERTEXT_TAG)
pos += binary.PutUvarint(encoded[pos:], uint64(ciphertextLen))
block, err := aes.NewCipher(derivedKeys.aesKey)
if err != nil {
return nil, err
}
paddedPlaintext, err := pkcs7Pad(plaintext, aes.BlockSize)
if err != nil {
return nil, err
}
ciphertext := make([]byte, len(plaintext)+aes.BlockSize-len(plaintext)%aes.BlockSize)
mode := cipher.NewCBCEncrypter(block, derivedKeys.aesIV)
mode.CryptBlocks(ciphertext, paddedPlaintext)
copy(encoded[pos:], ciphertext)
pos += len(ciphertext)
mac := HMACSHA256(derivedKeys.macKey, encoded[0:pos])
copy(encoded[pos:pos+MAC_LEN], mac[0:MAC_LEN])
pos += MAC_LEN
sig := ed25519.Sign(*s.SigningKey, encoded[0:pos])
copy(encoded[pos:], sig)
s.Ratchet.Advance()
return encoded, nil
}