move go-libp2p-pnet here
This commit is contained in:
commit
a225be04d9
|
@ -0,0 +1,18 @@
|
||||||
|
package pnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
ipnet "github.com/libp2p/go-libp2p-core/pnet"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewProtectedConn creates a new protected connection
|
||||||
|
func NewProtectedConn(psk ipnet.PSK, conn net.Conn) (net.Conn, error) {
|
||||||
|
if len(psk) != 32 {
|
||||||
|
return nil, errors.New("expected 32 byte PSK")
|
||||||
|
}
|
||||||
|
var p [32]byte
|
||||||
|
copy(p[:], psk)
|
||||||
|
return newPSKConn(&p, conn)
|
||||||
|
}
|
|
@ -0,0 +1,83 @@
|
||||||
|
package pnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/rand"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/libp2p/go-libp2p-core/pnet"
|
||||||
|
|
||||||
|
"github.com/davidlazar/go-crypto/salsa20"
|
||||||
|
pool "github.com/libp2p/go-buffer-pool"
|
||||||
|
)
|
||||||
|
|
||||||
|
// we are using buffer pool as user needs their slice back
|
||||||
|
// so we can't do XOR cripter in place
|
||||||
|
var (
|
||||||
|
errShortNonce = pnet.NewError("could not read full nonce")
|
||||||
|
errInsecureNil = pnet.NewError("insecure is nil")
|
||||||
|
errPSKNil = pnet.NewError("pre-shread key is nil")
|
||||||
|
)
|
||||||
|
|
||||||
|
type pskConn struct {
|
||||||
|
net.Conn
|
||||||
|
psk *[32]byte
|
||||||
|
|
||||||
|
writeS20 cipher.Stream
|
||||||
|
readS20 cipher.Stream
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *pskConn) Read(out []byte) (int, error) {
|
||||||
|
if c.readS20 == nil {
|
||||||
|
nonce := make([]byte, 24)
|
||||||
|
_, err := io.ReadFull(c.Conn, nonce)
|
||||||
|
if err != nil {
|
||||||
|
return 0, errShortNonce
|
||||||
|
}
|
||||||
|
c.readS20 = salsa20.New(c.psk, nonce)
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := c.Conn.Read(out) // read to in
|
||||||
|
if n > 0 {
|
||||||
|
c.readS20.XORKeyStream(out[:n], out[:n]) // decrypt to out buffer
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *pskConn) Write(in []byte) (int, error) {
|
||||||
|
if c.writeS20 == nil {
|
||||||
|
nonce := make([]byte, 24)
|
||||||
|
_, err := rand.Read(nonce)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
_, err = c.Conn.Write(nonce)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.writeS20 = salsa20.New(c.psk, nonce)
|
||||||
|
}
|
||||||
|
out := pool.Get(len(in))
|
||||||
|
defer pool.Put(out)
|
||||||
|
|
||||||
|
c.writeS20.XORKeyStream(out, in) // encrypt
|
||||||
|
|
||||||
|
return c.Conn.Write(out) // send
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ net.Conn = (*pskConn)(nil)
|
||||||
|
|
||||||
|
func newPSKConn(psk *[32]byte, insecure net.Conn) (net.Conn, error) {
|
||||||
|
if insecure == nil {
|
||||||
|
return nil, errInsecureNil
|
||||||
|
}
|
||||||
|
if psk == nil {
|
||||||
|
return nil, errPSKNil
|
||||||
|
}
|
||||||
|
return &pskConn{
|
||||||
|
Conn: insecure,
|
||||||
|
psk: psk,
|
||||||
|
}, nil
|
||||||
|
}
|
|
@ -0,0 +1,92 @@
|
||||||
|
package pnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupPSKConns(ctx context.Context, t *testing.T) (net.Conn, net.Conn) {
|
||||||
|
testPSK := make([]byte, 32) // null bytes are as good test key as any other key
|
||||||
|
conn1, conn2 := net.Pipe()
|
||||||
|
|
||||||
|
psk1, err := NewProtectedConn(testPSK, conn1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
psk2, err := NewProtectedConn(testPSK, conn2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return psk1, psk2
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPSKSimpelMessges(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.TODO())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
psk1, psk2 := setupPSKConns(ctx, t)
|
||||||
|
msg1 := []byte("hello world")
|
||||||
|
out1 := make([]byte, len(msg1))
|
||||||
|
|
||||||
|
wch := make(chan error)
|
||||||
|
go func() {
|
||||||
|
_, err := psk1.Write(msg1)
|
||||||
|
wch <- err
|
||||||
|
}()
|
||||||
|
n, err := psk2.Read(out1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = <-wch
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n != len(out1) {
|
||||||
|
t.Fatalf("expected to read %d bytes, read: %d", len(out1), n)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(msg1, out1) {
|
||||||
|
t.Fatalf("input and output are not the same")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPSKFragmentation(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.TODO())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
psk1, psk2 := setupPSKConns(ctx, t)
|
||||||
|
|
||||||
|
in := make([]byte, 1000)
|
||||||
|
_, err := rand.Read(in)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make([]byte, 100)
|
||||||
|
|
||||||
|
wch := make(chan error)
|
||||||
|
go func() {
|
||||||
|
_, err := psk1.Write(in)
|
||||||
|
wch <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
if _, err := psk2.Read(out); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(in[:100], out) {
|
||||||
|
t.Fatalf("input and output are not the same")
|
||||||
|
}
|
||||||
|
in = in[100:]
|
||||||
|
}
|
||||||
|
|
||||||
|
err = <-wch
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue