Merge pull request #55 from libp2p/feat/validators

Message Validators
This commit is contained in:
Steven Allen 2018-01-24 06:03:20 +00:00 committed by GitHub
commit c82e67dcd3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 470 additions and 45 deletions

View File

@ -17,7 +17,12 @@ import (
timecache "github.com/whyrusleeping/timecache"
)
const ID = protocol.ID("/floodsub/1.0.0")
const (
ID = protocol.ID("/floodsub/1.0.0")
defaultValidateTimeout = 150 * time.Millisecond
defaultValidateConcurrency = 100
defaultValidateThrottle = 8192
)
var log = logging.Logger("floodsub")
@ -54,6 +59,18 @@ type PubSub struct {
// topics tracks which topics each of our peers are subscribed to
topics map[string]map[peer.ID]struct{}
// sendMsg handles messages that have been validated
sendMsg chan *sendReq
// addVal handles validator registration requests
addVal chan *addValReq
// topicVals tracks per topic validators
topicVals map[string]*topicVal
// validateThrottle limits the number of active validation goroutines
validateThrottle chan struct{}
peers map[peer.ID]chan *RPC
seenMessages *timecache.TimeCache
@ -78,24 +95,37 @@ type RPC struct {
from peer.ID
}
type Option func(*PubSub) error
// NewFloodSub returns a new FloodSub management object
func NewFloodSub(ctx context.Context, h host.Host) *PubSub {
func NewFloodSub(ctx context.Context, h host.Host, opts ...Option) (*PubSub, error) {
ps := &PubSub{
host: h,
ctx: ctx,
incoming: make(chan *RPC, 32),
publish: make(chan *Message),
newPeers: make(chan inet.Stream),
peerDead: make(chan peer.ID),
cancelCh: make(chan *Subscription),
getPeers: make(chan *listPeerReq),
addSub: make(chan *addSubReq),
getTopics: make(chan *topicReq),
myTopics: make(map[string]map[*Subscription]struct{}),
topics: make(map[string]map[peer.ID]struct{}),
peers: make(map[peer.ID]chan *RPC),
seenMessages: timecache.NewTimeCache(time.Second * 30),
counter: uint64(time.Now().UnixNano()),
host: h,
ctx: ctx,
incoming: make(chan *RPC, 32),
publish: make(chan *Message),
newPeers: make(chan inet.Stream),
peerDead: make(chan peer.ID),
cancelCh: make(chan *Subscription),
getPeers: make(chan *listPeerReq),
addSub: make(chan *addSubReq),
getTopics: make(chan *topicReq),
sendMsg: make(chan *sendReq, 32),
addVal: make(chan *addValReq),
validateThrottle: make(chan struct{}, defaultValidateThrottle),
myTopics: make(map[string]map[*Subscription]struct{}),
topics: make(map[string]map[peer.ID]struct{}),
peers: make(map[peer.ID]chan *RPC),
topicVals: make(map[string]*topicVal),
seenMessages: timecache.NewTimeCache(time.Second * 30),
counter: uint64(time.Now().UnixNano()),
}
for _, opt := range opts {
err := opt(ps)
if err != nil {
return nil, err
}
}
h.SetStreamHandler(ID, ps.handleNewStream)
@ -103,7 +133,14 @@ func NewFloodSub(ctx context.Context, h host.Host) *PubSub {
go ps.processLoop(ctx)
return ps
return ps, nil
}
func WithValidateThrottle(n int) Option {
return func(ps *PubSub) error {
ps.validateThrottle = make(chan struct{}, n)
return nil
}
}
// processLoop handles all inputs arriving on the channels
@ -176,7 +213,15 @@ func (p *PubSub) processLoop(ctx context.Context) {
continue
}
case msg := <-p.publish:
p.maybePublishMessage(p.host.ID(), msg.Message)
vals := p.getValidators(msg)
p.pushMsg(vals, p.host.ID(), msg)
case req := <-p.sendMsg:
p.maybePublishMessage(req.from, req.msg.Message)
case req := <-p.addVal:
p.addValidator(req)
case <-ctx.Done():
log.Info("pubsub processloop shutting down")
return
@ -210,24 +255,22 @@ func (p *PubSub) handleRemoveSubscription(sub *Subscription) {
// subscribes to the topic.
// Only called from processLoop.
func (p *PubSub) handleAddSubscription(req *addSubReq) {
subs := p.myTopics[req.topic]
sub := req.sub
subs := p.myTopics[sub.topic]
// announce we want this topic
if len(subs) == 0 {
p.announce(req.topic, true)
p.announce(sub.topic, true)
}
// make new if not there
if subs == nil {
p.myTopics[req.topic] = make(map[*Subscription]struct{})
subs = p.myTopics[req.topic]
p.myTopics[sub.topic] = make(map[*Subscription]struct{})
subs = p.myTopics[sub.topic]
}
sub := &Subscription{
ch: make(chan *Message, 32),
topic: req.topic,
cancelCh: p.cancelCh,
}
sub.ch = make(chan *Message, 32)
sub.cancelCh = p.cancelCh
p.myTopics[sub.topic][sub] = struct{}{}
@ -314,8 +357,11 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) error {
continue
}
p.maybePublishMessage(rpc.from, pmsg)
msg := &Message{pmsg}
vals := p.getValidators(msg)
p.pushMsg(vals, rpc.from, msg)
}
return nil
}
@ -324,6 +370,75 @@ func msgID(pmsg *pb.Message) string {
return string(pmsg.GetFrom()) + string(pmsg.GetSeqno())
}
// pushMsg pushes a message performing validation as necessary
func (p *PubSub) pushMsg(vals []*topicVal, src peer.ID, msg *Message) {
if len(vals) > 0 {
// validation is asynchronous and globally throttled with the throttleValidate semaphore.
// the purpose of the global throttle is to bound the goncurrency possible from incoming
// network traffic; each validator also has an individual throttle to preclude
// slow (or faulty) validators from starving other topics; see validate below.
select {
case p.validateThrottle <- struct{}{}:
go func() {
p.validate(vals, src, msg)
<-p.validateThrottle
}()
default:
log.Warningf("message validation throttled; dropping message from %s", src)
}
return
}
p.maybePublishMessage(src, msg.Message)
}
// validate performs validation and only sends the message if all validators succeed
func (p *PubSub) validate(vals []*topicVal, src peer.ID, msg *Message) {
ctx, cancel := context.WithCancel(p.ctx)
defer cancel()
rch := make(chan bool, len(vals))
rcount := 0
throttle := false
loop:
for _, val := range vals {
rcount++
select {
case val.validateThrottle <- struct{}{}:
go func(val *topicVal) {
rch <- val.validateMsg(ctx, msg)
<-val.validateThrottle
}(val)
default:
log.Debugf("validation throttled for topic %s", val.topic)
throttle = true
break loop
}
}
if throttle {
log.Warningf("message validation throttled; dropping message from %s", src)
return
}
for i := 0; i < rcount; i++ {
valid := <-rch
if !valid {
log.Warningf("message validation failed; dropping message from %s", src)
return
}
}
// all validators were successful, send the message
p.sendMsg <- &sendReq{
from: src,
msg: msg,
}
}
func (p *PubSub) maybePublishMessage(from peer.ID, pmsg *pb.Message) {
id := msgID(pmsg)
if p.seenMessage(id) {
@ -348,7 +463,7 @@ func (p *PubSub) publishMessage(from peer.ID, msg *pb.Message) error {
continue
}
for p, _ := range tmap {
for p := range tmap {
tosend[p] = struct{}{}
}
}
@ -375,20 +490,38 @@ func (p *PubSub) publishMessage(from peer.ID, msg *pb.Message) error {
return nil
}
type addSubReq struct {
topic string
resp chan *Subscription
// getValidators returns all validators that apply to a given message
func (p *PubSub) getValidators(msg *Message) []*topicVal {
var vals []*topicVal
for _, topic := range msg.GetTopicIDs() {
val, ok := p.topicVals[topic]
if !ok {
continue
}
vals = append(vals, val)
}
return vals
}
type addSubReq struct {
sub *Subscription
resp chan *Subscription
}
type SubOpt func(sub *Subscription) error
// Subscribe returns a new Subscription for the given topic
func (p *PubSub) Subscribe(topic string) (*Subscription, error) {
func (p *PubSub) Subscribe(topic string, opts ...SubOpt) (*Subscription, error) {
td := pb.TopicDescriptor{Name: &topic}
return p.SubscribeByTopicDescriptor(&td)
return p.SubscribeByTopicDescriptor(&td, opts...)
}
// SubscribeByTopicDescriptor lets you subscribe a topic using a pb.TopicDescriptor
func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor) (*Subscription, error) {
func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor, opts ...SubOpt) (*Subscription, error) {
if td.GetAuth().GetMode() != pb.TopicDescriptor_AuthOpts_NONE {
return nil, fmt.Errorf("auth mode not yet supported")
}
@ -397,10 +530,21 @@ func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor) (*Subscripti
return nil, fmt.Errorf("encryption mode not yet supported")
}
sub := &Subscription{
topic: td.GetName(),
}
for _, opt := range opts {
err := opt(sub)
if err != nil {
return nil, err
}
}
out := make(chan *Subscription, 1)
p.addSub <- &addSubReq{
topic: td.GetName(),
resp: out,
sub: sub,
resp: out,
}
return <-out, nil
@ -439,6 +583,12 @@ type listPeerReq struct {
topic string
}
// sendReq is a request to call maybePublishMessage. It is issued after the subscription verification is done.
type sendReq struct {
from peer.ID
msg *Message
}
// ListPeers returns a list of peers we are connected to.
func (p *PubSub) ListPeers(topic string) []peer.ID {
out := make(chan []peer.ID)
@ -448,3 +598,100 @@ func (p *PubSub) ListPeers(topic string) []peer.ID {
}
return <-out
}
// per topic validators
type addValReq struct {
topic string
validate Validator
timeout time.Duration
throttle int
resp chan error
}
type topicVal struct {
topic string
validate Validator
validateTimeout time.Duration
validateThrottle chan struct{}
}
// Validator is a function that validates a message
type Validator func(context.Context, *Message) bool
// ValidatorOpt is an option for RegisterTopicValidator
type ValidatorOpt func(addVal *addValReq) error
// WithValidatorTimeout is an option that sets the topic validator timeout
func WithValidatorTimeout(timeout time.Duration) ValidatorOpt {
return func(addVal *addValReq) error {
addVal.timeout = timeout
return nil
}
}
// WithValidatorConcurrency is an option that sets topic validator throttle
func WithValidatorConcurrency(n int) ValidatorOpt {
return func(addVal *addValReq) error {
addVal.throttle = n
return nil
}
}
// RegisterTopicValidator registers a validator for topic
func (p *PubSub) RegisterTopicValidator(topic string, val Validator, opts ...ValidatorOpt) error {
addVal := &addValReq{
topic: topic,
validate: val,
resp: make(chan error, 1),
}
for _, opt := range opts {
err := opt(addVal)
if err != nil {
return err
}
}
p.addVal <- addVal
return <-addVal.resp
}
func (ps *PubSub) addValidator(req *addValReq) {
topic := req.topic
_, ok := ps.topicVals[topic]
if ok {
req.resp <- fmt.Errorf("Duplicate validator for topic %s", topic)
return
}
val := &topicVal{
topic: topic,
validate: req.validate,
validateTimeout: defaultValidateTimeout,
validateThrottle: make(chan struct{}, defaultValidateConcurrency),
}
if req.timeout > 0 {
val.validateTimeout = req.timeout
}
if req.throttle > 0 {
val.validateThrottle = make(chan struct{}, req.throttle)
}
ps.topicVals[topic] = val
req.resp <- nil
}
func (val *topicVal) validateMsg(ctx context.Context, msg *Message) bool {
vctx, cancel := context.WithTimeout(ctx, val.validateTimeout)
defer cancel()
valid := val.validate(vctx, msg)
if !valid {
log.Debugf("validation failed for topic %s", val.topic)
}
return valid
}

View File

@ -6,6 +6,7 @@ import (
"fmt"
"math/rand"
"sort"
"sync"
"testing"
"time"
@ -80,10 +81,14 @@ func connectAll(t *testing.T, hosts []host.Host) {
}
}
func getPubsubs(ctx context.Context, hs []host.Host) []*PubSub {
func getPubsubs(ctx context.Context, hs []host.Host, opts ...Option) []*PubSub {
var psubs []*PubSub
for _, h := range hs {
psubs = append(psubs, NewFloodSub(ctx, h))
ps, err := NewFloodSub(ctx, h, opts...)
if err != nil {
panic(err)
}
psubs = append(psubs, ps)
}
return psubs
}
@ -289,11 +294,14 @@ func TestSelfReceive(t *testing.T) {
host := getNetHosts(t, ctx, 1)[0]
psub := NewFloodSub(ctx, host)
psub, err := NewFloodSub(ctx, host)
if err != nil {
t.Fatal(err)
}
msg := []byte("hello world")
err := psub.Publish("foobar", msg)
err = psub.Publish("foobar", msg)
if err != nil {
t.Fatal(err)
}
@ -323,14 +331,181 @@ func TestOneToOne(t *testing.T) {
connect(t, hosts[0], hosts[1])
ch, err := psubs[1].Subscribe("foobar")
sub, err := psubs[1].Subscribe("foobar")
if err != nil {
t.Fatal(err)
}
time.Sleep(time.Millisecond * 50)
checkMessageRouting(t, "foobar", psubs, []*Subscription{ch})
checkMessageRouting(t, "foobar", psubs, []*Subscription{sub})
}
func TestValidate(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hosts := getNetHosts(t, ctx, 2)
psubs := getPubsubs(ctx, hosts)
connect(t, hosts[0], hosts[1])
topic := "foobar"
err := psubs[1].RegisterTopicValidator(topic, func(ctx context.Context, msg *Message) bool {
return !bytes.Contains(msg.Data, []byte("illegal"))
})
if err != nil {
t.Fatal(err)
}
sub, err := psubs[1].Subscribe(topic)
if err != nil {
t.Fatal(err)
}
time.Sleep(time.Millisecond * 50)
msgs := []struct {
msg []byte
validates bool
}{
{msg: []byte("this is a legal message"), validates: true},
{msg: []byte("there also is nothing controversial about this message"), validates: true},
{msg: []byte("openly illegal content will be censored"), validates: false},
{msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true},
}
for _, tc := range msgs {
for _, p := range psubs {
err := p.Publish(topic, tc.msg)
if err != nil {
t.Fatal(err)
}
select {
case msg := <-sub.ch:
if !tc.validates {
t.Log(msg)
t.Error("expected message validation to filter out the message")
}
case <-time.After(333 * time.Millisecond):
if tc.validates {
t.Error("expected message validation to accept the message")
}
}
}
}
}
func TestValidateOverload(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
type msg struct {
msg []byte
validates bool
}
tcs := []struct {
msgs []msg
maxConcurrency int
}{
{
maxConcurrency: 10,
msgs: []msg{
{msg: []byte("this is a legal message"), validates: true},
{msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true},
{msg: []byte("there also is nothing controversial about this message"), validates: true},
{msg: []byte("also fine"), validates: true},
{msg: []byte("still, all good"), validates: true},
{msg: []byte("this is getting boring"), validates: true},
{msg: []byte("foo"), validates: true},
{msg: []byte("foobar"), validates: true},
{msg: []byte("foofoo"), validates: true},
{msg: []byte("barfoo"), validates: true},
{msg: []byte("oh no!"), validates: false},
},
},
{
maxConcurrency: 2,
msgs: []msg{
{msg: []byte("this is a legal message"), validates: true},
{msg: []byte("but subversive actors will use leetspeek to spread 1ll3g4l content"), validates: true},
{msg: []byte("oh no!"), validates: false},
},
},
}
for _, tc := range tcs {
hosts := getNetHosts(t, ctx, 2)
psubs := getPubsubs(ctx, hosts)
connect(t, hosts[0], hosts[1])
topic := "foobar"
block := make(chan struct{})
err := psubs[1].RegisterTopicValidator(topic,
func(ctx context.Context, msg *Message) bool {
<-block
return true
},
WithValidatorConcurrency(tc.maxConcurrency))
if err != nil {
t.Fatal(err)
}
sub, err := psubs[1].Subscribe(topic)
if err != nil {
t.Fatal(err)
}
time.Sleep(time.Millisecond * 50)
if len(tc.msgs) != tc.maxConcurrency+1 {
t.Fatalf("expected number of messages sent to be maxConcurrency+1. Got %d, expected %d", len(tc.msgs), tc.maxConcurrency+1)
}
p := psubs[0]
var wg sync.WaitGroup
wg.Add(1)
go func() {
for _, tmsg := range tc.msgs {
select {
case msg := <-sub.ch:
if !tmsg.validates {
t.Log(msg)
t.Error("expected message validation to drop the message because all validator goroutines are taken")
}
case <-time.After(333 * time.Millisecond):
if tmsg.validates {
t.Error("expected message validation to accept the message")
}
}
}
wg.Done()
}()
for i, tmsg := range tc.msgs {
err := p.Publish(topic, tmsg.msg)
if err != nil {
t.Fatal(err)
}
// wait a bit to let pubsub's internal state machine start validating the message
time.Sleep(10 * time.Millisecond)
// unblock validator goroutines after we sent one too many
if i == len(tc.msgs)-1 {
close(block)
}
}
wg.Wait()
}
}
func assertPeerLists(t *testing.T, hosts []host.Host, ps *PubSub, has ...int) {
@ -414,7 +589,10 @@ func TestSubReporting(t *testing.T) {
defer cancel()
host := getNetHosts(t, ctx, 1)[0]
psub := NewFloodSub(ctx, host)
psub, err := NewFloodSub(ctx, host)
if err != nil {
t.Fatal(err)
}
fooSub, err := psub.Subscribe("foo")
if err != nil {