From 8e8d75dda1094d3f4cfb21550dc1f67a327c58f9 Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Fri, 13 Mar 2015 14:30:48 +1100 Subject: [PATCH] Support initial payload, and improve tests --- mse/mse.go | 28 ++++++++++++++++------------ mse/mse_test.go | 47 ++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 54 insertions(+), 21 deletions(-) diff --git a/mse/mse.go b/mse/mse.go index e9ea9a67..f675b40b 100644 --- a/mse/mse.go +++ b/mse/mse.go @@ -9,11 +9,12 @@ import ( "crypto/sha1" "encoding/binary" "errors" + "expvar" "fmt" "io" "io/ioutil" - "log" "math/big" + "strconv" "sync" "github.com/bradfitz/iter" @@ -35,6 +36,8 @@ var ( req1 = []byte("req1") req2 = []byte("req2") req3 = []byte("req3") + + cryptoProvidesCount = expvar.NewMap("mseCryptoProvides") ) func init() { @@ -176,6 +179,7 @@ type handshake struct { initer bool skeys [][]byte skey []byte + ia []byte // Initial payload. Only used by the initiator. writeMu sync.Mutex writes [][]byte @@ -288,7 +292,6 @@ type cryptoNegotiation struct { func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) { err = binary.Read(r, binary.BigEndian, me.VC[:]) - // _, err = io.ReadFull(r, me.VC[:]) if err != nil { return } @@ -300,7 +303,6 @@ func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) { if err != nil { return } - log.Print(me.PadLen) _, err = io.CopyN(ioutil.Discard, r, int64(me.PadLen)) return } @@ -344,7 +346,6 @@ func suffixMatchLen(a, b []byte) int { } func readUntil(r io.Reader, b []byte) error { - log.Println("read until", b) b1 := make([]byte, len(b)) i := 0 for { @@ -379,7 +380,7 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) { if err != nil { return } - err = marshal(buf, uint16(0)) + err = marshal(buf, uint16(len(h.ia)), h.ia) if err != nil { return } @@ -390,7 +391,6 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) { bC := newEncrypt(false, h.s.Bytes(), h.skey) var eVC [8]byte bC.XORKeyStream(eVC[:], make([]byte, 8)) - log.Print(eVC) // Read until the all zero VC. err = readUntil(h.conn, eVC[:]) if err != nil { @@ -400,7 +400,6 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) { var cn cryptoNegotiation r := &cipherReader{bC, h.conn} err = cn.UnmarshalReader(io.MultiReader(bytes.NewReader(make([]byte, 8)), r)) - log.Printf("initer got %v", cn) if err != nil { err = fmt.Errorf("error reading crypto negotiation: %s", err) return @@ -436,12 +435,17 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) { if err != nil { return } - log.Printf("receiver got %v", cn) + cryptoProvidesCount.Add(strconv.FormatUint(uint64(cn.Method), 16), 1) if cn.Method&cryptoMethodRC4 == 0 { err = errors.New("no supported crypto methods were provided") return } - unmarshal(r, new(uint16)) + var lenIA uint16 + unmarshal(r, &lenIA) + if lenIA != 0 { + h.ia = make([]byte, lenIA) + unmarshal(r, h.ia) + } buf := &bytes.Buffer{} w := cipherWriter{newEncrypt(false, h.s.Bytes(), h.skey), buf} err = (&cryptoNegotiation{ @@ -455,7 +459,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) { if err != nil { return } - ret = readWriter{r, &cipherWriter{w.c, h.conn}} + ret = readWriter{io.MultiReader(bytes.NewReader(h.ia), r), &cipherWriter{w.c, h.conn}} return } @@ -483,15 +487,15 @@ func (h *handshake) Do() (ret io.ReadWriter, err error) { if err != nil { return } - log.Print("ermahgerd, finished MSE handshake") return } -func InitiateHandshake(rw io.ReadWriteCloser, skey []byte) (ret io.ReadWriter, err error) { +func InitiateHandshake(rw io.ReadWriteCloser, skey []byte, initialPayload []byte) (ret io.ReadWriter, err error) { h := handshake{ conn: rw, initer: true, skey: skey, + ia: initialPayload, } h.writeCond.L = &h.writeMu h.writerCond.L = &h.writerMu diff --git a/mse/mse_test.go b/mse/mse_test.go index 644a4cdf..6b5f82c3 100644 --- a/mse/mse_test.go +++ b/mse/mse_test.go @@ -3,10 +3,11 @@ package mse import ( "bytes" "io" - "log" "net" "sync" + "github.com/bradfitz/iter" + "testing" ) @@ -43,21 +44,25 @@ func TestSuffixMatchLen(t *testing.T) { test("sup", "person", 1) } -func TestHandshake(t *testing.T) { +func handshakeTest(t testing.TB, ia []byte, aData, bData string) { a, b := net.Pipe() wg := sync.WaitGroup{} wg.Add(2) go func() { defer wg.Done() - a, err := InitiateHandshake(a, []byte("yep")) + a, err := InitiateHandshake(a, []byte("yep"), ia) if err != nil { t.Fatal(err) return } - a.Write([]byte("hello world")) + go a.Write([]byte(aData)) + var msg [20]byte n, _ := a.Read(msg[:]) - log.Print(string(msg[:n])) + if n != len(bData) { + t.FailNow() + } + // t.Log(string(msg[:n])) }() go func() { defer wg.Done() @@ -66,10 +71,34 @@ func TestHandshake(t *testing.T) { t.Fatal(err) return } - var msg [20]byte - n, _ := b.Read(msg[:]) - log.Print(string(msg[:n])) - b.Write([]byte("yo dawg")) + go b.Write([]byte(bData)) + // Need to be exact here, as there are several reads, and net.Pipe is + // most synchronous. + msg := make([]byte, len(ia)+len(aData)) + n, _ := io.ReadFull(b, msg[:]) + if n != len(msg) { + t.FailNow() + } + // t.Log(string(msg[:n])) }() wg.Wait() + a.Close() + b.Close() +} + +func allHandshakeTests(t testing.TB) { + handshakeTest(t, []byte("jump the gun, "), "hello world", "yo dawg") + handshakeTest(t, nil, "hello world", "yo dawg") + handshakeTest(t, []byte{}, "hello world", "yo dawg") +} + +func TestHandshake(t *testing.T) { + allHandshakeTests(t) + t.Logf("crypto provides encountered: %s", cryptoProvidesCount) +} + +func BenchmarkHandshake(b *testing.B) { + for range iter.N(b.N) { + allHandshakeTests(b) + } }