2018-04-26 14:06:58 +03:00
|
|
|
package db
|
2018-04-23 11:54:25 +03:00
|
|
|
|
|
|
|
|
import (
|
2018-04-24 12:00:37 +03:00
|
|
|
"bytes"
|
2018-04-23 11:54:25 +03:00
|
|
|
"context"
|
2018-04-23 22:28:16 +03:00
|
|
|
"crypto/rand"
|
2018-04-24 12:00:37 +03:00
|
|
|
"crypto/sha256"
|
2018-04-23 22:28:16 +03:00
|
|
|
"database/sql"
|
2018-04-24 12:00:37 +03:00
|
|
|
"encoding/binary"
|
|
|
|
|
"fmt"
|
2018-04-23 22:28:16 +03:00
|
|
|
"os"
|
|
|
|
|
"time"
|
|
|
|
|
|
2023-06-01 11:50:48 -04:00
|
|
|
dbi "github.com/waku-org/go-libp2p-rendezvous/db"
|
2018-04-26 14:06:58 +03:00
|
|
|
|
2018-04-23 22:28:16 +03:00
|
|
|
_ "github.com/mattn/go-sqlite3"
|
2018-04-23 11:54:25 +03:00
|
|
|
|
2020-10-28 14:39:21 +01:00
|
|
|
logging "github.com/ipfs/go-log/v2"
|
2022-11-08 16:57:44 +01:00
|
|
|
"github.com/libp2p/go-libp2p/core/peer"
|
2018-04-23 11:54:25 +03:00
|
|
|
)
|
|
|
|
|
|
2018-04-26 14:06:58 +03:00
|
|
|
var log = logging.Logger("rendezvous/db")
|
|
|
|
|
|
2018-04-23 11:54:25 +03:00
|
|
|
type DB struct {
|
2018-04-23 22:28:16 +03:00
|
|
|
db *sql.DB
|
|
|
|
|
|
|
|
|
|
insertPeerRegistration *sql.Stmt
|
|
|
|
|
deletePeerRegistrations *sql.Stmt
|
|
|
|
|
deletePeerRegistrationsNs *sql.Stmt
|
|
|
|
|
countPeerRegistrations *sql.Stmt
|
|
|
|
|
selectPeerRegistrations *sql.Stmt
|
|
|
|
|
selectPeerRegistrationsNS *sql.Stmt
|
|
|
|
|
selectPeerRegistrationsC *sql.Stmt
|
|
|
|
|
selectPeerRegistrationsNSC *sql.Stmt
|
|
|
|
|
deleteExpiredRegistrations *sql.Stmt
|
2019-01-18 15:46:02 +02:00
|
|
|
getCounter *sql.Stmt
|
2018-04-23 22:28:16 +03:00
|
|
|
|
|
|
|
|
nonce []byte
|
|
|
|
|
|
|
|
|
|
cancel func()
|
2018-04-23 11:54:25 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func OpenDB(ctx context.Context, path string) (*DB, error) {
|
2018-04-23 22:28:16 +03:00
|
|
|
var create bool
|
|
|
|
|
if path == ":memory:" {
|
|
|
|
|
create = true
|
|
|
|
|
} else {
|
|
|
|
|
_, err := os.Stat(path)
|
|
|
|
|
switch {
|
|
|
|
|
case os.IsNotExist(err):
|
|
|
|
|
create = true
|
|
|
|
|
case err != nil:
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
db, err := sql.Open("sqlite3", path)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
2018-04-24 14:08:17 +03:00
|
|
|
if path == ":memory:" {
|
|
|
|
|
// this is necessary to avoid creating a new database on each connection
|
|
|
|
|
db.SetMaxOpenConns(1)
|
|
|
|
|
}
|
|
|
|
|
|
2018-04-23 22:28:16 +03:00
|
|
|
rdb := &DB{db: db}
|
|
|
|
|
if create {
|
|
|
|
|
err = rdb.prepareDB()
|
|
|
|
|
if err != nil {
|
|
|
|
|
db.Close()
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
err = rdb.loadNonce()
|
|
|
|
|
if err != nil {
|
|
|
|
|
db.Close()
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
err = rdb.prepareStmts()
|
|
|
|
|
if err != nil {
|
|
|
|
|
db.Close()
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bgctx, cancel := context.WithCancel(ctx)
|
|
|
|
|
rdb.cancel = cancel
|
|
|
|
|
go rdb.background(bgctx)
|
|
|
|
|
|
|
|
|
|
return rdb, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (db *DB) Close() error {
|
|
|
|
|
db.cancel()
|
|
|
|
|
return db.db.Close()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (db *DB) prepareDB() error {
|
2023-06-01 11:50:48 -04:00
|
|
|
_, err := db.db.Exec("CREATE TABLE Registrations (counter INTEGER PRIMARY KEY AUTOINCREMENT, peer VARCHAR(64), ns VARCHAR, expire INTEGER, signedPeerRecord VARBINARY)")
|
2018-04-23 22:28:16 +03:00
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_, err = db.db.Exec("CREATE TABLE Nonce (nonce VARBINARY)")
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
2018-04-24 18:23:10 +03:00
|
|
|
nonce := make([]byte, 32)
|
2018-04-23 22:28:16 +03:00
|
|
|
_, err = rand.Read(nonce)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_, err = db.db.Exec("INSERT INTO Nonce VALUES (?)", nonce)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
db.nonce = nonce
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (db *DB) loadNonce() error {
|
|
|
|
|
var nonce []byte
|
|
|
|
|
row := db.db.QueryRow("SELECT nonce FROM Nonce")
|
|
|
|
|
err := row.Scan(&nonce)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
db.nonce = nonce
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (db *DB) prepareStmts() error {
|
|
|
|
|
stmt, err := db.db.Prepare("INSERT INTO Registrations VALUES (NULL, ?, ?, ?, ?)")
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
db.insertPeerRegistration = stmt
|
|
|
|
|
|
|
|
|
|
stmt, err = db.db.Prepare("DELETE FROM Registrations WHERE peer = ?")
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
db.deletePeerRegistrations = stmt
|
|
|
|
|
|
|
|
|
|
stmt, err = db.db.Prepare("DELETE FROM Registrations WHERE peer = ? AND ns = ?")
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
db.deletePeerRegistrationsNs = stmt
|
|
|
|
|
|
|
|
|
|
stmt, err = db.db.Prepare("SELECT COUNT(*) FROM Registrations WHERE peer = ?")
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
db.countPeerRegistrations = stmt
|
|
|
|
|
|
|
|
|
|
stmt, err = db.db.Prepare("SELECT * FROM Registrations WHERE expire > ? LIMIT ?")
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
db.selectPeerRegistrations = stmt
|
|
|
|
|
|
|
|
|
|
stmt, err = db.db.Prepare("SELECT * FROM Registrations WHERE ns = ? AND expire > ? LIMIT ?")
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
db.selectPeerRegistrationsNS = stmt
|
|
|
|
|
|
|
|
|
|
stmt, err = db.db.Prepare("SELECT * FROM Registrations WHERE counter > ? AND expire > ? LIMIT ?")
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
db.selectPeerRegistrationsC = stmt
|
|
|
|
|
|
|
|
|
|
stmt, err = db.db.Prepare("SELECT * FROM Registrations WHERE counter > ? AND ns = ? AND expire > ? LIMIT ?")
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
db.selectPeerRegistrationsNSC = stmt
|
|
|
|
|
|
|
|
|
|
stmt, err = db.db.Prepare("DELETE FROM Registrations WHERE expire < ?")
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
db.deleteExpiredRegistrations = stmt
|
|
|
|
|
|
2019-01-18 15:46:02 +02:00
|
|
|
stmt, err = db.db.Prepare("SELECT MAX(counter) FROM Registrations")
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
db.getCounter = stmt
|
|
|
|
|
|
2018-04-23 22:28:16 +03:00
|
|
|
return nil
|
2018-04-23 11:54:25 +03:00
|
|
|
}
|
|
|
|
|
|
2023-06-01 11:50:48 -04:00
|
|
|
func (db *DB) Register(p peer.ID, ns string, signedPeerRecord []byte, ttl int) (uint64, error) {
|
2018-04-23 22:28:16 +03:00
|
|
|
pid := p.Pretty()
|
|
|
|
|
expire := time.Now().Unix() + int64(ttl)
|
|
|
|
|
|
|
|
|
|
tx, err := db.db.Begin()
|
|
|
|
|
if err != nil {
|
2019-01-18 15:46:02 +02:00
|
|
|
return 0, err
|
2018-04-23 22:28:16 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
delOld := tx.Stmt(db.deletePeerRegistrationsNs)
|
|
|
|
|
insertNew := tx.Stmt(db.insertPeerRegistration)
|
2019-01-18 15:46:02 +02:00
|
|
|
getCounter := tx.Stmt(db.getCounter)
|
2018-04-23 22:28:16 +03:00
|
|
|
|
|
|
|
|
_, err = delOld.Exec(pid, ns)
|
|
|
|
|
if err != nil {
|
|
|
|
|
tx.Rollback()
|
2019-01-18 15:46:02 +02:00
|
|
|
return 0, err
|
2018-04-23 22:28:16 +03:00
|
|
|
}
|
|
|
|
|
|
2023-06-01 11:50:48 -04:00
|
|
|
_, err = insertNew.Exec(pid, ns, expire, signedPeerRecord)
|
2018-04-23 22:28:16 +03:00
|
|
|
if err != nil {
|
|
|
|
|
tx.Rollback()
|
2019-01-18 15:46:02 +02:00
|
|
|
return 0, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var counter uint64
|
|
|
|
|
row := getCounter.QueryRow()
|
|
|
|
|
err = row.Scan(&counter)
|
|
|
|
|
if err != nil {
|
|
|
|
|
tx.Rollback()
|
|
|
|
|
return 0, err
|
2018-04-23 22:28:16 +03:00
|
|
|
}
|
|
|
|
|
|
2019-01-18 15:46:02 +02:00
|
|
|
err = tx.Commit()
|
|
|
|
|
return counter, err
|
2018-04-23 11:54:25 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (db *DB) CountRegistrations(p peer.ID) (int, error) {
|
2018-04-23 22:28:16 +03:00
|
|
|
pid := p.Pretty()
|
|
|
|
|
|
|
|
|
|
row := db.countPeerRegistrations.QueryRow(pid)
|
|
|
|
|
|
|
|
|
|
var count int
|
|
|
|
|
err := row.Scan(&count)
|
|
|
|
|
|
|
|
|
|
return count, err
|
2018-04-23 11:54:25 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (db *DB) Unregister(p peer.ID, ns string) error {
|
2018-04-23 22:28:16 +03:00
|
|
|
pid := p.Pretty()
|
|
|
|
|
|
|
|
|
|
var err error
|
|
|
|
|
|
|
|
|
|
if ns == "" {
|
|
|
|
|
_, err = db.deletePeerRegistrations.Exec(pid)
|
|
|
|
|
} else {
|
|
|
|
|
_, err = db.deletePeerRegistrationsNs.Exec(pid, ns)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
2018-04-26 14:06:58 +03:00
|
|
|
func (db *DB) Discover(ns string, cookie []byte, limit int) ([]dbi.RegistrationRecord, []byte, error) {
|
2018-04-23 22:28:16 +03:00
|
|
|
now := time.Now().Unix()
|
|
|
|
|
|
|
|
|
|
var (
|
|
|
|
|
counter int64
|
|
|
|
|
rows *sql.Rows
|
|
|
|
|
err error
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if cookie != nil {
|
2018-04-24 12:00:37 +03:00
|
|
|
counter, err = unpackCookie(cookie)
|
2018-04-23 22:28:16 +03:00
|
|
|
if err != nil {
|
|
|
|
|
log.Errorf("error unpacking cookie: %s", err.Error())
|
|
|
|
|
return nil, nil, err
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if counter > 0 {
|
|
|
|
|
if ns == "" {
|
|
|
|
|
rows, err = db.selectPeerRegistrationsC.Query(counter, now, limit)
|
|
|
|
|
} else {
|
|
|
|
|
rows, err = db.selectPeerRegistrationsNSC.Query(counter, ns, now, limit)
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if ns == "" {
|
|
|
|
|
rows, err = db.selectPeerRegistrations.Query(now, limit)
|
|
|
|
|
} else {
|
2018-04-24 14:08:17 +03:00
|
|
|
rows, err = db.selectPeerRegistrationsNS.Query(ns, now, limit)
|
2018-04-23 22:28:16 +03:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Errorf("query error: %s", err.Error())
|
|
|
|
|
return nil, nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
defer rows.Close()
|
|
|
|
|
|
2018-04-26 14:06:58 +03:00
|
|
|
regs := make([]dbi.RegistrationRecord, 0, limit)
|
2018-04-23 22:28:16 +03:00
|
|
|
for rows.Next() {
|
|
|
|
|
var (
|
2023-06-01 11:50:48 -04:00
|
|
|
reg dbi.RegistrationRecord
|
|
|
|
|
rid string
|
|
|
|
|
rns string
|
|
|
|
|
expire int64
|
|
|
|
|
signedPeerRecord []byte
|
|
|
|
|
p peer.ID
|
2018-04-23 22:28:16 +03:00
|
|
|
)
|
|
|
|
|
|
2023-06-01 11:50:48 -04:00
|
|
|
err = rows.Scan(&counter, &rid, &rns, &expire, &signedPeerRecord)
|
2018-04-23 22:28:16 +03:00
|
|
|
if err != nil {
|
|
|
|
|
log.Errorf("row scan error: %s", err.Error())
|
|
|
|
|
return nil, nil, err
|
|
|
|
|
}
|
|
|
|
|
|
2021-09-15 15:29:04 +02:00
|
|
|
p, err = peer.Decode(rid)
|
2018-04-23 22:28:16 +03:00
|
|
|
if err != nil {
|
|
|
|
|
log.Errorf("error decoding peer id: %s", err.Error())
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
|
2018-04-24 12:00:37 +03:00
|
|
|
reg.Id = p
|
2023-06-01 11:50:48 -04:00
|
|
|
reg.SignedPeerRecord = signedPeerRecord
|
2018-04-23 22:28:16 +03:00
|
|
|
reg.Ttl = int(expire - now)
|
|
|
|
|
|
|
|
|
|
if ns == "" {
|
|
|
|
|
reg.Ns = rns
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
regs = append(regs, reg)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
err = rows.Err()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if counter > 0 {
|
2018-04-24 12:00:37 +03:00
|
|
|
cookie = packCookie(counter, ns, db.nonce)
|
2018-04-23 22:28:16 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return regs, cookie, nil
|
2018-04-23 11:54:25 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (db *DB) ValidCookie(ns string, cookie []byte) bool {
|
2018-04-24 12:00:37 +03:00
|
|
|
return validCookie(cookie, ns, db.nonce)
|
2018-04-23 11:54:25 +03:00
|
|
|
}
|
|
|
|
|
|
2018-04-23 22:28:16 +03:00
|
|
|
func (db *DB) background(ctx context.Context) {
|
|
|
|
|
for {
|
2018-04-24 14:08:17 +03:00
|
|
|
db.cleanupExpired()
|
2018-04-23 22:28:16 +03:00
|
|
|
|
|
|
|
|
select {
|
|
|
|
|
case <-time.After(15 * time.Minute):
|
|
|
|
|
case <-ctx.Done():
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2018-04-24 14:08:17 +03:00
|
|
|
func (db *DB) cleanupExpired() {
|
|
|
|
|
now := time.Now().Unix()
|
|
|
|
|
_, err := db.deleteExpiredRegistrations.Exec(now)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Errorf("error deleting expired registrations: %s", err.Error())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2018-04-24 12:00:37 +03:00
|
|
|
// cookie: counter:SHA256(nonce + ns + counter)
|
|
|
|
|
func packCookie(counter int64, ns string, nonce []byte) []byte {
|
|
|
|
|
cbits := make([]byte, 8)
|
|
|
|
|
binary.BigEndian.PutUint64(cbits, uint64(counter))
|
|
|
|
|
|
|
|
|
|
hash := sha256.New()
|
|
|
|
|
_, err := hash.Write(nonce)
|
|
|
|
|
if err != nil {
|
|
|
|
|
panic(err)
|
|
|
|
|
}
|
|
|
|
|
_, err = hash.Write([]byte(ns))
|
|
|
|
|
if err != nil {
|
|
|
|
|
panic(err)
|
|
|
|
|
}
|
|
|
|
|
_, err = hash.Write(cbits)
|
|
|
|
|
if err != nil {
|
|
|
|
|
panic(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return hash.Sum(cbits)
|
2018-04-23 22:28:16 +03:00
|
|
|
}
|
|
|
|
|
|
2018-04-24 12:00:37 +03:00
|
|
|
func unpackCookie(cookie []byte) (int64, error) {
|
|
|
|
|
if len(cookie) < 8 {
|
|
|
|
|
return 0, fmt.Errorf("bad packed cookie: not enough bytes: %v", cookie)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
counter := binary.BigEndian.Uint64(cookie[:8])
|
|
|
|
|
return int64(counter), nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func validCookie(cookie []byte, ns string, nonce []byte) bool {
|
|
|
|
|
if len(cookie) != 40 {
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
cbits := cookie[:8]
|
|
|
|
|
hash := sha256.New()
|
|
|
|
|
_, err := hash.Write(nonce)
|
|
|
|
|
if err != nil {
|
|
|
|
|
panic(err)
|
|
|
|
|
}
|
|
|
|
|
_, err = hash.Write([]byte(ns))
|
|
|
|
|
if err != nil {
|
|
|
|
|
panic(err)
|
|
|
|
|
}
|
|
|
|
|
_, err = hash.Write(cbits)
|
|
|
|
|
if err != nil {
|
|
|
|
|
panic(err)
|
|
|
|
|
}
|
|
|
|
|
hbits := hash.Sum(nil)
|
|
|
|
|
|
|
|
|
|
return bytes.Equal(cookie[8:], hbits)
|
2018-04-23 11:54:25 +03:00
|
|
|
}
|