changes as per review
This commit is contained in:
parent
e99d5b623e
commit
47797a5deb
|
@ -18,7 +18,6 @@ type testMode int
|
|||
|
||||
const (
|
||||
readBufferGtEncMsg testMode = iota
|
||||
readBufferGtPlainText
|
||||
readBufferLtPlainText
|
||||
)
|
||||
|
||||
|
@ -28,9 +27,6 @@ var bcs = map[string]struct {
|
|||
"readBuffer > encrypted message": {
|
||||
readBufferGtEncMsg,
|
||||
},
|
||||
"readBuffer > decrypted plaintext": {
|
||||
readBufferGtPlainText,
|
||||
},
|
||||
"readBuffer < decrypted plaintext": {
|
||||
readBufferLtPlainText,
|
||||
},
|
||||
|
@ -178,8 +174,6 @@ func benchDataTransfer(b *benchenv, dataSize int64, m testMode) {
|
|||
switch m {
|
||||
case readBufferGtEncMsg:
|
||||
rbuf = make([]byte, len(plainTextBufs[i])+poly1305.TagSize+1)
|
||||
case readBufferGtPlainText:
|
||||
rbuf = make([]byte, len(plainTextBufs[i])+1)
|
||||
case readBufferLtPlainText:
|
||||
rbuf = make([]byte, len(plainTextBufs[i])-2)
|
||||
}
|
||||
|
|
|
@ -59,11 +59,19 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
|
||||
// We can re-use this buffer for all handshake messages as it's size
|
||||
// will be the size of the maximum handshake message for the Noise XX pattern.
|
||||
// Also, since we prefix every noise handshake message with it's length, we need to account for
|
||||
// it when we fetch the buffer from the pool
|
||||
maxMsgSize := 2*noise.DH25519.DHLen() + len(payload) + 2*poly1305.TagSize
|
||||
hbuf := pool.Get(maxMsgSize + LengthPrefixLength)
|
||||
defer pool.Put(hbuf)
|
||||
|
||||
if s.initiator {
|
||||
// stage 0 //
|
||||
// do not send the payload just yet, as it would be plaintext; not secret.
|
||||
// Handshake Msg Len = len(DH ephemeral key)
|
||||
err = s.sendHandshakeMessage(hs, nil, noise.DH25519.DHLen())
|
||||
err = s.sendHandshakeMessage(hs, nil, hbuf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending handshake message: %w", err)
|
||||
}
|
||||
|
@ -80,7 +88,7 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
|
|||
|
||||
// stage 2 //
|
||||
// Handshake Msg Len = len(DHT static key) + MAC(static key is encrypted) + len(Payload) + MAC(payload is encrypted)
|
||||
err = s.sendHandshakeMessage(hs, payload, noise.DH25519.DHLen()+len(payload)+2*poly1305.TagSize)
|
||||
err = s.sendHandshakeMessage(hs, payload, hbuf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending handshake message: %w", err)
|
||||
}
|
||||
|
@ -94,8 +102,7 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
|
|||
// stage 1 //
|
||||
// Handshake Msg Len = len(DH ephemeral key) + len(DHT static key) + MAC(static key is encrypted) + len(Payload) +
|
||||
//MAC(payload is encrypted)
|
||||
err = s.sendHandshakeMessage(hs, payload, 2*noise.DH25519.DHLen()+len(payload)+
|
||||
2*poly1305.TagSize)
|
||||
err = s.sendHandshakeMessage(hs, payload, hbuf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending handshake message: %w", err)
|
||||
}
|
||||
|
@ -134,19 +141,18 @@ func (s *secureSession) setCipherStates(cs1, cs2 *noise.CipherState) {
|
|||
// If payload is non-empty, it will be included in the handshake message.
|
||||
// If this is the final message in the sequence, calls setCipherStates
|
||||
// to initialize cipher states.
|
||||
func (s *secureSession) sendHandshakeMessage(hs *noise.HandshakeState, payload []byte, handshakeMsgCap int) error {
|
||||
hsbuf := pool.Get(handshakeMsgCap + LengthPrefixLength)
|
||||
defer pool.Put(hsbuf)
|
||||
|
||||
bz, cs1, cs2, err := hs.WriteMessage(hsbuf[:0], payload)
|
||||
func (s *secureSession) sendHandshakeMessage(hs *noise.HandshakeState, payload []byte, hbuf []byte) error {
|
||||
// the first two bytes will be the length of the noise handshake message.
|
||||
bz, cs1, cs2, err := hs.WriteMessage(hbuf[:LengthPrefixLength], payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
copy(hsbuf[LengthPrefixLength:], hsbuf)
|
||||
binary.BigEndian.PutUint16(hsbuf, uint16(len(bz)))
|
||||
// bz will also include the length prefix as we passed a slice of LengthPrefixLength length
|
||||
// to hs.Write().
|
||||
binary.BigEndian.PutUint16(hbuf, uint16(len(bz)-LengthPrefixLength))
|
||||
|
||||
_, err = s.writeMsgInsecure(hsbuf)
|
||||
_, err = s.writeMsgInsecure(hbuf[:len(bz)])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -54,32 +54,14 @@ func (s *secureSession) Read(buf []byte) (int, error) {
|
|||
return 0, err
|
||||
}
|
||||
|
||||
// If the buffer is atleast as big as the decrypted message size,
|
||||
// we can surely decrypt in place.
|
||||
if len(buf) >= nextMsgLen-poly1305.TagSize {
|
||||
var toDecrypt []byte
|
||||
|
||||
// If the buffer is atleast as big as the encrypted message, we can
|
||||
// read the message directly into the buffer and then decrypt in place.
|
||||
// If the buffer is atleast as big as the encrypted message size,
|
||||
// we can read AND decrypt in place.
|
||||
if len(buf) >= nextMsgLen {
|
||||
if err := s.readNextMsgInsecure(buf[:nextMsgLen]); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
toDecrypt = buf[:nextMsgLen]
|
||||
} else {
|
||||
// Since the buffer is not big enough for the encrypted message,
|
||||
// we need to get one from the pool.
|
||||
cbuf := pool.Get(nextMsgLen)
|
||||
defer pool.Put(cbuf)
|
||||
if err := s.readNextMsgInsecure(cbuf); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
toDecrypt = cbuf
|
||||
}
|
||||
|
||||
// decrypt the message directly into the buffer since we know the buffer is atleast that big.
|
||||
// This will avoid a copy from `cbuf` into buf for the else case above.
|
||||
_, err := s.decrypt(buf[:0], toDecrypt)
|
||||
_, err := s.decrypt(buf[:0], buf[:nextMsgLen])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
@ -130,16 +112,14 @@ func (s *secureSession) Write(data []byte) (int, error) {
|
|||
end = total
|
||||
}
|
||||
|
||||
b, err := s.encrypt(cbuf[:0], data[written:end])
|
||||
b, err := s.encrypt(cbuf[:LengthPrefixLength], data[written:end])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
copy(cbuf[LengthPrefixLength:], b)
|
||||
binary.BigEndian.PutUint16(cbuf, uint16(len(b)-LengthPrefixLength))
|
||||
|
||||
binary.BigEndian.PutUint16(cbuf, uint16(len(b)))
|
||||
|
||||
_, err = s.writeMsgInsecure(cbuf[0 : len(b)+LengthPrefixLength])
|
||||
_, err = s.writeMsgInsecure(cbuf[0:len(b)])
|
||||
if err != nil {
|
||||
return written, err
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue