477 lines
12 KiB
Go
Raw Normal View History

2021-04-13 14:52:57 -04:00
package persistence
2021-04-12 13:59:41 -04:00
import (
2023-03-09 11:48:25 -04:00
"context"
2021-04-12 13:59:41 -04:00
"database/sql"
"errors"
"fmt"
"strings"
"sync"
"time"
2021-04-12 13:59:41 -04:00
"github.com/waku-org/go-waku/waku/v2/protocol"
wpb "github.com/waku-org/go-waku/waku/v2/protocol/pb"
"github.com/waku-org/go-waku/waku/v2/protocol/store/pb"
2022-12-08 23:08:04 -04:00
"github.com/waku-org/go-waku/waku/v2/timesource"
"github.com/waku-org/go-waku/waku/v2/utils"
"go.uber.org/zap"
2021-04-12 13:59:41 -04:00
)
2021-10-25 15:41:08 -04:00
type MessageProvider interface {
GetAll() ([]StoredMessage, error)
Put(env *protocol.Envelope) error
Query(query *pb.HistoryQuery) ([]StoredMessage, error)
MostRecentTimestamp() (int64, error)
2023-03-09 11:48:25 -04:00
Start(ctx context.Context, timesource timesource.Timesource) error
2021-10-25 15:41:08 -04:00
Stop()
}
var ErrInvalidCursor = errors.New("invalid cursor")
// WALMode for sqlite.
const WALMode = "wal"
2021-04-22 14:49:52 -04:00
// DBStore is a MessageProvider that has a *sql.DB connection
2021-04-12 13:59:41 -04:00
type DBStore struct {
2021-10-25 15:41:08 -04:00
MessageProvider
2023-01-04 13:58:14 -04:00
db *sql.DB
migrationFn func(db *sql.DB) error
2022-12-08 23:08:04 -04:00
timesource timesource.Timesource
log *zap.Logger
maxMessages int
maxDuration time.Duration
enableMigrations bool
2023-03-09 11:48:25 -04:00
wg sync.WaitGroup
cancel context.CancelFunc
2021-04-12 13:59:41 -04:00
}
2021-10-25 15:41:08 -04:00
type StoredMessage struct {
ID []byte
PubsubTopic string
ReceiverTime int64
Message *wpb.WakuMessage
2021-10-25 15:41:08 -04:00
}
2021-10-09 14:18:53 -04:00
// DBOption is an optional setting that can be used to configure the DBStore
2021-04-13 14:52:57 -04:00
type DBOption func(*DBStore) error
2021-04-22 14:49:52 -04:00
// WithDB is a DBOption that lets you use any custom *sql.DB with a DBStore.
2021-04-13 14:52:57 -04:00
func WithDB(db *sql.DB) DBOption {
return func(d *DBStore) error {
d.db = db
return nil
}
}
type ConnectionPoolOptions struct {
MaxOpenConnections int
MaxIdleConnections int
ConnectionMaxLifetime time.Duration
ConnectionMaxIdleTime time.Duration
}
2021-04-22 14:49:52 -04:00
// WithDriver is a DBOption that will open a *sql.DB connection
func WithDriver(driverName string, datasourceName string, connectionPoolOptions ...ConnectionPoolOptions) DBOption {
2021-04-13 14:52:57 -04:00
return func(d *DBStore) error {
db, err := sql.Open(driverName, datasourceName)
if err != nil {
return err
}
if len(connectionPoolOptions) != 0 {
db.SetConnMaxIdleTime(connectionPoolOptions[0].ConnectionMaxIdleTime)
db.SetConnMaxLifetime(connectionPoolOptions[0].ConnectionMaxLifetime)
db.SetMaxIdleConns(connectionPoolOptions[0].MaxIdleConnections)
db.SetMaxOpenConns(connectionPoolOptions[0].MaxOpenConnections)
}
2021-04-13 14:52:57 -04:00
d.db = db
return nil
}
}
2022-07-25 11:28:17 -04:00
// WithRetentionPolicy is a DBOption that specifies the max number of messages
// to be stored and duration before they're removed from the message store
func WithRetentionPolicy(maxMessages int, maxDuration time.Duration) DBOption {
return func(d *DBStore) error {
d.maxDuration = maxDuration
d.maxMessages = maxMessages
return nil
}
}
2023-01-04 13:58:14 -04:00
// WithMigrations is a DBOption used to determine if migrations should
// be executed, and what driver to use
func WithMigrations(migrationFn func(db *sql.DB) error) DBOption {
return func(d *DBStore) error {
2023-01-04 13:58:14 -04:00
d.enableMigrations = true
d.migrationFn = migrationFn
return nil
}
}
func DefaultOptions() []DBOption {
2023-01-04 13:58:14 -04:00
return []DBOption{}
}
2021-04-22 14:49:52 -04:00
// Creates a new DB store using the db specified via options.
// It will create a messages table if it does not exist and
// clean up records according to the retention policy used
func NewDBStore(log *zap.Logger, options ...DBOption) (*DBStore, error) {
2021-04-13 14:52:57 -04:00
result := new(DBStore)
result.log = log.Named("dbstore")
2021-04-13 14:52:57 -04:00
optList := DefaultOptions()
optList = append(optList, options...)
for _, opt := range optList {
err := opt(result)
if err != nil {
return nil, err
}
}
if result.enableMigrations {
2023-01-04 13:58:14 -04:00
err := result.migrationFn(result.db)
if err != nil {
return nil, err
}
2021-04-12 13:59:41 -04:00
}
2022-12-08 23:08:04 -04:00
return result, nil
}
2023-03-09 11:48:25 -04:00
func (d *DBStore) Start(ctx context.Context, timesource timesource.Timesource) error {
ctx, cancel := context.WithCancel(ctx)
d.cancel = cancel
2022-12-08 23:08:04 -04:00
d.timesource = timesource
err := d.cleanOlderRecords()
2021-04-12 13:59:41 -04:00
if err != nil {
2022-12-08 23:08:04 -04:00
return err
2021-04-12 13:59:41 -04:00
}
2022-12-08 23:08:04 -04:00
d.wg.Add(1)
2023-03-09 11:48:25 -04:00
go d.checkForOlderRecords(ctx, 60*time.Second)
2022-12-08 23:08:04 -04:00
return nil
2021-04-12 13:59:41 -04:00
}
func (d *DBStore) cleanOlderRecords() error {
2022-11-25 16:54:11 -04:00
d.log.Info("Cleaning older records...")
// Delete older messages
if d.maxDuration > 0 {
start := time.Now()
2023-01-04 13:58:14 -04:00
sqlStmt := `DELETE FROM message WHERE receiverTimestamp < $1`
2022-12-08 23:08:04 -04:00
_, err := d.db.Exec(sqlStmt, utils.GetUnixEpochFrom(d.timesource.Now().Add(-d.maxDuration)))
if err != nil {
return err
}
elapsed := time.Since(start)
d.log.Debug("deleting older records from the DB", zap.Duration("duration", elapsed))
}
// Limit number of records to a max N
if d.maxMessages > 0 {
start := time.Now()
2023-01-04 13:58:14 -04:00
sqlStmt := `DELETE FROM message WHERE id IN (SELECT id FROM message ORDER BY receiverTimestamp DESC LIMIT -1 OFFSET $1)`
_, err := d.db.Exec(sqlStmt, d.maxMessages)
if err != nil {
return err
}
elapsed := time.Since(start)
d.log.Debug("deleting excess records from the DB", zap.Duration("duration", elapsed))
}
2022-11-25 16:54:11 -04:00
d.log.Info("Older records removed")
return nil
}
2023-03-09 11:48:25 -04:00
func (d *DBStore) checkForOlderRecords(ctx context.Context, t time.Duration) {
defer d.wg.Done()
ticker := time.NewTicker(t)
defer ticker.Stop()
for {
select {
2023-03-09 11:48:25 -04:00
case <-ctx.Done():
return
case <-ticker.C:
2022-05-27 15:55:35 -04:00
err := d.cleanOlderRecords()
if err != nil {
d.log.Error("cleaning older records", zap.Error(err))
}
}
}
}
2022-07-25 11:28:17 -04:00
// Stop closes a DB connection
2021-04-12 13:59:41 -04:00
func (d *DBStore) Stop() {
2023-03-09 11:48:25 -04:00
if d.cancel == nil {
return
}
d.cancel()
d.wg.Wait()
2021-04-12 13:59:41 -04:00
d.db.Close()
}
2022-07-25 11:28:17 -04:00
// Put inserts a WakuMessage into the DB
func (d *DBStore) Put(env *protocol.Envelope) error {
2023-01-04 13:58:14 -04:00
stmt, err := d.db.Prepare("INSERT INTO message (id, receiverTimestamp, senderTimestamp, contentTopic, pubsubTopic, payload, version) VALUES ($1, $2, $3, $4, $5, $6, $7)")
2021-04-12 13:59:41 -04:00
if err != nil {
return err
}
cursor := env.Index()
dbKey := NewDBKey(uint64(cursor.SenderTime), uint64(cursor.ReceiverTime), env.PubsubTopic(), env.Index().Digest)
_, err = stmt.Exec(dbKey.Bytes(), cursor.ReceiverTime, env.Message().Timestamp, env.Message().ContentTopic, env.PubsubTopic(), env.Message().Payload, env.Message().Version)
2021-04-12 13:59:41 -04:00
if err != nil {
return err
}
2022-05-27 14:34:13 -04:00
err = stmt.Close()
if err != nil {
return err
}
2021-04-12 13:59:41 -04:00
return nil
}
2022-07-25 11:28:17 -04:00
// Query retrieves messages from the DB
func (d *DBStore) Query(query *pb.HistoryQuery) (*pb.Index, []StoredMessage, error) {
start := time.Now()
defer func() {
elapsed := time.Since(start)
d.log.Info(fmt.Sprintf("Loading records from the DB took %s", elapsed))
}()
sqlQuery := `SELECT id, receiverTimestamp, senderTimestamp, contentTopic, pubsubTopic, payload, version
FROM message
%s
2023-01-04 13:58:14 -04:00
ORDER BY senderTimestamp %s, id %s, pubsubTopic %s, receiverTimestamp %s `
var conditions []string
var parameters []interface{}
2023-01-04 13:58:14 -04:00
paramCnt := 0
if query.PubsubTopic != "" {
2023-01-04 13:58:14 -04:00
paramCnt++
conditions = append(conditions, fmt.Sprintf("pubsubTopic = $%d", paramCnt))
parameters = append(parameters, query.PubsubTopic)
}
if len(query.ContentFilters) != 0 {
var ctPlaceHolder []string
for _, ct := range query.ContentFilters {
if ct.ContentTopic != "" {
2023-01-04 13:58:14 -04:00
paramCnt++
ctPlaceHolder = append(ctPlaceHolder, fmt.Sprintf("$%d", paramCnt))
parameters = append(parameters, ct.ContentTopic)
}
}
conditions = append(conditions, "contentTopic IN ("+strings.Join(ctPlaceHolder, ", ")+")")
}
2022-11-25 16:54:11 -04:00
usesCursor := false
if query.PagingInfo.Cursor != nil {
2022-11-25 16:54:11 -04:00
usesCursor = true
var exists bool
cursorDBKey := NewDBKey(uint64(query.PagingInfo.Cursor.SenderTime), uint64(query.PagingInfo.Cursor.ReceiverTime), query.PagingInfo.Cursor.PubsubTopic, query.PagingInfo.Cursor.Digest)
2023-01-04 13:58:14 -04:00
err := d.db.QueryRow("SELECT EXISTS(SELECT 1 FROM message WHERE id = $1)",
cursorDBKey.Bytes(),
).Scan(&exists)
if err != nil {
return nil, nil, err
}
if exists {
eqOp := ">"
if query.PagingInfo.Direction == pb.PagingInfo_BACKWARD {
eqOp = "<"
}
2023-01-04 13:58:14 -04:00
paramCnt++
conditions = append(conditions, fmt.Sprintf("id %s $%d", eqOp, paramCnt))
parameters = append(parameters, cursorDBKey.Bytes())
} else {
return nil, nil, ErrInvalidCursor
}
}
2022-11-25 16:54:11 -04:00
if query.StartTime != 0 {
if !usesCursor || query.PagingInfo.Direction == pb.PagingInfo_BACKWARD {
2023-01-04 13:58:14 -04:00
paramCnt++
conditions = append(conditions, fmt.Sprintf("id >= $%d", paramCnt))
startTimeDBKey := NewDBKey(uint64(query.StartTime), uint64(query.StartTime), "", []byte{})
2022-11-25 16:54:11 -04:00
parameters = append(parameters, startTimeDBKey.Bytes())
}
}
if query.EndTime != 0 {
if !usesCursor || query.PagingInfo.Direction == pb.PagingInfo_FORWARD {
2023-01-04 13:58:14 -04:00
paramCnt++
conditions = append(conditions, fmt.Sprintf("id <= $%d", paramCnt))
endTimeDBKey := NewDBKey(uint64(query.EndTime), uint64(query.EndTime), "", []byte{})
2022-11-25 16:54:11 -04:00
parameters = append(parameters, endTimeDBKey.Bytes())
}
}
conditionStr := ""
if len(conditions) != 0 {
conditionStr = "WHERE " + strings.Join(conditions, " AND ")
}
orderDirection := "ASC"
if query.PagingInfo.Direction == pb.PagingInfo_BACKWARD {
orderDirection = "DESC"
}
2023-01-04 13:58:14 -04:00
paramCnt++
sqlQuery += fmt.Sprintf("LIMIT $%d", paramCnt)
sqlQuery = fmt.Sprintf(sqlQuery, conditionStr, orderDirection, orderDirection, orderDirection, orderDirection)
stmt, err := d.db.Prepare(sqlQuery)
if err != nil {
return nil, nil, err
}
defer stmt.Close()
2022-10-03 15:26:45 -04:00
pageSize := query.PagingInfo.PageSize + 1
parameters = append(parameters, pageSize)
rows, err := stmt.Query(parameters...)
if err != nil {
return nil, nil, err
}
var result []StoredMessage
for rows.Next() {
record, err := d.GetStoredMessage(rows)
if err != nil {
return nil, nil, err
}
result = append(result, record)
}
defer rows.Close()
2022-11-25 16:54:11 -04:00
var cursor *pb.Index
if len(result) != 0 {
2022-10-03 15:26:45 -04:00
if len(result) > int(query.PagingInfo.PageSize) {
result = result[0:query.PagingInfo.PageSize]
lastMsgIdx := len(result) - 1
cursor = protocol.NewEnvelope(result[lastMsgIdx].Message, result[lastMsgIdx].ReceiverTime, result[lastMsgIdx].PubsubTopic).Index()
}
}
// The retrieved messages list should always be in chronological order
if query.PagingInfo.Direction == pb.PagingInfo_BACKWARD {
for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 {
result[i], result[j] = result[j], result[i]
}
}
return cursor, result, nil
}
2022-07-25 11:28:17 -04:00
// MostRecentTimestamp returns an unix timestamp with the most recent senderTimestamp
// in the message table
func (d *DBStore) MostRecentTimestamp() (int64, error) {
result := sql.NullInt64{}
err := d.db.QueryRow(`SELECT max(senderTimestamp) FROM message`).Scan(&result)
if err != nil && err != sql.ErrNoRows {
return 0, err
}
return result.Int64, nil
}
2022-07-28 15:17:12 -04:00
// Count returns the number of rows in the message table
func (d *DBStore) Count() (int, error) {
var result int
err := d.db.QueryRow(`SELECT COUNT(*) FROM message`).Scan(&result)
if err != nil && err != sql.ErrNoRows {
return 0, err
}
return result, nil
}
2022-07-25 11:28:17 -04:00
// GetAll returns all the stored WakuMessages
2021-10-25 15:41:08 -04:00
func (d *DBStore) GetAll() ([]StoredMessage, error) {
start := time.Now()
defer func() {
elapsed := time.Since(start)
d.log.Info("loading records from the DB", zap.Duration("duration", elapsed))
}()
rows, err := d.db.Query("SELECT id, receiverTimestamp, senderTimestamp, contentTopic, pubsubTopic, payload, version FROM message ORDER BY senderTimestamp ASC")
2021-04-12 13:59:41 -04:00
if err != nil {
return nil, err
}
2021-10-25 15:41:08 -04:00
var result []StoredMessage
2021-04-12 13:59:41 -04:00
defer rows.Close()
for rows.Next() {
record, err := d.GetStoredMessage(rows)
2021-04-12 13:59:41 -04:00
if err != nil {
return nil, err
}
result = append(result, record)
2021-04-12 13:59:41 -04:00
}
d.log.Info("DB returned records", zap.Int("count", len(result)))
2021-04-12 13:59:41 -04:00
err = rows.Err()
if err != nil {
return nil, err
}
return result, nil
}
2022-07-25 11:28:17 -04:00
// GetStoredMessage is a helper function used to convert a `*sql.Rows` into a `StoredMessage`
func (d *DBStore) GetStoredMessage(row *sql.Rows) (StoredMessage, error) {
var id []byte
var receiverTimestamp int64
var senderTimestamp int64
var contentTopic string
var payload []byte
var version uint32
var pubsubTopic string
2022-07-25 11:28:17 -04:00
err := row.Scan(&id, &receiverTimestamp, &senderTimestamp, &contentTopic, &pubsubTopic, &payload, &version)
if err != nil {
d.log.Error("scanning messages from db", zap.Error(err))
return StoredMessage{}, err
}
msg := new(wpb.WakuMessage)
msg.ContentTopic = contentTopic
msg.Payload = payload
msg.Timestamp = senderTimestamp
msg.Version = version
record := StoredMessage{
ID: id,
PubsubTopic: pubsubTopic,
ReceiverTime: receiverTimestamp,
Message: msg,
}
return record, nil
}