import initial implementation

This commit is contained in:
Andrea Franz 2021-06-30 11:56:52 +02:00
commit 6367795862
No known key found for this signature in database
GPG Key ID: 4F0D2F2D9DE7F29D
9 changed files with 744 additions and 0 deletions

5
go.mod Normal file
View File

@ -0,0 +1,5 @@
module github.com/status-im/megolm
go 1.16
require golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a // indirect

7
go.sum Normal file
View File

@ -0,0 +1,7 @@
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a h1:kr2P4QFmQr29mSLA43kwrOcgcReGTfbE9N577tCTuBc=
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

186
inbound_session.go Normal file
View File

@ -0,0 +1,186 @@
package main
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/ed25519"
"encoding/base64"
"encoding/binary"
"errors"
"fmt"
)
const (
AES256_KEY_LENGTH = 32
AES256_IV_LENGTH = 16
HMAC_KEY_LENGTH = 32
HKDF_DEFAULT_SALT_LEN = 32
SESSION_SHARING_LENGTH = 229
SESSION_KEY_VERSION = 2
MESSAGE_INDEX_TAG = 0x08
MESSAGE_CIPHERTEXT_TAG = 0x12
)
var MEGOLM_KDF_INFO = []byte("MEGOLM_KEYS")
type Message struct {
Version byte
Payload []byte
Index int
Plaintext []byte
MAC []byte
Signature []byte
}
type DerivedKeys struct {
aesKey []byte
macKey []byte
aesIV []byte
}
type InboundSession struct {
Version byte
SigningPubKey []byte
InitialRatchet *Megolm
LatestRatchet *Megolm
}
func NewInboundSession(encodedSession string) (*InboundSession, error) {
session, err := base64.StdEncoding.DecodeString(encodedSession)
if err != nil {
return nil, err
}
if len(session) != SESSION_SHARING_LENGTH {
return nil, errors.New("bad session length")
}
version := session[0]
index := int(binary.BigEndian.Uint32(session[1:5]))
ratchets := session[5:133]
signingPubKey := session[133:165]
signature := session[165:229]
initialRatchet, err := NewMegolm(ratchets, index)
if err != nil {
return nil, err
}
latestRatchet, err := NewMegolm(ratchets, index)
if err != nil {
return nil, err
}
if !ed25519.Verify(signingPubKey, session[0:165], signature) {
return nil, errors.New("bad signature")
}
return &InboundSession{
Version: version,
SigningPubKey: signingPubKey,
InitialRatchet: initialRatchet,
LatestRatchet: latestRatchet,
}, nil
}
func (s *InboundSession) Decrypt(rawMessage string) (*Message, error) {
decoded, err := base64.StdEncoding.DecodeString(rawMessage)
if err != nil {
return nil, err
}
// 1 (version) + n (payload) + 8 (MAC) + 72 (signature)
// TODO: it should be min 81 + min message len
if len(decoded) < 81 {
return nil, errors.New("message too short")
}
// TODO: check max length
msgLen := len(decoded)
version := decoded[0]
payload := decoded[1 : msgLen-72]
mac := decoded[msgLen-72 : msgLen]
signature := decoded[msgLen-64 : msgLen]
index, ciphertext, err := decodeMessage(payload)
if err != nil {
return nil, err
}
message := decoded[0 : msgLen-64]
if !ed25519.Verify(s.SigningPubKey, message, signature) {
return nil, errors.New("bad signature")
}
// TODO: advance ratchet to index
derivedKeys, err := deriveKeys(s.LatestRatchet.Data(), []byte{}, MEGOLM_KDF_INFO)
if err != nil {
return nil, err
}
// TODO: check mac
block, err := aes.NewCipher(derivedKeys.aesKey)
if err != nil {
return nil, err
}
mode := cipher.NewCBCDecrypter(block, derivedKeys.aesIV)
mode.CryptBlocks(ciphertext, ciphertext)
ciphertext, _ = pkcs7Unpad(ciphertext, aes.BlockSize)
return &Message{
Version: version,
Payload: payload,
Index: index,
Plaintext: ciphertext,
MAC: mac,
Signature: signature,
}, nil
}
//TODO: check megolm encoding implementation
func decodeMessage(payload []byte) (int, []byte, error) {
buf := bytes.NewBuffer(payload)
tag, err := binary.ReadUvarint(buf)
if err != nil {
return 0, nil, err
}
if tag != MESSAGE_INDEX_TAG {
return 0, nil, fmt.Errorf("expected tag %x, got %x", MESSAGE_INDEX_TAG, tag)
}
index, err := binary.ReadUvarint(buf)
if err != nil {
return 0, nil, err
}
tag, err = binary.ReadUvarint(buf)
if err != nil {
return 0, nil, err
}
if tag != MESSAGE_CIPHERTEXT_TAG {
return 0, nil, fmt.Errorf("expected tag %x, got %x", MESSAGE_CIPHERTEXT_TAG, tag)
}
length, err := binary.ReadUvarint(buf)
if err != nil {
return 0, nil, err
}
ciphertext := make([]byte, length)
n, err := buf.Read(ciphertext)
if err != nil {
return 0, nil, err
}
if n != int(length) {
return 0, nil, fmt.Errorf("expected cipertext length %d, got %d", length, n)
}
return int(index), ciphertext, nil
}

33
inbound_session_test.go Normal file
View File

@ -0,0 +1,33 @@
package main
import (
"bytes"
"log"
"testing"
)
func TestInboundSession(t *testing.T) {
sessionKey := "AgAAAAAwMTIzNDU2Nzg5QUJERUYwMTIzNDU2Nzg5QUJDREVGMDEyMzQ1Njc4OUFCREVGM" +
"DEyMzQ1Njc4OUFCQ0RFRjAxMjM0NTY3ODlBQkRFRjAxMjM0NTY3ODlBQkNERUYwMTIzND" +
"U2Nzg5QUJERUYwMTIzNDU2Nzg5QUJDREVGMDEyMztqJ7zOtqQtYqOo0CpvDXNlMhV3HeJ" +
"DpjrASKGLWdop4lx1cSN3Xv1TgfLPW8rhGiW+hHiMxd36nRuxscNv9k4oJA/KP+o0mi1w" +
"v44StrEJ1wwx9WZHBUIWkQbaBSuBDw=="
message := "AwgAEhAcbh6UpbByoyZxufQ+h2B+8XHMjhR69G8nP4pNZGl/3QMgrzCZPmP+F2aPLyKPz" +
"xRPBMUkeXRJ6Iqm5NeOdx2eERgTW7P20CM+lL3Xpk+ZUOOPvsSQNaAL"
s, err := NewInboundSession(sessionKey)
if err != nil {
log.Fatal(err)
}
m, err := s.Decrypt(message)
if err != nil {
log.Fatal(err)
}
expectedMessage := []byte("Message")
if !bytes.Equal(m.Plaintext, expectedMessage) {
t.Fatalf("expected message to be `%s`, got `%s`", expectedMessage, m.Plaintext)
}
}

125
megolm.go Normal file
View File

