mirror of
https://github.com/logos-messaging/go-libp2p-pubsub.git
synced 2026-01-07 15:23:08 +00:00
complete validator functions
- make validators time out after 100ms - add context param to validator functions - add type Validator func(context.Context, *Message) bool - drop message if more than 10 messages are already being validated
This commit is contained in:
parent
89e6a06f3c
commit
02877cda71
35
floodsub.go
35
floodsub.go
@ -17,7 +17,11 @@ import (
|
|||||||
timecache "github.com/whyrusleeping/timecache"
|
timecache "github.com/whyrusleeping/timecache"
|
||||||
)
|
)
|
||||||
|
|
||||||
const ID = protocol.ID("/floodsub/1.0.0")
|
const (
|
||||||
|
ID = protocol.ID("/floodsub/1.0.0")
|
||||||
|
maxConcurrency = 10
|
||||||
|
validateTimeoutMillis = 100
|
||||||
|
)
|
||||||
|
|
||||||
var log = logging.Logger("floodsub")
|
var log = logging.Logger("floodsub")
|
||||||
|
|
||||||
@ -57,6 +61,9 @@ type PubSub struct {
|
|||||||
// sendMsg handles messages that have been validated
|
// sendMsg handles messages that have been validated
|
||||||
sendMsg chan sendReq
|
sendMsg chan sendReq
|
||||||
|
|
||||||
|
// throttleValidate bounds the number of goroutines concurrently validating messages
|
||||||
|
throttleValidate chan struct{}
|
||||||
|
|
||||||
peers map[peer.ID]chan *RPC
|
peers map[peer.ID]chan *RPC
|
||||||
seenMessages *timecache.TimeCache
|
seenMessages *timecache.TimeCache
|
||||||
|
|
||||||
@ -100,6 +107,7 @@ func NewFloodSub(ctx context.Context, h host.Host) *PubSub {
|
|||||||
peers: make(map[peer.ID]chan *RPC),
|
peers: make(map[peer.ID]chan *RPC),
|
||||||
seenMessages: timecache.NewTimeCache(time.Second * 30),
|
seenMessages: timecache.NewTimeCache(time.Second * 30),
|
||||||
counter: uint64(time.Now().UnixNano()),
|
counter: uint64(time.Now().UnixNano()),
|
||||||
|
throttleValidate: make(chan struct{}, maxConcurrency),
|
||||||
}
|
}
|
||||||
|
|
||||||
h.SetStreamHandler(ID, ps.handleNewStream)
|
h.SetStreamHandler(ID, ps.handleNewStream)
|
||||||
@ -181,14 +189,23 @@ func (p *PubSub) processLoop(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
case msg := <-p.publish:
|
case msg := <-p.publish:
|
||||||
subs := p.getSubscriptions(msg) // call before goroutine!
|
subs := p.getSubscriptions(msg) // call before goroutine!
|
||||||
|
|
||||||
|
select {
|
||||||
|
case p.throttleValidate <- struct{}{}:
|
||||||
go func() {
|
go func() {
|
||||||
|
defer func() { <-p.throttleValidate }()
|
||||||
|
|
||||||
if p.validate(subs, msg) {
|
if p.validate(subs, msg) {
|
||||||
p.sendMsg <- sendReq{
|
p.sendMsg <- sendReq{
|
||||||
from: p.host.ID(),
|
from: p.host.ID(),
|
||||||
msg: msg,
|
msg: msg,
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
default:
|
||||||
|
log.Warning("could not acquire validator; dropping message")
|
||||||
|
}
|
||||||
case req := <-p.sendMsg:
|
case req := <-p.sendMsg:
|
||||||
p.maybePublishMessage(req.from, req.msg.Message)
|
p.maybePublishMessage(req.from, req.msg.Message)
|
||||||
|
|
||||||
@ -328,7 +345,12 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
subs := p.getSubscriptions(&Message{pmsg}) // call before goroutine!
|
subs := p.getSubscriptions(&Message{pmsg}) // call before goroutine!
|
||||||
|
|
||||||
|
select {
|
||||||
|
case p.throttleValidate <- struct{}{}:
|
||||||
go func(pmsg *pb.Message) {
|
go func(pmsg *pb.Message) {
|
||||||
|
defer func() { <-p.throttleValidate }()
|
||||||
|
|
||||||
if p.validate(subs, &Message{pmsg}) {
|
if p.validate(subs, &Message{pmsg}) {
|
||||||
p.sendMsg <- sendReq{
|
p.sendMsg <- sendReq{
|
||||||
from: rpc.from,
|
from: rpc.from,
|
||||||
@ -336,6 +358,9 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}(pmsg)
|
}(pmsg)
|
||||||
|
default:
|
||||||
|
log.Warning("could not acquire validator; dropping message")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -348,7 +373,10 @@ func msgID(pmsg *pb.Message) string {
|
|||||||
// validate is called in a goroutine and calls the validate functions of all subs with msg as parameter.
|
// validate is called in a goroutine and calls the validate functions of all subs with msg as parameter.
|
||||||
func (p *PubSub) validate(subs []*Subscription, msg *Message) bool {
|
func (p *PubSub) validate(subs []*Subscription, msg *Message) bool {
|
||||||
for _, sub := range subs {
|
for _, sub := range subs {
|
||||||
if sub.validate != nil && !sub.validate(msg) {
|
ctx, cancel := context.WithTimeout(p.ctx, validateTimeoutMillis*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if sub.validate != nil && !sub.validate(ctx, msg) {
|
||||||
log.Debugf("validator for topic %s returned false", sub.topic)
|
log.Debugf("validator for topic %s returned false", sub.topic)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -432,9 +460,10 @@ type addSubReq struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type SubOpt func(*Subscription) error
|
type SubOpt func(*Subscription) error
|
||||||
|
type Validator func(context.Context, *Message) bool
|
||||||
|
|
||||||
// WithValidator is an option that can be supplied to Subscribe. The argument is a function that returns whether or not a given message should be propagated further.
|
// WithValidator is an option that can be supplied to Subscribe. The argument is a function that returns whether or not a given message should be propagated further.
|
||||||
func WithValidator(validate func(*Message) bool) func(*Subscription) error {
|
func WithValidator(validate Validator) func(*Subscription) error {
|
||||||
return func(sub *Subscription) error {
|
return func(sub *Subscription) error {
|
||||||
sub.validate = validate
|
sub.validate = validate
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
128
floodsub_test.go
128
floodsub_test.go
@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"sort"
|
"sort"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -343,7 +344,7 @@ func TestValidate(t *testing.T) {
|
|||||||
connect(t, hosts[0], hosts[1])
|
connect(t, hosts[0], hosts[1])
|
||||||
topic := "foobar"
|
topic := "foobar"
|
||||||
|
|
||||||
sub, err := psubs[1].Subscribe(topic, WithValidator(func(msg *Message) bool {
|
sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool {
|
||||||
return !bytes.Contains(msg.Data, []byte("illegal"))
|
return !bytes.Contains(msg.Data, []byte("illegal"))
|
||||||
}))
|
}))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -384,6 +385,131 @@ func TestValidate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidateCancel(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
hosts := getNetHosts(t, ctx, 2)
|
||||||
|
psubs := getPubsubs(ctx, hosts)
|
||||||
|
|
||||||
|
connect(t, hosts[0], hosts[1])
|
||||||
|
topic := "foobar"
|
||||||
|
|
||||||
|
sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool {
|
||||||
|
<-ctx.Done()
|
||||||
|
return true
|
||||||
|
}))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Millisecond * 50)
|
||||||
|
|
||||||
|
testmsg := []byte("this is a legal message")
|
||||||
|
validates := true
|
||||||
|
|
||||||
|
p := psubs[0]
|
||||||
|
|
||||||
|
err = p.Publish(topic, testmsg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case msg := <-sub.ch:
|
||||||
|
if !validates {
|
||||||
|
t.Log(msg)
|
||||||
|
t.Error("expected message validation to filter out the message")
|
||||||
|
}
|
||||||
|
case <-time.After(333 * time.Millisecond):
|
||||||
|
if validates {
|
||||||
|
t.Error("expected message validation to accept the message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateOverload(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
hosts := getNetHosts(t, ctx, 2)
|
||||||
|
psubs := getPubsubs(ctx, hosts)
|
||||||
|
|
||||||
|
connect(t, hosts[0], hosts[1])
|
||||||
|
topic := "foobar"
|
||||||
|
|
||||||
|
block := make(chan struct{})
|
||||||
|
|
||||||
|
sub, err := psubs[1].Subscribe(topic, WithValidator(func(ctx context.Context, msg *Message) bool {
|
||||||
|
_, _ = <-block
|
||||||
|
return true
|
||||||
|
}))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Millisecond * 50)
|
||||||
|
|
||||||
|
msgs := []struct {
|
||||||
|
msg []byte
|
||||||
|
validates bool
|
||||||
|
}{
|
||||||
|
{msg: []byte("this is a legal message"), validates: true},
|
||||||
|
{msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true},
|
||||||
|
{msg: []byte("there also is nothing controversial about this message"), validates: true},
|
||||||
|
{msg: []byte("also fine"), validates: true},
|
||||||
|
{msg: []byte("still, all good"), validates: true},
|
||||||
|
{msg: []byte("this is getting boring"), validates: true},
|
||||||
|
{msg: []byte("foo"), validates: true},
|
||||||
|
{msg: []byte("foobar"), validates: true},
|
||||||
|
{msg: []byte("foofoo"), validates: true},
|
||||||
|
{msg: []byte("barfoo"), validates: true},
|
||||||
|
{msg: []byte("barbar"), validates: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(msgs) != maxConcurrency+1 {
|
||||||
|
t.Fatalf("expected number of messages sent to be maxConcurrency+1. Got %d, expected %d", len(msgs), maxConcurrency+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
p := psubs[0]
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
for _, tc := range msgs {
|
||||||
|
select {
|
||||||
|
case msg := <-sub.ch:
|
||||||
|
if !tc.validates {
|
||||||
|
t.Log(msg)
|
||||||
|
t.Error("expected message validation to drop the message because all validator goroutines are taken")
|
||||||
|
}
|
||||||
|
case <-time.After(333 * time.Millisecond):
|
||||||
|
if tc.validates {
|
||||||
|
t.Error("expected message validation to accept the message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
for i, tc := range msgs {
|
||||||
|
err := p.Publish(topic, tc.msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait a bit to let pubsub's internal state machine start validating the message
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
// unblock validator goroutines after we sent one too many
|
||||||
|
if i == len(msgs)-1 {
|
||||||
|
close(block)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
func assertPeerLists(t *testing.T, hosts []host.Host, ps *PubSub, has ...int) {
|
func assertPeerLists(t *testing.T, hosts []host.Host, ps *PubSub, has ...int) {
|
||||||
peers := ps.ListPeers("")
|
peers := ps.ListPeers("")
|
||||||
set := make(map[peer.ID]struct{})
|
set := make(map[peer.ID]struct{})
|
||||||
|
|||||||
@ -9,7 +9,7 @@ type Subscription struct {
|
|||||||
ch chan *Message
|
ch chan *Message
|
||||||
cancelCh chan<- *Subscription
|
cancelCh chan<- *Subscription
|
||||||
err error
|
err error
|
||||||
validate func(*Message) bool
|
validate Validator
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sub *Subscription) Topic() string {
|
func (sub *Subscription) Topic() string {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user