use io.ReadFull instead of custom helper

This commit is contained in:
Yusef Napora 2020-01-29 09:55:17 -05:00
parent 3971c24bec
commit 268e9e2936
4 changed files with 11 additions and 39 deletions

View File

@ -3,6 +3,7 @@ package noise
import (
"context"
"fmt"
"io"
proto "github.com/gogo/protobuf/proto"
ik "github.com/libp2p/go-libp2p-noise/ik"
@ -28,7 +29,7 @@ func (s *secureSession) ik_sendHandshakeMessage(payload []byte, initial_stage bo
}
// send message
_, err = writeAll(s.insecure, encMsgBuf)
_, err = s.insecure.Write(encMsgBuf)
if err != nil {
log.Error("ik_sendHandshakeMessage initiator=%v err=%s", s.initiator, err)
return fmt.Errorf("ik_sendHandshakeMessage write to conn err=%s", err)
@ -45,7 +46,7 @@ func (s *secureSession) ik_recvHandshakeMessage(initial_stage bool) (buf []byte,
buf = make([]byte, l)
_, err = fillBuffer(buf, s.insecure)
_, err = io.ReadFull(s.insecure, buf)
if err != nil {
return buf, nil, false, fmt.Errorf("ik_recvHandshakeMessage read from conn err=%s", err)
}

View File

@ -5,6 +5,7 @@ import (
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
@ -121,14 +122,14 @@ func (s *secureSession) NoisePrivateKey() [32]byte {
func (s *secureSession) readLength() (int, error) {
buf := make([]byte, 2)
_, err := fillBuffer(buf, s.insecure)
_, err := io.ReadFull(s.insecure, buf)
return int(binary.BigEndian.Uint16(buf)), err
}
func (s *secureSession) writeLength(length int) error {
buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, uint16(length))
_, err := writeAll(s.insecure, buf)
_, err := s.insecure.Write(buf)
return err
}
@ -259,7 +260,7 @@ func (s *secureSession) Read(buf []byte) (int, error) {
// read and decrypt ciphertext
ciphertext := make([]byte, l)
_, err = fillBuffer(ciphertext, s.insecure)
_, err = io.ReadFull(s.insecure, ciphertext)
if err != nil {
log.Error("read ciphertext err", err)
return 0, err
@ -337,7 +338,7 @@ func (s *secureSession) Write(in []byte) (int, error) {
return 0, err
}
_, err = writeAll(s.insecure, ciphertext)
_, err = s.insecure.Write(ciphertext)
return len(in), err
}

View File

@ -1,31 +0,0 @@
package noise
import "io"
// fillBuffer reads from the given reader until the given buffer
// is full
func fillBuffer(buf []byte, reader io.Reader) (int, error) {
total := 0
for total < len(buf) {
c, err := reader.Read(buf[total:])
if err != nil {
return total, err
}
total += c
}
return total, nil
}
// writeAll is a helper that writes to the given io.Writer until all input data
// has been written
func writeAll(writer io.Writer, data []byte) (int, error) {
total := 0
for total < len(data) {
c, err := writer.Write(data[total:])
if err != nil {
return total, err
}
total += c
}
return total, nil
}

View File

@ -3,6 +3,7 @@ package noise
import (
"context"
"fmt"
"io"
proto "github.com/gogo/protobuf/proto"
"github.com/libp2p/go-libp2p-core/peer"
@ -27,7 +28,7 @@ func (s *secureSession) xx_sendHandshakeMessage(payload []byte, initial_stage bo
return fmt.Errorf("xx_sendHandshakeMessage write length err=%s", err)
}
_, err = writeAll(s.insecure, encMsgBuf)
_, err = s.insecure.Write(encMsgBuf)
if err != nil {
log.Error("xx_sendHandshakeMessage initiator=%v err=%s", s.initiator, err)
return fmt.Errorf("xx_sendHandshakeMessage write to conn err=%s", err)
@ -46,7 +47,7 @@ func (s *secureSession) xx_recvHandshakeMessage(initial_stage bool) (buf []byte,
buf = make([]byte, l)
_, err = fillBuffer(buf, s.insecure)
_, err = io.ReadFull(s.insecure, buf)
if err != nil {
return buf, nil, false, fmt.Errorf("xx_recvHandshakeMessage read from conn err=%s", err)
}