refactor: use db for serving history queries (#243)

This commit is contained in:
Richard Ramos 2022-05-30 14:48:22 -04:00 committed by GitHub
parent 7c44369def
commit 7c0206684f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 628 additions and 816 deletions

2
.gitignore vendored
View File

@ -1,4 +1,6 @@
*.db *.db
*.db-shm
*.db-wal
nodekey nodekey
# Binaries for programs and plugins # Binaries for programs and plugins

View File

@ -9,18 +9,21 @@ import (
_ "github.com/mattn/go-sqlite3" // Blank import to register the sqlite3 driver _ "github.com/mattn/go-sqlite3" // Blank import to register the sqlite3 driver
"github.com/status-im/go-waku/waku/persistence"
"github.com/status-im/go-waku/waku/v2/protocol" "github.com/status-im/go-waku/waku/v2/protocol"
"github.com/status-im/go-waku/waku/v2/protocol/pb" "github.com/status-im/go-waku/waku/v2/protocol/pb"
) )
const secondsMonth = int64(30 * time.Hour * 24)
func genRandomBytes(size int) (blk []byte, err error) { func genRandomBytes(size int) (blk []byte, err error) {
blk = make([]byte, size) blk = make([]byte, size)
_, err = rand.Read(blk) _, err = rand.Read(blk)
return return
} }
func genRandomTimestamp(now int64, last30d int64) int64 { func genRandomTimestamp(t30daysAgo int64) int64 {
return rand.Int63n(last30d) + now return rand.Int63n(secondsMonth) + t30daysAgo
} }
func genRandomContentTopic(n int) string { func genRandomContentTopic(n int) string {
@ -38,7 +41,10 @@ func newdb(path string) (*sql.DB, error) {
} }
func createTable(db *sql.DB) error { func createTable(db *sql.DB) error {
sqlStmt := `CREATE TABLE IF NOT EXISTS message ( sqlStmt := `
PRAGMA journal_mode=WAL;
CREATE TABLE IF NOT EXISTS message (
id BLOB, id BLOB,
receiverTimestamp INTEGER NOT NULL, receiverTimestamp INTEGER NOT NULL,
senderTimestamp INTEGER NOT NULL, senderTimestamp INTEGER NOT NULL,
@ -46,7 +52,7 @@ func createTable(db *sql.DB) error {
pubsubTopic BLOB NOT NULL, pubsubTopic BLOB NOT NULL,
payload BLOB, payload BLOB,
version INTEGER NOT NULL DEFAULT 0, version INTEGER NOT NULL DEFAULT 0,
CONSTRAINT messageIndex PRIMARY KEY (senderTimestamp, id, pubsubTopic) CONSTRAINT messageIndex PRIMARY KEY (id, pubsubTopic)
) WITHOUT ROWID; ) WITHOUT ROWID;
CREATE INDEX IF NOT EXISTS message_senderTimestamp ON message(senderTimestamp); CREATE INDEX IF NOT EXISTS message_senderTimestamp ON message(senderTimestamp);
@ -71,6 +77,7 @@ func main() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
defer db.Close()
query := "INSERT INTO message (id, receiverTimestamp, senderTimestamp, contentTopic, pubsubTopic, payload, version) VALUES (?, ?, ?, ?, ?, ?, ?)" query := "INSERT INTO message (id, receiverTimestamp, senderTimestamp, contentTopic, pubsubTopic, payload, version) VALUES (?, ?, ?, ?, ?, ?, ?)"
@ -89,8 +96,7 @@ func main() {
panic(err) panic(err)
} }
last30d := time.Now().UnixNano() - time.Now().Add(-30*time.Hour*24).UnixNano() t30daysAgo := time.Now().UnixNano() - secondsMonth
now := time.Now().Add(-1 * time.Minute).UnixNano()
pubsubTopic := protocol.DefaultPubsubTopic().String() pubsubTopic := protocol.DefaultPubsubTopic().String()
for i := 1; i <= N; i++ { for i := 1; i <= N; i++ {
@ -123,16 +129,14 @@ func main() {
msg := pb.WakuMessage{ msg := pb.WakuMessage{
Version: 0, Version: 0,
ContentTopic: genRandomContentTopic(i), ContentTopic: genRandomContentTopic(i),
Timestamp: genRandomTimestamp(now, last30d), Timestamp: genRandomTimestamp(t30daysAgo),
Payload: randPayload, Payload: randPayload,
} }
hash, err := msg.Hash() envelope := protocol.NewEnvelope(&msg, msg.Timestamp, pubsubTopic)
if err != nil { dbKey := persistence.NewDBKey(uint64(msg.Timestamp), pubsubTopic, envelope.Index().Digest)
panic(err)
}
_, err = stmt.Exec(hash, msg.Timestamp, msg.Timestamp, msg.ContentTopic, pubsubTopic, msg.Payload, msg.Version) _, err = stmt.Exec(dbKey.Bytes(), msg.Timestamp, msg.Timestamp, msg.ContentTopic, pubsubTopic, msg.Payload, msg.Version)
if err != nil { if err != nil {
panic(err) panic(err)
} }

View File

@ -95,10 +95,15 @@ func Execute(options Options) {
} }
var db *sql.DB var db *sql.DB
if options.UseDB { if options.UseDB {
db, err = sqlite.NewDB(options.DBPath) db, err = sqlite.NewDB(options.DBPath)
failOnErr(err, "Could not connect to DB") failOnErr(err, "Could not connect to DB")
logger.Debug("using database: ", zap.String("path", options.DBPath))
} else {
db, err = sqlite.NewDB(":memory:")
failOnErr(err, "Could not create in-memory DB")
logger.Debug("using in-memory database")
} }
ctx := context.Background() ctx := context.Background()
@ -189,13 +194,9 @@ func Execute(options Options) {
if options.Store.Enable { if options.Store.Enable {
nodeOpts = append(nodeOpts, node.WithWakuStoreAndRetentionPolicy(options.Store.ShouldResume, options.Store.RetentionMaxDaysDuration(), options.Store.RetentionMaxMessages)) nodeOpts = append(nodeOpts, node.WithWakuStoreAndRetentionPolicy(options.Store.ShouldResume, options.Store.RetentionMaxDaysDuration(), options.Store.RetentionMaxMessages))
if options.UseDB { dbStore, err := persistence.NewDBStore(logger, persistence.WithDB(db), persistence.WithRetentionPolicy(options.Store.RetentionMaxMessages, options.Store.RetentionMaxDaysDuration()))
dbStore, err := persistence.NewDBStore(logger, persistence.WithDB(db), persistence.WithRetentionPolicy(options.Store.RetentionMaxMessages, options.Store.RetentionMaxDaysDuration())) failOnErr(err, "DBStore")
failOnErr(err, "DBStore") nodeOpts = append(nodeOpts, node.WithMessageProvider(dbStore))
nodeOpts = append(nodeOpts, node.WithMessageProvider(dbStore))
} else {
nodeOpts = append(nodeOpts, node.WithMessageProvider(nil))
}
} }
if options.LightPush.Enable { if options.LightPush.Enable {
@ -303,7 +304,7 @@ func Execute(options Options) {
rpcServer.Start() rpcServer.Start()
} }
utils.Logger().Info("Node setup complete") logger.Info("Node setup complete")
// Wait for a SIGINT or SIGTERM signal // Wait for a SIGINT or SIGTERM signal
ch := make(chan os.Signal, 1) ch := make(chan os.Signal, 1)

View File

@ -0,0 +1,50 @@
package persistence
import (
"crypto/sha256"
"encoding/binary"
"errors"
)
const (
TimestampLength = 8
HashLength = 32
DigestLength = HashLength
PubsubTopicLength = HashLength
DBKeyLength = TimestampLength + PubsubTopicLength + DigestLength
)
type Hash [HashLength]byte
var (
// ErrInvalidByteSize is returned when DBKey can't be created
// from a byte slice because it has invalid length.
ErrInvalidByteSize = errors.New("byte slice has invalid length")
)
// DBKey key to be stored in a db.
type DBKey struct {
raw []byte
}
// Bytes returns a bytes representation of the DBKey.
func (k *DBKey) Bytes() []byte {
return k.raw
}
func (k *DBKey) Digest() []byte {
return k.raw[TimestampLength+PubsubTopicLength : TimestampLength+PubsubTopicLength+DigestLength]
}
// NewDBKey creates a new DBKey with the given values.
func NewDBKey(timestamp uint64, pubsubTopic string, digest []byte) *DBKey {
pubSubHash := sha256.Sum256([]byte(pubsubTopic))
var k DBKey
k.raw = make([]byte, DBKeyLength)
binary.BigEndian.PutUint64(k.raw, timestamp)
copy(k.raw[TimestampLength:], pubSubHash[:])
copy(k.raw[TimestampLength+PubsubTopicLength:], digest)
return &k
}

View File

@ -2,8 +2,13 @@ package persistence
import ( import (
"database/sql" "database/sql"
"errors"
"fmt"
"strings"
"sync"
"time" "time"
"github.com/status-im/go-waku/waku/v2/protocol"
"github.com/status-im/go-waku/waku/v2/protocol/pb" "github.com/status-im/go-waku/waku/v2/protocol/pb"
"github.com/status-im/go-waku/waku/v2/utils" "github.com/status-im/go-waku/waku/v2/utils"
"go.uber.org/zap" "go.uber.org/zap"
@ -11,10 +16,17 @@ import (
type MessageProvider interface { type MessageProvider interface {
GetAll() ([]StoredMessage, error) GetAll() ([]StoredMessage, error)
Put(cursor *pb.Index, pubsubTopic string, message *pb.WakuMessage) error Put(env *protocol.Envelope) error
Query(query *pb.HistoryQuery) ([]StoredMessage, error)
MostRecentTimestamp() (int64, error)
Stop() Stop()
} }
var ErrInvalidCursor = errors.New("invalid cursor")
// WALMode for sqlite.
const WALMode = "wal"
// DBStore is a MessageProvider that has a *sql.DB connection // DBStore is a MessageProvider that has a *sql.DB connection
type DBStore struct { type DBStore struct {
MessageProvider MessageProvider
@ -23,6 +35,9 @@ type DBStore struct {
maxMessages int maxMessages int
maxDuration time.Duration maxDuration time.Duration
wg sync.WaitGroup
quit chan struct{}
} }
type StoredMessage struct { type StoredMessage struct {
@ -69,6 +84,7 @@ func WithRetentionPolicy(maxMessages int, maxDuration time.Duration) DBOption {
func NewDBStore(log *zap.Logger, options ...DBOption) (*DBStore, error) { func NewDBStore(log *zap.Logger, options ...DBOption) (*DBStore, error) {
result := new(DBStore) result := new(DBStore)
result.log = log.Named("dbstore") result.log = log.Named("dbstore")
result.quit = make(chan struct{})
for _, opt := range options { for _, opt := range options {
err := opt(result) err := opt(result)
@ -77,7 +93,30 @@ func NewDBStore(log *zap.Logger, options ...DBOption) (*DBStore, error) {
} }
} }
err := result.createTable() // Disable concurrent access as not supported by the driver
result.db.SetMaxOpenConns(1)
var seq string
var name string
var file string // file will be empty if DB is :memory"
err := result.db.QueryRow("PRAGMA database_list").Scan(&seq, &name, &file)
if err != nil {
return nil, err
}
// readers do not block writers and faster i/o operations
// https://www.sqlite.org/draft/wal.html
// must be set after db is encrypted
var mode string
err = result.db.QueryRow("PRAGMA journal_mode=WAL").Scan(&mode)
if err != nil {
return nil, err
}
if mode != WALMode && file != "" {
return nil, fmt.Errorf("unable to set journal_mode to WAL. actual mode %s", mode)
}
err = result.createTable()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -87,6 +126,9 @@ func NewDBStore(log *zap.Logger, options ...DBOption) (*DBStore, error) {
return nil, err return nil, err
} }
result.wg.Add(1)
go result.checkForOlderRecords(10 * time.Second) // is 10s okay?
return result, nil return result, nil
} }
@ -99,7 +141,7 @@ func (d *DBStore) createTable() error {
pubsubTopic BLOB NOT NULL, pubsubTopic BLOB NOT NULL,
payload BLOB, payload BLOB,
version INTEGER NOT NULL DEFAULT 0, version INTEGER NOT NULL DEFAULT 0,
CONSTRAINT messageIndex PRIMARY KEY (senderTimestamp, id, pubsubTopic) CONSTRAINT messageIndex PRIMARY KEY (id, pubsubTopic)
) WITHOUT ROWID; ) WITHOUT ROWID;
CREATE INDEX IF NOT EXISTS message_senderTimestamp ON message(senderTimestamp); CREATE INDEX IF NOT EXISTS message_senderTimestamp ON message(senderTimestamp);
@ -141,18 +183,39 @@ func (d *DBStore) cleanOlderRecords() error {
return nil return nil
} }
func (d *DBStore) checkForOlderRecords(t time.Duration) {
defer d.wg.Done()
ticker := time.NewTicker(t)
defer ticker.Stop()
for {
select {
case <-d.quit:
return
case <-ticker.C:
d.cleanOlderRecords()
}
}
}
// Closes a DB connection // Closes a DB connection
func (d *DBStore) Stop() { func (d *DBStore) Stop() {
d.quit <- struct{}{}
d.wg.Wait()
d.db.Close() d.db.Close()
} }
// Inserts a WakuMessage into the DB // Inserts a WakuMessage into the DB
func (d *DBStore) Put(cursor *pb.Index, pubsubTopic string, message *pb.WakuMessage) error { func (d *DBStore) Put(env *protocol.Envelope) error {
stmt, err := d.db.Prepare("INSERT INTO message (id, receiverTimestamp, senderTimestamp, contentTopic, pubsubTopic, payload, version) VALUES (?, ?, ?, ?, ?, ?, ?)") stmt, err := d.db.Prepare("INSERT INTO message (id, receiverTimestamp, senderTimestamp, contentTopic, pubsubTopic, payload, version) VALUES (?, ?, ?, ?, ?, ?, ?)")
if err != nil { if err != nil {
return err return err
} }
_, err = stmt.Exec(cursor.Digest, cursor.ReceiverTime, message.Timestamp, message.ContentTopic, pubsubTopic, message.Payload, message.Version)
cursor := env.Index()
dbKey := NewDBKey(uint64(cursor.SenderTime), env.PubsubTopic(), env.Index().Digest)
_, err = stmt.Exec(dbKey.Bytes(), cursor.ReceiverTime, env.Message().Timestamp, env.Message().ContentTopic, env.PubsubTopic(), env.Message().Payload, env.Message().Version)
if err != nil { if err != nil {
return err return err
} }
@ -165,6 +228,124 @@ func (d *DBStore) Put(cursor *pb.Index, pubsubTopic string, message *pb.WakuMess
return nil return nil
} }
func (d *DBStore) Query(query *pb.HistoryQuery) ([]StoredMessage, error) {
start := time.Now()
defer func() {
elapsed := time.Since(start)
d.log.Info(fmt.Sprintf("Loading records from the DB took %s", elapsed))
}()
sqlQuery := `SELECT id, receiverTimestamp, senderTimestamp, contentTopic, pubsubTopic, payload, version
FROM message
%s
ORDER BY senderTimestamp %s, pubsubTopic, id
LIMIT ?`
var conditions []string
var parameters []interface{}
if query.PubsubTopic != "" {
conditions = append(conditions, "pubsubTopic = ?")
parameters = append(parameters, query.PubsubTopic)
}
if query.StartTime != 0 {
conditions = append(conditions, "id >= ?")
startTimeDBKey := NewDBKey(uint64(query.StartTime), "", []byte{})
parameters = append(parameters, startTimeDBKey.Bytes())
}
if query.EndTime != 0 {
conditions = append(conditions, "id <= ?")
endTimeDBKey := NewDBKey(uint64(query.EndTime), "", []byte{})
parameters = append(parameters, endTimeDBKey.Bytes())
}
if len(query.ContentFilters) != 0 {
var ctPlaceHolder []string
for _, ct := range query.ContentFilters {
if ct.ContentTopic != "" {
ctPlaceHolder = append(ctPlaceHolder, "?")
parameters = append(parameters, ct.ContentTopic)
}
}
conditions = append(conditions, "contentTopic IN ("+strings.Join(ctPlaceHolder, ", ")+")")
}
if query.PagingInfo.Cursor != nil {
var exists bool
cursorDBKey := NewDBKey(uint64(query.PagingInfo.Cursor.SenderTime), query.PagingInfo.Cursor.PubsubTopic, query.PagingInfo.Cursor.Digest)
err := d.db.QueryRow("SELECT EXISTS(SELECT 1 FROM message WHERE id = ?)",
cursorDBKey.Bytes(),
).Scan(&exists)
if err != nil {
return nil, err
}
if exists {
eqOp := ">"
if query.PagingInfo.Direction == pb.PagingInfo_BACKWARD {
eqOp = "<"
}
conditions = append(conditions, fmt.Sprintf("id %s ?", eqOp))
parameters = append(parameters, cursorDBKey.Bytes())
} else {
return nil, ErrInvalidCursor
}
}
conditionStr := ""
if len(conditions) != 0 {
conditionStr = "WHERE " + strings.Join(conditions, " AND ")
}
orderDirection := "ASC"
if query.PagingInfo.Direction == pb.PagingInfo_BACKWARD {
orderDirection = "DESC"
}
sqlQuery = fmt.Sprintf(sqlQuery, conditionStr, orderDirection)
stmt, err := d.db.Prepare(sqlQuery)
if err != nil {
return nil, err
}
defer stmt.Close()
parameters = append(parameters, query.PagingInfo.PageSize)
rows, err := stmt.Query(parameters...)
if err != nil {
return nil, err
}
var result []StoredMessage
for rows.Next() {
record, err := d.GetStoredMessage(rows)
if err != nil {
return nil, err
}
result = append(result, record)
}
defer rows.Close()
return result, nil
}
func (d *DBStore) MostRecentTimestamp() (int64, error) {
result := sql.NullInt64{}
err := d.db.QueryRow(`SELECT max(senderTimestamp) FROM message`).Scan(&result)
if err != nil && err != sql.ErrNoRows {
return 0, err
}
return result.Int64, nil
}
// Returns all the stored WakuMessages // Returns all the stored WakuMessages
func (d *DBStore) GetAll() ([]StoredMessage, error) { func (d *DBStore) GetAll() ([]StoredMessage, error) {
start := time.Now() start := time.Now()
@ -183,32 +364,10 @@ func (d *DBStore) GetAll() ([]StoredMessage, error) {
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
var id []byte record, err := d.GetStoredMessage(rows)
var receiverTimestamp int64
var senderTimestamp int64
var contentTopic string
var payload []byte
var version uint32
var pubsubTopic string
err = rows.Scan(&id, &receiverTimestamp, &senderTimestamp, &contentTopic, &pubsubTopic, &payload, &version)
if err != nil { if err != nil {
d.log.Fatal("scanning next row", zap.Error(err)) return nil, err
} }
msg := new(pb.WakuMessage)
msg.ContentTopic = contentTopic
msg.Payload = payload
msg.Timestamp = senderTimestamp
msg.Version = version
record := StoredMessage{
ID: id,
PubsubTopic: pubsubTopic,
ReceiverTime: receiverTimestamp,
Message: msg,
}
result = append(result, record) result = append(result, record)
} }
@ -221,3 +380,34 @@ func (d *DBStore) GetAll() ([]StoredMessage, error) {
return result, nil return result, nil
} }
func (d *DBStore) GetStoredMessage(rows *sql.Rows) (StoredMessage, error) {
var id []byte
var receiverTimestamp int64
var senderTimestamp int64
var contentTopic string
var payload []byte
var version uint32
var pubsubTopic string
err := rows.Scan(&id, &receiverTimestamp, &senderTimestamp, &contentTopic, &pubsubTopic, &payload, &version)
if err != nil {
d.log.Error("scanning messages from db", zap.Error(err))
return StoredMessage{}, err
}
msg := new(pb.WakuMessage)
msg.ContentTopic = contentTopic
msg.Payload = payload
msg.Timestamp = senderTimestamp
msg.Version = version
record := StoredMessage{
ID: id,
PubsubTopic: pubsubTopic,
ReceiverTime: receiverTimestamp,
Message: msg,
}
return record, nil
}

View File

@ -7,7 +7,7 @@ import (
_ "github.com/mattn/go-sqlite3" // Blank import to register the sqlite3 driver _ "github.com/mattn/go-sqlite3" // Blank import to register the sqlite3 driver
"github.com/status-im/go-waku/tests" "github.com/status-im/go-waku/tests"
"github.com/status-im/go-waku/waku/v2/protocol/pb" "github.com/status-im/go-waku/waku/v2/protocol"
"github.com/status-im/go-waku/waku/v2/utils" "github.com/status-im/go-waku/waku/v2/utils"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap" "go.uber.org/zap"
@ -22,14 +22,6 @@ func NewMock() *sql.DB {
return db return db
} }
func createIndex(digest []byte, receiverTime int64) *pb.Index {
return &pb.Index{
Digest: digest,
ReceiverTime: receiverTime,
SenderTime: 1.0,
}
}
func TestDbStore(t *testing.T) { func TestDbStore(t *testing.T) {
db := NewMock() db := NewMock()
option := WithDB(db) option := WithDB(db)
@ -40,11 +32,7 @@ func TestDbStore(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Empty(t, res) require.Empty(t, res)
err = store.Put( err = store.Put(protocol.NewEnvelope(tests.CreateWakuMessage("test", 1), utils.GetUnixEpoch(), "test"))
createIndex([]byte("digest"), 1),
"test",
tests.CreateWakuMessage("test", 1),
)
require.NoError(t, err) require.NoError(t, err)
res, err = store.GetAll() res, err = store.GetAll()
@ -59,18 +47,18 @@ func TestStoreRetention(t *testing.T) {
insertTime := time.Now() insertTime := time.Now()
_ = store.Put(createIndex([]byte{1}, insertTime.Add(-70*time.Second).UnixNano()), "test", tests.CreateWakuMessage("test", 1)) _ = store.Put(protocol.NewEnvelope(tests.CreateWakuMessage("test1", insertTime.Add(-70*time.Second).UnixNano()), insertTime.Add(-70*time.Second).UnixNano(), "test"))
_ = store.Put(createIndex([]byte{2}, insertTime.Add(-60*time.Second).UnixNano()), "test", tests.CreateWakuMessage("test", 2)) _ = store.Put(protocol.NewEnvelope(tests.CreateWakuMessage("test2", insertTime.Add(-60*time.Second).UnixNano()), insertTime.Add(-60*time.Second).UnixNano(), "test"))
_ = store.Put(createIndex([]byte{3}, insertTime.Add(-50*time.Second).UnixNano()), "test", tests.CreateWakuMessage("test", 3)) _ = store.Put(protocol.NewEnvelope(tests.CreateWakuMessage("test3", insertTime.Add(-50*time.Second).UnixNano()), insertTime.Add(-50*time.Second).UnixNano(), "test"))
_ = store.Put(createIndex([]byte{4}, insertTime.Add(-40*time.Second).UnixNano()), "test", tests.CreateWakuMessage("test", 4)) _ = store.Put(protocol.NewEnvelope(tests.CreateWakuMessage("test4", insertTime.Add(-40*time.Second).UnixNano()), insertTime.Add(-40*time.Second).UnixNano(), "test"))
_ = store.Put(createIndex([]byte{5}, insertTime.Add(-30*time.Second).UnixNano()), "test", tests.CreateWakuMessage("test", 5)) _ = store.Put(protocol.NewEnvelope(tests.CreateWakuMessage("test5", insertTime.Add(-30*time.Second).UnixNano()), insertTime.Add(-30*time.Second).UnixNano(), "test"))
dbResults, err := store.GetAll() dbResults, err := store.GetAll()
require.NoError(t, err) require.NoError(t, err)
require.Len(t, dbResults, 5) require.Len(t, dbResults, 5)
_ = store.Put(createIndex([]byte{6}, insertTime.Add(-20*time.Second).UnixNano()), "test", tests.CreateWakuMessage("test", 6)) _ = store.Put(protocol.NewEnvelope(tests.CreateWakuMessage("test6", insertTime.Add(-20*time.Second).UnixNano()), insertTime.Add(-20*time.Second).UnixNano(), "test"))
_ = store.Put(createIndex([]byte{7}, insertTime.Add(-10*time.Second).UnixNano()), "test", tests.CreateWakuMessage("test", 7)) _ = store.Put(protocol.NewEnvelope(tests.CreateWakuMessage("test7", insertTime.Add(-10*time.Second).UnixNano()), insertTime.Add(-10*time.Second).UnixNano(), "test"))
// This step simulates starting go-waku again from scratch // This step simulates starting go-waku again from scratch
@ -80,7 +68,7 @@ func TestStoreRetention(t *testing.T) {
dbResults, err = store.GetAll() dbResults, err = store.GetAll()
require.NoError(t, err) require.NoError(t, err)
require.Len(t, dbResults, 3) require.Len(t, dbResults, 3)
require.Equal(t, []byte{5}, dbResults[0].ID) require.Equal(t, "test5", dbResults[0].Message.ContentTopic)
require.Equal(t, []byte{6}, dbResults[1].ID) require.Equal(t, "test6", dbResults[1].Message.ContentTopic)
require.Equal(t, []byte{7}, dbResults[2].ID) require.Equal(t, "test7", dbResults[2].Message.ContentTopic)
} }

View File

@ -5,6 +5,8 @@ import (
"testing" "testing"
"github.com/status-im/go-waku/waku/v2/protocol" "github.com/status-im/go-waku/waku/v2/protocol"
"github.com/status-im/go-waku/waku/v2/protocol/pb"
"github.com/status-im/go-waku/waku/v2/utils"
) )
// Adapted from https://github.com/dustin/go-broadcast/commit/f664265f5a662fb4d1df7f3533b1e8d0e0277120 // Adapted from https://github.com/dustin/go-broadcast/commit/f664265f5a662fb4d1df7f3533b1e8d0e0277120
@ -28,7 +30,7 @@ func TestBroadcast(t *testing.T) {
} }
env := new(protocol.Envelope) env := protocol.NewEnvelope(&pb.WakuMessage{}, utils.GetUnixEpoch(), "abc")
b.Submit(env) b.Submit(env)
wg.Wait() wg.Wait()
@ -55,7 +57,7 @@ func TestBroadcastWait(t *testing.T) {
} }
env := new(protocol.Envelope) env := protocol.NewEnvelope(&pb.WakuMessage{}, utils.GetUnixEpoch(), "abc")
b.Submit(env) b.Submit(env)
wg.Wait() wg.Wait()

View File

@ -7,6 +7,9 @@ import (
"testing" "testing"
"time" "time"
"github.com/status-im/go-waku/waku/persistence"
"github.com/status-im/go-waku/waku/persistence/sqlite"
"github.com/status-im/go-waku/waku/v2/utils"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -62,6 +65,11 @@ func TestConnectionStatusChanges(t *testing.T) {
err = node2.Start() err = node2.Start()
require.NoError(t, err) require.NoError(t, err)
db, err := sqlite.NewDB(":memory:")
require.NoError(t, err)
dbStore, err := persistence.NewDBStore(utils.Logger(), persistence.WithDB(db))
require.NoError(t, err)
// Node3: Relay + Store // Node3: Relay + Store
hostAddr3, err := net.ResolveTCPAddr("tcp", "0.0.0.0:0") hostAddr3, err := net.ResolveTCPAddr("tcp", "0.0.0.0:0")
require.NoError(t, err) require.NoError(t, err)
@ -69,6 +77,7 @@ func TestConnectionStatusChanges(t *testing.T) {
WithHostAddress(hostAddr3), WithHostAddress(hostAddr3),
WithWakuRelay(), WithWakuRelay(),
WithWakuStore(false, false), WithWakuStore(false, false),
WithMessageProvider(dbStore),
) )
require.NoError(t, err) require.NoError(t, err)
err = node3.Start() err = node3.Start()

View File

@ -3,6 +3,7 @@ package node
import ( import (
"crypto/ecdsa" "crypto/ecdsa"
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"net" "net"
"time" "time"
@ -17,7 +18,6 @@ import (
basichost "github.com/libp2p/go-libp2p/p2p/host/basic" basichost "github.com/libp2p/go-libp2p/p2p/host/basic"
"github.com/libp2p/go-tcp-transport" "github.com/libp2p/go-tcp-transport"
"github.com/multiformats/go-multiaddr" "github.com/multiformats/go-multiaddr"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net" manet "github.com/multiformats/go-multiaddr/net"
rendezvous "github.com/status-im/go-waku-rendezvous" rendezvous "github.com/status-im/go-waku-rendezvous"
"github.com/status-im/go-waku/waku/v2/protocol/filter" "github.com/status-im/go-waku/waku/v2/protocol/filter"
@ -36,7 +36,7 @@ type WakuNodeParameters struct {
hostAddr *net.TCPAddr hostAddr *net.TCPAddr
dns4Domain string dns4Domain string
advertiseAddr *net.IP advertiseAddr *net.IP
multiAddr []ma.Multiaddr multiAddr []multiaddr.Multiaddr
addressFactory basichost.AddrsFactory addressFactory basichost.AddrsFactory
privKey *ecdsa.PrivateKey privKey *ecdsa.PrivateKey
libP2POpts []libp2p.Option libP2POpts []libp2p.Option
@ -97,7 +97,7 @@ var DefaultWakuNodeOptions = []WakuNodeOption{
} }
// MultiAddresses return the list of multiaddresses configured in the node // MultiAddresses return the list of multiaddresses configured in the node
func (w WakuNodeParameters) MultiAddresses() []ma.Multiaddr { func (w WakuNodeParameters) MultiAddresses() []multiaddr.Multiaddr {
return w.multiAddr return w.multiAddr
} }
@ -124,24 +124,24 @@ func WithDns4Domain(dns4Domain string) WakuNodeOption {
return func(params *WakuNodeParameters) error { return func(params *WakuNodeParameters) error {
params.dns4Domain = dns4Domain params.dns4Domain = dns4Domain
params.addressFactory = func([]ma.Multiaddr) []ma.Multiaddr { params.addressFactory = func([]multiaddr.Multiaddr) []multiaddr.Multiaddr {
var result []multiaddr.Multiaddr var result []multiaddr.Multiaddr
hostAddrMA, err := ma.NewMultiaddr("/dns4/" + params.dns4Domain) hostAddrMA, err := multiaddr.NewMultiaddr("/dns4/" + params.dns4Domain)
if err != nil { if err != nil {
panic(fmt.Sprintf("invalid dns4 address: %s", err.Error())) panic(fmt.Sprintf("invalid dns4 address: %s", err.Error()))
} }
tcp, _ := ma.NewMultiaddr(fmt.Sprintf("/tcp/%d", params.hostAddr.Port)) tcp, _ := multiaddr.NewMultiaddr(fmt.Sprintf("/tcp/%d", params.hostAddr.Port))
result = append(result, hostAddrMA.Encapsulate(tcp)) result = append(result, hostAddrMA.Encapsulate(tcp))
if params.enableWS || params.enableWSS { if params.enableWS || params.enableWSS {
if params.enableWSS { if params.enableWSS {
wss, _ := ma.NewMultiaddr(fmt.Sprintf("/tcp/%d/wss", params.wssPort)) wss, _ := multiaddr.NewMultiaddr(fmt.Sprintf("/tcp/%d/wss", params.wssPort))
result = append(result, hostAddrMA.Encapsulate(wss)) result = append(result, hostAddrMA.Encapsulate(wss))
} else { } else {
ws, _ := ma.NewMultiaddr(fmt.Sprintf("/tcp/%d/ws", params.wsPort)) ws, _ := multiaddr.NewMultiaddr(fmt.Sprintf("/tcp/%d/ws", params.wsPort))
result = append(result, hostAddrMA.Encapsulate(ws)) result = append(result, hostAddrMA.Encapsulate(ws))
} }
} }
@ -176,7 +176,7 @@ func WithAdvertiseAddress(address *net.TCPAddr) WakuNodeOption {
return err return err
} }
params.addressFactory = func([]ma.Multiaddr) []ma.Multiaddr { params.addressFactory = func([]multiaddr.Multiaddr) []multiaddr.Multiaddr {
var result []multiaddr.Multiaddr var result []multiaddr.Multiaddr
result = append(result, advertiseAddress) result = append(result, advertiseAddress)
if params.enableWS || params.enableWSS { if params.enableWS || params.enableWSS {
@ -195,7 +195,7 @@ func WithAdvertiseAddress(address *net.TCPAddr) WakuNodeOption {
} }
// WithMultiaddress is a WakuNodeOption that configures libp2p to listen on a list of multiaddresses // WithMultiaddress is a WakuNodeOption that configures libp2p to listen on a list of multiaddresses
func WithMultiaddress(addresses []ma.Multiaddr) WakuNodeOption { func WithMultiaddress(addresses []multiaddr.Multiaddr) WakuNodeOption {
return func(params *WakuNodeParameters) error { return func(params *WakuNodeParameters) error {
params.multiAddr = append(params.multiAddr, addresses...) params.multiAddr = append(params.multiAddr, addresses...)
return nil return nil
@ -334,6 +334,9 @@ func WithWakuStoreAndRetentionPolicy(shouldResume bool, maxDuration time.Duratio
// used to store and retrieve persisted messages // used to store and retrieve persisted messages
func WithMessageProvider(s store.MessageProvider) WakuNodeOption { func WithMessageProvider(s store.MessageProvider) WakuNodeOption {
return func(params *WakuNodeParameters) error { return func(params *WakuNodeParameters) error {
if s == nil {
return errors.New("message provider can't be nil")
}
params.messageProvider = s params.messageProvider = s
return nil return nil
} }

View File

@ -9,6 +9,7 @@ import (
"github.com/multiformats/go-multiaddr" "github.com/multiformats/go-multiaddr"
rendezvous "github.com/status-im/go-waku-rendezvous" rendezvous "github.com/status-im/go-waku-rendezvous"
"github.com/status-im/go-waku/tests" "github.com/status-im/go-waku/tests"
"github.com/status-im/go-waku/waku/persistence"
"github.com/status-im/go-waku/waku/v2/protocol/store" "github.com/status-im/go-waku/waku/v2/protocol/store"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -46,7 +47,7 @@ func TestWakuOptions(t *testing.T) {
WithDiscoveryV5(123, nil, false), WithDiscoveryV5(123, nil, false),
WithWakuStore(true, true), WithWakuStore(true, true),
WithWakuStoreAndRetentionPolicy(true, time.Hour, 100), WithWakuStoreAndRetentionPolicy(true, time.Hour, 100),
WithMessageProvider(nil), WithMessageProvider(&persistence.DBStore{}),
WithLightPush(), WithLightPush(),
WithKeepAlive(time.Hour), WithKeepAlive(time.Hour),
WithConnectionStatusChannel(connStatusChan), WithConnectionStatusChannel(connStatusChan),

View File

@ -1,27 +1,38 @@
package protocol package protocol
import "github.com/status-im/go-waku/waku/v2/protocol/pb" import (
"crypto/sha256"
"github.com/status-im/go-waku/waku/v2/protocol/pb"
)
// Envelope contains information about the pubsub topic of a WakuMessage // Envelope contains information about the pubsub topic of a WakuMessage
// and a hash used to identify a message based on the bytes of a WakuMessage // and a hash used to identify a message based on the bytes of a WakuMessage
// protobuffer // protobuffer
type Envelope struct { type Envelope struct {
msg *pb.WakuMessage msg *pb.WakuMessage
pubsubTopic string size int
size int hash []byte
hash []byte index *pb.Index
} }
// NewEnvelope creates a new Envelope that contains a WakuMessage // NewEnvelope creates a new Envelope that contains a WakuMessage
// It's used as a way to know to which Pubsub topic belongs a WakuMessage // It's used as a way to know to which Pubsub topic belongs a WakuMessage
// as well as generating a hash based on the bytes that compose the message // as well as generating a hash based on the bytes that compose the message
func NewEnvelope(msg *pb.WakuMessage, pubSubTopic string) *Envelope { func NewEnvelope(msg *pb.WakuMessage, receiverTime int64, pubSubTopic string) *Envelope {
data, _ := msg.Marshal() data, _ := msg.Marshal()
hash := sha256.Sum256(append([]byte(msg.ContentTopic), msg.Payload...))
return &Envelope{ return &Envelope{
msg: msg, msg: msg,
pubsubTopic: pubSubTopic, size: len(data),
size: len(data), hash: pb.Hash(data),
hash: pb.Hash(data), index: &pb.Index{
Digest: hash[:],
ReceiverTime: receiverTime,
SenderTime: msg.Timestamp,
PubsubTopic: pubSubTopic,
},
} }
} }
@ -32,7 +43,7 @@ func (e *Envelope) Message() *pb.WakuMessage {
// PubsubTopic returns the topic on which a WakuMessage was received // PubsubTopic returns the topic on which a WakuMessage was received
func (e *Envelope) PubsubTopic() string { func (e *Envelope) PubsubTopic() string {
return e.pubsubTopic return e.index.PubsubTopic
} }
// Hash returns a 32 byte hash calculated from the WakuMessage bytes // Hash returns a 32 byte hash calculated from the WakuMessage bytes
@ -44,3 +55,7 @@ func (e *Envelope) Hash() []byte {
func (e *Envelope) Size() int { func (e *Envelope) Size() int {
return e.size return e.size
} }
func (env *Envelope) Index() *pb.Index {
return env.index
}

View File

@ -4,12 +4,14 @@ import (
"testing" "testing"
"github.com/status-im/go-waku/waku/v2/protocol/pb" "github.com/status-im/go-waku/waku/v2/protocol/pb"
"github.com/status-im/go-waku/waku/v2/utils"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestEnvelope(t *testing.T) { func TestEnvelope(t *testing.T) {
e := NewEnvelope( e := NewEnvelope(
&pb.WakuMessage{ContentTopic: "ContentTopic"}, &pb.WakuMessage{ContentTopic: "ContentTopic"},
utils.GetUnixEpoch(),
"test", "test",
) )

View File

@ -5,6 +5,7 @@ import (
"github.com/status-im/go-waku/waku/v2/protocol" "github.com/status-im/go-waku/waku/v2/protocol"
"github.com/status-im/go-waku/waku/v2/protocol/pb" "github.com/status-im/go-waku/waku/v2/protocol/pb"
"github.com/status-im/go-waku/waku/v2/utils"
) )
type FilterMap struct { type FilterMap struct {
@ -79,7 +80,7 @@ func (fm *FilterMap) Notify(msg *pb.WakuMessage, requestId string) {
defer fm.RUnlock() defer fm.RUnlock()
for key, filter := range fm.items { for key, filter := range fm.items {
envelope := protocol.NewEnvelope(msg, filter.Topic) envelope := protocol.NewEnvelope(msg, utils.GetUnixEpoch(), filter.Topic)
// We do this because the key for the filter is set to the requestId received from the filter protocol. // We do this because the key for the filter is set to the requestId received from the filter protocol.
// This means we do not need to check the content filter explicitly as all MessagePushs already contain // This means we do not need to check the content filter explicitly as all MessagePushs already contain

View File

@ -112,7 +112,7 @@ func TestWakuLightPush(t *testing.T) {
// Checking that msg hash is correct // Checking that msg hash is correct
hash, err := client.PublishToTopic(ctx, msg2, testTopic) hash, err := client.PublishToTopic(ctx, msg2, testTopic)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, protocol.NewEnvelope(msg2, string(testTopic)).Hash(), hash) require.Equal(t, protocol.NewEnvelope(msg2, utils.GetUnixEpoch(), string(testTopic)).Hash(), hash)
wg.Wait() wg.Wait()
} }

View File

@ -20,6 +20,7 @@ import (
"github.com/status-im/go-waku/waku/v2/metrics" "github.com/status-im/go-waku/waku/v2/metrics"
waku_proto "github.com/status-im/go-waku/waku/v2/protocol" waku_proto "github.com/status-im/go-waku/waku/v2/protocol"
"github.com/status-im/go-waku/waku/v2/protocol/pb" "github.com/status-im/go-waku/waku/v2/protocol/pb"
"github.com/status-im/go-waku/waku/v2/utils"
) )
const WakuRelayID_v200 = protocol.ID("/vac/waku/relay/2.0.0") const WakuRelayID_v200 = protocol.ID("/vac/waku/relay/2.0.0")
@ -337,7 +338,7 @@ func (w *WakuRelay) subscribeToTopic(t string, subscription *Subscription, sub *
return return
} }
envelope := waku_proto.NewEnvelope(wakuMessage, string(t)) envelope := waku_proto.NewEnvelope(wakuMessage, utils.GetUnixEpoch(), string(t))
if w.bcaster != nil { if w.bcaster != nil {
w.bcaster.Submit(envelope) w.bcaster.Submit(envelope)

View File

@ -1,139 +0,0 @@
package store
import (
"errors"
"sync"
"time"
"github.com/status-im/go-waku/waku/v2/utils"
)
// MaxTimeVariance is the maximum duration in the future allowed for a message timestamp
const MaxTimeVariance = time.Duration(20) * time.Second
type MessageQueue struct {
sync.RWMutex
seen map[[32]byte]struct{}
messages []IndexedWakuMessage
maxMessages int
maxDuration time.Duration
quit chan struct{}
wg *sync.WaitGroup
}
var ErrDuplicatedMessage = errors.New("duplicated message")
var ErrFutureMessage = errors.New("message timestamp in the future")
var ErrTooOld = errors.New("message is too old")
func (self *MessageQueue) Push(msg IndexedWakuMessage) error {
self.Lock()
defer self.Unlock()
var k [32]byte
copy(k[:], msg.index.Digest)
if _, ok := self.seen[k]; ok {
return ErrDuplicatedMessage
}
// Ensure that messages don't "jump" to the front of the queue with future timestamps
if msg.index.SenderTime-msg.index.ReceiverTime > int64(MaxTimeVariance) {
return ErrFutureMessage
}
self.seen[k] = struct{}{}
self.messages = append(self.messages, msg)
if self.maxMessages != 0 && len(self.messages) > self.maxMessages {
if indexComparison(msg.index, self.messages[0].index) < 0 {
return ErrTooOld // :(
}
numToPop := len(self.messages) - self.maxMessages
self.messages = self.messages[numToPop:len(self.messages)]
}
return nil
}
func (self *MessageQueue) Messages() <-chan IndexedWakuMessage {
c := make(chan IndexedWakuMessage)
f := func() {
self.RLock()
defer self.RUnlock()
for _, value := range self.messages {
c <- value
}
close(c)
}
go f()
return c
}
func (self *MessageQueue) cleanOlderRecords() {
self.Lock()
defer self.Unlock()
// TODO: check if retention days was set
t := utils.GetUnixEpochFrom(time.Now().Add(-self.maxDuration))
var idx int
for i := 0; i < len(self.messages); i++ {
if self.messages[i].index.ReceiverTime >= t {
idx = i
break
}
}
self.messages = self.messages[idx:]
}
func (self *MessageQueue) checkForOlderRecords(d time.Duration) {
defer self.wg.Done()
ticker := time.NewTicker(d)
defer ticker.Stop()
for {
select {
case <-self.quit:
return
case <-ticker.C:
self.cleanOlderRecords()
}
}
}
func (self *MessageQueue) Length() int {
self.RLock()
defer self.RUnlock()
return len(self.messages)
}
func NewMessageQueue(maxMessages int, maxDuration time.Duration) *MessageQueue {
result := &MessageQueue{
maxMessages: maxMessages,
maxDuration: maxDuration,
seen: make(map[[32]byte]struct{}),
quit: make(chan struct{}),
wg: &sync.WaitGroup{},
}
if maxDuration != 0 {
result.wg.Add(1)
go result.checkForOlderRecords(10 * time.Second) // is 10s okay?
}
return result
}
func (self *MessageQueue) Stop() {
close(self.quit)
self.wg.Wait()
}

View File

@ -1,64 +0,0 @@
package store
import (
"testing"
"time"
"github.com/status-im/go-waku/tests"
"github.com/status-im/go-waku/waku/v2/protocol/pb"
"github.com/status-im/go-waku/waku/v2/utils"
"github.com/stretchr/testify/require"
)
func TestMessageQueue(t *testing.T) {
msg1 := tests.CreateWakuMessage("1", 1)
msg2 := tests.CreateWakuMessage("2", 2)
msg3 := tests.CreateWakuMessage("3", 3)
msg4 := tests.CreateWakuMessage("3", 3)
msg5 := tests.CreateWakuMessage("3", 3)
msgQ := NewMessageQueue(3, 1*time.Minute)
err := msgQ.Push(IndexedWakuMessage{msg: msg1, index: &pb.Index{Digest: []byte{1}, ReceiverTime: utils.GetUnixEpochFrom(time.Now().Add(-20 * time.Second))}, pubsubTopic: "test"})
require.NoError(t, err)
err = msgQ.Push(IndexedWakuMessage{msg: msg2, index: &pb.Index{Digest: []byte{2}, ReceiverTime: utils.GetUnixEpochFrom(time.Now().Add(-15 * time.Second))}, pubsubTopic: "test"})
require.NoError(t, err)
err = msgQ.Push(IndexedWakuMessage{msg: msg3, index: &pb.Index{Digest: []byte{3}, ReceiverTime: utils.GetUnixEpochFrom(time.Now().Add(-10 * time.Second))}, pubsubTopic: "test"})
require.NoError(t, err)
require.Equal(t, msgQ.Length(), 3)
err = msgQ.Push(IndexedWakuMessage{msg: msg4, index: &pb.Index{Digest: []byte{4}, ReceiverTime: utils.GetUnixEpochFrom(time.Now().Add(-3 * time.Second))}, pubsubTopic: "test"})
require.NoError(t, err)
require.Len(t, msgQ.messages, 3)
require.Equal(t, msg2.Payload, msgQ.messages[0].msg.Payload)
require.Equal(t, msg4.Payload, msgQ.messages[2].msg.Payload)
indexedMsg5 := IndexedWakuMessage{msg: msg5, index: &pb.Index{Digest: []byte{5}, ReceiverTime: utils.GetUnixEpochFrom(time.Now().Add(0 * time.Second))}, pubsubTopic: "test"}
err = msgQ.Push(indexedMsg5)
require.NoError(t, err)
require.Len(t, msgQ.messages, 3)
require.Equal(t, msg3.Payload, msgQ.messages[0].msg.Payload)
require.Equal(t, msg5.Payload, msgQ.messages[2].msg.Payload)
// Test duplication
err = msgQ.Push(indexedMsg5)
require.ErrorIs(t, err, ErrDuplicatedMessage)
require.Len(t, msgQ.messages, 3)
require.Equal(t, msg3.Payload, msgQ.messages[0].msg.Payload)
require.Equal(t, msg4.Payload, msgQ.messages[1].msg.Payload)
require.Equal(t, msg5.Payload, msgQ.messages[2].msg.Payload)
// Test retention
msgQ.maxDuration = 5 * time.Second
msgQ.cleanOlderRecords()
require.Len(t, msgQ.messages, 2)
require.Equal(t, msg4.Payload, msgQ.messages[0].msg.Payload)
require.Equal(t, msg5.Payload, msgQ.messages[1].msg.Payload)
}

View File

@ -0,0 +1,22 @@
package store
import (
"database/sql"
"testing"
"github.com/status-im/go-waku/waku/persistence"
"github.com/status-im/go-waku/waku/persistence/sqlite"
"github.com/status-im/go-waku/waku/v2/utils"
"github.com/stretchr/testify/require"
)
func MemoryDB(t *testing.T) *persistence.DBStore {
var db *sql.DB
db, err := sqlite.NewDB(":memory:")
require.NoError(t, err)
dbStore, err := persistence.NewDBStore(utils.Logger(), persistence.WithDB(db))
require.NoError(t, err)
return dbStore
}

View File

@ -15,20 +15,23 @@ import (
) )
func TestFindLastSeenMessage(t *testing.T) { func TestFindLastSeenMessage(t *testing.T) {
msg1 := protocol.NewEnvelope(tests.CreateWakuMessage("1", 1), "test") msg1 := protocol.NewEnvelope(tests.CreateWakuMessage("1", 1), utils.GetUnixEpoch(), "test")
msg2 := protocol.NewEnvelope(tests.CreateWakuMessage("2", 2), "test") msg2 := protocol.NewEnvelope(tests.CreateWakuMessage("2", 2), utils.GetUnixEpoch(), "test")
msg3 := protocol.NewEnvelope(tests.CreateWakuMessage("3", 3), "test") msg3 := protocol.NewEnvelope(tests.CreateWakuMessage("3", 3), utils.GetUnixEpoch(), "test")
msg4 := protocol.NewEnvelope(tests.CreateWakuMessage("4", 4), "test") msg4 := protocol.NewEnvelope(tests.CreateWakuMessage("4", 4), utils.GetUnixEpoch(), "test")
msg5 := protocol.NewEnvelope(tests.CreateWakuMessage("5", 5), "test") msg5 := protocol.NewEnvelope(tests.CreateWakuMessage("5", 5), utils.GetUnixEpoch(), "test")
s := NewWakuStore(nil, nil, nil, 0, 0, utils.Logger()) s := NewWakuStore(nil, nil, MemoryDB(t), 0, 0, utils.Logger())
_ = s.storeMessage(msg1) _ = s.storeMessage(msg1)
_ = s.storeMessage(msg3) _ = s.storeMessage(msg3)
_ = s.storeMessage(msg5) _ = s.storeMessage(msg5)
_ = s.storeMessage(msg2) _ = s.storeMessage(msg2)
_ = s.storeMessage(msg4) _ = s.storeMessage(msg4)
require.Equal(t, msg5.Message().Timestamp, s.findLastSeen()) lastSeen, err := s.findLastSeen()
require.NoError(t, err)
require.Equal(t, msg5.Message().Timestamp, lastSeen)
} }
func TestResume(t *testing.T) { func TestResume(t *testing.T) {
@ -38,7 +41,7 @@ func TestResume(t *testing.T) {
host1, err := libp2p.New(libp2p.DefaultTransports, libp2p.ListenAddrStrings("/ip4/0.0.0.0/tcp/0")) host1, err := libp2p.New(libp2p.DefaultTransports, libp2p.ListenAddrStrings("/ip4/0.0.0.0/tcp/0"))
require.NoError(t, err) require.NoError(t, err)
s1 := NewWakuStore(host1, nil, nil, 0, 0, utils.Logger()) s1 := NewWakuStore(host1, nil, MemoryDB(t), 0, 0, utils.Logger())
s1.Start(ctx) s1.Start(ctx)
defer s1.Stop() defer s1.Stop()
@ -49,14 +52,14 @@ func TestResume(t *testing.T) {
} }
wakuMessage := tests.CreateWakuMessage(contentTopic, int64(i+1)) wakuMessage := tests.CreateWakuMessage(contentTopic, int64(i+1))
msg := protocol.NewEnvelope(wakuMessage, "test") msg := protocol.NewEnvelope(wakuMessage, utils.GetUnixEpoch(), "test")
_ = s1.storeMessage(msg) _ = s1.storeMessage(msg)
} }
host2, err := libp2p.New(libp2p.DefaultTransports, libp2p.ListenAddrStrings("/ip4/0.0.0.0/tcp/0")) host2, err := libp2p.New(libp2p.DefaultTransports, libp2p.ListenAddrStrings("/ip4/0.0.0.0/tcp/0"))
require.NoError(t, err) require.NoError(t, err)
s2 := NewWakuStore(host2, nil, nil, 0, 0, utils.Logger()) s2 := NewWakuStore(host2, nil, MemoryDB(t), 0, 0, utils.Logger())
s2.Start(ctx) s2.Start(ctx)
defer s2.Stop() defer s2.Stop()
@ -68,7 +71,11 @@ func TestResume(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 10, msgCount) require.Equal(t, 10, msgCount)
require.Len(t, s2.messageQueue.messages, 10)
allMsgs, err := s2.msgProvider.GetAll()
require.NoError(t, err)
require.Len(t, allMsgs, 10)
// Test duplication // Test duplication
msgCount, err = s2.Resume(ctx, "test", []peer.ID{host1.ID()}) msgCount, err = s2.Resume(ctx, "test", []peer.ID{host1.ID()})
@ -88,18 +95,18 @@ func TestResumeWithListOfPeers(t *testing.T) {
host1, err := libp2p.New(libp2p.DefaultTransports, libp2p.ListenAddrStrings("/ip4/0.0.0.0/tcp/0")) host1, err := libp2p.New(libp2p.DefaultTransports, libp2p.ListenAddrStrings("/ip4/0.0.0.0/tcp/0"))
require.NoError(t, err) require.NoError(t, err)
s1 := NewWakuStore(host1, nil, nil, 0, 0, utils.Logger()) s1 := NewWakuStore(host1, nil, MemoryDB(t), 0, 0, utils.Logger())
s1.Start(ctx) s1.Start(ctx)
defer s1.Stop() defer s1.Stop()
msg0 := &pb.WakuMessage{Payload: []byte{1, 2, 3}, ContentTopic: "2", Version: 0, Timestamp: 0} msg0 := &pb.WakuMessage{Payload: []byte{1, 2, 3}, ContentTopic: "2", Version: 0, Timestamp: 0}
_ = s1.storeMessage(protocol.NewEnvelope(msg0, "test")) _ = s1.storeMessage(protocol.NewEnvelope(msg0, utils.GetUnixEpoch(), "test"))
host2, err := libp2p.New(libp2p.DefaultTransports, libp2p.ListenAddrStrings("/ip4/0.0.0.0/tcp/0")) host2, err := libp2p.New(libp2p.DefaultTransports, libp2p.ListenAddrStrings("/ip4/0.0.0.0/tcp/0"))
require.NoError(t, err) require.NoError(t, err)
s2 := NewWakuStore(host2, nil, nil, 0, 0, utils.Logger()) s2 := NewWakuStore(host2, nil, MemoryDB(t), 0, 0, utils.Logger())
s2.Start(ctx) s2.Start(ctx)
defer s2.Stop() defer s2.Stop()
@ -111,7 +118,10 @@ func TestResumeWithListOfPeers(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, msgCount) require.Equal(t, 1, msgCount)
require.Len(t, s2.messageQueue.messages, 1)
allMsgs, err := s2.msgProvider.GetAll()
require.NoError(t, err)
require.Len(t, allMsgs, 1)
} }
func TestResumeWithoutSpecifyingPeer(t *testing.T) { func TestResumeWithoutSpecifyingPeer(t *testing.T) {
@ -121,18 +131,18 @@ func TestResumeWithoutSpecifyingPeer(t *testing.T) {
host1, err := libp2p.New(libp2p.DefaultTransports, libp2p.ListenAddrStrings("/ip4/0.0.0.0/tcp/0")) host1, err := libp2p.New(libp2p.DefaultTransports, libp2p.ListenAddrStrings("/ip4/0.0.0.0/tcp/0"))
require.NoError(t, err) require.NoError(t, err)
s1 := NewWakuStore(host1, nil, nil, 0, 0, utils.Logger()) s1 := NewWakuStore(host1, nil, MemoryDB(t), 0, 0, utils.Logger())
s1.Start(ctx) s1.Start(ctx)
defer s1.Stop() defer s1.Stop()
msg0 := &pb.WakuMessage{Payload: []byte{1, 2, 3}, ContentTopic: "2", Version: 0, Timestamp: 0} msg0 := &pb.WakuMessage{Payload: []byte{1, 2, 3}, ContentTopic: "2", Version: 0, Timestamp: 0}
_ = s1.storeMessage(protocol.NewEnvelope(msg0, "test")) _ = s1.storeMessage(protocol.NewEnvelope(msg0, utils.GetUnixEpoch(), "test"))
host2, err := libp2p.New(libp2p.DefaultTransports, libp2p.ListenAddrStrings("/ip4/0.0.0.0/tcp/0")) host2, err := libp2p.New(libp2p.DefaultTransports, libp2p.ListenAddrStrings("/ip4/0.0.0.0/tcp/0"))
require.NoError(t, err) require.NoError(t, err)
s2 := NewWakuStore(host2, nil, nil, 0, 0, utils.Logger()) s2 := NewWakuStore(host2, nil, MemoryDB(t), 0, 0, utils.Logger())
s2.Start(ctx) s2.Start(ctx)
defer s2.Stop() defer s2.Stop()
@ -144,5 +154,8 @@ func TestResumeWithoutSpecifyingPeer(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, msgCount) require.Equal(t, 1, msgCount)
require.Len(t, s2.messageQueue.messages, 1)
allMsgs, err := s2.msgProvider.GetAll()
require.NoError(t, err)
require.Len(t, allMsgs, 1)
} }

View File

@ -1,13 +1,10 @@
package store package store
import ( import (
"bytes"
"context" "context"
"encoding/hex" "encoding/hex"
"errors" "errors"
"math" "math"
"sort"
"strings"
"sync" "sync"
"time" "time"
@ -49,157 +46,68 @@ var (
ErrFailedQuery = errors.New("failed to resolve the query") ErrFailedQuery = errors.New("failed to resolve the query")
) )
func minOf(vars ...int) int { func findMessages(query *pb.HistoryQuery, msgProvider MessageProvider) ([]*pb.WakuMessage, *pb.PagingInfo, error) {
min := vars[0] if query.PagingInfo == nil {
query.PagingInfo = &pb.PagingInfo{
for _, i := range vars { Direction: pb.PagingInfo_FORWARD,
if min > i {
min = i
} }
} }
return min if query.PagingInfo.PageSize == 0 || query.PagingInfo.PageSize > uint64(MaxPageSize) {
} query.PagingInfo.PageSize = MaxPageSize
func paginateWithIndex(list []IndexedWakuMessage, pinfo *pb.PagingInfo) (resMessages []IndexedWakuMessage, resPagingInfo *pb.PagingInfo) {
if pinfo == nil {
pinfo = new(pb.PagingInfo)
} }
// takes list, and performs paging based on pinfo queryResult, err := msgProvider.Query(query)
// returns the page i.e, a sequence of IndexedWakuMessage and the new paging info to be used for the next paging request if err != nil {
cursor := pinfo.Cursor return nil, nil, err
pageSize := pinfo.PageSize
dir := pinfo.Direction
if len(list) == 0 { // no pagination is needed for an empty list
return list, &pb.PagingInfo{PageSize: 0, Cursor: pinfo.Cursor, Direction: pinfo.Direction}
} }
if pageSize == 0 { if len(queryResult) == 0 { // no pagination is needed for an empty list
pageSize = MaxPageSize newPagingInfo := &pb.PagingInfo{PageSize: 0, Cursor: query.PagingInfo.Cursor, Direction: query.PagingInfo.Direction}
return nil, newPagingInfo, nil
} }
msgList := make([]IndexedWakuMessage, len(list)) lastMsgIdx := len(queryResult) - 1
_ = copy(msgList, list) // makes a copy of the list newCursor := protocol.NewEnvelope(queryResult[lastMsgIdx].Message, queryResult[lastMsgIdx].ReceiverTime, queryResult[lastMsgIdx].PubsubTopic).Index()
sort.Slice(msgList, func(i, j int) bool { // sorts msgList based on the custom comparison proc indexedWakuMessageComparison newPagingInfo := &pb.PagingInfo{PageSize: query.PagingInfo.PageSize, Cursor: newCursor, Direction: query.PagingInfo.Direction}
return indexedWakuMessageComparison(msgList[i], msgList[j]) == -1 if newPagingInfo.PageSize > uint64(len(queryResult)) {
}) newPagingInfo.PageSize = uint64(len(queryResult))
initQuery := false
if cursor == nil {
initQuery = true // an empty cursor means it is an initial query
switch dir {
case pb.PagingInfo_FORWARD:
cursor = list[0].index // perform paging from the beginning of the list
case pb.PagingInfo_BACKWARD:
cursor = list[len(list)-1].index // perform paging from the end of the list
}
} }
foundIndex := findIndex(msgList, cursor) resultMessages := make([]*pb.WakuMessage, len(queryResult))
if foundIndex == -1 { // the cursor is not valid for i := range queryResult {
return nil, &pb.PagingInfo{PageSize: 0, Cursor: pinfo.Cursor, Direction: pinfo.Direction} resultMessages[i] = queryResult[i].Message
} }
var retrievedPageSize, s, e int return resultMessages, newPagingInfo, nil
var newCursor *pb.Index // to be returned as part of the new paging info
switch dir {
case pb.PagingInfo_FORWARD: // forward pagination
remainingMessages := len(msgList) - foundIndex - 1
if initQuery {
remainingMessages = remainingMessages + 1
foundIndex = foundIndex - 1
}
// the number of queried messages cannot exceed the MaxPageSize and the total remaining messages i.e., msgList.len-foundIndex
retrievedPageSize = minOf(int(pageSize), MaxPageSize, remainingMessages)
s = foundIndex + 1 // non inclusive
e = foundIndex + retrievedPageSize
newCursor = msgList[e].index // the new cursor points to the end of the page
case pb.PagingInfo_BACKWARD: // backward pagination
remainingMessages := foundIndex
if initQuery {
remainingMessages = remainingMessages + 1
foundIndex = foundIndex + 1
}
// the number of queried messages cannot exceed the MaxPageSize and the total remaining messages i.e., foundIndex-0
retrievedPageSize = minOf(int(pageSize), MaxPageSize, remainingMessages)
s = foundIndex - retrievedPageSize
e = foundIndex - 1
newCursor = msgList[s].index // the new cursor points to the beginning of the page
}
// retrieve the messages
for i := s; i <= e; i++ {
resMessages = append(resMessages, msgList[i])
}
resPagingInfo = &pb.PagingInfo{PageSize: uint64(retrievedPageSize), Cursor: newCursor, Direction: pinfo.Direction}
return
}
func paginateWithoutIndex(list []IndexedWakuMessage, pinfo *pb.PagingInfo) (resMessages []*pb.WakuMessage, resPinfo *pb.PagingInfo) {
// takes list, and performs paging based on pinfo
// returns the page i.e, a sequence of WakuMessage and the new paging info to be used for the next paging request
indexedData, updatedPagingInfo := paginateWithIndex(list, pinfo)
for _, indexedMsg := range indexedData {
resMessages = append(resMessages, indexedMsg.msg)
}
resPinfo = updatedPagingInfo
return
} }
func (store *WakuStore) FindMessages(query *pb.HistoryQuery) *pb.HistoryResponse { func (store *WakuStore) FindMessages(query *pb.HistoryQuery) *pb.HistoryResponse {
result := new(pb.HistoryResponse) result := new(pb.HistoryResponse)
// data holds IndexedWakuMessage whose topics match the query
var data []IndexedWakuMessage
for indexedMsg := range store.messageQueue.Messages() {
// temporal filtering
// check whether the history query contains a time filter
if query.StartTime != 0 && query.EndTime != 0 {
if indexedMsg.msg.Timestamp < query.StartTime || indexedMsg.msg.Timestamp > query.EndTime {
continue
}
}
// filter based on content filters messages, newPagingInfo, err := findMessages(query, store.msgProvider)
// an empty list of contentFilters means no content filter is requested if err != nil {
if len(query.ContentFilters) != 0 { if err == persistence.ErrInvalidCursor {
match := false result.Error = pb.HistoryResponse_INVALID_CURSOR
for _, cf := range query.ContentFilters { } else {
if cf.ContentTopic == indexedMsg.msg.ContentTopic { // TODO: return error in pb.HistoryResponse
match = true store.log.Error("obtaining messages from db", zap.Error(err))
break
}
}
if !match {
continue
}
} }
// filter based on pubsub topic
// an empty pubsub topic means no pubsub topic filter is requested
if query.PubsubTopic != "" {
if indexedMsg.pubsubTopic != query.PubsubTopic {
continue
}
}
// Some criteria matched
data = append(data, indexedMsg)
} }
result.Messages, result.PagingInfo = paginateWithoutIndex(data, query.PagingInfo) result.Messages = messages
result.PagingInfo = newPagingInfo
return result return result
} }
type MessageProvider interface { type MessageProvider interface {
GetAll() ([]persistence.StoredMessage, error) GetAll() ([]persistence.StoredMessage, error)
Put(cursor *pb.Index, pubsubTopic string, message *pb.WakuMessage) error Query(query *pb.HistoryQuery) ([]persistence.StoredMessage, error)
Put(env *protocol.Envelope) error
MostRecentTimestamp() (int64, error)
Stop() Stop()
} }
type Query struct { type Query struct {
Topic string Topic string
ContentTopics []string ContentTopics []string
@ -228,12 +136,6 @@ func (r *Result) Query() *pb.HistoryQuery {
return r.query return r.query
} }
type IndexedWakuMessage struct {
msg *pb.WakuMessage
index *pb.Index
pubsubTopic string
}
type WakuStore struct { type WakuStore struct {
ctx context.Context ctx context.Context
MsgC chan *protocol.Envelope MsgC chan *protocol.Envelope
@ -243,10 +145,9 @@ type WakuStore struct {
started bool started bool
messageQueue *MessageQueue msgProvider MessageProvider
msgProvider MessageProvider h host.Host
h host.Host swap *swap.WakuSwap
swap *swap.WakuSwap
} }
type Store interface { type Store interface {
@ -266,7 +167,6 @@ func NewWakuStore(host host.Host, swap *swap.WakuSwap, p MessageProvider, maxNum
wakuStore.swap = swap wakuStore.swap = swap
wakuStore.wg = &sync.WaitGroup{} wakuStore.wg = &sync.WaitGroup{}
wakuStore.log = log.Named("store") wakuStore.log = log.Named("store")
wakuStore.messageQueue = NewMessageQueue(maxNumberOfMessages, maxRetentionDuration)
return wakuStore return wakuStore
} }
@ -281,6 +181,11 @@ func (store *WakuStore) Start(ctx context.Context) {
return return
} }
if store.msgProvider == nil {
store.log.Info("Store protocol started (no message provider)")
return
}
store.started = true store.started = true
store.ctx = ctx store.ctx = ctx
store.MsgC = make(chan *protocol.Envelope, 1024) store.MsgC = make(chan *protocol.Envelope, 1024)
@ -290,78 +195,17 @@ func (store *WakuStore) Start(ctx context.Context) {
store.wg.Add(1) store.wg.Add(1)
go store.storeIncomingMessages(ctx) go store.storeIncomingMessages(ctx)
if store.msgProvider == nil {
store.log.Info("Store protocol started (no message provider)")
return
}
store.fetchDBRecords(ctx)
store.log.Info("Store protocol started") store.log.Info("Store protocol started")
} }
func (store *WakuStore) fetchDBRecords(ctx context.Context) {
if store.msgProvider == nil {
return
}
start := time.Now()
defer func() {
elapsed := time.Since(start)
store.log.Info("Store initialization complete",
zap.Duration("duration", elapsed),
zap.Int("messages", store.messageQueue.Length()))
}()
storedMessages, err := (store.msgProvider).GetAll()
if err != nil {
store.log.Error("loading DBProvider messages", zap.Error(err))
metrics.RecordStoreError(ctx, "store_load_failure")
return
}
for _, storedMessage := range storedMessages {
idx := &pb.Index{
Digest: storedMessage.ID,
ReceiverTime: storedMessage.ReceiverTime,
}
_ = store.addToMessageQueue(storedMessage.PubsubTopic, idx, storedMessage.Message)
}
metrics.RecordMessage(ctx, "stored", store.messageQueue.Length())
}
func (store *WakuStore) addToMessageQueue(pubsubTopic string, idx *pb.Index, msg *pb.WakuMessage) error {
return store.messageQueue.Push(IndexedWakuMessage{msg: msg, index: idx, pubsubTopic: pubsubTopic})
}
func (store *WakuStore) storeMessage(env *protocol.Envelope) error { func (store *WakuStore) storeMessage(env *protocol.Envelope) error {
index, err := computeIndex(env) err := store.msgProvider.Put(env)
if err != nil {
store.log.Error("creating message index", zap.Error(err))
return err
}
err = store.addToMessageQueue(env.PubsubTopic(), index, env.Message())
if err == ErrDuplicatedMessage {
return err
}
if store.msgProvider == nil {
metrics.RecordMessage(store.ctx, "stored", store.messageQueue.Length())
return err
}
// TODO: Move this to a separate go routine if DB writes becomes a bottleneck
err = store.msgProvider.Put(index, env.PubsubTopic(), env.Message()) // Should the index be stored?
if err != nil { if err != nil {
store.log.Error("storing message", zap.Error(err)) store.log.Error("storing message", zap.Error(err))
metrics.RecordStoreError(store.ctx, "store_failure") metrics.RecordStoreError(store.ctx, "store_failure")
return err return err
} }
metrics.RecordMessage(store.ctx, "stored", store.messageQueue.Length())
return nil return nil
} }
@ -406,72 +250,6 @@ func (store *WakuStore) onRequest(s network.Stream) {
} }
} }
func computeIndex(env *protocol.Envelope) (*pb.Index, error) {
return &pb.Index{
Digest: env.Hash(),
ReceiverTime: utils.GetUnixEpoch(),
SenderTime: env.Message().Timestamp,
PubsubTopic: env.PubsubTopic(),
}, nil
}
func indexComparison(x, y *pb.Index) int {
// compares x and y
// returns 0 if they are equal
// returns -1 if x < y
// returns 1 if x > y
var timecmp int = 0
if x.SenderTime != 0 && y.SenderTime != 0 {
if x.SenderTime > y.SenderTime {
timecmp = 1
} else if x.SenderTime < y.SenderTime {
timecmp = -1
}
}
if timecmp != 0 {
return timecmp // timestamp has a higher priority for comparison
}
digestcm := bytes.Compare(x.Digest, y.Digest)
if digestcm != 0 {
return digestcm
}
pubsubTopicCmp := strings.Compare(x.PubsubTopic, y.PubsubTopic)
if pubsubTopicCmp != 0 {
return pubsubTopicCmp
}
// receiverTimestamp (a fallback only if senderTimestamp unset on either side, and all other fields unequal)
if x.ReceiverTime > y.ReceiverTime {
timecmp = 1
} else if x.ReceiverTime < y.ReceiverTime {
timecmp = -1
}
return timecmp
}
func indexedWakuMessageComparison(x, y IndexedWakuMessage) int {
// compares x and y
// returns 0 if they are equal
// returns -1 if x < y
// returns 1 if x > y
return indexComparison(x.index, y.index)
}
func findIndex(msgList []IndexedWakuMessage, index *pb.Index) int {
// returns the position of an IndexedWakuMessage in msgList whose index value matches the given index
// returns -1 if no match is found
for i, indexedWakuMessage := range msgList {
if bytes.Equal(indexedWakuMessage.index.Digest, index.Digest) && indexedWakuMessage.index.SenderTime == index.SenderTime && indexedWakuMessage.index.PubsubTopic == index.PubsubTopic {
return i
}
}
return -1
}
type HistoryRequestParameters struct { type HistoryRequestParameters struct {
selectedPeer peer.ID selectedPeer peer.ID
requestId []byte requestId []byte
@ -591,7 +369,7 @@ func (store *WakuStore) queryFrom(ctx context.Context, q *pb.HistoryQuery, selec
return nil, err return nil, err
} }
metrics.RecordMessage(ctx, "retrieved", store.messageQueue.Length()) metrics.RecordMessage(ctx, "retrieved", len(historyResponseRPC.Response.Messages))
return historyResponseRPC.Response, nil return historyResponseRPC.Response, nil
} }
@ -661,23 +439,23 @@ func (store *WakuStore) Query(ctx context.Context, query Query, opts ...HistoryR
// specify the cursor and pagination order and max number of results // specify the cursor and pagination order and max number of results
func (store *WakuStore) Next(ctx context.Context, r *Result) (*Result, error) { func (store *WakuStore) Next(ctx context.Context, r *Result) (*Result, error) {
q := &pb.HistoryQuery{ q := &pb.HistoryQuery{
PubsubTopic: r.query.PubsubTopic, PubsubTopic: r.Query().PubsubTopic,
ContentFilters: r.query.ContentFilters, ContentFilters: r.Query().ContentFilters,
StartTime: r.query.StartTime, StartTime: r.Query().StartTime,
EndTime: r.query.EndTime, EndTime: r.Query().EndTime,
PagingInfo: &pb.PagingInfo{ PagingInfo: &pb.PagingInfo{
PageSize: r.query.PagingInfo.PageSize, PageSize: r.Query().PagingInfo.PageSize,
Direction: r.query.PagingInfo.Direction, Direction: r.Query().PagingInfo.Direction,
Cursor: &pb.Index{ Cursor: &pb.Index{
Digest: r.cursor.Digest, Digest: r.Cursor().Digest,
ReceiverTime: r.cursor.ReceiverTime, ReceiverTime: r.Cursor().ReceiverTime,
SenderTime: r.cursor.SenderTime, SenderTime: r.Cursor().SenderTime,
PubsubTopic: r.cursor.PubsubTopic, PubsubTopic: r.Cursor().PubsubTopic,
}, },
}, },
} }
response, err := store.queryFrom(ctx, q, r.peerId, protocol.GenerateRequestId()) response, err := store.queryFrom(ctx, q, r.PeerID(), protocol.GenerateRequestId())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -690,7 +468,7 @@ func (store *WakuStore) Next(ctx context.Context, r *Result) (*Result, error) {
Messages: response.Messages, Messages: response.Messages,
cursor: response.PagingInfo.Cursor, cursor: response.PagingInfo.Cursor,
query: q, query: q,
peerId: r.peerId, peerId: r.PeerID(),
}, nil }, nil
} }
@ -732,14 +510,8 @@ func (store *WakuStore) queryLoop(ctx context.Context, query *pb.HistoryQuery, c
return nil, ErrFailedQuery return nil, ErrFailedQuery
} }
func (store *WakuStore) findLastSeen() int64 { func (store *WakuStore) findLastSeen() (int64, error) {
var lastSeenTime int64 = 0 return store.msgProvider.MostRecentTimestamp()
for imsg := range store.messageQueue.Messages() {
if imsg.msg.Timestamp > lastSeenTime {
lastSeenTime = imsg.msg.Timestamp
}
}
return lastSeenTime
} }
func max(x, y int64) int64 { func max(x, y int64) int64 {
@ -763,7 +535,10 @@ func (store *WakuStore) Resume(ctx context.Context, pubsubTopic string, peerList
} }
currentTime := utils.GetUnixEpoch() currentTime := utils.GetUnixEpoch()
lastSeenTime := store.findLastSeen() lastSeenTime, err := store.findLastSeen()
if err != nil {
return 0, err
}
var offset int64 = int64(20 * time.Nanosecond) var offset int64 = int64(20 * time.Nanosecond)
currentTime = currentTime + offset currentTime = currentTime + offset
@ -797,7 +572,7 @@ func (store *WakuStore) Resume(ctx context.Context, pubsubTopic string, peerList
msgCount := 0 msgCount := 0
for _, msg := range messages { for _, msg := range messages {
if err = store.storeMessage(protocol.NewEnvelope(msg, pubsubTopic)); err == nil { if err = store.storeMessage(protocol.NewEnvelope(msg, utils.GetUnixEpoch(), pubsubTopic)); err == nil {
msgCount++ msgCount++
} }
} }

View File

@ -1,9 +1,9 @@
package store package store
import ( import (
"sort"
"testing" "testing"
"github.com/status-im/go-waku/waku/persistence"
"github.com/status-im/go-waku/waku/v2/protocol" "github.com/status-im/go-waku/waku/v2/protocol"
"github.com/status-im/go-waku/waku/v2/protocol/pb" "github.com/status-im/go-waku/waku/v2/protocol/pb"
"github.com/status-im/go-waku/waku/v2/utils" "github.com/status-im/go-waku/waku/v2/utils"
@ -16,8 +16,7 @@ func TestIndexComputation(t *testing.T) {
Timestamp: utils.GetUnixEpoch(), Timestamp: utils.GetUnixEpoch(),
} }
idx, err := computeIndex(protocol.NewEnvelope(msg, "test")) idx := protocol.NewEnvelope(msg, utils.GetUnixEpoch(), "test").Index()
require.NoError(t, err)
require.NotZero(t, idx.ReceiverTime) require.NotZero(t, idx.ReceiverTime)
require.Equal(t, msg.Timestamp, idx.SenderTime) require.Equal(t, msg.Timestamp, idx.SenderTime)
require.NotZero(t, idx.Digest) require.NotZero(t, idx.Digest)
@ -28,268 +27,217 @@ func TestIndexComputation(t *testing.T) {
Timestamp: 123, Timestamp: 123,
ContentTopic: "/waku/2/default-content/proto", ContentTopic: "/waku/2/default-content/proto",
} }
idx1, err := computeIndex(protocol.NewEnvelope(msg1, "test")) idx1 := protocol.NewEnvelope(msg1, utils.GetUnixEpoch(), "test").Index()
require.NoError(t, err)
msg2 := &pb.WakuMessage{ msg2 := &pb.WakuMessage{
Payload: []byte{1, 2, 3}, Payload: []byte{1, 2, 3},
Timestamp: 123, Timestamp: 123,
ContentTopic: "/waku/2/default-content/proto", ContentTopic: "/waku/2/default-content/proto",
} }
idx2, err := computeIndex(protocol.NewEnvelope(msg2, "test")) idx2 := protocol.NewEnvelope(msg2, utils.GetUnixEpoch(), "test").Index()
require.NoError(t, err)
require.Equal(t, idx1.Digest, idx2.Digest) require.Equal(t, idx1.Digest, idx2.Digest)
} }
func TestIndexComparison(t *testing.T) { func createSampleList(s int) []*protocol.Envelope {
var result []*protocol.Envelope
index1 := &pb.Index{
ReceiverTime: 2,
SenderTime: 1,
Digest: []byte{1},
PubsubTopic: "abc",
}
index2 := &pb.Index{
ReceiverTime: 2,
SenderTime: 1,
Digest: []byte{2},
PubsubTopic: "abc",
}
index3 := &pb.Index{
ReceiverTime: 1,
SenderTime: 2,
Digest: []byte{3},
PubsubTopic: "abc",
}
index4 := &pb.Index{
ReceiverTime: 1,
SenderTime: 2,
Digest: []byte{3},
PubsubTopic: "def",
}
iwm1 := IndexedWakuMessage{index: index1}
iwm2 := IndexedWakuMessage{index: index2}
iwm3 := IndexedWakuMessage{index: index3}
iwm4 := IndexedWakuMessage{index: index4}
require.Equal(t, 0, indexComparison(index1, index1))
require.Equal(t, -1, indexComparison(index1, index2))
require.Equal(t, 1, indexComparison(index2, index1))
require.Equal(t, -1, indexComparison(index1, index3))
require.Equal(t, 1, indexComparison(index3, index1))
require.Equal(t, -1, indexComparison(index3, index4))
require.Equal(t, 0, indexedWakuMessageComparison(iwm1, iwm1))
require.Equal(t, -1, indexedWakuMessageComparison(iwm1, iwm2))
require.Equal(t, 1, indexedWakuMessageComparison(iwm2, iwm1))
require.Equal(t, -1, indexedWakuMessageComparison(iwm1, iwm3))
require.Equal(t, 1, indexedWakuMessageComparison(iwm3, iwm1))
require.Equal(t, -1, indexedWakuMessageComparison(iwm3, iwm4))
sortingList := []IndexedWakuMessage{iwm3, iwm1, iwm2, iwm4}
sort.Slice(sortingList, func(i, j int) bool {
return indexedWakuMessageComparison(sortingList[i], sortingList[j]) == -1
})
require.Equal(t, iwm1, sortingList[0])
require.Equal(t, iwm2, sortingList[1])
require.Equal(t, iwm3, sortingList[2])
require.Equal(t, iwm4, sortingList[3])
}
func createSampleList(s int) []IndexedWakuMessage {
var result []IndexedWakuMessage
for i := 0; i < s; i++ { for i := 0; i < s; i++ {
result = append(result, IndexedWakuMessage{ msg :=
msg: &pb.WakuMessage{ &pb.WakuMessage{
Payload: []byte{byte(i)}, Payload: []byte{byte(i)},
}, Timestamp: int64(i),
index: &pb.Index{ }
ReceiverTime: int64(i), result = append(result, protocol.NewEnvelope(msg, int64(i), "abc"))
SenderTime: int64(i),
Digest: []byte{1},
PubsubTopic: "abc",
},
})
} }
return result return result
} }
func TestFindIndex(t *testing.T) {
msgList := createSampleList(10)
require.Equal(t, 3, findIndex(msgList, msgList[3].index))
require.Equal(t, -1, findIndex(msgList, &pb.Index{}))
}
func TestForwardPagination(t *testing.T) { func TestForwardPagination(t *testing.T) {
msgList := createSampleList(10) msgList := createSampleList(10)
db := MemoryDB(t)
for _, m := range msgList {
err := db.Put(m)
require.NoError(t, err)
}
// test for a normal pagination // test for a normal pagination
pagingInfo := &pb.PagingInfo{PageSize: 2, Cursor: msgList[3].index, Direction: pb.PagingInfo_FORWARD} pagingInfo := &pb.PagingInfo{PageSize: 2, Cursor: msgList[3].Index(), Direction: pb.PagingInfo_FORWARD}
messages, newPagingInfo := paginateWithoutIndex(msgList, pagingInfo) messages, newPagingInfo, err := findMessages(&pb.HistoryQuery{PagingInfo: pagingInfo}, db)
require.NoError(t, err)
require.Len(t, messages, 2) require.Len(t, messages, 2)
require.Equal(t, []*pb.WakuMessage{msgList[4].msg, msgList[5].msg}, messages) require.Equal(t, []*pb.WakuMessage{msgList[4].Message(), msgList[5].Message()}, messages)
require.Equal(t, msgList[5].index, newPagingInfo.Cursor) require.Equal(t, msgList[5].Index(), newPagingInfo.Cursor)
require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction) require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction)
require.Equal(t, pagingInfo.PageSize, newPagingInfo.PageSize) require.Equal(t, pagingInfo.PageSize, newPagingInfo.PageSize)
// test for an initial pagination request with an empty cursor // test for an initial pagination request with an empty cursor
pagingInfo = &pb.PagingInfo{PageSize: 2, Direction: pb.PagingInfo_FORWARD} pagingInfo = &pb.PagingInfo{PageSize: 2, Direction: pb.PagingInfo_FORWARD}
messages, newPagingInfo = paginateWithoutIndex(msgList, pagingInfo) messages, newPagingInfo, err = findMessages(&pb.HistoryQuery{PagingInfo: pagingInfo}, db)
require.NoError(t, err)
require.Len(t, messages, 2) require.Len(t, messages, 2)
require.Equal(t, []*pb.WakuMessage{msgList[0].msg, msgList[1].msg}, messages) require.Equal(t, []*pb.WakuMessage{msgList[0].Message(), msgList[1].Message()}, messages)
require.Equal(t, msgList[1].index, newPagingInfo.Cursor) require.Equal(t, msgList[1].Index(), newPagingInfo.Cursor)
require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction) require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction)
require.Equal(t, pagingInfo.PageSize, newPagingInfo.PageSize) require.Equal(t, pagingInfo.PageSize, newPagingInfo.PageSize)
// test for an initial pagination request with an empty cursor to fetch the entire history // test for an initial pagination request with an empty cursor to fetch the entire history
pagingInfo = &pb.PagingInfo{PageSize: 13, Direction: pb.PagingInfo_FORWARD} pagingInfo = &pb.PagingInfo{PageSize: 13, Direction: pb.PagingInfo_FORWARD}
messages, newPagingInfo = paginateWithoutIndex(msgList, pagingInfo) messages, newPagingInfo, err = findMessages(&pb.HistoryQuery{PagingInfo: pagingInfo}, db)
require.NoError(t, err)
require.Len(t, messages, 10) require.Len(t, messages, 10)
require.Equal(t, msgList[9].msg, messages[9]) require.Equal(t, msgList[9].Message(), messages[9])
require.Equal(t, msgList[9].index, newPagingInfo.Cursor) require.Equal(t, msgList[9].Index(), newPagingInfo.Cursor)
require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction) require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction)
require.Equal(t, uint64(10), newPagingInfo.PageSize) require.Equal(t, uint64(10), newPagingInfo.PageSize)
// test for an empty msgList // test for an empty msgList
pagingInfo = &pb.PagingInfo{PageSize: 2, Direction: pb.PagingInfo_FORWARD} pagingInfo = &pb.PagingInfo{PageSize: 2, Direction: pb.PagingInfo_FORWARD}
var msgList2 []IndexedWakuMessage messages, newPagingInfo, err = findMessages(&pb.HistoryQuery{PagingInfo: pagingInfo}, MemoryDB(t))
messages, newPagingInfo = paginateWithoutIndex(msgList2, pagingInfo) require.NoError(t, err)
require.Len(t, messages, 0) require.Len(t, messages, 0)
require.Equal(t, pagingInfo.Cursor, newPagingInfo.Cursor) require.Equal(t, pagingInfo.Cursor, newPagingInfo.Cursor)
require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction) require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction)
require.Equal(t, uint64(0), newPagingInfo.PageSize) require.Equal(t, uint64(0), newPagingInfo.PageSize)
// test for a page size larger than the remaining messages // test for a page size larger than the remaining messages
pagingInfo = &pb.PagingInfo{PageSize: 10, Cursor: msgList[3].index, Direction: pb.PagingInfo_FORWARD} pagingInfo = &pb.PagingInfo{PageSize: 10, Cursor: msgList[3].Index(), Direction: pb.PagingInfo_FORWARD}
messages, newPagingInfo = paginateWithoutIndex(msgList, pagingInfo) messages, newPagingInfo, err = findMessages(&pb.HistoryQuery{PagingInfo: pagingInfo}, db)
require.NoError(t, err)
require.Len(t, messages, 6) require.Len(t, messages, 6)
require.Equal(t, []*pb.WakuMessage{msgList[4].msg, msgList[5].msg, msgList[6].msg, msgList[7].msg, msgList[8].msg, msgList[9].msg}, messages) require.Equal(t, []*pb.WakuMessage{msgList[4].Message(), msgList[5].Message(), msgList[6].Message(), msgList[7].Message(), msgList[8].Message(), msgList[9].Message()}, messages)
require.Equal(t, msgList[9].index, newPagingInfo.Cursor) require.Equal(t, msgList[9].Index(), newPagingInfo.Cursor)
require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction) require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction)
require.Equal(t, uint64(6), newPagingInfo.PageSize) require.Equal(t, uint64(6), newPagingInfo.PageSize)
// test for a page size larger than the maximum allowed page size // test for a page size larger than the maximum allowed page size
pagingInfo = &pb.PagingInfo{PageSize: MaxPageSize + 1, Cursor: msgList[3].index, Direction: pb.PagingInfo_FORWARD} pagingInfo = &pb.PagingInfo{PageSize: MaxPageSize + 1, Cursor: msgList[3].Index(), Direction: pb.PagingInfo_FORWARD}
messages, newPagingInfo = paginateWithoutIndex(msgList, pagingInfo) messages, newPagingInfo, err = findMessages(&pb.HistoryQuery{PagingInfo: pagingInfo}, db)
require.NoError(t, err)
require.True(t, len(messages) <= MaxPageSize) require.True(t, len(messages) <= MaxPageSize)
require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction) require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction)
require.True(t, newPagingInfo.PageSize <= MaxPageSize) require.True(t, newPagingInfo.PageSize <= MaxPageSize)
// test for a cursor pointing to the end of the message list // test for a cursor pointing to the end of the message list
pagingInfo = &pb.PagingInfo{PageSize: 10, Cursor: msgList[9].index, Direction: pb.PagingInfo_FORWARD} pagingInfo = &pb.PagingInfo{PageSize: 10, Cursor: msgList[9].Index(), Direction: pb.PagingInfo_FORWARD}
messages, newPagingInfo = paginateWithoutIndex(msgList, pagingInfo) messages, newPagingInfo, err = findMessages(&pb.HistoryQuery{PagingInfo: pagingInfo}, db)
require.NoError(t, err)
require.Len(t, messages, 0) require.Len(t, messages, 0)
require.Equal(t, msgList[9].index, newPagingInfo.Cursor) require.Equal(t, msgList[9].Index(), newPagingInfo.Cursor)
require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction) require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction)
require.Equal(t, uint64(0), newPagingInfo.PageSize) require.Equal(t, uint64(0), newPagingInfo.PageSize)
// test for an invalid cursor // test for an invalid cursor
invalidIndex, err := computeIndex(protocol.NewEnvelope(&pb.WakuMessage{Payload: []byte{255, 255, 255}}, "test")) invalidIndex := protocol.NewEnvelope(&pb.WakuMessage{Payload: []byte{255, 255, 255}}, utils.GetUnixEpoch(), "test").Index()
require.NoError(t, err)
pagingInfo = &pb.PagingInfo{PageSize: 10, Cursor: invalidIndex, Direction: pb.PagingInfo_FORWARD} pagingInfo = &pb.PagingInfo{PageSize: 10, Cursor: invalidIndex, Direction: pb.PagingInfo_FORWARD}
messages, newPagingInfo = paginateWithoutIndex(msgList, pagingInfo) _, _, err = findMessages(&pb.HistoryQuery{PagingInfo: pagingInfo}, db)
require.ErrorIs(t, err, persistence.ErrInvalidCursor)
require.Len(t, messages, 0) require.Len(t, messages, 0)
require.Equal(t, pagingInfo.Cursor, newPagingInfo.Cursor)
require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction)
require.Equal(t, uint64(0), newPagingInfo.PageSize)
// test initial paging query over a message list with one message // test initial paging query over a message list with one message
singleItemMsgList := msgList[0:1] singleItemDB := MemoryDB(t)
err = singleItemDB.Put(msgList[0])
require.NoError(t, err)
pagingInfo = &pb.PagingInfo{PageSize: 10, Direction: pb.PagingInfo_FORWARD} pagingInfo = &pb.PagingInfo{PageSize: 10, Direction: pb.PagingInfo_FORWARD}
messages, newPagingInfo = paginateWithoutIndex(singleItemMsgList, pagingInfo) messages, newPagingInfo, err = findMessages(&pb.HistoryQuery{PagingInfo: pagingInfo}, singleItemDB)
require.NoError(t, err)
require.Len(t, messages, 1) require.Len(t, messages, 1)
require.Equal(t, msgList[0].index, newPagingInfo.Cursor) require.Equal(t, msgList[0].Index(), newPagingInfo.Cursor)
require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction) require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction)
require.Equal(t, uint64(1), newPagingInfo.PageSize) require.Equal(t, uint64(1), newPagingInfo.PageSize)
} }
func TestBackwardPagination(t *testing.T) { func TestBackwardPagination(t *testing.T) {
msgList := createSampleList(10) msgList := createSampleList(10)
db := MemoryDB(t)
for _, m := range msgList {
err := db.Put(m)
require.NoError(t, err)
}
// test for a normal pagination // test for a normal pagination
pagingInfo := &pb.PagingInfo{PageSize: 2, Cursor: msgList[3].index, Direction: pb.PagingInfo_BACKWARD} pagingInfo := &pb.PagingInfo{PageSize: 2, Cursor: msgList[3].Index(), Direction: pb.PagingInfo_BACKWARD}
messages, newPagingInfo := paginateWithoutIndex(msgList, pagingInfo) messages, newPagingInfo, err := findMessages(&pb.HistoryQuery{PagingInfo: pagingInfo}, db)
require.NoError(t, err)
require.Len(t, messages, 2) require.Len(t, messages, 2)
require.Equal(t, []*pb.WakuMessage{msgList[1].msg, msgList[2].msg}, messages)
require.Equal(t, msgList[1].index, newPagingInfo.Cursor) require.Equal(t, []*pb.WakuMessage{msgList[2].Message(), msgList[1].Message()}, messages)
require.Equal(t, msgList[1].Index(), newPagingInfo.Cursor)
require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction) require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction)
require.Equal(t, pagingInfo.PageSize, newPagingInfo.PageSize) require.Equal(t, pagingInfo.PageSize, newPagingInfo.PageSize)
// test for an initial pagination request with an empty cursor // test for an initial pagination request with an empty cursor
pagingInfo = &pb.PagingInfo{PageSize: 2, Direction: pb.PagingInfo_BACKWARD} pagingInfo = &pb.PagingInfo{PageSize: 2, Direction: pb.PagingInfo_BACKWARD}
messages, newPagingInfo = paginateWithoutIndex(msgList, pagingInfo) messages, newPagingInfo, err = findMessages(&pb.HistoryQuery{PagingInfo: pagingInfo}, db)
require.NoError(t, err)
require.Len(t, messages, 2) require.Len(t, messages, 2)
require.Equal(t, []*pb.WakuMessage{msgList[8].msg, msgList[9].msg}, messages) require.Equal(t, []*pb.WakuMessage{msgList[9].Message(), msgList[8].Message()}, messages)
require.Equal(t, msgList[8].index, newPagingInfo.Cursor) require.Equal(t, msgList[8].Index(), newPagingInfo.Cursor)
require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction) require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction)
require.Equal(t, pagingInfo.PageSize, newPagingInfo.PageSize) require.Equal(t, pagingInfo.PageSize, newPagingInfo.PageSize)
// test for an initial pagination request with an empty cursor to fetch the entire history // test for an initial pagination request with an empty cursor to fetch the entire history
pagingInfo = &pb.PagingInfo{PageSize: 13, Direction: pb.PagingInfo_BACKWARD} pagingInfo = &pb.PagingInfo{PageSize: 13, Direction: pb.PagingInfo_BACKWARD}
messages, newPagingInfo = paginateWithoutIndex(msgList, pagingInfo) messages, newPagingInfo, err = findMessages(&pb.HistoryQuery{PagingInfo: pagingInfo}, db)
require.NoError(t, err)
require.Len(t, messages, 10) require.Len(t, messages, 10)
require.Equal(t, msgList[0].msg, messages[0]) require.Equal(t, msgList[0].Message(), messages[9])
require.Equal(t, msgList[9].msg, messages[9]) require.Equal(t, msgList[9].Message(), messages[0])
require.Equal(t, msgList[0].index, newPagingInfo.Cursor) require.Equal(t, msgList[0].Index(), newPagingInfo.Cursor)
require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction) require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction)
require.Equal(t, uint64(10), newPagingInfo.PageSize) require.Equal(t, uint64(10), newPagingInfo.PageSize)
// test for an empty msgList // test for an empty msgList
pagingInfo = &pb.PagingInfo{PageSize: 2, Direction: pb.PagingInfo_BACKWARD} pagingInfo = &pb.PagingInfo{PageSize: 2, Direction: pb.PagingInfo_BACKWARD}
var msgList2 []IndexedWakuMessage messages, newPagingInfo, err = findMessages(&pb.HistoryQuery{PagingInfo: pagingInfo}, MemoryDB(t))
messages, newPagingInfo = paginateWithoutIndex(msgList2, pagingInfo) require.NoError(t, err)
require.Len(t, messages, 0) require.Len(t, messages, 0)
require.Equal(t, pagingInfo.Cursor, newPagingInfo.Cursor) require.Equal(t, pagingInfo.Cursor, newPagingInfo.Cursor)
require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction) require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction)
require.Equal(t, uint64(0), newPagingInfo.PageSize) require.Equal(t, uint64(0), newPagingInfo.PageSize)
// test for a page size larger than the remaining messages // test for a page size larger than the remaining messages
pagingInfo = &pb.PagingInfo{PageSize: 5, Cursor: msgList[3].index, Direction: pb.PagingInfo_BACKWARD} pagingInfo = &pb.PagingInfo{PageSize: 5, Cursor: msgList[3].Index(), Direction: pb.PagingInfo_BACKWARD}
messages, newPagingInfo = paginateWithoutIndex(msgList, pagingInfo) messages, newPagingInfo, err = findMessages(&pb.HistoryQuery{PagingInfo: pagingInfo}, db)
require.NoError(t, err)
require.Len(t, messages, 3) require.Len(t, messages, 3)
require.Equal(t, []*pb.WakuMessage{msgList[0].msg, msgList[1].msg, msgList[2].msg}, messages) require.Equal(t, []*pb.WakuMessage{msgList[2].Message(), msgList[1].Message(), msgList[0].Message()}, messages)
require.Equal(t, msgList[0].index, newPagingInfo.Cursor) require.Equal(t, msgList[0].Index(), newPagingInfo.Cursor)
require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction) require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction)
require.Equal(t, uint64(3), newPagingInfo.PageSize) require.Equal(t, uint64(3), newPagingInfo.PageSize)
// test for a page size larger than the maximum allowed page size // test for a page size larger than the maximum allowed page size
pagingInfo = &pb.PagingInfo{PageSize: MaxPageSize + 1, Cursor: msgList[3].index, Direction: pb.PagingInfo_BACKWARD} pagingInfo = &pb.PagingInfo{PageSize: MaxPageSize + 1, Cursor: msgList[3].Index(), Direction: pb.PagingInfo_BACKWARD}
messages, newPagingInfo = paginateWithoutIndex(msgList, pagingInfo) messages, newPagingInfo, err = findMessages(&pb.HistoryQuery{PagingInfo: pagingInfo}, db)
require.NoError(t, err)
require.True(t, len(messages) <= MaxPageSize) require.True(t, len(messages) <= MaxPageSize)
require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction) require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction)
require.True(t, newPagingInfo.PageSize <= MaxPageSize) require.True(t, newPagingInfo.PageSize <= MaxPageSize)
// test for a cursor pointing to the beginning of the message list // test for a cursor pointing to the beginning of the message list
pagingInfo = &pb.PagingInfo{PageSize: 10, Cursor: msgList[0].index, Direction: pb.PagingInfo_BACKWARD} pagingInfo = &pb.PagingInfo{PageSize: 10, Cursor: msgList[0].Index(), Direction: pb.PagingInfo_BACKWARD}
messages, newPagingInfo = paginateWithoutIndex(msgList, pagingInfo) messages, newPagingInfo, err = findMessages(&pb.HistoryQuery{PagingInfo: pagingInfo}, db)
require.NoError(t, err)
require.Len(t, messages, 0) require.Len(t, messages, 0)
require.Equal(t, msgList[0].index, newPagingInfo.Cursor) require.Equal(t, msgList[0].Index(), newPagingInfo.Cursor)
require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction) require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction)
require.Equal(t, uint64(0), newPagingInfo.PageSize) require.Equal(t, uint64(0), newPagingInfo.PageSize)
// test for an invalid cursor // test for an invalid cursor
invalidIndex, err := computeIndex(protocol.NewEnvelope(&pb.WakuMessage{Payload: []byte{255, 255, 255}}, "test")) invalidIndex := protocol.NewEnvelope(&pb.WakuMessage{Payload: []byte{255, 255, 255}}, utils.GetUnixEpoch(), "test").Index()
require.NoError(t, err)
pagingInfo = &pb.PagingInfo{PageSize: 10, Cursor: invalidIndex, Direction: pb.PagingInfo_BACKWARD} pagingInfo = &pb.PagingInfo{PageSize: 10, Cursor: invalidIndex, Direction: pb.PagingInfo_BACKWARD}
messages, newPagingInfo = paginateWithoutIndex(msgList, pagingInfo) _, _, err = findMessages(&pb.HistoryQuery{PagingInfo: pagingInfo}, db)
require.ErrorIs(t, err, persistence.ErrInvalidCursor)
require.Len(t, messages, 0) require.Len(t, messages, 0)
require.Equal(t, pagingInfo.Cursor, newPagingInfo.Cursor)
require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction)
require.Equal(t, uint64(0), newPagingInfo.PageSize)
// test initial paging query over a message list with one message // test initial paging query over a message list with one message
singleItemMsgList := msgList[0:1] singleItemDB := MemoryDB(t)
err = singleItemDB.Put(msgList[0])
require.NoError(t, err)
pagingInfo = &pb.PagingInfo{PageSize: 10, Direction: pb.PagingInfo_BACKWARD} pagingInfo = &pb.PagingInfo{PageSize: 10, Direction: pb.PagingInfo_BACKWARD}
messages, newPagingInfo = paginateWithoutIndex(singleItemMsgList, pagingInfo) messages, newPagingInfo, err = findMessages(&pb.HistoryQuery{PagingInfo: pagingInfo}, singleItemDB)
require.NoError(t, err)
require.Len(t, messages, 1) require.Len(t, messages, 1)
require.Equal(t, msgList[0].index, newPagingInfo.Cursor) require.Equal(t, msgList[0].Index(), newPagingInfo.Cursor)
require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction) require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction)
require.Equal(t, uint64(1), newPagingInfo.PageSize) require.Equal(t, uint64(1), newPagingInfo.PageSize)
} }

