use tls.Conn.HandshakeContext instead of tls.Conn.Handshake (#106)

* use tls.Conn.HandshakeContext instead of tls.Conn.Handshake

* make sure that crypto/tls picks up the handshake ctx cancelation in tests
This commit is contained in:
Marten Seemann 2022-04-10 14:30:15 +01:00 committed by GitHub
parent b4e994803c
commit 7ee67dd8d4
2 changed files with 17 additions and 49 deletions

View File

@ -5,7 +5,6 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"net" "net"
"sync"
ci "github.com/libp2p/go-libp2p-core/crypto" ci "github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peer"
@ -71,44 +70,8 @@ func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p pee
return cs, err return cs, err
} }
func (t *Transport) handshake( func (t *Transport) handshake(ctx context.Context, tlsConn *tls.Conn, keyCh <-chan ci.PubKey) (sec.SecureConn, error) {
ctx context.Context, if err := tlsConn.HandshakeContext(ctx); err != nil {
tlsConn *tls.Conn,
keyCh <-chan ci.PubKey,
) (sec.SecureConn, error) {
// There's no way to pass a context to tls.Conn.Handshake().
// See https://github.com/golang/go/issues/18482.
// Close the connection instead.
select {
case <-ctx.Done():
tlsConn.Close()
default:
}
done := make(chan struct{})
var wg sync.WaitGroup
// Ensure that we do not return before
// either being done or having a context
// cancellation.
defer wg.Wait()
defer close(done)
wg.Add(1)
go func() {
defer wg.Done()
select {
case <-done:
case <-ctx.Done():
tlsConn.Close()
}
}()
if err := tlsConn.Handshake(); err != nil {
// if the context was canceled, return the context error
if ctxErr := ctx.Err(); ctxErr != nil {
return nil, ctxErr
}
return nil, err return nil, err
} }
@ -122,15 +85,7 @@ func (t *Transport) handshake(
return nil, errors.New("go-libp2p-tls BUG: expected remote pub key to be set") return nil, errors.New("go-libp2p-tls BUG: expected remote pub key to be set")
} }
conn, err := t.setupConn(tlsConn, remotePubKey) return t.setupConn(tlsConn, remotePubKey)
if err != nil {
// if the context was canceled, return the context error
if ctxErr := ctx.Err(); ctxErr != nil {
return nil, ctxErr
}
return nil, err
}
return conn, nil
} }
func (t *Transport) setupConn(tlsConn *tls.Conn, remotePubKey ci.PubKey) (sec.SecureConn, error) { func (t *Transport) setupConn(tlsConn *tls.Conn, remotePubKey ci.PubKey) (sec.SecureConn, error) {

View File

@ -121,6 +121,19 @@ func TestHandshakeSucceeds(t *testing.T) {
}) })
} }
// crypto/tls' cancellation logic works by spinning up a separate Go routine that watches the ctx.
// If the ctx is canceled, it kills the handshake.
// We need to make sure that the handshake doesn't complete before that Go routine picks up the cancellation.
type delayedConn struct {
net.Conn
delay time.Duration
}
func (c *delayedConn) Read(b []byte) (int, error) {
time.Sleep(c.delay)
return c.Conn.Read(b)
}
func TestHandshakeConnectionCancelations(t *testing.T) { func TestHandshakeConnectionCancelations(t *testing.T) {
_, clientKey := createPeer(t) _, clientKey := createPeer(t)
serverID, serverKey := createPeer(t) serverID, serverKey := createPeer(t)
@ -152,7 +165,7 @@ func TestHandshakeConnectionCancelations(t *testing.T) {
go func() { go func() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
cancel() cancel()
_, err := serverTransport.SecureInbound(ctx, serverInsecureConn, "") _, err := serverTransport.SecureInbound(ctx, &delayedConn{Conn: serverInsecureConn, delay: 5 * time.Millisecond}, "")
errChan <- err errChan <- err
}() }()
_, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID)