From 17e835cd17d1e8749bb17654a40617cf360d17b9 Mon Sep 17 00:00:00 2001 From: Jeromy Date: Sat, 10 Sep 2016 16:03:53 -0700 Subject: [PATCH] respect contexts better --- floodsub.go | 79 +++++++++++++++--------- floodsub_test.go | 158 ++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 186 insertions(+), 51 deletions(-) diff --git a/floodsub.go b/floodsub.go index ab2835e..bc0c132 100644 --- a/floodsub.go +++ b/floodsub.go @@ -2,7 +2,9 @@ package floodsub import ( "bufio" + "context" "fmt" + "io" "sync" "time" @@ -43,6 +45,8 @@ type PubSub struct { lastMsg map[peer.ID]uint64 addSub chan *addSub + + ctx context.Context } type Message struct { @@ -60,9 +64,10 @@ type RPC struct { from peer.ID } -func NewFloodSub(h host.Host) *PubSub { +func NewFloodSub(ctx context.Context, h host.Host) *PubSub { ps := &PubSub{ host: h, + ctx: ctx, incoming: make(chan *RPC, 32), outgoing: make(chan *RPC), newPeers: make(chan inet.Stream), @@ -77,7 +82,7 @@ func NewFloodSub(h host.Host) *PubSub { h.SetStreamHandler(ID, ps.handleNewStream) h.Network().Notify(ps) - go ps.processLoop() + go ps.processLoop(ctx) return ps } @@ -99,47 +104,63 @@ func (p *PubSub) handleNewStream(s inet.Stream) { rpc := new(RPC) err := r.ReadMsg(&rpc.RPC) if err != nil { - log.Errorf("error reading rpc from %s: %s", s.Conn().RemotePeer(), err) - // TODO: cleanup of some sort + if err != io.EOF { + log.Errorf("error reading rpc from %s: %s", s.Conn().RemotePeer(), err) + } return } rpc.from = s.Conn().RemotePeer() - p.incoming <- rpc + select { + case p.incoming <- rpc: + case <-p.ctx.Done(): + return + } } } -func (p *PubSub) handleSendingMessages(s inet.Stream, in <-chan *RPC) { +func (p *PubSub) handleSendingMessages(ctx context.Context, s inet.Stream, in <-chan *RPC) { var dead bool bufw := bufio.NewWriter(s) wc := ggio.NewDelimitedWriter(bufw) + + writeMsg := func(msg proto.Message) error { + err := wc.WriteMsg(msg) + if err != nil { + return err + } + + return bufw.Flush() + } + defer wc.Close() - for rpc := range in { - if dead { - continue - } + for { + select { + case rpc, ok := <-in: + if !ok { + return + } + if dead { + // continue in order to drain messages + continue + } - err := wc.WriteMsg(&rpc.RPC) - if err != nil { - log.Errorf("writing message to %s: %s", s.Conn().RemotePeer(), err) - dead = true - go func() { - p.peerDead <- s.Conn().RemotePeer() - }() - } + err := writeMsg(&rpc.RPC) + if err != nil { + log.Errorf("writing message to %s: %s", s.Conn().RemotePeer(), err) + dead = true + go func() { + p.peerDead <- s.Conn().RemotePeer() + }() + } - err = bufw.Flush() - if err != nil { - log.Errorf("writing message to %s: %s", s.Conn().RemotePeer(), err) - dead = true - go func() { - p.peerDead <- s.Conn().RemotePeer() - }() + case <-ctx.Done(): + return } } } -func (p *PubSub) processLoop() { +func (p *PubSub) processLoop(ctx context.Context) { for { select { @@ -153,12 +174,11 @@ func (p *PubSub) processLoop() { } messages := make(chan *RPC, 32) - go p.handleSendingMessages(s, messages) + go p.handleSendingMessages(ctx, s, messages) messages <- p.getHelloPacket() p.peers[pid] = messages - fmt.Println("added peer: ", pid) case pid := <-p.peerDead: delete(p.peers, pid) case sub := <-p.addSub: @@ -186,6 +206,9 @@ func (p *PubSub) processLoop() { log.Error("publishing message: ", err) } } + case <-ctx.Done(): + log.Info("pubsub processloop shutting down") + return } } } diff --git a/floodsub_test.go b/floodsub_test.go index 4f45aa5..8e0d621 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -12,11 +12,11 @@ import ( netutil "github.com/libp2p/go-libp2p/p2p/test/util" ) -func getNetHosts(t *testing.T, n int) []host.Host { +func getNetHosts(t *testing.T, ctx context.Context, n int) []host.Host { var out []host.Host for i := 0; i < n; i++ { - h := netutil.GenHostSwarm(t, context.Background()) + h := netutil.GenHostSwarm(t, ctx) out = append(out, h) } @@ -31,6 +31,22 @@ func connect(t *testing.T, a, b host.Host) { } } +func sparseConnect(t *testing.T, hosts []host.Host) { + for i, a := range hosts { + for j := 0; j < 3; j++ { + n := rand.Intn(len(hosts)) + if n == i { + j-- + continue + } + + b := hosts[n] + + connect(t, a, b) + } + } +} + func connectAll(t *testing.T, hosts []host.Host) { for i, a := range hosts { for j, b := range hosts { @@ -43,13 +59,20 @@ func connectAll(t *testing.T, hosts []host.Host) { } } -func TestBasicFloodsub(t *testing.T) { - hosts := getNetHosts(t, 20) - +func getPubsubs(ctx context.Context, hs []host.Host) []*PubSub { var psubs []*PubSub - for _, h := range hosts { - psubs = append(psubs, NewFloodSub(h)) + for _, h := range hs { + psubs = append(psubs, NewFloodSub(ctx, h)) } + return psubs +} + +func TestBasicFloodsub(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + hosts := getNetHosts(t, ctx, 20) + + psubs := getPubsubs(ctx, hosts) var msgs []<-chan *Message for _, ps := range psubs { @@ -61,26 +84,12 @@ func TestBasicFloodsub(t *testing.T) { msgs = append(msgs, subch) } - connectAll(t, hosts) + //connectAll(t, hosts) + sparseConnect(t, hosts) time.Sleep(time.Millisecond * 100) - psubs[0].Publish("foobar", []byte("ipfs rocks")) - - for i, resp := range msgs { - fmt.Printf("reading message from peer %d\n", i) - msg := <-resp - fmt.Printf("%s - %d: topic %s, from %s: %s\n", time.Now(), i, msg.Topic, msg.From, string(msg.Data)) - } - - psubs[2].Publish("foobar", []byte("libp2p is cool too")) - for i, resp := range msgs { - fmt.Printf("reading message from peer %d\n", i) - msg := <-resp - fmt.Printf("%s - %d: topic %s, from %s: %s\n", time.Now(), i, msg.Topic, msg.From, string(msg.Data)) - } for i := 0; i < 100; i++ { - fmt.Println("loop: ", i) msg := []byte(fmt.Sprintf("%d the flooooooood %d", i, i)) owner := rand.Intn(len(psubs)) @@ -94,4 +103,107 @@ func TestBasicFloodsub(t *testing.T) { } } } + +} + +func TestMultihops(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hosts := getNetHosts(t, ctx, 6) + + psubs := getPubsubs(ctx, hosts) + + connect(t, hosts[0], hosts[1]) + connect(t, hosts[1], hosts[2]) + connect(t, hosts[2], hosts[3]) + connect(t, hosts[3], hosts[4]) + connect(t, hosts[4], hosts[5]) + + var msgChs []<-chan *Message + for i := 1; i < 6; i++ { + ch, err := psubs[i].Subscribe("foobar") + if err != nil { + t.Fatal(err) + } + msgChs = append(msgChs, ch) + } + + time.Sleep(time.Millisecond * 100) + + msg := []byte("i like cats") + err := psubs[0].Publish("foobar", msg) + if err != nil { + t.Fatal(err) + } + + // last node in the chain should get the message + select { + case out := <-msgChs[4]: + if !bytes.Equal(out.GetData(), msg) { + t.Fatal("got wrong data") + } + case <-time.After(time.Second * 5): + t.Fatal("timed out waiting for message") + } +} + +func TestReconnects(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hosts := getNetHosts(t, ctx, 10) + + psubs := getPubsubs(ctx, hosts) + + connect(t, hosts[0], hosts[1]) + connect(t, hosts[0], hosts[2]) + + A, err := psubs[1].Subscribe("cats") + if err != nil { + t.Fatal(err) + } + + B, err := psubs[2].Subscribe("cats") + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Millisecond * 100) + + msg := []byte("apples and oranges") + err = psubs[0].Publish("cats", msg) + if err != nil { + t.Fatal(err) + } + + assertReceive(t, A, msg) + assertReceive(t, B, msg) + + hosts[2].Close() + + msg2 := []byte("potato") + err = psubs[0].Publish("cats", msg2) + if err != nil { + t.Fatal(err) + } + + assertReceive(t, A, msg2) + + time.Sleep(time.Millisecond * 50) + _, ok := psubs[0].peers[hosts[2].ID()] + if ok { + t.Fatal("shouldnt have this peer anymore") + } +} + +func assertReceive(t *testing.T, ch <-chan *Message, exp []byte) { + select { + case msg := <-ch: + if !bytes.Equal(msg.GetData(), exp) { + t.Fatalf("got wrong message, expected %s but got %s", string(exp), string(msg.GetData())) + } + case <-time.After(time.Second * 5): + t.Fatal("timed out waiting for message of: ", exp) + } }