// https://wiki.vuze.com/w/Message_Stream_Encryption package mse import ( "bytes" "crypto/rand" "crypto/rc4" "crypto/sha1" "encoding/binary" "errors" "expvar" "fmt" "io" "io/ioutil" "math/big" "strconv" "sync" "github.com/anacrolix/missinggo" "github.com/bradfitz/iter" ) const ( maxPadLen = 512 cryptoMethodPlaintext = 1 cryptoMethodRC4 = 2 ) var ( // Prime P according to the spec, and G, the generator. p, g big.Int // The rand.Int max arg for use in newPadLen() newPadLenMax big.Int // For use in initer's hashes req1 = []byte("req1") req2 = []byte("req2") req3 = []byte("req3") // Verification constant "VC" which is all zeroes in the bittorrent // implementation. vc [8]byte // Zero padding zeroPad [512]byte // Tracks counts of received crypto_provides cryptoProvidesCount = expvar.NewMap("mseCryptoProvides") ) func init() { p.SetString("0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563", 0) g.SetInt64(2) newPadLenMax.SetInt64(maxPadLen + 1) } func hash(parts ...[]byte) []byte { h := sha1.New() for _, p := range parts { n, err := h.Write(p) if err != nil { panic(err) } if n != len(p) { panic(n) } } return h.Sum(nil) } func newEncrypt(initer bool, s []byte, skey []byte) (c *rc4.Cipher) { c, err := rc4.NewCipher(hash([]byte(func() string { if initer { return "keyA" } else { return "keyB" } }()), s, skey)) if err != nil { panic(err) } var burnSrc, burnDst [1024]byte c.XORKeyStream(burnDst[:], burnSrc[:]) return } type cipherReader struct { c *rc4.Cipher r io.Reader } func (cr *cipherReader) Read(b []byte) (n int, err error) { be := make([]byte, len(b)) n, err = cr.r.Read(be) cr.c.XORKeyStream(b[:n], be[:n]) return } func newCipherReader(c *rc4.Cipher, r io.Reader) io.Reader { return &cipherReader{c, r} } type cipherWriter struct { c *rc4.Cipher w io.Writer } func (cr *cipherWriter) Write(b []byte) (n int, err error) { be := make([]byte, len(b)) cr.c.XORKeyStream(be, b) n, err = cr.w.Write(be) if n != len(be) { // The cipher will have advanced beyond the callers stream position. // We can't use the cipher anymore. cr.c = nil } return } func readY(r io.Reader) (y big.Int, err error) { var b [96]byte _, err = io.ReadFull(r, b[:]) if err != nil { return } y.SetBytes(b[:]) return } func newX() big.Int { var X big.Int X.SetBytes(func() []byte { var b [20]byte _, err := rand.Read(b[:]) if err != nil { panic(err) } return b[:] }()) return X } func paddedLeft(b []byte, _len int) []byte { if len(b) == _len { return b } ret := make([]byte, _len) if n := copy(ret[_len-len(b):], b); n != len(b) { panic(n) } return ret } // Calculate, and send Y, our public key. func (h *handshake) postY(x *big.Int) error { var y big.Int y.Exp(&g, x, &p) return h.postWrite(paddedLeft(y.Bytes(), 96)) } func (h *handshake) establishS() (err error) { x := newX() h.postY(&x) var b [96]byte _, err = io.ReadFull(h.conn, b[:]) if err != nil { return } var Y, S big.Int Y.SetBytes(b[:]) S.Exp(&Y, &x, &p) missinggo.CopyExact(&h.s, paddedLeft(S.Bytes(), 96)) return } func newPadLen() int64 { i, err := rand.Int(rand.Reader, &newPadLenMax) if err != nil { panic(err) } ret := i.Int64() if ret < 0 || ret > maxPadLen { panic(ret) } return ret } // Manages state for both initiating and receiving handshakes. type handshake struct { conn io.ReadWriter s [96]byte initer bool // Whether we're initiating or receiving. skeys [][]byte // Skeys we'll accept if receiving. skey []byte // Skey we're initiating with. ia []byte // Initial payload. Only used by the initiator. writeMu sync.Mutex writes [][]byte writeErr error writeCond sync.Cond writeClose bool writerMu sync.Mutex writerCond sync.Cond writerDone bool } func (h *handshake) finishWriting() { h.writeMu.Lock() h.writeClose = true h.writeCond.Broadcast() h.writeMu.Unlock() h.writerMu.Lock() for !h.writerDone { h.writerCond.Wait() } h.writerMu.Unlock() return } func (h *handshake) writer() { defer func() { h.writerMu.Lock() h.writerDone = true h.writerCond.Broadcast() h.writerMu.Unlock() }() for { h.writeMu.Lock() for { if len(h.writes) != 0 { break } if h.writeClose { h.writeMu.Unlock() return } h.writeCond.Wait() } b := h.writes[0] h.writes = h.writes[1:] h.writeMu.Unlock() _, err := h.conn.Write(b) if err != nil { h.writeMu.Lock() h.writeErr = err h.writeMu.Unlock() return } } } func (h *handshake) postWrite(b []byte) error { h.writeMu.Lock() defer h.writeMu.Unlock() if h.writeErr != nil { return h.writeErr } h.writes = append(h.writes, b) h.writeCond.Signal() return nil } func xor(dst, src []byte) (ret []byte) { max := len(dst) if max > len(src) { max = len(src) } ret = make([]byte, 0, max) for i := range iter.N(max) { ret = append(ret, dst[i]^src[i]) } return } func marshal(w io.Writer, data ...interface{}) (err error) { for _, data := range data { err = binary.Write(w, binary.BigEndian, data) if err != nil { break } } return } func unmarshal(r io.Reader, data ...interface{}) (err error) { for _, data := range data { err = binary.Read(r, binary.BigEndian, data) if err != nil { break } } return } // Looking for b at the end of a. func suffixMatchLen(a, b []byte) int { if len(b) > len(a) { b = b[:len(a)] } // i is how much of b to try to match for i := len(b); i > 0; i-- { // j is how many chars we've compared j := 0 for ; j < i; j++ { if b[i-1-j] != a[len(a)-1-j] { goto shorter } } return j shorter: } return 0 } // Reads from r until b has been seen. Keeps the minimum amount of data in // memory. func readUntil(r io.Reader, b []byte) error { b1 := make([]byte, len(b)) i := 0 for { _, err := io.ReadFull(r, b1[i:]) if err != nil { return err } i = suffixMatchLen(b1, b) if i == len(b) { break } if copy(b1, b1[len(b1)-i:]) != i { panic("wat") } } return nil } type readWriter struct { io.Reader io.Writer } func (h *handshake) newEncrypt(initer bool) *rc4.Cipher { return newEncrypt(initer, h.s[:], h.skey) } func (h *handshake) initerSteps() (ret io.ReadWriter, err error) { h.postWrite(hash(req1, h.s[:])) h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:]))) buf := &bytes.Buffer{} padLen := uint16(newPadLen()) err = marshal(buf, vc[:], uint32(cryptoMethodRC4), padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia) if err != nil { return } e := h.newEncrypt(true) be := make([]byte, buf.Len()) e.XORKeyStream(be, buf.Bytes()) h.postWrite(be) bC := h.newEncrypt(false) var eVC [8]byte bC.XORKeyStream(eVC[:], vc[:]) // Read until the all zero VC. At this point we've only read the 96 byte // public key, Y. There is potentially 512 byte padding, between us and // the 8 byte verification constant. err = readUntil(io.LimitReader(h.conn, 520), eVC[:]) if err != nil { if err == io.EOF { err = errors.New("failed to synchronize on VC") } else { err = fmt.Errorf("error reading until VC: %s", err) } return } r := &cipherReader{bC, h.conn} var method uint32 err = unmarshal(r, &method, &padLen) if err != nil { return } if method != cryptoMethodRC4 { err = fmt.Errorf("receiver chose unsupported method: %x", method) return } _, err = io.CopyN(ioutil.Discard, r, int64(padLen)) if err != nil { return } ret = readWriter{r, &cipherWriter{e, h.conn}} return } var ErrNoSecretKeyMatch = errors.New("no skey matched") func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) { // There is up to 512 bytes of padding, then the 20 byte hash. err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:])) if err != nil { if err == io.EOF { err = errors.New("failed to synchronize on S hash") } return } var b [20]byte _, err = io.ReadFull(h.conn, b[:]) if err != nil { return } err = ErrNoSecretKeyMatch for _, skey := range h.skeys { if bytes.Equal(xor(hash(req2, skey), hash(req3, h.s[:])), b[:]) { h.skey = skey err = nil break } } if err != nil { return } r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn) var ( vc [8]byte method uint32 padLen uint16 ) err = unmarshal(r, vc[:], &method, &padLen) if err != nil { return } cryptoProvidesCount.Add(strconv.FormatUint(uint64(method), 16), 1) if method&cryptoMethodRC4 == 0 { err = errors.New("no supported crypto methods were provided") return } _, err = io.CopyN(ioutil.Discard, r, int64(padLen)) if err != nil { return } var lenIA uint16 unmarshal(r, &lenIA) if lenIA != 0 { h.ia = make([]byte, lenIA) unmarshal(r, h.ia) } buf := &bytes.Buffer{} w := cipherWriter{h.newEncrypt(false), buf} padLen = uint16(newPadLen()) err = marshal(&w, &vc, uint32(cryptoMethodRC4), padLen, zeroPad[:padLen]) if err != nil { return } err = h.postWrite(buf.Bytes()) if err != nil { return } ret = readWriter{io.MultiReader(bytes.NewReader(h.ia), r), &cipherWriter{w.c, h.conn}} return } func (h *handshake) Do() (ret io.ReadWriter, err error) { h.writeCond.L = &h.writeMu h.writerCond.L = &h.writerMu go h.writer() defer func() { h.finishWriting() if err == nil { err = h.writeErr } }() err = h.establishS() if err != nil { err = fmt.Errorf("error while establishing secret: %s", err) return } pad := make([]byte, newPadLen()) io.ReadFull(rand.Reader, pad) err = h.postWrite(pad) if err != nil { return } if h.initer { ret, err = h.initerSteps() } else { ret, err = h.receiverSteps() } return } func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte) (ret io.ReadWriter, err error) { h := handshake{ conn: rw, initer: true, skey: skey, ia: initialPayload, } return h.Do() } func ReceiveHandshake(rw io.ReadWriter, skeys [][]byte) (ret io.ReadWriter, err error) { h := handshake{ conn: rw, initer: false, skeys: skeys, } return h.Do() }