diff --git a/floodsub_test.go b/floodsub_test.go index 23734e8..51ee0b8 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -550,6 +550,45 @@ func TestSubscribeMultipleTimes(t *testing.T) { } } +func TestPeerDisconnect(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hosts := getNetHosts(t, ctx, 2) + psubs := getPubsubs(ctx, hosts) + + connect(t, hosts[0], hosts[1]) + + _, err := psubs[0].Subscribe("foo") + if err != nil { + t.Fatal(err) + } + + _, err = psubs[1].Subscribe("foo") + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Millisecond * 10) + + peers := psubs[0].ListPeers("foo") + assertPeerList(t, peers, hosts[1].ID()) + for _, c := range hosts[1].Network().ConnsToPeer(hosts[0].ID()) { + streams, err := c.GetStreams() + if err != nil { + t.Fatal(err) + } + for _, s := range streams { + s.Close() + } + } + + time.Sleep(time.Millisecond * 10) + + peers = psubs[0].ListPeers("foo") + assertPeerList(t, peers) +} + func assertPeerList(t *testing.T, peers []peer.ID, expected ...peer.ID) { sort.Sort(peer.IDSlice(peers)) sort.Sort(peer.IDSlice(expected))