diff --git a/db/dbi.go b/db/dbi.go index b6fe8f6..3568877 100644 --- a/db/dbi.go +++ b/db/dbi.go @@ -13,7 +13,7 @@ type RegistrationRecord struct { type DB interface { Close() error - Register(p peer.ID, ns string, addrs [][]byte, ttl int) error + Register(p peer.ID, ns string, addrs [][]byte, ttl int) (uint64, error) Unregister(p peer.ID, ns string) error CountRegistrations(p peer.ID) (int, error) Discover(ns string, cookie []byte, limit int) ([]RegistrationRecord, []byte, error) diff --git a/db/sqlite/db.go b/db/sqlite/db.go index 356ed68..1399df2 100644 --- a/db/sqlite/db.go +++ b/db/sqlite/db.go @@ -33,6 +33,7 @@ type DB struct { selectPeerRegistrationsC *sql.Stmt selectPeerRegistrationsNSC *sql.Stmt deleteExpiredRegistrations *sql.Stmt + getCounter *sql.Stmt nonce []byte @@ -189,35 +190,51 @@ func (db *DB) prepareStmts() error { } db.deleteExpiredRegistrations = stmt + stmt, err = db.db.Prepare("SELECT MAX(counter) FROM Registrations") + if err != nil { + return err + } + db.getCounter = stmt + return nil } -func (db *DB) Register(p peer.ID, ns string, addrs [][]byte, ttl int) error { +func (db *DB) Register(p peer.ID, ns string, addrs [][]byte, ttl int) (uint64, error) { pid := p.Pretty() maddrs := packAddrs(addrs) expire := time.Now().Unix() + int64(ttl) tx, err := db.db.Begin() if err != nil { - return err + return 0, err } delOld := tx.Stmt(db.deletePeerRegistrationsNs) insertNew := tx.Stmt(db.insertPeerRegistration) + getCounter := tx.Stmt(db.getCounter) _, err = delOld.Exec(pid, ns) if err != nil { tx.Rollback() - return err + return 0, err } _, err = insertNew.Exec(pid, ns, expire, maddrs) if err != nil { tx.Rollback() - return err + return 0, err } - return tx.Commit() + var counter uint64 + row := getCounter.QueryRow() + err = row.Scan(&counter) + if err != nil { + tx.Rollback() + return 0, err + } + + err = tx.Commit() + return counter, err } func (db *DB) CountRegistrations(p peer.ID) (int, error) { diff --git a/svc.go b/svc.go index 0be1f07..0293765 100644 --- a/svc.go +++ b/svc.go @@ -26,7 +26,7 @@ type RendezvousService struct { } type RendezvousSync interface { - Register(p peer.ID, ns string, addrs [][]byte, ttl int) + Register(p peer.ID, ns string, addrs [][]byte, ttl int, counter uint64) Unregister(p peer.ID, ns string) } @@ -160,7 +160,7 @@ func (rz *RendezvousService) handleRegister(p peer.ID, m *pb.Message_Register) * } // ok, seems like we can register - err = rz.DB.Register(p, ns, maddrs, ttl) + counter, err := rz.DB.Register(p, ns, maddrs, ttl) if err != nil { log.Errorf("Error registering: %s", err.Error()) return newRegisterResponseError(pb.Message_E_INTERNAL_ERROR, "database error") @@ -169,7 +169,7 @@ func (rz *RendezvousService) handleRegister(p peer.ID, m *pb.Message_Register) * log.Infof("registered peer %s %s (%d)", p, ns, ttl) for _, rzs := range rz.rzs { - rzs.Register(p, ns, maddrs, ttl) + rzs.Register(p, ns, maddrs, ttl, counter) } return newRegisterResponse(ttl)