fix: stream leaks

Signed-off-by: Guilhem Fanton <guilhem.fanton@gmail.com>
This commit is contained in:
Guilhem Fanton 2020-10-28 14:39:45 +01:00
parent 6c6522afbc
commit c9e4735369
3 changed files with 6 additions and 9 deletions

View File

@ -68,7 +68,7 @@ func (rp *rendezvousPoint) Register(ctx context.Context, ns string, ttl int) (ti
if err != nil {
return 0, err
}
defer s.Close()
defer s.Reset()
r := ggio.NewDelimitedReader(s, inet.MessageSizeMax)
w := ggio.NewDelimitedWriter(s)
@ -165,7 +165,7 @@ func (rp *rendezvousPoint) Discover(ctx context.Context, ns string, limit int, c
if err != nil {
return nil, nil, err
}
defer s.Close()
defer s.Reset()
r := ggio.NewDelimitedReader(s, inet.MessageSizeMax)
w := ggio.NewDelimitedWriter(s)
@ -174,7 +174,6 @@ func (rp *rendezvousPoint) Discover(ctx context.Context, ns string, limit int, c
}
func discoverQuery(ns string, limit int, cookie []byte, r ggio.Reader, w ggio.Writer) ([]Registration, []byte, error) {
req := newDiscoverMessage(ns, limit, cookie)
err := w.WriteMsg(req)
if err != nil {
@ -222,7 +221,7 @@ func (rp *rendezvousPoint) DiscoverAsync(ctx context.Context, ns string) (<-chan
}
func discoverAsync(ctx context.Context, ns string, s inet.Stream, ch chan Registration) {
defer s.Close()
defer s.Reset()
defer close(ch)
r := ggio.NewDelimitedReader(s, inet.MessageSizeMax)

6
svc.go
View File

@ -38,6 +38,8 @@ func NewRendezvousService(host host.Host, db db.DB, rzs ...RendezvousSync) *Rend
}
func (rz *RendezvousService) handleStream(s inet.Stream) {
defer s.Reset()
pid := s.Conn().RemotePeer()
log.Debugf("New stream from %s", pid.Pretty())
@ -50,7 +52,6 @@ func (rz *RendezvousService) handleStream(s inet.Stream) {
err := r.ReadMsg(&req)
if err != nil {
s.Reset()
return
}
@ -63,7 +64,6 @@ func (rz *RendezvousService) handleStream(s inet.Stream) {
err = w.WriteMsg(&res)
if err != nil {
log.Debugf("Error writing response: %s", err.Error())
s.Reset()
return
}
@ -80,13 +80,11 @@ func (rz *RendezvousService) handleStream(s inet.Stream) {
err = w.WriteMsg(&res)
if err != nil {
log.Debugf("Error writing response: %s", err.Error())
s.Reset()
return
}
default:
log.Debugf("Unexpected message: %s", t.String())
s.Reset()
return
}
}

View File

@ -140,7 +140,7 @@ func TestSVCRegistrationAndDiscovery(t *testing.T) {
t.Fatal(err)
}
if len(rrs) != 3 {
t.Fatal("Expected 3 registrations")
t.Fatalf("Expected 3 registrations, got %d", len(rrs))
}
for j, rr := range rrs {