269 lines
5.4 KiB
Go
269 lines
5.4 KiB
Go
package secio
|
|
|
|
import (
|
|
"crypto/cipher"
|
|
"crypto/hmac"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"sync"
|
|
|
|
pool "github.com/libp2p/go-buffer-pool"
|
|
msgio "github.com/libp2p/go-msgio"
|
|
)
|
|
|
|
// ErrMACInvalid signals that a MAC verification failed
|
|
var ErrMACInvalid = errors.New("MAC verification failed")
|
|
|
|
type etmWriter struct {
|
|
str cipher.Stream // the stream cipher to encrypt with
|
|
mac HMAC // the mac to authenticate data with
|
|
w io.Writer
|
|
|
|
sync.Mutex
|
|
}
|
|
|
|
// NewETMWriter Encrypt-Then-MAC
|
|
func NewETMWriter(w io.Writer, s cipher.Stream, mac HMAC) msgio.WriteCloser {
|
|
return &etmWriter{w: w, str: s, mac: mac}
|
|
}
|
|
|
|
// Write writes passed in buffer as a single message.
|
|
func (w *etmWriter) Write(b []byte) (int, error) {
|
|
if err := w.WriteMsg(b); err != nil {
|
|
return 0, err
|
|
}
|
|
return len(b), nil
|
|
}
|
|
|
|
// WriteMsg writes the msg in the passed in buffer.
|
|
func (w *etmWriter) WriteMsg(b []byte) error {
|
|
w.Lock()
|
|
defer w.Unlock()
|
|
|
|
// encrypt.
|
|
buf := pool.Get(4 + len(b) + w.mac.Size())
|
|
defer pool.Put(buf)
|
|
data := buf[4 : 4+len(b)]
|
|
w.str.XORKeyStream(data, b)
|
|
|
|
// log.Debugf("ENC plaintext (%d): %s %v", len(b), b, b)
|
|
// log.Debugf("ENC ciphertext (%d): %s %v", len(data), data, data)
|
|
|
|
// then, mac.
|
|
if _, err := w.mac.Write(data); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Sum appends.
|
|
data = w.mac.Sum(data)
|
|
w.mac.Reset()
|
|
binary.BigEndian.PutUint32(buf[:4], uint32(len(data)))
|
|
|
|
_, err := w.w.Write(buf)
|
|
return err
|
|
}
|
|
|
|
func (w *etmWriter) Close() error {
|
|
if c, ok := w.w.(io.Closer); ok {
|
|
return c.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type etmReader struct {
|
|
msgio.Reader
|
|
io.Closer
|
|
|
|
// internal buffer returned from the msgio
|
|
buf []byte
|
|
|
|
// low and high watermark for the buffered data
|
|
lowat int
|
|
hiwat int
|
|
|
|
// params
|
|
msg msgio.ReadCloser // msgio for knowing where boundaries lie
|
|
str cipher.Stream // the stream cipher to encrypt with
|
|
mac HMAC // the mac to authenticate data with
|
|
|
|
// internal buffer used for checking MACs, this saves us quite a few
|
|
// allocations and should be quite small.
|
|
macBuf []byte
|
|
|
|
sync.Mutex
|
|
}
|
|
|
|
// NewETMReader Encrypt-Then-MAC
|
|
func NewETMReader(r io.Reader, s cipher.Stream, mac HMAC) msgio.ReadCloser {
|
|
return &etmReader{msg: msgio.NewReader(r), str: s, mac: mac}
|
|
}
|
|
|
|
func (r *etmReader) NextMsgLen() (int, error) {
|
|
return r.msg.NextMsgLen()
|
|
}
|
|
|
|
func (r *etmReader) drain(buf []byte) int {
|
|
// Return zero if there is no data remaining in the internal buffer.
|
|
if r.lowat == r.hiwat {
|
|
return 0
|
|
}
|
|
|
|
// Copy data to the output buffer.
|
|
n := copy(buf, r.buf[r.lowat:r.hiwat])
|
|
|
|
// Update the low watermark.
|
|
r.lowat += n
|
|
|
|
// Release the buffer and reset the watermarks if it has been fully read.
|
|
if r.lowat == r.hiwat {
|
|
r.msg.ReleaseMsg(r.buf)
|
|
r.buf = nil
|
|
r.lowat = 0
|
|
r.hiwat = 0
|
|
}
|
|
|
|
return n
|
|
}
|
|
|
|
func (r *etmReader) fill() error {
|
|
// Read a message from the underlying msgio.
|
|
msg, err := r.msg.ReadMsg()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Check the MAC.
|
|
n, err := r.macCheckThenDecrypt(msg)
|
|
if err != nil {
|
|
r.msg.ReleaseMsg(msg)
|
|
return err
|
|
}
|
|
|
|
// Retain the buffer so it can be drained from and later released.
|
|
r.buf = msg
|
|
r.lowat = 0
|
|
r.hiwat = n
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *etmReader) Read(buf []byte) (int, error) {
|
|
r.Lock()
|
|
defer r.Unlock()
|
|
|
|
// Return buffered data without reading more, if possible.
|
|
copied := r.drain(buf)
|
|
if copied > 0 {
|
|
return copied, nil
|
|
}
|
|
|
|
// Check the length of the next message.
|
|
fullLen, err := r.msg.NextMsgLen()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// If the destination buffer is too short, fill an internal buffer and then
|
|
// drain as much of that into the output buffer as will fit.
|
|
if len(buf) < fullLen {
|
|
err := r.fill()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
copied := r.drain(buf)
|
|
return copied, nil
|
|
}
|
|
|
|
// Otherwise, read directly into the destination buffer.
|
|
n, err := io.ReadFull(r.msg, buf[:fullLen])
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
m, err := r.macCheckThenDecrypt(buf[:n])
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return m, nil
|
|
}
|
|
|
|
func (r *etmReader) ReadMsg() ([]byte, error) {
|
|
r.Lock()
|
|
defer r.Unlock()
|
|
|
|
msg, err := r.msg.ReadMsg()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
n, err := r.macCheckThenDecrypt(msg)
|
|
if err != nil {
|
|
r.msg.ReleaseMsg(msg)
|
|
return nil, err
|
|
}
|
|
return msg[:n], nil
|
|
}
|
|
|
|
func (r *etmReader) macCheckThenDecrypt(m []byte) (int, error) {
|
|
l := len(m)
|
|
if l < r.mac.size {
|
|
return 0, fmt.Errorf("buffer (%d) shorter than MAC size (%d)", l, r.mac.size)
|
|
}
|
|
|
|
mark := l - r.mac.size
|
|
data := m[:mark]
|
|
macd := m[mark:]
|
|
|
|
r.mac.Write(data)
|
|
r.macBuf = r.mac.Sum(r.macBuf[:0])
|
|
r.mac.Reset()
|
|
|
|
// check mac. if failed, return error.
|
|
if !hmac.Equal(macd, r.macBuf) {
|
|
log.Debug("MAC Invalid:", r.macBuf, "!=", macd)
|
|
return 0, ErrMACInvalid
|
|
}
|
|
|
|
// ok seems good. decrypt. (can decrypt in place, yay!)
|
|
// log.Debugf("DEC ciphertext (%d): %s %v", len(data), data, data)
|
|
r.str.XORKeyStream(data, data)
|
|
// log.Debugf("DEC plaintext (%d): %s %v", len(data), data, data)
|
|
|
|
return mark, nil
|
|
}
|
|
|
|
func (r *etmReader) Close() error {
|
|
return r.msg.Close()
|
|
}
|
|
|
|
// ReleaseMsg signals a buffer can be reused.
|
|
func (r *etmReader) ReleaseMsg(b []byte) {
|
|
r.msg.ReleaseMsg(b)
|
|
}
|
|
|
|
// read and write a message at the same time.
|
|
func readWriteMsg(c msgio.ReadWriter, out []byte) ([]byte, error) {
|
|
wresult := make(chan error)
|
|
go func() {
|
|
wresult <- c.WriteMsg(out)
|
|
}()
|
|
|
|
msg, err1 := c.ReadMsg()
|
|
|
|
// Always wait for the read to finish.
|
|
err2 := <-wresult
|
|
|
|
if err1 != nil {
|
|
return nil, err1
|
|
}
|
|
if err2 != nil {
|
|
c.ReleaseMsg(msg)
|
|
return nil, err2
|
|
}
|
|
return msg, nil
|
|
}
|