2
0
mirror of synced 2025-02-23 06:08:07 +00:00

mse: Support plaintext crypto mode

This commit is contained in:
Matt Joiner 2017-09-13 16:19:14 +10:00
parent 11165d4fa5
commit 29e06fb83c
2 changed files with 150 additions and 39 deletions

View File

@ -13,6 +13,7 @@ import (
"fmt"
"io"
"io/ioutil"
"math"
"math/big"
"strconv"
"sync"
@ -25,6 +26,7 @@ const (
cryptoMethodPlaintext = 1
cryptoMethodRC4 = 2
AllSupportedCrypto = cryptoMethodPlaintext | cryptoMethodRC4
)
var (
@ -209,6 +211,10 @@ type handshake struct {
skeys SecretKeyIter // Skeys we'll accept if receiving.
skey []byte // Skey we're initiating with.
ia []byte // Initial payload. Only used by the initiator.
// Return the bit for the crypto method the receiver wants to use.
chooseMethod func(supported uint32) uint32
// Sent to the receiver.
cryptoProvides uint32
writeMu sync.Mutex
writes [][]byte
@ -365,11 +371,11 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
buf := &bytes.Buffer{}
padLen := uint16(newPadLen())
err = marshal(buf, vc[:], uint32(cryptoMethodRC4), padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
if len(h.ia) > math.MaxUint16 {
err = errors.New("initial payload too large")
return
}
err = marshal(buf, vc[:], h.cryptoProvides, padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
if err != nil {
return
}
@ -398,15 +404,18 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
if err != nil {
return
}
if method != cryptoMethodRC4 {
err = fmt.Errorf("receiver chose unsupported method: %x", method)
return
}
_, err = io.CopyN(ioutil.Discard, r, int64(padLen))
if err != nil {
return
}
ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
switch method & h.cryptoProvides {
case cryptoMethodRC4:
ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
case cryptoMethodPlaintext:
ret = h.conn
default:
err = fmt.Errorf("receiver chose unsupported method: %x", method)
}
return
}
@ -440,20 +449,17 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
}
r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn)
var (
vc [8]byte
method uint32
padLen uint16
vc [8]byte
provides uint32
padLen uint16
)
err = unmarshal(r, vc[:], &method, &padLen)
err = unmarshal(r, vc[:], &provides, &padLen)
if err != nil {
return
}
cryptoProvidesCount.Add(strconv.FormatUint(uint64(method), 16), 1)
if method&cryptoMethodRC4 == 0 {
err = errors.New("no supported crypto methods were provided")
return
}
cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
chosen := h.chooseMethod(provides)
_, err = io.CopyN(ioutil.Discard, r, int64(padLen))
if err != nil {
return
@ -467,7 +473,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
buf := &bytes.Buffer{}
w := cipherWriter{h.newEncrypt(false), buf, nil}
padLen = uint16(newPadLen())
err = marshal(&w, &vc, uint32(cryptoMethodRC4), padLen, zeroPad[:padLen])
err = marshal(&w, &vc, uint32(chosen), padLen, zeroPad[:padLen])
if err != nil {
return
}
@ -475,7 +481,20 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
if err != nil {
return
}
ret = readWriter{io.MultiReader(bytes.NewReader(h.ia), r), &cipherWriter{w.c, h.conn, nil}}
switch chosen {
case cryptoMethodRC4:
ret = readWriter{
io.MultiReader(bytes.NewReader(h.ia), r),
&cipherWriter{w.c, h.conn, nil},
}
case cryptoMethodPlaintext:
ret = readWriter{
io.MultiReader(bytes.NewReader(h.ia), h.conn),
h.conn,
}
default:
err = errors.New("chosen crypto method is not supported")
}
return
}
@ -508,21 +527,23 @@ func (h *handshake) Do() (ret io.ReadWriter, err error) {
return
}
func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte) (ret io.ReadWriter, err error) {
func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides uint32) (ret io.ReadWriter, err error) {
h := handshake{
conn: rw,
initer: true,
skey: skey,
ia: initialPayload,
conn: rw,
initer: true,
skey: skey,
ia: initialPayload,
cryptoProvides: cryptoProvides,
}
return h.Do()
}
func ReceiveHandshake(rw io.ReadWriter, skeys [][]byte) (ret io.ReadWriter, err error) {
func ReceiveHandshake(rw io.ReadWriter, skeys [][]byte, selectCrypto func(uint32) uint32) (ret io.ReadWriter, err error) {
h := handshake{
conn: rw,
initer: false,
skeys: sliceIter(skeys),
conn: rw,
initer: false,
skeys: sliceIter(skeys),
chooseMethod: selectCrypto,
}
return h.Do()
}
@ -551,3 +572,12 @@ func ReceiveHandshakeLazy(rw io.ReadWriter, skeys SecretKeyIter) (ret io.ReadWri
}
return h.Do()
}
func DefaultCryptoSelector(provided uint32) uint32 {
if provided&cryptoMethodRC4 != 0 {
return cryptoMethodRC4
}
return cryptoMethodPlaintext
}
type CryptoSelector func(uint32) uint32

View File

@ -10,6 +10,8 @@ import (
"sync"
"testing"
_ "github.com/anacrolix/envpprof"
"github.com/bradfitz/iter"
"github.com/stretchr/testify/require"
)
@ -47,13 +49,13 @@ func TestSuffixMatchLen(t *testing.T) {
test("sup", "person", 1)
}
func handshakeTest(t testing.TB, ia []byte, aData, bData string) {
func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides uint32, cryptoSelect func(uint32) uint32) {
a, b := net.Pipe()
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
defer wg.Done()
a, err := InitiateHandshake(a, []byte("yep"), ia)
a, err := InitiateHandshake(a, []byte("yep"), ia, cryptoProvides)
if err != nil {
t.Fatal(err)
return
@ -69,7 +71,7 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string) {
}()
go func() {
defer wg.Done()
b, err := ReceiveHandshake(b, [][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")})
b, err := ReceiveHandshake(b, [][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}, cryptoSelect)
if err != nil {
t.Fatal(err)
return
@ -89,20 +91,24 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string) {
b.Close()
}
func allHandshakeTests(t testing.TB) {
handshakeTest(t, []byte("jump the gun, "), "hello world", "yo dawg")
handshakeTest(t, nil, "hello world", "yo dawg")
handshakeTest(t, []byte{}, "hello world", "yo dawg")
func allHandshakeTests(t testing.TB, provides uint32, selector CryptoSelector) {
handshakeTest(t, []byte("jump the gun, "), "hello world", "yo dawg", provides, selector)
handshakeTest(t, nil, "hello world", "yo dawg", provides, selector)
handshakeTest(t, []byte{}, "hello world", "yo dawg", provides, selector)
}
func TestHandshake(t *testing.T) {
allHandshakeTests(t)
func TestHandshakeDefault(t *testing.T) {
allHandshakeTests(t, AllSupportedCrypto, DefaultCryptoSelector)
t.Logf("crypto provides encountered: %s", cryptoProvidesCount)
}
func BenchmarkHandshake(b *testing.B) {
func TestHandshakeSelectPlaintext(t *testing.T) {
allHandshakeTests(t, AllSupportedCrypto, func(uint32) uint32 { return cryptoMethodPlaintext })
}
func BenchmarkHandshakeDefault(b *testing.B) {
for range iter.N(b.N) {
allHandshakeTests(b)
allHandshakeTests(b, AllSupportedCrypto, DefaultCryptoSelector)
}
}
@ -119,7 +125,7 @@ func (tr *trackReader) Read(b []byte) (n int, err error) {
func TestReceiveRandomData(t *testing.T) {
tr := trackReader{rand.Reader, 0}
_, err := ReceiveHandshake(readWriter{&tr, ioutil.Discard}, nil)
_, err := ReceiveHandshake(readWriter{&tr, ioutil.Discard}, nil, DefaultCryptoSelector)
// No skey matches
require.Error(t, err)
// Establishing S, and then reading the maximum padding for giving up on
@ -127,7 +133,82 @@ func TestReceiveRandomData(t *testing.T) {
require.EqualValues(t, 96+532, tr.n)
}
func BenchmarkPipe(t *testing.B) {
func fillRand(t testing.TB, bs ...[]byte) {
for _, b := range bs {
_, err := rand.Read(b)
require.NoError(t, err)
}
}
func readAndWrite(rw io.ReadWriter, r []byte, w []byte) error {
var wg sync.WaitGroup
wg.Add(1)
var wErr error
go func() {
defer wg.Done()
_, wErr = rw.Write(w)
}()
_, err := io.ReadFull(rw, r)
if err != nil {
return err
}
wg.Wait()
return wErr
}
func benchmarkStream(t *testing.B, crypto uint32) {
ia := make([]byte, 0x1000)
a := make([]byte, 1<<20)
b := make([]byte, 1<<20)
fillRand(t, ia, a, b)
t.StopTimer()
t.SetBytes(int64(len(ia) + len(a) + len(b)))
t.ResetTimer()
for range iter.N(t.N) {
ac, bc := net.Pipe()
ar := make([]byte, len(b))
br := make([]byte, len(ia)+len(a))
t.StartTimer()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer ac.Close()
defer wg.Done()
rw, err := InitiateHandshake(ac, []byte("cats"), ia, crypto)
require.NoError(t, err)
require.NoError(t, readAndWrite(rw, ar, a))
}()
func() {
defer bc.Close()
rw, err := ReceiveHandshake(bc, [][]byte{[]byte("cats")}, func(uint32) uint32 { return crypto })
require.NoError(t, err)
require.NoError(t, readAndWrite(rw, br, b))
}()
t.StopTimer()
if !bytes.Equal(ar, b) {
t.Fatalf("A read the wrong bytes")
}
if !bytes.Equal(br[:len(ia)], ia) {
t.Fatalf("B read the wrong IA")
}
if !bytes.Equal(br[len(ia):], a) {
t.Fatalf("B read the wrong A")
}
// require.Equal(t, b, ar)
// require.Equal(t, ia, br[:len(ia)])
// require.Equal(t, a, br[len(ia):])
}
}
func BenchmarkStreamRC4(t *testing.B) {
benchmarkStream(t, cryptoMethodRC4)
}
func BenchmarkStreamPlaintext(t *testing.B) {
benchmarkStream(t, cryptoMethodPlaintext)
}
func BenchmarkPipeRC4(t *testing.B) {
key := make([]byte, 20)
n, _ := rand.Read(key)
require.Equal(t, len(key), n)