diff --git a/deluge/httpdownloader.py b/deluge/httpdownloader.py index 1e0f40565..f3afb3c66 100644 --- a/deluge/httpdownloader.py +++ b/deluge/httpdownloader.py @@ -41,16 +41,17 @@ class HTTPDownloader(client.HTTPDownloader): """ Factory class for downloading files and keeping track of progress. """ - def __init__(self, url, filename, part_callback=None): + def __init__(self, url, filename, part_callback=None, headers=None): """ :param url: str, the url to download from :param filename: str, the filename to save the file as :param part_callback: func, a function to be called when a part of data is received, it's signature should be: func(data, current_length, total_length) + :param headers: dict, any optional headers to send """ self.__part_callback = part_callback self.current_length = 0 - client.HTTPDownloader.__init__(self, url, filename) + client.HTTPDownloader.__init__(self, url, filename, headers=headers) def gotStatus(self, version, status, message): self.code = int(status) @@ -77,7 +78,7 @@ class HTTPDownloader(client.HTTPDownloader): return client.HTTPDownloader.pagePart(self, data) -def download_file(url, filename, callback=None): +def download_file(url, filename, callback=None, headers=None): """ Downloads a file from a specific URL and returns a Deferred. You can also specify a callback function to be called as parts are received. @@ -86,10 +87,14 @@ def download_file(url, filename, callback=None): :param filename: str, the filename to save the file as :param callback: func, a function to be called when a part of data is received, it's signature should be: func(data, current_length, total_length) + :param headers: dict, any optional headers to send + :raises t.w.e.PageRedirect: when server responds with a temporary redirect + or permanently moved. + :raises t.w.e.Error: for all other HTTP response errors (besides OK) """ url = str(url) scheme, host, port, path = client._parse(url) - factory = HTTPDownloader(url, filename, callback) + factory = HTTPDownloader(url, filename, callback, headers) if scheme == "https": from twisted.internet import ssl reactor.connectSSL(host, port, factory, ssl.ClientContextFactory())