fix: invalid order when pagination is backwards (#313)

This commit is contained in:
Richard Ramos 2022-09-15 09:23:45 -04:00 committed by GitHub
parent c39c4d535c
commit f6cd9904c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 18 deletions

View File

@ -236,7 +236,7 @@ func (d *DBStore) Put(env *protocol.Envelope) error {
}
// Query retrieves messages from the DB
func (d *DBStore) Query(query *pb.HistoryQuery) ([]StoredMessage, error) {
func (d *DBStore) Query(query *pb.HistoryQuery) (*pb.Index, []StoredMessage, error) {
start := time.Now()
defer func() {
elapsed := time.Since(start)
@ -290,7 +290,7 @@ func (d *DBStore) Query(query *pb.HistoryQuery) ([]StoredMessage, error) {
).Scan(&exists)
if err != nil {
return nil, err
return nil, nil, err
}
if exists {
@ -302,7 +302,7 @@ func (d *DBStore) Query(query *pb.HistoryQuery) ([]StoredMessage, error) {
parameters = append(parameters, cursorDBKey.Bytes())
} else {
return nil, ErrInvalidCursor
return nil, nil, ErrInvalidCursor
}
}
@ -320,28 +320,41 @@ func (d *DBStore) Query(query *pb.HistoryQuery) ([]StoredMessage, error) {
stmt, err := d.db.Prepare(sqlQuery)
if err != nil {
return nil, err
return nil, nil, err
}
defer stmt.Close()
parameters = append(parameters, query.PagingInfo.PageSize)
rows, err := stmt.Query(parameters...)
if err != nil {
return nil, err
return nil, nil, err
}
var result []StoredMessage
for rows.Next() {
record, err := d.GetStoredMessage(rows)
if err != nil {
return nil, err
return nil, nil, err
}
result = append(result, record)
}
defer rows.Close()
return result, nil
cursor := &pb.Index{}
if len(result) != 0 {
lastMsgIdx := len(result) - 1
cursor = protocol.NewEnvelope(result[lastMsgIdx].Message, result[lastMsgIdx].ReceiverTime, result[lastMsgIdx].PubsubTopic).Index()
}
// The retrieved messages list should always be in chronological order
if query.PagingInfo.Direction == pb.PagingInfo_BACKWARD {
for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 {
result[i], result[j] = result[j], result[i]
}
}
return cursor, result, nil
}
// MostRecentTimestamp returns an unix timestamp with the most recent senderTimestamp

View File

@ -62,7 +62,7 @@ func findMessages(query *pb.HistoryQuery, msgProvider MessageProvider) ([]*pb.Wa
query.PagingInfo.PageSize = MaxPageSize
}
queryResult, err := msgProvider.Query(query)
cursor, queryResult, err := msgProvider.Query(query)
if err != nil {
return nil, nil, err
}
@ -72,10 +72,7 @@ func findMessages(query *pb.HistoryQuery, msgProvider MessageProvider) ([]*pb.Wa
return nil, newPagingInfo, nil
}
lastMsgIdx := len(queryResult) - 1
newCursor := protocol.NewEnvelope(queryResult[lastMsgIdx].Message, queryResult[lastMsgIdx].ReceiverTime, queryResult[lastMsgIdx].PubsubTopic).Index()
newPagingInfo := &pb.PagingInfo{PageSize: query.PagingInfo.PageSize, Cursor: newCursor, Direction: query.PagingInfo.Direction}
newPagingInfo := &pb.PagingInfo{PageSize: query.PagingInfo.PageSize, Cursor: cursor, Direction: query.PagingInfo.Direction}
if newPagingInfo.PageSize > uint64(len(queryResult)) {
newPagingInfo.PageSize = uint64(len(queryResult))
}
@ -108,7 +105,7 @@ func (store *WakuStore) FindMessages(query *pb.HistoryQuery) *pb.HistoryResponse
type MessageProvider interface {
GetAll() ([]persistence.StoredMessage, error)
Query(query *pb.HistoryQuery) ([]persistence.StoredMessage, error)
Query(query *pb.HistoryQuery) (*pb.Index, []persistence.StoredMessage, error)
Put(env *protocol.Envelope) error
MostRecentTimestamp() (int64, error)
Stop()

View File

@ -160,7 +160,7 @@ func TestBackwardPagination(t *testing.T) {
require.NoError(t, err)
require.Len(t, messages, 2)
require.Equal(t, []*pb.WakuMessage{msgList[2].Message(), msgList[1].Message()}, messages)
require.Equal(t, []*pb.WakuMessage{msgList[1].Message(), msgList[2].Message()}, messages)
require.Equal(t, msgList[1].Index(), newPagingInfo.Cursor)
require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction)
require.Equal(t, pagingInfo.PageSize, newPagingInfo.PageSize)
@ -170,7 +170,7 @@ func TestBackwardPagination(t *testing.T) {
messages, newPagingInfo, err = findMessages(&pb.HistoryQuery{PagingInfo: pagingInfo}, db)
require.NoError(t, err)
require.Len(t, messages, 2)
require.Equal(t, []*pb.WakuMessage{msgList[9].Message(), msgList[8].Message()}, messages)
require.Equal(t, []*pb.WakuMessage{msgList[8].Message(), msgList[9].Message()}, messages)
require.Equal(t, msgList[8].Index(), newPagingInfo.Cursor)
require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction)
require.Equal(t, pagingInfo.PageSize, newPagingInfo.PageSize)
@ -180,8 +180,8 @@ func TestBackwardPagination(t *testing.T) {
messages, newPagingInfo, err = findMessages(&pb.HistoryQuery{PagingInfo: pagingInfo}, db)
require.NoError(t, err)
require.Len(t, messages, 10)
require.Equal(t, msgList[0].Message(), messages[9])
require.Equal(t, msgList[9].Message(), messages[0])
require.Equal(t, msgList[0].Message(), messages[0])
require.Equal(t, msgList[9].Message(), messages[9])
require.Equal(t, msgList[0].Index(), newPagingInfo.Cursor)
require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction)
require.Equal(t, uint64(10), newPagingInfo.PageSize)
@ -200,7 +200,7 @@ func TestBackwardPagination(t *testing.T) {
messages, newPagingInfo, err = findMessages(&pb.HistoryQuery{PagingInfo: pagingInfo}, db)
require.NoError(t, err)
require.Len(t, messages, 3)
require.Equal(t, []*pb.WakuMessage{msgList[2].Message(), msgList[1].Message(), msgList[0].Message()}, messages)
require.Equal(t, []*pb.WakuMessage{msgList[0].Message(), msgList[1].Message(), msgList[2].Message()}, messages)
require.Equal(t, msgList[0].Index(), newPagingInfo.Cursor)
require.Equal(t, pagingInfo.Direction, newPagingInfo.Direction)
require.Equal(t, uint64(3), newPagingInfo.PageSize)