mirror of https://github.com/status-im/op-geth.git
359 lines
9.4 KiB
Go
359 lines
9.4 KiB
Go
package p2p
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/rand"
|
|
"errors"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net"
|
|
"reflect"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/davecgh/go-spew/spew"
|
|
"github.com/ethereum/go-ethereum/crypto"
|
|
"github.com/ethereum/go-ethereum/crypto/ecies"
|
|
"github.com/ethereum/go-ethereum/crypto/sha3"
|
|
"github.com/ethereum/go-ethereum/p2p/discover"
|
|
"github.com/ethereum/go-ethereum/rlp"
|
|
)
|
|
|
|
func TestSharedSecret(t *testing.T) {
|
|
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
|
|
pub0 := &prv0.PublicKey
|
|
prv1, _ := crypto.GenerateKey()
|
|
pub1 := &prv1.PublicKey
|
|
|
|
ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
|
|
if err != nil {
|
|
return
|
|
}
|
|
ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
|
|
if err != nil {
|
|
return
|
|
}
|
|
t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
|
|
if !bytes.Equal(ss0, ss1) {
|
|
t.Errorf("dont match :(")
|
|
}
|
|
}
|
|
|
|
func TestEncHandshake(t *testing.T) {
|
|
for i := 0; i < 10; i++ {
|
|
start := time.Now()
|
|
if err := testEncHandshake(nil); err != nil {
|
|
t.Fatalf("i=%d %v", i, err)
|
|
}
|
|
t.Logf("(without token) %d %v\n", i+1, time.Since(start))
|
|
}
|
|
for i := 0; i < 10; i++ {
|
|
tok := make([]byte, shaLen)
|
|
rand.Reader.Read(tok)
|
|
start := time.Now()
|
|
if err := testEncHandshake(tok); err != nil {
|
|
t.Fatalf("i=%d %v", i, err)
|
|
}
|
|
t.Logf("(with token) %d %v\n", i+1, time.Since(start))
|
|
}
|
|
}
|
|
|
|
func testEncHandshake(token []byte) error {
|
|
type result struct {
|
|
side string
|
|
id discover.NodeID
|
|
err error
|
|
}
|
|
var (
|
|
prv0, _ = crypto.GenerateKey()
|
|
prv1, _ = crypto.GenerateKey()
|
|
fd0, fd1 = net.Pipe()
|
|
c0, c1 = newRLPX(fd0).(*rlpx), newRLPX(fd1).(*rlpx)
|
|
output = make(chan result)
|
|
)
|
|
|
|
go func() {
|
|
r := result{side: "initiator"}
|
|
defer func() { output <- r }()
|
|
|
|
dest := &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey)}
|
|
r.id, r.err = c0.doEncHandshake(prv0, dest)
|
|
if r.err != nil {
|
|
return
|
|
}
|
|
id1 := discover.PubkeyID(&prv1.PublicKey)
|
|
if r.id != id1 {
|
|
r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id1)
|
|
}
|
|
}()
|
|
go func() {
|
|
r := result{side: "receiver"}
|
|
defer func() { output <- r }()
|
|
|
|
r.id, r.err = c1.doEncHandshake(prv1, nil)
|
|
if r.err != nil {
|
|
return
|
|
}
|
|
id0 := discover.PubkeyID(&prv0.PublicKey)
|
|
if r.id != id0 {
|
|
r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id0)
|
|
}
|
|
}()
|
|
|
|
// wait for results from both sides
|
|
r1, r2 := <-output, <-output
|
|
if r1.err != nil {
|
|
return fmt.Errorf("%s side error: %v", r1.side, r1.err)
|
|
}
|
|
if r2.err != nil {
|
|
return fmt.Errorf("%s side error: %v", r2.side, r2.err)
|
|
}
|
|
|
|
// compare derived secrets
|
|
if !reflect.DeepEqual(c0.rw.egressMAC, c1.rw.ingressMAC) {
|
|
return fmt.Errorf("egress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.egressMAC, c1.rw.ingressMAC)
|
|
}
|
|
if !reflect.DeepEqual(c0.rw.ingressMAC, c1.rw.egressMAC) {
|
|
return fmt.Errorf("ingress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.ingressMAC, c1.rw.egressMAC)
|
|
}
|
|
if !reflect.DeepEqual(c0.rw.enc, c1.rw.enc) {
|
|
return fmt.Errorf("enc cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.enc, c1.rw.enc)
|
|
}
|
|
if !reflect.DeepEqual(c0.rw.dec, c1.rw.dec) {
|
|
return fmt.Errorf("dec cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.dec, c1.rw.dec)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func TestProtocolHandshake(t *testing.T) {
|
|
var (
|
|
prv0, _ = crypto.GenerateKey()
|
|
node0 = &discover.Node{ID: discover.PubkeyID(&prv0.PublicKey), IP: net.IP{1, 2, 3, 4}, TCP: 33}
|
|
hs0 = &protoHandshake{Version: 3, ID: node0.ID, Caps: []Cap{{"a", 0}, {"b", 2}}}
|
|
|
|
prv1, _ = crypto.GenerateKey()
|
|
node1 = &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey), IP: net.IP{5, 6, 7, 8}, TCP: 44}
|
|
hs1 = &protoHandshake{Version: 3, ID: node1.ID, Caps: []Cap{{"c", 1}, {"d", 3}}}
|
|
|
|
fd0, fd1 = net.Pipe()
|
|
wg sync.WaitGroup
|
|
)
|
|
|
|
wg.Add(2)
|
|
go func() {
|
|
defer wg.Done()
|
|
rlpx := newRLPX(fd0)
|
|
remid, err := rlpx.doEncHandshake(prv0, node1)
|
|
if err != nil {
|
|
t.Errorf("dial side enc handshake failed: %v", err)
|
|
return
|
|
}
|
|
if remid != node1.ID {
|
|
t.Errorf("dial side remote id mismatch: got %v, want %v", remid, node1.ID)
|
|
return
|
|
}
|
|
|
|
phs, err := rlpx.doProtoHandshake(hs0)
|
|
if err != nil {
|
|
t.Errorf("dial side proto handshake error: %v", err)
|
|
return
|
|
}
|
|
if !reflect.DeepEqual(phs, hs1) {
|
|
t.Errorf("dial side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs1))
|
|
return
|
|
}
|
|
rlpx.close(DiscQuitting)
|
|
}()
|
|
go func() {
|
|
defer wg.Done()
|
|
rlpx := newRLPX(fd1)
|
|
remid, err := rlpx.doEncHandshake(prv1, nil)
|
|
if err != nil {
|
|
t.Errorf("listen side enc handshake failed: %v", err)
|
|
return
|
|
}
|
|
if remid != node0.ID {
|
|
t.Errorf("listen side remote id mismatch: got %v, want %v", remid, node0.ID)
|
|
return
|
|
}
|
|
|
|
phs, err := rlpx.doProtoHandshake(hs1)
|
|
if err != nil {
|
|
t.Errorf("listen side proto handshake error: %v", err)
|
|
return
|
|
}
|
|
if !reflect.DeepEqual(phs, hs0) {
|
|
t.Errorf("listen side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs0))
|
|
return
|
|
}
|
|
|
|
if err := ExpectMsg(rlpx, discMsg, []DiscReason{DiscQuitting}); err != nil {
|
|
t.Errorf("error receiving disconnect: %v", err)
|
|
}
|
|
}()
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestProtocolHandshakeErrors(t *testing.T) {
|
|
our := &protoHandshake{Version: 3, Caps: []Cap{{"foo", 2}, {"bar", 3}}, Name: "quux"}
|
|
id := randomID()
|
|
tests := []struct {
|
|
code uint64
|
|
msg interface{}
|
|
err error
|
|
}{
|
|
{
|
|
code: discMsg,
|
|
msg: []DiscReason{DiscQuitting},
|
|
err: DiscQuitting,
|
|
},
|
|
{
|
|
code: 0x989898,
|
|
msg: []byte{1},
|
|
err: errors.New("expected handshake, got 989898"),
|
|
},
|
|
{
|
|
code: handshakeMsg,
|
|
msg: make([]byte, baseProtocolMaxMsgSize+2),
|
|
err: errors.New("message too big"),
|
|
},
|
|
{
|
|
code: handshakeMsg,
|
|
msg: []byte{1, 2, 3},
|
|
err: newPeerError(errInvalidMsg, "(code 0) (size 4) rlp: expected input list for p2p.protoHandshake"),
|
|
},
|
|
{
|
|
code: handshakeMsg,
|
|
msg: &protoHandshake{Version: 9944, ID: id},
|
|
err: DiscIncompatibleVersion,
|
|
},
|
|
{
|
|
code: handshakeMsg,
|
|
msg: &protoHandshake{Version: 3},
|
|
err: DiscInvalidIdentity,
|
|
},
|
|
}
|
|
|
|
for i, test := range tests {
|
|
p1, p2 := MsgPipe()
|
|
go Send(p1, test.code, test.msg)
|
|
_, err := readProtocolHandshake(p2, our)
|
|
if !reflect.DeepEqual(err, test.err) {
|
|
t.Errorf("test %d: error mismatch: got %q, want %q", i, err, test.err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestRLPXFrameFake(t *testing.T) {
|
|
buf := new(bytes.Buffer)
|
|
hash := fakeHash([]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})
|
|
rw := newRLPXFrameRW(buf, secrets{
|
|
AES: crypto.Sha3(),
|
|
MAC: crypto.Sha3(),
|
|
IngressMAC: hash,
|
|
EgressMAC: hash,
|
|
})
|
|
|
|
golden := unhex(`
|
|
00828ddae471818bb0bfa6b551d1cb42
|
|
01010101010101010101010101010101
|
|
ba628a4ba590cb43f7848f41c4382885
|
|
01010101010101010101010101010101
|
|
`)
|
|
|
|
// Check WriteMsg. This puts a message into the buffer.
|
|
if err := Send(rw, 8, []uint{1, 2, 3, 4}); err != nil {
|
|
t.Fatalf("WriteMsg error: %v", err)
|
|
}
|
|
written := buf.Bytes()
|
|
if !bytes.Equal(written, golden) {
|
|
t.Fatalf("output mismatch:\n got: %x\n want: %x", written, golden)
|
|
}
|
|
|
|
// Check ReadMsg. It reads the message encoded by WriteMsg, which
|
|
// is equivalent to the golden message above.
|
|
msg, err := rw.ReadMsg()
|
|
if err != nil {
|
|
t.Fatalf("ReadMsg error: %v", err)
|
|
}
|
|
if msg.Size != 5 {
|
|
t.Errorf("msg size mismatch: got %d, want %d", msg.Size, 5)
|
|
}
|
|
if msg.Code != 8 {
|
|
t.Errorf("msg code mismatch: got %d, want %d", msg.Code, 8)
|
|
}
|
|
payload, _ := ioutil.ReadAll(msg.Payload)
|
|
wantPayload := unhex("C401020304")
|
|
if !bytes.Equal(payload, wantPayload) {
|
|
t.Errorf("msg payload mismatch:\ngot %x\nwant %x", payload, wantPayload)
|
|
}
|
|
}
|
|
|
|
type fakeHash []byte
|
|
|
|
func (fakeHash) Write(p []byte) (int, error) { return len(p), nil }
|
|
func (fakeHash) Reset() {}
|
|
func (fakeHash) BlockSize() int { return 0 }
|
|
|
|
func (h fakeHash) Size() int { return len(h) }
|
|
func (h fakeHash) Sum(b []byte) []byte { return append(b, h...) }
|
|
|
|
func TestRLPXFrameRW(t *testing.T) {
|
|
var (
|
|
aesSecret = make([]byte, 16)
|
|
macSecret = make([]byte, 16)
|
|
egressMACinit = make([]byte, 32)
|
|
ingressMACinit = make([]byte, 32)
|
|
)
|
|
for _, s := range [][]byte{aesSecret, macSecret, egressMACinit, ingressMACinit} {
|
|
rand.Read(s)
|
|
}
|
|
conn := new(bytes.Buffer)
|
|
|
|
s1 := secrets{
|
|
AES: aesSecret,
|
|
MAC: macSecret,
|
|
EgressMAC: sha3.NewKeccak256(),
|
|
IngressMAC: sha3.NewKeccak256(),
|
|
}
|
|
s1.EgressMAC.Write(egressMACinit)
|
|
s1.IngressMAC.Write(ingressMACinit)
|
|
rw1 := newRLPXFrameRW(conn, s1)
|
|
|
|
s2 := secrets{
|
|
AES: aesSecret,
|
|
MAC: macSecret,
|
|
EgressMAC: sha3.NewKeccak256(),
|
|
IngressMAC: sha3.NewKeccak256(),
|
|
}
|
|
s2.EgressMAC.Write(ingressMACinit)
|
|
s2.IngressMAC.Write(egressMACinit)
|
|
rw2 := newRLPXFrameRW(conn, s2)
|
|
|
|
// send some messages
|
|
for i := 0; i < 10; i++ {
|
|
// write message into conn buffer
|
|
wmsg := []interface{}{"foo", "bar", strings.Repeat("test", i)}
|
|
err := Send(rw1, uint64(i), wmsg)
|
|
if err != nil {
|
|
t.Fatalf("WriteMsg error (i=%d): %v", i, err)
|
|
}
|
|
|
|
// read message that rw1 just wrote
|
|
msg, err := rw2.ReadMsg()
|
|
if err != nil {
|
|
t.Fatalf("ReadMsg error (i=%d): %v", i, err)
|
|
}
|
|
if msg.Code != uint64(i) {
|
|
t.Fatalf("msg code mismatch: got %d, want %d", msg.Code, i)
|
|
}
|
|
payload, _ := ioutil.ReadAll(msg.Payload)
|
|
wantPayload, _ := rlp.EncodeToBytes(wmsg)
|
|
if !bytes.Equal(payload, wantPayload) {
|
|
t.Fatalf("msg payload mismatch:\ngot %x\nwant %x", payload, wantPayload)
|
|
}
|
|
}
|
|
}
|