Merge pull request #10 from libp2p/feat/message-chunking
Split large payloads into chunks
This commit is contained in:
commit
e360301114
|
@ -20,6 +20,11 @@ import (
|
|||
|
||||
const payload_string = "noise-libp2p-static-key:"
|
||||
|
||||
// Each encrypted transport message must be <= 65,535 bytes, including 16
|
||||
// bytes of authentication data. To write larger plaintexts, we split them
|
||||
// into fragments of maxPlaintextLength before encrypting.
|
||||
const maxPlaintextLength = 65519
|
||||
|
||||
var log = logging.Logger("noise")
|
||||
|
||||
type secureSession struct {
|
||||
|
@ -237,33 +242,50 @@ func (s *secureSession) Read(buf []byte) (int, error) {
|
|||
return l, nil
|
||||
}
|
||||
|
||||
// read length of encrypted message
|
||||
l, err := s.readLength()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
readChunk := func(buf []byte) (int, error) {
|
||||
// read length of encrypted message
|
||||
l, err := s.readLength()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// read and decrypt ciphertext
|
||||
ciphertext := make([]byte, l)
|
||||
_, err = s.insecure.Read(ciphertext)
|
||||
if err != nil {
|
||||
log.Error("read ciphertext err", err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
plaintext, err := s.Decrypt(ciphertext)
|
||||
if err != nil {
|
||||
log.Error("decrypt err", err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// append plaintext to message buffer, copy over what can fit in the buf
|
||||
// then advance message buffer to remove what was copied
|
||||
s.msgBuffer = append(s.msgBuffer, plaintext...)
|
||||
c := copy(buf, s.msgBuffer)
|
||||
s.msgBuffer = s.msgBuffer[c:]
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// read and decrypt ciphertext
|
||||
ciphertext := make([]byte, l)
|
||||
_, err = s.insecure.Read(ciphertext)
|
||||
if err != nil {
|
||||
log.Error("read ciphertext err", err)
|
||||
return 0, err
|
||||
total := 0
|
||||
for i := 0; i < len(buf); i += maxPlaintextLength {
|
||||
end := i + maxPlaintextLength
|
||||
if end > len(buf) {
|
||||
end = len(buf)
|
||||
}
|
||||
|
||||
c, err := readChunk(buf[i:end])
|
||||
total += c
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
}
|
||||
|
||||
plaintext, err := s.Decrypt(ciphertext)
|
||||
if err != nil {
|
||||
log.Error("decrypt err", err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// append plaintext to message buffer, copy over what can fit in the buf
|
||||
// then advance message buffer to remove what was copied
|
||||
s.msgBuffer = append(s.msgBuffer, plaintext...)
|
||||
c := copy(buf, s.msgBuffer)
|
||||
s.msgBuffer = s.msgBuffer[c:]
|
||||
|
||||
return c, nil
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (s *secureSession) RemoteAddr() net.Addr {
|
||||
|
@ -294,20 +316,37 @@ func (s *secureSession) Write(in []byte) (int, error) {
|
|||
s.rwLock.Lock()
|
||||
defer s.rwLock.Unlock()
|
||||
|
||||
ciphertext, err := s.Encrypt(in)
|
||||
if err != nil {
|
||||
log.Error("encrypt error", err)
|
||||
return 0, err
|
||||
writeChunk := func(in []byte) (int, error) {
|
||||
ciphertext, err := s.Encrypt(in)
|
||||
if err != nil {
|
||||
log.Error("encrypt error", err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = s.writeLength(len(ciphertext))
|
||||
if err != nil {
|
||||
log.Error("write length err", err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
_, err = s.insecure.Write(ciphertext)
|
||||
return len(in), err
|
||||
}
|
||||
|
||||
err = s.writeLength(len(ciphertext))
|
||||
if err != nil {
|
||||
log.Error("write length err", err)
|
||||
return 0, err
|
||||
}
|
||||
written := 0
|
||||
for i := 0; i < len(in); i += maxPlaintextLength {
|
||||
end := i + maxPlaintextLength
|
||||
if end > len(in) {
|
||||
end = len(in)
|
||||
}
|
||||
|
||||
_, err = s.insecure.Write(ciphertext)
|
||||
return len(in), err
|
||||
l, err := writeChunk(in[i:end])
|
||||
written += l
|
||||
if err != nil {
|
||||
return written, err
|
||||
}
|
||||
}
|
||||
return written, nil
|
||||
}
|
||||
|
||||
func (s *secureSession) Close() error {
|
||||
|
|
|
@ -3,6 +3,7 @@ package noise
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"math/rand"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
|
@ -148,6 +149,44 @@ func TestKeys(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func makeLargePlaintext(size int) []byte {
|
||||
buf := make([]byte, size)
|
||||
rand.Read(buf)
|
||||
return buf
|
||||
}
|
||||
|
||||
func TestLargePayloads(t *testing.T) {
|
||||
initTransport := newTestTransport(t, crypto.Ed25519, 2048)
|
||||
respTransport := newTestTransport(t, crypto.Ed25519, 2048)
|
||||
|
||||
initConn, respConn := connect(t, initTransport, respTransport)
|
||||
defer initConn.Close()
|
||||
defer respConn.Close()
|
||||
|
||||
// enough to require a couple Noise messages, with a size that
|
||||
// isn't a neat multiple of Noise message size, just in case
|
||||
size := 100000
|
||||
|
||||
before := makeLargePlaintext(size)
|
||||
_, err := initConn.Write(before)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
after := make([]byte, len(before))
|
||||
afterLen, err := respConn.Read(after)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(before) != afterLen {
|
||||
t.Errorf("expected to read same amount of data as written. written=%d read=%d", len(before), afterLen)
|
||||
}
|
||||
if !bytes.Equal(before, after) {
|
||||
t.Error("Message mismatch.")
|
||||
}
|
||||
}
|
||||
|
||||
// Tests XX handshake
|
||||
func TestHandshakeXX(t *testing.T) {
|
||||
initTransport := newTestTransport(t, crypto.Ed25519, 2048)
|
||||
|
|
Loading…
Reference in New Issue