diff --git a/p2p/security/tls/transport.go b/p2p/security/tls/transport.go index 6be3c79a..85e75a15 100644 --- a/p2p/security/tls/transport.go +++ b/p2p/security/tls/transport.go @@ -5,7 +5,6 @@ import ( "crypto/tls" "errors" "net" - "sync" ci "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/peer" @@ -71,44 +70,8 @@ func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p pee return cs, err } -func (t *Transport) handshake( - ctx context.Context, - tlsConn *tls.Conn, - keyCh <-chan ci.PubKey, -) (sec.SecureConn, error) { - // There's no way to pass a context to tls.Conn.Handshake(). - // See https://github.com/golang/go/issues/18482. - // Close the connection instead. - select { - case <-ctx.Done(): - tlsConn.Close() - default: - } - - done := make(chan struct{}) - var wg sync.WaitGroup - - // Ensure that we do not return before - // either being done or having a context - // cancellation. - defer wg.Wait() - defer close(done) - - wg.Add(1) - go func() { - defer wg.Done() - select { - case <-done: - case <-ctx.Done(): - tlsConn.Close() - } - }() - - if err := tlsConn.Handshake(); err != nil { - // if the context was canceled, return the context error - if ctxErr := ctx.Err(); ctxErr != nil { - return nil, ctxErr - } +func (t *Transport) handshake(ctx context.Context, tlsConn *tls.Conn, keyCh <-chan ci.PubKey) (sec.SecureConn, error) { + if err := tlsConn.HandshakeContext(ctx); err != nil { return nil, err } @@ -122,15 +85,7 @@ func (t *Transport) handshake( return nil, errors.New("go-libp2p-tls BUG: expected remote pub key to be set") } - conn, err := t.setupConn(tlsConn, remotePubKey) - if err != nil { - // if the context was canceled, return the context error - if ctxErr := ctx.Err(); ctxErr != nil { - return nil, ctxErr - } - return nil, err - } - return conn, nil + return t.setupConn(tlsConn, remotePubKey) } func (t *Transport) setupConn(tlsConn *tls.Conn, remotePubKey ci.PubKey) (sec.SecureConn, error) { diff --git a/p2p/security/tls/transport_test.go b/p2p/security/tls/transport_test.go index 8106f815..a290ed6f 100644 --- a/p2p/security/tls/transport_test.go +++ b/p2p/security/tls/transport_test.go @@ -121,6 +121,19 @@ func TestHandshakeSucceeds(t *testing.T) { }) } +// crypto/tls' cancellation logic works by spinning up a separate Go routine that watches the ctx. +// If the ctx is canceled, it kills the handshake. +// We need to make sure that the handshake doesn't complete before that Go routine picks up the cancellation. +type delayedConn struct { + net.Conn + delay time.Duration +} + +func (c *delayedConn) Read(b []byte) (int, error) { + time.Sleep(c.delay) + return c.Conn.Read(b) +} + func TestHandshakeConnectionCancelations(t *testing.T) { _, clientKey := createPeer(t) serverID, serverKey := createPeer(t) @@ -152,7 +165,7 @@ func TestHandshakeConnectionCancelations(t *testing.T) { go func() { ctx, cancel := context.WithCancel(context.Background()) cancel() - _, err := serverTransport.SecureInbound(ctx, serverInsecureConn, "") + _, err := serverTransport.SecureInbound(ctx, &delayedConn{Conn: serverInsecureConn, delay: 5 * time.Millisecond}, "") errChan <- err }() _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID)