diff --git a/db.go b/db.go index 117445d..dc3786e 100644 --- a/db.go +++ b/db.go @@ -1,10 +1,13 @@ package rendezvous import ( + "bytes" "context" "crypto/rand" + "crypto/sha256" "database/sql" - "errors" + "encoding/binary" + "fmt" "os" "time" @@ -242,7 +245,7 @@ func (db *DB) Discover(ns string, cookie []byte, limit int) ([]RegistrationRecor ) if cookie != nil { - counter, err = cookieToCounter(cookie) + counter, err = unpackCookie(cookie) if err != nil { log.Errorf("error unpacking cookie: %s", err.Error()) return nil, nil, err @@ -293,15 +296,15 @@ func (db *DB) Discover(ns string, cookie []byte, limit int) ([]RegistrationRecor log.Errorf("error decoding peer id: %s", err.Error()) continue } - reg.Id = p - addrs, err = unpackAddrs(raddrs) + addrs, err := unpackAddrs(raddrs) if err != nil { log.Errorf("error unpacking address: %s", err.Error()) continue } - reg.Addrs = addrs + reg.Id = p + reg.Addrs = addrs reg.Ttl = int(expire - now) if ns == "" { @@ -317,15 +320,14 @@ func (db *DB) Discover(ns string, cookie []byte, limit int) ([]RegistrationRecor } if counter > 0 { - cookie = counterToCookie(counter, ns, db.nonce) + cookie = packCookie(counter, ns, db.nonce) } return regs, cookie, nil } func (db *DB) ValidCookie(ns string, cookie []byte) bool { - // XXX - return false + return validCookie(cookie, ns, db.nonce) } func (db *DB) background(ctx context.Context) { @@ -345,21 +347,97 @@ func (db *DB) background(ctx context.Context) { } func packAddrs(addrs [][]byte) []byte { - // XXX - return nil + packlen := 0 + for _, addr := range addrs { + packlen = packlen + 2 + len(addr) + } + + packed := make([]byte, packlen) + buf := packed + for _, addr := range addrs { + binary.BigEndian.PutUint16(buf, uint16(len(addr))) + buf = buf[2:] + copy(buf, addr) + buf = buf[len(addr):] + } + + return packed } -func unpackAddrs(maddrs []byte) ([][]byte, error) { - // XXX - return nil, errors.New("IMPLEMENTME: unpackAddrs") +func unpackAddrs(packed []byte) ([][]byte, error) { + var addrs [][]byte + + buf := packed + for len(buf) > 1 { + l := binary.BigEndian.Uint16(buf) + buf = buf[2:] + if len(buf) < int(l) { + return nil, fmt.Errorf("bad packed address: not enough bytes %v %v", packed, buf) + } + addr := make([]byte, l) + copy(addr, buf[:l]) + buf = buf[l:] + addrs = append(addrs, addr) + } + + if len(buf) > 0 { + return nil, fmt.Errorf("bad packed address: unprocessed bytes: %v %v", packed, buf) + } + + return addrs, nil } -func cookieToCounter(cookie []byte) (int64, error) { - // XXX - return 0, errors.New("IMPLEMENTME: cookieToCounter") +// cookie: counter:SHA256(nonce + ns + counter) +func packCookie(counter int64, ns string, nonce []byte) []byte { + cbits := make([]byte, 8) + binary.BigEndian.PutUint64(cbits, uint64(counter)) + + hash := sha256.New() + _, err := hash.Write(nonce) + if err != nil { + panic(err) + } + _, err = hash.Write([]byte(ns)) + if err != nil { + panic(err) + } + _, err = hash.Write(cbits) + if err != nil { + panic(err) + } + + return hash.Sum(cbits) } -func counterToCookie(counter int64, ns string, nonce []byte) []byte { - // XXX - return nil +func unpackCookie(cookie []byte) (int64, error) { + if len(cookie) < 8 { + return 0, fmt.Errorf("bad packed cookie: not enough bytes: %v", cookie) + } + + counter := binary.BigEndian.Uint64(cookie[:8]) + return int64(counter), nil +} + +func validCookie(cookie []byte, ns string, nonce []byte) bool { + if len(cookie) != 40 { + return false + } + + cbits := cookie[:8] + hash := sha256.New() + _, err := hash.Write(nonce) + if err != nil { + panic(err) + } + _, err = hash.Write([]byte(ns)) + if err != nil { + panic(err) + } + _, err = hash.Write(cbits) + if err != nil { + panic(err) + } + hbits := hash.Sum(nil) + + return bytes.Equal(cookie[8:], hbits) }