refactor: validate protobuffer for store (#841)

This commit is contained in:
richΛrd 2023-10-30 12:55:36 -04:00 committed by GitHub
parent 38202e7a2e
commit 4584bb4324
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 230 additions and 88 deletions

View File

@ -7,7 +7,7 @@ var (
errMissingQuery = errors.New("missing Query field") errMissingQuery = errors.New("missing Query field")
errMissingMessage = errors.New("missing Message field") errMissingMessage = errors.New("missing Message field")
errMissingPubsubTopic = errors.New("missing PubsubTopic field") errMissingPubsubTopic = errors.New("missing PubsubTopic field")
errRequestIDMismatch = errors.New("RequestID in response does not match request") errRequestIDMismatch = errors.New("requestID in response does not match request")
errMissingResponse = errors.New("missing Response field") errMissingResponse = errors.New("missing Response field")
) )

View File

@ -0,0 +1,68 @@
package pb
import (
"errors"
)
// MaxContentFilters is the maximum number of allowed content filters in a query
const MaxContentFilters = 10
var (
errMissingRequestID = errors.New("missing RequestId field")
errMissingQuery = errors.New("missing Query field")
errRequestIDMismatch = errors.New("requestID in response does not match request")
errMaxContentFilters = errors.New("exceeds the maximum number of content filters allowed")
errEmptyContentTopics = errors.New("one or more content topics specified is empty")
)
func (x *HistoryQuery) Validate() error {
if len(x.ContentFilters) > MaxContentFilters {
return errMaxContentFilters
}
for _, m := range x.ContentFilters {
if m.ContentTopic == "" {
return errEmptyContentTopics
}
}
return nil
}
func (x *HistoryRPC) ValidateQuery() error {
if x.RequestId == "" {
return errMissingRequestID
}
if x.Query == nil {
return errMissingQuery
}
return x.Query.Validate()
}
func (x *HistoryResponse) Validate() error {
for _, m := range x.Messages {
if err := m.Validate(); err != nil {
return err
}
}
return nil
}
func (x *HistoryRPC) ValidateResponse(requestID string) error {
if x.RequestId == "" {
return errMissingRequestID
}
if x.RequestId != requestID {
return errRequestIDMismatch
}
if x.Response != nil {
return x.Response.Validate()
}
return nil
}

View File

@ -0,0 +1,42 @@
package pb
import (
"testing"
"github.com/stretchr/testify/require"
)
func cf(val string) *ContentFilter {
return &ContentFilter{
ContentTopic: val,
}
}
func TestValidateRequest(t *testing.T) {
request := HistoryRPC{}
require.ErrorIs(t, request.ValidateQuery(), errMissingRequestID)
request.RequestId = "test"
require.ErrorIs(t, request.ValidateQuery(), errMissingQuery)
request.Query = &HistoryQuery{
ContentFilters: []*ContentFilter{
cf("1"), cf("2"), cf("3"), cf("4"), cf("5"),
cf("6"), cf("7"), cf("8"), cf("9"), cf("10"),
cf("11"),
},
}
require.ErrorIs(t, request.ValidateQuery(), errMaxContentFilters)
request.Query.ContentFilters = []*ContentFilter{cf("a"), cf("")}
require.ErrorIs(t, request.ValidateQuery(), errEmptyContentTopics)
request.Query.ContentFilters = []*ContentFilter{cf("a")}
require.NoError(t, request.ValidateQuery())
}
func TestValidateResponse(t *testing.T) {
response := HistoryRPC{}
require.ErrorIs(t, response.ValidateResponse("test"), errMissingRequestID)
response.RequestId = "test1"
require.ErrorIs(t, response.ValidateResponse("test"), errRequestIDMismatch)
response.RequestId = "test"
response.Response = &HistoryResponse{}
require.NoError(t, response.ValidateResponse("test"))
}

View File

