Merge pull request #39 from libp2p/fix/cleanup
use io.ReadFull instead of custom helper
This commit is contained in:
commit
a1e6857b5c
|
@ -3,6 +3,7 @@ package noise
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
proto "github.com/gogo/protobuf/proto"
|
||||
ik "github.com/libp2p/go-libp2p-noise/ik"
|
||||
|
@ -28,7 +29,7 @@ func (s *secureSession) ik_sendHandshakeMessage(payload []byte, initial_stage bo
|
|||
}
|
||||
|
||||
// send message
|
||||
_, err = writeAll(s.insecure, encMsgBuf)
|
||||
_, err = s.insecure.Write(encMsgBuf)
|
||||
if err != nil {
|
||||
log.Error("ik_sendHandshakeMessage initiator=%v err=%s", s.initiator, err)
|
||||
return fmt.Errorf("ik_sendHandshakeMessage write to conn err=%s", err)
|
||||
|
@ -45,7 +46,7 @@ func (s *secureSession) ik_recvHandshakeMessage(initial_stage bool) (buf []byte,
|
|||
|
||||
buf = make([]byte, l)
|
||||
|
||||
_, err = fillBuffer(buf, s.insecure)
|
||||
_, err = io.ReadFull(s.insecure, buf)
|
||||
if err != nil {
|
||||
return buf, nil, false, fmt.Errorf("ik_recvHandshakeMessage read from conn err=%s", err)
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -121,14 +122,14 @@ func (s *secureSession) NoisePrivateKey() [32]byte {
|
|||
|
||||
func (s *secureSession) readLength() (int, error) {
|
||||
buf := make([]byte, 2)
|
||||
_, err := fillBuffer(buf, s.insecure)
|
||||
_, err := io.ReadFull(s.insecure, buf)
|
||||
return int(binary.BigEndian.Uint16(buf)), err
|
||||
}
|
||||
|
||||
func (s *secureSession) writeLength(length int) error {
|
||||
buf := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(buf, uint16(length))
|
||||
_, err := writeAll(s.insecure, buf)
|
||||
_, err := s.insecure.Write(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -259,7 +260,7 @@ func (s *secureSession) Read(buf []byte) (int, error) {
|
|||
|
||||
// read and decrypt ciphertext
|
||||
ciphertext := make([]byte, l)
|
||||
_, err = fillBuffer(ciphertext, s.insecure)
|
||||
_, err = io.ReadFull(s.insecure, ciphertext)
|
||||
if err != nil {
|
||||
log.Error("read ciphertext err", err)
|
||||
return 0, err
|
||||
|
@ -337,7 +338,7 @@ func (s *secureSession) Write(in []byte) (int, error) {
|
|||
return 0, err
|
||||
}
|
||||
|
||||
_, err = writeAll(s.insecure, ciphertext)
|
||||
_, err = s.insecure.Write(ciphertext)
|
||||
return len(in), err
|
||||
}
|
||||
|
||||
|
|
|
@ -1,31 +0,0 @@
|
|||
package noise
|
||||
|
||||
import "io"
|
||||
|
||||
// fillBuffer reads from the given reader until the given buffer
|
||||
// is full
|
||||
func fillBuffer(buf []byte, reader io.Reader) (int, error) {
|
||||
total := 0
|
||||
for total < len(buf) {
|
||||
c, err := reader.Read(buf[total:])
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
total += c
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// writeAll is a helper that writes to the given io.Writer until all input data
|
||||
// has been written
|
||||
func writeAll(writer io.Writer, data []byte) (int, error) {
|
||||
total := 0
|
||||
for total < len(data) {
|
||||
c, err := writer.Write(data[total:])
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
total += c
|
||||
}
|
||||
return total, nil
|
||||
}
|
|
@ -3,6 +3,7 @@ package noise
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
proto "github.com/gogo/protobuf/proto"
|
||||
"github.com/libp2p/go-libp2p-core/peer"
|
||||
|
@ -27,7 +28,7 @@ func (s *secureSession) xx_sendHandshakeMessage(payload []byte, initial_stage bo
|
|||
return fmt.Errorf("xx_sendHandshakeMessage write length err=%s", err)
|
||||
}
|
||||
|
||||
_, err = writeAll(s.insecure, encMsgBuf)
|
||||
_, err = s.insecure.Write(encMsgBuf)
|
||||
if err != nil {
|
||||
log.Error("xx_sendHandshakeMessage initiator=%v err=%s", s.initiator, err)
|
||||
return fmt.Errorf("xx_sendHandshakeMessage write to conn err=%s", err)
|
||||
|
@ -46,7 +47,7 @@ func (s *secureSession) xx_recvHandshakeMessage(initial_stage bool) (buf []byte,
|
|||
|
||||
buf = make([]byte, l)
|
||||
|
||||
_, err = fillBuffer(buf, s.insecure)
|
||||
_, err = io.ReadFull(s.insecure, buf)
|
||||
if err != nil {
|
||||
return buf, nil, false, fmt.Errorf("xx_recvHandshakeMessage read from conn err=%s", err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue