expose counter in register interface

This commit is contained in:
vyzo 2019-01-18 15:46:02 +02:00
parent 3c726d2ea9
commit f2ee9b3d44
3 changed files with 26 additions and 9 deletions

View File

@ -13,7 +13,7 @@ type RegistrationRecord struct {
type DB interface { type DB interface {
Close() error Close() error
Register(p peer.ID, ns string, addrs [][]byte, ttl int) error Register(p peer.ID, ns string, addrs [][]byte, ttl int) (uint64, error)
Unregister(p peer.ID, ns string) error Unregister(p peer.ID, ns string) error
CountRegistrations(p peer.ID) (int, error) CountRegistrations(p peer.ID) (int, error)
Discover(ns string, cookie []byte, limit int) ([]RegistrationRecord, []byte, error) Discover(ns string, cookie []byte, limit int) ([]RegistrationRecord, []byte, error)

View File

@ -33,6 +33,7 @@ type DB struct {
selectPeerRegistrationsC *sql.Stmt selectPeerRegistrationsC *sql.Stmt
selectPeerRegistrationsNSC *sql.Stmt selectPeerRegistrationsNSC *sql.Stmt
deleteExpiredRegistrations *sql.Stmt deleteExpiredRegistrations *sql.Stmt
getCounter *sql.Stmt
nonce []byte nonce []byte
@ -189,35 +190,51 @@ func (db *DB) prepareStmts() error {
} }
db.deleteExpiredRegistrations = stmt db.deleteExpiredRegistrations = stmt
stmt, err = db.db.Prepare("SELECT MAX(counter) FROM Registrations")
if err != nil {
return err
}
db.getCounter = stmt
return nil return nil
} }
func (db *DB) Register(p peer.ID, ns string, addrs [][]byte, ttl int) error { func (db *DB) Register(p peer.ID, ns string, addrs [][]byte, ttl int) (uint64, error) {
pid := p.Pretty() pid := p.Pretty()
maddrs := packAddrs(addrs) maddrs := packAddrs(addrs)
expire := time.Now().Unix() + int64(ttl) expire := time.Now().Unix() + int64(ttl)
tx, err := db.db.Begin() tx, err := db.db.Begin()
if err != nil { if err != nil {
return err return 0, err
} }
delOld := tx.Stmt(db.deletePeerRegistrationsNs) delOld := tx.Stmt(db.deletePeerRegistrationsNs)
insertNew := tx.Stmt(db.insertPeerRegistration) insertNew := tx.Stmt(db.insertPeerRegistration)
getCounter := tx.Stmt(db.getCounter)
_, err = delOld.Exec(pid, ns) _, err = delOld.Exec(pid, ns)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
return err return 0, err
} }
_, err = insertNew.Exec(pid, ns, expire, maddrs) _, err = insertNew.Exec(pid, ns, expire, maddrs)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
return err return 0, err
} }
return tx.Commit() var counter uint64
row := getCounter.QueryRow()
err = row.Scan(&counter)
if err != nil {
tx.Rollback()
return 0, err
}
err = tx.Commit()
return counter, err
} }
func (db *DB) CountRegistrations(p peer.ID) (int, error) { func (db *DB) CountRegistrations(p peer.ID) (int, error) {

6
svc.go
View File

@ -26,7 +26,7 @@ type RendezvousService struct {
} }
type RendezvousSync interface { type RendezvousSync interface {
Register(p peer.ID, ns string, addrs [][]byte, ttl int) Register(p peer.ID, ns string, addrs [][]byte, ttl int, counter uint64)
Unregister(p peer.ID, ns string) Unregister(p peer.ID, ns string)
} }
@ -160,7 +160,7 @@ func (rz *RendezvousService) handleRegister(p peer.ID, m *pb.Message_Register) *
} }
// ok, seems like we can register // ok, seems like we can register
err = rz.DB.Register(p, ns, maddrs, ttl) counter, err := rz.DB.Register(p, ns, maddrs, ttl)
if err != nil { if err != nil {
log.Errorf("Error registering: %s", err.Error()) log.Errorf("Error registering: %s", err.Error())
return newRegisterResponseError(pb.Message_E_INTERNAL_ERROR, "database error") return newRegisterResponseError(pb.Message_E_INTERNAL_ERROR, "database error")
@ -169,7 +169,7 @@ func (rz *RendezvousService) handleRegister(p peer.ID, m *pb.Message_Register) *
log.Infof("registered peer %s %s (%d)", p, ns, ttl) log.Infof("registered peer %s %s (%d)", p, ns, ttl)
for _, rzs := range rz.rzs { for _, rzs := range rz.rzs {
rzs.Register(p, ns, maddrs, ttl) rzs.Register(p, ns, maddrs, ttl, counter)
} }
return newRegisterResponse(ttl) return newRegisterResponse(ttl)