go-libp2p/p2p/security/noise/transport_test.go

218 lines
5.0 KiB
Go

package noise
import (
"bytes"
"context"
"math/rand"
"net"
"testing"
crypto "github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/sec"
)
func newTestTransport(t *testing.T, typ, bits int) *Transport {
priv, pub, err := crypto.GenerateKeyPair(typ, bits)
if err != nil {
t.Fatal(err)
}
id, err := peer.IDFromPublicKey(pub)
if err != nil {
t.Fatal(err)
}
return &Transport{
localID: id,
privateKey: priv,
}
}
// Create a new pair of connected TCP sockets.
func newConnPair(t *testing.T) (net.Conn, net.Conn) {
lstnr, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
return nil, nil
}
var clientErr error
var client net.Conn
addr := lstnr.Addr()
done := make(chan struct{})
go func() {
defer close(done)
client, clientErr = net.Dial(addr.Network(), addr.String())
}()
server, err := lstnr.Accept()
<-done
lstnr.Close()
if err != nil {
t.Fatalf("Failed to accept: %v", err)
}
if clientErr != nil {
t.Fatalf("Failed to connect: %v", clientErr)
}
return client, server
}
func connect(t *testing.T, initTransport, respTransport *Transport) (*secureSession, *secureSession) {
init, resp := newConnPair(t)
var initConn sec.SecureConn
var initErr error
done := make(chan struct{})
go func() {
defer close(done)
initConn, initErr = initTransport.SecureOutbound(context.TODO(), init, respTransport.localID)
}()
respConn, respErr := respTransport.SecureInbound(context.TODO(), resp)
<-done
if initErr != nil {
t.Fatal(initErr)
}
if respErr != nil {
t.Fatal(respErr)
}
return initConn.(*secureSession), respConn.(*secureSession)
}
func TestIDs(t *testing.T) {
initTransport := newTestTransport(t, crypto.Ed25519, 2048)
respTransport := newTestTransport(t, crypto.Ed25519, 2048)
initConn, respConn := connect(t, initTransport, respTransport)
defer initConn.Close()
defer respConn.Close()
if initConn.LocalPeer() != initTransport.localID {
t.Fatal("Initiator Local Peer ID mismatch.")
}
if respConn.RemotePeer() != initTransport.localID {
t.Fatal("Responder Remote Peer ID mismatch.")
}
if initConn.LocalPeer() != respConn.RemotePeer() {
t.Fatal("Responder Local Peer ID mismatch.")
}
// TODO: check after stage 0 of handshake if updated
if initConn.RemotePeer() != respTransport.localID {
t.Errorf("Initiator Remote Peer ID mismatch. expected %x got %x", respTransport.localID, initConn.RemotePeer())
}
}
func TestKeys(t *testing.T) {
initTransport := newTestTransport(t, crypto.Ed25519, 2048)
respTransport := newTestTransport(t, crypto.Ed25519, 2048)
initConn, respConn := connect(t, initTransport, respTransport)
defer initConn.Close()
defer respConn.Close()
sk := respConn.LocalPrivateKey()
pk := sk.GetPublic()
if !sk.Equals(respTransport.privateKey) {
t.Error("Private key Mismatch.")
}
if !pk.Equals(initConn.RemotePublicKey()) {
t.Errorf("Public key mismatch. expected %x got %x", pk, initConn.RemotePublicKey())
}
}
func TestPeerIDMismatchFailsHandshake(t *testing.T) {
initTransport := newTestTransport(t, crypto.Ed25519, 2048)
respTransport := newTestTransport(t, crypto.Ed25519, 2048)
init, resp := newConnPair(t)
var initErr error
done := make(chan struct{})
go func() {
defer close(done)
_, initErr = initTransport.SecureOutbound(context.TODO(), init, "a-random-peer-id")
}()
_, _ = respTransport.SecureInbound(context.TODO(), resp)
<-done
if initErr == nil {
t.Fatal("expected initiator to fail with peer ID mismatch error")
}
}
func makeLargePlaintext(size int) []byte {
buf := make([]byte, size)
rand.Read(buf)
return buf
}
func TestLargePayloads(t *testing.T) {
initTransport := newTestTransport(t, crypto.Ed25519, 2048)
respTransport := newTestTransport(t, crypto.Ed25519, 2048)
initConn, respConn := connect(t, initTransport, respTransport)
defer initConn.Close()
defer respConn.Close()
// enough to require a couple Noise messages, with a size that
// isn't a neat multiple of Noise message size, just in case
size := 100000
before := makeLargePlaintext(size)
_, err := initConn.Write(before)
if err != nil {
t.Fatal(err)
}
after := make([]byte, len(before))
afterLen, err := respConn.Read(after)
if err != nil {
t.Fatal(err)
}
if len(before) != afterLen {
t.Errorf("expected to read same amount of data as written. written=%d read=%d", len(before), afterLen)
}
if !bytes.Equal(before, after) {
t.Error("Message mismatch.")
}
}
// Tests XX handshake
func TestHandshakeXX(t *testing.T) {
initTransport := newTestTransport(t, crypto.Ed25519, 2048)
respTransport := newTestTransport(t, crypto.Ed25519, 2048)
initConn, respConn := connect(t, initTransport, respTransport)
defer initConn.Close()
defer respConn.Close()
before := []byte("hello world")
_, err := initConn.Write(before)
if err != nil {
t.Fatal(err)
}
after := make([]byte, len(before))
_, err = respConn.Read(after)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(before, after) {
t.Errorf("Message mismatch. %v != %v", before, after)
}
}