From e267d49e21d5ffdb6b6ebf8d5c08f068c6201061 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 5 Sep 2021 15:56:11 +0100 Subject: [PATCH] add the peer ID to SecureInbound --- p2p/net/upgrader/listener.go | 6 ++++-- p2p/net/upgrader/listener_test.go | 7 +++++-- p2p/net/upgrader/upgrader.go | 25 ++++++++++++++----------- p2p/net/upgrader/upgrader_test.go | 4 +++- 4 files changed, 26 insertions(+), 16 deletions(-) diff --git a/p2p/net/upgrader/listener.go b/p2p/net/upgrader/listener.go index f03e42d8..d25a6565 100644 --- a/p2p/net/upgrader/listener.go +++ b/p2p/net/upgrader/listener.go @@ -5,9 +5,11 @@ import ( "fmt" "sync" + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/transport" + logging "github.com/ipfs/go-log" tec "github.com/jbenet/go-temp-err-catcher" - "github.com/libp2p/go-libp2p-core/transport" manet "github.com/multiformats/go-multiaddr/net" ) @@ -106,7 +108,7 @@ func (l *listener) handleIncoming() { ctx, cancel := context.WithTimeout(l.ctx, transport.AcceptTimeout) defer cancel() - conn, err := l.upgrader.UpgradeInbound(ctx, l.transport, maconn) + conn, err := l.upgrader.Upgrade(ctx, l.transport, maconn, network.DirInbound, "") if err != nil { // Don't bother bubbling this up. We just failed // to completely negotiate the connection. diff --git a/p2p/net/upgrader/listener_test.go b/p2p/net/upgrader/listener_test.go index 73946865..38d0d91d 100644 --- a/p2p/net/upgrader/listener_test.go +++ b/p2p/net/upgrader/listener_test.go @@ -12,6 +12,7 @@ import ( "github.com/libp2p/go-libp2p-core/sec" "github.com/libp2p/go-libp2p-core/transport" st "github.com/libp2p/go-libp2p-transport-upgrader" + ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" @@ -26,8 +27,10 @@ type MuxAdapter struct { tpt sec.SecureTransport } -func (mux *MuxAdapter) SecureInbound(ctx context.Context, insecure net.Conn) (sec.SecureConn, bool, error) { - sconn, err := mux.tpt.SecureInbound(ctx, insecure) +var _ sec.SecureMuxer = &MuxAdapter{} + +func (mux *MuxAdapter) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, bool, error) { + sconn, err := mux.tpt.SecureInbound(ctx, insecure, p) return sconn, true, err } diff --git a/p2p/net/upgrader/upgrader.go b/p2p/net/upgrader/upgrader.go index de5de08b..22e78259 100644 --- a/p2p/net/upgrader/upgrader.go +++ b/p2p/net/upgrader/upgrader.go @@ -13,7 +13,7 @@ import ( ipnet "github.com/libp2p/go-libp2p-core/pnet" "github.com/libp2p/go-libp2p-core/sec" "github.com/libp2p/go-libp2p-core/transport" - "github.com/libp2p/go-libp2p-pnet" + pnet "github.com/libp2p/go-libp2p-pnet" manet "github.com/multiformats/go-multiaddr/net" ) @@ -51,20 +51,23 @@ func (u *Upgrader) UpgradeListener(t transport.Transport, list manet.Listener) t // UpgradeOutbound upgrades the given outbound multiaddr-net connection into a // full libp2p-transport connection. +// Deprecated: use Upgrade instead. func (u *Upgrader) UpgradeOutbound(ctx context.Context, t transport.Transport, maconn manet.Conn, p peer.ID) (transport.CapableConn, error) { - if p == "" { - return nil, ErrNilPeer - } - return u.upgrade(ctx, t, maconn, p, network.DirOutbound) + return u.Upgrade(ctx, t, maconn, network.DirOutbound, p) } // UpgradeInbound upgrades the given inbound multiaddr-net connection into a // full libp2p-transport connection. +// Deprecated: use Upgrade instead. func (u *Upgrader) UpgradeInbound(ctx context.Context, t transport.Transport, maconn manet.Conn) (transport.CapableConn, error) { - return u.upgrade(ctx, t, maconn, "", network.DirInbound) + return u.Upgrade(ctx, t, maconn, network.DirInbound, "") } -func (u *Upgrader) upgrade(ctx context.Context, t transport.Transport, maconn manet.Conn, p peer.ID, dir network.Direction) (transport.CapableConn, error) { +// Upgrade upgrades the multiaddr/net connection into a full libp2p-transport connection. +func (u *Upgrader) Upgrade(ctx context.Context, t transport.Transport, maconn manet.Conn, dir network.Direction, p peer.ID) (transport.CapableConn, error) { + if dir == network.DirOutbound && p == "" { + return nil, ErrNilPeer + } var stat network.Stat if cs, ok := maconn.(network.ConnStat); ok { stat = cs.Stat() @@ -83,7 +86,7 @@ func (u *Upgrader) upgrade(ctx context.Context, t transport.Transport, maconn ma return nil, ipnet.ErrNotInPrivateNetwork } - sconn, server, err := u.setupSecurity(ctx, conn, p) + sconn, server, err := u.setupSecurity(ctx, conn, p, dir) if err != nil { conn.Close() return nil, fmt.Errorf("failed to negotiate security protocol: %s", err) @@ -115,9 +118,9 @@ func (u *Upgrader) upgrade(ctx context.Context, t transport.Transport, maconn ma return tc, nil } -func (u *Upgrader) setupSecurity(ctx context.Context, conn net.Conn, p peer.ID) (sec.SecureConn, bool, error) { - if p == "" { - return u.Secure.SecureInbound(ctx, conn) +func (u *Upgrader) setupSecurity(ctx context.Context, conn net.Conn, p peer.ID, dir network.Direction) (sec.SecureConn, bool, error) { + if dir == network.DirInbound { + return u.Secure.SecureInbound(ctx, conn, p) } return u.Secure.SecureOutbound(ctx, conn, p) } diff --git a/p2p/net/upgrader/upgrader_test.go b/p2p/net/upgrader/upgrader_test.go index 795a0112..4bebf25e 100644 --- a/p2p/net/upgrader/upgrader_test.go +++ b/p2p/net/upgrader/upgrader_test.go @@ -8,12 +8,14 @@ import ( "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/mux" + "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/sec/insecure" "github.com/libp2p/go-libp2p-core/test" "github.com/libp2p/go-libp2p-core/transport" mplex "github.com/libp2p/go-libp2p-mplex" st "github.com/libp2p/go-libp2p-transport-upgrader" + ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" @@ -104,7 +106,7 @@ func dial(t *testing.T, upgrader *st.Upgrader, raddr ma.Multiaddr, p peer.ID) (t if err != nil { return nil, err } - return upgrader.UpgradeOutbound(context.Background(), nil, macon, p) + return upgrader.Upgrade(context.Background(), nil, macon, network.DirOutbound, p) } func TestOutboundConnectionGating(t *testing.T) {