mirror of https://github.com/status-im/megolm.git
import initial implementation
This commit is contained in:
commit
6367795862
|
@ -0,0 +1,5 @@
|
|||
module github.com/status-im/megolm
|
||||
|
||||
go 1.16
|
||||
|
||||
require golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a // indirect
|
|
@ -0,0 +1,7 @@
|
|||
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a h1:kr2P4QFmQr29mSLA43kwrOcgcReGTfbE9N577tCTuBc=
|
||||
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
|
@ -0,0 +1,186 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
AES256_KEY_LENGTH = 32
|
||||
AES256_IV_LENGTH = 16
|
||||
HMAC_KEY_LENGTH = 32
|
||||
HKDF_DEFAULT_SALT_LEN = 32
|
||||
SESSION_SHARING_LENGTH = 229
|
||||
SESSION_KEY_VERSION = 2
|
||||
|
||||
MESSAGE_INDEX_TAG = 0x08
|
||||
MESSAGE_CIPHERTEXT_TAG = 0x12
|
||||
)
|
||||
|
||||
var MEGOLM_KDF_INFO = []byte("MEGOLM_KEYS")
|
||||
|
||||
type Message struct {
|
||||
Version byte
|
||||
Payload []byte
|
||||
Index int
|
||||
Plaintext []byte
|
||||
MAC []byte
|
||||
Signature []byte
|
||||
}
|
||||
|
||||
type DerivedKeys struct {
|
||||
aesKey []byte
|
||||
macKey []byte
|
||||
aesIV []byte
|
||||
}
|
||||
|
||||
type InboundSession struct {
|
||||
Version byte
|
||||
SigningPubKey []byte
|
||||
InitialRatchet *Megolm
|
||||
LatestRatchet *Megolm
|
||||
}
|
||||
|
||||
func NewInboundSession(encodedSession string) (*InboundSession, error) {
|
||||
session, err := base64.StdEncoding.DecodeString(encodedSession)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(session) != SESSION_SHARING_LENGTH {
|
||||
return nil, errors.New("bad session length")
|
||||
}
|
||||
|
||||
version := session[0]
|
||||
index := int(binary.BigEndian.Uint32(session[1:5]))
|
||||
ratchets := session[5:133]
|
||||
signingPubKey := session[133:165]
|
||||
signature := session[165:229]
|
||||
|
||||
initialRatchet, err := NewMegolm(ratchets, index)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
latestRatchet, err := NewMegolm(ratchets, index)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !ed25519.Verify(signingPubKey, session[0:165], signature) {
|
||||
return nil, errors.New("bad signature")
|
||||
}
|
||||
|
||||
return &InboundSession{
|
||||
Version: version,
|
||||
SigningPubKey: signingPubKey,
|
||||
InitialRatchet: initialRatchet,
|
||||
LatestRatchet: latestRatchet,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *InboundSession) Decrypt(rawMessage string) (*Message, error) {
|
||||
decoded, err := base64.StdEncoding.DecodeString(rawMessage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 1 (version) + n (payload) + 8 (MAC) + 72 (signature)
|
||||
// TODO: it should be min 81 + min message len
|
||||
if len(decoded) < 81 {
|
||||
return nil, errors.New("message too short")
|
||||
}
|
||||
|
||||
// TODO: check max length
|
||||
msgLen := len(decoded)
|
||||
|
||||
version := decoded[0]
|
||||
payload := decoded[1 : msgLen-72]
|
||||
mac := decoded[msgLen-72 : msgLen]
|
||||
signature := decoded[msgLen-64 : msgLen]
|
||||
|
||||
index, ciphertext, err := decodeMessage(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
message := decoded[0 : msgLen-64]
|
||||
if !ed25519.Verify(s.SigningPubKey, message, signature) {
|
||||
return nil, errors.New("bad signature")
|
||||
}
|
||||
|
||||
// TODO: advance ratchet to index
|
||||
derivedKeys, err := deriveKeys(s.LatestRatchet.Data(), []byte{}, MEGOLM_KDF_INFO)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: check mac
|
||||
block, err := aes.NewCipher(derivedKeys.aesKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mode := cipher.NewCBCDecrypter(block, derivedKeys.aesIV)
|
||||
mode.CryptBlocks(ciphertext, ciphertext)
|
||||
ciphertext, _ = pkcs7Unpad(ciphertext, aes.BlockSize)
|
||||
|
||||
return &Message{
|
||||
Version: version,
|
||||
Payload: payload,
|
||||
Index: index,
|
||||
Plaintext: ciphertext,
|
||||
MAC: mac,
|
||||
Signature: signature,
|
||||
}, nil
|
||||
}
|
||||
|
||||
//TODO: check megolm encoding implementation
|
||||
func decodeMessage(payload []byte) (int, []byte, error) {
|
||||
buf := bytes.NewBuffer(payload)
|
||||
tag, err := binary.ReadUvarint(buf)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
if tag != MESSAGE_INDEX_TAG {
|
||||
return 0, nil, fmt.Errorf("expected tag %x, got %x", MESSAGE_INDEX_TAG, tag)
|
||||
}
|
||||
|
||||
index, err := binary.ReadUvarint(buf)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
tag, err = binary.ReadUvarint(buf)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
if tag != MESSAGE_CIPHERTEXT_TAG {
|
||||
return 0, nil, fmt.Errorf("expected tag %x, got %x", MESSAGE_CIPHERTEXT_TAG, tag)
|
||||
}
|
||||
|
||||
length, err := binary.ReadUvarint(buf)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
ciphertext := make([]byte, length)
|
||||
n, err := buf.Read(ciphertext)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
if n != int(length) {
|
||||
return 0, nil, fmt.Errorf("expected cipertext length %d, got %d", length, n)
|
||||
}
|
||||
|
||||
return int(index), ciphertext, nil
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInboundSession(t *testing.T) {
|
||||
sessionKey := "AgAAAAAwMTIzNDU2Nzg5QUJERUYwMTIzNDU2Nzg5QUJDREVGMDEyMzQ1Njc4OUFCREVGM" +
|
||||
"DEyMzQ1Njc4OUFCQ0RFRjAxMjM0NTY3ODlBQkRFRjAxMjM0NTY3ODlBQkNERUYwMTIzND" +
|
||||
"U2Nzg5QUJERUYwMTIzNDU2Nzg5QUJDREVGMDEyMztqJ7zOtqQtYqOo0CpvDXNlMhV3HeJ" +
|
||||
"DpjrASKGLWdop4lx1cSN3Xv1TgfLPW8rhGiW+hHiMxd36nRuxscNv9k4oJA/KP+o0mi1w" +
|
||||
"v44StrEJ1wwx9WZHBUIWkQbaBSuBDw=="
|
||||
|
||||
message := "AwgAEhAcbh6UpbByoyZxufQ+h2B+8XHMjhR69G8nP4pNZGl/3QMgrzCZPmP+F2aPLyKPz" +
|
||||
"xRPBMUkeXRJ6Iqm5NeOdx2eERgTW7P20CM+lL3Xpk+ZUOOPvsSQNaAL"
|
||||
|
||||
s, err := NewInboundSession(sessionKey)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
m, err := s.Decrypt(message)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
expectedMessage := []byte("Message")
|
||||
if !bytes.Equal(m.Plaintext, expectedMessage) {
|
||||
t.Fatalf("expected message to be `%s`, got `%s`", expectedMessage, m.Plaintext)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,125 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
SHA256_BLOCK_LENGTH = 64
|
||||
SHA256_OUTPUT_LENGTH = 32
|
||||
MEGOLM_RATCHET_PART_LENGTH = 32
|
||||
MEGOLM_RATCHET_PARTS = 4
|
||||
MEGOLM_RATCHET_LENGTH = (MEGOLM_RATCHET_PARTS * MEGOLM_RATCHET_PART_LENGTH)
|
||||
OLM_PROTOCOL_VERSION = 3
|
||||
)
|
||||
|
||||
var HASH_KEY_SEEDS = [MEGOLM_RATCHET_PARTS]byte{0x00, 0x01, 0x02, 0x03}
|
||||
|
||||
type Megolm struct {
|
||||
data [][]byte
|
||||
counter int
|
||||
}
|
||||
|
||||
func NewMegolm(initialData []byte, initialCounter int) (*Megolm, error) {
|
||||
if len(initialData) != MEGOLM_RATCHET_LENGTH {
|
||||
return nil, fmt.Errorf("megolm initial data must be %d bytes. Got %d.", MEGOLM_RATCHET_LENGTH, len(initialData))
|
||||
}
|
||||
|
||||
data := make([][]byte, MEGOLM_RATCHET_PARTS)
|
||||
for i := 0; i < MEGOLM_RATCHET_PARTS; i++ {
|
||||
data[i] = make([]byte, MEGOLM_RATCHET_PART_LENGTH)
|
||||
start := i * MEGOLM_RATCHET_PART_LENGTH
|
||||
copy(data[i], initialData[start:start+MEGOLM_RATCHET_PART_LENGTH])
|
||||
}
|
||||
|
||||
return &Megolm{
|
||||
data: data,
|
||||
counter: initialCounter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *Megolm) Data() []byte {
|
||||
data := make([]byte, MEGOLM_RATCHET_LENGTH)
|
||||
for i := 0; i < MEGOLM_RATCHET_PARTS; i++ {
|
||||
start := i * MEGOLM_RATCHET_PART_LENGTH
|
||||
copy(data[start:start+MEGOLM_RATCHET_PART_LENGTH], m.data[i])
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
func (m *Megolm) Advance() {
|
||||
mask := 0x00FFFFFF
|
||||
h := 0
|
||||
m.counter++
|
||||
|
||||
/* figure out how much we need to rekey */
|
||||
for h < MEGOLM_RATCHET_PARTS {
|
||||
if m.counter&mask == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
h++
|
||||
mask >>= 8
|
||||
}
|
||||
|
||||
// update R[h:3] based on h
|
||||
for i := MEGOLM_RATCHET_PARTS - 1; i >= h; i-- {
|
||||
m.rehashPart(h, i)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Megolm) AdvanceTo(advanceTo int) {
|
||||
/* starting with R0, see if we need to update each part of the hash */
|
||||
for j := 0; j < MEGOLM_RATCHET_PARTS; j++ {
|
||||
shift := (MEGOLM_RATCHET_PARTS - j - 1) * 8
|
||||
mask := 0xffffffff << shift
|
||||
|
||||
/* how many times do we need to rehash this part?
|
||||
*
|
||||
* '& 0xff' ensures we handle integer wraparound correctly
|
||||
*/
|
||||
steps :=
|
||||
((advanceTo >> shift) - (m.counter >> shift)) & 0xff
|
||||
|
||||
if steps == 0 {
|
||||
/* deal with the edge case where megolm->counter is slightly larger
|
||||
* than advanceTo. This should only happen for R(0), and implies
|
||||
* that advanceTo has wrapped around and we need to advance R(0)
|
||||
* 256 times.
|
||||
*/
|
||||
if advanceTo < m.counter {
|
||||
steps = 0x100
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
/* for all but the last step, we can just bump R(j) without regard
|
||||
* to R(j+1)...R(3).
|
||||
*/
|
||||
for steps > 1 {
|
||||
m.rehashPart(j, j)
|
||||
steps--
|
||||
}
|
||||
|
||||
/* on the last step we also need to bump R(j+1)...R(3).
|
||||
*
|
||||
* (Theoretically, we could skip bumping R(j+2) if we're going to bump
|
||||
* R(j+1) again, but the code to figure that out is a bit baroque and
|
||||
* doesn't save us much).
|
||||
*/
|
||||
for k := 3; k >= j; k-- {
|
||||
m.rehashPart(j, k)
|
||||
}
|
||||
m.counter = advanceTo & mask
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Megolm) rehashPart(fromPart, toPart int) {
|
||||
newPart := HMACSHA256(
|
||||
m.data[fromPart],
|
||||
[]byte{HASH_KEY_SEEDS[toPart]},
|
||||
)
|
||||
m.data[toPart] = newPart
|
||||
}
|
|
@ -0,0 +1,116 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"log"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAdvance(t *testing.T) {
|
||||
randomData := []byte("0123456789ABCDEF0123456789ABCDEF" +
|
||||
"0123456789ABCDEF0123456789ABCDEF" +
|
||||
"0123456789ABCDEF0123456789ABCDEF" +
|
||||
"0123456789ABCDEF0123456789ABCDEF")
|
||||
|
||||
m, err := NewMegolm(randomData, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
m.Advance()
|
||||
index := 1
|
||||
expected1 := [][]byte{
|
||||
[]byte{0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46,
|
||||
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46},
|
||||
[]byte{0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46,
|
||||
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46},
|
||||
[]byte{0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46,
|
||||
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46},
|
||||
[]byte{0xba, 0x9c, 0xd9, 0x55, 0x74, 0x1d, 0x1c, 0x16, 0x23, 0x23, 0xec, 0x82, 0x5e, 0x7c, 0x5c, 0xe8,
|
||||
0x89, 0xbb, 0xb4, 0x23, 0xa1, 0x8f, 0x23, 0x82, 0x8f, 0xb2, 0x09, 0x0d, 0x6e, 0x2a, 0xf8, 0x6a},
|
||||
}
|
||||
|
||||
if m.counter != index {
|
||||
t.Errorf("expected counter to be: %#x, got %#x", index, m.counter)
|
||||
}
|
||||
|
||||
for i, _ := range m.data {
|
||||
if !bytes.Equal(m.data[i], expected1[i]) {
|
||||
t.Errorf("expected part 1 to be: %x, got %x", m.data[i], expected1[i])
|
||||
}
|
||||
}
|
||||
|
||||
index = 0x1000000
|
||||
m.AdvanceTo(index)
|
||||
expected2 := [][]byte{
|
||||
[]byte{0x54, 0x02, 0x2d, 0x7d, 0xc0, 0x29, 0x8e, 0x16, 0x37, 0xe2, 0x1c, 0x97, 0x15, 0x30, 0x92, 0xf9,
|
||||
0x33, 0xc0, 0x56, 0xff, 0x74, 0xfe, 0x1b, 0x92, 0x2d, 0x97, 0x1f, 0x24, 0x82, 0xc2, 0x85, 0x9c},
|
||||
[]byte{0x70, 0x04, 0xc0, 0x1e, 0xe4, 0x9b, 0xd6, 0xef, 0xe0, 0x07, 0x35, 0x25, 0xaf, 0x9b, 0x16, 0x32,
|
||||
0xc5, 0xbe, 0x72, 0x6d, 0x12, 0x34, 0x9c, 0xc5, 0xbd, 0x47, 0x2b, 0xdc, 0x2d, 0xf6, 0x54, 0x0f},
|
||||
[]byte{0x31, 0x12, 0x59, 0x11, 0x94, 0xfd, 0xa6, 0x17, 0xe5, 0x68, 0xc6, 0x83, 0x10, 0x1e, 0xae, 0xcd,
|
||||
0x7e, 0xdd, 0xd6, 0xde, 0x1f, 0xbc, 0x07, 0x67, 0xae, 0x34, 0xda, 0x1a, 0x09, 0xa5, 0x4e, 0xab},
|
||||
[]byte{0xba, 0x9c, 0xd9, 0x55, 0x74, 0x1d, 0x1c, 0x16, 0x23, 0x23, 0xec, 0x82, 0x5e, 0x7c, 0x5c, 0xe8,
|
||||
0x89, 0xbb, 0xb4, 0x23, 0xa1, 0x8f, 0x23, 0x82, 0x8f, 0xb2, 0x09, 0x0d, 0x6e, 0x2a, 0xf8, 0x6a},
|
||||
}
|
||||
|
||||
if m.counter != index {
|
||||
t.Errorf("expected counter to be: %#x, got %#x", index, m.counter)
|
||||
}
|
||||
|
||||
for i, _ := range m.data {
|
||||
if !bytes.Equal(m.data[i], expected2[i]) {
|
||||
t.Errorf("expected part 1 to be: %x, got %x", m.data[i], expected1[i])
|
||||
}
|
||||
}
|
||||
|
||||
index = 0x1041506
|
||||
m.AdvanceTo(index)
|
||||
expected3 := [][]byte{
|
||||
[]byte{0x54, 0x02, 0x2d, 0x7d, 0xc0, 0x29, 0x8e, 0x16, 0x37, 0xe2, 0x1c, 0x97, 0x15, 0x30, 0x92, 0xf9,
|
||||
0x33, 0xc0, 0x56, 0xff, 0x74, 0xfe, 0x1b, 0x92, 0x2d, 0x97, 0x1f, 0x24, 0x82, 0xc2, 0x85, 0x9c},
|
||||
[]byte{0x55, 0x58, 0x8d, 0xf5, 0xb7, 0xa4, 0x88, 0x78, 0x42, 0x89, 0x27, 0x86, 0x81, 0x64, 0x58, 0x9f,
|
||||
0x36, 0x63, 0x44, 0x7b, 0x51, 0xed, 0xc3, 0x59, 0x5b, 0x03, 0x6c, 0xa6, 0x04, 0xc4, 0x6d, 0xcd},
|
||||
[]byte{0x5c, 0x54, 0x85, 0x0b, 0xfa, 0x98, 0xa1, 0xfd, 0x79, 0xa9, 0xdf, 0x1c, 0xbe, 0x8f, 0xc5, 0x68,
|
||||
0x19, 0x37, 0xd3, 0x0c, 0x85, 0xc8, 0xc3, 0x1f, 0x7b, 0xb8, 0x28, 0x81, 0x6c, 0xf9, 0xff, 0x3b},
|
||||
[]byte{0x95, 0x6c, 0xbf, 0x80, 0x7e, 0x65, 0x12, 0x6a, 0x49, 0x55, 0x8d, 0x45, 0xc8, 0x4a, 0x2e, 0x4c,
|
||||
0xd5, 0x6f, 0x03, 0xe2, 0x44, 0x16, 0xb9, 0x8e, 0x1c, 0xfd, 0x97, 0xc2, 0x06, 0xaa, 0x90, 0x7a},
|
||||
}
|
||||
|
||||
if m.counter != index {
|
||||
t.Errorf("expected counter to be: %#x, got %#x", index, m.counter)
|
||||
}
|
||||
|
||||
for i, _ := range m.data {
|
||||
if !bytes.Equal(m.data[i], expected3[i]) {
|
||||
t.Errorf("expected part 1 to be: %x, got %x", m.data[i], expected1[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decodeHEX(s string) []byte {
|
||||
decoded, err := hex.DecodeString(s)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return decoded
|
||||
}
|
||||
|
||||
func TestDeriveKeys(t *testing.T) {
|
||||
inputKeys := decodeHEX("3031323334353637383941424445463031323334353637383941424344454630313233343536373839414244454630313233343536373839414243444546303132333435363738394142444546303132333435363738394142434445463031323334353637383941424445463031323334353637383941424344454630313233")
|
||||
info := decodeHEX("4d45474f4c4d5f4b455953")
|
||||
expected := decodeHEX("bdf842da829272e777e3c6ba70bb2f6e976e32330691e4dc3ca3dbbd43c62c0f4075ce27caf907395c3287c1971d95c58909ce9f810c1ee79fef7eefa10c46dc3694fc16cbd3a8b16b66cbc897121c6b")
|
||||
|
||||
keys, err := deriveKeys(inputKeys, []byte{}, info)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
res := keys.aesKey
|
||||
res = append(res, keys.macKey...)
|
||||
res = append(res, keys.aesIV...)
|
||||
|
||||
if !bytes.Equal(expected, res) {
|
||||
t.Errorf("expected %x, got %x", expected, res)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,119 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/ed25519"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
ED25519_RANDOM_LENGTH = 32
|
||||
ED25519_PUBLIC_KEY_LENGTH = 32
|
||||
ED25519_PRIVATE_KEY_LENGTH = 64
|
||||
ED25519_SIGNATURE_LENGTH = 64
|
||||
GROUP_MESSAGE_INDEX_TAG = 010
|
||||
GROUP_CIPHERTEXT_TAG = 022
|
||||
MAC_LEN = 8
|
||||
)
|
||||
|
||||
type OutboundSession struct {
|
||||
Ratchet *Megolm
|
||||
SigningKey *ed25519.PrivateKey
|
||||
}
|
||||
|
||||
func NewOutboundSession(initialData []byte, initialCounter int) (*OutboundSession, error) {
|
||||
if len(initialData) < MEGOLM_RATCHET_LENGTH+ED25519_RANDOM_LENGTH {
|
||||
return nil, fmt.Errorf("initialData must be MEGOLM_RATCHET_LENGTH + ED25519_RANDOM_LENGTH = %d bytes. Got %d.", MEGOLM_RATCHET_LENGTH+ED25519_RANDOM_LENGTH, len(initialData))
|
||||
}
|
||||
|
||||
ratchet, err := NewMegolm(initialData[:MEGOLM_RATCHET_LENGTH], initialCounter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
seed := initialData[MEGOLM_RATCHET_LENGTH : MEGOLM_RATCHET_LENGTH+ED25519_RANDOM_LENGTH]
|
||||
privKey := ed25519.NewKeyFromSeed(seed)
|
||||
|
||||
return &OutboundSession{
|
||||
Ratchet: ratchet,
|
||||
SigningKey: &privKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OutboundSession) GenerateSessionKey() []byte {
|
||||
key := make([]byte, 1+4+MEGOLM_RATCHET_LENGTH+ED25519_PUBLIC_KEY_LENGTH+ED25519_SIGNATURE_LENGTH)
|
||||
|
||||
key[0] = SESSION_KEY_VERSION
|
||||
|
||||
counter := uint32(s.Ratchet.counter)
|
||||
for i := 0; i < 4; i++ {
|
||||
value := 0xFF & (counter >> 24)
|
||||
binary.BigEndian.PutUint32(key[i*4+1:], value)
|
||||
counter <<= 8
|
||||
}
|
||||
|
||||
copy(key[5:], s.Ratchet.Data())
|
||||
copy(key[5+MEGOLM_RATCHET_LENGTH:], s.SigningKey.Public().(ed25519.PublicKey))
|
||||
|
||||
sig := ed25519.Sign(*s.SigningKey, key[0:5+MEGOLM_RATCHET_LENGTH+ED25519_PUBLIC_KEY_LENGTH])
|
||||
copy(key[5+MEGOLM_RATCHET_LENGTH+ED25519_PUBLIC_KEY_LENGTH:], sig)
|
||||
|
||||
return key
|
||||
}
|
||||
|
||||
func (s *OutboundSession) Encrypt(plaintext []byte) ([]byte, error) {
|
||||
derivedKeys, err := deriveKeys(s.Ratchet.Data(), []byte{}, MEGOLM_KDF_INFO)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ciphertextLen := len(plaintext) + aes.BlockSize - len(plaintext)%aes.BlockSize
|
||||
ciphertextEncodedLen := varStringLen(uint64(ciphertextLen))
|
||||
encodedLen :=
|
||||
uvarintLen(OLM_PROTOCOL_VERSION) +
|
||||
1 + // MESSAGE_INDEX_TAG
|
||||
uvarintLen(uint64(s.Ratchet.counter)) +
|
||||
1 + // MESSAGE_CIPHERTEXT_TAG
|
||||
ciphertextEncodedLen +
|
||||
uint64(MAC_LEN) +
|
||||
ED25519_SIGNATURE_LENGTH
|
||||
|
||||
encoded := make([]byte, encodedLen)
|
||||
pos := 0
|
||||
pos += binary.PutUvarint(encoded[pos:], OLM_PROTOCOL_VERSION)
|
||||
pos += binary.PutUvarint(encoded[pos:], MESSAGE_INDEX_TAG)
|
||||
pos += binary.PutUvarint(encoded[pos:], uint64(s.Ratchet.counter))
|
||||
pos += binary.PutUvarint(encoded[pos:], MESSAGE_CIPHERTEXT_TAG)
|
||||
pos += binary.PutUvarint(encoded[pos:], uint64(ciphertextLen))
|
||||
|
||||
block, err := aes.NewCipher(derivedKeys.aesKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paddedPlaintext, err := pkcs7Pad(plaintext, aes.BlockSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ciphertext := make([]byte, len(plaintext)+aes.BlockSize-len(plaintext)%aes.BlockSize)
|
||||
|
||||
mode := cipher.NewCBCEncrypter(block, derivedKeys.aesIV)
|
||||
mode.CryptBlocks(ciphertext, paddedPlaintext)
|
||||
|
||||
copy(encoded[pos:], ciphertext)
|
||||
pos += len(ciphertext)
|
||||
|
||||
mac := HMACSHA256(derivedKeys.macKey, encoded[0:pos])
|
||||
copy(encoded[pos:pos+MAC_LEN], mac[0:MAC_LEN])
|
||||
pos += MAC_LEN
|
||||
|
||||
sig := ed25519.Sign(*s.SigningKey, encoded[0:pos])
|
||||
copy(encoded[pos:], sig)
|
||||
|
||||
s.Ratchet.Advance()
|
||||
|
||||
return encoded, nil
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestOutboundSession(t *testing.T) {
|
||||
randomData := []byte(
|
||||
"0123456789ABDEF0123456789ABCDEF" +
|
||||
"0123456789ABDEF0123456789ABCDEF" +
|
||||
"0123456789ABDEF0123456789ABCDEF" +
|
||||
"0123456789ABDEF0123456789ABCDEF" +
|
||||
"0123456789ABDEF0123456789ABCDEF" +
|
||||
"0123456789ABDEF0123456789ABCDEF")
|
||||
|
||||
o, err := NewOutboundSession(randomData, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sessionKey := o.GenerateSessionKey()
|
||||
|
||||
expectedSessionKey := decodeHEX("020000000030313233343536373839414244454630313233343536373839414243444546303132333435363738394142444546303132333435363738394142434445463031323334353637383941424445463031323334353637383941424344454630313233343536373839414244454630313233343536373839414243444546303132330d1b760d410eae0fc7fb25068c34eaaf27fc1f5605fc1663234e07c0e55265c61b0ac599a49bf5b47f798f3180446bc663f4ee11660fdbaa036a7c3cc1dd9d5e72b2682e7c4ea82d9e7ef3f99640e5b75554e70c8967c1df281d4611ba4f1309")
|
||||
if !bytes.Equal(sessionKey, expectedSessionKey) {
|
||||
t.Fatalf("expected %x, got %x", expectedSessionKey, sessionKey)
|
||||
}
|
||||
|
||||
plaintext := "Message"
|
||||
encrypted, err := o.Encrypt([]byte(plaintext))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedEncrytped := decodeHEX("03080012101c6e1e94a5b072a32671b9f43e87607ef171cc8e147af46f05e3eaa331a1659c85eeb09657debf0b9d13911bd4f70d715d6d5baf216acbc917e6e22f1045b7bc54fc064451b9ec1d42b66c5a9ee3fc6c7b679d8aa5c0fb03")
|
||||
if !bytes.Equal(encrypted, expectedEncrytped) {
|
||||
t.Fatalf("expected %x, got %x", expectedEncrytped, encrypted)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,114 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
func deriveKeys(input []byte, salt []byte, info []byte) (*DerivedKeys, error) {
|
||||
h := hkdf.New(sha256.New, input, salt, info)
|
||||
|
||||
keys := &DerivedKeys{
|
||||
aesKey: make([]byte, AES256_KEY_LENGTH),
|
||||
macKey: make([]byte, HMAC_KEY_LENGTH),
|
||||
aesIV: make([]byte, AES256_IV_LENGTH),
|
||||
}
|
||||
|
||||
if _, err := h.Read(keys.aesKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := h.Read(keys.macKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := h.Read(keys.aesIV); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func pkcs7Pad(data []byte, blocksize int) ([]byte, error) {
|
||||
if blocksize <= 0 {
|
||||
return nil, errors.New("ErrInvalidBlockSize")
|
||||
}
|
||||
|
||||
if data == nil || len(data) == 0 {
|
||||
return nil, errors.New("ErrInvalidPKCS7Data")
|
||||
}
|
||||
|
||||
padLen := blocksize - len(data)%blocksize
|
||||
padding := bytes.Repeat([]byte{byte(padLen)}, padLen)
|
||||
|
||||
return append(data, padding...), nil
|
||||
}
|
||||
|
||||
func pkcs7Unpad(data []byte, blocksize int) ([]byte, error) {
|
||||
if blocksize <= 0 {
|
||||
return nil, errors.New("ErrInvalidBlockSize")
|
||||
}
|
||||
|
||||
if data == nil || len(data) == 0 {
|
||||
return nil, errors.New("ErrInvalidPKCS7Data")
|
||||
}
|
||||
|
||||
if len(data)%blocksize != 0 {
|
||||
return nil, errors.New("ErrInvalidPKCS7Padding")
|
||||
}
|
||||
|
||||
c := data[len(data)-1]
|
||||
n := int(c)
|
||||
|
||||
if n == 0 || n > len(data) {
|
||||
return nil, errors.New("ErrInvalidPKCS7Padding")
|
||||
}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
if data[len(data)-n+i] != c {
|
||||
return nil, errors.New("ErrInvalidPKCS7Padding")
|
||||
}
|
||||
}
|
||||
|
||||
return data[:len(data)-n], nil
|
||||
}
|
||||
|
||||
func generateHMACKey(inputKey []byte) []byte {
|
||||
hmacKey := make([]byte, SHA256_BLOCK_LENGTH)
|
||||
if len(inputKey) > SHA256_BLOCK_LENGTH {
|
||||
// TODO: check this part
|
||||
h := sha256.New()
|
||||
h.Write(inputKey)
|
||||
res := h.Sum(nil)
|
||||
copy(hmacKey, res)
|
||||
} else {
|
||||
copy(hmacKey, inputKey)
|
||||
}
|
||||
|
||||
return hmacKey
|
||||
}
|
||||
|
||||
func HMACSHA256(key []byte, input []byte) []byte {
|
||||
hmacKey := generateHMACKey(key)
|
||||
ctx := hmac.New(sha256.New, hmacKey)
|
||||
ctx.Write(input)
|
||||
return ctx.Sum(nil)
|
||||
}
|
||||
|
||||
func uvarintLen(x uint64) uint64 {
|
||||
var res uint64 = 1
|
||||
for x >= 0x80 {
|
||||
res++
|
||||
x >>= 7
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func varStringLen(n uint64) uint64 {
|
||||
return uvarintLen(n) + n
|
||||
}
|
Loading…
Reference in New Issue