mirror of https://github.com/status-im/go-waku.git
587 lines
16 KiB
Go
587 lines
16 KiB
Go
package persistence
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"github.com/waku-org/go-waku/waku/v2/protocol"
|
|
wpb "github.com/waku-org/go-waku/waku/v2/protocol/pb"
|
|
"github.com/waku-org/go-waku/waku/v2/protocol/store/pb"
|
|
"github.com/waku-org/go-waku/waku/v2/timesource"
|
|
"go.uber.org/zap"
|
|
"google.golang.org/protobuf/proto"
|
|
)
|
|
|
|
// MessageProvider is an interface that provides access to store/retrieve messages from a persistence store.
|
|
type MessageProvider interface {
|
|
GetAll() ([]StoredMessage, error)
|
|
Validate(env *protocol.Envelope) error
|
|
Put(env *protocol.Envelope) error
|
|
Query(query *pb.HistoryQuery) ([]StoredMessage, error)
|
|
MostRecentTimestamp() (int64, error)
|
|
Start(ctx context.Context, timesource timesource.Timesource) error
|
|
Stop()
|
|
}
|
|
|
|
// ErrInvalidCursor indicates that an invalid cursor has been passed to access store
|
|
var ErrInvalidCursor = errors.New("invalid cursor")
|
|
|
|
// ErrFutureMessage indicates that a message with timestamp in future was requested to be stored
|
|
var ErrFutureMessage = errors.New("message timestamp in the future")
|
|
|
|
// ErrMessageTooOld indicates that a message that was too old was requested to be stored.
|
|
var ErrMessageTooOld = errors.New("message too old")
|
|
|
|
// WALMode for sqlite.
|
|
const WALMode = "wal"
|
|
|
|
// MaxTimeVariance is the maximum duration in the future allowed for a message timestamp
|
|
const MaxTimeVariance = time.Duration(20) * time.Second
|
|
|
|
// DBStore is a MessageProvider that has a *sql.DB connection
|
|
type DBStore struct {
|
|
MessageProvider
|
|
|
|
db *sql.DB
|
|
migrationFn func(db *sql.DB) error
|
|
|
|
metrics Metrics
|
|
timesource timesource.Timesource
|
|
log *zap.Logger
|
|
|
|
maxMessages int
|
|
maxDuration time.Duration
|
|
|
|
enableMigrations bool
|
|
|
|
wg sync.WaitGroup
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
// StoredMessage is the format of the message stored in persistence store
|
|
type StoredMessage struct {
|
|
ID []byte
|
|
PubsubTopic string
|
|
ReceiverTime int64
|
|
Message *wpb.WakuMessage
|
|
}
|
|
|
|
// DBOption is an optional setting that can be used to configure the DBStore
|
|
type DBOption func(*DBStore) error
|
|
|
|
// WithDB is a DBOption that lets you use any custom *sql.DB with a DBStore.
|
|
func WithDB(db *sql.DB) DBOption {
|
|
return func(d *DBStore) error {
|
|
d.db = db
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// ConnectionPoolOptions is the options to be used for DB connection pooling
|
|
type ConnectionPoolOptions struct {
|
|
MaxOpenConnections int
|
|
MaxIdleConnections int
|
|
ConnectionMaxLifetime time.Duration
|
|
ConnectionMaxIdleTime time.Duration
|
|
}
|
|
|
|
// WithDriver is a DBOption that will open a *sql.DB connection
|
|
func WithDriver(driverName string, datasourceName string, connectionPoolOptions ...ConnectionPoolOptions) DBOption {
|
|
return func(d *DBStore) error {
|
|
db, err := sql.Open(driverName, datasourceName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(connectionPoolOptions) != 0 {
|
|
db.SetConnMaxIdleTime(connectionPoolOptions[0].ConnectionMaxIdleTime)
|
|
db.SetConnMaxLifetime(connectionPoolOptions[0].ConnectionMaxLifetime)
|
|
db.SetMaxIdleConns(connectionPoolOptions[0].MaxIdleConnections)
|
|
db.SetMaxOpenConns(connectionPoolOptions[0].MaxOpenConnections)
|
|
}
|
|
|
|
d.db = db
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// WithRetentionPolicy is a DBOption that specifies the max number of messages
|
|
// to be stored and duration before they're removed from the message store
|
|
func WithRetentionPolicy(maxMessages int, maxDuration time.Duration) DBOption {
|
|
return func(d *DBStore) error {
|
|
d.maxDuration = maxDuration
|
|
d.maxMessages = maxMessages
|
|
return nil
|
|
}
|
|
}
|
|
|
|
type MigrationFn func(db *sql.DB) error
|
|
|
|
// WithMigrations is a DBOption used to determine if migrations should
|
|
// be executed, and what driver to use
|
|
func WithMigrations(migrationFn MigrationFn) DBOption {
|
|
return func(d *DBStore) error {
|
|
d.enableMigrations = true
|
|
d.migrationFn = migrationFn
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// DefaultOptions returns the default DBoptions to be used.
|
|
func DefaultOptions() []DBOption {
|
|
return []DBOption{}
|
|
}
|
|
|
|
// Creates a new DB store using the db specified via options.
|
|
// It will create a messages table if it does not exist and
|
|
// clean up records according to the retention policy used
|
|
func NewDBStore(reg prometheus.Registerer, log *zap.Logger, options ...DBOption) (*DBStore, error) {
|
|
result := new(DBStore)
|
|
result.log = log.Named("dbstore")
|
|
result.metrics = newMetrics(reg)
|
|
|
|
optList := DefaultOptions()
|
|
optList = append(optList, options...)
|
|
|
|
for _, opt := range optList {
|
|
err := opt(result)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if result.enableMigrations {
|
|
err := result.migrationFn(result.db)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// Start starts the store server functionality
|
|
func (d *DBStore) Start(ctx context.Context, timesource timesource.Timesource) error {
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
|
|
d.cancel = cancel
|
|
d.timesource = timesource
|
|
|
|
err := d.cleanOlderRecords(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
d.wg.Add(2)
|
|
go d.checkForOlderRecords(ctx, 60*time.Second)
|
|
go d.updateMetrics(ctx)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (d *DBStore) updateMetrics(ctx context.Context) {
|
|
ticker := time.NewTicker(5 * time.Second)
|
|
defer ticker.Stop()
|
|
defer d.wg.Done()
|
|
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
msgCount, err := d.Count()
|
|
if err != nil {
|
|
d.log.Error("updating store metrics", zap.Error(err))
|
|
} else {
|
|
d.metrics.RecordMessage(msgCount)
|
|
}
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (d *DBStore) cleanOlderRecords(ctx context.Context) error {
|
|
d.log.Info("Cleaning older records...")
|
|
|
|
// Delete older messages
|
|
if d.maxDuration > 0 {
|
|
start := time.Now()
|
|
sqlStmt := `DELETE FROM message WHERE receiverTimestamp < $1`
|
|
_, err := d.db.Exec(sqlStmt, d.timesource.Now().Add(-d.maxDuration).UnixNano())
|
|
if err != nil {
|
|
d.metrics.RecordError(retPolicyFailure)
|
|
return err
|
|
}
|
|
elapsed := time.Since(start)
|
|
d.log.Debug("deleting older records from the DB", zap.Duration("duration", elapsed))
|
|
}
|
|
|
|
// Limit number of records to a max N
|
|
if d.maxMessages > 0 {
|
|
start := time.Now()
|
|
|
|
_, err := d.db.Exec(d.getDeleteOldRowsQuery(), d.maxMessages)
|
|
if err != nil {
|
|
d.metrics.RecordError(retPolicyFailure)
|
|
return err
|
|
}
|
|
elapsed := time.Since(start)
|
|
d.log.Debug("deleting excess records from the DB", zap.Duration("duration", elapsed))
|
|
}
|
|
|
|
d.log.Info("Older records removed")
|
|
|
|
return nil
|
|
}
|
|
|
|
func (d *DBStore) getDeleteOldRowsQuery() string {
|
|
sqlStmt := `DELETE FROM message WHERE id IN (SELECT id FROM message ORDER BY receiverTimestamp DESC %s OFFSET $1)`
|
|
switch GetDriverType(d.db) {
|
|
case SQLiteDriver:
|
|
sqlStmt = fmt.Sprintf(sqlStmt, "LIMIT -1")
|
|
case PostgresDriver:
|
|
sqlStmt = fmt.Sprintf(sqlStmt, "")
|
|
}
|
|
return sqlStmt
|
|
}
|
|
|
|
func (d *DBStore) checkForOlderRecords(ctx context.Context, t time.Duration) {
|
|
defer d.wg.Done()
|
|
|
|
ticker := time.NewTicker(t)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
err := d.cleanOlderRecords(ctx)
|
|
if err != nil {
|
|
d.log.Error("cleaning older records", zap.Error(err))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Stop closes a DB connection
|
|
func (d *DBStore) Stop() {
|
|
if d.cancel == nil {
|
|
return
|
|
}
|
|
|
|
d.cancel()
|
|
d.wg.Wait()
|
|
d.db.Close()
|
|
}
|
|
|
|
// Validate validates the message to be stored against possible fradulent conditions.
|
|
func (d *DBStore) Validate(env *protocol.Envelope) error {
|
|
n := time.Unix(0, env.Index().ReceiverTime)
|
|
upperBound := n.Add(MaxTimeVariance)
|
|
lowerBound := n.Add(-MaxTimeVariance)
|
|
|
|
// Ensure that messages don't "jump" to the front of the queue with future timestamps
|
|
if env.Message().GetTimestamp() > upperBound.UnixNano() {
|
|
return ErrFutureMessage
|
|
}
|
|
|
|
if env.Message().GetTimestamp() < lowerBound.UnixNano() {
|
|
return ErrMessageTooOld
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Put inserts a WakuMessage into the DB
|
|
func (d *DBStore) Put(env *protocol.Envelope) error {
|
|
stmt, err := d.db.Prepare("INSERT INTO message (id, receiverTimestamp, senderTimestamp, contentTopic, pubsubTopic, payload, version) VALUES ($1, $2, $3, $4, $5, $6, $7)")
|
|
if err != nil {
|
|
d.metrics.RecordError(insertFailure)
|
|
return err
|
|
}
|
|
|
|
cursor := env.Index()
|
|
dbKey := NewDBKey(uint64(cursor.SenderTime), uint64(cursor.ReceiverTime), env.PubsubTopic(), env.Index().Digest)
|
|
|
|
start := time.Now()
|
|
_, err = stmt.Exec(dbKey.Bytes(), cursor.ReceiverTime, env.Message().GetTimestamp(), env.Message().ContentTopic, env.PubsubTopic(), env.Message().Payload, env.Message().GetVersion())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
d.metrics.RecordInsertDuration(time.Since(start))
|
|
|
|
err = stmt.Close()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (d *DBStore) handleQueryCursor(query *pb.HistoryQuery, paramCnt *int, conditions []string, parameters []interface{}) ([]string, []interface{}, error) {
|
|
usesCursor := false
|
|
if query.PagingInfo.Cursor != nil {
|
|
usesCursor = true
|
|
var exists bool
|
|
cursorDBKey := NewDBKey(uint64(query.PagingInfo.Cursor.SenderTime), uint64(query.PagingInfo.Cursor.ReceiverTime), query.PagingInfo.Cursor.PubsubTopic, query.PagingInfo.Cursor.Digest)
|
|
|
|
err := d.db.QueryRow("SELECT EXISTS(SELECT 1 FROM message WHERE id = $1)",
|
|
cursorDBKey.Bytes(),
|
|
).Scan(&exists)
|
|
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
if exists {
|
|
eqOp := ">"
|
|
if query.PagingInfo.Direction == pb.PagingInfo_BACKWARD {
|
|
eqOp = "<"
|
|
}
|
|
*paramCnt++
|
|
conditions = append(conditions, fmt.Sprintf("id %s $%d", eqOp, *paramCnt))
|
|
|
|
parameters = append(parameters, cursorDBKey.Bytes())
|
|
} else {
|
|
return nil, nil, ErrInvalidCursor
|
|
}
|
|
}
|
|
|
|
handleTimeParam := func(time int64, op string) {
|
|
*paramCnt++
|
|
conditions = append(conditions, fmt.Sprintf("id %s $%d", op, *paramCnt))
|
|
timeDBKey := NewDBKey(uint64(time), 0, "", []byte{})
|
|
parameters = append(parameters, timeDBKey.Bytes())
|
|
}
|
|
|
|
startTime := query.GetStartTime()
|
|
if startTime != 0 {
|
|
if !usesCursor || query.PagingInfo.Direction == pb.PagingInfo_BACKWARD {
|
|
handleTimeParam(startTime, ">=")
|
|
}
|
|
}
|
|
|
|
endTime := query.GetEndTime()
|
|
if endTime != 0 {
|
|
if !usesCursor || query.PagingInfo.Direction == pb.PagingInfo_FORWARD {
|
|
handleTimeParam(endTime+1, "<")
|
|
}
|
|
}
|
|
return conditions, parameters, nil
|
|
}
|
|
|
|
func (d *DBStore) prepareQuerySQL(query *pb.HistoryQuery) (string, []interface{}, error) {
|
|
sqlQuery := `SELECT id, receiverTimestamp, senderTimestamp, contentTopic, pubsubTopic, payload, version
|
|
FROM message
|
|
%s
|
|
ORDER BY senderTimestamp %s, id %s, pubsubTopic %s, receiverTimestamp %s `
|
|
|
|
var conditions []string
|
|
//var parameters []interface{}
|
|
parameters := make([]interface{}, 0) //Allocating as a slice so that references get passed rather than value
|
|
paramCnt := 0
|
|
|
|
if query.PubsubTopic != "" {
|
|
paramCnt++
|
|
conditions = append(conditions, fmt.Sprintf("pubsubTopic = $%d", paramCnt))
|
|
parameters = append(parameters, query.PubsubTopic)
|
|
}
|
|
|
|
if len(query.ContentFilters) != 0 {
|
|
var ctPlaceHolder []string
|
|
for _, ct := range query.ContentFilters {
|
|
if ct.ContentTopic != "" {
|
|
paramCnt++
|
|
ctPlaceHolder = append(ctPlaceHolder, fmt.Sprintf("$%d", paramCnt))
|
|
parameters = append(parameters, ct.ContentTopic)
|
|
}
|
|
}
|
|
conditions = append(conditions, "contentTopic IN ("+strings.Join(ctPlaceHolder, ", ")+")")
|
|
}
|
|
|
|
conditions, parameters, err := d.handleQueryCursor(query, ¶mCnt, conditions, parameters)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
conditionStr := ""
|
|
if len(conditions) != 0 {
|
|
conditionStr = "WHERE " + strings.Join(conditions, " AND ")
|
|
}
|
|
|
|
orderDirection := "ASC"
|
|
if query.PagingInfo.Direction == pb.PagingInfo_BACKWARD {
|
|
orderDirection = "DESC"
|
|
}
|
|
|
|
paramCnt++
|
|
|
|
sqlQuery += fmt.Sprintf("LIMIT $%d", paramCnt)
|
|
// Always search for _max page size_ + 1. If the extra row does not exist, do not return pagination info.
|
|
pageSize := query.PagingInfo.PageSize + 1
|
|
parameters = append(parameters, pageSize)
|
|
|
|
sqlQuery = fmt.Sprintf(sqlQuery, conditionStr, orderDirection, orderDirection, orderDirection, orderDirection)
|
|
d.log.Info(fmt.Sprintf("sqlQuery: %s", sqlQuery))
|
|
|
|
return sqlQuery, parameters, nil
|
|
|
|
}
|
|
|
|
// Query retrieves messages from the DB
|
|
func (d *DBStore) Query(query *pb.HistoryQuery) (*pb.Index, []StoredMessage, error) {
|
|
start := time.Now()
|
|
defer func() {
|
|
elapsed := time.Since(start)
|
|
d.log.Info(fmt.Sprintf("Loading records from the DB took %s", elapsed))
|
|
}()
|
|
|
|
sqlQuery, parameters, err := d.prepareQuerySQL(query)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
stmt, err := d.db.Prepare(sqlQuery)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
defer stmt.Close()
|
|
//
|
|
measurementStart := time.Now()
|
|
rows, err := stmt.Query(parameters...)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
d.metrics.RecordQueryDuration(time.Since(measurementStart))
|
|
|
|
var result []StoredMessage
|
|
for rows.Next() {
|
|
record, err := d.GetStoredMessage(rows)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
result = append(result, record)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var cursor *pb.Index
|
|
if len(result) != 0 {
|
|
// since there are more rows than pagingInfo.PageSize, we need to return a cursor, for pagination
|
|
if len(result) > int(query.PagingInfo.PageSize) {
|
|
result = result[0:query.PagingInfo.PageSize]
|
|
lastMsgIdx := len(result) - 1
|
|
cursor = protocol.NewEnvelope(result[lastMsgIdx].Message, result[lastMsgIdx].ReceiverTime, result[lastMsgIdx].PubsubTopic).Index()
|
|
}
|
|
}
|
|
|
|
// The retrieved messages list should always be in chronological order
|
|
if query.PagingInfo.Direction == pb.PagingInfo_BACKWARD {
|
|
for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 {
|
|
result[i], result[j] = result[j], result[i]
|
|
}
|
|
}
|
|
|
|
return cursor, result, nil
|
|
}
|
|
|
|
// MostRecentTimestamp returns an unix timestamp with the most recent senderTimestamp
|
|
// in the message table
|
|
func (d *DBStore) MostRecentTimestamp() (int64, error) {
|
|
result := sql.NullInt64{}
|
|
|
|
err := d.db.QueryRow(`SELECT max(senderTimestamp) FROM message`).Scan(&result)
|
|
if err != nil && err != sql.ErrNoRows {
|
|
return 0, err
|
|
}
|
|
return result.Int64, nil
|
|
}
|
|
|
|
// Count returns the number of rows in the message table
|
|
func (d *DBStore) Count() (int, error) {
|
|
var result int
|
|
err := d.db.QueryRow(`SELECT COUNT(*) FROM message`).Scan(&result)
|
|
if err != nil && err != sql.ErrNoRows {
|
|
return 0, err
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
// GetAll returns all the stored WakuMessages
|
|
func (d *DBStore) GetAll() ([]StoredMessage, error) {
|
|
start := time.Now()
|
|
defer func() {
|
|
elapsed := time.Since(start)
|
|
d.log.Info("loading records from the DB", zap.Duration("duration", elapsed))
|
|
}()
|
|
|
|
rows, err := d.db.Query("SELECT id, receiverTimestamp, senderTimestamp, contentTopic, pubsubTopic, payload, version FROM message ORDER BY senderTimestamp ASC")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var result []StoredMessage
|
|
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
record, err := d.GetStoredMessage(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
result = append(result, record)
|
|
}
|
|
|
|
d.log.Info("DB returned records", zap.Int("count", len(result)))
|
|
|
|
err = rows.Err()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// GetStoredMessage is a helper function used to convert a `*sql.Rows` into a `StoredMessage`
|
|
func (d *DBStore) GetStoredMessage(row *sql.Rows) (StoredMessage, error) {
|
|
var id []byte
|
|
var receiverTimestamp int64
|
|
var senderTimestamp int64
|
|
var contentTopic string
|
|
var payload []byte
|
|
var version uint32
|
|
var pubsubTopic string
|
|
|
|
err := row.Scan(&id, &receiverTimestamp, &senderTimestamp, &contentTopic, &pubsubTopic, &payload, &version)
|
|
if err != nil {
|
|
d.log.Error("scanning messages from db", zap.Error(err))
|
|
return StoredMessage{}, err
|
|
}
|
|
|
|
msg := new(wpb.WakuMessage)
|
|
msg.ContentTopic = contentTopic
|
|
msg.Payload = payload
|
|
|
|
if senderTimestamp != 0 {
|
|
msg.Timestamp = proto.Int64(senderTimestamp)
|
|
}
|
|
|
|
if version > 0 {
|
|
msg.Version = proto.Uint32(version)
|
|
}
|
|
|
|
record := StoredMessage{
|
|
ID: id,
|
|
PubsubTopic: pubsubTopic,
|
|
ReceiverTime: receiverTimestamp,
|
|
Message: msg,
|
|
}
|
|
|
|
return record, nil
|
|
}
|