mirror of https://github.com/status-im/go-waku.git
refactor: validate protobuffers for lightpush and relay (#824)
This commit is contained in:
parent
fa51d10b4b
commit
94f18c537c
|
@ -32,6 +32,7 @@ func createTestMsg(version uint32) *pb.WakuMessage {
|
||||||
message.Payload = []byte{0, 1, 2}
|
message.Payload = []byte{0, 1, 2}
|
||||||
message.Version = version
|
message.Version = version
|
||||||
message.Timestamp = 123456
|
message.Timestamp = 123456
|
||||||
|
message.ContentTopic = "abc"
|
||||||
return message
|
return message
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -264,7 +265,8 @@ func TestDecoupledStoreFromRelay(t *testing.T) {
|
||||||
time.Sleep(2 * time.Second)
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
_, filter, err := wakuNode2.LegacyFilter().Subscribe(ctx, legacy_filter.ContentFilter{
|
_, filter, err := wakuNode2.LegacyFilter().Subscribe(ctx, legacy_filter.ContentFilter{
|
||||||
Topic: string(relay.DefaultWakuTopic),
|
Topic: string(relay.DefaultWakuTopic),
|
||||||
|
ContentTopics: []string{"abc"},
|
||||||
}, legacy_filter.WithPeer(wakuNode1.host.ID()))
|
}, legacy_filter.WithPeer(wakuNode1.host.ID()))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
@ -281,7 +283,10 @@ func TestDecoupledStoreFromRelay(t *testing.T) {
|
||||||
go func() {
|
go func() {
|
||||||
// MSG1 should be pushed in NODE2 via filter
|
// MSG1 should be pushed in NODE2 via filter
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
env := <-filter.Chan
|
env, ok := <-filter.Chan
|
||||||
|
if !ok {
|
||||||
|
require.Fail(t, "no message")
|
||||||
|
}
|
||||||
require.Equal(t, msg.Timestamp, env.Message().Timestamp)
|
require.Equal(t, msg.Timestamp, env.Message().Timestamp)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|
|
@ -104,10 +104,6 @@ func (fm *FilterMap) Notify(msg *pb.WakuMessage, requestID string) {
|
||||||
// Broadcasting message so it's stored
|
// Broadcasting message so it's stored
|
||||||
fm.broadcaster.Submit(envelope)
|
fm.broadcaster.Submit(envelope)
|
||||||
|
|
||||||
if msg.ContentTopic == "" {
|
|
||||||
filter.Chan <- envelope
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: In case of no topics we should either trigger here for all messages,
|
// TODO: In case of no topics we should either trigger here for all messages,
|
||||||
// or we should not allow such filter to exist in the first place.
|
// or we should not allow such filter to exist in the first place.
|
||||||
for _, contentTopic := range filter.ContentFilters {
|
for _, contentTopic := range filter.ContentFilters {
|
||||||
|
|
|
@ -49,14 +49,14 @@ func (m *metricsImpl) RecordMessage() {
|
||||||
type metricsErrCategory string
|
type metricsErrCategory string
|
||||||
|
|
||||||
var (
|
var (
|
||||||
decodeRPCFailure metricsErrCategory = "decode_rpc_failure"
|
decodeRPCFailure metricsErrCategory = "decode_rpc_failure"
|
||||||
writeRequestFailure metricsErrCategory = "write_request_failure"
|
writeRequestFailure metricsErrCategory = "write_request_failure"
|
||||||
writeResponseFailure metricsErrCategory = "write_response_failure"
|
writeResponseFailure metricsErrCategory = "write_response_failure"
|
||||||
dialFailure metricsErrCategory = "dial_failure"
|
dialFailure metricsErrCategory = "dial_failure"
|
||||||
messagePushFailure metricsErrCategory = "message_push_failure"
|
messagePushFailure metricsErrCategory = "message_push_failure"
|
||||||
emptyRequestBodyFailure metricsErrCategory = "empty_request_body_failure"
|
requestBodyFailure metricsErrCategory = "request_failure"
|
||||||
emptyResponseBodyFailure metricsErrCategory = "empty_response_body_failure"
|
responseBodyFailure metricsErrCategory = "response_body_failure"
|
||||||
peerNotFoundFailure metricsErrCategory = "peer_not_found_failure"
|
peerNotFoundFailure metricsErrCategory = "peer_not_found_failure"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RecordError increases the counter for different error types
|
// RecordError increases the counter for different error types
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
package pb
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
errMissingRequestID = errors.New("missing RequestId field")
|
||||||
|
errMissingQuery = errors.New("missing Query field")
|
||||||
|
errMissingMessage = errors.New("missing Message field")
|
||||||
|
errMissingPubsubTopic = errors.New("missing PubsubTopic field")
|
||||||
|
errRequestIDMismatch = errors.New("RequestID in response does not match request")
|
||||||
|
errMissingResponse = errors.New("missing Response field")
|
||||||
|
)
|
||||||
|
|
||||||
|
func (x *PushRPC) ValidateRequest() error {
|
||||||
|
if x.RequestId == "" {
|
||||||
|
return errMissingRequestID
|
||||||
|
}
|
||||||
|
|
||||||
|
if x.Query == nil {
|
||||||
|
return errMissingQuery
|
||||||
|
}
|
||||||
|
|
||||||
|
if x.Query.PubsubTopic == "" {
|
||||||
|
return errMissingPubsubTopic
|
||||||
|
}
|
||||||
|
|
||||||
|
if x.Query.Message == nil {
|
||||||
|
return errMissingMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
return x.Query.Message.Validate()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *PushRPC) ValidateResponse(requestID string) error {
|
||||||
|
if x.RequestId == "" {
|
||||||
|
return errMissingRequestID
|
||||||
|
}
|
||||||
|
|
||||||
|
if x.RequestId != requestID {
|
||||||
|
return errRequestIDMismatch
|
||||||
|
}
|
||||||
|
|
||||||
|
if x.Response == nil {
|
||||||
|
return errMissingResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,35 @@
|
||||||
|
package pb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/waku-org/go-waku/waku/v2/protocol/pb"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidateRequest(t *testing.T) {
|
||||||
|
request := PushRPC{}
|
||||||
|
require.ErrorIs(t, request.ValidateRequest(), errMissingRequestID)
|
||||||
|
request.RequestId = "test"
|
||||||
|
require.ErrorIs(t, request.ValidateRequest(), errMissingQuery)
|
||||||
|
request.Query = &PushRequest{}
|
||||||
|
require.ErrorIs(t, request.ValidateRequest(), errMissingPubsubTopic)
|
||||||
|
request.Query.PubsubTopic = "test"
|
||||||
|
require.ErrorIs(t, request.ValidateRequest(), errMissingMessage)
|
||||||
|
request.Query.Message = &pb.WakuMessage{
|
||||||
|
Payload: []byte{1, 2, 3},
|
||||||
|
ContentTopic: "test",
|
||||||
|
}
|
||||||
|
require.NoError(t, request.ValidateRequest())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateResponse(t *testing.T) {
|
||||||
|
response := PushRPC{}
|
||||||
|
require.ErrorIs(t, response.ValidateResponse("test"), errMissingRequestID)
|
||||||
|
response.RequestId = "test1"
|
||||||
|
require.ErrorIs(t, response.ValidateResponse("test"), errRequestIDMismatch)
|
||||||
|
response.RequestId = "test"
|
||||||
|
require.ErrorIs(t, response.ValidateResponse("test"), errMissingResponse)
|
||||||
|
response.Response = &PushResponse{}
|
||||||
|
require.NoError(t, response.ValidateResponse("test"))
|
||||||
|
}
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
|
|
||||||
"github.com/libp2p/go-libp2p/core/host"
|
"github.com/libp2p/go-libp2p/core/host"
|
||||||
|
@ -81,7 +82,6 @@ func (wakuLP *WakuLightPush) onRequest(ctx context.Context) func(network.Stream)
|
||||||
logger := wakuLP.log.With(logging.HostID("peer", stream.Conn().RemotePeer()))
|
logger := wakuLP.log.With(logging.HostID("peer", stream.Conn().RemotePeer()))
|
||||||
requestPushRPC := &pb.PushRPC{}
|
requestPushRPC := &pb.PushRPC{}
|
||||||
|
|
||||||
writer := pbio.NewDelimitedWriter(stream)
|
|
||||||
reader := pbio.NewDelimitedReader(stream, math.MaxInt32)
|
reader := pbio.NewDelimitedReader(stream, math.MaxInt32)
|
||||||
|
|
||||||
err := reader.ReadMsg(requestPushRPC)
|
err := reader.ReadMsg(requestPushRPC)
|
||||||
|
@ -94,67 +94,69 @@ func (wakuLP *WakuLightPush) onRequest(ctx context.Context) func(network.Stream)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
responsePushRPC := &pb.PushRPC{
|
||||||
|
RequestId: requestPushRPC.RequestId,
|
||||||
|
Response: &pb.PushResponse{},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := requestPushRPC.ValidateRequest(); err != nil {
|
||||||
|
responsePushRPC.Response.Info = err.Error()
|
||||||
|
wakuLP.metrics.RecordError(requestBodyFailure)
|
||||||
|
wakuLP.reply(stream, responsePushRPC, logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
logger = logger.With(zap.String("requestID", requestPushRPC.RequestId))
|
logger = logger.With(zap.String("requestID", requestPushRPC.RequestId))
|
||||||
|
|
||||||
responsePushRPC := &pb.PushRPC{}
|
logger.Info("push request")
|
||||||
responsePushRPC.RequestId = requestPushRPC.RequestId
|
|
||||||
|
|
||||||
if requestPushRPC.Query != nil {
|
pubSubTopic := requestPushRPC.Query.PubsubTopic
|
||||||
logger.Info("push request")
|
message := requestPushRPC.Query.Message
|
||||||
response := new(pb.PushResponse)
|
|
||||||
|
|
||||||
pubSubTopic := requestPushRPC.Query.PubsubTopic
|
wakuLP.metrics.RecordMessage()
|
||||||
message := requestPushRPC.Query.Message
|
|
||||||
|
|
||||||
wakuLP.metrics.RecordMessage()
|
// TODO: Assumes success, should probably be extended to check for network, peers, etc
|
||||||
|
// It might make sense to use WithReadiness option here?
|
||||||
|
|
||||||
// TODO: Assumes success, should probably be extended to check for network, peers, etc
|
_, err = wakuLP.relay.PublishToTopic(ctx, message, pubSubTopic)
|
||||||
// It might make sense to use WithReadiness option here?
|
if err != nil {
|
||||||
|
logger.Error("publishing message", zap.Error(err))
|
||||||
_, err := wakuLP.relay.PublishToTopic(ctx, message, pubSubTopic)
|
wakuLP.metrics.RecordError(messagePushFailure)
|
||||||
|
responsePushRPC.Response.Info = fmt.Sprintf("Could not publish message: %s", err.Error())
|
||||||
if err != nil {
|
return
|
||||||
logger.Error("publishing message", zap.Error(err))
|
|
||||||
wakuLP.metrics.RecordError(messagePushFailure)
|
|
||||||
response.Info = "Could not publish message"
|
|
||||||
} else {
|
|
||||||
response.IsSuccess = true
|
|
||||||
response.Info = "OK"
|
|
||||||
}
|
|
||||||
|
|
||||||
responsePushRPC.Response = response
|
|
||||||
|
|
||||||
err = writer.WriteMsg(responsePushRPC)
|
|
||||||
if err != nil {
|
|
||||||
wakuLP.metrics.RecordError(writeResponseFailure)
|
|
||||||
logger.Error("writing response", zap.Error(err))
|
|
||||||
if err := stream.Reset(); err != nil {
|
|
||||||
wakuLP.log.Error("resetting connection", zap.Error(err))
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("response sent")
|
|
||||||
stream.Close()
|
|
||||||
} else {
|
} else {
|
||||||
wakuLP.metrics.RecordError(emptyRequestBodyFailure)
|
responsePushRPC.Response.IsSuccess = true
|
||||||
if err := stream.Reset(); err != nil {
|
responsePushRPC.Response.Info = "OK"
|
||||||
wakuLP.log.Error("resetting connection", zap.Error(err))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if requestPushRPC.Response != nil {
|
wakuLP.reply(stream, responsePushRPC, logger)
|
||||||
if requestPushRPC.Response.IsSuccess {
|
|
||||||
logger.Info("request success")
|
logger.Info("response sent")
|
||||||
} else {
|
|
||||||
logger.Info("request failure", zap.String("info=", requestPushRPC.Response.Info))
|
stream.Close()
|
||||||
}
|
|
||||||
|
if responsePushRPC.Response.IsSuccess {
|
||||||
|
logger.Info("request success")
|
||||||
} else {
|
} else {
|
||||||
wakuLP.metrics.RecordError(emptyResponseBodyFailure)
|
logger.Info("request failure", zap.String("info", responsePushRPC.Response.Info))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (wakuLP *WakuLightPush) reply(stream network.Stream, responsePushRPC *pb.PushRPC, logger *zap.Logger) {
|
||||||
|
writer := pbio.NewDelimitedWriter(stream)
|
||||||
|
err := writer.WriteMsg(responsePushRPC)
|
||||||
|
if err != nil {
|
||||||
|
wakuLP.metrics.RecordError(writeResponseFailure)
|
||||||
|
logger.Error("writing response", zap.Error(err))
|
||||||
|
if err := stream.Reset(); err != nil {
|
||||||
|
wakuLP.log.Error("resetting connection", zap.Error(err))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stream.Close()
|
||||||
|
}
|
||||||
|
|
||||||
// request sends a message via lightPush protocol to either a specified peer or peer that is selected.
|
// request sends a message via lightPush protocol to either a specified peer or peer that is selected.
|
||||||
func (wakuLP *WakuLightPush) request(ctx context.Context, req *pb.PushRequest, params *lightPushParameters) (*pb.PushResponse, error) {
|
func (wakuLP *WakuLightPush) request(ctx context.Context, req *pb.PushRequest, params *lightPushParameters) (*pb.PushResponse, error) {
|
||||||
if params == nil {
|
if params == nil {
|
||||||
|
@ -201,6 +203,11 @@ func (wakuLP *WakuLightPush) request(ctx context.Context, req *pb.PushRequest, p
|
||||||
|
|
||||||
stream.Close()
|
stream.Close()
|
||||||
|
|
||||||
|
if err = pushResponseRPC.ValidateResponse(pushRequestRPC.RequestId); err != nil {
|
||||||
|
wakuLP.metrics.RecordError(responseBodyFailure)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return pushResponseRPC.Response, nil
|
return pushResponseRPC.Response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
package pb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
const MaxMetaAttrLength = 64
|
||||||
|
|
||||||
|
var (
|
||||||
|
errMissingPayload = errors.New("missing Payload field")
|
||||||
|
errMissingContentTopic = errors.New("missing ContentTopic field")
|
||||||
|
errInvalidMetaLength = errors.New("invalid length for Meta field")
|
||||||
|
)
|
||||||
|
|
||||||
|
func (msg *WakuMessage) Validate() error {
|
||||||
|
if len(msg.Payload) == 0 {
|
||||||
|
return errMissingPayload
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg.ContentTopic == "" {
|
||||||
|
return errMissingContentTopic
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(msg.Meta) > MaxMetaAttrLength {
|
||||||
|
return errInvalidMetaLength
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Unmarshal(data []byte) (*WakuMessage, error) {
|
||||||
|
msg := &WakuMessage{}
|
||||||
|
err := proto.Unmarshal(data, msg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = msg.Validate()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return msg, nil
|
||||||
|
|
||||||
|
}
|
|
@ -12,7 +12,6 @@ import (
|
||||||
"github.com/ethereum/go-ethereum/crypto/secp256k1"
|
"github.com/ethereum/go-ethereum/crypto/secp256k1"
|
||||||
pubsub "github.com/libp2p/go-libp2p-pubsub"
|
pubsub "github.com/libp2p/go-libp2p-pubsub"
|
||||||
"github.com/libp2p/go-libp2p/core/peer"
|
"github.com/libp2p/go-libp2p/core/peer"
|
||||||
proto "google.golang.org/protobuf/proto"
|
|
||||||
|
|
||||||
"github.com/waku-org/go-waku/waku/v2/hash"
|
"github.com/waku-org/go-waku/waku/v2/hash"
|
||||||
"github.com/waku-org/go-waku/waku/v2/protocol/pb"
|
"github.com/waku-org/go-waku/waku/v2/protocol/pb"
|
||||||
|
@ -62,8 +61,7 @@ func (w *WakuRelay) RemoveTopicValidator(topic string) {
|
||||||
|
|
||||||
func (w *WakuRelay) topicValidator(topic string) func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool {
|
func (w *WakuRelay) topicValidator(topic string) func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool {
|
||||||
return func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool {
|
return func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool {
|
||||||
msg := new(pb.WakuMessage)
|
msg, err := pb.Unmarshal(message.Data)
|
||||||
err := proto.Unmarshal(message.Data, msg)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
|
@ -234,12 +234,15 @@ func (w *WakuRelay) PublishToTopic(ctx context.Context, message *pb.WakuMessage,
|
||||||
return nil, errors.New("message can't be null")
|
return nil, errors.New("message can't be null")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := message.Validate(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
if !w.EnoughPeersToPublishToTopic(topic) {
|
if !w.EnoughPeersToPublishToTopic(topic) {
|
||||||
return nil, errors.New("not enough peers to publish")
|
return nil, errors.New("not enough peers to publish")
|
||||||
}
|
}
|
||||||
|
|
||||||
pubSubTopic, err := w.upsertTopic(topic)
|
pubSubTopic, err := w.upsertTopic(topic)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -461,11 +464,13 @@ func (w *WakuRelay) pubsubTopicMsgHandler(pubsubTopic string, sub *pubsub.Subscr
|
||||||
sub.Cancel()
|
sub.Cancel()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
wakuMessage := &pb.WakuMessage{}
|
|
||||||
if err := proto.Unmarshal(msg.Data, wakuMessage); err != nil {
|
wakuMessage, err := pb.Unmarshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
w.log.Error("decoding message", zap.Error(err))
|
w.log.Error("decoding message", zap.Error(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
envelope := waku_proto.NewEnvelope(wakuMessage, w.timesource.Now().UnixNano(), pubsubTopic)
|
envelope := waku_proto.NewEnvelope(wakuMessage, w.timesource.Now().UnixNano(), pubsubTopic)
|
||||||
w.metrics.RecordMessage(envelope)
|
w.metrics.RecordMessage(envelope)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue