use tls.Conn.HandshakeContext instead of tls.Conn.Handshake (#106)
* use tls.Conn.HandshakeContext instead of tls.Conn.Handshake * make sure that crypto/tls picks up the handshake ctx cancelation in tests
This commit is contained in:
parent
b4e994803c
commit
7ee67dd8d4
|
@ -5,7 +5,6 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
|
||||||
|
|
||||||
ci "github.com/libp2p/go-libp2p-core/crypto"
|
ci "github.com/libp2p/go-libp2p-core/crypto"
|
||||||
"github.com/libp2p/go-libp2p-core/peer"
|
"github.com/libp2p/go-libp2p-core/peer"
|
||||||
|
@ -71,44 +70,8 @@ func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p pee
|
||||||
return cs, err
|
return cs, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Transport) handshake(
|
func (t *Transport) handshake(ctx context.Context, tlsConn *tls.Conn, keyCh <-chan ci.PubKey) (sec.SecureConn, error) {
|
||||||
ctx context.Context,
|
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||||
tlsConn *tls.Conn,
|
|
||||||
keyCh <-chan ci.PubKey,
|
|
||||||
) (sec.SecureConn, error) {
|
|
||||||
// There's no way to pass a context to tls.Conn.Handshake().
|
|
||||||
// See https://github.com/golang/go/issues/18482.
|
|
||||||
// Close the connection instead.
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
tlsConn.Close()
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
|
|
||||||
// Ensure that we do not return before
|
|
||||||
// either being done or having a context
|
|
||||||
// cancellation.
|
|
||||||
defer wg.Wait()
|
|
||||||
defer close(done)
|
|
||||||
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
case <-ctx.Done():
|
|
||||||
tlsConn.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if err := tlsConn.Handshake(); err != nil {
|
|
||||||
// if the context was canceled, return the context error
|
|
||||||
if ctxErr := ctx.Err(); ctxErr != nil {
|
|
||||||
return nil, ctxErr
|
|
||||||
}
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -122,15 +85,7 @@ func (t *Transport) handshake(
|
||||||
return nil, errors.New("go-libp2p-tls BUG: expected remote pub key to be set")
|
return nil, errors.New("go-libp2p-tls BUG: expected remote pub key to be set")
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := t.setupConn(tlsConn, remotePubKey)
|
return t.setupConn(tlsConn, remotePubKey)
|
||||||
if err != nil {
|
|
||||||
// if the context was canceled, return the context error
|
|
||||||
if ctxErr := ctx.Err(); ctxErr != nil {
|
|
||||||
return nil, ctxErr
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return conn, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Transport) setupConn(tlsConn *tls.Conn, remotePubKey ci.PubKey) (sec.SecureConn, error) {
|
func (t *Transport) setupConn(tlsConn *tls.Conn, remotePubKey ci.PubKey) (sec.SecureConn, error) {
|
||||||
|
|
|
@ -121,6 +121,19 @@ func TestHandshakeSucceeds(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// crypto/tls' cancellation logic works by spinning up a separate Go routine that watches the ctx.
|
||||||
|
// If the ctx is canceled, it kills the handshake.
|
||||||
|
// We need to make sure that the handshake doesn't complete before that Go routine picks up the cancellation.
|
||||||
|
type delayedConn struct {
|
||||||
|
net.Conn
|
||||||
|
delay time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *delayedConn) Read(b []byte) (int, error) {
|
||||||
|
time.Sleep(c.delay)
|
||||||
|
return c.Conn.Read(b)
|
||||||
|
}
|
||||||
|
|
||||||
func TestHandshakeConnectionCancelations(t *testing.T) {
|
func TestHandshakeConnectionCancelations(t *testing.T) {
|
||||||
_, clientKey := createPeer(t)
|
_, clientKey := createPeer(t)
|
||||||
serverID, serverKey := createPeer(t)
|
serverID, serverKey := createPeer(t)
|
||||||
|
@ -152,7 +165,7 @@ func TestHandshakeConnectionCancelations(t *testing.T) {
|
||||||
go func() {
|
go func() {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
cancel()
|
cancel()
|
||||||
_, err := serverTransport.SecureInbound(ctx, serverInsecureConn, "")
|
_, err := serverTransport.SecureInbound(ctx, &delayedConn{Conn: serverInsecureConn, delay: 5 * time.Millisecond}, "")
|
||||||
errChan <- err
|
errChan <- err
|
||||||
}()
|
}()
|
||||||
_, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID)
|
_, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID)
|
||||||
|
|
Loading…
Reference in New Issue