@ -0,0 +1,125 @@
package main
import (
"fmt"
)
const (
SHA256_BLOCK_LENGTH = 64
SHA256_OUTPUT_LENGTH = 32
MEGOLM_RATCHET_PART_LENGTH = 32
MEGOLM_RATCHET_PARTS = 4
MEGOLM_RATCHET_LENGTH = (MEGOLM_RATCHET_PARTS * MEGOLM_RATCHET_PART_LENGTH)
OLM_PROTOCOL_VERSION = 3
)
var HASH_KEY_SEEDS = [MEGOLM_RATCHET_PARTS]byte{0x00, 0x01, 0x02, 0x03}
type Megolm struct {
data [][]byte
counter int
}
func NewMegolm(initialData []byte, initialCounter int) (*Megolm, error) {
if len(initialData) != MEGOLM_RATCHET_LENGTH {
return nil, fmt.Errorf("megolm initial data must be %d bytes. Got %d.", MEGOLM_RATCHET_LENGTH, len(initialData))
}
data := make([][]byte, MEGOLM_RATCHET_PARTS)
for i := 0; i < MEGOLM_RATCHET_PARTS; i++ {
data[i] = make([]byte, MEGOLM_RATCHET_PART_LENGTH)
start := i * MEGOLM_RATCHET_PART_LENGTH
copy(data[i], initialData[start:start+MEGOLM_RATCHET_PART_LENGTH])
}
return &Megolm{
data: data,
counter: initialCounter,
}, nil
}
func (m *Megolm) Data() []byte {
data := make([]byte, MEGOLM_RATCHET_LENGTH)
for i := 0; i < MEGOLM_RATCHET_PARTS; i++ {
start := i * MEGOLM_RATCHET_PART_LENGTH
copy(data[start:start+MEGOLM_RATCHET_PART_LENGTH], m.data[i])
}
return data
}
func (m *Megolm) Advance() {
mask := 0x00FFFFFF
h := 0
m.counter++
/* figure out how much we need to rekey */
for h < MEGOLM_RATCHET_PARTS {
if m.counter&mask == 0 {
break
}
h++
mask >>= 8
}
// update R[h:3] based on h
for i := MEGOLM_RATCHET_PARTS - 1; i >= h; i-- {
m.rehashPart(h, i)
}
}
func (m *Megolm) AdvanceTo(advanceTo int) {
/* starting with R0, see if we need to update each part of the hash */
for j := 0; j < MEGOLM_RATCHET_PARTS; j++ {
shift := (MEGOLM_RATCHET_PARTS - j - 1) * 8
mask := 0xffffffff << shift
/* how many times do we need to rehash this part?
*
* '& 0xff' ensures we handle integer wraparound correctly
*/
steps :=
((advanceTo >> shift) - (m.counter >> shift)) & 0xff
if steps == 0 {
/* deal with the edge case where megolm->counter is slightly larger
* than advanceTo. This should only happen for R(0), and implies
* that advanceTo has wrapped around and we need to advance R(0)
* 256 times.
*/
if advanceTo < m.counter {
steps = 0x100
} else {
continue
}
}
/* for all but the last step, we can just bump R(j) without regard
* to R(j+1)...R(3).
*/
for steps > 1 {
m.rehashPart(j, j)
steps--
}
/* on the last step we also need to bump R(j+1)...R(3).
*
* (Theoretically, we could skip bumping R(j+2) if we're going to bump
* R(j+1) again, but the code to figure that out is a bit baroque and
* doesn't save us much).
*/
for k := 3; k >= j; k-- {
m.rehashPart(j, k)
}
m.counter = advanceTo & mask
}
}
func (m *Megolm) rehashPart(fromPart, toPart int) {
newPart := HMACSHA256(
m.data[fromPart],
[]byte{HASH_KEY_SEEDS[toPart]},
)
m.data[toPart] = newPart
}

116
megolm_test.go Normal file
View File

