mse: Tons of fixes and improvements
This commit is contained in:
parent
13a5b8b279
commit
d57f5896d4
224
mse/mse.go
224
mse/mse.go
@ -106,10 +106,6 @@ func (me *cipherWriter) Write(b []byte) (n int, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func newCipherWriter(c *rc4.Cipher, w io.Writer) io.Writer {
|
||||
return &cipherWriter{c, w}
|
||||
}
|
||||
|
||||
func readY(r io.Reader) (y big.Int, err error) {
|
||||
var b [96]byte
|
||||
_, err = io.ReadFull(r, b[:])
|
||||
@ -133,12 +129,17 @@ func newX() big.Int {
|
||||
return X
|
||||
}
|
||||
|
||||
// Calculate, and send Y, our public key.
|
||||
func (h *handshake) postY(x *big.Int) error {
|
||||
var y big.Int
|
||||
y.Exp(&g, x, &p)
|
||||
b := y.Bytes()
|
||||
if len(b) != 96 {
|
||||
panic(len(b))
|
||||
b1 := make([]byte, 96)
|
||||
if n := copy(b1[96-len(b):], b); n != len(b) {
|
||||
panic(n)
|
||||
}
|
||||
b = b1
|
||||
}
|
||||
return h.postWrite(b)
|
||||
}
|
||||
@ -173,6 +174,7 @@ type handshake struct {
|
||||
conn io.ReadWriteCloser
|
||||
s big.Int
|
||||
initer bool
|
||||
skeys [][]byte
|
||||
skey []byte
|
||||
|
||||
writeMu sync.Mutex
|
||||
@ -257,6 +259,26 @@ func xor(dst, src []byte) (ret []byte) {
|
||||
return
|
||||
}
|
||||
|
||||
func marshal(w io.Writer, data ...interface{}) (err error) {
|
||||
for _, data := range data {
|
||||
err = binary.Write(w, binary.BigEndian, data)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func unmarshal(r io.Reader, data ...interface{}) (err error) {
|
||||
for _, data := range data {
|
||||
err = binary.Read(r, binary.BigEndian, data)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type cryptoNegotiation struct {
|
||||
VC [8]byte
|
||||
Method uint32
|
||||
@ -265,7 +287,8 @@ type cryptoNegotiation struct {
|
||||
}
|
||||
|
||||
func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) {
|
||||
_, err = io.ReadFull(r, me.VC[:])
|
||||
err = binary.Read(r, binary.BigEndian, me.VC[:])
|
||||
// _, err = io.ReadFull(r, me.VC[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -283,7 +306,8 @@ func (me *cryptoNegotiation) UnmarshalReader(r io.Reader) (err error) {
|
||||
}
|
||||
|
||||
func (me *cryptoNegotiation) MarshalWriter(w io.Writer) (err error) {
|
||||
_, err = w.Write(me.VC[:])
|
||||
// _, err = w.Write(me.VC[:])
|
||||
err = binary.Write(w, binary.BigEndian, me.VC[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -344,9 +368,101 @@ type readWriter struct {
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
|
||||
h.postWrite(hash(req1, h.s.Bytes()))
|
||||
h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s.Bytes())))
|
||||
buf := &bytes.Buffer{}
|
||||
err = (&cryptoNegotiation{
|
||||
Method: cryptoMethodRC4,
|
||||
PadLen: uint16(newPadLen()),
|
||||
}).MarshalWriter(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = marshal(buf, uint16(0))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
e := newEncrypt(true, h.s.Bytes(), h.skey)
|
||||
be := make([]byte, buf.Len())
|
||||
e.XORKeyStream(be, buf.Bytes())
|
||||
h.postWrite(be)
|
||||
bC := newEncrypt(false, h.s.Bytes(), h.skey)
|
||||
var eVC [8]byte
|
||||
bC.XORKeyStream(eVC[:], make([]byte, 8))
|
||||
log.Print(eVC)
|
||||
// Read until the all zero VC.
|
||||
err = readUntil(h.conn, eVC[:])
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error reading until VC: %s", err)
|
||||
return
|
||||
}
|
||||
var cn cryptoNegotiation
|
||||
r := &cipherReader{bC, h.conn}
|
||||
err = cn.UnmarshalReader(io.MultiReader(bytes.NewReader(make([]byte, 8)), r))
|
||||
log.Printf("initer got %v", cn)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error reading crypto negotiation: %s", err)
|
||||
return
|
||||
}
|
||||
ret = readWriter{r, &cipherWriter{e, h.conn}}
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
|
||||
err = readUntil(h.conn, hash(req1, h.s.Bytes()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var b [20]byte
|
||||
_, err = io.ReadFull(h.conn, b[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = errors.New("skey doesn't match")
|
||||
for _, skey := range h.skeys {
|
||||
if bytes.Equal(xor(hash(req2, skey), hash(req3, h.s.Bytes())), b[:]) {
|
||||
h.skey = skey
|
||||
err = nil
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var cn cryptoNegotiation
|
||||
r := newCipherReader(newEncrypt(true, h.s.Bytes(), h.skey), h.conn)
|
||||
err = cn.UnmarshalReader(r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
log.Printf("receiver got %v", cn)
|
||||
if cn.Method&cryptoMethodRC4 == 0 {
|
||||
err = errors.New("no supported crypto methods were provided")
|
||||
return
|
||||
}
|
||||
unmarshal(r, new(uint16))
|
||||
buf := &bytes.Buffer{}
|
||||
w := cipherWriter{newEncrypt(false, h.s.Bytes(), h.skey), buf}
|
||||
err = (&cryptoNegotiation{
|
||||
Method: cryptoMethodRC4,
|
||||
PadLen: uint16(newPadLen()),
|
||||
}).MarshalWriter(&w)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = h.postWrite(buf.Bytes())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ret = readWriter{r, &cipherWriter{w.c, h.conn}}
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handshake) Do() (ret io.ReadWriter, err error) {
|
||||
err = h.establishS()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error while establishing secret: %s", err)
|
||||
return
|
||||
}
|
||||
pad := make([]byte, newPadLen())
|
||||
@ -356,92 +472,25 @@ func (h *handshake) Do() (ret io.ReadWriter, err error) {
|
||||
return
|
||||
}
|
||||
if h.initer {
|
||||
h.postWrite(hash(req1, h.s.Bytes()))
|
||||
h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s.Bytes())))
|
||||
buf := &bytes.Buffer{}
|
||||
err = (&cryptoNegotiation{
|
||||
Method: cryptoMethodRC4,
|
||||
PadLen: uint16(newPadLen()),
|
||||
}).MarshalWriter(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
e := newEncrypt(true, h.s.Bytes(), h.skey)
|
||||
be := make([]byte, buf.Len())
|
||||
e.XORKeyStream(be, buf.Bytes())
|
||||
h.postWrite(be)
|
||||
bC := newEncrypt(false, h.s.Bytes(), h.skey)
|
||||
var eVC [8]byte
|
||||
bC.XORKeyStream(eVC[:], make([]byte, 8))
|
||||
log.Print(eVC)
|
||||
// Read until the all zero VC.
|
||||
err = readUntil(h.conn, eVC[:])
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error reading until VC: %s", err)
|
||||
return
|
||||
}
|
||||
var cn cryptoNegotiation
|
||||
r := &cipherReader{bC, h.conn}
|
||||
err = cn.UnmarshalReader(io.MultiReader(bytes.NewReader(make([]byte, 8)), r))
|
||||
log.Printf("initer got %v", cn)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error reading crypto negotiation: %s", err)
|
||||
return
|
||||
}
|
||||
ret = readWriter{r, &cipherWriter{bC, h.conn}}
|
||||
ret, err = h.initerSteps()
|
||||
} else {
|
||||
err = readUntil(h.conn, hash(req1, h.s.Bytes()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var b [20]byte
|
||||
_, err = io.ReadFull(h.conn, b[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(xor(hash(req2, h.skey), hash(req3, h.s.Bytes())), b[:]) {
|
||||
err = errors.New("skey doesn't match")
|
||||
return
|
||||
}
|
||||
var cn cryptoNegotiation
|
||||
r := newCipherReader(newEncrypt(true, h.s.Bytes(), h.skey), h.conn)
|
||||
err = cn.UnmarshalReader(r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
log.Printf("receiver got %v", cn)
|
||||
if cn.Method&cryptoMethodRC4 == 0 {
|
||||
err = errors.New("no supported crypto methods were provided")
|
||||
return
|
||||
}
|
||||
buf := &bytes.Buffer{}
|
||||
w := newCipherWriter(newEncrypt(false, h.s.Bytes(), h.skey), buf)
|
||||
err = (&cryptoNegotiation{
|
||||
Method: cryptoMethodRC4,
|
||||
PadLen: uint16(newPadLen()),
|
||||
}).MarshalWriter(w)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
log.Println("encrypted VC", buf.Bytes()[:8])
|
||||
err = h.postWrite(buf.Bytes())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ret = readWriter{r, w}
|
||||
ret, err = h.receiverSteps()
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = h.finishWriting()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ret = h.conn
|
||||
log.Print("ermahgerd, finished MSE handshake")
|
||||
return
|
||||
}
|
||||
|
||||
func Handshake(rw io.ReadWriteCloser, initer bool, skey []byte) (ret io.ReadWriter, err error) {
|
||||
func InitiateHandshake(rw io.ReadWriteCloser, skey []byte) (ret io.ReadWriter, err error) {
|
||||
h := handshake{
|
||||
conn: rw,
|
||||
initer: initer,
|
||||
initer: true,
|
||||
skey: skey,
|
||||
}
|
||||
h.writeCond.L = &h.writeMu
|
||||
@ -449,3 +498,14 @@ func Handshake(rw io.ReadWriteCloser, initer bool, skey []byte) (ret io.ReadWrit
|
||||
go h.writer()
|
||||
return h.Do()
|
||||
}
|
||||
func ReceiveHandshake(rw io.ReadWriteCloser, skeys [][]byte) (ret io.ReadWriter, err error) {
|
||||
h := handshake{
|
||||
conn: rw,
|
||||
initer: false,
|
||||
skeys: skeys,
|
||||
}
|
||||
h.writeCond.L = &h.writeMu
|
||||
h.writerCond.L = &h.writerMu
|
||||
go h.writer()
|
||||
return h.Do()
|
||||
}
|
||||
|
@ -49,7 +49,7 @@ func TestHandshake(t *testing.T) {
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
a, err := Handshake(a, true, []byte("yep"))
|
||||
a, err := InitiateHandshake(a, []byte("yep"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@ -61,7 +61,7 @@ func TestHandshake(t *testing.T) {
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
b, err := Handshake(b, false, []byte("yep"))
|
||||
b, err := ReceiveHandshake(b, [][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
|
Loading…
x
Reference in New Issue
Block a user