[HTTPDownloader] Refactor HTTPDownloader to Agent
As of twisted 16.7.0, `twisted.web.client.HTTPDownloader` have been marked as deprecated. This caused the tests results to show many lines of warnings about it. This refactor uses `twisted.web.client.Agent`, as suggested by Twisted.
This commit is contained in:
parent
089c667d7f
commit
c7e61f8c34
|
@ -15,33 +15,80 @@ import os.path
|
|||
import zlib
|
||||
|
||||
from twisted.internet import reactor
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.web import client, http
|
||||
from twisted.web.client import URI
|
||||
from twisted.web._newclient import HTTPClientParser
|
||||
from twisted.web.error import PageRedirect
|
||||
from twisted.web.http_headers import Headers
|
||||
from twisted.web.iweb import IAgent
|
||||
from zope.interface import implementer
|
||||
|
||||
from deluge.common import get_version, utf8_encode_structure
|
||||
|
||||
try:
|
||||
from urllib.parse import urljoin
|
||||
except ImportError:
|
||||
# PY2 fallback
|
||||
from urlparse import urljoin # pylint: disable=ungrouped-imports
|
||||
from deluge.common import get_version
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HTTPDownloader(client.HTTPDownloader):
|
||||
class CompressionDecoder(client.GzipDecoder):
|
||||
"""A compression decoder for gzip, x-gzip and deflate"""
|
||||
def deliverBody(self, protocol): # NOQA: N802
|
||||
self.original.deliverBody(CompressionDecoderProtocol(protocol, self.original))
|
||||
|
||||
|
||||
class CompressionDecoderProtocol(client._GzipProtocol):
|
||||
"""A compression decoder protocol for CompressionDecoder"""
|
||||
def __init__(self, protocol, response):
|
||||
super(CompressionDecoderProtocol, self).__init__(protocol, response)
|
||||
self._zlibDecompress = zlib.decompressobj(32 + zlib.MAX_WBITS)
|
||||
|
||||
|
||||
class BodyHandler(HTTPClientParser, object):
|
||||
"""An HTTP parser that saves the response on a file"""
|
||||
def __init__(self, request, finished, length, agent):
|
||||
"""
|
||||
|
||||
:param request: the request to which this parser is for
|
||||
:type request: twisted.web.iweb.IClientRequest
|
||||
:param finished: a Deferred to handle the the finished response
|
||||
:type finished: twisted.internet.defer.Deferred
|
||||
:param length: the length of the response
|
||||
:type length: int
|
||||
:param agent: the agent from which the request was sent
|
||||
:type agent: twisted.web.iweb.IAgent
|
||||
"""
|
||||
super(BodyHandler, self).__init__(request, finished)
|
||||
self.agent = agent
|
||||
self.finished = finished
|
||||
self.total_length = length
|
||||
self.current_length = 0
|
||||
self.data = b''
|
||||
|
||||
def dataReceived(self, data): # NOQA: N802
|
||||
self.current_length += len(data)
|
||||
self.data += data
|
||||
if self.agent.part_callback:
|
||||
self.agent.part_callback(data, self.current_length, self.total_length)
|
||||
|
||||
def connectionLost(self, reason): # NOQA: N802
|
||||
with open(self.agent.filename, 'wb') as _file:
|
||||
_file.write(self.data)
|
||||
self.finished.callback(self.agent.filename)
|
||||
self.state = u'DONE'
|
||||
HTTPClientParser.connectionLost(self, reason)
|
||||
|
||||
|
||||
@implementer(IAgent)
|
||||
class HTTPDownloaderAgent(object):
|
||||
"""
|
||||
Factory class for downloading files and keeping track of progress.
|
||||
A File Downloader Agent
|
||||
"""
|
||||
def __init__(
|
||||
self, url, filename, part_callback=None, headers=None,
|
||||
force_filename=False, allow_compression=True,
|
||||
self, agent, filename, part_callback=None,
|
||||
force_filename=False, allow_compression=True, handle_redirect=True,
|
||||
):
|
||||
"""
|
||||
:param url: the url to download from
|
||||
:type url: string
|
||||
:param agent: the agent which will send the requests
|
||||
:type agent: twisted.web.client.Agent
|
||||
:param filename: the filename to save the file as
|
||||
:type filename: string
|
||||
:param force_filename: forces use of the supplied filename, regardless of header content
|
||||
|
@ -49,46 +96,39 @@ class HTTPDownloader(client.HTTPDownloader):
|
|||
:param part_callback: a function to be called when a part of data
|
||||
is received, it's signature should be: func(data, current_length, total_length)
|
||||
:type part_callback: function
|
||||
:param headers: any optional headers to send
|
||||
:type headers: dictionary
|
||||
"""
|
||||
|
||||
self.handle_redirect = handle_redirect
|
||||
self.agent = agent
|
||||
self.filename = filename
|
||||
self.part_callback = part_callback
|
||||
self.current_length = 0
|
||||
self.total_length = 0
|
||||
self.decoder = None
|
||||
self.value = filename
|
||||
self.force_filename = force_filename
|
||||
self.allow_compression = allow_compression
|
||||
self.code = None
|
||||
agent = 'Deluge/%s (http://deluge-torrent.org)' % get_version()
|
||||
client.HTTPDownloader.__init__(
|
||||
self, url, filename, headers=headers, agent=agent.encode('utf-8'))
|
||||
self.decoder = None
|
||||
|
||||
def gotHeaders(self, headers): # NOQA: N802
|
||||
self.code = int(self.status)
|
||||
if self.code == http.OK:
|
||||
if b'content-length' in headers:
|
||||
self.total_length = int(headers[b'content-length'][0])
|
||||
else:
|
||||
self.total_length = 0
|
||||
def request_callback(self, response):
|
||||
finished = Deferred()
|
||||
|
||||
encodings_accepted = [b'gzip', b'x-gzip', b'deflate']
|
||||
if (
|
||||
self.allow_compression and b'content-encoding' in headers
|
||||
and headers[b'content-encoding'][0] in encodings_accepted
|
||||
):
|
||||
# Adding 32 to the wbits enables gzip & zlib decoding (with automatic header detection)
|
||||
# Adding 16 just enables gzip decoding (no zlib)
|
||||
self.decoder = zlib.decompressobj(zlib.MAX_WBITS + 32)
|
||||
if not self.handle_redirect and response.code in (
|
||||
http.MOVED_PERMANENTLY,
|
||||
http.FOUND,
|
||||
http.SEE_OTHER,
|
||||
http.TEMPORARY_REDIRECT,
|
||||
):
|
||||
location = response.headers.getRawHeaders(b'location')[0]
|
||||
error = PageRedirect(response.code, location=location)
|
||||
finished.errback(Failure(error))
|
||||
else:
|
||||
headers = response.headers
|
||||
body_length = int(headers.getRawHeaders(b'content-length', default=[0])[0])
|
||||
|
||||
if b'content-disposition' in headers and not self.force_filename:
|
||||
content_disp = headers[b'content-disposition'][0].decode('utf-8')
|
||||
if headers.hasHeader(b'content-disposition') and not self.force_filename:
|
||||
content_disp = headers.getRawHeaders(b'content-disposition')[0].decode('utf-8')
|
||||
content_disp_params = cgi.parse_header(content_disp)[1]
|
||||
if 'filename' in content_disp_params:
|
||||
new_file_name = content_disp_params['filename']
|
||||
new_file_name = sanitise_filename(new_file_name)
|
||||
new_file_name = os.path.join(os.path.split(self.value)[0], new_file_name)
|
||||
new_file_name = os.path.join(os.path.split(self.filename)[0], new_file_name)
|
||||
|
||||
count = 1
|
||||
fileroot = os.path.splitext(new_file_name)[0]
|
||||
|
@ -98,39 +138,39 @@ class HTTPDownloader(client.HTTPDownloader):
|
|||
new_file_name = '%s-%s%s' % (fileroot, count, fileext)
|
||||
count += 1
|
||||
|
||||
self.fileName = new_file_name
|
||||
self.value = new_file_name
|
||||
self.filename = new_file_name
|
||||
|
||||
elif self.code in (
|
||||
http.MOVED_PERMANENTLY,
|
||||
http.FOUND,
|
||||
http.SEE_OTHER,
|
||||
http.TEMPORARY_REDIRECT,
|
||||
):
|
||||
location = headers[b'location'][0]
|
||||
error = PageRedirect(self.code, location=location)
|
||||
self.noPage(Failure(error))
|
||||
response.deliverBody(BodyHandler(response.request, finished, body_length, self))
|
||||
|
||||
return client.HTTPDownloader.gotHeaders(self, headers)
|
||||
return finished
|
||||
|
||||
def pagePart(self, data): # NOQA: N802
|
||||
if self.code == http.OK:
|
||||
self.current_length += len(data)
|
||||
if self.decoder:
|
||||
data = self.decoder.decompress(data)
|
||||
if self.part_callback:
|
||||
self.part_callback(data, self.current_length, self.total_length)
|
||||
def request(self, method, uri, headers=None, body_producer=None):
|
||||
"""
|
||||
|
||||
return client.HTTPDownloader.pagePart(self, data)
|
||||
:param method: the HTTP method to use
|
||||
:param uri: the url to download from
|
||||
:type uri: string
|
||||
:param headers: any optional headers to send
|
||||
:type headers: twisted.web.http_headers.Headers
|
||||
:param body_producer:
|
||||
:return:
|
||||
"""
|
||||
if headers is None:
|
||||
headers = Headers()
|
||||
|
||||
def pageEnd(self): # NOQA: N802
|
||||
if self.decoder:
|
||||
data = self.decoder.flush()
|
||||
self.current_length -= len(data)
|
||||
self.decoder = None
|
||||
self.pagePart(data)
|
||||
if not headers.hasHeader(b'User-Agent'):
|
||||
version = get_version()
|
||||
user_agent = 'Deluge/%s (https://deluge-torrent.org)' % version
|
||||
headers.addRawHeader('User-Agent', user_agent)
|
||||
|
||||
return client.HTTPDownloader.pageEnd(self)
|
||||
d = self.agent.request(
|
||||
method=method,
|
||||
uri=uri,
|
||||
headers=headers,
|
||||
bodyProducer=body_producer,
|
||||
)
|
||||
d.addCallback(self.request_callback)
|
||||
return d
|
||||
|
||||
|
||||
def sanitise_filename(filename):
|
||||
|
@ -161,7 +201,10 @@ def sanitise_filename(filename):
|
|||
return filename
|
||||
|
||||
|
||||
def _download_file(url, filename, callback=None, headers=None, force_filename=False, allow_compression=True):
|
||||
def _download_file(
|
||||
url, filename, callback=None, headers=None,
|
||||
force_filename=False, allow_compression=True, handle_redirects=True,
|
||||
):
|
||||
"""
|
||||
Downloads a file from a specific URL and returns a Deferred. A callback
|
||||
function can be specified to be called as parts are received.
|
||||
|
@ -185,42 +228,24 @@ def _download_file(url, filename, callback=None, headers=None, force_filename=Fa
|
|||
|
||||
"""
|
||||
|
||||
agent = client.Agent(reactor)
|
||||
|
||||
if allow_compression:
|
||||
if not headers:
|
||||
headers = {}
|
||||
headers['accept-encoding'] = 'deflate, gzip, x-gzip'
|
||||
enc_accepted = ['gzip', 'x-gzip', 'deflate']
|
||||
decoders = [(enc.encode(), CompressionDecoder) for enc in enc_accepted]
|
||||
agent = client.ContentDecoderAgent(agent, decoders)
|
||||
if handle_redirects:
|
||||
agent = client.RedirectAgent(agent)
|
||||
|
||||
url = url.encode('utf8')
|
||||
headers = utf8_encode_structure(headers) if headers else headers
|
||||
factory = HTTPDownloader(url, filename, callback, headers, force_filename, allow_compression)
|
||||
agent = HTTPDownloaderAgent(agent, filename, callback, force_filename, allow_compression, handle_redirects)
|
||||
|
||||
uri = URI.fromBytes(url)
|
||||
host = uri.host
|
||||
port = uri.port
|
||||
# The Headers init expects dict values to be a list.
|
||||
if headers:
|
||||
for name, value in list(headers.items()):
|
||||
if not isinstance(value, list):
|
||||
headers[name] = [value]
|
||||
|
||||
if uri.scheme == b'https':
|
||||
from twisted.internet import ssl
|
||||
# ClientTLSOptions in Twisted >= 14, see ticket #2765 for details on this addition.
|
||||
try:
|
||||
from twisted.internet._sslverify import ClientTLSOptions
|
||||
except ImportError:
|
||||
ctx_factory = ssl.ClientContextFactory()
|
||||
else:
|
||||
class TLSSNIContextFactory(ssl.ClientContextFactory): # pylint: disable=no-init
|
||||
"""
|
||||
A custom context factory to add a server name for TLS connections.
|
||||
"""
|
||||
def getContext(self): # NOQA: N802
|
||||
ctx = ssl.ClientContextFactory.getContext(self)
|
||||
ClientTLSOptions(host, ctx)
|
||||
return ctx
|
||||
ctx_factory = TLSSNIContextFactory()
|
||||
|
||||
reactor.connectSSL(host, port, factory, ctx_factory)
|
||||
else:
|
||||
reactor.connectTCP(host, port, factory)
|
||||
|
||||
return factory.deferred
|
||||
return agent.request(b'GET', url.encode(), Headers(headers))
|
||||
|
||||
|
||||
def download_file(
|
||||
|
@ -255,26 +280,17 @@ def download_file(
|
|||
return result
|
||||
|
||||
def on_download_fail(failure):
|
||||
if failure.check(PageRedirect) and handle_redirects:
|
||||
new_url = urljoin(url, failure.getErrorMessage().split(' to ')[1])
|
||||
result = _download_file(
|
||||
new_url, filename, callback=callback, headers=headers,
|
||||
force_filename=force_filename,
|
||||
allow_compression=allow_compression,
|
||||
)
|
||||
result.addCallbacks(on_download_success, on_download_fail)
|
||||
else:
|
||||
# Log the failure and pass to the caller
|
||||
log.warning(
|
||||
'Error occurred downloading file from "%s": %s',
|
||||
url, failure.getErrorMessage(),
|
||||
)
|
||||
result = failure
|
||||
log.warning(
|
||||
'Error occurred downloading file from "%s": %s',
|
||||
url, failure.getErrorMessage(),
|
||||
)
|
||||
result = failure
|
||||
return result
|
||||
|
||||
d = _download_file(
|
||||
url, filename, callback=callback, headers=headers,
|
||||
force_filename=force_filename, allow_compression=allow_compression,
|
||||
handle_redirects=handle_redirects,
|
||||
)
|
||||
d.addCallbacks(on_download_success, on_download_fail)
|
||||
return d
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 1.1 KiB |
|
@ -59,6 +59,14 @@ class TrackerIconsTestCase(BaseTestCase):
|
|||
d.addCallback(self.assertEqual, icon)
|
||||
return d
|
||||
|
||||
def test_get_seo_ico_with_sni(self):
|
||||
# seo using certificates with SNI support only
|
||||
icon = TrackerIcon(common.get_test_data_file('seo.ico'))
|
||||
d = self.icons.fetch('www.seo.com')
|
||||
d.addCallback(self.assertNotIdentical, None)
|
||||
d.addCallback(self.assertEqual, icon)
|
||||
return d
|
||||
|
||||
def test_get_empty_string_tracker(self):
|
||||
d = self.icons.fetch('')
|
||||
d.addCallback(self.assertIdentical, None)
|
||||
|
|
Loading…
Reference in New Issue