@ -170,7 +170,7 @@ func DefaultOptions() []HistoryRequestOption {
} }
} }
func (store *WakuStore) queryFrom(ctx context.Context, q *pb.HistoryQuery, selectedPeer peer.ID, requestID []byte) (*pb.HistoryResponse, error) { func (store *WakuStore) queryFrom(ctx context.Context, historyRequest *pb.HistoryRPC, selectedPeer peer.ID) (*pb.HistoryResponse, error) {
logger := store.log.With(logging.HostID("peer", selectedPeer)) logger := store.log.With(logging.HostID("peer", selectedPeer))
logger.Info("querying message history") logger.Info("querying message history")
@ -181,8 +181,6 @@ func (store *WakuStore) queryFrom(ctx context.Context, q *pb.HistoryQuery, selec
return nil, err return nil, err
} }
historyRequest := &pb.HistoryRPC{Query: q, RequestId: hex.EncodeToString(requestID)}
writer := pbio.NewDelimitedWriter(stream) writer := pbio.NewDelimitedWriter(stream)
reader := pbio.NewDelimitedReader(stream, math.MaxInt32) reader := pbio.NewDelimitedReader(stream, math.MaxInt32)
@ -209,6 +207,8 @@ func (store *WakuStore) queryFrom(ctx context.Context, q *pb.HistoryQuery, selec
stream.Close() stream.Close()
// nwaku does not return a response if there are no results due to the way their
// protobuffer library works. this condition once they have proper proto3 support
if historyResponseRPC.Response == nil { if historyResponseRPC.Response == nil {
// Empty response // Empty response
return &pb.HistoryResponse{ return &pb.HistoryResponse{
@ -216,10 +216,14 @@ func (store *WakuStore) queryFrom(ctx context.Context, q *pb.HistoryQuery, selec
}, nil }, nil
} }
if err := historyResponseRPC.ValidateResponse(historyRequest.RequestId); err != nil {
return nil, err
}
return historyResponseRPC.Response, nil return historyResponseRPC.Response, nil
} }
func (store *WakuStore) localQuery(query *pb.HistoryQuery, requestID []byte) (*pb.HistoryResponse, error) { func (store *WakuStore) localQuery(historyQuery *pb.HistoryRPC) (*pb.HistoryResponse, error) {
logger := store.log logger := store.log
logger.Info("querying local message history") logger.Info("querying local message history")
@ -228,8 +232,8 @@ func (store *WakuStore) localQuery(query *pb.HistoryQuery, requestID []byte) (*p
} }
historyResponseRPC := &pb.HistoryRPC{ historyResponseRPC := &pb.HistoryRPC{
RequestId: hex.EncodeToString(requestID), RequestId: historyQuery.RequestId,
Response: store.FindMessages(query), Response: store.FindMessages(historyQuery.Query),
} }
if historyResponseRPC.Response == nil { if historyResponseRPC.Response == nil {
@ -243,21 +247,6 @@ func (store *WakuStore) localQuery(query *pb.HistoryQuery, requestID []byte) (*p
} }
func (store *WakuStore) Query(ctx context.Context, query Query, opts ...HistoryRequestOption) (*Result, error) { func (store *WakuStore) Query(ctx context.Context, query Query, opts ...HistoryRequestOption) (*Result, error) {
q := &pb.HistoryQuery{
PubsubTopic: query.Topic,
ContentFilters: []*pb.ContentFilter{},
StartTime: query.StartTime,
EndTime: query.EndTime,
PagingInfo: &pb.PagingInfo{},
}
for _, cf := range query.ContentTopics {
q.ContentFilters = append(q.ContentFilters, &pb.ContentFilter{ContentTopic: cf})
}
if len(q.ContentFilters) > MaxContentFilters {
return nil, ErrMaxContentFilters
}
params := new(HistoryRequestParameters) params := new(HistoryRequestParameters)
params.s = store params.s = store
@ -283,38 +272,53 @@ func (store *WakuStore) Query(ctx context.Context, query Query, opts ...HistoryR
} }
} }
historyRequest := &pb.HistoryRPC{
RequestId: hex.EncodeToString(params.requestID),
Query: &pb.HistoryQuery{
PubsubTopic: query.Topic,
ContentFilters: []*pb.ContentFilter{},
StartTime: query.StartTime,
EndTime: query.EndTime,
PagingInfo: &pb.PagingInfo{},
},
}
for _, cf := range query.ContentTopics {
historyRequest.Query.ContentFilters = append(historyRequest.Query.ContentFilters, &pb.ContentFilter{ContentTopic: cf})
}
if !params.localQuery && params.selectedPeer == "" { if !params.localQuery && params.selectedPeer == "" {
store.metrics.RecordError(peerNotFoundFailure) store.metrics.RecordError(peerNotFoundFailure)
return nil, ErrNoPeersAvailable return nil, ErrNoPeersAvailable
} }
if len(params.requestID) == 0 {
return nil, ErrInvalidID
}
if params.cursor != nil { if params.cursor != nil {
q.PagingInfo.Cursor = params.cursor historyRequest.Query.PagingInfo.Cursor = params.cursor
} }
if params.asc { if params.asc {
q.PagingInfo.Direction = pb.PagingInfo_FORWARD historyRequest.Query.PagingInfo.Direction = pb.PagingInfo_FORWARD
} else { } else {
q.PagingInfo.Direction = pb.PagingInfo_BACKWARD historyRequest.Query.PagingInfo.Direction = pb.PagingInfo_BACKWARD
} }
pageSize := params.pageSize pageSize := params.pageSize
if pageSize == 0 || pageSize > uint64(MaxPageSize) { if pageSize == 0 || pageSize > uint64(MaxPageSize) {
pageSize = MaxPageSize pageSize = MaxPageSize
} }
q.PagingInfo.PageSize = pageSize historyRequest.Query.PagingInfo.PageSize = pageSize
err := historyRequest.ValidateQuery()
if err != nil {
return nil, err
}
var response *pb.HistoryResponse var response *pb.HistoryResponse
var err error
if params.localQuery { if params.localQuery {
response, err = store.localQuery(q, params.requestID) response, err = store.localQuery(historyRequest)
} else { } else {
response, err = store.queryFrom(ctx, q, params.selectedPeer, params.requestID) response, err = store.queryFrom(ctx, historyRequest, params.selectedPeer)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@ -327,7 +331,7 @@ func (store *WakuStore) Query(ctx context.Context, query Query, opts ...HistoryR
result := &Result{ result := &Result{
store: store, store: store,
Messages: response.Messages, Messages: response.Messages,
query: q, query: historyRequest.Query,
peerID: params.selectedPeer, peerID: params.selectedPeer,
} }
@ -390,24 +394,27 @@ func (store *WakuStore) Next(ctx context.Context, r *Result) (*Result, error) {
}, nil }, nil
} }
q := &pb.HistoryQuery{ historyRequest := &pb.HistoryRPC{
PubsubTopic: r.Query().PubsubTopic, RequestId: hex.EncodeToString(protocol.GenerateRequestID()),
ContentFilters: r.Query().ContentFilters, Query: &pb.HistoryQuery{
StartTime: r.Query().StartTime, PubsubTopic: r.Query().PubsubTopic,
EndTime: r.Query().EndTime, ContentFilters: r.Query().ContentFilters,
PagingInfo: &pb.PagingInfo{ StartTime: r.Query().StartTime,
PageSize: r.Query().PagingInfo.PageSize, EndTime: r.Query().EndTime,
Direction: r.Query().PagingInfo.Direction, PagingInfo: &pb.PagingInfo{
Cursor: &pb.Index{ PageSize: r.Query().PagingInfo.PageSize,
Digest: r.Cursor().Digest, Direction: r.Query().PagingInfo.Direction,
ReceiverTime: r.Cursor().ReceiverTime, Cursor: &pb.Index{
SenderTime: r.Cursor().SenderTime, Digest: r.Cursor().Digest,
PubsubTopic: r.Cursor().PubsubTopic, ReceiverTime: r.Cursor().ReceiverTime,
SenderTime: r.Cursor().SenderTime,
PubsubTopic: r.Cursor().PubsubTopic,
},
}, },
}, },
} }
response, err := store.queryFrom(ctx, q, r.PeerID(), protocol.GenerateRequestID()) response, err := store.queryFrom(ctx, historyRequest, r.PeerID())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -420,7 +427,7 @@ func (store *WakuStore) Next(ctx context.Context, r *Result) (*Result, error) {
started: true, started: true,
store: store, store: store,
Messages: response.Messages, Messages: response.Messages,
query: q, query: historyRequest.Query,
peerID: r.PeerID(), peerID: r.PeerID(),
} }

View File

@ -20,27 +20,15 @@ const StoreID_v20beta4 = libp2pProtocol.ID("/vac/waku/store/2.0.0-beta4")
// MaxPageSize is the maximum number of waku messages to return per page // MaxPageSize is the maximum number of waku messages to return per page
const MaxPageSize = 20 const MaxPageSize = 20
// MaxContentFilters is the maximum number of allowed content filters in a query
const MaxContentFilters = 10
var ( var (
// ErrMaxContentFilters is returned when the number of content topics in the query
// exceeds the limit
ErrMaxContentFilters = errors.New("exceeds the maximum number of content filters allowed")
// ErrNoPeersAvailable is returned when there are no store peers in the peer store // ErrNoPeersAvailable is returned when there are no store peers in the peer store
// that could be used to retrieve message history // that could be used to retrieve message history
ErrNoPeersAvailable = errors.New("no suitable remote peers") ErrNoPeersAvailable = errors.New("no suitable remote peers")
// ErrInvalidID is returned when no RequestID is given
ErrInvalidID = errors.New("invalid request id")
// ErrFailedToResumeHistory is returned when the node attempted to retrieve historic // ErrFailedToResumeHistory is returned when the node attempted to retrieve historic
// messages to fill its own message history but for some reason it failed // messages to fill its own message history but for some reason it failed
ErrFailedToResumeHistory = errors.New("failed to resume the history") ErrFailedToResumeHistory = errors.New("failed to resume the history")
// ErrFailedQuery is emitted when the query fails to return results
ErrFailedQuery = errors.New("failed to resolve the query")
) )
type WakuSwap interface { type WakuSwap interface {

View File

@ -2,6 +2,7 @@ package store
import ( import (
"context" "context"
"encoding/hex"
"errors" "errors"
"math" "math"
"sync" "sync"
@ -33,10 +34,6 @@ func findMessages(query *pb.HistoryQuery, msgProvider MessageProvider) ([]*wpb.W
query.PagingInfo.PageSize = MaxPageSize query.PagingInfo.PageSize = MaxPageSize
} }
if len(query.ContentFilters) > MaxContentFilters {
return nil, nil, ErrMaxContentFilters
}
cursor, queryResult, err := msgProvider.Query(query) cursor, queryResult, err := msgProvider.Query(query)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -181,6 +178,18 @@ func (store *WakuStore) onRequest(stream network.Stream) {
return return
} }
if err := historyRPCRequest.ValidateQuery(); err != nil {
logger.Error("invalid request received", zap.Error(err))
store.metrics.RecordError(decodeRPCFailure)
if err := stream.Reset(); err != nil {
store.log.Error("resetting connection", zap.Error(err))
}
// TODO: If store protocol is updated to include error messages
// `err.Error()` can be returned as a response
return
}
logger = logger.With(zap.String("id", historyRPCRequest.RequestId)) logger = logger.With(zap.String("id", historyRPCRequest.RequestId))
if query := historyRPCRequest.Query; query != nil { if query := historyRPCRequest.Query; query != nil {
logger = logger.With(logging.Filters(query.GetContentFilters())) logger = logger.With(logging.Filters(query.GetContentFilters()))
@ -238,42 +247,59 @@ func (store *WakuStore) Stop() {
store.wg.Wait() store.wg.Wait()
} }
func (store *WakuStore) queryLoop(ctx context.Context, query *pb.HistoryQuery, candidateList []peer.ID) ([]*wpb.WakuMessage, error) { type queryLoopCandidateResponse struct {
// loops through the candidateList in order and sends the query to each until one of the query gets resolved successfully peerID peer.ID
// returns the number of retrieved messages, or error if all the requests fail response *pb.HistoryResponse
err error
}
func (store *WakuStore) queryLoop(ctx context.Context, query *pb.HistoryQuery, candidateList []peer.ID) ([]*queryLoopCandidateResponse, error) {
err := query.Validate()
if err != nil {
return nil, err
}
queryWg := sync.WaitGroup{} queryWg := sync.WaitGroup{}
queryWg.Add(len(candidateList)) queryWg.Add(len(candidateList))
resultChan := make(chan *pb.HistoryResponse, len(candidateList)) resultChan := make(chan *queryLoopCandidateResponse, len(candidateList))
// loops through the candidateList in order and sends the query to each until one of the query gets resolved successfully
// returns the number of retrieved messages, or error if all the requests fail
for _, peer := range candidateList { for _, peer := range candidateList {
func() { func() {
defer queryWg.Done() defer queryWg.Done()
result, err := store.queryFrom(ctx, query, peer, protocol.GenerateRequestID())
if err == nil { historyRequest := &pb.HistoryRPC{
resultChan <- result RequestId: hex.EncodeToString(protocol.GenerateRequestID()),
return Query: query,
} }
store.log.Error("resuming history", logging.HostID("peer", peer), zap.Error(err))
result := &queryLoopCandidateResponse{
peerID: peer,
}
response, err := store.queryFrom(ctx, historyRequest, peer)
if err != nil {
store.log.Error("resuming history", logging.HostID("peer", peer), zap.Error(err))
result.err = err
} else {
result.response = response
}
resultChan <- result
}() }()
} }
queryWg.Wait() queryWg.Wait()
close(resultChan) close(resultChan)
var messages []*wpb.WakuMessage var queryLoopResults []*queryLoopCandidateResponse
hasResults := false
for result := range resultChan { for result := range resultChan {
hasResults = true queryLoopResults = append(queryLoopResults, result)
messages = append(messages, result.Messages...)
} }
if hasResults { return queryLoopResults, nil
return messages, nil
}
return nil, ErrFailedQuery
} }
func (store *WakuStore) findLastSeen() (int64, error) { func (store *WakuStore) findLastSeen() (int64, error) {
@ -323,20 +349,31 @@ func (store *WakuStore) Resume(ctx context.Context, pubsubTopic string, peerList
return -1, ErrNoPeersAvailable return -1, ErrNoPeersAvailable
} }
messages, err := store.queryLoop(ctx, rpc, peerList) queryLoopResults, err := store.queryLoop(ctx, rpc, peerList)
if err != nil { if err != nil {
store.log.Error("resuming history", zap.Error(err)) store.log.Error("resuming history", zap.Error(err))
return -1, ErrFailedToResumeHistory return -1, ErrFailedToResumeHistory
} }
msgCount := 0 msgCount := 0
for _, msg := range messages { for _, r := range queryLoopResults {
if err = store.storeMessage(protocol.NewEnvelope(msg, store.timesource.Now().UnixNano(), pubsubTopic)); err == nil { if err == nil && r.response.GetError() != pb.HistoryResponse_NONE {
msgCount++ r.err = errors.New("invalid cursor")
}
if r.err != nil {
store.log.Warn("could not resume message history", zap.Error(r.err), logging.HostID("peer", r.peerID))
continue
}
for _, msg := range r.response.Messages {
if err = store.storeMessage(protocol.NewEnvelope(msg, store.timesource.Now().UnixNano(), pubsubTopic)); err == nil {
msgCount++
}
} }
} }
store.log.Info("retrieved messages since the last online time", zap.Int("messages", len(messages))) store.log.Info("retrieved messages since the last online time", zap.Int("messages", msgCount))
return msgCount, nil return msgCount, nil
} }