From 962299d8b4390932e56f6bab94bba87381b9fe60 Mon Sep 17 00:00:00 2001 From: Richard Ramos Date: Wed, 21 Dec 2022 11:56:30 -0400 Subject: [PATCH] fix: payload serialization --- noise_test.go | 16 ++++++++++++++-- patterns.go | 9 ++++++++- payload.go | 21 ++++++++++++--------- 3 files changed, 34 insertions(+), 12 deletions(-) diff --git a/noise_test.go b/noise_test.go index a59b08d..adc60ea 100644 --- a/noise_test.go +++ b/noise_test.go @@ -126,7 +126,13 @@ func handshakeTest(t *testing.T, hsAlice *Handshake, hsBob *Handshake) { encryptedPayload, err := aliceHSResult.WriteMessage(message, defaultMessageNametagBuffer) require.NoError(t, err) - plaintext, err := bobHSResult.ReadMessage(encryptedPayload, defaultMessageNametagBuffer) + serializedPayload, err := encryptedPayload.Serialize() + require.NoError(t, err) + + deserializedPayload, err := DeserializePayloadV2(serializedPayload) + require.NoError(t, err) + + plaintext, err := bobHSResult.ReadMessage(deserializedPayload, defaultMessageNametagBuffer) require.NoError(t, err) require.Equal(t, message, plaintext) @@ -137,7 +143,13 @@ func handshakeTest(t *testing.T, hsAlice *Handshake, hsBob *Handshake) { encryptedPayload, err = bobHSResult.WriteMessage(message, defaultMessageNametagBuffer) require.NoError(t, err) - plaintext, err = aliceHSResult.ReadMessage(encryptedPayload, defaultMessageNametagBuffer) + serializedPayload, err = encryptedPayload.Serialize() + require.NoError(t, err) + + deserializedPayload, err = DeserializePayloadV2(serializedPayload) + require.NoError(t, err) + + plaintext, err = aliceHSResult.ReadMessage(deserializedPayload, defaultMessageNametagBuffer) require.NoError(t, err) require.Equal(t, message, plaintext) diff --git a/patterns.go b/patterns.go index a05e11f..40b3121 100644 --- a/patterns.go +++ b/patterns.go @@ -104,15 +104,17 @@ type HandshakePattern struct { messagePatterns []MessagePattern hashFn func() hash.Hash cipherFn func([]byte) (cipher.AEAD, error) + tagSize int dhKey DHKey } -func NewHandshakePattern(protocolID byte, name string, hashFn func() hash.Hash, cipherFn func([]byte) (cipher.AEAD, error), dhKey DHKey, preMessagePatterns []PreMessagePattern, messagePatterns []MessagePattern) HandshakePattern { +func NewHandshakePattern(protocolID byte, name string, hashFn func() hash.Hash, cipherFn func([]byte) (cipher.AEAD, error), tagSize int, dhKey DHKey, preMessagePatterns []PreMessagePattern, messagePatterns []MessagePattern) HandshakePattern { return HandshakePattern{ protocolID: protocolID, name: name, hashFn: hashFn, cipherFn: cipherFn, + tagSize: tagSize, dhKey: dhKey, premessagePatterns: preMessagePatterns, messagePatterns: messagePatterns, @@ -153,6 +155,7 @@ var K1K1 = NewHandshakePattern( "Noise_K1K1_25519_ChaChaPoly_SHA256", sha256.New, chacha20poly1305.New, + 16, DH25519, []PreMessagePattern{ NewPreMessagePattern(Right, []NoiseTokens{S}), @@ -170,6 +173,7 @@ var XK1 = NewHandshakePattern( "Noise_XK1_25519_ChaChaPoly_SHA256", sha256.New, chacha20poly1305.New, + 16, DH25519, []PreMessagePattern{ NewPreMessagePattern(Left, []NoiseTokens{S}), @@ -186,6 +190,7 @@ var XX = NewHandshakePattern( "Noise_XX_25519_ChaChaPoly_SHA256", sha256.New, chacha20poly1305.New, + 16, DH25519, EmptyPreMessage, []MessagePattern{ @@ -200,6 +205,7 @@ var XXpsk0 = NewHandshakePattern( "Noise_XXpsk0_25519_ChaChaPoly_SHA256", sha256.New, chacha20poly1305.New, + 16, DH25519, EmptyPreMessage, []MessagePattern{ @@ -214,6 +220,7 @@ var WakuPairing = NewHandshakePattern( "Noise_WakuPairing_25519_ChaChaPoly_SHA256", sha256.New, chacha20poly1305.New, + 16, DH25519, []PreMessagePattern{ NewPreMessagePattern(Left, []NoiseTokens{E}), diff --git a/payload.go b/payload.go index 096b4e5..5a2caf6 100644 --- a/payload.go +++ b/payload.go @@ -2,7 +2,6 @@ package noise import ( "bytes" - "crypto/ed25519" "encoding/binary" "errors" ) @@ -106,8 +105,6 @@ func (p *PayloadV2) Serialize() ([]byte, error) { return payloadBuf.Bytes(), nil } -const ChaChaPolyTagSize = byte(16) - // Deserializes a byte sequence to a PayloadV2 object according to https://rfc.vac.dev/spec/35/. // The input serialized payload concatenates the output PayloadV2 object fields as // payload = ( protocolId || serializedHandshakeMessageLen || serializedHandshakeMessage || transportMessageLen || transportMessage) @@ -126,8 +123,14 @@ func DeserializePayloadV2(payload []byte) (*PayloadV2, error) { return nil, err } - if !IsProtocolIDSupported(result.ProtocolId) { - return nil, errors.New("unsupported protocol") + var pattern HandshakePattern + var err error + + if result.ProtocolId != None { + pattern, err = GetHandshakePattern(result.ProtocolId) + if err != nil { + return nil, err + } } // We read the Handshake Message length (1 byte) @@ -150,13 +153,13 @@ func DeserializePayloadV2(payload []byte) (*PayloadV2, error) { if flag == 0 { // If the key is unencrypted, we only read the X coordinate of the EC public key and we deserialize into a Noise Public Key - pkLen := ed25519.PublicKeySize + pkLen := pattern.dhKey.DHLen() var pkBytes SerializedNoisePublicKey = make([]byte, pkLen) if err := binary.Read(payloadBuf, binary.BigEndian, &pkBytes); err != nil { return nil, err } - serializedPK := SerializedNoisePublicKey(make([]byte, ed25519.PublicKeySize+1)) + serializedPK := SerializedNoisePublicKey(make([]byte, pkLen+1)) serializedPK[0] = flag copy(serializedPK[1:], pkBytes) @@ -169,7 +172,7 @@ func DeserializePayloadV2(payload []byte) (*PayloadV2, error) { written += uint8(1 + pkLen) } else if flag == 1 { // If the key is encrypted, we only read the encrypted X coordinate and the authorization tag, and we deserialize into a Noise Public Key - pkLen := ed25519.PublicKeySize + ChaChaPolyTagSize + pkLen := pattern.dhKey.DHLen() + pattern.tagSize // TODO: duplicated code: ============== var pkBytes SerializedNoisePublicKey = make([]byte, pkLen) @@ -177,7 +180,7 @@ func DeserializePayloadV2(payload []byte) (*PayloadV2, error) { return nil, err } - serializedPK := SerializedNoisePublicKey(make([]byte, ed25519.PublicKeySize+1)) + serializedPK := SerializedNoisePublicKey(make([]byte, pkLen+1)) serializedPK[0] = flag copy(serializedPK[1:], pkBytes)