diff --git a/comm.go b/comm.go index 5657abb..adda4d4 100644 --- a/comm.go +++ b/comm.go @@ -39,6 +39,10 @@ func (p *PubSub) handleNewStream(s inet.Stream) { // but it doesn't hurt to send it. s.Close() } + select { + case p.peerDead <- s.Conn().RemotePeer(): + case <-p.ctx.Done(): + } return } @@ -54,7 +58,6 @@ func (p *PubSub) handleNewStream(s inet.Stream) { } func (p *PubSub) handleSendingMessages(ctx context.Context, s inet.Stream, outgoing <-chan *RPC) { - var dead bool bufw := bufio.NewWriter(s) wc := ggio.NewDelimitedWriter(bufw) @@ -74,21 +77,16 @@ func (p *PubSub) handleSendingMessages(ctx context.Context, s inet.Stream, outgo if !ok { return } - if dead { - // continue in order to drain messages - continue - } err := writeMsg(&rpc.RPC) if err != nil { s.Reset() log.Warningf("writing message to %s: %s", s.Conn().RemotePeer(), err) - dead = true - go func() { - p.peerDead <- s.Conn().RemotePeer() - }() + select { + case p.peerDead <- s.Conn().RemotePeer(): + case <-ctx.Done(): + } } - case <-ctx.Done(): return } diff --git a/floodsub.go b/floodsub.go index 466e73b..5c56d1a 100644 --- a/floodsub.go +++ b/floodsub.go @@ -103,6 +103,14 @@ func NewFloodSub(ctx context.Context, h host.Host) *PubSub { // processLoop handles all inputs arriving on the channels func (p *PubSub) processLoop(ctx context.Context) { + defer func() { + // Clean up go routines. + for _, ch := range p.peers { + close(ch) + } + p.peers = nil + p.topics = nil + }() for { select { case s := <-p.newPeers: diff --git a/floodsub_test.go b/floodsub_test.go index 23734e8..51ee0b8 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -550,6 +550,45 @@ func TestSubscribeMultipleTimes(t *testing.T) { } } +func TestPeerDisconnect(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hosts := getNetHosts(t, ctx, 2) + psubs := getPubsubs(ctx, hosts) + + connect(t, hosts[0], hosts[1]) + + _, err := psubs[0].Subscribe("foo") + if err != nil { + t.Fatal(err) + } + + _, err = psubs[1].Subscribe("foo") + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Millisecond * 10) + + peers := psubs[0].ListPeers("foo") + assertPeerList(t, peers, hosts[1].ID()) + for _, c := range hosts[1].Network().ConnsToPeer(hosts[0].ID()) { + streams, err := c.GetStreams() + if err != nil { + t.Fatal(err) + } + for _, s := range streams { + s.Close() + } + } + + time.Sleep(time.Millisecond * 10) + + peers = psubs[0].ListPeers("foo") + assertPeerList(t, peers) +} + func assertPeerList(t *testing.T, peers []peer.ID, expected ...peer.ID) { sort.Sort(peer.IDSlice(peers)) sort.Sort(peer.IDSlice(expected))