View File

@ -1,12 +1,8 @@
package store package store
import ( import (
"context"
"database/sql"
"testing" "testing"
"github.com/status-im/go-waku/waku/persistence"
"github.com/status-im/go-waku/waku/persistence/sqlite"
"github.com/status-im/go-waku/waku/v2/protocol" "github.com/status-im/go-waku/waku/v2/protocol"
"github.com/status-im/go-waku/waku/v2/protocol/pb" "github.com/status-im/go-waku/waku/v2/protocol/pb"
"github.com/status-im/go-waku/waku/v2/utils" "github.com/status-im/go-waku/waku/v2/utils"
@ -14,19 +10,9 @@ import (
) )
func TestStorePersistence(t *testing.T) { func TestStorePersistence(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) db := MemoryDB(t)
defer cancel()
var db *sql.DB s1 := NewWakuStore(nil, nil, db, 0, 0, utils.Logger())
db, err := sqlite.NewDB(":memory:")
require.NoError(t, err)
dbStore, err := persistence.NewDBStore(utils.Logger(), persistence.WithDB(db))
require.NoError(t, err)
s1 := NewWakuStore(nil, nil, dbStore, 0, 0, utils.Logger())
s1.fetchDBRecords(ctx)
require.Len(t, s1.messageQueue.messages, 0)
defaultPubSubTopic := "test" defaultPubSubTopic := "test"
defaultContentTopic := "1" defaultContentTopic := "1"
@ -37,14 +23,14 @@ func TestStorePersistence(t *testing.T) {
Timestamp: utils.GetUnixEpoch(), Timestamp: utils.GetUnixEpoch(),
} }
_ = s1.storeMessage(protocol.NewEnvelope(msg, defaultPubSubTopic)) _ = s1.storeMessage(protocol.NewEnvelope(msg, utils.GetUnixEpoch(), defaultPubSubTopic))
s2 := NewWakuStore(nil, nil, dbStore, 0, 0, utils.Logger()) allMsgs, err := db.GetAll()
s2.fetchDBRecords(ctx) require.NoError(t, err)
require.Len(t, s2.messageQueue.messages, 1) require.Len(t, allMsgs, 1)
require.Equal(t, msg, s2.messageQueue.messages[0].msg) require.Equal(t, msg, allMsgs[0].Message)
// Storing a duplicated message should not crash. It's okay to generate an error log in this case // Storing a duplicated message should not crash. It's okay to generate an error log in this case
err = s1.storeMessage(protocol.NewEnvelope(msg, defaultPubSubTopic)) err = s1.storeMessage(protocol.NewEnvelope(msg, utils.GetUnixEpoch(), defaultPubSubTopic))
require.ErrorIs(t, err, ErrDuplicatedMessage) require.Error(t, err)
} }

