complete validator functions
- make validators time out after 100ms - add context param to validator functions - add type Validator func(context.Context, *Message) bool - drop message if more than 10 messages are already being validated
This commit is contained in:
parent
89e6a06f3c
commit
02877cda71
35
floodsub.go
35
floodsub.go
|
@ -17,7 +17,11 @@ import (
|
|||
timecache "github.com/whyrusleeping/timecache"
|
||||
)
|
||||
|
||||
const ID = protocol.ID("/floodsub/1.0.0")
|
||||
const (
|
||||
ID = protocol.ID("/floodsub/1.0.0")
|
||||
maxConcurrency = 10
|
||||
validateTimeoutMillis = 100
|
||||
)
|
||||
|
||||
var log = logging.Logger("floodsub")
|
||||
|
||||
|
@ -57,6 +61,9 @@ type PubSub struct {
|
|||
// sendMsg handles messages that have been validated
|
||||
sendMsg chan sendReq
|
||||
|
||||
// throttleValidate bounds the number of goroutines concurrently validating messages
|
||||
throttleValidate chan struct{}
|
||||
|
||||
peers map[peer.ID]chan *RPC
|
||||
seenMessages *timecache.TimeCache
|
||||
|
||||
|
@ -100,6 +107,7 @@ func NewFloodSub(ctx context.Context, h host.Host) *PubSub {
|
|||
peers: make(map[peer.ID]chan *RPC),
|
||||
seenMessages: timecache.NewTimeCache(time.Second * 30),
|
||||
counter: uint64(time.Now().UnixNano()),
|
||||
throttleValidate: make(chan struct{}, maxConcurrency),
|
||||
}
|
||||
|
||||
h.SetStreamHandler(ID, ps.handleNewStream)
|
||||
|
@ -181,14 +189,23 @@ func (p *PubSub) processLoop(ctx context.Context) {
|
|||
}
|
||||
case msg := <-p.publish:
|
||||
subs := p.getSubscriptions(msg) // call before goroutine!
|
||||
|
||||
select {
|
||||
case p.throttleValidate <- struct{}{}:
|
||||
go func() {
|
||||
defer func() { <-p.throttleValidate }()
|
||||
|
||||
if p.validate(subs, msg) {
|
||||
p.sendMsg <- sendReq{
|
||||
from: p.host.ID(),
|
||||
msg: msg,
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
default:
|
||||
log.Warning("could not acquire validator; dropping message")
|
||||
}
|
||||
case req := <-p.sendMsg:
|
||||
p.maybePublishMessage(req.from, req.msg.Message)
|
||||
|
||||
|
@ -328,7 +345,12 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) error {
|
|||
}
|
||||
|
||||
subs := p.getSubscriptions(&Message{pmsg}) // call before goroutine!
|
||||
|
||||
select {
|
||||
case p.throttleValidate <- struct{}{}:
|
||||
go func(pmsg *pb.Message) {
|
||||
defer func() { <-p.throttleValidate }()
|
||||
|
||||
if p.validate(subs, &Message{pmsg}) {
|
||||
p.sendMsg <- sendReq{
|
||||
from: rpc.from,
|
||||
|
@ -336,6 +358,9 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) error {
|
|||
}
|
||||
}
|
||||
}(pmsg)
|
||||
default:
|
||||
log.Warning("could not acquire validator; dropping message")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -348,7 +373,10 @@ func msgID(pmsg *pb.Message) string {
|
|||
// validate is called in a goroutine and calls the validate functions of all subs with msg as parameter.
|
||||
func (p *PubSub) validate(subs []*Subscription, msg *Message) bool {
|
||||
for _, sub := range subs {
|
||||
if sub.validate != nil && !sub.validate(msg) {
|
||||
ctx, cancel := context.WithTimeout(p.ctx, validateTimeoutMillis*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
if sub.validate != nil && !sub.validate(ctx, msg) {
|
||||
log.Debugf("validator for topic %s returned false", sub.topic)
|
||||
return false
|
||||
}
|
||||
|
@ -432,9 +460,10 @@ type addSubReq struct {
|
|||
}
|
||||
|
||||
type SubOpt func(*Subscription) error
|
||||
type Validator func(context.Context, *Message) bool
|
||||
|
||||
// WithValidator is an option that can be supplied to Subscribe. The argument is a function that returns whether or not a given message should be propagated further.
|
||||
func WithValidator(validate func(*Message) bool) func(*Subscription) error {
|
||||
func WithValidator(validate Validator) func(*Subscription) error {
|
||||
return func(sub *Subscription) error {
|
||||
sub.validate = validate
|
||||
return nil
|
||||
|
|
128
floodsub_test.go
128
floodsub_test.go
|
@ -6,6 +6,7 @@ import (
|
|||
"fmt"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -343,7 +344,7 @@ func TestValidate(t *testing.T) {
|
|||
connect(t, hosts[0], hosts[1])
|
||||
topic := "foobar"
|
||||
|
||||
sub, err := psubs[1].Subscribe(topic, WithValidator(func(msg *Message) bool {
|
||||
sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool {
|
||||
return !bytes.Contains(msg.Data, []byte("illegal"))
|
||||
}))
|
||||
if err != nil {
|
||||
|
@ -384,6 +385,131 @@ func TestValidate(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestValidateCancel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
hosts := getNetHosts(t, ctx, 2)
|
||||
psubs := getPubsubs(ctx, hosts)
|
||||
|
||||
connect(t, hosts[0], hosts[1])
|
||||
topic := "foobar"
|
||||
|
||||
sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool {
|
||||
<-ctx.Done()
|
||||
return true
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
|
||||
testmsg := []byte("this is a legal message")
|
||||
validates := true
|
||||
|
||||
p := psubs[0]
|
||||
|
||||
err = p.Publish(topic, testmsg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
select {
|
||||
case msg := <-sub.ch:
|
||||
if !validates {
|
||||
t.Log(msg)
|
||||
t.Error("expected message validation to filter out the message")
|
||||
}
|
||||
case <-time.After(333 * time.Millisecond):
|
||||
if validates {
|
||||
t.Error("expected message validation to accept the message")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOverload(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
hosts := getNetHosts(t, ctx, 2)
|
||||
psubs := getPubsubs(ctx, hosts)
|
||||
|
||||
connect(t, hosts[0], hosts[1])
|
||||
topic := "foobar"
|
||||
|
||||
block := make(chan struct{})
|
||||
|
||||
sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool {
|
||||
_, _ = <-block
|
||||
return true
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
|
||||
msgs := []struct {
|
||||
msg []byte
|
||||
validates bool
|
||||
}{
|
||||
{msg: []byte("this is a legal message"), validates: true},
|
||||
{msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true},
|
||||
{msg: []byte("there also is nothing controversial about this message"), validates: true},
|
||||
{msg: []byte("also fine"), validates: true},
|
||||
{msg: []byte("still, all good"), validates: true},
|
||||
{msg: []byte("this is getting boring"), validates: true},
|
||||
{msg: []byte("foo"), validates: true},
|
||||
{msg: []byte("foobar"), validates: true},
|
||||
{msg: []byte("foofoo"), validates: true},
|
||||
{msg: []byte("barfoo"), validates: true},
|
||||
{msg: []byte("barbar"), validates: false},
|
||||
}
|
||||
|
||||
if len(msgs) != maxConcurrency+1 {
|
||||
t.Fatalf("expected number of messages sent to be maxConcurrency+1. Got %d, expected %d", len(msgs), maxConcurrency+1)
|
||||
}
|
||||
|
||||
p := psubs[0]
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
for _, tc := range msgs {
|
||||
select {
|
||||
case msg := <-sub.ch:
|
||||
if !tc.validates {
|
||||
t.Log(msg)
|
||||
t.Error("expected message validation to drop the message because all validator goroutines are taken")
|
||||
}
|
||||
case <-time.After(333 * time.Millisecond):
|
||||
if tc.validates {
|
||||
t.Error("expected message validation to accept the message")
|
||||
}
|
||||
}
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
for i, tc := range msgs {
|
||||
err := p.Publish(topic, tc.msg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// wait a bit to let pubsub's internal state machine start validating the message
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// unblock validator goroutines after we sent one too many
|
||||
if i == len(msgs)-1 {
|
||||
close(block)
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func assertPeerLists(t *testing.T, hosts []host.Host, ps *PubSub, has ...int) {
|
||||
peers := ps.ListPeers("")
|
||||
set := make(map[peer.ID]struct{})
|
||||
|
|
|
@ -9,7 +9,7 @@ type Subscription struct {
|
|||
ch chan *Message
|
||||
cancelCh chan<- *Subscription
|
||||
err error
|
||||
validate func(*Message) bool
|
||||
validate Validator
|
||||
}
|
||||
|
||||
func (sub *Subscription) Topic() string {
|
||||
|
|
Loading…
Reference in New Issue