@ -0,0 +1,116 @@
package main
import (
"bytes"
"encoding/hex"
"log"
"testing"
)
func TestAdvance(t *testing.T) {
randomData := []byte("0123456789ABCDEF0123456789ABCDEF" +
"0123456789ABCDEF0123456789ABCDEF" +
"0123456789ABCDEF0123456789ABCDEF" +
"0123456789ABCDEF0123456789ABCDEF")
m, err := NewMegolm(randomData, 0)
if err != nil {
t.Fatal(err)
}
m.Advance()
index := 1
expected1 := [][]byte{
[]byte{0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46,
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46},
[]byte{0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46,
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46},
[]byte{0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46,
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46},
[]byte{0xba, 0x9c, 0xd9, 0x55, 0x74, 0x1d, 0x1c, 0x16, 0x23, 0x23, 0xec, 0x82, 0x5e, 0x7c, 0x5c, 0xe8,
0x89, 0xbb, 0xb4, 0x23, 0xa1, 0x8f, 0x23, 0x82, 0x8f, 0xb2, 0x09, 0x0d, 0x6e, 0x2a, 0xf8, 0x6a},
}
if m.counter != index {
t.Errorf("expected counter to be: %#x, got %#x", index, m.counter)
}
for i, _ := range m.data {
if !bytes.Equal(m.data[i], expected1[i]) {
t.Errorf("expected part 1 to be: %x, got %x", m.data[i], expected1[i])
}
}
index = 0x1000000
m.AdvanceTo(index)
expected2 := [][]byte{
[]byte{0x54, 0x02, 0x2d, 0x7d, 0xc0, 0x29, 0x8e, 0x16, 0x37, 0xe2, 0x1c, 0x97, 0x15, 0x30, 0x92, 0xf9,
0x33, 0xc0, 0x56, 0xff, 0x74, 0xfe, 0x1b, 0x92, 0x2d, 0x97, 0x1f, 0x24, 0x82, 0xc2, 0x85, 0x9c},
[]byte{0x70, 0x04, 0xc0, 0x1e, 0xe4, 0x9b, 0xd6, 0xef, 0xe0, 0x07, 0x35, 0x25, 0xaf, 0x9b, 0x16, 0x32,
0xc5, 0xbe, 0x72, 0x6d, 0x12, 0x34, 0x9c, 0xc5, 0xbd, 0x47, 0x2b, 0xdc, 0x2d, 0xf6, 0x54, 0x0f},
[]byte{0x31, 0x12, 0x59, 0x11, 0x94, 0xfd, 0xa6, 0x17, 0xe5, 0x68, 0xc6, 0x83, 0x10, 0x1e, 0xae, 0xcd,
0x7e, 0xdd, 0xd6, 0xde, 0x1f, 0xbc, 0x07, 0x67, 0xae, 0x34, 0xda, 0x1a, 0x09, 0xa5, 0x4e, 0xab},
[]byte{0xba, 0x9c, 0xd9, 0x55, 0x74, 0x1d, 0x1c, 0x16, 0x23, 0x23, 0xec, 0x82, 0x5e, 0x7c, 0x5c, 0xe8,
0x89, 0xbb, 0xb4, 0x23, 0xa1, 0x8f, 0x23, 0x82, 0x8f, 0xb2, 0x09, 0x0d, 0x6e, 0x2a, 0xf8, 0x6a},
}
if m.counter != index {
t.Errorf("expected counter to be: %#x, got %#x", index, m.counter)
}
for i, _ := range m.data {
if !bytes.Equal(m.data[i], expected2[i]) {
t.Errorf("expected part 1 to be: %x, got %x", m.data[i], expected1[i])
}
}
index = 0x1041506
m.AdvanceTo(index)
expected3 := [][]byte{
[]byte{0x54, 0x02, 0x2d, 0x7d, 0xc0, 0x29, 0x8e, 0x16, 0x37, 0xe2, 0x1c, 0x97, 0x15, 0x30, 0x92, 0xf9,
0x33, 0xc0, 0x56, 0xff, 0x74, 0xfe, 0x1b, 0x92, 0x2d, 0x97, 0x1f, 0x24, 0x82, 0xc2, 0x85, 0x9c},
[]byte{0x55, 0x58, 0x8d, 0xf5, 0xb7, 0xa4, 0x88, 0x78, 0x42, 0x89, 0x27, 0x86, 0x81, 0x64, 0x58, 0x9f,
0x36, 0x63, 0x44, 0x7b, 0x51, 0xed, 0xc3, 0x59, 0x5b, 0x03, 0x6c, 0xa6, 0x04, 0xc4, 0x6d, 0xcd},
[]byte{0x5c, 0x54, 0x85, 0x0b, 0xfa, 0x98, 0xa1, 0xfd, 0x79, 0xa9, 0xdf, 0x1c, 0xbe, 0x8f, 0xc5, 0x68,
0x19, 0x37, 0xd3, 0x0c, 0x85, 0xc8, 0xc3, 0x1f, 0x7b, 0xb8, 0x28, 0x81, 0x6c, 0xf9, 0xff, 0x3b},
[]byte{0x95, 0x6c, 0xbf, 0x80, 0x7e, 0x65, 0x12, 0x6a, 0x49, 0x55, 0x8d, 0x45, 0xc8, 0x4a, 0x2e, 0x4c,
0xd5, 0x6f, 0x03, 0xe2, 0x44, 0x16, 0xb9, 0x8e, 0x1c, 0xfd, 0x97, 0xc2, 0x06, 0xaa, 0x90, 0x7a},
}
if m.counter != index {
t.Errorf("expected counter to be: %#x, got %#x", index, m.counter)
}
for i, _ := range m.data {
if !bytes.Equal(m.data[i], expected3[i]) {
t.Errorf("expected part 1 to be: %x, got %x", m.data[i], expected1[i])
}
}
}
func decodeHEX(s string) []byte {
decoded, err := hex.DecodeString(s)
if err != nil {
log.Fatal(err)
}
return decoded
}
func TestDeriveKeys(t *testing.T) {
inputKeys := decodeHEX("3031323334353637383941424445463031323334353637383941424344454630313233343536373839414244454630313233343536373839414243444546303132333435363738394142444546303132333435363738394142434445463031323334353637383941424445463031323334353637383941424344454630313233")
info := decodeHEX("4d45474f4c4d5f4b455953")
expected := decodeHEX("bdf842da829272e777e3c6ba70bb2f6e976e32330691e4dc3ca3dbbd43c62c0f4075ce27caf907395c3287c1971d95c58909ce9f810c1ee79fef7eefa10c46dc3694fc16cbd3a8b16b66cbc897121c6b")
keys, err := deriveKeys(inputKeys, []byte{}, info)
if err != nil {
t.Fatal(err)
}
res := keys.aesKey
res = append(res, keys.macKey...)
res = append(res, keys.aesIV...)
if !bytes.Equal(expected, res) {
t.Errorf("expected %x, got %x", expected, res)
}
}

119
outbound_session.go Normal file
View File

@ -0,0 +1,119 @@
package main
import (
"crypto/aes"
"crypto/cipher"
"crypto/ed25519"
"encoding/binary"
"fmt"
)
const (
ED25519_RANDOM_LENGTH = 32
ED25519_PUBLIC_KEY_LENGTH = 32
ED25519_PRIVATE_KEY_LENGTH = 64
ED25519_SIGNATURE_LENGTH = 64
GROUP_MESSAGE_INDEX_TAG = 010
GROUP_CIPHERTEXT_TAG = 022
MAC_LEN = 8
)
type OutboundSession struct {
Ratchet *Megolm
SigningKey *ed25519.PrivateKey
}
func NewOutboundSession(initialData []byte, initialCounter int) (*OutboundSession, error) {
if len(initialData) < MEGOLM_RATCHET_LENGTH+ED25519_RANDOM_LENGTH {
return nil, fmt.Errorf("initialData must be MEGOLM_RATCHET_LENGTH + ED25519_RANDOM_LENGTH = %d bytes. Got %d.", MEGOLM_RATCHET_LENGTH+ED25519_RANDOM_LENGTH, len(initialData))
}
ratchet, err := NewMegolm(initialData[:MEGOLM_RATCHET_LENGTH], initialCounter)
if err != nil {
return nil, err
}
seed := initialData[MEGOLM_RATCHET_LENGTH : MEGOLM_RATCHET_LENGTH+ED25519_RANDOM_LENGTH]
privKey := ed25519.NewKeyFromSeed(seed)
return &OutboundSession{
Ratchet: ratchet,
SigningKey: &privKey,
}, nil
}
func (s *OutboundSession) GenerateSessionKey() []byte {
key := make([]byte, 1+4+MEGOLM_RATCHET_LENGTH+ED25519_PUBLIC_KEY_LENGTH+ED25519_SIGNATURE_LENGTH)
key[0] = SESSION_KEY_VERSION
counter := uint32(s.Ratchet.counter)
for i := 0; i < 4; i++ {
value := 0xFF & (counter >> 24)
binary.BigEndian.PutUint32(key[i*4+1:], value)
counter <<= 8
}
copy(key[5:], s.Ratchet.Data())
copy(key[5+MEGOLM_RATCHET_LENGTH:], s.SigningKey.Public().(ed25519.PublicKey))
sig := ed25519.Sign(*s.SigningKey, key[0:5+MEGOLM_RATCHET_LENGTH+ED25519_PUBLIC_KEY_LENGTH])
copy(key[5+MEGOLM_RATCHET_LENGTH+ED25519_PUBLIC_KEY_LENGTH:], sig)
return key
}
func (s *OutboundSession) Encrypt(plaintext []byte) ([]byte, error) {
derivedKeys, err := deriveKeys(s.Ratchet.Data(), []byte{}, MEGOLM_KDF_INFO)
if err != nil {
return nil, err
}
ciphertextLen := len(plaintext) + aes.BlockSize - len(plaintext)%aes.BlockSize
ciphertextEncodedLen := varStringLen(uint64(ciphertextLen))
encodedLen :=
uvarintLen(OLM_PROTOCOL_VERSION) +
1 + // MESSAGE_INDEX_TAG
uvarintLen(uint64(s.Ratchet.counter)) +
1 + // MESSAGE_CIPHERTEXT_TAG
ciphertextEncodedLen +
uint64(MAC_LEN) +
ED25519_SIGNATURE_LENGTH
encoded := make([]byte, encodedLen)
pos := 0
pos += binary.PutUvarint(encoded[pos:], OLM_PROTOCOL_VERSION)
pos += binary.PutUvarint(encoded[pos:], MESSAGE_INDEX_TAG)
pos += binary.PutUvarint(encoded[pos:], uint64(s.Ratchet.counter))
pos += binary.PutUvarint(encoded[pos:], MESSAGE_CIPHERTEXT_TAG)
pos += binary.PutUvarint(encoded[pos:], uint64(ciphertextLen))
block, err := aes.NewCipher(derivedKeys.aesKey)
if err != nil {
return nil, err
}
paddedPlaintext, err := pkcs7Pad(plaintext, aes.BlockSize)
if err != nil {
return nil, err
}
ciphertext := make([]byte, len(plaintext)+aes.BlockSize-len(plaintext)%aes.BlockSize)
mode := cipher.NewCBCEncrypter(block, derivedKeys.aesIV)
mode.CryptBlocks(ciphertext, paddedPlaintext)
copy(encoded[pos:], ciphertext)
pos += len(ciphertext)
mac := HMACSHA256(derivedKeys.macKey, encoded[0:pos])
copy(encoded[pos:pos+MAC_LEN], mac[0:MAC_LEN])
pos += MAC_LEN
sig := ed25519.Sign(*s.SigningKey, encoded[0:pos])
copy(encoded[pos:], sig)
s.Ratchet.Advance()
return encoded, nil
}

39
outbound_session_test.go Normal file
View File

@ -0,0 +1,39 @@
package main
import (
"bytes"
"testing"
)
func TestOutboundSession(t *testing.T) {
randomData := []byte(
"0123456789ABDEF0123456789ABCDEF" +
"0123456789ABDEF0123456789ABCDEF" +
"0123456789ABDEF0123456789ABCDEF" +
"0123456789ABDEF0123456789ABCDEF" +
"0123456789ABDEF0123456789ABCDEF" +
"0123456789ABDEF0123456789ABCDEF")
o, err := NewOutboundSession(randomData, 0)
if err != nil {
t.Fatal(err)
}
sessionKey := o.GenerateSessionKey()
expectedSessionKey := decodeHEX("020000000030313233343536373839414244454630313233343536373839414243444546303132333435363738394142444546303132333435363738394142434445463031323334353637383941424445463031323334353637383941424344454630313233343536373839414244454630313233343536373839414243444546303132330d1b760d410eae0fc7fb25068c34eaaf27fc1f5605fc1663234e07c0e55265c61b0ac599a49bf5b47f798f3180446bc663f4ee11660fdbaa036a7c3cc1dd9d5e72b2682e7c4ea82d9e7ef3f99640e5b75554e70c8967c1df281d4611ba4f1309")
if !bytes.Equal(sessionKey, expectedSessionKey) {
t.Fatalf("expected %x, got %x", expectedSessionKey, sessionKey)
}
plaintext := "Message"
encrypted, err := o.Encrypt([]byte(plaintext))
if err != nil {
t.Fatal(err)
}
expectedEncrytped := decodeHEX("03080012101c6e1e94a5b072a32671b9f43e87607ef171cc8e147af46f05e3eaa331a1659c85eeb09657debf0b9d13911bd4f70d715d6d5baf216acbc917e6e22f1045b7bc54fc064451b9ec1d42b66c5a9ee3fc6c7b679d8aa5c0fb03")
if !bytes.Equal(encrypted, expectedEncrytped) {
t.Fatalf("expected %x, got %x", expectedEncrytped, encrypted)
}
}

