Added support for sending headers when using download_file.

Allows HTTP conditional GET for servers which support it.
This commit is contained in:
John Garland 2009-07-01 10:02:45 +00:00
parent d592c0370c
commit d3d991d8aa
1 changed files with 9 additions and 4 deletions

View File

@ -41,16 +41,17 @@ class HTTPDownloader(client.HTTPDownloader):
"""
Factory class for downloading files and keeping track of progress.
"""
def __init__(self, url, filename, part_callback=None):
def __init__(self, url, filename, part_callback=None, headers=None):
"""
:param url: str, the url to download from
:param filename: str, the filename to save the file as
:param part_callback: func, a function to be called when a part of data
is received, it's signature should be: func(data, current_length, total_length)
:param headers: dict, any optional headers to send
"""
self.__part_callback = part_callback
self.current_length = 0
client.HTTPDownloader.__init__(self, url, filename)
client.HTTPDownloader.__init__(self, url, filename, headers=headers)
def gotStatus(self, version, status, message):
self.code = int(status)
@ -77,7 +78,7 @@ class HTTPDownloader(client.HTTPDownloader):
return client.HTTPDownloader.pagePart(self, data)
def download_file(url, filename, callback=None):
def download_file(url, filename, callback=None, headers=None):
"""
Downloads a file from a specific URL and returns a Deferred. You can also
specify a callback function to be called as parts are received.
@ -86,10 +87,14 @@ def download_file(url, filename, callback=None):
:param filename: str, the filename to save the file as
:param callback: func, a function to be called when a part of data is received,
it's signature should be: func(data, current_length, total_length)
:param headers: dict, any optional headers to send
:raises t.w.e.PageRedirect: when server responds with a temporary redirect
or permanently moved.
:raises t.w.e.Error: for all other HTTP response errors (besides OK)
"""
url = str(url)
scheme, host, port, path = client._parse(url)
factory = HTTPDownloader(url, filename, callback)
factory = HTTPDownloader(url, filename, callback, headers)
if scheme == "https":
from twisted.internet import ssl
reactor.connectSSL(host, port, factory, ssl.ClientContextFactory())