Merge pull request #39 from libp2p/fix/38

mark the peer as dead if the inbound stream closes
This commit is contained in:
Steven Allen 2017-10-14 16:26:20 -07:00 committed by GitHub
commit 6d14add4a3
3 changed files with 55 additions and 10 deletions

18
comm.go
View File

@ -39,6 +39,10 @@ func (p *PubSub) handleNewStream(s inet.Stream) {
// but it doesn't hurt to send it. // but it doesn't hurt to send it.
s.Close() s.Close()
} }
select {
case p.peerDead <- s.Conn().RemotePeer():
case <-p.ctx.Done():
}
return return
} }
@ -54,7 +58,6 @@ func (p *PubSub) handleNewStream(s inet.Stream) {
} }
func (p *PubSub) handleSendingMessages(ctx context.Context, s inet.Stream, outgoing <-chan *RPC) { func (p *PubSub) handleSendingMessages(ctx context.Context, s inet.Stream, outgoing <-chan *RPC) {
var dead bool
bufw := bufio.NewWriter(s) bufw := bufio.NewWriter(s)
wc := ggio.NewDelimitedWriter(bufw) wc := ggio.NewDelimitedWriter(bufw)
@ -74,21 +77,16 @@ func (p *PubSub) handleSendingMessages(ctx context.Context, s inet.Stream, outgo
if !ok { if !ok {
return return
} }
if dead {
// continue in order to drain messages
continue
}
err := writeMsg(&rpc.RPC) err := writeMsg(&rpc.RPC)
if err != nil { if err != nil {
s.Reset() s.Reset()
log.Warningf("writing message to %s: %s", s.Conn().RemotePeer(), err) log.Warningf("writing message to %s: %s", s.Conn().RemotePeer(), err)
dead = true select {
go func() { case p.peerDead <- s.Conn().RemotePeer():
p.peerDead <- s.Conn().RemotePeer() case <-ctx.Done():
}() }
} }
case <-ctx.Done(): case <-ctx.Done():
return return
} }

View File

@ -103,6 +103,14 @@ func NewFloodSub(ctx context.Context, h host.Host) *PubSub {
// processLoop handles all inputs arriving on the channels // processLoop handles all inputs arriving on the channels
func (p *PubSub) processLoop(ctx context.Context) { func (p *PubSub) processLoop(ctx context.Context) {
defer func() {
// Clean up go routines.
for _, ch := range p.peers {
close(ch)
}
p.peers = nil
p.topics = nil
}()
for { for {
select { select {
case s := <-p.newPeers: case s := <-p.newPeers:

View File

@ -550,6 +550,45 @@ func TestSubscribeMultipleTimes(t *testing.T) {
} }
} }
func TestPeerDisconnect(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hosts := getNetHosts(t, ctx, 2)
psubs := getPubsubs(ctx, hosts)
connect(t, hosts[0], hosts[1])
_, err := psubs[0].Subscribe("foo")
if err != nil {
t.Fatal(err)
}
_, err = psubs[1].Subscribe("foo")
if err != nil {
t.Fatal(err)
}
time.Sleep(time.Millisecond * 10)
peers := psubs[0].ListPeers("foo")
assertPeerList(t, peers, hosts[1].ID())
for _, c := range hosts[1].Network().ConnsToPeer(hosts[0].ID()) {
streams, err := c.GetStreams()
if err != nil {
t.Fatal(err)
}
for _, s := range streams {
s.Close()
}
}
time.Sleep(time.Millisecond * 10)
peers = psubs[0].ListPeers("foo")
assertPeerList(t, peers)
}
func assertPeerList(t *testing.T, peers []peer.ID, expected ...peer.ID) { func assertPeerList(t *testing.T, peers []peer.ID, expected ...peer.ID) {
sort.Sort(peer.IDSlice(peers)) sort.Sort(peer.IDSlice(peers))
sort.Sort(peer.IDSlice(expected)) sort.Sort(peer.IDSlice(expected))