go-waku/waku/v2/rendezvous/db.go

391 lines
7.6 KiB
Go

package rendezvous
import (
"bytes"
"context"
"crypto/rand"
"crypto/sha256"
"database/sql"
"encoding/binary"
"errors"
"fmt"
"time"
"github.com/libp2p/go-libp2p/core/peer"
dbi "github.com/waku-org/go-libp2p-rendezvous/db"
"go.uber.org/zap"
)
type DB struct {
db *sql.DB
logger *zap.Logger
insertPeerRegistration *sql.Stmt
deletePeerRegistrations *sql.Stmt
deletePeerRegistrationsNs *sql.Stmt
countPeerRegistrations *sql.Stmt
selectPeerRegistrations *sql.Stmt
selectPeerRegistrationsNS *sql.Stmt
selectPeerRegistrationsC *sql.Stmt
selectPeerRegistrationsNSC *sql.Stmt
deleteExpiredRegistrations *sql.Stmt
getCounter *sql.Stmt
nonce []byte
cancel func()
}
func NewDB(ctx context.Context, db *sql.DB, logger *zap.Logger) *DB {
rdb := &DB{
db: db,
logger: logger.Named("rendezvous/db"),
}
return rdb
}
func (db *DB) Start(ctx context.Context) error {
err := db.loadNonce()
if err != nil {
db.Close()
return err
}
err = db.prepareStmts()
if err != nil {
db.Close()
return err
}
bgctx, cancel := context.WithCancel(ctx)
db.cancel = cancel
go db.background(bgctx)
return nil
}
func (db *DB) Close() error {
db.cancel()
return db.db.Close()
}
func (db *DB) insertNonce() error {
nonce := make([]byte, 32)
_, err := rand.Read(nonce)
if err != nil {
return err
}
_, err = db.db.Exec("INSERT INTO nonce VALUES (?)", nonce)
if err != nil {
return err
}
db.nonce = nonce
return nil
}
func (db *DB) loadNonce() error {
var nonce []byte
row := db.db.QueryRow("SELECT nonce FROM nonce")
err := row.Scan(&nonce)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return db.insertNonce()
}
return err
}
db.nonce = nonce
return nil
}
func (db *DB) prepareStmts() error {
stmt, err := db.db.Prepare("INSERT INTO registrations VALUES (NULL, ?, ?, ?, ?)")
if err != nil {
return err
}
db.insertPeerRegistration = stmt
stmt, err = db.db.Prepare("DELETE FROM registrations WHERE peer = ?")
if err != nil {
return err
}
db.deletePeerRegistrations = stmt
stmt, err = db.db.Prepare("DELETE FROM registrations WHERE peer = ? AND ns = ?")
if err != nil {
return err
}
db.deletePeerRegistrationsNs = stmt
stmt, err = db.db.Prepare("SELECT COUNT(*) FROM registrations WHERE peer = ?")
if err != nil {
return err
}
db.countPeerRegistrations = stmt
stmt, err = db.db.Prepare("SELECT * FROM registrations WHERE expire > ? LIMIT ?")
if err != nil {
return err
}
db.selectPeerRegistrations = stmt
stmt, err = db.db.Prepare("SELECT * FROM registrations WHERE ns = ? AND expire > ? LIMIT ?")
if err != nil {
return err
}
db.selectPeerRegistrationsNS = stmt
stmt, err = db.db.Prepare("SELECT * FROM registrations WHERE counter > ? AND expire > ? LIMIT ?")
if err != nil {
return err
}
db.selectPeerRegistrationsC = stmt
stmt, err = db.db.Prepare("SELECT * FROM registrations WHERE counter > ? AND ns = ? AND expire > ? LIMIT ?")
if err != nil {
return err
}
db.selectPeerRegistrationsNSC = stmt
stmt, err = db.db.Prepare("DELETE FROM registrations WHERE expire < ?")
if err != nil {
return err
}
db.deleteExpiredRegistrations = stmt
stmt, err = db.db.Prepare("SELECT MAX(counter) FROM registrations")
if err != nil {
return err
}
db.getCounter = stmt
return nil
}
func (db *DB) Register(p peer.ID, ns string, signedPeerRecord []byte, ttl int) (uint64, error) {
pid := p.Pretty()
expire := time.Now().Unix() + int64(ttl)
tx, err := db.db.Begin()
if err != nil {
return 0, err
}
delOld := tx.Stmt(db.deletePeerRegistrationsNs)
insertNew := tx.Stmt(db.insertPeerRegistration)
getCounter := tx.Stmt(db.getCounter)
_, err = delOld.Exec(pid, ns)
if err != nil {
_ = tx.Rollback()
return 0, err
}
_, err = insertNew.Exec(pid, ns, expire, signedPeerRecord)
if err != nil {
_ = tx.Rollback()
return 0, err
}
var counter uint64
row := getCounter.QueryRow()
err = row.Scan(&counter)
if err != nil {
_ = tx.Rollback()
return 0, err
}
err = tx.Commit()
return counter, err
}
func (db *DB) CountRegistrations(p peer.ID) (int, error) {
pid := p.Pretty()
row := db.countPeerRegistrations.QueryRow(pid)
var count int
err := row.Scan(&count)
return count, err
}
func (db *DB) Unregister(p peer.ID, ns string) error {
pid := p.Pretty()
var err error
if ns == "" {
_, err = db.deletePeerRegistrations.Exec(pid)
} else {
_, err = db.deletePeerRegistrationsNs.Exec(pid, ns)
}
return err
}
func (db *DB) Discover(ns string, cookie []byte, limit int) ([]dbi.RegistrationRecord, []byte, error) {
now := time.Now().Unix()
var (
counter int64
rows *sql.Rows
err error
)
if cookie != nil {
counter, err = unpackCookie(cookie)
if err != nil {
db.logger.Error("unpacking cookie", zap.Error(err))
return nil, nil, err
}
}
if counter > 0 {
if ns == "" {
rows, err = db.selectPeerRegistrationsC.Query(counter, now, limit)
} else {
rows, err = db.selectPeerRegistrationsNSC.Query(counter, ns, now, limit)
}
} else {
if ns == "" {
rows, err = db.selectPeerRegistrations.Query(now, limit)
} else {
rows, err = db.selectPeerRegistrationsNS.Query(ns, now, limit)
}
}
if err != nil {
db.logger.Error("query", zap.Error(err))
return nil, nil, err
}
defer rows.Close()
regs := make([]dbi.RegistrationRecord, 0, limit)
for rows.Next() {
var (
reg dbi.RegistrationRecord
rid string
rns string
expire int64
signedPeerRecord []byte
p peer.ID
)
err = rows.Scan(&counter, &rid, &rns, &expire, &signedPeerRecord)
if err != nil {
db.logger.Error("row scan error", zap.Error(err))
return nil, nil, err
}
p, err = peer.Decode(rid)
if err != nil {
db.logger.Error("error decoding peer id", zap.Error(err))
continue
}
reg.Id = p
reg.SignedPeerRecord = signedPeerRecord
reg.Ttl = int(expire - now)
if ns == "" {
reg.Ns = rns
}
regs = append(regs, reg)
}
err = rows.Err()
if err != nil {
return nil, nil, err
}
if counter > 0 {
cookie = packCookie(counter, ns, db.nonce)
}
return regs, cookie, nil
}
func (db *DB) ValidCookie(ns string, cookie []byte) bool {
return validCookie(cookie, ns, db.nonce)
}
func (db *DB) background(ctx context.Context) {
for {
db.cleanupExpired()
select {
case <-time.After(15 * time.Minute):
case <-ctx.Done():
return
}
}
}
func (db *DB) cleanupExpired() {
now := time.Now().Unix()
_, err := db.deleteExpiredRegistrations.Exec(now)
if err != nil {
db.logger.Error("deleting expired registrations", zap.Error(err))
}
}
// cookie: counter:SHA256(nonce + ns + counter)
func packCookie(counter int64, ns string, nonce []byte) []byte {
cbits := make([]byte, 8)
binary.BigEndian.PutUint64(cbits, uint64(counter))
hash := sha256.New()
_, err := hash.Write(nonce)
if err != nil {
panic(err)
}
_, err = hash.Write([]byte(ns))
if err != nil {
panic(err)
}
_, err = hash.Write(cbits)
if err != nil {
panic(err)
}
return hash.Sum(cbits)
}
func unpackCookie(cookie []byte) (int64, error) {
if len(cookie) < 8 {
return 0, fmt.Errorf("bad packed cookie: not enough bytes: %v", cookie)
}
counter := binary.BigEndian.Uint64(cookie[:8])
return int64(counter), nil
}
func validCookie(cookie []byte, ns string, nonce []byte) bool {
if len(cookie) != 40 {
return false
}
cbits := cookie[:8]
hash := sha256.New()
_, err := hash.Write(nonce)
if err != nil {
panic(err)
}
_, err = hash.Write([]byte(ns))
if err != nil {
panic(err)
}
_, err = hash.Write(cbits)
if err != nil {
panic(err)
}
hbits := hash.Sum(nil)
return bytes.Equal(cookie[8:], hbits)
}