fix: obey new stream timeout
This commit is contained in:
parent
fcf69647e9
commit
024f1af9ae
|
@ -631,10 +631,24 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I
|
|||
}, nil
|
||||
}
|
||||
|
||||
selected, err := msmux.SelectOneOf(pidStrings, s)
|
||||
if err != nil {
|
||||
// Negotiate the protocol in the background, obeying the context.
|
||||
var selected string
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
selected, err = msmux.SelectOneOf(pidStrings, s)
|
||||
errCh <- err
|
||||
}()
|
||||
select {
|
||||
case err = <-errCh:
|
||||
if err != nil {
|
||||
s.Reset()
|
||||
return nil, err
|
||||
}
|
||||
case <-ctx.Done():
|
||||
s.Reset()
|
||||
return nil, err
|
||||
// wait for the negotiation to cancel.
|
||||
<-errCh
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
selpid := protocol.ID(selected)
|
||||
|
|
|
@ -3,6 +3,7 @@ package basichost
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
@ -777,6 +778,49 @@ func TestHostAddrChangeDetection(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestNegotiationCancel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
h1, h2 := getHostPair(ctx, t)
|
||||
defer h1.Close()
|
||||
defer h2.Close()
|
||||
|
||||
// pre-negotiation so we can make the negotiation hang.
|
||||
h1.Network().SetStreamHandler(func(s network.Stream) {
|
||||
<-ctx.Done() // wait till the test is done.
|
||||
s.Reset()
|
||||
})
|
||||
|
||||
ctx2, cancel2 := context.WithCancel(ctx)
|
||||
defer cancel2()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
s, err := h2.NewStream(ctx2, h1.ID(), "/testing")
|
||||
if s != nil {
|
||||
errCh <- fmt.Errorf("expected to fail negotiation")
|
||||
return
|
||||
}
|
||||
errCh <- err
|
||||
}()
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Fatal(err)
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
// ok, hung.
|
||||
}
|
||||
cancel2()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
require.Equal(t, err, context.Canceled)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
// failed to cancel
|
||||
t.Fatal("expected negotiation to be canceled")
|
||||
}
|
||||
}
|
||||
|
||||
func waitForAddrChangeEvent(ctx context.Context, sub event.Subscription, t *testing.T) event.EvtLocalAddressesUpdated {
|
||||
for {
|
||||
select {
|
||||
|
|
Loading…
Reference in New Issue