mse: Support plaintext crypto mode
This commit is contained in:
parent
11165d4fa5
commit
29e06fb83c
82
mse/mse.go
82
mse/mse.go
@ -13,6 +13,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"math/big"
|
||||
"strconv"
|
||||
"sync"
|
||||
@ -25,6 +26,7 @@ const (
|
||||
|
||||
cryptoMethodPlaintext = 1
|
||||
cryptoMethodRC4 = 2
|
||||
AllSupportedCrypto = cryptoMethodPlaintext | cryptoMethodRC4
|
||||
)
|
||||
|
||||
var (
|
||||
@ -209,6 +211,10 @@ type handshake struct {
|
||||
skeys SecretKeyIter // Skeys we'll accept if receiving.
|
||||
skey []byte // Skey we're initiating with.
|
||||
ia []byte // Initial payload. Only used by the initiator.
|
||||
// Return the bit for the crypto method the receiver wants to use.
|
||||
chooseMethod func(supported uint32) uint32
|
||||
// Sent to the receiver.
|
||||
cryptoProvides uint32
|
||||
|
||||
writeMu sync.Mutex
|
||||
writes [][]byte
|
||||
@ -365,11 +371,11 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
|
||||
h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
|
||||
buf := &bytes.Buffer{}
|
||||
padLen := uint16(newPadLen())
|
||||
err = marshal(buf, vc[:], uint32(cryptoMethodRC4), padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
|
||||
if len(h.ia) > math.MaxUint16 {
|
||||
err = errors.New("initial payload too large")
|
||||
return
|
||||
}
|
||||
err = marshal(buf, vc[:], h.cryptoProvides, padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -398,15 +404,18 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if method != cryptoMethodRC4 {
|
||||
err = fmt.Errorf("receiver chose unsupported method: %x", method)
|
||||
return
|
||||
}
|
||||
_, err = io.CopyN(ioutil.Discard, r, int64(padLen))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
|
||||
switch method & h.cryptoProvides {
|
||||
case cryptoMethodRC4:
|
||||
ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
|
||||
case cryptoMethodPlaintext:
|
||||
ret = h.conn
|
||||
default:
|
||||
err = fmt.Errorf("receiver chose unsupported method: %x", method)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@ -440,20 +449,17 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
|
||||
}
|
||||
r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn)
|
||||
var (
|
||||
vc [8]byte
|
||||
method uint32
|
||||
padLen uint16
|
||||
vc [8]byte
|
||||
provides uint32
|
||||
padLen uint16
|
||||
)
|
||||
|
||||
err = unmarshal(r, vc[:], &method, &padLen)
|
||||
err = unmarshal(r, vc[:], &provides, &padLen)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cryptoProvidesCount.Add(strconv.FormatUint(uint64(method), 16), 1)
|
||||
if method&cryptoMethodRC4 == 0 {
|
||||
err = errors.New("no supported crypto methods were provided")
|
||||
return
|
||||
}
|
||||
cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
|
||||
chosen := h.chooseMethod(provides)
|
||||
_, err = io.CopyN(ioutil.Discard, r, int64(padLen))
|
||||
if err != nil {
|
||||
return
|
||||
@ -467,7 +473,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
|
||||
buf := &bytes.Buffer{}
|
||||
w := cipherWriter{h.newEncrypt(false), buf, nil}
|
||||
padLen = uint16(newPadLen())
|
||||
err = marshal(&w, &vc, uint32(cryptoMethodRC4), padLen, zeroPad[:padLen])
|
||||
err = marshal(&w, &vc, uint32(chosen), padLen, zeroPad[:padLen])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -475,7 +481,20 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ret = readWriter{io.MultiReader(bytes.NewReader(h.ia), r), &cipherWriter{w.c, h.conn, nil}}
|
||||
switch chosen {
|
||||
case cryptoMethodRC4:
|
||||
ret = readWriter{
|
||||
io.MultiReader(bytes.NewReader(h.ia), r),
|
||||
&cipherWriter{w.c, h.conn, nil},
|
||||
}
|
||||
case cryptoMethodPlaintext:
|
||||
ret = readWriter{
|
||||
io.MultiReader(bytes.NewReader(h.ia), h.conn),
|
||||
h.conn,
|
||||
}
|
||||
default:
|
||||
err = errors.New("chosen crypto method is not supported")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@ -508,21 +527,23 @@ func (h *handshake) Do() (ret io.ReadWriter, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte) (ret io.ReadWriter, err error) {
|
||||
func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides uint32) (ret io.ReadWriter, err error) {
|
||||
h := handshake{
|
||||
conn: rw,
|
||||
initer: true,
|
||||
skey: skey,
|
||||
ia: initialPayload,
|
||||
conn: rw,
|
||||
initer: true,
|
||||
skey: skey,
|
||||
ia: initialPayload,
|
||||
cryptoProvides: cryptoProvides,
|
||||
}
|
||||
return h.Do()
|
||||
}
|
||||
|
||||
func ReceiveHandshake(rw io.ReadWriter, skeys [][]byte) (ret io.ReadWriter, err error) {
|
||||
func ReceiveHandshake(rw io.ReadWriter, skeys [][]byte, selectCrypto func(uint32) uint32) (ret io.ReadWriter, err error) {
|
||||
h := handshake{
|
||||
conn: rw,
|
||||
initer: false,
|
||||
skeys: sliceIter(skeys),
|
||||
conn: rw,
|
||||
initer: false,
|
||||
skeys: sliceIter(skeys),
|
||||
chooseMethod: selectCrypto,
|
||||
}
|
||||
return h.Do()
|
||||
}
|
||||
@ -551,3 +572,12 @@ func ReceiveHandshakeLazy(rw io.ReadWriter, skeys SecretKeyIter) (ret io.ReadWri
|
||||
}
|
||||
return h.Do()
|
||||
}
|
||||
|
||||
func DefaultCryptoSelector(provided uint32) uint32 {
|
||||
if provided&cryptoMethodRC4 != 0 {
|
||||
return cryptoMethodRC4
|
||||
}
|
||||
return cryptoMethodPlaintext
|
||||
}
|
||||
|
||||
type CryptoSelector func(uint32) uint32
|
||||
|
107
mse/mse_test.go
107
mse/mse_test.go
@ -10,6 +10,8 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
_ "github.com/anacrolix/envpprof"
|
||||
|
||||
"github.com/bradfitz/iter"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@ -47,13 +49,13 @@ func TestSuffixMatchLen(t *testing.T) {
|
||||
test("sup", "person", 1)
|
||||
}
|
||||
|
||||
func handshakeTest(t testing.TB, ia []byte, aData, bData string) {
|
||||
func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides uint32, cryptoSelect func(uint32) uint32) {
|
||||
a, b := net.Pipe()
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
a, err := InitiateHandshake(a, []byte("yep"), ia)
|
||||
a, err := InitiateHandshake(a, []byte("yep"), ia, cryptoProvides)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@ -69,7 +71,7 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string) {
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
b, err := ReceiveHandshake(b, [][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")})
|
||||
b, err := ReceiveHandshake(b, [][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}, cryptoSelect)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@ -89,20 +91,24 @@ func handshakeTest(t testing.TB, ia []byte, aData, bData string) {
|
||||
b.Close()
|
||||
}
|
||||
|
||||
func allHandshakeTests(t testing.TB) {
|
||||
handshakeTest(t, []byte("jump the gun, "), "hello world", "yo dawg")
|
||||
handshakeTest(t, nil, "hello world", "yo dawg")
|
||||
handshakeTest(t, []byte{}, "hello world", "yo dawg")
|
||||
func allHandshakeTests(t testing.TB, provides uint32, selector CryptoSelector) {
|
||||
handshakeTest(t, []byte("jump the gun, "), "hello world", "yo dawg", provides, selector)
|
||||
handshakeTest(t, nil, "hello world", "yo dawg", provides, selector)
|
||||
handshakeTest(t, []byte{}, "hello world", "yo dawg", provides, selector)
|
||||
}
|
||||
|
||||
func TestHandshake(t *testing.T) {
|
||||
allHandshakeTests(t)
|
||||
func TestHandshakeDefault(t *testing.T) {
|
||||
allHandshakeTests(t, AllSupportedCrypto, DefaultCryptoSelector)
|
||||
t.Logf("crypto provides encountered: %s", cryptoProvidesCount)
|
||||
}
|
||||
|
||||
func BenchmarkHandshake(b *testing.B) {
|
||||
func TestHandshakeSelectPlaintext(t *testing.T) {
|
||||
allHandshakeTests(t, AllSupportedCrypto, func(uint32) uint32 { return cryptoMethodPlaintext })
|
||||
}
|
||||
|
||||
func BenchmarkHandshakeDefault(b *testing.B) {
|
||||
for range iter.N(b.N) {
|
||||
allHandshakeTests(b)
|
||||
allHandshakeTests(b, AllSupportedCrypto, DefaultCryptoSelector)
|
||||
}
|
||||
}
|
||||
|
||||
@ -119,7 +125,7 @@ func (tr *trackReader) Read(b []byte) (n int, err error) {
|
||||
|
||||
func TestReceiveRandomData(t *testing.T) {
|
||||
tr := trackReader{rand.Reader, 0}
|
||||
_, err := ReceiveHandshake(readWriter{&tr, ioutil.Discard}, nil)
|
||||
_, err := ReceiveHandshake(readWriter{&tr, ioutil.Discard}, nil, DefaultCryptoSelector)
|
||||
// No skey matches
|
||||
require.Error(t, err)
|
||||
// Establishing S, and then reading the maximum padding for giving up on
|
||||
@ -127,7 +133,82 @@ func TestReceiveRandomData(t *testing.T) {
|
||||
require.EqualValues(t, 96+532, tr.n)
|
||||
}
|
||||
|
||||
func BenchmarkPipe(t *testing.B) {
|
||||
func fillRand(t testing.TB, bs ...[]byte) {
|
||||
for _, b := range bs {
|
||||
_, err := rand.Read(b)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func readAndWrite(rw io.ReadWriter, r []byte, w []byte) error {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
var wErr error
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, wErr = rw.Write(w)
|
||||
}()
|
||||
_, err := io.ReadFull(rw, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
wg.Wait()
|
||||
return wErr
|
||||
}
|
||||
|
||||
func benchmarkStream(t *testing.B, crypto uint32) {
|
||||
ia := make([]byte, 0x1000)
|
||||
a := make([]byte, 1<<20)
|
||||
b := make([]byte, 1<<20)
|
||||
fillRand(t, ia, a, b)
|
||||
t.StopTimer()
|
||||
t.SetBytes(int64(len(ia) + len(a) + len(b)))
|
||||
t.ResetTimer()
|
||||
for range iter.N(t.N) {
|
||||
ac, bc := net.Pipe()
|
||||
ar := make([]byte, len(b))
|
||||
br := make([]byte, len(ia)+len(a))
|
||||
t.StartTimer()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer ac.Close()
|
||||
defer wg.Done()
|
||||
rw, err := InitiateHandshake(ac, []byte("cats"), ia, crypto)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, readAndWrite(rw, ar, a))
|
||||
}()
|
||||
func() {
|
||||
defer bc.Close()
|
||||
rw, err := ReceiveHandshake(bc, [][]byte{[]byte("cats")}, func(uint32) uint32 { return crypto })
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, readAndWrite(rw, br, b))
|
||||
}()
|
||||
t.StopTimer()
|
||||
if !bytes.Equal(ar, b) {
|
||||
t.Fatalf("A read the wrong bytes")
|
||||
}
|
||||
if !bytes.Equal(br[:len(ia)], ia) {
|
||||
t.Fatalf("B read the wrong IA")
|
||||
}
|
||||
if !bytes.Equal(br[len(ia):], a) {
|
||||
t.Fatalf("B read the wrong A")
|
||||
}
|
||||
// require.Equal(t, b, ar)
|
||||
// require.Equal(t, ia, br[:len(ia)])
|
||||
// require.Equal(t, a, br[len(ia):])
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStreamRC4(t *testing.B) {
|
||||
benchmarkStream(t, cryptoMethodRC4)
|
||||
}
|
||||
|
||||
func BenchmarkStreamPlaintext(t *testing.B) {
|
||||
benchmarkStream(t, cryptoMethodPlaintext)
|
||||
}
|
||||
|
||||
func BenchmarkPipeRC4(t *testing.B) {
|
||||
key := make([]byte, 20)
|
||||
n, _ := rand.Read(key)
|
||||
require.Equal(t, len(key), n)
|
||||
|
Loading…
x
Reference in New Issue
Block a user