diff --git a/blacklist.go b/blacklist.go new file mode 100644 index 0000000..17bb241 --- /dev/null +++ b/blacklist.go @@ -0,0 +1,28 @@ +package pubsub + +import ( + peer "github.com/libp2p/go-libp2p-peer" +) + +// Blacklist is an interface for peer blacklisting. +type Blacklist interface { + Add(peer.ID) + Contains(peer.ID) bool +} + +// MapBlacklist is a blacklist implementation using a perfect map +type MapBlacklist map[peer.ID]struct{} + +// NewMapBlacklist creates a new MapBlacklist +func NewMapBlacklist() Blacklist { + return MapBlacklist(make(map[peer.ID]struct{})) +} + +func (b MapBlacklist) Add(p peer.ID) { + b[p] = struct{}{} +} + +func (b MapBlacklist) Contains(p peer.ID) bool { + _, ok := b[p] + return ok +} diff --git a/pubsub.go b/pubsub.go index 790db65..3c5cebe 100644 --- a/pubsub.go +++ b/pubsub.go @@ -99,7 +99,7 @@ type PubSub struct { eval chan func() // peer blacklist - blacklist map[peer.ID]struct{} + blacklist Blacklist blacklistPeer chan peer.ID peers map[peer.ID]chan *RPC @@ -183,7 +183,7 @@ func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option topics: make(map[string]map[peer.ID]struct{}), peers: make(map[peer.ID]chan *RPC), topicVals: make(map[string]*topicVal), - blacklist: make(map[peer.ID]struct{}), + blacklist: NewMapBlacklist(), blacklistPeer: make(chan peer.ID), seenMessages: timecache.NewTimeCache(TimeCacheDuration), counter: uint64(time.Now().UnixNano()), @@ -268,6 +268,15 @@ func WithStrictSignatureVerification(required bool) Option { } } +// WithBlacklist provides an implementation of the blacklist; the default is a +// MapBlacklist +func WithBlacklist(b Blacklist) Option { + return func(p *PubSub) error { + p.blacklist = b + return nil + } +} + // processLoop handles all inputs arriving on the channels func (p *PubSub) processLoop(ctx context.Context) { defer func() { @@ -282,14 +291,12 @@ func (p *PubSub) processLoop(ctx context.Context) { for { select { case pid := <-p.newPeers: - _, ok := p.peers[pid] - if ok { + if p.blacklist.Contains(pid) { log.Warning("already have connection to peer: ", pid) continue } - _, ok = p.blacklist[pid] - if ok { + if p.blacklist.Contains(pid) { log.Warning("ignoring connection from blacklisted peer: ", pid) continue } @@ -309,8 +316,7 @@ func (p *PubSub) processLoop(ctx context.Context) { continue } - _, ok = p.blacklist[pid] - if ok { + if p.blacklist.Contains(pid) { log.Warning("closing stream for blacklisted peer: ", pid) close(ch) s.Reset() @@ -396,7 +402,7 @@ func (p *PubSub) processLoop(ctx context.Context) { case pid := <-p.blacklistPeer: log.Infof("Blacklisting peer %s", pid) - p.blacklist[pid] = struct{}{} + p.blacklist.Add(pid) ch, ok := p.peers[pid] if ok { @@ -602,13 +608,13 @@ func msgID(pmsg *pb.Message) string { // pushMsg pushes a message performing validation as necessary func (p *PubSub) pushMsg(vals []*topicVal, src peer.ID, msg *Message) { // reject messages from blacklisted peers - if _, ok := p.blacklist[src]; ok { + if p.blacklist.Contains(src) { log.Warningf("dropping message from blacklisted peer %s", src) return } // even if they are forwarded by good peers - if _, ok := p.blacklist[msg.GetFrom()]; ok { + if p.blacklist.Contains(msg.GetFrom()) { log.Warningf("dropping message from blacklisted source %s", src) return }