View File

@ -20,7 +20,7 @@ func TestWakuStoreProtocolQuery(t *testing.T) {
host1, err := libp2p.New(libp2p.DefaultTransports, libp2p.ListenAddrStrings("/ip4/0.0.0.0/tcp/0")) host1, err := libp2p.New(libp2p.DefaultTransports, libp2p.ListenAddrStrings("/ip4/0.0.0.0/tcp/0"))
require.NoError(t, err) require.NoError(t, err)
s1 := NewWakuStore(host1, nil, nil, 0, 0, utils.Logger()) s1 := NewWakuStore(host1, nil, MemoryDB(t), 0, 0, utils.Logger())
s1.Start(ctx) s1.Start(ctx)
defer s1.Stop() defer s1.Stop()
@ -37,9 +37,9 @@ func TestWakuStoreProtocolQuery(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Simulate a message has been received via relay protocol // Simulate a message has been received via relay protocol
s1.MsgC <- protocol.NewEnvelope(msg, pubsubTopic1) s1.MsgC <- protocol.NewEnvelope(msg, utils.GetUnixEpoch(), pubsubTopic1)
s2 := NewWakuStore(host2, nil, nil, 0, 0, utils.Logger()) s2 := NewWakuStore(host2, nil, MemoryDB(t), 0, 0, utils.Logger())
s2.Start(ctx) s2.Start(ctx)
defer s2.Stop() defer s2.Stop()
@ -66,7 +66,9 @@ func TestWakuStoreProtocolNext(t *testing.T) {
host1, err := libp2p.New(libp2p.DefaultTransports, libp2p.ListenAddrStrings("/ip4/0.0.0.0/tcp/0")) host1, err := libp2p.New(libp2p.DefaultTransports, libp2p.ListenAddrStrings("/ip4/0.0.0.0/tcp/0"))
require.NoError(t, err) require.NoError(t, err)
s1 := NewWakuStore(host1, nil, nil, 0, 0, utils.Logger()) db := MemoryDB(t)
s1 := NewWakuStore(host1, nil, db, 0, 0, utils.Logger())
s1.Start(ctx) s1.Start(ctx)
defer s1.Stop() defer s1.Stop()
@ -79,11 +81,11 @@ func TestWakuStoreProtocolNext(t *testing.T) {
msg4 := tests.CreateWakuMessage(topic1, 4) msg4 := tests.CreateWakuMessage(topic1, 4)
msg5 := tests.CreateWakuMessage(topic1, 5) msg5 := tests.CreateWakuMessage(topic1, 5)
s1.MsgC <- protocol.NewEnvelope(msg1, pubsubTopic1) s1.MsgC <- protocol.NewEnvelope(msg1, utils.GetUnixEpoch(), pubsubTopic1)
s1.MsgC <- protocol.NewEnvelope(msg2, pubsubTopic1) s1.MsgC <- protocol.NewEnvelope(msg2, utils.GetUnixEpoch(), pubsubTopic1)
s1.MsgC <- protocol.NewEnvelope(msg3, pubsubTopic1) s1.MsgC <- protocol.NewEnvelope(msg3, utils.GetUnixEpoch(), pubsubTopic1)
s1.MsgC <- protocol.NewEnvelope(msg4, pubsubTopic1) s1.MsgC <- protocol.NewEnvelope(msg4, utils.GetUnixEpoch(), pubsubTopic1)
s1.MsgC <- protocol.NewEnvelope(msg5, pubsubTopic1) s1.MsgC <- protocol.NewEnvelope(msg5, utils.GetUnixEpoch(), pubsubTopic1)
host2, err := libp2p.New(libp2p.DefaultTransports, libp2p.ListenAddrStrings("/ip4/0.0.0.0/tcp/0")) host2, err := libp2p.New(libp2p.DefaultTransports, libp2p.ListenAddrStrings("/ip4/0.0.0.0/tcp/0"))
require.NoError(t, err) require.NoError(t, err)
@ -92,7 +94,7 @@ func TestWakuStoreProtocolNext(t *testing.T) {
err = host2.Peerstore().AddProtocols(host1.ID(), string(StoreID_v20beta4)) err = host2.Peerstore().AddProtocols(host1.ID(), string(StoreID_v20beta4))
require.NoError(t, err) require.NoError(t, err)
s2 := NewWakuStore(host2, nil, nil, 0, 0, utils.Logger()) s2 := NewWakuStore(host2, nil, db, 0, 0, utils.Logger())
s2.Start(ctx) s2.Start(ctx)
defer s2.Stop() defer s2.Stop()

View File

@ -17,9 +17,9 @@ func TestStoreQuery(t *testing.T) {
msg1 := tests.CreateWakuMessage(defaultContentTopic, utils.GetUnixEpoch()) msg1 := tests.CreateWakuMessage(defaultContentTopic, utils.GetUnixEpoch())
msg2 := tests.CreateWakuMessage("2", utils.GetUnixEpoch()) msg2 := tests.CreateWakuMessage("2", utils.GetUnixEpoch())
s := NewWakuStore(nil, nil, nil, 0, 0, utils.Logger()) s := NewWakuStore(nil, nil, MemoryDB(t), 0, 0, utils.Logger())
_ = s.storeMessage(protocol.NewEnvelope(msg1, defaultPubSubTopic)) _ = s.storeMessage(protocol.NewEnvelope(msg1, utils.GetUnixEpoch(), defaultPubSubTopic))
_ = s.storeMessage(protocol.NewEnvelope(msg2, defaultPubSubTopic)) _ = s.storeMessage(protocol.NewEnvelope(msg2, utils.GetUnixEpoch(), defaultPubSubTopic))
response := s.FindMessages(&pb.HistoryQuery{ response := s.FindMessages(&pb.HistoryQuery{
ContentFilters: []*pb.ContentFilter{ ContentFilters: []*pb.ContentFilter{
@ -43,11 +43,11 @@ func TestStoreQueryMultipleContentFilters(t *testing.T) {
msg2 := tests.CreateWakuMessage(topic2, utils.GetUnixEpoch()) msg2 := tests.CreateWakuMessage(topic2, utils.GetUnixEpoch())
msg3 := tests.CreateWakuMessage(topic3, utils.GetUnixEpoch()) msg3 := tests.CreateWakuMessage(topic3, utils.GetUnixEpoch())
s := NewWakuStore(nil, nil, nil, 0, 0, utils.Logger()) s := NewWakuStore(nil, nil, MemoryDB(t), 0, 0, utils.Logger())
_ = s.storeMessage(protocol.NewEnvelope(msg1, defaultPubSubTopic)) _ = s.storeMessage(protocol.NewEnvelope(msg1, utils.GetUnixEpoch(), defaultPubSubTopic))
_ = s.storeMessage(protocol.NewEnvelope(msg2, defaultPubSubTopic)) _ = s.storeMessage(protocol.NewEnvelope(msg2, utils.GetUnixEpoch(), defaultPubSubTopic))
_ = s.storeMessage(protocol.NewEnvelope(msg3, defaultPubSubTopic)) _ = s.storeMessage(protocol.NewEnvelope(msg3, utils.GetUnixEpoch(), defaultPubSubTopic))
response := s.FindMessages(&pb.HistoryQuery{ response := s.FindMessages(&pb.HistoryQuery{
ContentFilters: []*pb.ContentFilter{ ContentFilters: []*pb.ContentFilter{
@ -77,10 +77,10 @@ func TestStoreQueryPubsubTopicFilter(t *testing.T) {
msg2 := tests.CreateWakuMessage(topic2, utils.GetUnixEpoch()) msg2 := tests.CreateWakuMessage(topic2, utils.GetUnixEpoch())
msg3 := tests.CreateWakuMessage(topic3, utils.GetUnixEpoch()) msg3 := tests.CreateWakuMessage(topic3, utils.GetUnixEpoch())
s := NewWakuStore(nil, nil, nil, 0, 0, utils.Logger()) s := NewWakuStore(nil, nil, MemoryDB(t), 0, 0, utils.Logger())
_ = s.storeMessage(protocol.NewEnvelope(msg1, pubsubTopic1)) _ = s.storeMessage(protocol.NewEnvelope(msg1, utils.GetUnixEpoch(), pubsubTopic1))
_ = s.storeMessage(protocol.NewEnvelope(msg2, pubsubTopic2)) _ = s.storeMessage(protocol.NewEnvelope(msg2, utils.GetUnixEpoch(), pubsubTopic2))
_ = s.storeMessage(protocol.NewEnvelope(msg3, pubsubTopic2)) _ = s.storeMessage(protocol.NewEnvelope(msg3, utils.GetUnixEpoch(), pubsubTopic2))
response := s.FindMessages(&pb.HistoryQuery{ response := s.FindMessages(&pb.HistoryQuery{
PubsubTopic: pubsubTopic1, PubsubTopic: pubsubTopic1,
@ -109,10 +109,10 @@ func TestStoreQueryPubsubTopicNoMatch(t *testing.T) {
msg2 := tests.CreateWakuMessage(topic2, utils.GetUnixEpoch()) msg2 := tests.CreateWakuMessage(topic2, utils.GetUnixEpoch())
msg3 := tests.CreateWakuMessage(topic3, utils.GetUnixEpoch()) msg3 := tests.CreateWakuMessage(topic3, utils.GetUnixEpoch())
s := NewWakuStore(nil, nil, nil, 0, 0, utils.Logger()) s := NewWakuStore(nil, nil, MemoryDB(t), 0, 0, utils.Logger())
_ = s.storeMessage(protocol.NewEnvelope(msg1, pubsubTopic2)) _ = s.storeMessage(protocol.NewEnvelope(msg1, utils.GetUnixEpoch(), pubsubTopic2))
_ = s.storeMessage(protocol.NewEnvelope(msg2, pubsubTopic2)) _ = s.storeMessage(protocol.NewEnvelope(msg2, utils.GetUnixEpoch(), pubsubTopic2))
_ = s.storeMessage(protocol.NewEnvelope(msg3, pubsubTopic2)) _ = s.storeMessage(protocol.NewEnvelope(msg3, utils.GetUnixEpoch(), pubsubTopic2))
response := s.FindMessages(&pb.HistoryQuery{ response := s.FindMessages(&pb.HistoryQuery{
PubsubTopic: pubsubTopic1, PubsubTopic: pubsubTopic1,
@ -131,10 +131,10 @@ func TestStoreQueryPubsubTopicAllMessages(t *testing.T) {
msg2 := tests.CreateWakuMessage(topic2, utils.GetUnixEpoch()) msg2 := tests.CreateWakuMessage(topic2, utils.GetUnixEpoch())
msg3 := tests.CreateWakuMessage(topic3, utils.GetUnixEpoch()) msg3 := tests.CreateWakuMessage(topic3, utils.GetUnixEpoch())
s := NewWakuStore(nil, nil, nil, 0, 0, utils.Logger()) s := NewWakuStore(nil, nil, MemoryDB(t), 0, 0, utils.Logger())
_ = s.storeMessage(protocol.NewEnvelope(msg1, pubsubTopic1)) _ = s.storeMessage(protocol.NewEnvelope(msg1, utils.GetUnixEpoch(), pubsubTopic1))
_ = s.storeMessage(protocol.NewEnvelope(msg2, pubsubTopic1)) _ = s.storeMessage(protocol.NewEnvelope(msg2, utils.GetUnixEpoch(), pubsubTopic1))
_ = s.storeMessage(protocol.NewEnvelope(msg3, pubsubTopic1)) _ = s.storeMessage(protocol.NewEnvelope(msg3, utils.GetUnixEpoch(), pubsubTopic1))
response := s.FindMessages(&pb.HistoryQuery{ response := s.FindMessages(&pb.HistoryQuery{
PubsubTopic: pubsubTopic1, PubsubTopic: pubsubTopic1,
@ -150,11 +150,11 @@ func TestStoreQueryForwardPagination(t *testing.T) {
topic1 := "1" topic1 := "1"
pubsubTopic1 := "topic1" pubsubTopic1 := "topic1"
s := NewWakuStore(nil, nil, nil, 0, 0, utils.Logger()) s := NewWakuStore(nil, nil, MemoryDB(t), 0, 0, utils.Logger())
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
msg := tests.CreateWakuMessage(topic1, utils.GetUnixEpoch()) msg := tests.CreateWakuMessage(topic1, utils.GetUnixEpoch())
msg.Payload = []byte{byte(i)} msg.Payload = []byte{byte(i)}
_ = s.storeMessage(protocol.NewEnvelope(msg, pubsubTopic1)) _ = s.storeMessage(protocol.NewEnvelope(msg, utils.GetUnixEpoch(), pubsubTopic1))
} }
response := s.FindMessages(&pb.HistoryQuery{ response := s.FindMessages(&pb.HistoryQuery{
@ -174,7 +174,7 @@ func TestStoreQueryBackwardPagination(t *testing.T) {
topic1 := "1" topic1 := "1"
pubsubTopic1 := "topic1" pubsubTopic1 := "topic1"
s := NewWakuStore(nil, nil, nil, 0, 0, utils.Logger()) s := NewWakuStore(nil, nil, MemoryDB(t), 0, 0, utils.Logger())
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
msg := &pb.WakuMessage{ msg := &pb.WakuMessage{
Payload: []byte{byte(i)}, Payload: []byte{byte(i)},
@ -182,7 +182,7 @@ func TestStoreQueryBackwardPagination(t *testing.T) {
Version: 0, Version: 0,
Timestamp: utils.GetUnixEpoch(), Timestamp: utils.GetUnixEpoch(),
} }
_ = s.storeMessage(protocol.NewEnvelope(msg, pubsubTopic1)) _ = s.storeMessage(protocol.NewEnvelope(msg, utils.GetUnixEpoch(), pubsubTopic1))
} }
@ -200,7 +200,7 @@ func TestStoreQueryBackwardPagination(t *testing.T) {
} }
func TestTemporalHistoryQueries(t *testing.T) { func TestTemporalHistoryQueries(t *testing.T) {
s := NewWakuStore(nil, nil, nil, 0, 0, utils.Logger()) s := NewWakuStore(nil, nil, MemoryDB(t), 0, 0, utils.Logger())
var messages []*pb.WakuMessage var messages []*pb.WakuMessage
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
@ -209,7 +209,7 @@ func TestTemporalHistoryQueries(t *testing.T) {
contentTopic = "2" contentTopic = "2"
} }
msg := tests.CreateWakuMessage(contentTopic, int64(i)) msg := tests.CreateWakuMessage(contentTopic, int64(i))
_ = s.storeMessage(protocol.NewEnvelope(msg, "test")) _ = s.storeMessage(protocol.NewEnvelope(msg, utils.GetUnixEpoch(), "test"))
messages = append(messages, msg) messages = append(messages, msg)
} }