optionally allow caller to validate messages

This commit is contained in:
keks 2017-11-08 20:00:52 +01:00 committed by vyzo
parent 0ec8f2fa08
commit 647bb98655
3 changed files with 137 additions and 20 deletions

View File

@ -54,6 +54,9 @@ type PubSub struct {
// topics tracks which topics each of our peers are subscribed to
topics map[string]map[peer.ID]struct{}
// sendMsg handles messages that have been validated
sendMsg chan sendReq
peers map[peer.ID]chan *RPC
seenMessages *timecache.TimeCache
@ -91,6 +94,7 @@ func NewFloodSub(ctx context.Context, h host.Host) *PubSub {
getPeers: make(chan *listPeerReq),
addSub: make(chan *addSubReq),
getTopics: make(chan *topicReq),
sendMsg: make(chan sendReq),
myTopics: make(map[string]map[*Subscription]struct{}),
topics: make(map[string]map[peer.ID]struct{}),
peers: make(map[peer.ID]chan *RPC),
@ -176,7 +180,19 @@ func (p *PubSub) processLoop(ctx context.Context) {
continue
}
case msg := <-p.publish:
p.maybePublishMessage(p.host.ID(), msg.Message)
subs := p.getSubscriptions(msg) // call before goroutine!
go func() {
if p.validate(subs, msg) {
p.sendMsg <- sendReq{
from: p.host.ID(),
msg: msg,
}
}
}()
case req := <-p.sendMsg:
p.maybePublishMessage(req.from, req.msg.Message)
case <-ctx.Done():
log.Info("pubsub processloop shutting down")
return
@ -210,24 +226,22 @@ func (p *PubSub) handleRemoveSubscription(sub *Subscription) {
// subscribes to the topic.
// Only called from processLoop.
func (p *PubSub) handleAddSubscription(req *addSubReq) {
subs := p.myTopics[req.topic]
sub := req.sub
subs := p.myTopics[sub.topic]
// announce we want this topic
if len(subs) == 0 {
p.announce(req.topic, true)
p.announce(sub.topic, true)
}
// make new if not there
if subs == nil {
p.myTopics[req.topic] = make(map[*Subscription]struct{})
subs = p.myTopics[req.topic]
p.myTopics[sub.topic] = make(map[*Subscription]struct{})
subs = p.myTopics[sub.topic]
}
sub := &Subscription{
ch: make(chan *Message, 32),
topic: req.topic,
cancelCh: p.cancelCh,
}
sub.ch = make(chan *Message, 32)
sub.cancelCh = p.cancelCh
p.myTopics[sub.topic][sub] = struct{}{}
@ -314,7 +328,15 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) error {
continue
}
p.maybePublishMessage(rpc.from, pmsg)
subs := p.getSubscriptions(&Message{pmsg}) // call before goroutine!
go func() {
if p.validate(subs, &Message{pmsg}) {
p.sendMsg <- sendReq{
from: rpc.from,
msg: &*Message{pmsg},
}
}
}()
}
return nil
}
@ -324,6 +346,17 @@ func msgID(pmsg *pb.Message) string {
return string(pmsg.GetFrom()) + string(pmsg.GetSeqno())
}
// validate is called in a goroutine and calls the validate functions of all subs with msg as parameter.
func (p *PubSub) validate(subs []*Subscription, msg *Message) bool {
for _, sub := range subs {
if sub.validate != nil && !sub.validate(msg) {
return false
}
}
return true
}
func (p *PubSub) maybePublishMessage(from peer.ID, pmsg *pb.Message) {
id := msgID(pmsg)
if p.seenMessage(id) {
@ -375,20 +408,47 @@ func (p *PubSub) publishMessage(from peer.ID, msg *pb.Message) error {
return nil
}
// getSubscriptions returns all subscriptions the would receive the given message.
func (p *PubSub) getSubscriptions(msg *Message) []*Subscription {
var subs []*Subscription
for _, topic := range msg.GetTopicIDs() {
tSubs, ok := p.myTopics[topic]
if !ok {
continue
}
for sub := range tSubs {
subs = append(subs, sub)
}
}
return subs
}
type addSubReq struct {
topic string
resp chan *Subscription
sub *Subscription
resp chan *Subscription
}
// WithValidator is an option that can be supplied to Subscribe. The argument is a function that returns whether or not a given message should be propagated further.
func WithValidator(validate func(*Message) bool) func(*Subscription) error {
return func(sub *Subscription) error {
sub.validate = validate
return nil
}
}
// Subscribe returns a new Subscription for the given topic
func (p *PubSub) Subscribe(topic string) (*Subscription, error) {
func (p *PubSub) Subscribe(topic string, opts ...func(*Subscription) error) (*Subscription, error) {
td := pb.TopicDescriptor{Name: &topic}
return p.SubscribeByTopicDescriptor(&td)
return p.SubscribeByTopicDescriptor(&td, opts...)
}
// SubscribeByTopicDescriptor lets you subscribe a topic using a pb.TopicDescriptor
func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor) (*Subscription, error) {
func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor, opts ...func(*Subscription) error) (*Subscription, error) {
if td.GetAuth().GetMode() != pb.TopicDescriptor_AuthOpts_NONE {
return nil, fmt.Errorf("auth mode not yet supported")
}
@ -397,10 +457,21 @@ func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor) (*Subscripti
return nil, fmt.Errorf("encryption mode not yet supported")
}
sub := &Subscription{
topic: td.GetName(),
}
for _, opt := range opts {
err := opt(sub)
if err != nil {
return nil, err
}
}
out := make(chan *Subscription, 1)
p.addSub <- &addSubReq{
topic: td.GetName(),
resp: out,
sub: sub,
resp: out,
}
return <-out, nil
@ -439,6 +510,12 @@ type listPeerReq struct {
topic string
}
// sendReq is a request to call maybePublishMessage. It is issued after the subscription verification is done.
type sendReq struct {
from peer.ID
msg *Message
}
// ListPeers returns a list of peers we are connected to.
func (p *PubSub) ListPeers(topic string) []peer.ID {
out := make(chan []peer.ID)

View File

@ -323,14 +323,53 @@ func TestOneToOne(t *testing.T) {
connect(t, hosts[0], hosts[1])
ch, err := psubs[1].Subscribe("foobar")
sub, err := psubs[1].Subscribe("foobar")
if err != nil {
t.Fatal(err)
}
time.Sleep(time.Millisecond * 50)
checkMessageRouting(t, "foobar", psubs, []*Subscription{ch})
checkMessageRouting(t, "foobar", psubs, []*Subscription{sub})
}
func TestValidate(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hosts := getNetHosts(t, ctx, 2)
psubs := getPubsubs(ctx, hosts)
connect(t, hosts[0], hosts[1])
topic := "foobar"
sub, err := psubs[1].Subscribe(topic, WithValidator(func(msg *Message) bool {
return !bytes.Contains(msg.Data, []byte("illegal"))
}))
if err != nil {
t.Fatal(err)
}
time.Sleep(time.Millisecond * 50)
data := make([]byte, 16)
rand.Read(data)
data = append(data, []byte("illegal")...)
for _, p := range psubs {
err := p.Publish(topic, data)
if err != nil {
t.Fatal(err)
}
select {
case msg := <-sub.ch:
t.Log(msg)
t.Fatal("expected message validation to filter out the message")
case <-time.After(333 * time.Millisecond):
}
}
}
func assertPeerLists(t *testing.T, hosts []host.Host, ps *PubSub, has ...int) {

View File

@ -9,6 +9,7 @@ type Subscription struct {
ch chan *Message
cancelCh chan<- *Subscription
err error
validate func(*Message) bool
}
func (sub *Subscription) Topic() string {