diff --git a/config.go b/config.go index 398fdd33..11e70d68 100644 --- a/config.go +++ b/config.go @@ -1,6 +1,7 @@ package torrent import ( + "context" "net" "net/http" "net/url" @@ -90,6 +91,10 @@ type ClientConfig struct { // Defines proxy for HTTP requests, such as for trackers. It's commonly set from the result of // "net/http".ProxyURL(HTTPProxy). HTTPProxy func(*http.Request) (*url.URL, error) + // Defines DialContext func to use for HTTP tracker announcements + TrackerDialContext func(ctx context.Context, network, addr string) (net.Conn, error) + // Defines ListenPacket func to use for UDP tracker announcements + TrackerListenPacket func(network, addr string) (net.PacketConn, error) // Takes a tracker's hostname and requests DNS A and AAAA records. // Used in case DNS lookups require a special setup (i.e., dns-over-https) LookupTrackerIp func(*url.URL) ([]net.IP, error) diff --git a/tracker/client.go b/tracker/client.go index 1df0aa43..2558da70 100644 --- a/tracker/client.go +++ b/tracker/client.go @@ -2,6 +2,7 @@ package tracker import ( "context" + "net" "net/url" "github.com/anacrolix/log" @@ -19,8 +20,9 @@ type AnnounceOpt = trHttp.AnnounceOpt type NewClientOpts struct { Http trHttp.NewClientOpts // Overrides the network in the scheme. Probably a legacy thing. - UdpNetwork string - Logger log.Logger + UdpNetwork string + Logger log.Logger + ListenPacket func(network, addr string) (net.PacketConn, error) } func NewClient(urlStr string, opts NewClientOpts) (Client, error) { @@ -37,9 +39,10 @@ func NewClient(urlStr string, opts NewClientOpts) (Client, error) { network = opts.UdpNetwork } cc, err := udp.NewConnClient(udp.NewConnClientOpts{ - Network: network, - Host: _url.Host, - Logger: opts.Logger, + Network: network, + Host: _url.Host, + Logger: opts.Logger, + ListenPacket: opts.ListenPacket, }) if err != nil { return nil, err diff --git a/tracker/http/client.go b/tracker/http/client.go index d0c27a02..cd18f65a 100644 --- a/tracker/http/client.go +++ b/tracker/http/client.go @@ -1,7 +1,9 @@ package http import ( + "context" "crypto/tls" + "net" "net/http" "net/url" ) @@ -12,9 +14,11 @@ type Client struct { } type ProxyFunc func(*http.Request) (*url.URL, error) +type DialContextFunc func(ctx context.Context, network, addr string) (net.Conn, error) type NewClientOpts struct { Proxy ProxyFunc + DialContext DialContextFunc ServerName string AllowKeepAlive bool } @@ -24,6 +28,7 @@ func NewClient(url_ *url.URL, opts NewClientOpts) Client { url_: url_, hc: &http.Client{ Transport: &http.Transport{ + DialContext: opts.DialContext, Proxy: opts.Proxy, TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, diff --git a/tracker/tracker.go b/tracker/tracker.go index 7c3a5f6d..a9721c77 100644 --- a/tracker/tracker.go +++ b/tracker/tracker.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net" "net/http" "net/url" "time" @@ -33,13 +34,15 @@ type AnnounceEvent = udp.AnnounceEvent var ErrBadScheme = errors.New("unknown scheme") type Announce struct { - TrackerUrl string - Request AnnounceRequest - HostHeader string - HTTPProxy func(*http.Request) (*url.URL, error) - ServerName string - UserAgent string - UdpNetwork string + TrackerUrl string + Request AnnounceRequest + HostHeader string + HTTPProxy func(*http.Request) (*url.URL, error) + DialContext func(ctx context.Context, network, addr string) (net.Conn, error) + ListenPacket func(network, addr string) (net.PacketConn, error) + ServerName string + UserAgent string + UdpNetwork string // If the port is zero, it's assumed to be the same as the Request.Port. ClientIp4 krpc.NodeAddr // If the port is zero, it's assumed to be the same as the Request.Port. @@ -54,11 +57,13 @@ const DefaultTrackerAnnounceTimeout = 15 * time.Second func (me Announce) Do() (res AnnounceResponse, err error) { cl, err := NewClient(me.TrackerUrl, NewClientOpts{ Http: trHttp.NewClientOpts{ - Proxy: me.HTTPProxy, - ServerName: me.ServerName, + Proxy: me.HTTPProxy, + DialContext: me.DialContext, + ServerName: me.ServerName, }, - UdpNetwork: me.UdpNetwork, - Logger: me.Logger.WithContextValue(fmt.Sprintf("tracker client for %q", me.TrackerUrl)), + UdpNetwork: me.UdpNetwork, + Logger: me.Logger.WithContextValue(fmt.Sprintf("tracker client for %q", me.TrackerUrl)), + ListenPacket: me.ListenPacket, }) if err != nil { return diff --git a/tracker/udp/conn-client.go b/tracker/udp/conn-client.go index 2e1cac79..f500fc3f 100644 --- a/tracker/udp/conn-client.go +++ b/tracker/udp/conn-client.go @@ -9,6 +9,8 @@ import ( "github.com/anacrolix/missinggo/v2" ) +type listenPacketFunc func(network, addr string) (net.PacketConn, error) + type NewConnClientOpts struct { // The network to operate to use, such as "udp4", "udp", "udp6". Network string @@ -18,6 +20,8 @@ type NewConnClientOpts struct { Ipv6 *bool // Logger to use for internal errors. Logger log.Logger + // Custom function to use as a substitute for net.ListenPacket + ListenPacket listenPacketFunc } // Manages a Client with a specific connection. @@ -80,7 +84,13 @@ func (me clientWriter) Write(p []byte) (n int, err error) { } func NewConnClient(opts NewConnClientOpts) (cc *ConnClient, err error) { - conn, err := net.ListenPacket(opts.Network, ":0") + var conn net.PacketConn + if opts.ListenPacket != nil { + conn, err = opts.ListenPacket(opts.Network, ":0") + } else { + conn, err = net.ListenPacket(opts.Network, ":0") + } + if err != nil { return } diff --git a/tracker_scraper.go b/tracker_scraper.go index b441efb9..de65d9fb 100644 --- a/tracker_scraper.go +++ b/tracker_scraper.go @@ -156,17 +156,19 @@ func (me *trackerScraper) announce(ctx context.Context, event tracker.AnnounceEv defer cancel() me.t.logger.WithDefaultLevel(log.Debug).Printf("announcing to %q: %#v", me.u.String(), req) res, err := tracker.Announce{ - Context: ctx, - HTTPProxy: me.t.cl.config.HTTPProxy, - UserAgent: me.t.cl.config.HTTPUserAgent, - TrackerUrl: me.trackerUrl(ip), - Request: req, - HostHeader: me.u.Host, - ServerName: me.u.Hostname(), - UdpNetwork: me.u.Scheme, - ClientIp4: krpc.NodeAddr{IP: me.t.cl.config.PublicIp4}, - ClientIp6: krpc.NodeAddr{IP: me.t.cl.config.PublicIp6}, - Logger: me.t.logger, + Context: ctx, + HTTPProxy: me.t.cl.config.HTTPProxy, + DialContext: me.t.cl.config.TrackerDialContext, + ListenPacket: me.t.cl.config.TrackerListenPacket, + UserAgent: me.t.cl.config.HTTPUserAgent, + TrackerUrl: me.trackerUrl(ip), + Request: req, + HostHeader: me.u.Host, + ServerName: me.u.Hostname(), + UdpNetwork: me.u.Scheme, + ClientIp4: krpc.NodeAddr{IP: me.t.cl.config.PublicIp4}, + ClientIp6: krpc.NodeAddr{IP: me.t.cl.config.PublicIp6}, + Logger: me.t.logger, }.Do() me.t.logger.WithDefaultLevel(log.Debug).Printf("announce to %q returned %#v: %v", me.u.String(), res, err) if err != nil {