go-libp2p-pubsub/floodsub.go

354 lines
6.8 KiB
Go
Raw Normal View History

2016-09-10 03:13:50 +00:00
package floodsub
import (
2016-09-10 15:28:29 +00:00
"bufio"
"fmt"
"sync"
"time"
2016-09-10 15:14:17 +00:00
pb "github.com/whyrusleeping/go-floodsub/pb"
ggio "github.com/gogo/protobuf/io"
proto "github.com/gogo/protobuf/proto"
2016-09-10 15:28:29 +00:00
peer "github.com/ipfs/go-libp2p-peer"
logging "github.com/ipfs/go-log"
host "github.com/libp2p/go-libp2p/p2p/host"
inet "github.com/libp2p/go-libp2p/p2p/net"
protocol "github.com/libp2p/go-libp2p/p2p/protocol"
)
2016-09-10 03:13:50 +00:00
const ID = protocol.ID("/floodsub/1.0.0")
2016-09-10 15:14:17 +00:00
var (
2016-09-10 03:13:50 +00:00
AddSubMessageType = "sub"
UnsubMessageType = "unsub"
PubMessageType = "pub"
)
var log = logging.Logger("floodsub")
type PubSub struct {
host host.Host
incoming chan *RPC
outgoing chan *RPC
newPeers chan inet.Stream
peerDead chan peer.ID
myTopics map[string]chan *Message
pubsubLk sync.Mutex
topics map[string]map[peer.ID]struct{}
peers map[peer.ID]chan *RPC
lastMsg map[peer.ID]uint64
addSub chan *addSub
}
type Message struct {
2016-09-10 15:14:17 +00:00
*pb.Message
}
2016-09-10 15:14:17 +00:00
func (m *Message) GetFrom() peer.ID {
return peer.ID(m.Message.GetFrom())
}
type RPC struct {
2016-09-10 15:14:17 +00:00
pb.RPC
// unexported on purpose, not sending this over the wire
from peer.ID
}
2016-09-10 03:13:50 +00:00
func NewFloodSub(h host.Host) *PubSub {
ps := &PubSub{
host: h,
incoming: make(chan *RPC, 32),
outgoing: make(chan *RPC),
newPeers: make(chan inet.Stream),
2016-09-10 15:14:17 +00:00
peerDead: make(chan peer.ID),
addSub: make(chan *addSub),
myTopics: make(map[string]chan *Message),
topics: make(map[string]map[peer.ID]struct{}),
peers: make(map[peer.ID]chan *RPC),
lastMsg: make(map[peer.ID]uint64),
}
h.SetStreamHandler(ID, ps.handleNewStream)
h.Network().Notify(ps)
go ps.processLoop()
return ps
}
func (p *PubSub) getHelloPacket() *RPC {
var rpc RPC
for t, _ := range p.myTopics {
rpc.Topics = append(rpc.Topics, t)
}
2016-09-10 15:14:17 +00:00
rpc.Type = &AddSubMessageType
return &rpc
}
func (p *PubSub) handleNewStream(s inet.Stream) {
defer s.Close()
2016-09-10 15:14:17 +00:00
r := ggio.NewDelimitedReader(s, 1<<20)
for {
rpc := new(RPC)
2016-09-10 15:14:17 +00:00
err := r.ReadMsg(&rpc.RPC)
if err != nil {
log.Errorf("error reading rpc from %s: %s", s.Conn().RemotePeer(), err)
// TODO: cleanup of some sort
return
}
rpc.from = s.Conn().RemotePeer()
p.incoming <- rpc
}
}
func (p *PubSub) handleSendingMessages(s inet.Stream, in <-chan *RPC) {
var dead bool
2016-09-10 15:28:29 +00:00
bufw := bufio.NewWriter(s)
wc := ggio.NewDelimitedWriter(bufw)
2016-09-10 15:14:17 +00:00
defer wc.Close()
for rpc := range in {
if dead {
continue
}
2016-09-10 15:14:17 +00:00
err := wc.WriteMsg(&rpc.RPC)
if err != nil {
log.Errorf("writing message to %s: %s", s.Conn().RemotePeer(), err)
dead = true
go func() {
p.peerDead <- s.Conn().RemotePeer()
}()
}
2016-09-10 15:28:29 +00:00
err = bufw.Flush()
if err != nil {
log.Errorf("writing message to %s: %s", s.Conn().RemotePeer(), err)
dead = true
go func() {
p.peerDead <- s.Conn().RemotePeer()
}()
}
}
}
func (p *PubSub) processLoop() {
for {
select {
case s := <-p.newPeers:
pid := s.Conn().RemotePeer()
_, ok := p.peers[pid]
if ok {
log.Error("already have connection to peer: ", pid)
s.Close()
continue
}
messages := make(chan *RPC, 32)
go p.handleSendingMessages(s, messages)
messages <- p.getHelloPacket()
p.peers[pid] = messages
fmt.Println("added peer: ", pid)
case pid := <-p.peerDead:
delete(p.peers, pid)
case sub := <-p.addSub:
2016-09-10 03:13:50 +00:00
p.handleSubscriptionChange(sub)
case rpc := <-p.incoming:
err := p.handleIncomingRPC(rpc)
if err != nil {
log.Error("handling RPC: ", err)
}
case rpc := <-p.outgoing:
2016-09-10 15:14:17 +00:00
switch rpc.GetType() {
2016-09-10 03:13:50 +00:00
case AddSubMessageType, UnsubMessageType:
for _, mch := range p.peers {
mch <- rpc
}
2016-09-10 03:13:50 +00:00
case PubMessageType:
//fmt.Println("publishing outgoing message")
err := p.recvMessage(rpc)
if err != nil {
log.Error("error receiving message: ", err)
}
err = p.publishMessage(rpc)
if err != nil {
log.Error("publishing message: ", err)
}
}
}
}
}
2016-09-10 03:13:50 +00:00
func (p *PubSub) handleSubscriptionChange(sub *addSub) {
ch, ok := p.myTopics[sub.topic]
out := &RPC{
2016-09-10 15:14:17 +00:00
RPC: pb.RPC{
Topics: []string{sub.topic},
},
2016-09-10 03:13:50 +00:00
}
if sub.cancel {
if !ok {
return
}
close(ch)
delete(p.myTopics, sub.topic)
2016-09-10 15:14:17 +00:00
out.Type = &UnsubMessageType
2016-09-10 03:13:50 +00:00
} else {
if ok {
// we don't allow multiple subs per topic at this point
sub.resp <- nil
return
}
resp := make(chan *Message, 16)
p.myTopics[sub.topic] = resp
sub.resp <- resp
2016-09-10 15:14:17 +00:00
out.Type = &AddSubMessageType
2016-09-10 03:13:50 +00:00
}
go func() {
p.outgoing <- out
}()
}
func (p *PubSub) recvMessage(rpc *RPC) error {
2016-09-10 15:14:17 +00:00
subch, ok := p.myTopics[rpc.Msg.GetTopic()]
if ok {
//fmt.Println("writing out to subscriber!")
2016-09-10 15:14:17 +00:00
subch <- &Message{rpc.Msg}
}
return nil
}
func (p *PubSub) handleIncomingRPC(rpc *RPC) error {
2016-09-10 15:14:17 +00:00
switch rpc.GetType() {
2016-09-10 03:13:50 +00:00
case AddSubMessageType:
for _, t := range rpc.Topics {
tmap, ok := p.topics[t]
if !ok {
tmap = make(map[peer.ID]struct{})
p.topics[t] = tmap
}
tmap[rpc.from] = struct{}{}
}
2016-09-10 03:13:50 +00:00
case UnsubMessageType:
for _, t := range rpc.Topics {
tmap, ok := p.topics[t]
if !ok {
return nil
}
delete(tmap, rpc.from)
}
2016-09-10 03:13:50 +00:00
case PubMessageType:
if rpc.Msg == nil {
return fmt.Errorf("nil pub message")
}
2016-09-10 03:13:50 +00:00
2016-09-10 15:14:17 +00:00
msg := &Message{rpc.Msg}
// Note: Obviously this is an incredibly insecure way of
// filtering out "messages we've already seen". But it works for a
// cool demo, so i'm not gonna waste time thinking about it any more
2016-09-10 15:14:17 +00:00
if p.lastMsg[msg.GetFrom()] >= msg.GetSeqno() {
//log.Error("skipping 'old' message")
return nil
}
2016-09-10 15:14:17 +00:00
if msg.GetFrom() == p.host.ID() {
log.Error("skipping message from self")
return nil
}
2016-09-10 15:14:17 +00:00
p.lastMsg[msg.GetFrom()] = msg.GetSeqno()
if err := p.recvMessage(rpc); err != nil {
log.Error("error receiving message: ", err)
}
err := p.publishMessage(rpc)
if err != nil {
log.Error("publish message: ", err)
}
}
return nil
}
func (p *PubSub) publishMessage(rpc *RPC) error {
2016-09-10 15:14:17 +00:00
tmap, ok := p.topics[rpc.Msg.GetTopic()]
if !ok {
return nil
}
for pid, _ := range tmap {
2016-09-10 15:14:17 +00:00
if pid == rpc.from || pid == peer.ID(rpc.Msg.GetFrom()) {
continue
}
mch, ok := p.peers[pid]
if !ok {
continue
}
go func() { mch <- rpc }()
}
return nil
}
type addSub struct {
2016-09-10 03:13:50 +00:00
topic string
cancel bool
resp chan chan *Message
}
func (p *PubSub) Subscribe(topic string) (<-chan *Message, error) {
resp := make(chan chan *Message)
p.addSub <- &addSub{
topic: topic,
resp: resp,
}
outch := <-resp
if outch == nil {
return nil, fmt.Errorf("error, duplicate subscription")
}
return outch, nil
}
func (p *PubSub) Unsub(topic string) {
2016-09-10 03:13:50 +00:00
p.addSub <- &addSub{
topic: topic,
cancel: true,
}
}
func (p *PubSub) Publish(topic string, data []byte) error {
2016-09-10 15:14:17 +00:00
seqno := uint64(time.Now().UnixNano())
p.outgoing <- &RPC{
2016-09-10 15:14:17 +00:00
RPC: pb.RPC{
Msg: &pb.Message{
Data: data,
Topic: &topic,
From: proto.String(string(p.host.ID())),
Seqno: &seqno,
},
Type: &PubMessageType,
},
}
return nil
}