changes as per review
This commit is contained in:
parent
e99d5b623e
commit
47797a5deb
|
@ -18,7 +18,6 @@ type testMode int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
readBufferGtEncMsg testMode = iota
|
readBufferGtEncMsg testMode = iota
|
||||||
readBufferGtPlainText
|
|
||||||
readBufferLtPlainText
|
readBufferLtPlainText
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -28,9 +27,6 @@ var bcs = map[string]struct {
|
||||||
"readBuffer > encrypted message": {
|
"readBuffer > encrypted message": {
|
||||||
readBufferGtEncMsg,
|
readBufferGtEncMsg,
|
||||||
},
|
},
|
||||||
"readBuffer > decrypted plaintext": {
|
|
||||||
readBufferGtPlainText,
|
|
||||||
},
|
|
||||||
"readBuffer < decrypted plaintext": {
|
"readBuffer < decrypted plaintext": {
|
||||||
readBufferLtPlainText,
|
readBufferLtPlainText,
|
||||||
},
|
},
|
||||||
|
@ -178,8 +174,6 @@ func benchDataTransfer(b *benchenv, dataSize int64, m testMode) {
|
||||||
switch m {
|
switch m {
|
||||||
case readBufferGtEncMsg:
|
case readBufferGtEncMsg:
|
||||||
rbuf = make([]byte, len(plainTextBufs[i])+poly1305.TagSize+1)
|
rbuf = make([]byte, len(plainTextBufs[i])+poly1305.TagSize+1)
|
||||||
case readBufferGtPlainText:
|
|
||||||
rbuf = make([]byte, len(plainTextBufs[i])+1)
|
|
||||||
case readBufferLtPlainText:
|
case readBufferLtPlainText:
|
||||||
rbuf = make([]byte, len(plainTextBufs[i])-2)
|
rbuf = make([]byte, len(plainTextBufs[i])-2)
|
||||||
}
|
}
|
||||||
|
|
|
@ -59,11 +59,19 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We can re-use this buffer for all handshake messages as it's size
|
||||||
|
// will be the size of the maximum handshake message for the Noise XX pattern.
|
||||||
|
// Also, since we prefix every noise handshake message with it's length, we need to account for
|
||||||
|
// it when we fetch the buffer from the pool
|
||||||
|
maxMsgSize := 2*noise.DH25519.DHLen() + len(payload) + 2*poly1305.TagSize
|
||||||
|
hbuf := pool.Get(maxMsgSize + LengthPrefixLength)
|
||||||
|
defer pool.Put(hbuf)
|
||||||
|
|
||||||
if s.initiator {
|
if s.initiator {
|
||||||
// stage 0 //
|
// stage 0 //
|
||||||
// do not send the payload just yet, as it would be plaintext; not secret.
|
// do not send the payload just yet, as it would be plaintext; not secret.
|
||||||
// Handshake Msg Len = len(DH ephemeral key)
|
// Handshake Msg Len = len(DH ephemeral key)
|
||||||
err = s.sendHandshakeMessage(hs, nil, noise.DH25519.DHLen())
|
err = s.sendHandshakeMessage(hs, nil, hbuf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error sending handshake message: %w", err)
|
return fmt.Errorf("error sending handshake message: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -80,7 +88,7 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
|
||||||
|
|
||||||
// stage 2 //
|
// stage 2 //
|
||||||
// Handshake Msg Len = len(DHT static key) + MAC(static key is encrypted) + len(Payload) + MAC(payload is encrypted)
|
// Handshake Msg Len = len(DHT static key) + MAC(static key is encrypted) + len(Payload) + MAC(payload is encrypted)
|
||||||
err = s.sendHandshakeMessage(hs, payload, noise.DH25519.DHLen()+len(payload)+2*poly1305.TagSize)
|
err = s.sendHandshakeMessage(hs, payload, hbuf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error sending handshake message: %w", err)
|
return fmt.Errorf("error sending handshake message: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -94,8 +102,7 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
|
||||||
// stage 1 //
|
// stage 1 //
|
||||||
// Handshake Msg Len = len(DH ephemeral key) + len(DHT static key) + MAC(static key is encrypted) + len(Payload) +
|
// Handshake Msg Len = len(DH ephemeral key) + len(DHT static key) + MAC(static key is encrypted) + len(Payload) +
|
||||||
//MAC(payload is encrypted)
|
//MAC(payload is encrypted)
|
||||||
err = s.sendHandshakeMessage(hs, payload, 2*noise.DH25519.DHLen()+len(payload)+
|
err = s.sendHandshakeMessage(hs, payload, hbuf)
|
||||||
2*poly1305.TagSize)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error sending handshake message: %w", err)
|
return fmt.Errorf("error sending handshake message: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -134,19 +141,18 @@ func (s *secureSession) setCipherStates(cs1, cs2 *noise.CipherState) {
|
||||||
// If payload is non-empty, it will be included in the handshake message.
|
// If payload is non-empty, it will be included in the handshake message.
|
||||||
// If this is the final message in the sequence, calls setCipherStates
|
// If this is the final message in the sequence, calls setCipherStates
|
||||||
// to initialize cipher states.
|
// to initialize cipher states.
|
||||||
func (s *secureSession) sendHandshakeMessage(hs *noise.HandshakeState, payload []byte, handshakeMsgCap int) error {
|
func (s *secureSession) sendHandshakeMessage(hs *noise.HandshakeState, payload []byte, hbuf []byte) error {
|
||||||
hsbuf := pool.Get(handshakeMsgCap + LengthPrefixLength)
|
// the first two bytes will be the length of the noise handshake message.
|
||||||
defer pool.Put(hsbuf)
|
bz, cs1, cs2, err := hs.WriteMessage(hbuf[:LengthPrefixLength], payload)
|
||||||
|
|
||||||
bz, cs1, cs2, err := hs.WriteMessage(hsbuf[:0], payload)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
copy(hsbuf[LengthPrefixLength:], hsbuf)
|
// bz will also include the length prefix as we passed a slice of LengthPrefixLength length
|
||||||
binary.BigEndian.PutUint16(hsbuf, uint16(len(bz)))
|
// to hs.Write().
|
||||||
|
binary.BigEndian.PutUint16(hbuf, uint16(len(bz)-LengthPrefixLength))
|
||||||
|
|
||||||
_, err = s.writeMsgInsecure(hsbuf)
|
_, err = s.writeMsgInsecure(hbuf[:len(bz)])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,32 +54,14 @@ func (s *secureSession) Read(buf []byte) (int, error) {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the buffer is atleast as big as the decrypted message size,
|
// If the buffer is atleast as big as the encrypted message size,
|
||||||
// we can surely decrypt in place.
|
// we can read AND decrypt in place.
|
||||||
if len(buf) >= nextMsgLen-poly1305.TagSize {
|
|
||||||
var toDecrypt []byte
|
|
||||||
|
|
||||||
// If the buffer is atleast as big as the encrypted message, we can
|
|
||||||
// read the message directly into the buffer and then decrypt in place.
|
|
||||||
if len(buf) >= nextMsgLen {
|
if len(buf) >= nextMsgLen {
|
||||||
if err := s.readNextMsgInsecure(buf[:nextMsgLen]); err != nil {
|
if err := s.readNextMsgInsecure(buf[:nextMsgLen]); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
toDecrypt = buf[:nextMsgLen]
|
|
||||||
} else {
|
|
||||||
// Since the buffer is not big enough for the encrypted message,
|
|
||||||
// we need to get one from the pool.
|
|
||||||
cbuf := pool.Get(nextMsgLen)
|
|
||||||
defer pool.Put(cbuf)
|
|
||||||
if err := s.readNextMsgInsecure(cbuf); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
toDecrypt = cbuf
|
|
||||||
}
|
|
||||||
|
|
||||||
// decrypt the message directly into the buffer since we know the buffer is atleast that big.
|
_, err := s.decrypt(buf[:0], buf[:nextMsgLen])
|
||||||
// This will avoid a copy from `cbuf` into buf for the else case above.
|
|
||||||
_, err := s.decrypt(buf[:0], toDecrypt)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
@ -130,16 +112,14 @@ func (s *secureSession) Write(data []byte) (int, error) {
|
||||||
end = total
|
end = total
|
||||||
}
|
}
|
||||||
|
|
||||||
b, err := s.encrypt(cbuf[:0], data[written:end])
|
b, err := s.encrypt(cbuf[:LengthPrefixLength], data[written:end])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
copy(cbuf[LengthPrefixLength:], b)
|
binary.BigEndian.PutUint16(cbuf, uint16(len(b)-LengthPrefixLength))
|
||||||
|
|
||||||
binary.BigEndian.PutUint16(cbuf, uint16(len(b)))
|
_, err = s.writeMsgInsecure(cbuf[0:len(b)])
|
||||||
|
|
||||||
_, err = s.writeMsgInsecure(cbuf[0 : len(b)+LengthPrefixLength])
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return written, err
|
return written, err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue