2
0
mirror of synced 2025-02-23 14:18:13 +00:00

mse: Tons of fixes and improvements

This commit is contained in:
Matt Joiner 2015-03-13 06:16:49 +11:00
parent 13a5b8b279
commit d57f5896d4
2 changed files with 144 additions and 84 deletions

View File

@ -106,10 +106,6 @@ func (me *cipherWriter) Write(b []byte) (n int, err error) {
return
}
func newCipherWriter(c *rc4.Cipher, w io.Writer) io.Writer {
return &cipherWriter{c, w}
}
func readY(r io.Reader) (y big.Int, err error) {
var b [96]byte
_, err = io.ReadFull(r, b[:])
@ -133,12 +129,17 @@ func newX() big.Int {
return X
}
// Calculate, and send Y, our public key.
func (h *handshake) postY(x *big.Int) error {
var y big.Int
y.Exp(&g, x, &p)
b := y.Bytes()
if len(b) != 96 {
panic(len(b))
b1 := make([]byte, 96)
if n := copy(b1[96-len(b):], b); n != len(b) {
panic(n)
}
b = b1
}
return h.postWrite(b)
}
@ -173,6 +174,7 @@ type handshake struct {
conn io.ReadWriteCloser
s big.Int
initer bool
skeys [][]byte
skey []byte
writeMu sync.Mutex
@ -257,6 +259,26 @@ func xor(dst, src []byte) (ret []byte) {
return
}
func marshal(w io.Writer, data ...interface{}) (err error) {
for _, data := range data {
err = binary.Write(w, binary.BigEndian, data)
if err != nil {
break
}
}
return
}
func unmarshal(r io.Reader, data ...interface{}) (err error) {
for _, data := range data {
err = binary.Read(r, binary.BigEndian, data)
if err != nil {
break
}
}
return
}
type cryptoNegotiation struct {
VC [8]byte
Method uint32
@ -265,7 +287,8 @@ type cryptoNegotiation struct {
}
func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) {
_, err = io.ReadFull(r, me.VC[:])
err = binary.Read(r, binary.BigEndian, me.VC[:])
// _, err = io.ReadFull(r, me.VC[:])
if err != nil {
return
}
@ -283,7 +306,8 @@ func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) {
}
func (me *cryptoNegotiation) MarshalWriter(w io.Writer) (err error) {
_, err = w.Write(me.VC[:])
// _, err = w.Write(me.VC[:])
err = binary.Write(w, binary.BigEndian, me.VC[:])
if err != nil {
return
}
@ -344,9 +368,101 @@ type readWriter struct {
io.Writer
}
func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
h.postWrite(hash(req1, h.s.Bytes()))
h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s.Bytes())))
buf := &bytes.Buffer{}
err = (&cryptoNegotiation{
Method: cryptoMethodRC4,
PadLen: uint16(newPadLen()),
}).MarshalWriter(buf)
if err != nil {
return
}
err = marshal(buf, uint16(0))
if err != nil {
return
}
e := newEncrypt(true, h.s.Bytes(), h.skey)
be := make([]byte, buf.Len())
e.XORKeyStream(be, buf.Bytes())
h.postWrite(be)
bC := newEncrypt(false, h.s.Bytes(), h.skey)
var eVC [8]byte
bC.XORKeyStream(eVC[:], make([]byte, 8))
log.Print(eVC)
// Read until the all zero VC.
err = readUntil(h.conn, eVC[:])
if err != nil {
err = fmt.Errorf("error reading until VC: %s", err)
return
}
var cn cryptoNegotiation
r := &cipherReader{bC, h.conn}
err = cn.UnmarshalReader(io.MultiReader(bytes.NewReader(make([]byte, 8)), r))
log.Printf("initer got %v", cn)
if err != nil {
err = fmt.Errorf("error reading crypto negotiation: %s", err)
return
}
ret = readWriter{r, &cipherWriter{e, h.conn}}
return
}
func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
err = readUntil(h.conn, hash(req1, h.s.Bytes()))
if err != nil {
return
}
var b [20]byte
_, err = io.ReadFull(h.conn, b[:])
if err != nil {
return
}
err = errors.New("skey doesn't match")
for _, skey := range h.skeys {
if bytes.Equal(xor(hash(req2, skey), hash(req3, h.s.Bytes())), b[:]) {
h.skey = skey
err = nil
break
}
}
if err != nil {
return
}
var cn cryptoNegotiation
r := newCipherReader(newEncrypt(true, h.s.Bytes(), h.skey), h.conn)
err = cn.UnmarshalReader(r)
if err != nil {
return
}
log.Printf("receiver got %v", cn)
if cn.Method&cryptoMethodRC4 == 0 {
err = errors.New("no supported crypto methods were provided")
return
}
unmarshal(r, new(uint16))
buf := &bytes.Buffer{}
w := cipherWriter{newEncrypt(false, h.s.Bytes(), h.skey), buf}
err = (&cryptoNegotiation{
Method: cryptoMethodRC4,
PadLen: uint16(newPadLen()),
}).MarshalWriter(&w)
if err != nil {
return
}
err = h.postWrite(buf.Bytes())
if err != nil {
return
}
ret = readWriter{r, &cipherWriter{w.c, h.conn}}
return
}
func (h *handshake) Do() (ret io.ReadWriter, err error) {
err = h.establishS()
if err != nil {
err = fmt.Errorf("error while establishing secret: %s", err)
return
}
pad := make([]byte, newPadLen())
@ -356,92 +472,25 @@ func (h *handshake) Do() (ret io.ReadWriter, err error) {
return
}
if h.initer {
h.postWrite(hash(req1, h.s.Bytes()))
h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s.Bytes())))
buf := &bytes.Buffer{}
err = (&cryptoNegotiation{
Method: cryptoMethodRC4,
PadLen: uint16(newPadLen()),
}).MarshalWriter(buf)
if err != nil {
return
}
e := newEncrypt(true, h.s.Bytes(), h.skey)
be := make([]byte, buf.Len())
e.XORKeyStream(be, buf.Bytes())
h.postWrite(be)
bC := newEncrypt(false, h.s.Bytes(), h.skey)
var eVC [8]byte
bC.XORKeyStream(eVC[:], make([]byte, 8))
log.Print(eVC)
// Read until the all zero VC.
err = readUntil(h.conn, eVC[:])
if err != nil {
err = fmt.Errorf("error reading until VC: %s", err)
return
}
var cn cryptoNegotiation
r := &cipherReader{bC, h.conn}
err = cn.UnmarshalReader(io.MultiReader(bytes.NewReader(make([]byte, 8)), r))
log.Printf("initer got %v", cn)
if err != nil {
err = fmt.Errorf("error reading crypto negotiation: %s", err)
return
}
ret = readWriter{r, &cipherWriter{bC, h.conn}}
ret, err = h.initerSteps()
} else {
err = readUntil(h.conn, hash(req1, h.s.Bytes()))
if err != nil {
return
}
var b [20]byte
_, err = io.ReadFull(h.conn, b[:])
if err != nil {
return
}
if !bytes.Equal(xor(hash(req2, h.skey), hash(req3, h.s.Bytes())), b[:]) {
err = errors.New("skey doesn't match")
return
}
var cn cryptoNegotiation
r := newCipherReader(newEncrypt(true, h.s.Bytes(), h.skey), h.conn)
err = cn.UnmarshalReader(r)
if err != nil {
return
}
log.Printf("receiver got %v", cn)
if cn.Method&cryptoMethodRC4 == 0 {
err = errors.New("no supported crypto methods were provided")
return
}
buf := &bytes.Buffer{}
w := newCipherWriter(newEncrypt(false, h.s.Bytes(), h.skey), buf)
err = (&cryptoNegotiation{
Method: cryptoMethodRC4,
PadLen: uint16(newPadLen()),
}).MarshalWriter(w)
if err != nil {
return
}
log.Println("encrypted VC", buf.Bytes()[:8])
err = h.postWrite(buf.Bytes())
if err != nil {
return
}
ret = readWriter{r, w}
ret, err = h.receiverSteps()
}
if err != nil {
return
}
err = h.finishWriting()
if err != nil {
return
}
ret = h.conn
log.Print("ermahgerd, finished MSE handshake")
return
}
func Handshake(rw io.ReadWriteCloser, initer bool, skey []byte) (ret io.ReadWriter, err error) {
func InitiateHandshake(rw io.ReadWriteCloser, skey []byte) (ret io.ReadWriter, err error) {
h := handshake{
conn: rw,
initer: initer,
initer: true,
skey: skey,
}
h.writeCond.L = &h.writeMu
@ -449,3 +498,14 @@ func Handshake(rw io.ReadWriteCloser, initer bool, skey []byte) (ret io.ReadWrit
go h.writer()
return h.Do()
}
func ReceiveHandshake(rw io.ReadWriteCloser, skeys [][]byte) (ret io.ReadWriter, err error) {
h := handshake{
conn: rw,
initer: false,
skeys: skeys,
}
h.writeCond.L = &h.writeMu
h.writerCond.L = &h.writerMu
go h.writer()
return h.Do()
}

View File

@ -49,7 +49,7 @@ func TestHandshake(t *testing.T) {
wg.Add(2)
go func() {
defer wg.Done()
a, err := Handshake(a, true, []byte("yep"))
a, err := InitiateHandshake(a, []byte("yep"))
if err != nil {
t.Fatal(err)
return
@ -61,7 +61,7 @@ func TestHandshake(t *testing.T) {
}()
go func() {
defer wg.Done()
b, err := Handshake(b, false, []byte("yep"))
b, err := ReceiveHandshake(b, [][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")})
if err != nil {
t.Fatal(err)
return