go-libp2p-pubsub/pubsub.go
vyzo b09c9d1a48 check and mark seen messages prior to validation
this allows us to avoid revalidating messages, either because they were concurrently
received or were previously rejected by the validator.
also allows us to filter invalid messages from gossip.
2018-06-12 08:00:17 -07:00

725 lines
16 KiB
Go

package floodsub
import (
"context"
"encoding/binary"
"fmt"
"sync/atomic"
"time"
pb "github.com/libp2p/go-floodsub/pb"
logging "github.com/ipfs/go-log"
host "github.com/libp2p/go-libp2p-host"
inet "github.com/libp2p/go-libp2p-net"
peer "github.com/libp2p/go-libp2p-peer"
protocol "github.com/libp2p/go-libp2p-protocol"
timecache "github.com/whyrusleeping/timecache"
)
const (
defaultValidateTimeout = 150 * time.Millisecond
defaultValidateConcurrency = 100
defaultValidateThrottle = 8192
)
var log = logging.Logger("pubsub")
type PubSub struct {
// atomic counter for seqnos
// NOTE: Must be declared at the top of the struct as we perform atomic
// operations on this field.
//
// See: https://golang.org/pkg/sync/atomic/#pkg-note-BUG
counter uint64
host host.Host
rt PubSubRouter
// incoming messages from other peers
incoming chan *RPC
// messages we are publishing out to our peers
publish chan *Message
// addSub is a control channel for us to add and remove subscriptions
addSub chan *addSubReq
// get list of topics we are subscribed to
getTopics chan *topicReq
// get chan of peers we are connected to
getPeers chan *listPeerReq
// send subscription here to cancel it
cancelCh chan *Subscription
// a notification channel for incoming streams from other peers
newPeers chan inet.Stream
// a notification channel for when our peers die
peerDead chan peer.ID
// The set of topics we are subscribed to
myTopics map[string]map[*Subscription]struct{}
// topics tracks which topics each of our peers are subscribed to
topics map[string]map[peer.ID]struct{}
// sendMsg handles messages that have been validated
sendMsg chan *sendReq
// addVal handles validator registration requests
addVal chan *addValReq
// rmVal handles validator unregistration requests
rmVal chan *rmValReq
// topicVals tracks per topic validators
topicVals map[string]*topicVal
// validateThrottle limits the number of active validation goroutines
validateThrottle chan struct{}
peers map[peer.ID]chan *RPC
seenMessages *timecache.TimeCache
ctx context.Context
}
// PubSubRouter is the message router component of PubSub
type PubSubRouter interface {
Protocols() []protocol.ID
Attach(*PubSub)
AddPeer(peer.ID, protocol.ID)
RemovePeer(peer.ID)
HandleRPC(*RPC)
Publish(peer.ID, *pb.Message)
Join(topic string)
Leave(topic string)
}
type Message struct {
*pb.Message
}
func (m *Message) GetFrom() peer.ID {
return peer.ID(m.Message.GetFrom())
}
type RPC struct {
pb.RPC
// unexported on purpose, not sending this over the wire
from peer.ID
}
type Option func(*PubSub) error
// NewFloodSub returns a new PubSub management object
func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option) (*PubSub, error) {
ps := &PubSub{
host: h,
ctx: ctx,
rt: rt,
incoming: make(chan *RPC, 32),
publish: make(chan *Message),
newPeers: make(chan inet.Stream),
peerDead: make(chan peer.ID),
cancelCh: make(chan *Subscription),
getPeers: make(chan *listPeerReq),
addSub: make(chan *addSubReq),
getTopics: make(chan *topicReq),
sendMsg: make(chan *sendReq, 32),
addVal: make(chan *addValReq),
rmVal: make(chan *rmValReq),
validateThrottle: make(chan struct{}, defaultValidateThrottle),
myTopics: make(map[string]map[*Subscription]struct{}),
topics: make(map[string]map[peer.ID]struct{}),
peers: make(map[peer.ID]chan *RPC),
topicVals: make(map[string]*topicVal),
seenMessages: timecache.NewTimeCache(time.Second * 30),
counter: uint64(time.Now().UnixNano()),
}
for _, opt := range opts {
err := opt(ps)
if err != nil {
return nil, err
}
}
rt.Attach(ps)
for _, id := range rt.Protocols() {
h.SetStreamHandler(id, ps.handleNewStream)
}
h.Network().Notify((*PubSubNotif)(ps))
go ps.processLoop(ctx)
return ps, nil
}
func WithValidateThrottle(n int) Option {
return func(ps *PubSub) error {
ps.validateThrottle = make(chan struct{}, n)
return nil
}
}
// processLoop handles all inputs arriving on the channels
func (p *PubSub) processLoop(ctx context.Context) {
defer func() {
// Clean up go routines.
for _, ch := range p.peers {
close(ch)
}
p.peers = nil
p.topics = nil
}()
for {
select {
case s := <-p.newPeers:
pid := s.Conn().RemotePeer()
ch, ok := p.peers[pid]
if ok {
log.Error("already have connection to peer: ", pid)
close(ch)
}
messages := make(chan *RPC, 32)
go p.handleSendingMessages(ctx, s, messages)
messages <- p.getHelloPacket()
p.peers[pid] = messages
p.rt.AddPeer(pid, s.Protocol())
case pid := <-p.peerDead:
ch, ok := p.peers[pid]
if ok {
close(ch)
}
delete(p.peers, pid)
for _, t := range p.topics {
delete(t, pid)
}
p.rt.RemovePeer(pid)
case treq := <-p.getTopics:
var out []string
for t := range p.myTopics {
out = append(out, t)
}
treq.resp <- out
case sub := <-p.cancelCh:
p.handleRemoveSubscription(sub)
case sub := <-p.addSub:
p.handleAddSubscription(sub)
case preq := <-p.getPeers:
tmap, ok := p.topics[preq.topic]
if preq.topic != "" && !ok {
preq.resp <- nil
continue
}
var peers []peer.ID
for p := range p.peers {
if preq.topic != "" {
_, ok := tmap[p]
if !ok {
continue
}
}
peers = append(peers, p)
}
preq.resp <- peers
case rpc := <-p.incoming:
p.handleIncomingRPC(rpc)
case msg := <-p.publish:
vals := p.getValidators(msg)
p.pushMsg(vals, p.host.ID(), msg)
case req := <-p.sendMsg:
p.publishMessage(req.from, req.msg.Message)
case req := <-p.addVal:
p.addValidator(req)
case req := <-p.rmVal:
p.rmValidator(req)
case <-ctx.Done():
log.Info("pubsub processloop shutting down")
return
}
}
}
// handleRemoveSubscription removes Subscription sub from bookeeping.
// If this was the last Subscription for a given topic, it will also announce
// that this node is not subscribing to this topic anymore.
// Only called from processLoop.
func (p *PubSub) handleRemoveSubscription(sub *Subscription) {
subs := p.myTopics[sub.topic]
if subs == nil {
return
}
sub.err = fmt.Errorf("subscription cancelled by calling sub.Cancel()")
close(sub.ch)
delete(subs, sub)
if len(subs) == 0 {
delete(p.myTopics, sub.topic)
p.announce(sub.topic, false)
p.rt.Leave(sub.topic)
}
}
// handleAddSubscription adds a Subscription for a particular topic. If it is
// the first Subscription for the topic, it will announce that this node
// subscribes to the topic.
// Only called from processLoop.
func (p *PubSub) handleAddSubscription(req *addSubReq) {
sub := req.sub
subs := p.myTopics[sub.topic]
// announce we want this topic
if len(subs) == 0 {
p.announce(sub.topic, true)
p.rt.Join(sub.topic)
}
// make new if not there
if subs == nil {
p.myTopics[sub.topic] = make(map[*Subscription]struct{})
subs = p.myTopics[sub.topic]
}
sub.ch = make(chan *Message, 32)
sub.cancelCh = p.cancelCh
p.myTopics[sub.topic][sub] = struct{}{}
req.resp <- sub
}
// announce announces whether or not this node is interested in a given topic
// Only called from processLoop.
func (p *PubSub) announce(topic string, sub bool) {
subopt := &pb.RPC_SubOpts{
Topicid: &topic,
Subscribe: &sub,
}
out := rpcWithSubs(subopt)
for pid, peer := range p.peers {
select {
case peer <- out:
default:
log.Infof("dropping announce message to peer %s: queue full", pid)
}
}
}
// notifySubs sends a given message to all corresponding subscribers.
// Only called from processLoop.
func (p *PubSub) notifySubs(msg *pb.Message) {
for _, topic := range msg.GetTopicIDs() {
subs := p.myTopics[topic]
for f := range subs {
select {
case f.ch <- &Message{msg}:
default:
log.Infof("Can't deliver message to subscription for topic %s; subscriber too slow", topic)
}
}
}
}
// seenMessage returns whether we already saw this message before
func (p *PubSub) seenMessage(id string) bool {
return p.seenMessages.Has(id)
}
// markSeen marks a message as seen such that seenMessage returns `true' for the given id
func (p *PubSub) markSeen(id string) {
p.seenMessages.Add(id)
}
// subscribedToMessage returns whether we are subscribed to one of the topics
// of a given message
func (p *PubSub) subscribedToMsg(msg *pb.Message) bool {
if len(p.myTopics) == 0 {
return false
}
for _, t := range msg.GetTopicIDs() {
if _, ok := p.myTopics[t]; ok {
return true
}
}
return false
}
func (p *PubSub) handleIncomingRPC(rpc *RPC) {
for _, subopt := range rpc.GetSubscriptions() {
t := subopt.GetTopicid()
if subopt.GetSubscribe() {
tmap, ok := p.topics[t]
if !ok {
tmap = make(map[peer.ID]struct{})
p.topics[t] = tmap
}
tmap[rpc.from] = struct{}{}
} else {
tmap, ok := p.topics[t]
if !ok {
continue
}
delete(tmap, rpc.from)
}
}
for _, pmsg := range rpc.GetPublish() {
if !p.subscribedToMsg(pmsg) {
log.Warning("received message we didn't subscribe to. Dropping.")
continue
}
msg := &Message{pmsg}
vals := p.getValidators(msg)
p.pushMsg(vals, rpc.from, msg)
}
p.rt.HandleRPC(rpc)
}
// msgID returns a unique ID of the passed Message
func msgID(pmsg *pb.Message) string {
return string(pmsg.GetFrom()) + string(pmsg.GetSeqno())
}
// pushMsg pushes a message performing validation as necessary
func (p *PubSub) pushMsg(vals []*topicVal, src peer.ID, msg *Message) {
id := msgID(msg.Message)
if p.seenMessage(id) {
return
}
p.markSeen(id)
if len(vals) > 0 {
// validation is asynchronous and globally throttled with the throttleValidate semaphore.
// the purpose of the global throttle is to bound the goncurrency possible from incoming
// network traffic; each validator also has an individual throttle to preclude
// slow (or faulty) validators from starving other topics; see validate below.
select {
case p.validateThrottle <- struct{}{}:
go func() {
p.validate(vals, src, msg)
<-p.validateThrottle
}()
default:
log.Warningf("message validation throttled; dropping message from %s", src)
}
return
}
p.publishMessage(src, msg.Message)
}
// validate performs validation and only sends the message if all validators succeed
func (p *PubSub) validate(vals []*topicVal, src peer.ID, msg *Message) {
ctx, cancel := context.WithCancel(p.ctx)
defer cancel()
rch := make(chan bool, len(vals))
rcount := 0
throttle := false
loop:
for _, val := range vals {
rcount++
select {
case val.validateThrottle <- struct{}{}:
go func(val *topicVal) {
rch <- val.validateMsg(ctx, msg)
<-val.validateThrottle
}(val)
default:
log.Debugf("validation throttled for topic %s", val.topic)
throttle = true
break loop
}
}
if throttle {
log.Warningf("message validation throttled; dropping message from %s", src)
return
}
for i := 0; i < rcount; i++ {
valid := <-rch
if !valid {
log.Warningf("message validation failed; dropping message from %s", src)
return
}
}
// all validators were successful, send the message
p.sendMsg <- &sendReq{
from: src,
msg: msg,
}
}
func (p *PubSub) publishMessage(from peer.ID, pmsg *pb.Message) {
p.notifySubs(pmsg)
p.rt.Publish(from, pmsg)
}
// getValidators returns all validators that apply to a given message
func (p *PubSub) getValidators(msg *Message) []*topicVal {
var vals []*topicVal
for _, topic := range msg.GetTopicIDs() {
val, ok := p.topicVals[topic]
if !ok {
continue
}
vals = append(vals, val)
}
return vals
}
type addSubReq struct {
sub *Subscription
resp chan *Subscription
}
type SubOpt func(sub *Subscription) error
// Subscribe returns a new Subscription for the given topic
func (p *PubSub) Subscribe(topic string, opts ...SubOpt) (*Subscription, error) {
td := pb.TopicDescriptor{Name: &topic}
return p.SubscribeByTopicDescriptor(&td, opts...)
}
// SubscribeByTopicDescriptor lets you subscribe a topic using a pb.TopicDescriptor
func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor, opts ...SubOpt) (*Subscription, error) {
if td.GetAuth().GetMode() != pb.TopicDescriptor_AuthOpts_NONE {
return nil, fmt.Errorf("auth mode not yet supported")
}
if td.GetEnc().GetMode() != pb.TopicDescriptor_EncOpts_NONE {
return nil, fmt.Errorf("encryption mode not yet supported")
}
sub := &Subscription{
topic: td.GetName(),
}
for _, opt := range opts {
err := opt(sub)
if err != nil {
return nil, err
}
}
out := make(chan *Subscription, 1)
p.addSub <- &addSubReq{
sub: sub,
resp: out,
}
return <-out, nil
}
type topicReq struct {
resp chan []string
}
// GetTopics returns the topics this node is subscribed to
func (p *PubSub) GetTopics() []string {
out := make(chan []string, 1)
p.getTopics <- &topicReq{resp: out}
return <-out
}
// Publish publishes data under the given topic
func (p *PubSub) Publish(topic string, data []byte) error {
seqno := make([]byte, 8)
counter := atomic.AddUint64(&p.counter, 1)
binary.BigEndian.PutUint64(seqno, counter)
p.publish <- &Message{
&pb.Message{
Data: data,
TopicIDs: []string{topic},
From: []byte(p.host.ID()),
Seqno: seqno,
},
}
return nil
}
type listPeerReq struct {
resp chan []peer.ID
topic string
}
// sendReq is a request to call publishMessage.
// It is issued after message validation is done.
type sendReq struct {
from peer.ID
msg *Message
}
// ListPeers returns a list of peers we are connected to.
func (p *PubSub) ListPeers(topic string) []peer.ID {
out := make(chan []peer.ID)
p.getPeers <- &listPeerReq{
resp: out,
topic: topic,
}
return <-out
}
// per topic validators
type addValReq struct {
topic string
validate Validator
timeout time.Duration
throttle int
resp chan error
}
type rmValReq struct {
topic string
resp chan error
}
type topicVal struct {
topic string
validate Validator
validateTimeout time.Duration
validateThrottle chan struct{}
}
// Validator is a function that validates a message
type Validator func(context.Context, *Message) bool
// ValidatorOpt is an option for RegisterTopicValidator
type ValidatorOpt func(addVal *addValReq) error
// WithValidatorTimeout is an option that sets the topic validator timeout
func WithValidatorTimeout(timeout time.Duration) ValidatorOpt {
return func(addVal *addValReq) error {
addVal.timeout = timeout
return nil
}
}
// WithValidatorConcurrency is an option that sets topic validator throttle
func WithValidatorConcurrency(n int) ValidatorOpt {
return func(addVal *addValReq) error {
addVal.throttle = n
return nil
}
}
// RegisterTopicValidator registers a validator for topic
func (p *PubSub) RegisterTopicValidator(topic string, val Validator, opts ...ValidatorOpt) error {
addVal := &addValReq{
topic: topic,
validate: val,
resp: make(chan error, 1),
}
for _, opt := range opts {
err := opt(addVal)
if err != nil {
return err
}
}
p.addVal <- addVal
return <-addVal.resp
}
func (ps *PubSub) addValidator(req *addValReq) {
topic := req.topic
_, ok := ps.topicVals[topic]
if ok {
req.resp <- fmt.Errorf("Duplicate validator for topic %s", topic)
return
}
val := &topicVal{
topic: topic,
validate: req.validate,
validateTimeout: defaultValidateTimeout,
validateThrottle: make(chan struct{}, defaultValidateConcurrency),
}
if req.timeout > 0 {
val.validateTimeout = req.timeout
}
if req.throttle > 0 {
val.validateThrottle = make(chan struct{}, req.throttle)
}
ps.topicVals[topic] = val
req.resp <- nil
}
// UnregisterTopicValidator removes a validator from a topic
// returns an error if there was no validator registered with the topic
func (p *PubSub) UnregisterTopicValidator(topic string) error {
rmVal := &rmValReq{
topic: topic,
resp: make(chan error, 1),
}
p.rmVal <- rmVal
return <-rmVal.resp
}
func (ps *PubSub) rmValidator(req *rmValReq) {
topic := req.topic
_, ok := ps.topicVals[topic]
if ok {
delete(ps.topicVals, topic)
req.resp <- nil
} else {
req.resp <- fmt.Errorf("No validator for topic %s", topic)
}
}
func (val *topicVal) validateMsg(ctx context.Context, msg *Message) bool {
vctx, cancel := context.WithTimeout(ctx, val.validateTimeout)
defer cancel()
valid := val.validate(vctx, msg)
if !valid {
log.Debugf("validation failed for topic %s", val.topic)
}
return valid
}