respect contexts better
This commit is contained in:
parent
c0d5b0ef26
commit
17e835cd17
79
floodsub.go
79
floodsub.go
|
@ -2,7 +2,9 @@ package floodsub
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -43,6 +45,8 @@ type PubSub struct {
|
|||
lastMsg map[peer.ID]uint64
|
||||
|
||||
addSub chan *addSub
|
||||
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
|
@ -60,9 +64,10 @@ type RPC struct {
|
|||
from peer.ID
|
||||
}
|
||||
|
||||
func NewFloodSub(h host.Host) *PubSub {
|
||||
func NewFloodSub(ctx context.Context, h host.Host) *PubSub {
|
||||
ps := &PubSub{
|
||||
host: h,
|
||||
ctx: ctx,
|
||||
incoming: make(chan *RPC, 32),
|
||||
outgoing: make(chan *RPC),
|
||||
newPeers: make(chan inet.Stream),
|
||||
|
@ -77,7 +82,7 @@ func NewFloodSub(h host.Host) *PubSub {
|
|||
h.SetStreamHandler(ID, ps.handleNewStream)
|
||||
h.Network().Notify(ps)
|
||||
|
||||
go ps.processLoop()
|
||||
go ps.processLoop(ctx)
|
||||
|
||||
return ps
|
||||
}
|
||||
|
@ -99,47 +104,63 @@ func (p *PubSub) handleNewStream(s inet.Stream) {
|
|||
rpc := new(RPC)
|
||||
err := r.ReadMsg(&rpc.RPC)
|
||||
if err != nil {
|
||||
log.Errorf("error reading rpc from %s: %s", s.Conn().RemotePeer(), err)
|
||||
// TODO: cleanup of some sort
|
||||
if err != io.EOF {
|
||||
log.Errorf("error reading rpc from %s: %s", s.Conn().RemotePeer(), err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
rpc.from = s.Conn().RemotePeer()
|
||||
p.incoming <- rpc
|
||||
select {
|
||||
case p.incoming <- rpc:
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PubSub) handleSendingMessages(s inet.Stream, in <-chan *RPC) {
|
||||
func (p *PubSub) handleSendingMessages(ctx context.Context, s inet.Stream, in <-chan *RPC) {
|
||||
var dead bool
|
||||
bufw := bufio.NewWriter(s)
|
||||
wc := ggio.NewDelimitedWriter(bufw)
|
||||
|
||||
writeMsg := func(msg proto.Message) error {
|
||||
err := wc.WriteMsg(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bufw.Flush()
|
||||
}
|
||||
|
||||
defer wc.Close()
|
||||
for rpc := range in {
|
||||
if dead {
|
||||
continue
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case rpc, ok := <-in:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if dead {
|
||||
// continue in order to drain messages
|
||||
continue
|
||||
}
|
||||
|
||||
err := wc.WriteMsg(&rpc.RPC)
|
||||
if err != nil {
|
||||
log.Errorf("writing message to %s: %s", s.Conn().RemotePeer(), err)
|
||||
dead = true
|
||||
go func() {
|
||||
p.peerDead <- s.Conn().RemotePeer()
|
||||
}()
|
||||
}
|
||||
err := writeMsg(&rpc.RPC)
|
||||
if err != nil {
|
||||
log.Errorf("writing message to %s: %s", s.Conn().RemotePeer(), err)
|
||||
dead = true
|
||||
go func() {
|
||||
p.peerDead <- s.Conn().RemotePeer()
|
||||
}()
|
||||
}
|
||||
|
||||
err = bufw.Flush()
|
||||
if err != nil {
|
||||
log.Errorf("writing message to %s: %s", s.Conn().RemotePeer(), err)
|
||||
dead = true
|
||||
go func() {
|
||||
p.peerDead <- s.Conn().RemotePeer()
|
||||
}()
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PubSub) processLoop() {
|
||||
func (p *PubSub) processLoop(ctx context.Context) {
|
||||
|
||||
for {
|
||||
select {
|
||||
|
@ -153,12 +174,11 @@ func (p *PubSub) processLoop() {
|
|||
}
|
||||
|
||||
messages := make(chan *RPC, 32)
|
||||
go p.handleSendingMessages(s, messages)
|
||||
go p.handleSendingMessages(ctx, s, messages)
|
||||
messages <- p.getHelloPacket()
|
||||
|
||||
p.peers[pid] = messages
|
||||
|
||||
fmt.Println("added peer: ", pid)
|
||||
case pid := <-p.peerDead:
|
||||
delete(p.peers, pid)
|
||||
case sub := <-p.addSub:
|
||||
|
@ -186,6 +206,9 @@ func (p *PubSub) processLoop() {
|
|||
log.Error("publishing message: ", err)
|
||||
}
|
||||
}
|
||||
case <-ctx.Done():
|
||||
log.Info("pubsub processloop shutting down")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
158
floodsub_test.go
158
floodsub_test.go
|
@ -12,11 +12,11 @@ import (
|
|||
netutil "github.com/libp2p/go-libp2p/p2p/test/util"
|
||||
)
|
||||
|
||||
func getNetHosts(t *testing.T, n int) []host.Host {
|
||||
func getNetHosts(t *testing.T, ctx context.Context, n int) []host.Host {
|
||||
var out []host.Host
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
h := netutil.GenHostSwarm(t, context.Background())
|
||||
h := netutil.GenHostSwarm(t, ctx)
|
||||
out = append(out, h)
|
||||
}
|
||||
|
||||
|
@ -31,6 +31,22 @@ func connect(t *testing.T, a, b host.Host) {
|
|||
}
|
||||
}
|
||||
|
||||
func sparseConnect(t *testing.T, hosts []host.Host) {
|
||||
for i, a := range hosts {
|
||||
for j := 0; j < 3; j++ {
|
||||
n := rand.Intn(len(hosts))
|
||||
if n == i {
|
||||
j--
|
||||
continue
|
||||
}
|
||||
|
||||
b := hosts[n]
|
||||
|
||||
connect(t, a, b)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func connectAll(t *testing.T, hosts []host.Host) {
|
||||
for i, a := range hosts {
|
||||
for j, b := range hosts {
|
||||
|
@ -43,13 +59,20 @@ func connectAll(t *testing.T, hosts []host.Host) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestBasicFloodsub(t *testing.T) {
|
||||
hosts := getNetHosts(t, 20)
|
||||
|
||||
func getPubsubs(ctx context.Context, hs []host.Host) []*PubSub {
|
||||
var psubs []*PubSub
|
||||
for _, h := range hosts {
|
||||
psubs = append(psubs, NewFloodSub(h))
|
||||
for _, h := range hs {
|
||||
psubs = append(psubs, NewFloodSub(ctx, h))
|
||||
}
|
||||
return psubs
|
||||
}
|
||||
|
||||
func TestBasicFloodsub(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
hosts := getNetHosts(t, ctx, 20)
|
||||
|
||||
psubs := getPubsubs(ctx, hosts)
|
||||
|
||||
var msgs []<-chan *Message
|
||||
for _, ps := range psubs {
|
||||
|
@ -61,26 +84,12 @@ func TestBasicFloodsub(t *testing.T) {
|
|||
msgs = append(msgs, subch)
|
||||
}
|
||||
|
||||
connectAll(t, hosts)
|
||||
//connectAll(t, hosts)
|
||||
sparseConnect(t, hosts)
|
||||
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
psubs[0].Publish("foobar", []byte("ipfs rocks"))
|
||||
|
||||
for i, resp := range msgs {
|
||||
fmt.Printf("reading message from peer %d\n", i)
|
||||
msg := <-resp
|
||||
fmt.Printf("%s - %d: topic %s, from %s: %s\n", time.Now(), i, msg.Topic, msg.From, string(msg.Data))
|
||||
}
|
||||
|
||||
psubs[2].Publish("foobar", []byte("libp2p is cool too"))
|
||||
for i, resp := range msgs {
|
||||
fmt.Printf("reading message from peer %d\n", i)
|
||||
msg := <-resp
|
||||
fmt.Printf("%s - %d: topic %s, from %s: %s\n", time.Now(), i, msg.Topic, msg.From, string(msg.Data))
|
||||
}
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
fmt.Println("loop: ", i)
|
||||
msg := []byte(fmt.Sprintf("%d the flooooooood %d", i, i))
|
||||
|
||||
owner := rand.Intn(len(psubs))
|
||||
|
@ -94,4 +103,107 @@ func TestBasicFloodsub(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestMultihops(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
hosts := getNetHosts(t, ctx, 6)
|
||||
|
||||
psubs := getPubsubs(ctx, hosts)
|
||||
|
||||
connect(t, hosts[0], hosts[1])
|
||||
connect(t, hosts[1], hosts[2])
|
||||
connect(t, hosts[2], hosts[3])
|
||||
connect(t, hosts[3], hosts[4])
|
||||
connect(t, hosts[4], hosts[5])
|
||||
|
||||
var msgChs []<-chan *Message
|
||||
for i := 1; i < 6; i++ {
|
||||
ch, err := psubs[i].Subscribe("foobar")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
msgChs = append(msgChs, ch)
|
||||
}
|
||||
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
|
||||
msg := []byte("i like cats")
|
||||
err := psubs[0].Publish("foobar", msg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// last node in the chain should get the message
|
||||
select {
|
||||
case out := <-msgChs[4]:
|
||||
if !bytes.Equal(out.GetData(), msg) {
|
||||
t.Fatal("got wrong data")
|
||||
}
|
||||
case <-time.After(time.Second * 5):
|
||||
t.Fatal("timed out waiting for message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconnects(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
hosts := getNetHosts(t, ctx, 10)
|
||||
|
||||
psubs := getPubsubs(ctx, hosts)
|
||||
|
||||
connect(t, hosts[0], hosts[1])
|
||||
connect(t, hosts[0], hosts[2])
|
||||
|
||||
A, err := psubs[1].Subscribe("cats")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
B, err := psubs[2].Subscribe("cats")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
|
||||
msg := []byte("apples and oranges")
|
||||
err = psubs[0].Publish("cats", msg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
assertReceive(t, A, msg)
|
||||
assertReceive(t, B, msg)
|
||||
|
||||
hosts[2].Close()
|
||||
|
||||
msg2 := []byte("potato")
|
||||
err = psubs[0].Publish("cats", msg2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
assertReceive(t, A, msg2)
|
||||
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
_, ok := psubs[0].peers[hosts[2].ID()]
|
||||
if ok {
|
||||
t.Fatal("shouldnt have this peer anymore")
|
||||
}
|
||||
}
|
||||
|
||||
func assertReceive(t *testing.T, ch <-chan *Message, exp []byte) {
|
||||
select {
|
||||
case msg := <-ch:
|
||||
if !bytes.Equal(msg.GetData(), exp) {
|
||||
t.Fatalf("got wrong message, expected %s but got %s", string(exp), string(msg.GetData()))
|
||||
}
|
||||
case <-time.After(time.Second * 5):
|
||||
t.Fatal("timed out waiting for message of: ", exp)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue