From f6cd9904c556ceea7144b40014435e64f27a9394 Mon Sep 17 00:00:00 2001 From: Richard Ramos Date: Thu, 15 Sep 2022 09:23:45 -0400 Subject: [PATCH] fix: invalid order when pagination is backwards (#313) --- waku/persistence/store.go | 27 ++++++++++++++----- waku/v2/protocol/store/waku_store.go | 9 +++---- .../store/waku_store_pagination_test.go | 10 +++---- 3 files changed, 28 insertions(+), 18 deletions(-) diff --git a/waku/persistence/store.go b/waku/persistence/store.go index 0ed06df5..fb0c88e1 100644 --- a/waku/persistence/store.go +++ b/waku/persistence/store.go @@ -236,7 +236,7 @@ func (d *DBStore) Put(env *protocol.Envelope) error { } // Query retrieves messages from the DB -func (d *DBStore) Query(query *pb.HistoryQuery) ([]StoredMessage, error) { +func (d *DBStore) Query(query *pb.HistoryQuery) (*pb.Index, []StoredMessage, error) { start := time.Now() defer func() { elapsed := time.Since(start) @@ -290,7 +290,7 @@ func (d *DBStore) Query(query *pb.HistoryQuery) ([]StoredMessage, error) { ).Scan(&exists) if err != nil { - return nil, err + return nil, nil, err } if exists { @@ -302,7 +302,7 @@ func (d *DBStore) Query(query *pb.HistoryQuery) ([]StoredMessage, error) { parameters = append(parameters, cursorDBKey.Bytes()) } else { - return nil, ErrInvalidCursor + return nil, nil, ErrInvalidCursor } } @@ -320,28 +320,41 @@ func (d *DBStore) Query(query *pb.HistoryQuery) ([]StoredMessage, error) { stmt, err := d.db.Prepare(sqlQuery) if err != nil { - return nil, err + return nil, nil, err } defer stmt.Close() parameters = append(parameters, query.PagingInfo.PageSize) rows, err := stmt.Query(parameters...) if err != nil { - return nil, err + return nil, nil, err } var result []StoredMessage for rows.Next() { record, err := d.GetStoredMessage(rows) if err != nil { - return nil, err + return nil, nil, err } result = append(result, record) } defer rows.Close() - return result, nil + cursor := &pb.Index{} + if len(result) != 0 { + lastMsgIdx := len(result) - 1 + cursor = protocol.NewEnvelope(result[lastMsgIdx].Message, result[lastMsgIdx].ReceiverTime, result[lastMsgIdx].PubsubTopic).Index() + } + + // The retrieved messages list should always be in chronological order + if query.PagingInfo.Direction == pb.PagingInfo_BACKWARD { + for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 { + result[i], result[j] = result[j], result[i] + } + } + + return cursor, result, nil } // MostRecentTimestamp returns an unix timestamp with the most recent senderTimestamp diff --git a/waku/v2/protocol/store/waku_store.go b/waku/v2/protocol/store/waku_store.go index d843a487..8f7f34fe 100644 --- a/waku/v2/protocol/store/waku_store.go +++ b/waku/v2/protocol/store/waku_store.go @@ -62,7 +62,7 @@ func findMessages(query *pb.HistoryQuery, msgProvider MessageProvider) ([]*pb.Wa query.PagingInfo.PageSize = MaxPageSize } - queryResult, err := msgProvider.Query(query) + cursor, queryResult, err := msgProvider.Query(query) if err != nil { return nil, nil, err } @@ -72,10 +72,7 @@ func findMessages(query *pb.HistoryQuery, msgProvider MessageProvider) ([]*pb.Wa return nil, newPagingInfo, nil } - lastMsgIdx := len(queryResult) - 1 - newCursor := protocol.NewEnvelope(queryResult[lastMsgIdx].Message, queryResult[lastMsgIdx].ReceiverTime, queryResult[lastMsgIdx].PubsubTopic).Index() - - newPagingInfo := &pb.PagingInfo{PageSize: query.PagingInfo.PageSize, Cursor: newCursor, Direction: query.PagingInfo.Direction} + newPagingInfo := &pb.PagingInfo{PageSize: query.PagingInfo.PageSize, Cursor: cursor, Direction: query.PagingInfo.Direction} if newPagingInfo.PageSize > uint64(len(queryResult)) { newPagingInfo.PageSize = uint64(len(queryResult)) } @@ -108,7 +105,7 @@ func (store *WakuStore) FindMessages(query *pb.HistoryQuery) *pb.HistoryResponse type MessageProvider interface { GetAll() ([]persistence.StoredMessage, error) - Query(query *pb.HistoryQuery) ([]persistence.StoredMessage, error) + Query(query *pb.HistoryQuery) (*pb.Index, []persistence.StoredMessage, error) Put(env *protocol.Envelope) error MostRecentTimestamp() (int64, error) Stop() diff --git a/waku/v2/protocol/store/waku_store_pagination_test.go b/waku/v2/protocol/store/waku_store_pagination_test.go index 15ec748c..c1bee493 100644 --- a/waku/v2/protocol/store/waku_store_pagination_test.go +++ b/waku/v2/protocol/store/waku_store_pagination_test.go @@ -160,7 +160,7 @@ func TestBackwardPagination(t *testing.T) { require.NoError(t, err) require.Len(t, messages, 2) - require.Equal(t, []*pb.WakuMessage{msgList[2].Message(), msgList[1].Message()}, messages) + require.Equal(t, []*pb.WakuMessage{msgList[1].Message(), msgList[2].Message()}, messages) require.Equal(t, msgList[1].Index(), newPagingInfo.Cursor) require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction) require.Equal(t, pagingInfo.PageSize, newPagingInfo.PageSize) @@ -170,7 +170,7 @@ func TestBackwardPagination(t *testing.T) { messages, newPagingInfo, err = findMessages(&pb.HistoryQuery{PagingInfo: pagingInfo}, db) require.NoError(t, err) require.Len(t, messages, 2) - require.Equal(t, []*pb.WakuMessage{msgList[9].Message(), msgList[8].Message()}, messages) + require.Equal(t, []*pb.WakuMessage{msgList[8].Message(), msgList[9].Message()}, messages) require.Equal(t, msgList[8].Index(), newPagingInfo.Cursor) require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction) require.Equal(t, pagingInfo.PageSize, newPagingInfo.PageSize) @@ -180,8 +180,8 @@ func TestBackwardPagination(t *testing.T) { messages, newPagingInfo, err = findMessages(&pb.HistoryQuery{PagingInfo: pagingInfo}, db) require.NoError(t, err) require.Len(t, messages, 10) - require.Equal(t, msgList[0].Message(), messages[9]) - require.Equal(t, msgList[9].Message(), messages[0]) + require.Equal(t, msgList[0].Message(), messages[0]) + require.Equal(t, msgList[9].Message(), messages[9]) require.Equal(t, msgList[0].Index(), newPagingInfo.Cursor) require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction) require.Equal(t, uint64(10), newPagingInfo.PageSize) @@ -200,7 +200,7 @@ func TestBackwardPagination(t *testing.T) { messages, newPagingInfo, err = findMessages(&pb.HistoryQuery{PagingInfo: pagingInfo}, db) require.NoError(t, err) require.Len(t, messages, 3) - require.Equal(t, []*pb.WakuMessage{msgList[2].Message(), msgList[1].Message(), msgList[0].Message()}, messages) + require.Equal(t, []*pb.WakuMessage{msgList[0].Message(), msgList[1].Message(), msgList[2].Message()}, messages) require.Equal(t, msgList[0].Index(), newPagingInfo.Cursor) require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction) require.Equal(t, uint64(3), newPagingInfo.PageSize)