complete validator functions

- make validators time out after 100ms
  - add context param to validator functions
  - add type Validator func(context.Context, *Message) bool
- drop message if more than 10 messages are already being validated
This commit is contained in:
keks 2017-11-22 18:40:45 +01:00 committed by vyzo
parent 89e6a06f3c
commit 02877cda71
3 changed files with 190 additions and 35 deletions

View File

@ -17,7 +17,11 @@ import (
timecache "github.com/whyrusleeping/timecache"
)
const ID = protocol.ID("/floodsub/1.0.0")
const (
ID = protocol.ID("/floodsub/1.0.0")
maxConcurrency = 10
validateTimeoutMillis = 100
)
var log = logging.Logger("floodsub")
@ -57,6 +61,9 @@ type PubSub struct {
// sendMsg handles messages that have been validated
sendMsg chan sendReq
// throttleValidate bounds the number of goroutines concurrently validating messages
throttleValidate chan struct{}
peers map[peer.ID]chan *RPC
seenMessages *timecache.TimeCache
@ -84,22 +91,23 @@ type RPC struct {
// NewFloodSub returns a new FloodSub management object
func NewFloodSub(ctx context.Context, h host.Host) *PubSub {
ps := &PubSub{
host: h,
ctx: ctx,
incoming: make(chan *RPC, 32),
publish: make(chan *Message),
newPeers: make(chan inet.Stream),
peerDead: make(chan peer.ID),
cancelCh: make(chan *Subscription),
getPeers: make(chan *listPeerReq),
addSub: make(chan *addSubReq),
getTopics: make(chan *topicReq),
sendMsg: make(chan sendReq),
myTopics: make(map[string]map[*Subscription]struct{}),
topics: make(map[string]map[peer.ID]struct{}),
peers: make(map[peer.ID]chan *RPC),
seenMessages: timecache.NewTimeCache(time.Second * 30),
counter: uint64(time.Now().UnixNano()),
host: h,
ctx: ctx,
incoming: make(chan *RPC, 32),
publish: make(chan *Message),
newPeers: make(chan inet.Stream),
peerDead: make(chan peer.ID),
cancelCh: make(chan *Subscription),
getPeers: make(chan *listPeerReq),
addSub: make(chan *addSubReq),
getTopics: make(chan *topicReq),
sendMsg: make(chan sendReq),
myTopics: make(map[string]map[*Subscription]struct{}),
topics: make(map[string]map[peer.ID]struct{}),
peers: make(map[peer.ID]chan *RPC),
seenMessages: timecache.NewTimeCache(time.Second * 30),
counter: uint64(time.Now().UnixNano()),
throttleValidate: make(chan struct{}, maxConcurrency),
}
h.SetStreamHandler(ID, ps.handleNewStream)
@ -181,14 +189,23 @@ func (p *PubSub) processLoop(ctx context.Context) {
}
case msg := <-p.publish:
subs := p.getSubscriptions(msg) // call before goroutine!
go func() {
if p.validate(subs, msg) {
p.sendMsg <- sendReq{
from: p.host.ID(),
msg: msg,
select {
case p.throttleValidate <- struct{}{}:
go func() {
defer func() { <-p.throttleValidate }()
if p.validate(subs, msg) {
p.sendMsg <- sendReq{
from: p.host.ID(),
msg: msg,
}
}
}
}()
}()
default:
log.Warning("could not acquire validator; dropping message")
}
case req := <-p.sendMsg:
p.maybePublishMessage(req.from, req.msg.Message)
@ -328,14 +345,22 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) error {
}
subs := p.getSubscriptions(&Message{pmsg}) // call before goroutine!
go func(pmsg *pb.Message) {
if p.validate(subs, &Message{pmsg}) {
p.sendMsg <- sendReq{
from: rpc.from,
msg: &Message{pmsg},
select {
case p.throttleValidate <- struct{}{}:
go func(pmsg *pb.Message) {
defer func() { <-p.throttleValidate }()
if p.validate(subs, &Message{pmsg}) {
p.sendMsg <- sendReq{
from: rpc.from,
msg: &Message{pmsg},
}
}
}
}(pmsg)
}(pmsg)
default:
log.Warning("could not acquire validator; dropping message")
}
}
return nil
}
@ -348,7 +373,10 @@ func msgID(pmsg *pb.Message) string {
// validate is called in a goroutine and calls the validate functions of all subs with msg as parameter.
func (p *PubSub) validate(subs []*Subscription, msg *Message) bool {
for _, sub := range subs {
if sub.validate != nil && !sub.validate(msg) {
ctx, cancel := context.WithTimeout(p.ctx, validateTimeoutMillis*time.Millisecond)
defer cancel()
if sub.validate != nil && !sub.validate(ctx, msg) {
log.Debugf("validator for topic %s returned false", sub.topic)
return false
}
@ -432,9 +460,10 @@ type addSubReq struct {
}
type SubOpt func(*Subscription) error
type Validator func(context.Context, *Message) bool
// WithValidator is an option that can be supplied to Subscribe. The argument is a function that returns whether or not a given message should be propagated further.
func WithValidator(validate func(*Message) bool) func(*Subscription) error {
func WithValidator(validate Validator) func(*Subscription) error {
return func(sub *Subscription) error {
sub.validate = validate
return nil

View File

@ -6,6 +6,7 @@ import (
"fmt"
"math/rand"
"sort"
"sync"
"testing"
"time"
@ -343,7 +344,7 @@ func TestValidate(t *testing.T) {
connect(t, hosts[0], hosts[1])
topic := "foobar"
sub, err := psubs[1].Subscribe(topic, WithValidator(func(msg *Message) bool {
sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool {
return !bytes.Contains(msg.Data, []byte("illegal"))
}))
if err != nil {
@ -384,6 +385,131 @@ func TestValidate(t *testing.T) {
}
}
func TestValidateCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hosts := getNetHosts(t, ctx, 2)
psubs := getPubsubs(ctx, hosts)
connect(t, hosts[0], hosts[1])
topic := "foobar"
sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool {
<-ctx.Done()
return true
}))
if err != nil {
t.Fatal(err)
}
time.Sleep(time.Millisecond * 50)
testmsg := []byte("this is a legal message")
validates := true
p := psubs[0]
err = p.Publish(topic, testmsg)
if err != nil {
t.Fatal(err)
}
select {
case msg := <-sub.ch:
if !validates {
t.Log(msg)
t.Error("expected message validation to filter out the message")
}
case <-time.After(333 * time.Millisecond):
if validates {
t.Error("expected message validation to accept the message")
}
}
}
func TestValidateOverload(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hosts := getNetHosts(t, ctx, 2)
psubs := getPubsubs(ctx, hosts)
connect(t, hosts[0], hosts[1])
topic := "foobar"
block := make(chan struct{})
sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool {
_, _ = <-block
return true
}))
if err != nil {
t.Fatal(err)
}
time.Sleep(time.Millisecond * 50)
msgs := []struct {
msg []byte
validates bool
}{
{msg: []byte("this is a legal message"), validates: true},
{msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true},
{msg: []byte("there also is nothing controversial about this message"), validates: true},
{msg: []byte("also fine"), validates: true},
{msg: []byte("still, all good"), validates: true},
{msg: []byte("this is getting boring"), validates: true},
{msg: []byte("foo"), validates: true},
{msg: []byte("foobar"), validates: true},
{msg: []byte("foofoo"), validates: true},
{msg: []byte("barfoo"), validates: true},
{msg: []byte("barbar"), validates: false},
}
if len(msgs) != maxConcurrency+1 {
t.Fatalf("expected number of messages sent to be maxConcurrency+1. Got %d, expected %d", len(msgs), maxConcurrency+1)
}
p := psubs[0]
var wg sync.WaitGroup
wg.Add(1)
go func() {
for _, tc := range msgs {
select {
case msg := <-sub.ch:
if !tc.validates {
t.Log(msg)
t.Error("expected message validation to drop the message because all validator goroutines are taken")
}
case <-time.After(333 * time.Millisecond):
if tc.validates {
t.Error("expected message validation to accept the message")
}
}
}
wg.Done()
}()
for i, tc := range msgs {
err := p.Publish(topic, tc.msg)
if err != nil {
t.Fatal(err)
}
// wait a bit to let pubsub's internal state machine start validating the message
time.Sleep(10 * time.Millisecond)
// unblock validator goroutines after we sent one too many
if i == len(msgs)-1 {
close(block)
}
}
wg.Wait()
}
func assertPeerLists(t *testing.T, hosts []host.Host, ps *PubSub, has ...int) {
peers := ps.ListPeers("")
set := make(map[peer.ID]struct{})

View File

@ -9,7 +9,7 @@ type Subscription struct {
ch chan *Message
cancelCh chan<- *Subscription
err error
validate func(*Message) bool
validate Validator
}
func (sub *Subscription) Topic() string {