pre-calculate handshake message len

This commit is contained in:
Aarsh Shah 2020-04-29 18:39:12 +05:30
parent d9758632dc
commit e99d5b623e
1 changed files with 17 additions and 11 deletions

View File

@ -6,6 +6,7 @@ import (
"encoding/binary"
"fmt"
pool "github.com/libp2p/go-buffer-pool"
"golang.org/x/crypto/poly1305"
"time"
"github.com/flynn/noise"
@ -61,7 +62,8 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
if s.initiator {
// stage 0 //
// do not send the payload just yet, as it would be plaintext; not secret.
err = s.sendHandshakeMessage(hs, nil)
// Handshake Msg Len = len(DH ephemeral key)
err = s.sendHandshakeMessage(hs, nil, noise.DH25519.DHLen())
if err != nil {
return fmt.Errorf("error sending handshake message: %w", err)
}
@ -77,7 +79,8 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
}
// stage 2 //
err = s.sendHandshakeMessage(hs, payload)
// Handshake Msg Len = len(DHT static key) + MAC(static key is encrypted) + len(Payload) + MAC(payload is encrypted)
err = s.sendHandshakeMessage(hs, payload, noise.DH25519.DHLen()+len(payload)+2*poly1305.TagSize)
if err != nil {
return fmt.Errorf("error sending handshake message: %w", err)
}
@ -89,7 +92,10 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
}
// stage 1 //
err = s.sendHandshakeMessage(hs, payload)
// Handshake Msg Len = len(DH ephemeral key) + len(DHT static key) + MAC(static key is encrypted) + len(Payload) +
//MAC(payload is encrypted)
err = s.sendHandshakeMessage(hs, payload, 2*noise.DH25519.DHLen()+len(payload)+
2*poly1305.TagSize)
if err != nil {
return fmt.Errorf("error sending handshake message: %w", err)
}
@ -128,19 +134,19 @@ func (s *secureSession) setCipherStates(cs1, cs2 *noise.CipherState) {
// If payload is non-empty, it will be included in the handshake message.
// If this is the final message in the sequence, calls setCipherStates
// to initialize cipher states.
func (s *secureSession) sendHandshakeMessage(hs *noise.HandshakeState, payload []byte) error {
buf, cs1, cs2, err := hs.WriteMessage(nil, payload)
func (s *secureSession) sendHandshakeMessage(hs *noise.HandshakeState, payload []byte, handshakeMsgCap int) error {
hsbuf := pool.Get(handshakeMsgCap + LengthPrefixLength)
defer pool.Put(hsbuf)
bz, cs1, cs2, err := hs.WriteMessage(hsbuf[:0], payload)
if err != nil {
return err
}
bz := pool.Get(LengthPrefixLength + len(buf))
defer pool.Put(bz)
copy(hsbuf[LengthPrefixLength:], hsbuf)
binary.BigEndian.PutUint16(hsbuf, uint16(len(bz)))
binary.BigEndian.PutUint16(bz, uint16(len(buf)))
copy(bz[LengthPrefixLength:], buf)
_, err = s.writeMsgInsecure(bz)
_, err = s.writeMsgInsecure(hsbuf)
if err != nil {
return err
}