From c9e4735369dc100e8a62ca4a408c1b6738c43f40 Mon Sep 17 00:00:00 2001 From: Guilhem Fanton Date: Wed, 28 Oct 2020 14:39:45 +0100 Subject: [PATCH] fix: stream leaks Signed-off-by: Guilhem Fanton --- client.go | 7 +++---- svc.go | 6 ++---- svc_test.go | 2 +- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/client.go b/client.go index 6cd091b..219f688 100644 --- a/client.go +++ b/client.go @@ -68,7 +68,7 @@ func (rp *rendezvousPoint) Register(ctx context.Context, ns string, ttl int) (ti if err != nil { return 0, err } - defer s.Close() + defer s.Reset() r := ggio.NewDelimitedReader(s, inet.MessageSizeMax) w := ggio.NewDelimitedWriter(s) @@ -165,7 +165,7 @@ func (rp *rendezvousPoint) Discover(ctx context.Context, ns string, limit int, c if err != nil { return nil, nil, err } - defer s.Close() + defer s.Reset() r := ggio.NewDelimitedReader(s, inet.MessageSizeMax) w := ggio.NewDelimitedWriter(s) @@ -174,7 +174,6 @@ func (rp *rendezvousPoint) Discover(ctx context.Context, ns string, limit int, c } func discoverQuery(ns string, limit int, cookie []byte, r ggio.Reader, w ggio.Writer) ([]Registration, []byte, error) { - req := newDiscoverMessage(ns, limit, cookie) err := w.WriteMsg(req) if err != nil { @@ -222,7 +221,7 @@ func (rp *rendezvousPoint) DiscoverAsync(ctx context.Context, ns string) (<-chan } func discoverAsync(ctx context.Context, ns string, s inet.Stream, ch chan Registration) { - defer s.Close() + defer s.Reset() defer close(ch) r := ggio.NewDelimitedReader(s, inet.MessageSizeMax) diff --git a/svc.go b/svc.go index e7204b1..11ee5e6 100644 --- a/svc.go +++ b/svc.go @@ -38,6 +38,8 @@ func NewRendezvousService(host host.Host, db db.DB, rzs ...RendezvousSync) *Rend } func (rz *RendezvousService) handleStream(s inet.Stream) { + defer s.Reset() + pid := s.Conn().RemotePeer() log.Debugf("New stream from %s", pid.Pretty()) @@ -50,7 +52,6 @@ func (rz *RendezvousService) handleStream(s inet.Stream) { err := r.ReadMsg(&req) if err != nil { - s.Reset() return } @@ -63,7 +64,6 @@ func (rz *RendezvousService) handleStream(s inet.Stream) { err = w.WriteMsg(&res) if err != nil { log.Debugf("Error writing response: %s", err.Error()) - s.Reset() return } @@ -80,13 +80,11 @@ func (rz *RendezvousService) handleStream(s inet.Stream) { err = w.WriteMsg(&res) if err != nil { log.Debugf("Error writing response: %s", err.Error()) - s.Reset() return } default: log.Debugf("Unexpected message: %s", t.String()) - s.Reset() return } } diff --git a/svc_test.go b/svc_test.go index cc311d1..44751fb 100644 --- a/svc_test.go +++ b/svc_test.go @@ -140,7 +140,7 @@ func TestSVCRegistrationAndDiscovery(t *testing.T) { t.Fatal(err) } if len(rrs) != 3 { - t.Fatal("Expected 3 registrations") + t.Fatalf("Expected 3 registrations, got %d", len(rrs)) } for j, rr := range rrs {