114
utils.go Normal file
View File

@ -0,0 +1,114 @@
package main
import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"errors"
"golang.org/x/crypto/hkdf"
)
func deriveKeys(input []byte, salt []byte, info []byte) (*DerivedKeys, error) {
h := hkdf.New(sha256.New, input, salt, info)
keys := &DerivedKeys{
aesKey: make([]byte, AES256_KEY_LENGTH),
macKey: make([]byte, HMAC_KEY_LENGTH),
aesIV: make([]byte, AES256_IV_LENGTH),
}
if _, err := h.Read(keys.aesKey); err != nil {
return nil, err
}
if _, err := h.Read(keys.macKey); err != nil {
return nil, err
}
if _, err := h.Read(keys.aesIV); err != nil {
return nil, err
}
return keys, nil
}
func pkcs7Pad(data []byte, blocksize int) ([]byte, error) {
if blocksize <= 0 {
return nil, errors.New("ErrInvalidBlockSize")
}
if data == nil || len(data) == 0 {
return nil, errors.New("ErrInvalidPKCS7Data")
}
padLen := blocksize - len(data)%blocksize
padding := bytes.Repeat([]byte{byte(padLen)}, padLen)
return append(data, padding...), nil
}
func pkcs7Unpad(data []byte, blocksize int) ([]byte, error) {
if blocksize <= 0 {
return nil, errors.New("ErrInvalidBlockSize")
}
if data == nil || len(data) == 0 {
return nil, errors.New("ErrInvalidPKCS7Data")
}
if len(data)%blocksize != 0 {
return nil, errors.New("ErrInvalidPKCS7Padding")
}
c := data[len(data)-1]
n := int(c)
if n == 0 || n > len(data) {
return nil, errors.New("ErrInvalidPKCS7Padding")
}
for i := 0; i < n; i++ {
if data[len(data)-n+i] != c {
return nil, errors.New("ErrInvalidPKCS7Padding")
}
}
return data[:len(data)-n], nil
}
func generateHMACKey(inputKey []byte) []byte {
hmacKey := make([]byte, SHA256_BLOCK_LENGTH)
if len(inputKey) > SHA256_BLOCK_LENGTH {
// TODO: check this part
h := sha256.New()
h.Write(inputKey)
res := h.Sum(nil)
copy(hmacKey, res)
} else {
copy(hmacKey, inputKey)
}
return hmacKey
}
func HMACSHA256(key []byte, input []byte) []byte {
hmacKey := generateHMACKey(key)
ctx := hmac.New(sha256.New, hmacKey)
ctx.Write(input)
return ctx.Sum(nil)
}
func uvarintLen(x uint64) uint64 {
var res uint64 = 1
for x >= 0x80 {
res++
x >>= 7
}
return res
}
func varStringLen(n uint64) uint64 {
return uvarintLen(n) + n
}