From e9e128cbf8e4752b33d158bd781f7c4de94703e1 Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Thu, 20 Feb 2020 10:57:02 +1100 Subject: [PATCH] Split Client dialers and listeners --- client.go | 61 +++++++++++------------ client_test.go | 2 +- socket.go | 129 +++++++++++++++---------------------------------- 3 files changed, 71 insertions(+), 121 deletions(-) diff --git a/client.go b/client.go index 8c1eb87f..60bd27e8 100644 --- a/client.go +++ b/client.go @@ -58,7 +58,8 @@ type Client struct { peerID PeerID defaultStorage *storage.Client onClose []func() - conns []socket + dialers []dialer + listeners []listener dhtServers []*dht.Server ipBlockList iplist.Ranger // Our BitTorrent protocol extension bytes, sent in our BT handshakes. @@ -92,7 +93,7 @@ func (cl *Client) PeerID() PeerID { } func (cl *Client) LocalPort() (port int) { - cl.eachListener(func(l socket) bool { + cl.eachListener(func(l listener) bool { _port := missinggo.AddrPort(l.Addr()) if _port == 0 { panic(l) @@ -227,28 +228,34 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) { } } - cl.conns, err = listenAll(cl.listenNetworks(), cl.config.ListenHost, cl.config.ListenPort, cl.config.ProxyURL, cl.firewallCallback) + sockets, err := listenAll(cl.listenNetworks(), cl.config.ListenHost, cl.config.ListenPort, cl.firewallCallback) if err != nil { return } + // Check for panics. cl.LocalPort() - for _, s := range cl.conns { + for _, _s := range sockets { + s := _s // Go is fucking retarded. + cl.onClose = append(cl.onClose, func() { s.Close() }) if peerNetworkEnabled(parseNetworkString(s.Addr().Network()), cl.config) { + cl.dialers = append(cl.dialers, s) + cl.listeners = append(cl.listeners, s) go cl.acceptConnections(s) } } go cl.forwardPort() if !cfg.NoDHT { - for _, s := range cl.conns { + for _, s := range sockets { if pc, ok := s.(net.PacketConn); ok { ds, err := cl.newDhtServer(pc) if err != nil { panic(err) } cl.dhtServers = append(cl.dhtServers, ds) + cl.onClose = append(cl.onClose, func() { ds.Close() }) } } } @@ -334,27 +341,17 @@ func (cl *Client) eachDhtServer(f func(*dht.Server)) { } } -func (cl *Client) closeSockets() { - cl.eachListener(func(l socket) bool { - l.Close() - return true - }) - cl.conns = nil -} - // Stops the client. All connections to peers are closed and all activity will // come to a halt. func (cl *Client) Close() { cl.lock() defer cl.unlock() cl.closed.Set() - cl.eachDhtServer(func(s *dht.Server) { s.Close() }) - cl.closeSockets() for _, t := range cl.torrents { t.close() } - for _, f := range cl.onClose { - f() + for i := range cl.onClose { + cl.onClose[len(cl.onClose)-1-i]() } cl.event.Broadcast() } @@ -521,18 +518,14 @@ func (cl *Client) dialFirst(ctx context.Context, addr string) (res dialResult) { func() { cl.lock() defer cl.unlock() - cl.eachListener(func(s socket) bool { + cl.eachDialer(func(s dialer) bool { func() { - network := s.Addr().Network() - if !peerNetworkEnabled(parseNetworkString(network), cl.config) { - return - } left++ //cl.logger.Printf("dialing %s on %s/%s", addr, s.Addr().Network(), s.Addr()) go func() { resCh <- dialResult{ cl.dialFromSocket(ctx, s, addr), - network, + s.LocalAddr().Network(), } }() }() @@ -566,11 +559,11 @@ func (cl *Client) dialFirst(ctx context.Context, addr string) (res dialResult) { return res } -func (cl *Client) dialFromSocket(ctx context.Context, s socket, addr string) net.Conn { - network := s.Addr().Network() +func (cl *Client) dialFromSocket(ctx context.Context, s dialer, addr string) net.Conn { + network := s.LocalAddr().Network() cte := cl.config.ConnTracker.Wait( ctx, - conntrack.Entry{network, s.Addr().String(), addr}, + conntrack.Entry{network, s.LocalAddr().String(), addr}, "dial torrent client", 0, ) @@ -1264,8 +1257,16 @@ func firstNotNil(ips ...net.IP) net.IP { return nil } -func (cl *Client) eachListener(f func(socket) bool) { - for _, s := range cl.conns { +func (cl *Client) eachDialer(f func(dialer) bool) { + for _, s := range cl.dialers { + if !f(s) { + break + } + } +} + +func (cl *Client) eachListener(f func(listener) bool) { + for _, s := range cl.listeners { if !f(s) { break } @@ -1273,7 +1274,7 @@ func (cl *Client) eachListener(f func(socket) bool) { } func (cl *Client) findListener(f func(net.Listener) bool) (ret net.Listener) { - cl.eachListener(func(l socket) bool { + cl.eachListener(func(l listener) bool { ret = l return !f(l) }) @@ -1310,7 +1311,7 @@ func (cl *Client) publicAddr(peer net.IP) IpPort { func (cl *Client) ListenAddrs() (ret []net.Addr) { cl.lock() defer cl.unlock() - cl.eachListener(func(l socket) bool { + cl.eachListener(func(l listener) bool { ret = append(ret, l.Addr()) return true }) diff --git a/client_test.go b/client_test.go index 0278d42c..9e070eab 100644 --- a/client_test.go +++ b/client_test.go @@ -910,7 +910,7 @@ func TestClientDynamicListenPortAllProtocols(t *testing.T) { defer cl.Close() port := cl.LocalPort() assert.NotEqual(t, 0, port) - cl.eachListener(func(s socket) bool { + cl.eachListener(func(s listener) bool { assert.Equal(t, port, missinggo.AddrPort(s.Addr())) return true }) diff --git a/socket.go b/socket.go index d61e3d32..c5f7dbcd 100644 --- a/socket.go +++ b/socket.go @@ -2,95 +2,72 @@ package torrent import ( "context" - "fmt" "net" - "net/url" "strconv" "github.com/anacrolix/missinggo" "github.com/anacrolix/missinggo/perf" "github.com/pkg/errors" - "golang.org/x/net/proxy" ) type dialer interface { dial(_ context.Context, addr string) (net.Conn, error) + LocalAddr() net.Addr +} + +type listener interface { + net.Listener } type socket interface { - net.Listener + listener dialer } -func getProxyDialer(proxyURL string) (proxy.Dialer, error) { - fixedURL, err := url.Parse(proxyURL) - if err != nil { - return nil, err - } - - return proxy.FromURL(fixedURL, proxy.Direct) -} - -func listen(n network, addr, proxyURL string, f firewallCallback) (socket, error) { +func listen(n network, addr string, f firewallCallback) (socket, error) { switch { case n.Tcp: - return listenTcp(n.String(), addr, proxyURL) + return listenTcp(n.String(), addr) case n.Udp: - return listenUtp(n.String(), addr, proxyURL, f) + return listenUtp(n.String(), addr, f) default: panic(n) } } -func listenTcp(network, address, proxyURL string) (s socket, err error) { +func listenTcp(network, address string) (s socket, err error) { l, err := net.Listen(network, address) - if err != nil { - return - } - defer func() { - if err != nil { - l.Close() - } - }() - - // If we don't need the proxy - then we should return default net.Dialer, - // otherwise, let's try to parse the proxyURL and return proxy.Dialer - if len(proxyURL) != 0 { - dl := disabledListener{l} - dialer, err := getProxyDialer(proxyURL) - if err != nil { - return nil, err - } - return tcpSocket{dl, func(ctx context.Context, addr string) (conn net.Conn, err error) { - defer perf.ScopeTimerErr(&err)() - return dialer.Dial(network, addr) - }}, nil - } - dialer := net.Dialer{} - return tcpSocket{l, func(ctx context.Context, addr string) (conn net.Conn, err error) { - defer perf.ScopeTimerErr(&err)() - return dialer.DialContext(ctx, network, addr) - }}, nil -} - -type disabledListener struct { - net.Listener -} - -func (dl disabledListener) Accept() (net.Conn, error) { - return nil, fmt.Errorf("tcp listener disabled due to proxy") + return tcpSocket{ + Listener: l, + network: network, + }, err } type tcpSocket struct { net.Listener - d func(ctx context.Context, addr string) (net.Conn, error) + network string + dialer net.Dialer } -func (me tcpSocket) dial(ctx context.Context, addr string) (net.Conn, error) { - return me.d(ctx, addr) +func (me tcpSocket) dial(ctx context.Context, addr string) (_ net.Conn, err error) { + defer perf.ScopeTimerErr(&err)() + return me.dialer.DialContext(ctx, me.network, addr) } -func listenAll(networks []network, getHost func(string) string, port int, proxyURL string, f firewallCallback) ([]socket, error) { +func (me tcpSocket) LocalAddr() net.Addr { + return tcpSocketLocalAddr{me.network, me.Listener.Addr().String()} +} + +type tcpSocketLocalAddr struct { + network string + s string +} + +func (me tcpSocketLocalAddr) Network() string { return me.network } + +func (me tcpSocketLocalAddr) String() string { return "" } + +func listenAll(networks []network, getHost func(string) string, port int, f firewallCallback) ([]socket, error) { if len(networks) == 0 { return nil, nil } @@ -99,7 +76,7 @@ func listenAll(networks []network, getHost func(string) string, port int, proxyU nahs = append(nahs, networkAndHost{n, getHost(n.String())}) } for { - ss, retry, err := listenAllRetry(nahs, port, proxyURL, f) + ss, retry, err := listenAllRetry(nahs, port, f) if !retry { return ss, err } @@ -111,10 +88,10 @@ type networkAndHost struct { Host string } -func listenAllRetry(nahs []networkAndHost, port int, proxyURL string, f firewallCallback) (ss []socket, retry bool, err error) { +func listenAllRetry(nahs []networkAndHost, port int, f firewallCallback) (ss []socket, retry bool, err error) { ss = make([]socket, 1, len(nahs)) portStr := strconv.FormatInt(int64(port), 10) - ss[0], err = listen(nahs[0].Network, net.JoinHostPort(nahs[0].Host, portStr), proxyURL, f) + ss[0], err = listen(nahs[0].Network, net.JoinHostPort(nahs[0].Host, portStr), f) if err != nil { return nil, false, errors.Wrap(err, "first listen") } @@ -128,7 +105,7 @@ func listenAllRetry(nahs []networkAndHost, port int, proxyURL string, f firewall }() portStr = strconv.FormatInt(int64(missinggo.AddrPort(ss[0].Addr())), 10) for _, nah := range nahs[1:] { - s, err := listen(nah.Network, net.JoinHostPort(nah.Host, portStr), proxyURL, f) + s, err := listen(nah.Network, net.JoinHostPort(nah.Host, portStr), f) if err != nil { return ss, missinggo.IsAddrInUse(err) && port == 0, @@ -141,45 +118,17 @@ func listenAllRetry(nahs []networkAndHost, port int, proxyURL string, f firewall type firewallCallback func(net.Addr) bool -func listenUtp(network, addr, proxyURL string, fc firewallCallback) (s socket, err error) { +func listenUtp(network, addr string, fc firewallCallback) (socket, error) { us, err := NewUtpSocket(network, addr, fc) - if err != nil { - return - } - - // If we don't need the proxy - then we should return default net.Dialer, - // otherwise, let's try to parse the proxyURL and return proxy.Dialer - if len(proxyURL) != 0 { - ds := disabledUtpSocket{us} - dialer, err := getProxyDialer(proxyURL) - if err != nil { - return nil, err - } - return utpSocketSocket{ds, network, dialer}, nil - } - - return utpSocketSocket{us, network, nil}, nil -} - -type disabledUtpSocket struct { - utpSocket -} - -func (ds disabledUtpSocket) Accept() (net.Conn, error) { - return nil, fmt.Errorf("utp listener disabled due to proxy") + return utpSocketSocket{us, network}, err } type utpSocketSocket struct { utpSocket network string - d proxy.Dialer } func (me utpSocketSocket) dial(ctx context.Context, addr string) (conn net.Conn, err error) { defer perf.ScopeTimerErr(&err)() - if me.d != nil { - return me.d.Dial(me.network, addr) - } - return me.utpSocket.DialContext(ctx, me.network, addr) }