From f56473fb013ededdc63021ca80467be784e45358 Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Wed, 8 Aug 2018 13:36:01 -0700 Subject: [PATCH] make sure reset works on half-closed streams --- p2p/net/mock/mock_stream.go | 73 ++++++++++++++++++++++++------------- 1 file changed, 47 insertions(+), 26 deletions(-) diff --git a/p2p/net/mock/mock_stream.go b/p2p/net/mock/mock_stream.go index 79a3834c..6d056c8e 100644 --- a/p2p/net/mock/mock_stream.go +++ b/p2p/net/mock/mock_stream.go @@ -22,7 +22,7 @@ type stream struct { close chan struct{} closed chan struct{} - state error + writeErr error protocol protocol.ID } @@ -56,7 +56,7 @@ func (s *stream) Write(p []byte) (n int, err error) { t := time.Now().Add(delay) select { case <-s.closed: // bail out if we're closing. - return 0, s.state + return 0, s.writeErr case s.toDeliver <- &transportObject{msg: p, arrivalTime: t}: } return len(p), nil @@ -76,30 +76,28 @@ func (s *stream) Close() error { default: } <-s.closed - if s.state != ErrClosed { - return s.state + if s.writeErr != ErrClosed { + return s.writeErr } return nil } func (s *stream) Reset() error { - // Cancel any pending writes. - s.write.Close() + // Cancel any pending reads/writes with an error. + s.write.CloseWithError(ErrReset) + s.read.CloseWithError(ErrReset) select { case s.reset <- struct{}{}: default: } <-s.closed - if s.state != ErrReset { - return s.state - } + + // No meaningful error case here. return nil } func (s *stream) teardown() { - s.write.Close() - // at this point, no streams are writing. s.conn.removeStream(s) @@ -151,20 +149,21 @@ func (s *stream) transport() { // writeBuf writes the contents of buf through to the s.Writer. // done only when arrival time makes sense. - drainBuf := func() { + drainBuf := func() error { if buf.Len() > 0 { _, err := s.write.Write(buf.Bytes()) if err != nil { - return + return err } buf.Reset() } + return nil } // deliverOrWait is a helper func that processes // an incoming packet. it waits until the arrival time, // and then writes things out. - deliverOrWait := func(o *transportObject) { + deliverOrWait := func(o *transportObject) error { buffered := len(o.msg) + buf.Len() // Yes, we can end up extending a timer multiple times if we @@ -189,43 +188,65 @@ func (s *stream) transport() { select { case <-timer.C: case <-s.reset: - s.reset <- struct{}{} - return + select { + case s.reset <- struct{}{}: + default: + } + return ErrReset + } + if err := drainBuf(); err != nil { + return err } - drainBuf() // write this message. _, err := s.write.Write(o.msg) if err != nil { - log.Error("mock_stream", err) + return err } } else { buf.Write(o.msg) } + return nil } for { // Reset takes precedent. select { case <-s.reset: - s.state = ErrReset - s.read.CloseWithError(ErrReset) + s.writeErr = ErrReset return default: } select { case <-s.reset: - s.state = ErrReset - s.read.CloseWithError(ErrReset) + s.writeErr = ErrReset return case <-s.close: - s.state = ErrClosed - drainBuf() + if err := drainBuf(); err != nil { + s.resetWith(err) + return + } + s.writeErr = s.write.Close() + if s.writeErr == nil { + s.writeErr = ErrClosed + } return case o := <-s.toDeliver: - deliverOrWait(o) + if err := deliverOrWait(o); err != nil { + s.resetWith(err) + return + } case <-timer.C: // ok, due to write it out. - drainBuf() + if err := drainBuf(); err != nil { + s.resetWith(err) + return + } } } } + +func (s *stream) resetWith(err error) { + s.write.CloseWithError(err) + s.read.CloseWithError(err) + s.writeErr = err +}