Merge pull request #10 from libp2p/feat/message-chunking

Split large payloads into chunks
This commit is contained in:
Yusef Napora 2019-12-06 12:47:45 -05:00 committed by GitHub
commit e360301114
2 changed files with 112 additions and 34 deletions

View File

@ -20,6 +20,11 @@ import (
const payload_string = "noise-libp2p-static-key:"
// Each encrypted transport message must be <= 65,535 bytes, including 16
// bytes of authentication data. To write larger plaintexts, we split them
// into fragments of maxPlaintextLength before encrypting.
const maxPlaintextLength = 65519
var log = logging.Logger("noise")
type secureSession struct {
@ -237,33 +242,50 @@ func (s *secureSession) Read(buf []byte) (int, error) {
return l, nil
}
// read length of encrypted message
l, err := s.readLength()
if err != nil {
return 0, err
readChunk := func(buf []byte) (int, error) {
// read length of encrypted message
l, err := s.readLength()
if err != nil {
return 0, err
}
// read and decrypt ciphertext
ciphertext := make([]byte, l)
_, err = s.insecure.Read(ciphertext)
if err != nil {
log.Error("read ciphertext err", err)
return 0, err
}
plaintext, err := s.Decrypt(ciphertext)
if err != nil {
log.Error("decrypt err", err)
return 0, err
}
// append plaintext to message buffer, copy over what can fit in the buf
// then advance message buffer to remove what was copied
s.msgBuffer = append(s.msgBuffer, plaintext...)
c := copy(buf, s.msgBuffer)
s.msgBuffer = s.msgBuffer[c:]
return c, nil
}
// read and decrypt ciphertext
ciphertext := make([]byte, l)
_, err = s.insecure.Read(ciphertext)
if err != nil {
log.Error("read ciphertext err", err)
return 0, err
total := 0
for i := 0; i < len(buf); i += maxPlaintextLength {
end := i + maxPlaintextLength
if end > len(buf) {
end = len(buf)
}
c, err := readChunk(buf[i:end])
total += c
if err != nil {
return total, err
}
}
plaintext, err := s.Decrypt(ciphertext)
if err != nil {
log.Error("decrypt err", err)
return 0, err
}
// append plaintext to message buffer, copy over what can fit in the buf
// then advance message buffer to remove what was copied
s.msgBuffer = append(s.msgBuffer, plaintext...)
c := copy(buf, s.msgBuffer)
s.msgBuffer = s.msgBuffer[c:]
return c, nil
return total, nil
}
func (s *secureSession) RemoteAddr() net.Addr {
@ -294,20 +316,37 @@ func (s *secureSession) Write(in []byte) (int, error) {
s.rwLock.Lock()
defer s.rwLock.Unlock()
ciphertext, err := s.Encrypt(in)
if err != nil {
log.Error("encrypt error", err)
return 0, err
writeChunk := func(in []byte) (int, error) {
ciphertext, err := s.Encrypt(in)
if err != nil {
log.Error("encrypt error", err)
return 0, err
}
err = s.writeLength(len(ciphertext))
if err != nil {
log.Error("write length err", err)
return 0, err
}
_, err = s.insecure.Write(ciphertext)
return len(in), err
}
err = s.writeLength(len(ciphertext))
if err != nil {
log.Error("write length err", err)
return 0, err
}
written := 0
for i := 0; i < len(in); i += maxPlaintextLength {
end := i + maxPlaintextLength
if end > len(in) {
end = len(in)
}
_, err = s.insecure.Write(ciphertext)
return len(in), err
l, err := writeChunk(in[i:end])
written += l
if err != nil {
return written, err
}
}
return written, nil
}
func (s *secureSession) Close() error {

View File

@ -3,6 +3,7 @@ package noise
import (
"bytes"
"context"
"math/rand"
"net"
"testing"
@ -148,6 +149,44 @@ func TestKeys(t *testing.T) {
}
}
func makeLargePlaintext(size int) []byte {
buf := make([]byte, size)
rand.Read(buf)
return buf
}
func TestLargePayloads(t *testing.T) {
initTransport := newTestTransport(t, crypto.Ed25519, 2048)
respTransport := newTestTransport(t, crypto.Ed25519, 2048)
initConn, respConn := connect(t, initTransport, respTransport)
defer initConn.Close()
defer respConn.Close()
// enough to require a couple Noise messages, with a size that
// isn't a neat multiple of Noise message size, just in case
size := 100000
before := makeLargePlaintext(size)
_, err := initConn.Write(before)
if err != nil {
t.Fatal(err)
}
after := make([]byte, len(before))
afterLen, err := respConn.Read(after)
if err != nil {
t.Fatal(err)
}
if len(before) != afterLen {
t.Errorf("expected to read same amount of data as written. written=%d read=%d", len(before), afterLen)
}
if !bytes.Equal(before, after) {
t.Error("Message mismatch.")
}
}
// Tests XX handshake
func TestHandshakeXX(t *testing.T) {
initTransport := newTestTransport(t, crypto.Ed25519, 2048)