diff --git a/db.go b/db.go index cfee9cb..117445d 100644 --- a/db.go +++ b/db.go @@ -2,34 +2,364 @@ package rendezvous import ( "context" + "crypto/rand" + "database/sql" "errors" + "os" + "time" + + _ "github.com/mattn/go-sqlite3" peer "github.com/libp2p/go-libp2p-peer" ) type DB struct { + db *sql.DB + + insertPeerRegistration *sql.Stmt + deletePeerRegistrations *sql.Stmt + deletePeerRegistrationsNs *sql.Stmt + countPeerRegistrations *sql.Stmt + selectPeerRegistrations *sql.Stmt + selectPeerRegistrationsNS *sql.Stmt + selectPeerRegistrationsC *sql.Stmt + selectPeerRegistrationsNSC *sql.Stmt + deleteExpiredRegistrations *sql.Stmt + + nonce []byte + + cancel func() } func OpenDB(ctx context.Context, path string) (*DB, error) { - return nil, errors.New("IMPLEMENTME: OpenDB") + var create bool + if path == ":memory:" { + create = true + } else { + _, err := os.Stat(path) + switch { + case os.IsNotExist(err): + create = true + case err != nil: + return nil, err + } + } + + db, err := sql.Open("sqlite3", path) + if err != nil { + return nil, err + } + + rdb := &DB{db: db} + if create { + err = rdb.prepareDB() + if err != nil { + db.Close() + return nil, err + } + } else { + err = rdb.loadNonce() + if err != nil { + db.Close() + return nil, err + } + } + + err = rdb.prepareStmts() + if err != nil { + db.Close() + return nil, err + } + + bgctx, cancel := context.WithCancel(ctx) + rdb.cancel = cancel + go rdb.background(bgctx) + + return rdb, nil +} + +func (db *DB) Close() error { + db.cancel() + return db.db.Close() +} + +func (db *DB) prepareDB() error { + _, err := db.db.Exec("CREATE TABLE Registrations (counter INTEGER PRIMARY KEY AUTOINCREMENT, peer VARCHAR(64), ns VARCHAR, expire INTEGER, addrs VARBINARY)") + if err != nil { + return err + } + + _, err = db.db.Exec("CREATE TABLE Nonce (nonce VARBINARY)") + if err != nil { + return err + } + + nonce := make([]byte, 16) + _, err = rand.Read(nonce) + if err != nil { + return err + } + + _, err = db.db.Exec("INSERT INTO Nonce VALUES (?)", nonce) + if err != nil { + return err + } + + db.nonce = nonce + return nil +} + +func (db *DB) loadNonce() error { + var nonce []byte + row := db.db.QueryRow("SELECT nonce FROM Nonce") + err := row.Scan(&nonce) + if err != nil { + return err + } + + db.nonce = nonce + return nil +} + +func (db *DB) prepareStmts() error { + stmt, err := db.db.Prepare("INSERT INTO Registrations VALUES (NULL, ?, ?, ?, ?)") + if err != nil { + return err + } + db.insertPeerRegistration = stmt + + stmt, err = db.db.Prepare("DELETE FROM Registrations WHERE peer = ?") + if err != nil { + return err + } + db.deletePeerRegistrations = stmt + + stmt, err = db.db.Prepare("DELETE FROM Registrations WHERE peer = ? AND ns = ?") + if err != nil { + return err + } + db.deletePeerRegistrationsNs = stmt + + stmt, err = db.db.Prepare("SELECT COUNT(*) FROM Registrations WHERE peer = ?") + if err != nil { + return err + } + db.countPeerRegistrations = stmt + + stmt, err = db.db.Prepare("SELECT * FROM Registrations WHERE expire > ? LIMIT ?") + if err != nil { + return err + } + db.selectPeerRegistrations = stmt + + stmt, err = db.db.Prepare("SELECT * FROM Registrations WHERE ns = ? AND expire > ? LIMIT ?") + if err != nil { + return err + } + db.selectPeerRegistrationsNS = stmt + + stmt, err = db.db.Prepare("SELECT * FROM Registrations WHERE counter > ? AND expire > ? LIMIT ?") + if err != nil { + return err + } + db.selectPeerRegistrationsC = stmt + + stmt, err = db.db.Prepare("SELECT * FROM Registrations WHERE counter > ? AND ns = ? AND expire > ? LIMIT ?") + if err != nil { + return err + } + db.selectPeerRegistrationsNSC = stmt + + stmt, err = db.db.Prepare("DELETE FROM Registrations WHERE expire < ?") + if err != nil { + return err + } + db.deleteExpiredRegistrations = stmt + + return nil } func (db *DB) Register(p peer.ID, ns string, addrs [][]byte, ttl int) error { - return errors.New("IMPLEMENTME: DB.Register") + pid := p.Pretty() + maddrs := packAddrs(addrs) + expire := time.Now().Unix() + int64(ttl) + + tx, err := db.db.Begin() + if err != nil { + return err + } + + delOld := tx.Stmt(db.deletePeerRegistrationsNs) + insertNew := tx.Stmt(db.insertPeerRegistration) + + _, err = delOld.Exec(pid, ns) + if err != nil { + tx.Rollback() + return err + } + + _, err = insertNew.Exec(pid, ns, expire, maddrs) + if err != nil { + tx.Rollback() + return err + } + + return tx.Commit() } func (db *DB) CountRegistrations(p peer.ID) (int, error) { - return 0, errors.New("IMPLEMENTME: DB.CountRegistrations") + pid := p.Pretty() + + row := db.countPeerRegistrations.QueryRow(pid) + + var count int + err := row.Scan(&count) + + return count, err } func (db *DB) Unregister(p peer.ID, ns string) error { - return errors.New("IMPLEMENTME: DB.Unregister") -} + pid := p.Pretty() -func (db *DB) ValidCookie(ns string, cookie []byte) bool { - return false + var err error + + if ns == "" { + _, err = db.deletePeerRegistrations.Exec(pid) + } else { + _, err = db.deletePeerRegistrationsNs.Exec(pid, ns) + } + + return err } func (db *DB) Discover(ns string, cookie []byte, limit int) ([]RegistrationRecord, []byte, error) { - return nil, nil, errors.New("IMPLEMENTME: DB.Discover") + now := time.Now().Unix() + + var ( + counter int64 + rows *sql.Rows + err error + ) + + if cookie != nil { + counter, err = cookieToCounter(cookie) + if err != nil { + log.Errorf("error unpacking cookie: %s", err.Error()) + return nil, nil, err + } + } + + if counter > 0 { + if ns == "" { + rows, err = db.selectPeerRegistrationsC.Query(counter, now, limit) + } else { + rows, err = db.selectPeerRegistrationsNSC.Query(counter, ns, now, limit) + } + } else { + if ns == "" { + rows, err = db.selectPeerRegistrations.Query(now, limit) + } else { + rows, err = db.selectPeerRegistrations.Query(ns, now, limit) + } + } + + if err != nil { + log.Errorf("query error: %s", err.Error()) + return nil, nil, err + } + + defer rows.Close() + + regs := make([]RegistrationRecord, 0, limit) + for rows.Next() { + var ( + reg RegistrationRecord + rid string + rns string + expire int64 + raddrs []byte + addrs [][]byte + p peer.ID + ) + + err = rows.Scan(&counter, &rid, &rns, &expire, &raddrs) + if err != nil { + log.Errorf("row scan error: %s", err.Error()) + return nil, nil, err + } + + p, err = peer.IDB58Decode(rid) + if err != nil { + log.Errorf("error decoding peer id: %s", err.Error()) + continue + } + reg.Id = p + + addrs, err = unpackAddrs(raddrs) + if err != nil { + log.Errorf("error unpacking address: %s", err.Error()) + continue + } + reg.Addrs = addrs + + reg.Ttl = int(expire - now) + + if ns == "" { + reg.Ns = rns + } + + regs = append(regs, reg) + } + + err = rows.Err() + if err != nil { + return nil, nil, err + } + + if counter > 0 { + cookie = counterToCookie(counter, ns, db.nonce) + } + + return regs, cookie, nil +} + +func (db *DB) ValidCookie(ns string, cookie []byte) bool { + // XXX + return false +} + +func (db *DB) background(ctx context.Context) { + for { + now := time.Now().Unix() + _, err := db.deleteExpiredRegistrations.Exec(now) + if err != nil { + log.Errorf("error deleting expired registrations: %s", err.Error()) + } + + select { + case <-time.After(15 * time.Minute): + case <-ctx.Done(): + return + } + } +} + +func packAddrs(addrs [][]byte) []byte { + // XXX + return nil +} + +func unpackAddrs(maddrs []byte) ([][]byte, error) { + // XXX + return nil, errors.New("IMPLEMENTME: unpackAddrs") +} + +func cookieToCounter(cookie []byte) (int64, error) { + // XXX + return 0, errors.New("IMPLEMENTME: cookieToCounter") +} + +func counterToCookie(counter int64, ns string, nonce []byte) []byte { + // XXX + return nil }