mirror of
https://github.com/logos-messaging/go-libp2p-pubsub.git
synced 2026-01-05 22:33:10 +00:00
commit
c82e67dcd3
323
floodsub.go
323
floodsub.go
@ -17,7 +17,12 @@ import (
|
|||||||
timecache "github.com/whyrusleeping/timecache"
|
timecache "github.com/whyrusleeping/timecache"
|
||||||
)
|
)
|
||||||
|
|
||||||
const ID = protocol.ID("/floodsub/1.0.0")
|
const (
|
||||||
|
ID = protocol.ID("/floodsub/1.0.0")
|
||||||
|
defaultValidateTimeout = 150 * time.Millisecond
|
||||||
|
defaultValidateConcurrency = 100
|
||||||
|
defaultValidateThrottle = 8192
|
||||||
|
)
|
||||||
|
|
||||||
var log = logging.Logger("floodsub")
|
var log = logging.Logger("floodsub")
|
||||||
|
|
||||||
@ -54,6 +59,18 @@ type PubSub struct {
|
|||||||
// topics tracks which topics each of our peers are subscribed to
|
// topics tracks which topics each of our peers are subscribed to
|
||||||
topics map[string]map[peer.ID]struct{}
|
topics map[string]map[peer.ID]struct{}
|
||||||
|
|
||||||
|
// sendMsg handles messages that have been validated
|
||||||
|
sendMsg chan *sendReq
|
||||||
|
|
||||||
|
// addVal handles validator registration requests
|
||||||
|
addVal chan *addValReq
|
||||||
|
|
||||||
|
// topicVals tracks per topic validators
|
||||||
|
topicVals map[string]*topicVal
|
||||||
|
|
||||||
|
// validateThrottle limits the number of active validation goroutines
|
||||||
|
validateThrottle chan struct{}
|
||||||
|
|
||||||
peers map[peer.ID]chan *RPC
|
peers map[peer.ID]chan *RPC
|
||||||
seenMessages *timecache.TimeCache
|
seenMessages *timecache.TimeCache
|
||||||
|
|
||||||
@ -78,24 +95,37 @@ type RPC struct {
|
|||||||
from peer.ID
|
from peer.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Option func(*PubSub) error
|
||||||
|
|
||||||
// NewFloodSub returns a new FloodSub management object
|
// NewFloodSub returns a new FloodSub management object
|
||||||
func NewFloodSub(ctx context.Context, h host.Host) *PubSub {
|
func NewFloodSub(ctx context.Context, h host.Host, opts ...Option) (*PubSub, error) {
|
||||||
ps := &PubSub{
|
ps := &PubSub{
|
||||||
host: h,
|
host: h,
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
incoming: make(chan *RPC, 32),
|
incoming: make(chan *RPC, 32),
|
||||||
publish: make(chan *Message),
|
publish: make(chan *Message),
|
||||||
newPeers: make(chan inet.Stream),
|
newPeers: make(chan inet.Stream),
|
||||||
peerDead: make(chan peer.ID),
|
peerDead: make(chan peer.ID),
|
||||||
cancelCh: make(chan *Subscription),
|
cancelCh: make(chan *Subscription),
|
||||||
getPeers: make(chan *listPeerReq),
|
getPeers: make(chan *listPeerReq),
|
||||||
addSub: make(chan *addSubReq),
|
addSub: make(chan *addSubReq),
|
||||||
getTopics: make(chan *topicReq),
|
getTopics: make(chan *topicReq),
|
||||||
myTopics: make(map[string]map[*Subscription]struct{}),
|
sendMsg: make(chan *sendReq, 32),
|
||||||
topics: make(map[string]map[peer.ID]struct{}),
|
addVal: make(chan *addValReq),
|
||||||
peers: make(map[peer.ID]chan *RPC),
|
validateThrottle: make(chan struct{}, defaultValidateThrottle),
|
||||||
seenMessages: timecache.NewTimeCache(time.Second * 30),
|
myTopics: make(map[string]map[*Subscription]struct{}),
|
||||||
counter: uint64(time.Now().UnixNano()),
|
topics: make(map[string]map[peer.ID]struct{}),
|
||||||
|
peers: make(map[peer.ID]chan *RPC),
|
||||||
|
topicVals: make(map[string]*topicVal),
|
||||||
|
seenMessages: timecache.NewTimeCache(time.Second * 30),
|
||||||
|
counter: uint64(time.Now().UnixNano()),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
err := opt(ps)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
h.SetStreamHandler(ID, ps.handleNewStream)
|
h.SetStreamHandler(ID, ps.handleNewStream)
|
||||||
@ -103,7 +133,14 @@ func NewFloodSub(ctx context.Context, h host.Host) *PubSub {
|
|||||||
|
|
||||||
go ps.processLoop(ctx)
|
go ps.processLoop(ctx)
|
||||||
|
|
||||||
return ps
|
return ps, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithValidateThrottle(n int) Option {
|
||||||
|
return func(ps *PubSub) error {
|
||||||
|
ps.validateThrottle = make(chan struct{}, n)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// processLoop handles all inputs arriving on the channels
|
// processLoop handles all inputs arriving on the channels
|
||||||
@ -176,7 +213,15 @@ func (p *PubSub) processLoop(ctx context.Context) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
case msg := <-p.publish:
|
case msg := <-p.publish:
|
||||||
p.maybePublishMessage(p.host.ID(), msg.Message)
|
vals := p.getValidators(msg)
|
||||||
|
p.pushMsg(vals, p.host.ID(), msg)
|
||||||
|
|
||||||
|
case req := <-p.sendMsg:
|
||||||
|
p.maybePublishMessage(req.from, req.msg.Message)
|
||||||
|
|
||||||
|
case req := <-p.addVal:
|
||||||
|
p.addValidator(req)
|
||||||
|
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
log.Info("pubsub processloop shutting down")
|
log.Info("pubsub processloop shutting down")
|
||||||
return
|
return
|
||||||
@ -210,24 +255,22 @@ func (p *PubSub) handleRemoveSubscription(sub *Subscription) {
|
|||||||
// subscribes to the topic.
|
// subscribes to the topic.
|
||||||
// Only called from processLoop.
|
// Only called from processLoop.
|
||||||
func (p *PubSub) handleAddSubscription(req *addSubReq) {
|
func (p *PubSub) handleAddSubscription(req *addSubReq) {
|
||||||
subs := p.myTopics[req.topic]
|
sub := req.sub
|
||||||
|
subs := p.myTopics[sub.topic]
|
||||||
|
|
||||||
// announce we want this topic
|
// announce we want this topic
|
||||||
if len(subs) == 0 {
|
if len(subs) == 0 {
|
||||||
p.announce(req.topic, true)
|
p.announce(sub.topic, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
// make new if not there
|
// make new if not there
|
||||||
if subs == nil {
|
if subs == nil {
|
||||||
p.myTopics[req.topic] = make(map[*Subscription]struct{})
|
p.myTopics[sub.topic] = make(map[*Subscription]struct{})
|
||||||
subs = p.myTopics[req.topic]
|
subs = p.myTopics[sub.topic]
|
||||||
}
|
}
|
||||||
|
|
||||||
sub := &Subscription{
|
sub.ch = make(chan *Message, 32)
|
||||||
ch: make(chan *Message, 32),
|
sub.cancelCh = p.cancelCh
|
||||||
topic: req.topic,
|
|
||||||
cancelCh: p.cancelCh,
|
|
||||||
}
|
|
||||||
|
|
||||||
p.myTopics[sub.topic][sub] = struct{}{}
|
p.myTopics[sub.topic][sub] = struct{}{}
|
||||||
|
|
||||||
@ -314,8 +357,11 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
p.maybePublishMessage(rpc.from, pmsg)
|
msg := &Message{pmsg}
|
||||||
|
vals := p.getValidators(msg)
|
||||||
|
p.pushMsg(vals, rpc.from, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -324,6 +370,75 @@ func msgID(pmsg *pb.Message) string {
|
|||||||
return string(pmsg.GetFrom()) + string(pmsg.GetSeqno())
|
return string(pmsg.GetFrom()) + string(pmsg.GetSeqno())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// pushMsg pushes a message performing validation as necessary
|
||||||
|
func (p *PubSub) pushMsg(vals []*topicVal, src peer.ID, msg *Message) {
|
||||||
|
if len(vals) > 0 {
|
||||||
|
// validation is asynchronous and globally throttled with the throttleValidate semaphore.
|
||||||
|
// the purpose of the global throttle is to bound the goncurrency possible from incoming
|
||||||
|
// network traffic; each validator also has an individual throttle to preclude
|
||||||
|
// slow (or faulty) validators from starving other topics; see validate below.
|
||||||
|
select {
|
||||||
|
case p.validateThrottle <- struct{}{}:
|
||||||
|
go func() {
|
||||||
|
p.validate(vals, src, msg)
|
||||||
|
<-p.validateThrottle
|
||||||
|
}()
|
||||||
|
default:
|
||||||
|
log.Warningf("message validation throttled; dropping message from %s", src)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.maybePublishMessage(src, msg.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate performs validation and only sends the message if all validators succeed
|
||||||
|
func (p *PubSub) validate(vals []*topicVal, src peer.ID, msg *Message) {
|
||||||
|
ctx, cancel := context.WithCancel(p.ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
rch := make(chan bool, len(vals))
|
||||||
|
rcount := 0
|
||||||
|
throttle := false
|
||||||
|
|
||||||
|
loop:
|
||||||
|
for _, val := range vals {
|
||||||
|
rcount++
|
||||||
|
|
||||||
|
select {
|
||||||
|
case val.validateThrottle <- struct{}{}:
|
||||||
|
go func(val *topicVal) {
|
||||||
|
rch <- val.validateMsg(ctx, msg)
|
||||||
|
<-val.validateThrottle
|
||||||
|
}(val)
|
||||||
|
|
||||||
|
default:
|
||||||
|
log.Debugf("validation throttled for topic %s", val.topic)
|
||||||
|
throttle = true
|
||||||
|
break loop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if throttle {
|
||||||
|
log.Warningf("message validation throttled; dropping message from %s", src)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < rcount; i++ {
|
||||||
|
valid := <-rch
|
||||||
|
if !valid {
|
||||||
|
log.Warningf("message validation failed; dropping message from %s", src)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// all validators were successful, send the message
|
||||||
|
p.sendMsg <- &sendReq{
|
||||||
|
from: src,
|
||||||
|
msg: msg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (p *PubSub) maybePublishMessage(from peer.ID, pmsg *pb.Message) {
|
func (p *PubSub) maybePublishMessage(from peer.ID, pmsg *pb.Message) {
|
||||||
id := msgID(pmsg)
|
id := msgID(pmsg)
|
||||||
if p.seenMessage(id) {
|
if p.seenMessage(id) {
|
||||||
@ -348,7 +463,7 @@ func (p *PubSub) publishMessage(from peer.ID, msg *pb.Message) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for p, _ := range tmap {
|
for p := range tmap {
|
||||||
tosend[p] = struct{}{}
|
tosend[p] = struct{}{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -375,20 +490,38 @@ func (p *PubSub) publishMessage(from peer.ID, msg *pb.Message) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type addSubReq struct {
|
// getValidators returns all validators that apply to a given message
|
||||||
topic string
|
func (p *PubSub) getValidators(msg *Message) []*topicVal {
|
||||||
resp chan *Subscription
|
var vals []*topicVal
|
||||||
|
|
||||||
|
for _, topic := range msg.GetTopicIDs() {
|
||||||
|
val, ok := p.topicVals[topic]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
vals = append(vals, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
return vals
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type addSubReq struct {
|
||||||
|
sub *Subscription
|
||||||
|
resp chan *Subscription
|
||||||
|
}
|
||||||
|
|
||||||
|
type SubOpt func(sub *Subscription) error
|
||||||
|
|
||||||
// Subscribe returns a new Subscription for the given topic
|
// Subscribe returns a new Subscription for the given topic
|
||||||
func (p *PubSub) Subscribe(topic string) (*Subscription, error) {
|
func (p *PubSub) Subscribe(topic string, opts ...SubOpt) (*Subscription, error) {
|
||||||
td := pb.TopicDescriptor{Name: &topic}
|
td := pb.TopicDescriptor{Name: &topic}
|
||||||
|
|
||||||
return p.SubscribeByTopicDescriptor(&td)
|
return p.SubscribeByTopicDescriptor(&td, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SubscribeByTopicDescriptor lets you subscribe a topic using a pb.TopicDescriptor
|
// SubscribeByTopicDescriptor lets you subscribe a topic using a pb.TopicDescriptor
|
||||||
func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor) (*Subscription, error) {
|
func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor, opts ...SubOpt) (*Subscription, error) {
|
||||||
if td.GetAuth().GetMode() != pb.TopicDescriptor_AuthOpts_NONE {
|
if td.GetAuth().GetMode() != pb.TopicDescriptor_AuthOpts_NONE {
|
||||||
return nil, fmt.Errorf("auth mode not yet supported")
|
return nil, fmt.Errorf("auth mode not yet supported")
|
||||||
}
|
}
|
||||||
@ -397,10 +530,21 @@ func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor) (*Subscripti
|
|||||||
return nil, fmt.Errorf("encryption mode not yet supported")
|
return nil, fmt.Errorf("encryption mode not yet supported")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sub := &Subscription{
|
||||||
|
topic: td.GetName(),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
err := opt(sub)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
out := make(chan *Subscription, 1)
|
out := make(chan *Subscription, 1)
|
||||||
p.addSub <- &addSubReq{
|
p.addSub <- &addSubReq{
|
||||||
topic: td.GetName(),
|
sub: sub,
|
||||||
resp: out,
|
resp: out,
|
||||||
}
|
}
|
||||||
|
|
||||||
return <-out, nil
|
return <-out, nil
|
||||||
@ -439,6 +583,12 @@ type listPeerReq struct {
|
|||||||
topic string
|
topic string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sendReq is a request to call maybePublishMessage. It is issued after the subscription verification is done.
|
||||||
|
type sendReq struct {
|
||||||
|
from peer.ID
|
||||||
|
msg *Message
|
||||||
|
}
|
||||||
|
|
||||||
// ListPeers returns a list of peers we are connected to.
|
// ListPeers returns a list of peers we are connected to.
|
||||||
func (p *PubSub) ListPeers(topic string) []peer.ID {
|
func (p *PubSub) ListPeers(topic string) []peer.ID {
|
||||||
out := make(chan []peer.ID)
|
out := make(chan []peer.ID)
|
||||||
@ -448,3 +598,100 @@ func (p *PubSub) ListPeers(topic string) []peer.ID {
|
|||||||
}
|
}
|
||||||
return <-out
|
return <-out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// per topic validators
|
||||||
|
type addValReq struct {
|
||||||
|
topic string
|
||||||
|
validate Validator
|
||||||
|
timeout time.Duration
|
||||||
|
throttle int
|
||||||
|
resp chan error
|
||||||
|
}
|
||||||
|
|
||||||
|
type topicVal struct {
|
||||||
|
topic string
|
||||||
|
validate Validator
|
||||||
|
validateTimeout time.Duration
|
||||||
|
validateThrottle chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validator is a function that validates a message
|
||||||
|
type Validator func(context.Context, *Message) bool
|
||||||
|
|
||||||
|
// ValidatorOpt is an option for RegisterTopicValidator
|
||||||
|
type ValidatorOpt func(addVal *addValReq) error
|
||||||
|
|
||||||
|
// WithValidatorTimeout is an option that sets the topic validator timeout
|
||||||
|
func WithValidatorTimeout(timeout time.Duration) ValidatorOpt {
|
||||||
|
return func(addVal *addValReq) error {
|
||||||
|
addVal.timeout = timeout
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithValidatorConcurrency is an option that sets topic validator throttle
|
||||||
|
func WithValidatorConcurrency(n int) ValidatorOpt {
|
||||||
|
return func(addVal *addValReq) error {
|
||||||
|
addVal.throttle = n
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterTopicValidator registers a validator for topic
|
||||||
|
func (p *PubSub) RegisterTopicValidator(topic string, val Validator, opts ...ValidatorOpt) error {
|
||||||
|
addVal := &addValReq{
|
||||||
|
topic: topic,
|
||||||
|
validate: val,
|
||||||
|
resp: make(chan error, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
err := opt(addVal)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.addVal <- addVal
|
||||||
|
return <-addVal.resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *PubSub) addValidator(req *addValReq) {
|
||||||
|
topic := req.topic
|
||||||
|
|
||||||
|
_, ok := ps.topicVals[topic]
|
||||||
|
if ok {
|
||||||
|
req.resp <- fmt.Errorf("Duplicate validator for topic %s", topic)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
val := &topicVal{
|
||||||
|
topic: topic,
|
||||||
|
validate: req.validate,
|
||||||
|
validateTimeout: defaultValidateTimeout,
|
||||||
|
validateThrottle: make(chan struct{}, defaultValidateConcurrency),
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.timeout > 0 {
|
||||||
|
val.validateTimeout = req.timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.throttle > 0 {
|
||||||
|
val.validateThrottle = make(chan struct{}, req.throttle)
|
||||||
|
}
|
||||||
|
|
||||||
|
ps.topicVals[topic] = val
|
||||||
|
req.resp <- nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (val *topicVal) validateMsg(ctx context.Context, msg *Message) bool {
|
||||||
|
vctx, cancel := context.WithTimeout(ctx, val.validateTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
valid := val.validate(vctx, msg)
|
||||||
|
if !valid {
|
||||||
|
log.Debugf("validation failed for topic %s", val.topic)
|
||||||
|
}
|
||||||
|
|
||||||
|
return valid
|
||||||
|
}
|
||||||
|
|||||||
192
floodsub_test.go
192
floodsub_test.go
@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"sort"
|
"sort"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -80,10 +81,14 @@ func connectAll(t *testing.T, hosts []host.Host) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPubsubs(ctx context.Context, hs []host.Host) []*PubSub {
|
func getPubsubs(ctx context.Context, hs []host.Host, opts ...Option) []*PubSub {
|
||||||
var psubs []*PubSub
|
var psubs []*PubSub
|
||||||
for _, h := range hs {
|
for _, h := range hs {
|
||||||
psubs = append(psubs, NewFloodSub(ctx, h))
|
ps, err := NewFloodSub(ctx, h, opts...)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
psubs = append(psubs, ps)
|
||||||
}
|
}
|
||||||
return psubs
|
return psubs
|
||||||
}
|
}
|
||||||
@ -289,11 +294,14 @@ func TestSelfReceive(t *testing.T) {
|
|||||||
|
|
||||||
host := getNetHosts(t, ctx, 1)[0]
|
host := getNetHosts(t, ctx, 1)[0]
|
||||||
|
|
||||||
psub := NewFloodSub(ctx, host)
|
psub, err := NewFloodSub(ctx, host)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
msg := []byte("hello world")
|
msg := []byte("hello world")
|
||||||
|
|
||||||
err := psub.Publish("foobar", msg)
|
err = psub.Publish("foobar", msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -323,14 +331,181 @@ func TestOneToOne(t *testing.T) {
|
|||||||
|
|
||||||
connect(t, hosts[0], hosts[1])
|
connect(t, hosts[0], hosts[1])
|
||||||
|
|
||||||
ch, err := psubs[1].Subscribe("foobar")
|
sub, err := psubs[1].Subscribe("foobar")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
time.Sleep(time.Millisecond * 50)
|
time.Sleep(time.Millisecond * 50)
|
||||||
|
|
||||||
checkMessageRouting(t, "foobar", psubs, []*Subscription{ch})
|
checkMessageRouting(t, "foobar", psubs, []*Subscription{sub})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
hosts := getNetHosts(t, ctx, 2)
|
||||||
|
psubs := getPubsubs(ctx, hosts)
|
||||||
|
|
||||||
|
connect(t, hosts[0], hosts[1])
|
||||||
|
topic := "foobar"
|
||||||
|
|
||||||
|
err := psubs[1].RegisterTopicValidator(topic, func(ctx context.Context, msg *Message) bool {
|
||||||
|
return !bytes.Contains(msg.Data, []byte("illegal"))
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sub, err := psubs[1].Subscribe(topic)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Millisecond * 50)
|
||||||
|
|
||||||
|
msgs := []struct {
|
||||||
|
msg []byte
|
||||||
|
validates bool
|
||||||
|
}{
|
||||||
|
{msg: []byte("this is a legal message"), validates: true},
|
||||||
|
{msg: []byte("there also is nothing controversial about this message"), validates: true},
|
||||||
|
{msg: []byte("openly illegal content will be censored"), validates: false},
|
||||||
|
{msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range msgs {
|
||||||
|
for _, p := range psubs {
|
||||||
|
err := p.Publish(topic, tc.msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case msg := <-sub.ch:
|
||||||
|
if !tc.validates {
|
||||||
|
t.Log(msg)
|
||||||
|
t.Error("expected message validation to filter out the message")
|
||||||
|
}
|
||||||
|
case <-time.After(333 * time.Millisecond):
|
||||||
|
if tc.validates {
|
||||||
|
t.Error("expected message validation to accept the message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateOverload(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
type msg struct {
|
||||||
|
msg []byte
|
||||||
|
validates bool
|
||||||
|
}
|
||||||
|
|
||||||
|
tcs := []struct {
|
||||||
|
msgs []msg
|
||||||
|
|
||||||
|
maxConcurrency int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
maxConcurrency: 10,
|
||||||
|
msgs: []msg{
|
||||||
|
{msg: []byte("this is a legal message"), validates: true},
|
||||||
|
{msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true},
|
||||||
|
{msg: []byte("there also is nothing controversial about this message"), validates: true},
|
||||||
|
{msg: []byte("also fine"), validates: true},
|
||||||
|
{msg: []byte("still, all good"), validates: true},
|
||||||
|
{msg: []byte("this is getting boring"), validates: true},
|
||||||
|
{msg: []byte("foo"), validates: true},
|
||||||
|
{msg: []byte("foobar"), validates: true},
|
||||||
|
{msg: []byte("foofoo"), validates: true},
|
||||||
|
{msg: []byte("barfoo"), validates: true},
|
||||||
|
{msg: []byte("oh no!"), validates: false},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
maxConcurrency: 2,
|
||||||
|
msgs: []msg{
|
||||||
|
{msg: []byte("this is a legal message"), validates: true},
|
||||||
|
{msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true},
|
||||||
|
{msg: []byte("oh no!"), validates: false},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tcs {
|
||||||
|
|
||||||
|
hosts := getNetHosts(t, ctx, 2)
|
||||||
|
psubs := getPubsubs(ctx, hosts)
|
||||||
|
|
||||||
|
connect(t, hosts[0], hosts[1])
|
||||||
|
topic := "foobar"
|
||||||
|
|
||||||
|
block := make(chan struct{})
|
||||||
|
|
||||||
|
err := psubs[1].RegisterTopicValidator(topic,
|
||||||
|
func(ctx context.Context, msg *Message) bool {
|
||||||
|
<-block
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
WithValidatorConcurrency(tc.maxConcurrency))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sub, err := psubs[1].Subscribe(topic)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Millisecond * 50)
|
||||||
|
|
||||||
|
if len(tc.msgs) != tc.maxConcurrency+1 {
|
||||||
|
t.Fatalf("expected number of messages sent to be maxConcurrency+1. Got %d, expected %d", len(tc.msgs), tc.maxConcurrency+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
p := psubs[0]
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
for _, tmsg := range tc.msgs {
|
||||||
|
select {
|
||||||
|
case msg := <-sub.ch:
|
||||||
|
if !tmsg.validates {
|
||||||
|
t.Log(msg)
|
||||||
|
t.Error("expected message validation to drop the message because all validator goroutines are taken")
|
||||||
|
}
|
||||||
|
case <-time.After(333 * time.Millisecond):
|
||||||
|
if tmsg.validates {
|
||||||
|
t.Error("expected message validation to accept the message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
for i, tmsg := range tc.msgs {
|
||||||
|
err := p.Publish(topic, tmsg.msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait a bit to let pubsub's internal state machine start validating the message
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
// unblock validator goroutines after we sent one too many
|
||||||
|
if i == len(tc.msgs)-1 {
|
||||||
|
close(block)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertPeerLists(t *testing.T, hosts []host.Host, ps *PubSub, has ...int) {
|
func assertPeerLists(t *testing.T, hosts []host.Host, ps *PubSub, has ...int) {
|
||||||
@ -414,7 +589,10 @@ func TestSubReporting(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
host := getNetHosts(t, ctx, 1)[0]
|
host := getNetHosts(t, ctx, 1)[0]
|
||||||
psub := NewFloodSub(ctx, host)
|
psub, err := NewFloodSub(ctx, host)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
fooSub, err := psub.Subscribe("foo")
|
fooSub, err := psub.Subscribe("foo")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user