changes as per review

This commit is contained in:
Aarsh Shah 2020-05-04 15:34:12 +05:30
parent e99d5b623e
commit 47797a5deb
3 changed files with 27 additions and 47 deletions

View File

@ -18,7 +18,6 @@ type testMode int
const ( const (
readBufferGtEncMsg testMode = iota readBufferGtEncMsg testMode = iota
readBufferGtPlainText
readBufferLtPlainText readBufferLtPlainText
) )
@ -28,9 +27,6 @@ var bcs = map[string]struct {
"readBuffer > encrypted message": { "readBuffer > encrypted message": {
readBufferGtEncMsg, readBufferGtEncMsg,
}, },
"readBuffer > decrypted plaintext": {
readBufferGtPlainText,
},
"readBuffer < decrypted plaintext": { "readBuffer < decrypted plaintext": {
readBufferLtPlainText, readBufferLtPlainText,
}, },
@ -178,8 +174,6 @@ func benchDataTransfer(b *benchenv, dataSize int64, m testMode) {
switch m { switch m {
case readBufferGtEncMsg: case readBufferGtEncMsg:
rbuf = make([]byte, len(plainTextBufs[i])+poly1305.TagSize+1) rbuf = make([]byte, len(plainTextBufs[i])+poly1305.TagSize+1)
case readBufferGtPlainText:
rbuf = make([]byte, len(plainTextBufs[i])+1)
case readBufferLtPlainText: case readBufferLtPlainText:
rbuf = make([]byte, len(plainTextBufs[i])-2) rbuf = make([]byte, len(plainTextBufs[i])-2)
} }

View File

@ -59,11 +59,19 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
} }
} }
// We can re-use this buffer for all handshake messages as it's size
// will be the size of the maximum handshake message for the Noise XX pattern.
// Also, since we prefix every noise handshake message with it's length, we need to account for
// it when we fetch the buffer from the pool
maxMsgSize := 2*noise.DH25519.DHLen() + len(payload) + 2*poly1305.TagSize
hbuf := pool.Get(maxMsgSize + LengthPrefixLength)
defer pool.Put(hbuf)
if s.initiator { if s.initiator {
// stage 0 // // stage 0 //
// do not send the payload just yet, as it would be plaintext; not secret. // do not send the payload just yet, as it would be plaintext; not secret.
// Handshake Msg Len = len(DH ephemeral key) // Handshake Msg Len = len(DH ephemeral key)
err = s.sendHandshakeMessage(hs, nil, noise.DH25519.DHLen()) err = s.sendHandshakeMessage(hs, nil, hbuf)
if err != nil { if err != nil {
return fmt.Errorf("error sending handshake message: %w", err) return fmt.Errorf("error sending handshake message: %w", err)
} }
@ -80,7 +88,7 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
// stage 2 // // stage 2 //
// Handshake Msg Len = len(DHT static key) + MAC(static key is encrypted) + len(Payload) + MAC(payload is encrypted) // Handshake Msg Len = len(DHT static key) + MAC(static key is encrypted) + len(Payload) + MAC(payload is encrypted)
err = s.sendHandshakeMessage(hs, payload, noise.DH25519.DHLen()+len(payload)+2*poly1305.TagSize) err = s.sendHandshakeMessage(hs, payload, hbuf)
if err != nil { if err != nil {
return fmt.Errorf("error sending handshake message: %w", err) return fmt.Errorf("error sending handshake message: %w", err)
} }
@ -94,8 +102,7 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
// stage 1 // // stage 1 //
// Handshake Msg Len = len(DH ephemeral key) + len(DHT static key) + MAC(static key is encrypted) + len(Payload) + // Handshake Msg Len = len(DH ephemeral key) + len(DHT static key) + MAC(static key is encrypted) + len(Payload) +
//MAC(payload is encrypted) //MAC(payload is encrypted)
err = s.sendHandshakeMessage(hs, payload, 2*noise.DH25519.DHLen()+len(payload)+ err = s.sendHandshakeMessage(hs, payload, hbuf)
2*poly1305.TagSize)
if err != nil { if err != nil {
return fmt.Errorf("error sending handshake message: %w", err) return fmt.Errorf("error sending handshake message: %w", err)
} }
@ -134,19 +141,18 @@ func (s *secureSession) setCipherStates(cs1, cs2 *noise.CipherState) {
// If payload is non-empty, it will be included in the handshake message. // If payload is non-empty, it will be included in the handshake message.
// If this is the final message in the sequence, calls setCipherStates // If this is the final message in the sequence, calls setCipherStates
// to initialize cipher states. // to initialize cipher states.
func (s *secureSession) sendHandshakeMessage(hs *noise.HandshakeState, payload []byte, handshakeMsgCap int) error { func (s *secureSession) sendHandshakeMessage(hs *noise.HandshakeState, payload []byte, hbuf []byte) error {
hsbuf := pool.Get(handshakeMsgCap + LengthPrefixLength) // the first two bytes will be the length of the noise handshake message.
defer pool.Put(hsbuf) bz, cs1, cs2, err := hs.WriteMessage(hbuf[:LengthPrefixLength], payload)
bz, cs1, cs2, err := hs.WriteMessage(hsbuf[:0], payload)
if err != nil { if err != nil {
return err return err
} }
copy(hsbuf[LengthPrefixLength:], hsbuf) // bz will also include the length prefix as we passed a slice of LengthPrefixLength length
binary.BigEndian.PutUint16(hsbuf, uint16(len(bz))) // to hs.Write().
binary.BigEndian.PutUint16(hbuf, uint16(len(bz)-LengthPrefixLength))
_, err = s.writeMsgInsecure(hsbuf) _, err = s.writeMsgInsecure(hbuf[:len(bz)])
if err != nil { if err != nil {
return err return err
} }

View File

@ -54,32 +54,14 @@ func (s *secureSession) Read(buf []byte) (int, error) {
return 0, err return 0, err
} }
// If the buffer is atleast as big as the decrypted message size, // If the buffer is atleast as big as the encrypted message size,
// we can surely decrypt in place. // we can read AND decrypt in place.
if len(buf) >= nextMsgLen-poly1305.TagSize {
var toDecrypt []byte
// If the buffer is atleast as big as the encrypted message, we can
// read the message directly into the buffer and then decrypt in place.
if len(buf) >= nextMsgLen { if len(buf) >= nextMsgLen {
if err := s.readNextMsgInsecure(buf[:nextMsgLen]); err != nil { if err := s.readNextMsgInsecure(buf[:nextMsgLen]); err != nil {
return 0, err return 0, err
} }
toDecrypt = buf[:nextMsgLen]
} else {
// Since the buffer is not big enough for the encrypted message,
// we need to get one from the pool.
cbuf := pool.Get(nextMsgLen)
defer pool.Put(cbuf)
if err := s.readNextMsgInsecure(cbuf); err != nil {
return 0, err
}
toDecrypt = cbuf
}
// decrypt the message directly into the buffer since we know the buffer is atleast that big. _, err := s.decrypt(buf[:0], buf[:nextMsgLen])
// This will avoid a copy from `cbuf` into buf for the else case above.
_, err := s.decrypt(buf[:0], toDecrypt)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -130,16 +112,14 @@ func (s *secureSession) Write(data []byte) (int, error) {
end = total end = total
} }
b, err := s.encrypt(cbuf[:0], data[written:end]) b, err := s.encrypt(cbuf[:LengthPrefixLength], data[written:end])
if err != nil { if err != nil {
return 0, err return 0, err
} }
copy(cbuf[LengthPrefixLength:], b) binary.BigEndian.PutUint16(cbuf, uint16(len(b)-LengthPrefixLength))
binary.BigEndian.PutUint16(cbuf, uint16(len(b))) _, err = s.writeMsgInsecure(cbuf[0:len(b)])
_, err = s.writeMsgInsecure(cbuf[0 : len(b)+LengthPrefixLength])
if err != nil { if err != nil {
return written, err return written, err
} }