From a109f8da858ed65b0bc16d5793af724572db89b8 Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Thu, 21 Jun 2018 14:13:48 -0700 Subject: [PATCH] expose methods from underlying connection types This sucks but I can't think of a better way to do this. We really do want to expose these features and doing so through type assertions is very go-like. --- net.go | 88 ++++++++++++++++++++++++++++++++++++++--------------- net_test.go | 4 +++ 2 files changed, 68 insertions(+), 24 deletions(-) diff --git a/net.go b/net.go index 104a767..a6d0e76 100644 --- a/net.go +++ b/net.go @@ -28,8 +28,64 @@ type Conn interface { RemoteMultiaddr() ma.Multiaddr } -// WrapNetConn wraps a net.Conn object with a Multiaddr -// friendly Conn. +type halfOpen interface { + net.Conn + CloseRead() error + CloseWrite() error +} + +func wrap(nconn net.Conn, laddr, raddr ma.Multiaddr) Conn { + endpts := maEndpoints{ + laddr: laddr, + raddr: raddr, + } + // This sucks. However, it's the only way to reliably expose the + // underlying methods. This way, users that need access to, e.g., + // CloseRead and CloseWrite, can do so via type assertions. + switch nconn := nconn.(type) { + case *net.TCPConn: + return &struct { + *net.TCPConn + maEndpoints + }{nconn, endpts} + case *net.UDPConn: + return &struct { + *net.UDPConn + maEndpoints + }{nconn, endpts} + case *net.IPConn: + return &struct { + *net.IPConn + maEndpoints + }{nconn, endpts} + case *net.UnixConn: + return &struct { + *net.UnixConn + maEndpoints + }{nconn, endpts} + case halfOpen: + return &struct { + halfOpen + maEndpoints + }{nconn, endpts} + default: + return &struct { + net.Conn + maEndpoints + }{nconn, endpts} + } +} + +// WrapNetConn wraps a net.Conn object with a Multiaddr friendly Conn. +// +// This function does it's best to avoid "hiding" methods exposed by the wrapped +// type. Guarantees: +// +// * If the wrapped connection exposes the "half-open" closer methods +// (CloseWrite, CloseRead), these will be available on the wrapped connection +// via type assertions. +// * If the wrapped connection is a UnixConn, IPConn, TCPConn, or UDPConn, all +// methods on these wrapped connections will be available via type assertions. func WrapNetConn(nconn net.Conn) (Conn, error) { if nconn == nil { return nil, fmt.Errorf("failed to convert nconn.LocalAddr: nil") @@ -45,30 +101,23 @@ func WrapNetConn(nconn net.Conn) (Conn, error) { return nil, fmt.Errorf("failed to convert nconn.RemoteAddr: %s", err) } - return &maConn{ - Conn: nconn, - laddr: laddr, - raddr: raddr, - }, nil + return wrap(nconn, laddr, raddr), nil } -// maConn implements the Conn interface. It's a thin wrapper -// around a net.Conn -type maConn struct { - net.Conn +type maEndpoints struct { laddr ma.Multiaddr raddr ma.Multiaddr } // LocalMultiaddr returns the local address associated with // this connection -func (c *maConn) LocalMultiaddr() ma.Multiaddr { +func (c *maEndpoints) LocalMultiaddr() ma.Multiaddr { return c.laddr } // RemoteMultiaddr returns the remote address associated with // this connection -func (c *maConn) RemoteMultiaddr() ma.Multiaddr { +func (c *maEndpoints) RemoteMultiaddr() ma.Multiaddr { return c.raddr } @@ -135,12 +184,7 @@ func (d *Dialer) DialContext(ctx context.Context, remote ma.Multiaddr) (Conn, er return nil, err } } - - return &maConn{ - Conn: nconn, - laddr: local, - raddr: remote, - }, nil + return wrap(nconn, local, remote), nil } // Dial connects to a remote address. It uses an underlying net.Conn, @@ -204,11 +248,7 @@ func (l *maListener) Accept() (Conn, error) { return nil, fmt.Errorf("failed to convert connn.RemoteAddr: %s", err) } - return &maConn{ - Conn: nconn, - laddr: l.laddr, - raddr: raddr, - }, nil + return wrap(nconn, l.laddr, raddr), nil } // Multiaddr returns the listener's (local) Multiaddr. diff --git a/net_test.go b/net_test.go index a073e3b..9454fe2 100644 --- a/net_test.go +++ b/net_test.go @@ -407,12 +407,14 @@ func TestWrapNetConn(t *testing.T) { defer wg.Done() cB, err := listener.Accept() checkErr(err, "failed to accept") + _ = cB.(halfOpen) cB.Close() }() cA, err := net.Dial("tcp", listener.Addr().String()) checkErr(err, "failed to dial") defer cA.Close() + _ = cA.(halfOpen) lmaddr, err := FromNetAddr(cA.LocalAddr()) checkErr(err, "failed to get local addr") @@ -422,6 +424,8 @@ func TestWrapNetConn(t *testing.T) { mcA, err := WrapNetConn(cA) checkErr(err, "failed to wrap conn") + _ = mcA.(halfOpen) + if mcA.LocalAddr().String() != cA.LocalAddr().String() { t.Error("wrapped conn local addr differs") }