Merge pull request #39 from libp2p/fix/cleanup

use io.ReadFull instead of custom helper
This commit is contained in:
Yusef Napora 2020-01-29 09:58:58 -05:00 committed by GitHub
commit a1e6857b5c
4 changed files with 11 additions and 39 deletions

View File

@ -3,6 +3,7 @@ package noise
import ( import (
"context" "context"
"fmt" "fmt"
"io"
proto "github.com/gogo/protobuf/proto" proto "github.com/gogo/protobuf/proto"
ik "github.com/libp2p/go-libp2p-noise/ik" ik "github.com/libp2p/go-libp2p-noise/ik"
@ -28,7 +29,7 @@ func (s *secureSession) ik_sendHandshakeMessage(payload []byte, initial_stage bo
} }
// send message // send message
_, err = writeAll(s.insecure, encMsgBuf) _, err = s.insecure.Write(encMsgBuf)
if err != nil { if err != nil {
log.Error("ik_sendHandshakeMessage initiator=%v err=%s", s.initiator, err) log.Error("ik_sendHandshakeMessage initiator=%v err=%s", s.initiator, err)
return fmt.Errorf("ik_sendHandshakeMessage write to conn err=%s", err) return fmt.Errorf("ik_sendHandshakeMessage write to conn err=%s", err)
@ -45,7 +46,7 @@ func (s *secureSession) ik_recvHandshakeMessage(initial_stage bool) (buf []byte,
buf = make([]byte, l) buf = make([]byte, l)
_, err = fillBuffer(buf, s.insecure) _, err = io.ReadFull(s.insecure, buf)
if err != nil { if err != nil {
return buf, nil, false, fmt.Errorf("ik_recvHandshakeMessage read from conn err=%s", err) return buf, nil, false, fmt.Errorf("ik_recvHandshakeMessage read from conn err=%s", err)
} }

View File

@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"sync" "sync"
"time" "time"
@ -121,14 +122,14 @@ func (s *secureSession) NoisePrivateKey() [32]byte {
func (s *secureSession) readLength() (int, error) { func (s *secureSession) readLength() (int, error) {
buf := make([]byte, 2) buf := make([]byte, 2)
_, err := fillBuffer(buf, s.insecure) _, err := io.ReadFull(s.insecure, buf)
return int(binary.BigEndian.Uint16(buf)), err return int(binary.BigEndian.Uint16(buf)), err
} }
func (s *secureSession) writeLength(length int) error { func (s *secureSession) writeLength(length int) error {
buf := make([]byte, 2) buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, uint16(length)) binary.BigEndian.PutUint16(buf, uint16(length))
_, err := writeAll(s.insecure, buf) _, err := s.insecure.Write(buf)
return err return err
} }
@ -259,7 +260,7 @@ func (s *secureSession) Read(buf []byte) (int, error) {
// read and decrypt ciphertext // read and decrypt ciphertext
ciphertext := make([]byte, l) ciphertext := make([]byte, l)
_, err = fillBuffer(ciphertext, s.insecure) _, err = io.ReadFull(s.insecure, ciphertext)
if err != nil { if err != nil {
log.Error("read ciphertext err", err) log.Error("read ciphertext err", err)
return 0, err return 0, err
@ -337,7 +338,7 @@ func (s *secureSession) Write(in []byte) (int, error) {
return 0, err return 0, err
} }
_, err = writeAll(s.insecure, ciphertext) _, err = s.insecure.Write(ciphertext)
return len(in), err return len(in), err
} }

View File

@ -1,31 +0,0 @@
package noise
import "io"
// fillBuffer reads from the given reader until the given buffer
// is full
func fillBuffer(buf []byte, reader io.Reader) (int, error) {
total := 0
for total < len(buf) {
c, err := reader.Read(buf[total:])
if err != nil {
return total, err
}
total += c
}
return total, nil
}
// writeAll is a helper that writes to the given io.Writer until all input data
// has been written
func writeAll(writer io.Writer, data []byte) (int, error) {
total := 0
for total < len(data) {
c, err := writer.Write(data[total:])
if err != nil {
return total, err
}
total += c
}
return total, nil
}

View File

@ -3,6 +3,7 @@ package noise
import ( import (
"context" "context"
"fmt" "fmt"
"io"
proto "github.com/gogo/protobuf/proto" proto "github.com/gogo/protobuf/proto"
"github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peer"
@ -27,7 +28,7 @@ func (s *secureSession) xx_sendHandshakeMessage(payload []byte, initial_stage bo
return fmt.Errorf("xx_sendHandshakeMessage write length err=%s", err) return fmt.Errorf("xx_sendHandshakeMessage write length err=%s", err)
} }
_, err = writeAll(s.insecure, encMsgBuf) _, err = s.insecure.Write(encMsgBuf)
if err != nil { if err != nil {
log.Error("xx_sendHandshakeMessage initiator=%v err=%s", s.initiator, err) log.Error("xx_sendHandshakeMessage initiator=%v err=%s", s.initiator, err)
return fmt.Errorf("xx_sendHandshakeMessage write to conn err=%s", err) return fmt.Errorf("xx_sendHandshakeMessage write to conn err=%s", err)
@ -46,7 +47,7 @@ func (s *secureSession) xx_recvHandshakeMessage(initial_stage bool) (buf []byte,
buf = make([]byte, l) buf = make([]byte, l)
_, err = fillBuffer(buf, s.insecure) _, err = io.ReadFull(s.insecure, buf)
if err != nil { if err != nil {
return buf, nil, false, fmt.Errorf("xx_recvHandshakeMessage read from conn err=%s", err) return buf, nil, false, fmt.Errorf("xx_recvHandshakeMessage read from conn err=%s", err)
} }