From 636779586288b27d64e35e59775ba92904100ee7 Mon Sep 17 00:00:00 2001 From: Andrea Franz Date: Wed, 30 Jun 2021 11:56:52 +0200 Subject: [PATCH] import initial implementation --- go.mod | 5 ++ go.sum | 7 ++ inbound_session.go | 186 +++++++++++++++++++++++++++++++++++++++ inbound_session_test.go | 33 +++++++ megolm.go | 125 ++++++++++++++++++++++++++ megolm_test.go | 116 ++++++++++++++++++++++++ outbound_session.go | 119 +++++++++++++++++++++++++ outbound_session_test.go | 39 ++++++++ utils.go | 114 ++++++++++++++++++++++++ 9 files changed, 744 insertions(+) create mode 100644 go.mod create mode 100644 go.sum create mode 100644 inbound_session.go create mode 100644 inbound_session_test.go create mode 100644 megolm.go create mode 100644 megolm_test.go create mode 100644 outbound_session.go create mode 100644 outbound_session_test.go create mode 100644 utils.go diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..13337e3 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module github.com/status-im/megolm + +go 1.16 + +require golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..b4f9bee --- /dev/null +++ b/go.sum @@ -0,0 +1,7 @@ +golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a h1:kr2P4QFmQr29mSLA43kwrOcgcReGTfbE9N577tCTuBc= +golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/inbound_session.go b/inbound_session.go new file mode 100644 index 0000000..a314474 --- /dev/null +++ b/inbound_session.go @@ -0,0 +1,186 @@ +package main + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/ed25519" + "encoding/base64" + "encoding/binary" + "errors" + "fmt" +) + +const ( + AES256_KEY_LENGTH = 32 + AES256_IV_LENGTH = 16 + HMAC_KEY_LENGTH = 32 + HKDF_DEFAULT_SALT_LEN = 32 + SESSION_SHARING_LENGTH = 229 + SESSION_KEY_VERSION = 2 + + MESSAGE_INDEX_TAG = 0x08 + MESSAGE_CIPHERTEXT_TAG = 0x12 +) + +var MEGOLM_KDF_INFO = []byte("MEGOLM_KEYS") + +type Message struct { + Version byte + Payload []byte + Index int + Plaintext []byte + MAC []byte + Signature []byte +} + +type DerivedKeys struct { + aesKey []byte + macKey []byte + aesIV []byte +} + +type InboundSession struct { + Version byte + SigningPubKey []byte + InitialRatchet *Megolm + LatestRatchet *Megolm +} + +func NewInboundSession(encodedSession string) (*InboundSession, error) { + session, err := base64.StdEncoding.DecodeString(encodedSession) + if err != nil { + return nil, err + } + + if len(session) != SESSION_SHARING_LENGTH { + return nil, errors.New("bad session length") + } + + version := session[0] + index := int(binary.BigEndian.Uint32(session[1:5])) + ratchets := session[5:133] + signingPubKey := session[133:165] + signature := session[165:229] + + initialRatchet, err := NewMegolm(ratchets, index) + if err != nil { + return nil, err + } + + latestRatchet, err := NewMegolm(ratchets, index) + if err != nil { + return nil, err + } + + if !ed25519.Verify(signingPubKey, session[0:165], signature) { + return nil, errors.New("bad signature") + } + + return &InboundSession{ + Version: version, + SigningPubKey: signingPubKey, + InitialRatchet: initialRatchet, + LatestRatchet: latestRatchet, + }, nil +} + +func (s *InboundSession) Decrypt(rawMessage string) (*Message, error) { + decoded, err := base64.StdEncoding.DecodeString(rawMessage) + if err != nil { + return nil, err + } + + // 1 (version) + n (payload) + 8 (MAC) + 72 (signature) + // TODO: it should be min 81 + min message len + if len(decoded) < 81 { + return nil, errors.New("message too short") + } + + // TODO: check max length + msgLen := len(decoded) + + version := decoded[0] + payload := decoded[1 : msgLen-72] + mac := decoded[msgLen-72 : msgLen] + signature := decoded[msgLen-64 : msgLen] + + index, ciphertext, err := decodeMessage(payload) + if err != nil { + return nil, err + } + + message := decoded[0 : msgLen-64] + if !ed25519.Verify(s.SigningPubKey, message, signature) { + return nil, errors.New("bad signature") + } + + // TODO: advance ratchet to index + derivedKeys, err := deriveKeys(s.LatestRatchet.Data(), []byte{}, MEGOLM_KDF_INFO) + if err != nil { + return nil, err + } + + // TODO: check mac + block, err := aes.NewCipher(derivedKeys.aesKey) + if err != nil { + return nil, err + } + + mode := cipher.NewCBCDecrypter(block, derivedKeys.aesIV) + mode.CryptBlocks(ciphertext, ciphertext) + ciphertext, _ = pkcs7Unpad(ciphertext, aes.BlockSize) + + return &Message{ + Version: version, + Payload: payload, + Index: index, + Plaintext: ciphertext, + MAC: mac, + Signature: signature, + }, nil +} + +//TODO: check megolm encoding implementation +func decodeMessage(payload []byte) (int, []byte, error) { + buf := bytes.NewBuffer(payload) + tag, err := binary.ReadUvarint(buf) + if err != nil { + return 0, nil, err + } + + if tag != MESSAGE_INDEX_TAG { + return 0, nil, fmt.Errorf("expected tag %x, got %x", MESSAGE_INDEX_TAG, tag) + } + + index, err := binary.ReadUvarint(buf) + if err != nil { + return 0, nil, err + } + + tag, err = binary.ReadUvarint(buf) + if err != nil { + return 0, nil, err + } + + if tag != MESSAGE_CIPHERTEXT_TAG { + return 0, nil, fmt.Errorf("expected tag %x, got %x", MESSAGE_CIPHERTEXT_TAG, tag) + } + + length, err := binary.ReadUvarint(buf) + if err != nil { + return 0, nil, err + } + + ciphertext := make([]byte, length) + n, err := buf.Read(ciphertext) + if err != nil { + return 0, nil, err + } + + if n != int(length) { + return 0, nil, fmt.Errorf("expected cipertext length %d, got %d", length, n) + } + + return int(index), ciphertext, nil +} diff --git a/inbound_session_test.go b/inbound_session_test.go new file mode 100644 index 0000000..5ba1b2f --- /dev/null +++ b/inbound_session_test.go @@ -0,0 +1,33 @@ +package main + +import ( + "bytes" + "log" + "testing" +) + +func TestInboundSession(t *testing.T) { + sessionKey := "AgAAAAAwMTIzNDU2Nzg5QUJERUYwMTIzNDU2Nzg5QUJDREVGMDEyMzQ1Njc4OUFCREVGM" + + "DEyMzQ1Njc4OUFCQ0RFRjAxMjM0NTY3ODlBQkRFRjAxMjM0NTY3ODlBQkNERUYwMTIzND" + + "U2Nzg5QUJERUYwMTIzNDU2Nzg5QUJDREVGMDEyMztqJ7zOtqQtYqOo0CpvDXNlMhV3HeJ" + + "DpjrASKGLWdop4lx1cSN3Xv1TgfLPW8rhGiW+hHiMxd36nRuxscNv9k4oJA/KP+o0mi1w" + + "v44StrEJ1wwx9WZHBUIWkQbaBSuBDw==" + + message := "AwgAEhAcbh6UpbByoyZxufQ+h2B+8XHMjhR69G8nP4pNZGl/3QMgrzCZPmP+F2aPLyKPz" + + "xRPBMUkeXRJ6Iqm5NeOdx2eERgTW7P20CM+lL3Xpk+ZUOOPvsSQNaAL" + + s, err := NewInboundSession(sessionKey) + if err != nil { + log.Fatal(err) + } + + m, err := s.Decrypt(message) + if err != nil { + log.Fatal(err) + } + + expectedMessage := []byte("Message") + if !bytes.Equal(m.Plaintext, expectedMessage) { + t.Fatalf("expected message to be `%s`, got `%s`", expectedMessage, m.Plaintext) + } +} diff --git a/megolm.go b/megolm.go new file mode 100644 index 0000000..cb6a140 --- /dev/null +++ b/megolm.go @@ -0,0 +1,125 @@ +package main + +import ( + "fmt" +) + +const ( + SHA256_BLOCK_LENGTH = 64 + SHA256_OUTPUT_LENGTH = 32 + MEGOLM_RATCHET_PART_LENGTH = 32 + MEGOLM_RATCHET_PARTS = 4 + MEGOLM_RATCHET_LENGTH = (MEGOLM_RATCHET_PARTS * MEGOLM_RATCHET_PART_LENGTH) + OLM_PROTOCOL_VERSION = 3 +) + +var HASH_KEY_SEEDS = [MEGOLM_RATCHET_PARTS]byte{0x00, 0x01, 0x02, 0x03} + +type Megolm struct { + data [][]byte + counter int +} + +func NewMegolm(initialData []byte, initialCounter int) (*Megolm, error) { + if len(initialData) != MEGOLM_RATCHET_LENGTH { + return nil, fmt.Errorf("megolm initial data must be %d bytes. Got %d.", MEGOLM_RATCHET_LENGTH, len(initialData)) + } + + data := make([][]byte, MEGOLM_RATCHET_PARTS) + for i := 0; i < MEGOLM_RATCHET_PARTS; i++ { + data[i] = make([]byte, MEGOLM_RATCHET_PART_LENGTH) + start := i * MEGOLM_RATCHET_PART_LENGTH + copy(data[i], initialData[start:start+MEGOLM_RATCHET_PART_LENGTH]) + } + + return &Megolm{ + data: data, + counter: initialCounter, + }, nil +} + +func (m *Megolm) Data() []byte { + data := make([]byte, MEGOLM_RATCHET_LENGTH) + for i := 0; i < MEGOLM_RATCHET_PARTS; i++ { + start := i * MEGOLM_RATCHET_PART_LENGTH + copy(data[start:start+MEGOLM_RATCHET_PART_LENGTH], m.data[i]) + } + + return data +} + +func (m *Megolm) Advance() { + mask := 0x00FFFFFF + h := 0 + m.counter++ + + /* figure out how much we need to rekey */ + for h < MEGOLM_RATCHET_PARTS { + if m.counter&mask == 0 { + break + } + + h++ + mask >>= 8 + } + + // update R[h:3] based on h + for i := MEGOLM_RATCHET_PARTS - 1; i >= h; i-- { + m.rehashPart(h, i) + } +} + +func (m *Megolm) AdvanceTo(advanceTo int) { + /* starting with R0, see if we need to update each part of the hash */ + for j := 0; j < MEGOLM_RATCHET_PARTS; j++ { + shift := (MEGOLM_RATCHET_PARTS - j - 1) * 8 + mask := 0xffffffff << shift + + /* how many times do we need to rehash this part? + * + * '& 0xff' ensures we handle integer wraparound correctly + */ + steps := + ((advanceTo >> shift) - (m.counter >> shift)) & 0xff + + if steps == 0 { + /* deal with the edge case where megolm->counter is slightly larger + * than advanceTo. This should only happen for R(0), and implies + * that advanceTo has wrapped around and we need to advance R(0) + * 256 times. + */ + if advanceTo < m.counter { + steps = 0x100 + } else { + continue + } + } + + /* for all but the last step, we can just bump R(j) without regard + * to R(j+1)...R(3). + */ + for steps > 1 { + m.rehashPart(j, j) + steps-- + } + + /* on the last step we also need to bump R(j+1)...R(3). + * + * (Theoretically, we could skip bumping R(j+2) if we're going to bump + * R(j+1) again, but the code to figure that out is a bit baroque and + * doesn't save us much). + */ + for k := 3; k >= j; k-- { + m.rehashPart(j, k) + } + m.counter = advanceTo & mask + } +} + +func (m *Megolm) rehashPart(fromPart, toPart int) { + newPart := HMACSHA256( + m.data[fromPart], + []byte{HASH_KEY_SEEDS[toPart]}, + ) + m.data[toPart] = newPart +} diff --git a/megolm_test.go b/megolm_test.go new file mode 100644 index 0000000..d4a15e5 --- /dev/null +++ b/megolm_test.go @@ -0,0 +1,116 @@ +package main + +import ( + "bytes" + "encoding/hex" + "log" + "testing" +) + +func TestAdvance(t *testing.T) { + randomData := []byte("0123456789ABCDEF0123456789ABCDEF" + + "0123456789ABCDEF0123456789ABCDEF" + + "0123456789ABCDEF0123456789ABCDEF" + + "0123456789ABCDEF0123456789ABCDEF") + + m, err := NewMegolm(randomData, 0) + if err != nil { + t.Fatal(err) + } + + m.Advance() + index := 1 + expected1 := [][]byte{ + []byte{0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46}, + []byte{0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46}, + []byte{0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46}, + []byte{0xba, 0x9c, 0xd9, 0x55, 0x74, 0x1d, 0x1c, 0x16, 0x23, 0x23, 0xec, 0x82, 0x5e, 0x7c, 0x5c, 0xe8, + 0x89, 0xbb, 0xb4, 0x23, 0xa1, 0x8f, 0x23, 0x82, 0x8f, 0xb2, 0x09, 0x0d, 0x6e, 0x2a, 0xf8, 0x6a}, + } + + if m.counter != index { + t.Errorf("expected counter to be: %#x, got %#x", index, m.counter) + } + + for i, _ := range m.data { + if !bytes.Equal(m.data[i], expected1[i]) { + t.Errorf("expected part 1 to be: %x, got %x", m.data[i], expected1[i]) + } + } + + index = 0x1000000 + m.AdvanceTo(index) + expected2 := [][]byte{ + []byte{0x54, 0x02, 0x2d, 0x7d, 0xc0, 0x29, 0x8e, 0x16, 0x37, 0xe2, 0x1c, 0x97, 0x15, 0x30, 0x92, 0xf9, + 0x33, 0xc0, 0x56, 0xff, 0x74, 0xfe, 0x1b, 0x92, 0x2d, 0x97, 0x1f, 0x24, 0x82, 0xc2, 0x85, 0x9c}, + []byte{0x70, 0x04, 0xc0, 0x1e, 0xe4, 0x9b, 0xd6, 0xef, 0xe0, 0x07, 0x35, 0x25, 0xaf, 0x9b, 0x16, 0x32, + 0xc5, 0xbe, 0x72, 0x6d, 0x12, 0x34, 0x9c, 0xc5, 0xbd, 0x47, 0x2b, 0xdc, 0x2d, 0xf6, 0x54, 0x0f}, + []byte{0x31, 0x12, 0x59, 0x11, 0x94, 0xfd, 0xa6, 0x17, 0xe5, 0x68, 0xc6, 0x83, 0x10, 0x1e, 0xae, 0xcd, + 0x7e, 0xdd, 0xd6, 0xde, 0x1f, 0xbc, 0x07, 0x67, 0xae, 0x34, 0xda, 0x1a, 0x09, 0xa5, 0x4e, 0xab}, + []byte{0xba, 0x9c, 0xd9, 0x55, 0x74, 0x1d, 0x1c, 0x16, 0x23, 0x23, 0xec, 0x82, 0x5e, 0x7c, 0x5c, 0xe8, + 0x89, 0xbb, 0xb4, 0x23, 0xa1, 0x8f, 0x23, 0x82, 0x8f, 0xb2, 0x09, 0x0d, 0x6e, 0x2a, 0xf8, 0x6a}, + } + + if m.counter != index { + t.Errorf("expected counter to be: %#x, got %#x", index, m.counter) + } + + for i, _ := range m.data { + if !bytes.Equal(m.data[i], expected2[i]) { + t.Errorf("expected part 1 to be: %x, got %x", m.data[i], expected1[i]) + } + } + + index = 0x1041506 + m.AdvanceTo(index) + expected3 := [][]byte{ + []byte{0x54, 0x02, 0x2d, 0x7d, 0xc0, 0x29, 0x8e, 0x16, 0x37, 0xe2, 0x1c, 0x97, 0x15, 0x30, 0x92, 0xf9, + 0x33, 0xc0, 0x56, 0xff, 0x74, 0xfe, 0x1b, 0x92, 0x2d, 0x97, 0x1f, 0x24, 0x82, 0xc2, 0x85, 0x9c}, + []byte{0x55, 0x58, 0x8d, 0xf5, 0xb7, 0xa4, 0x88, 0x78, 0x42, 0x89, 0x27, 0x86, 0x81, 0x64, 0x58, 0x9f, + 0x36, 0x63, 0x44, 0x7b, 0x51, 0xed, 0xc3, 0x59, 0x5b, 0x03, 0x6c, 0xa6, 0x04, 0xc4, 0x6d, 0xcd}, + []byte{0x5c, 0x54, 0x85, 0x0b, 0xfa, 0x98, 0xa1, 0xfd, 0x79, 0xa9, 0xdf, 0x1c, 0xbe, 0x8f, 0xc5, 0x68, + 0x19, 0x37, 0xd3, 0x0c, 0x85, 0xc8, 0xc3, 0x1f, 0x7b, 0xb8, 0x28, 0x81, 0x6c, 0xf9, 0xff, 0x3b}, + []byte{0x95, 0x6c, 0xbf, 0x80, 0x7e, 0x65, 0x12, 0x6a, 0x49, 0x55, 0x8d, 0x45, 0xc8, 0x4a, 0x2e, 0x4c, + 0xd5, 0x6f, 0x03, 0xe2, 0x44, 0x16, 0xb9, 0x8e, 0x1c, 0xfd, 0x97, 0xc2, 0x06, 0xaa, 0x90, 0x7a}, + } + + if m.counter != index { + t.Errorf("expected counter to be: %#x, got %#x", index, m.counter) + } + + for i, _ := range m.data { + if !bytes.Equal(m.data[i], expected3[i]) { + t.Errorf("expected part 1 to be: %x, got %x", m.data[i], expected1[i]) + } + } +} + +func decodeHEX(s string) []byte { + decoded, err := hex.DecodeString(s) + if err != nil { + log.Fatal(err) + } + return decoded +} + +func TestDeriveKeys(t *testing.T) { + inputKeys := decodeHEX("3031323334353637383941424445463031323334353637383941424344454630313233343536373839414244454630313233343536373839414243444546303132333435363738394142444546303132333435363738394142434445463031323334353637383941424445463031323334353637383941424344454630313233") + info := decodeHEX("4d45474f4c4d5f4b455953") + expected := decodeHEX("bdf842da829272e777e3c6ba70bb2f6e976e32330691e4dc3ca3dbbd43c62c0f4075ce27caf907395c3287c1971d95c58909ce9f810c1ee79fef7eefa10c46dc3694fc16cbd3a8b16b66cbc897121c6b") + + keys, err := deriveKeys(inputKeys, []byte{}, info) + if err != nil { + t.Fatal(err) + } + + res := keys.aesKey + res = append(res, keys.macKey...) + res = append(res, keys.aesIV...) + + if !bytes.Equal(expected, res) { + t.Errorf("expected %x, got %x", expected, res) + } +} diff --git a/outbound_session.go b/outbound_session.go new file mode 100644 index 0000000..90c9c6f --- /dev/null +++ b/outbound_session.go @@ -0,0 +1,119 @@ +package main + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/ed25519" + "encoding/binary" + "fmt" +) + +const ( + ED25519_RANDOM_LENGTH = 32 + ED25519_PUBLIC_KEY_LENGTH = 32 + ED25519_PRIVATE_KEY_LENGTH = 64 + ED25519_SIGNATURE_LENGTH = 64 + GROUP_MESSAGE_INDEX_TAG = 010 + GROUP_CIPHERTEXT_TAG = 022 + MAC_LEN = 8 +) + +type OutboundSession struct { + Ratchet *Megolm + SigningKey *ed25519.PrivateKey +} + +func NewOutboundSession(initialData []byte, initialCounter int) (*OutboundSession, error) { + if len(initialData) < MEGOLM_RATCHET_LENGTH+ED25519_RANDOM_LENGTH { + return nil, fmt.Errorf("initialData must be MEGOLM_RATCHET_LENGTH + ED25519_RANDOM_LENGTH = %d bytes. Got %d.", MEGOLM_RATCHET_LENGTH+ED25519_RANDOM_LENGTH, len(initialData)) + } + + ratchet, err := NewMegolm(initialData[:MEGOLM_RATCHET_LENGTH], initialCounter) + if err != nil { + return nil, err + } + + seed := initialData[MEGOLM_RATCHET_LENGTH : MEGOLM_RATCHET_LENGTH+ED25519_RANDOM_LENGTH] + privKey := ed25519.NewKeyFromSeed(seed) + + return &OutboundSession{ + Ratchet: ratchet, + SigningKey: &privKey, + }, nil +} + +func (s *OutboundSession) GenerateSessionKey() []byte { + key := make([]byte, 1+4+MEGOLM_RATCHET_LENGTH+ED25519_PUBLIC_KEY_LENGTH+ED25519_SIGNATURE_LENGTH) + + key[0] = SESSION_KEY_VERSION + + counter := uint32(s.Ratchet.counter) + for i := 0; i < 4; i++ { + value := 0xFF & (counter >> 24) + binary.BigEndian.PutUint32(key[i*4+1:], value) + counter <<= 8 + } + + copy(key[5:], s.Ratchet.Data()) + copy(key[5+MEGOLM_RATCHET_LENGTH:], s.SigningKey.Public().(ed25519.PublicKey)) + + sig := ed25519.Sign(*s.SigningKey, key[0:5+MEGOLM_RATCHET_LENGTH+ED25519_PUBLIC_KEY_LENGTH]) + copy(key[5+MEGOLM_RATCHET_LENGTH+ED25519_PUBLIC_KEY_LENGTH:], sig) + + return key +} + +func (s *OutboundSession) Encrypt(plaintext []byte) ([]byte, error) { + derivedKeys, err := deriveKeys(s.Ratchet.Data(), []byte{}, MEGOLM_KDF_INFO) + if err != nil { + return nil, err + } + + ciphertextLen := len(plaintext) + aes.BlockSize - len(plaintext)%aes.BlockSize + ciphertextEncodedLen := varStringLen(uint64(ciphertextLen)) + encodedLen := + uvarintLen(OLM_PROTOCOL_VERSION) + + 1 + // MESSAGE_INDEX_TAG + uvarintLen(uint64(s.Ratchet.counter)) + + 1 + // MESSAGE_CIPHERTEXT_TAG + ciphertextEncodedLen + + uint64(MAC_LEN) + + ED25519_SIGNATURE_LENGTH + + encoded := make([]byte, encodedLen) + pos := 0 + pos += binary.PutUvarint(encoded[pos:], OLM_PROTOCOL_VERSION) + pos += binary.PutUvarint(encoded[pos:], MESSAGE_INDEX_TAG) + pos += binary.PutUvarint(encoded[pos:], uint64(s.Ratchet.counter)) + pos += binary.PutUvarint(encoded[pos:], MESSAGE_CIPHERTEXT_TAG) + pos += binary.PutUvarint(encoded[pos:], uint64(ciphertextLen)) + + block, err := aes.NewCipher(derivedKeys.aesKey) + if err != nil { + return nil, err + } + + paddedPlaintext, err := pkcs7Pad(plaintext, aes.BlockSize) + if err != nil { + return nil, err + } + + ciphertext := make([]byte, len(plaintext)+aes.BlockSize-len(plaintext)%aes.BlockSize) + + mode := cipher.NewCBCEncrypter(block, derivedKeys.aesIV) + mode.CryptBlocks(ciphertext, paddedPlaintext) + + copy(encoded[pos:], ciphertext) + pos += len(ciphertext) + + mac := HMACSHA256(derivedKeys.macKey, encoded[0:pos]) + copy(encoded[pos:pos+MAC_LEN], mac[0:MAC_LEN]) + pos += MAC_LEN + + sig := ed25519.Sign(*s.SigningKey, encoded[0:pos]) + copy(encoded[pos:], sig) + + s.Ratchet.Advance() + + return encoded, nil +} diff --git a/outbound_session_test.go b/outbound_session_test.go new file mode 100644 index 0000000..c0d715b --- /dev/null +++ b/outbound_session_test.go @@ -0,0 +1,39 @@ +package main + +import ( + "bytes" + "testing" +) + +func TestOutboundSession(t *testing.T) { + randomData := []byte( + "0123456789ABDEF0123456789ABCDEF" + + "0123456789ABDEF0123456789ABCDEF" + + "0123456789ABDEF0123456789ABCDEF" + + "0123456789ABDEF0123456789ABCDEF" + + "0123456789ABDEF0123456789ABCDEF" + + "0123456789ABDEF0123456789ABCDEF") + + o, err := NewOutboundSession(randomData, 0) + if err != nil { + t.Fatal(err) + } + + sessionKey := o.GenerateSessionKey() + + expectedSessionKey := decodeHEX("020000000030313233343536373839414244454630313233343536373839414243444546303132333435363738394142444546303132333435363738394142434445463031323334353637383941424445463031323334353637383941424344454630313233343536373839414244454630313233343536373839414243444546303132330d1b760d410eae0fc7fb25068c34eaaf27fc1f5605fc1663234e07c0e55265c61b0ac599a49bf5b47f798f3180446bc663f4ee11660fdbaa036a7c3cc1dd9d5e72b2682e7c4ea82d9e7ef3f99640e5b75554e70c8967c1df281d4611ba4f1309") + if !bytes.Equal(sessionKey, expectedSessionKey) { + t.Fatalf("expected %x, got %x", expectedSessionKey, sessionKey) + } + + plaintext := "Message" + encrypted, err := o.Encrypt([]byte(plaintext)) + if err != nil { + t.Fatal(err) + } + + expectedEncrytped := decodeHEX("03080012101c6e1e94a5b072a32671b9f43e87607ef171cc8e147af46f05e3eaa331a1659c85eeb09657debf0b9d13911bd4f70d715d6d5baf216acbc917e6e22f1045b7bc54fc064451b9ec1d42b66c5a9ee3fc6c7b679d8aa5c0fb03") + if !bytes.Equal(encrypted, expectedEncrytped) { + t.Fatalf("expected %x, got %x", expectedEncrytped, encrypted) + } +} diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..eca1e9e --- /dev/null +++ b/utils.go @@ -0,0 +1,114 @@ +package main + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "errors" + + "golang.org/x/crypto/hkdf" +) + +func deriveKeys(input []byte, salt []byte, info []byte) (*DerivedKeys, error) { + h := hkdf.New(sha256.New, input, salt, info) + + keys := &DerivedKeys{ + aesKey: make([]byte, AES256_KEY_LENGTH), + macKey: make([]byte, HMAC_KEY_LENGTH), + aesIV: make([]byte, AES256_IV_LENGTH), + } + + if _, err := h.Read(keys.aesKey); err != nil { + return nil, err + } + + if _, err := h.Read(keys.macKey); err != nil { + return nil, err + } + + if _, err := h.Read(keys.aesIV); err != nil { + return nil, err + } + + return keys, nil +} + +func pkcs7Pad(data []byte, blocksize int) ([]byte, error) { + if blocksize <= 0 { + return nil, errors.New("ErrInvalidBlockSize") + } + + if data == nil || len(data) == 0 { + return nil, errors.New("ErrInvalidPKCS7Data") + } + + padLen := blocksize - len(data)%blocksize + padding := bytes.Repeat([]byte{byte(padLen)}, padLen) + + return append(data, padding...), nil +} + +func pkcs7Unpad(data []byte, blocksize int) ([]byte, error) { + if blocksize <= 0 { + return nil, errors.New("ErrInvalidBlockSize") + } + + if data == nil || len(data) == 0 { + return nil, errors.New("ErrInvalidPKCS7Data") + } + + if len(data)%blocksize != 0 { + return nil, errors.New("ErrInvalidPKCS7Padding") + } + + c := data[len(data)-1] + n := int(c) + + if n == 0 || n > len(data) { + return nil, errors.New("ErrInvalidPKCS7Padding") + } + + for i := 0; i < n; i++ { + if data[len(data)-n+i] != c { + return nil, errors.New("ErrInvalidPKCS7Padding") + } + } + + return data[:len(data)-n], nil +} + +func generateHMACKey(inputKey []byte) []byte { + hmacKey := make([]byte, SHA256_BLOCK_LENGTH) + if len(inputKey) > SHA256_BLOCK_LENGTH { + // TODO: check this part + h := sha256.New() + h.Write(inputKey) + res := h.Sum(nil) + copy(hmacKey, res) + } else { + copy(hmacKey, inputKey) + } + + return hmacKey +} + +func HMACSHA256(key []byte, input []byte) []byte { + hmacKey := generateHMACKey(key) + ctx := hmac.New(sha256.New, hmacKey) + ctx.Write(input) + return ctx.Sum(nil) +} + +func uvarintLen(x uint64) uint64 { + var res uint64 = 1 + for x >= 0x80 { + res++ + x >>= 7 + } + + return res +} + +func varStringLen(n uint64) uint64 { + return uvarintLen(n) + n +}