pre-calculate handshake message len
This commit is contained in:
parent
d9758632dc
commit
e99d5b623e
|
@ -6,6 +6,7 @@ import (
|
|||
"encoding/binary"
|
||||
"fmt"
|
||||
pool "github.com/libp2p/go-buffer-pool"
|
||||
"golang.org/x/crypto/poly1305"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
|
@ -61,7 +62,8 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
|
|||
if s.initiator {
|
||||
// stage 0 //
|
||||
// do not send the payload just yet, as it would be plaintext; not secret.
|
||||
err = s.sendHandshakeMessage(hs, nil)
|
||||
// Handshake Msg Len = len(DH ephemeral key)
|
||||
err = s.sendHandshakeMessage(hs, nil, noise.DH25519.DHLen())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending handshake message: %w", err)
|
||||
}
|
||||
|
@ -77,7 +79,8 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// stage 2 //
|
||||
err = s.sendHandshakeMessage(hs, payload)
|
||||
// Handshake Msg Len = len(DHT static key) + MAC(static key is encrypted) + len(Payload) + MAC(payload is encrypted)
|
||||
err = s.sendHandshakeMessage(hs, payload, noise.DH25519.DHLen()+len(payload)+2*poly1305.TagSize)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending handshake message: %w", err)
|
||||
}
|
||||
|
@ -89,7 +92,10 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// stage 1 //
|
||||
err = s.sendHandshakeMessage(hs, payload)
|
||||
// Handshake Msg Len = len(DH ephemeral key) + len(DHT static key) + MAC(static key is encrypted) + len(Payload) +
|
||||
//MAC(payload is encrypted)
|
||||
err = s.sendHandshakeMessage(hs, payload, 2*noise.DH25519.DHLen()+len(payload)+
|
||||
2*poly1305.TagSize)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending handshake message: %w", err)
|
||||
}
|
||||
|
@ -128,19 +134,19 @@ func (s *secureSession) setCipherStates(cs1, cs2 *noise.CipherState) {
|
|||
// If payload is non-empty, it will be included in the handshake message.
|
||||
// If this is the final message in the sequence, calls setCipherStates
|
||||
// to initialize cipher states.
|
||||
func (s *secureSession) sendHandshakeMessage(hs *noise.HandshakeState, payload []byte) error {
|
||||
buf, cs1, cs2, err := hs.WriteMessage(nil, payload)
|
||||
func (s *secureSession) sendHandshakeMessage(hs *noise.HandshakeState, payload []byte, handshakeMsgCap int) error {
|
||||
hsbuf := pool.Get(handshakeMsgCap + LengthPrefixLength)
|
||||
defer pool.Put(hsbuf)
|
||||
|
||||
bz, cs1, cs2, err := hs.WriteMessage(hsbuf[:0], payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bz := pool.Get(LengthPrefixLength + len(buf))
|
||||
defer pool.Put(bz)
|
||||
copy(hsbuf[LengthPrefixLength:], hsbuf)
|
||||
binary.BigEndian.PutUint16(hsbuf, uint16(len(bz)))
|
||||
|
||||
binary.BigEndian.PutUint16(bz, uint16(len(buf)))
|
||||
copy(bz[LengthPrefixLength:], buf)
|
||||
|
||||
_, err = s.writeMsgInsecure(bz)
|
||||
_, err = s.writeMsgInsecure(hsbuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue