feat: add WakuMessage validation in gossipsub

This also stores the waku message in a cache to avoid having to decode it twice
This commit is contained in:
Richard Ramos 2023-02-08 10:20:40 -04:00 committed by RichΛrd
parent 7b3f4aade7
commit 144dfa5b7b
3 changed files with 128 additions and 12 deletions

View File

@ -8,6 +8,7 @@ import (
"sync" "sync"
"github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/protocol"
"go.opencensus.io/stats" "go.opencensus.io/stats"
"go.opencensus.io/tag" "go.opencensus.io/tag"
@ -29,6 +30,18 @@ const WakuRelayID_v200 = protocol.ID("/vac/waku/relay/2.0.0")
var DefaultWakuTopic string = waku_proto.DefaultPubsubTopic().String() var DefaultWakuTopic string = waku_proto.DefaultPubsubTopic().String()
type cacheItem struct {
msg *pb.WakuMessage
id string
pubsubTopic string
}
type msgReq struct {
id string
pubsubTopic string
ch chan *pb.WakuMessage
}
type WakuRelay struct { type WakuRelay struct {
host host.Host host host.Host
opts []pubsub.Option opts []pubsub.Option
@ -49,6 +62,14 @@ type WakuRelay struct {
// TODO: convert to concurrent maps // TODO: convert to concurrent maps
subscriptions map[string][]*Subscription subscriptions map[string][]*Subscription
subscriptionsMutex sync.Mutex subscriptionsMutex sync.Mutex
msgCache map[string]*pb.WakuMessage
msgCacheCh chan cacheItem
deleteCacheCh chan string
getMsgCh chan msgReq
cancel context.CancelFunc
wg sync.WaitGroup
} }
func msgIdFn(pmsg *pubsub_pb.Message) string { func msgIdFn(pmsg *pubsub_pb.Message) string {
@ -65,6 +86,7 @@ func NewWakuRelay(h host.Host, bcaster v2.Broadcaster, minPeersToPublish int, ti
w.subscriptions = make(map[string][]*Subscription) w.subscriptions = make(map[string][]*Subscription)
w.bcaster = bcaster w.bcaster = bcaster
w.minPeersToPublish = minPeersToPublish w.minPeersToPublish = minPeersToPublish
w.wg = sync.WaitGroup{}
w.log = log.Named("relay") w.log = log.Named("relay")
// default options required by WakuRelay // default options required by WakuRelay
@ -91,6 +113,18 @@ func NewWakuRelay(h host.Host, bcaster v2.Broadcaster, minPeersToPublish int, ti
} }
func (w *WakuRelay) Start(ctx context.Context) error { func (w *WakuRelay) Start(ctx context.Context) error {
w.wg.Wait()
ctx, cancel := context.WithCancel(ctx)
w.cancel = cancel
w.msgCacheCh = make(chan cacheItem, 1000)
w.deleteCacheCh = make(chan string, 1000)
w.msgCache = make(map[string]*pb.WakuMessage, 1000)
w.getMsgCh = make(chan msgReq, 1000)
w.wg.Add(1)
go w.cacheWorker(ctx)
ps, err := pubsub.NewGossipSub(ctx, w.host, w.opts...) ps, err := pubsub.NewGossipSub(ctx, w.host, w.opts...)
if err != nil { if err != nil {
return err return err
@ -101,6 +135,56 @@ func (w *WakuRelay) Start(ctx context.Context) error {
return nil return nil
} }
func (w *WakuRelay) getMessageFromCache(topic string, id string) *pb.WakuMessage {
resultCh := make(chan *pb.WakuMessage, 1)
defer close(resultCh)
w.getMsgCh <- msgReq{
id: id,
pubsubTopic: topic,
ch: resultCh,
}
return <-resultCh
}
func (w *WakuRelay) AddToCache(pubsubTopic string, id string, msg *pb.WakuMessage) {
w.msgCacheCh <- cacheItem{
msg: msg,
id: id,
pubsubTopic: pubsubTopic,
}
}
func (w *WakuRelay) cacheWorker(ctx context.Context) {
defer w.wg.Done()
deleteCounter := 0
for {
select {
case <-ctx.Done():
return
case item := <-w.msgCacheCh:
w.msgCache[item.pubsubTopic+item.id] = item.msg
case req := <-w.getMsgCh:
key := req.pubsubTopic + req.id
req.ch <- w.msgCache[key]
w.deleteCacheCh <- key
case key := <-w.deleteCacheCh:
delete(w.msgCache, key)
deleteCounter++
// Shrink msg cache to avoid oom
if deleteCounter > 1000 {
newMsgCache := make(map[string]*pb.WakuMessage, 1000)
for k, v := range w.msgCache {
newMsgCache[k] = v
}
w.msgCache = newMsgCache
}
}
}
}
// PubSub returns the implementation of the pubsub system // PubSub returns the implementation of the pubsub system
func (w *WakuRelay) PubSub() *pubsub.PubSub { func (w *WakuRelay) PubSub() *pubsub.PubSub {
return w.pubsub return w.pubsub
@ -139,6 +223,20 @@ func (w *WakuRelay) upsertTopic(topic string) (*pubsub.Topic, error) {
return pubSubTopic, nil return pubSubTopic, nil
} }
func (w *WakuRelay) validatorFactory(pubsubTopic string) func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool {
return func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool {
msg := new(pb.WakuMessage)
err := proto.Unmarshal(message.Data, msg)
if err != nil {
return false
}
w.AddToCache(pubsubTopic, message.ID, msg)
return true
}
}
func (w *WakuRelay) subscribe(topic string) (subs *pubsub.Subscription, err error) { func (w *WakuRelay) subscribe(topic string) (subs *pubsub.Subscription, err error) {
sub, ok := w.relaySubs[topic] sub, ok := w.relaySubs[topic]
if !ok { if !ok {
@ -147,6 +245,11 @@ func (w *WakuRelay) subscribe(topic string) (subs *pubsub.Subscription, err erro
return nil, err return nil, err
} }
err = w.pubsub.RegisterTopicValidator(topic, w.validatorFactory(topic))
if err != nil {
return nil, err
}
sub, err = pubSubTopic.Subscribe() sub, err = pubSubTopic.Subscribe()
if err != nil { if err != nil {
return nil, err return nil, err
@ -204,7 +307,21 @@ func (w *WakuRelay) Publish(ctx context.Context, message *pb.WakuMessage) ([]byt
// Stop unmounts the relay protocol and stops all subscriptions // Stop unmounts the relay protocol and stops all subscriptions
func (w *WakuRelay) Stop() { func (w *WakuRelay) Stop() {
if w.cancel == nil {
return // Not started
}
w.host.RemoveStreamHandler(WakuRelayID_v200) w.host.RemoveStreamHandler(WakuRelayID_v200)
w.cancel()
w.wg.Wait()
close(w.msgCacheCh)
close(w.deleteCacheCh)
close(w.getMsgCh)
w.msgCache = nil
w.subscriptionsMutex.Lock() w.subscriptionsMutex.Lock()
defer w.subscriptionsMutex.Unlock() defer w.subscriptionsMutex.Unlock()
@ -228,10 +345,7 @@ func (w *WakuRelay) EnoughPeersToPublishToTopic(topic string) bool {
// SubscribeToTopic returns a Subscription to receive messages from a pubsub topic // SubscribeToTopic returns a Subscription to receive messages from a pubsub topic
func (w *WakuRelay) SubscribeToTopic(ctx context.Context, topic string) (*Subscription, error) { func (w *WakuRelay) SubscribeToTopic(ctx context.Context, topic string) (*Subscription, error) {
// Subscribes to a PubSub topic.
// NOTE The data field SHOULD be decoded as a WakuMessage.
sub, err := w.subscribe(topic) sub, err := w.subscribe(topic)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -314,7 +428,7 @@ func (w *WakuRelay) nextMessage(ctx context.Context, sub *pubsub.Subscription) <
return msgChannel return msgChannel
} }
func (w *WakuRelay) subscribeToTopic(ctx context.Context, t string, subscription *Subscription, sub *pubsub.Subscription) { func (w *WakuRelay) subscribeToTopic(ctx context.Context, pubsubTopic string, subscription *Subscription, sub *pubsub.Subscription) {
ctx, err := tag.New(ctx, tag.Insert(metrics.KeyType, "relay")) ctx, err := tag.New(ctx, tag.Insert(metrics.KeyType, "relay"))
if err != nil { if err != nil {
w.log.Error("creating tag map", zap.Error(err)) w.log.Error("creating tag map", zap.Error(err))
@ -322,7 +436,6 @@ func (w *WakuRelay) subscribeToTopic(ctx context.Context, t string, subscription
} }
subChannel := w.nextMessage(ctx, sub) subChannel := w.nextMessage(ctx, sub)
for { for {
select { select {
case <-subscription.quit: case <-subscription.quit:
@ -339,20 +452,16 @@ func (w *WakuRelay) subscribeToTopic(ctx context.Context, t string, subscription
} }
close(subscription.C) close(subscription.C)
}(t) }(pubsubTopic)
// TODO: if there are no more relay subscriptions, close the pubsub subscription // TODO: if there are no more relay subscriptions, close the pubsub subscription
case msg := <-subChannel: case msg := <-subChannel:
if msg == nil { if msg == nil {
return return
} }
stats.Record(ctx, metrics.Messages.M(1)) stats.Record(ctx, metrics.Messages.M(1))
wakuMessage := &pb.WakuMessage{}
if err := proto.Unmarshal(msg.Data, wakuMessage); err != nil {
w.log.Error("decoding message", zap.Error(err))
return
}
envelope := waku_proto.NewEnvelope(wakuMessage, w.timesource.Now().UnixNano(), string(t)) wakuMessage := w.getMessageFromCache(pubsubTopic, msg.ID)
envelope := waku_proto.NewEnvelope(wakuMessage, w.timesource.Now().UnixNano(), string(pubsubTopic))
w.log.Debug("waku.relay received", logging.HexString("hash", envelope.Hash())) w.log.Debug("waku.relay received", logging.HexString("hash", envelope.Hash()))

View File

@ -36,6 +36,7 @@ func TestWakuRelay(t *testing.T) {
require.Equal(t, testTopic, topics[0]) require.Equal(t, testTopic, topics[0])
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
go func() { go func() {
defer cancel() defer cancel()

View File

@ -372,6 +372,9 @@ func (r *WakuRLNRelay) addValidator(
zap.Binary("payload", wakuMessage.Payload), zap.Binary("payload", wakuMessage.Payload),
zap.Any("proof", wakuMessage.RateLimitProof), zap.Any("proof", wakuMessage.RateLimitProof),
) )
relay.AddToCache(pubsubTopic, message.ID, wakuMessage)
return true return true
case MessageValidationResult_Invalid: case MessageValidationResult_Invalid:
r.log.Debug("message could not be verified", r.log.Debug("message could not be verified",
@ -404,6 +407,9 @@ func (r *WakuRLNRelay) addValidator(
} }
} }
// In case there's a topic validator registered
_ = relay.PubSub().UnregisterTopicValidator(pubsubTopic)
return relay.PubSub().RegisterTopicValidator(pubsubTopic, validator) return relay.PubSub().RegisterTopicValidator(pubsubTopic, validator)
} }