# encoding=utf-8
'''Document writers.'''
# Wpull. Copyright 2013-2015: Christopher Foo and others. License: GPL v3.
import abc
import email.utils
import gettext
import http.client
import itertools
import logging
import os
import re
import shutil
import time
from typing import cast, BinaryIO, Optional
from wpull.backport.logging import StyleAdapter
from wpull.body import Body
from wpull.document.css import CSSReader
from wpull.document.html import HTMLReader
from wpull.path import anti_clobber_dir_path, parse_content_disposition, \
PathNamer
import wpull.util
from wpull.protocol.abstract.request import BaseRequest, BaseResponse, \
SerializableMixin
from wpull.protocol.http.request import Response as HTTPResponse
from wpull.protocol.ftp.request import Response as FTPResponse
_ = gettext.gettext
_logger = StyleAdapter(logging.getLogger(__name__))
[docs]class BaseWriter(object, metaclass=abc.ABCMeta):
'''Base class for document writers.'''
@abc.abstractmethod
[docs] def session(self) -> 'BaseWriterSession':
'''Return a session for a document.'''
[docs]class BaseWriterSession(object, metaclass=abc.ABCMeta):
'''Base class for a single document to be written.'''
@abc.abstractmethod
[docs] def process_request(self, request: BaseRequest) -> BaseRequest:
'''Rewrite the request if needed.
This function is called by a Processor after it has created the
Request, but before submitting it to a Client.
Returns:
The original Request or a modified Request
'''
@abc.abstractmethod
[docs] def process_response(self, response: BaseResponse):
'''Do any processing using the given response if needed.
This function is called by a Processor before any response or error
handling is done.
'''
@abc.abstractmethod
[docs] def save_document(self, response: BaseResponse) -> str:
'''Process and save the document.
This function is called by a Processor once the Processor deemed
the document should be saved (i.e., a "200 OK" response).
Returns:
The filename of the document.
'''
@abc.abstractmethod
[docs] def discard_document(self, response: BaseResponse):
'''Don't save the document.
This function is called by a Processor once the Processor deemed
the document should be deleted (i.e., a "404 Not Found" response).
'''
@abc.abstractmethod
[docs]class BaseFileWriterSession(BaseWriterSession):
'''Base class for File Writer Sessions.'''
def __init__(self, path_namer: PathNamer,
file_continuing: bool,
headers_included: bool,
local_timestamping: bool,
adjust_extension: bool,
content_disposition: bool,
trust_server_names: bool):
self._path_namer = path_namer
self._file_continuing = file_continuing
self._headers_included = headers_included
self._local_timestamping = local_timestamping
self._adjust_extension = adjust_extension
self._content_disposition = content_disposition
self._trust_server_names = trust_server_names
self._filename = None
self._file_continue_requested = False
@classmethod
[docs] def open_file(cls, filename: str, response: BaseResponse, mode='wb+'):
'''Open a file object on to the Response Body.
Args:
filename: The path where the file is to be saved
response: Response
mode: The file mode
This function will create the directories if not exist.
'''
_logger.debug('Saving file to {0}, mode={1}.',
filename, mode)
dir_path = os.path.dirname(filename)
if dir_path and not os.path.exists(dir_path):
os.makedirs(dir_path)
response.body = Body(open(filename, mode))
@classmethod
[docs] def set_timestamp(cls, filename: str, response: HTTPResponse):
'''Set the Last-Modified timestamp onto the given file.
Args:
filename: The path of the file
response: Response
'''
last_modified = response.fields.get('Last-Modified')
if not last_modified:
return
try:
last_modified = email.utils.parsedate(last_modified)
except ValueError:
_logger.exception('Failed to parse date.')
return
last_modified = time.mktime(last_modified)
os.utime(filename, (time.time(), last_modified))
@classmethod
[docs] def process_request(self, request: BaseRequest):
if not self._filename:
self._filename = self._compute_filename(request)
if self._file_continuing and self._filename:
self._process_file_continue_request(request)
return request
def _compute_filename(self, request: BaseRequest):
'''Get the appropriate filename from the request.'''
path = self._path_namer.get_filename(request.url_info)
if os.path.isdir(path):
path += '.f'
else:
dir_name, name = os.path.split(path)
path = os.path.join(anti_clobber_dir_path(dir_name), name)
return path
def _process_file_continue_request(self, request: BaseRequest):
'''Modify the request to resume downloading file.'''
if os.path.exists(self._filename):
size = os.path.getsize(self._filename)
request.set_continue(size)
self._file_continue_requested = True
_logger.debug('Continue file from {0}.', size)
else:
_logger.debug('No file to continue.')
[docs] def process_response(self, response: BaseResponse):
if not self._filename:
return
if response.request.url_info.scheme == 'ftp':
response = cast(FTPResponse, response)
if self._file_continue_requested:
self._process_file_continue_ftp_response(response)
else:
self.open_file(self._filename, response)
else:
response = cast(HTTPResponse, response)
code = response.status_code
if self._file_continue_requested:
self._process_file_continue_response(response)
elif 200 <= code <= 299 or 400 <= code:
if self._trust_server_names:
self._rename_with_last_response(response)
if self._content_disposition:
self._rename_with_content_disposition(response)
if self._adjust_extension:
self._append_filename_extension(response)
self.open_file(self._filename, response)
def _process_file_continue_response(self, response: HTTPResponse):
'''Process a partial content response.'''
code = response.status_code
if code == http.client.PARTIAL_CONTENT:
self.open_file(self._filename, response, mode='ab+')
else:
self._raise_cannot_continue_error()
def _process_file_continue_ftp_response(self, response: FTPResponse):
'''Process a restarted content response.'''
if response.request.restart_value and response.restart_value:
self.open_file(self._filename, response, mode='ab+')
else:
self._raise_cannot_continue_error()
def _raise_cannot_continue_error(self):
'''Raise an error when server cannot continue a file.'''
# XXX: I cannot find where wget refuses to resume a file
# when the server does not support range requests. Wget has
# enums that appear to define this case, it is checked throughout
# the code, but the HTTP function doesn't even use them.
# FIXME: unit test is needed for this case
raise IOError(
_('Server not able to continue file download: {filename}.')
.format(filename=self._filename))
def _append_filename_extension(self, response: BaseResponse):
'''Append an HTML/CSS file suffix as needed.'''
if not self._filename:
return
if response.request.url_info.scheme not in ('http', 'https'):
return
if not re.search(r'\.[hH][tT][mM][lL]?$', self._filename) and \
HTMLReader.is_response(response):
self._filename += '.html'
elif not re.search(r'\.[cC][sS][sS]$', self._filename) and \
CSSReader.is_response(response):
self._filename += '.css'
def _rename_with_content_disposition(self, response: HTTPResponse):
'''Rename using the Content-Disposition header.'''
if not self._filename:
return
if response.request.url_info.scheme not in ('http', 'https'):
return
header_value = response.fields.get('Content-Disposition')
if not header_value:
return
filename = parse_content_disposition(header_value)
if filename:
dir_path = os.path.dirname(self._filename)
new_filename = self._path_namer.safe_filename(filename)
self._filename = os.path.join(dir_path, new_filename)
def _rename_with_last_response(self, response):
if not self._filename:
return
if response.request.url_info.scheme not in ('http', 'https'):
return
self._filename = self._compute_filename(response.request)
[docs] def save_document(self, response: BaseResponse):
if self._filename and os.path.exists(self._filename):
if self._headers_included:
self.save_headers(self._filename, response)
if self._local_timestamping and \
response.request.url_info.scheme in ('http', 'https'):
self.set_timestamp(self._filename, cast(HTTPResponse, response))
return self._filename
[docs] def discard_document(self, response: BaseResponse):
if self._filename and os.path.exists(self._filename):
os.remove(self._filename)
[docs]class BaseFileWriter(BaseWriter):
'''Base class for saving documents to disk.
Args:
path_namer: The path namer.
file_continuing: If True, the writer will modify requests to fetch
the remaining portion of the file
headers_included: If True, the writer will include the HTTP header
responses on top of the document
local_timestamping: If True, the writer will set the Last-Modified
timestamp on downloaded files
adjust_extension: If True, HTML or CSS file extension will be added
whenever it is detected as so.
content_disposition: If True, the filename is extracted from
the Content-Disposition header.
trust_server_names: If True and there is redirection, use the last
given response for the filename.
'''
def __init__(self, path_namer: PathNamer,
file_continuing: bool=False,
headers_included: bool=False,
local_timestamping: bool=True,
adjust_extension: bool=False,
content_disposition: bool=False,
trust_server_names: bool=False):
self._path_namer = path_namer
self._file_continuing = file_continuing
self._headers_included = headers_included
self._local_timestamping = local_timestamping
self._adjust_extension = adjust_extension
self._content_disposition = content_disposition
self._trust_server_names = trust_server_names
@abc.abstractproperty
def session_class(self) -> object:
'''Return the class of File Writer Session.
This should be overridden by subclasses.
'''
[docs] def session(self) -> BaseFileWriterSession:
'''Return the File Writer Session.'''
return self.session_class(
self._path_namer,
self._file_continuing,
self._headers_included,
self._local_timestamping,
self._adjust_extension,
self._content_disposition,
self._trust_server_names,
)
[docs]class OverwriteFileWriter(BaseFileWriter):
'''File writer that overwrites files.'''
@property
def session_class(self):
return OverwriteFileWriterSession
[docs]class OverwriteFileWriterSession(BaseFileWriterSession):
pass
[docs]class IgnoreFileWriter(BaseFileWriter):
'''File writer that ignores files that already exist.'''
@property
def session_class(self):
return IgnoreFileWriterSession
[docs]class IgnoreFileWriterSession(BaseFileWriterSession):
[docs] def process_request(self, request):
if not self._filename or not os.path.exists(self._filename):
return super().process_request(request)
[docs]class AntiClobberFileWriter(BaseFileWriter):
'''File writer that downloads to a new filename if the original exists.'''
@property
def session_class(self):
return AntiClobberFileWriterSession
[docs]class AntiClobberFileWriterSession(BaseFileWriterSession):
def _compute_filename(self, request: BaseRequest):
original_filename = self._path_namer.get_filename(request.url_info)
dir_name, filename = os.path.split(original_filename)
original_filename = os.path.join(
anti_clobber_dir_path(dir_name), filename
)
candidate_filename = original_filename
for suffix in itertools.count():
if suffix:
candidate_filename = '{0}.{1}'.format(original_filename,
suffix)
if not os.path.exists(candidate_filename):
return candidate_filename
[docs]class TimestampingFileWriter(BaseFileWriter):
'''File writer that only downloads newer files from the server.'''
@property
def session_class(self) -> BaseFileWriterSession:
return TimestampingFileWriterSession
[docs]class TimestampingFileWriterSession(BaseFileWriterSession):
[docs] def process_request(self, request: BaseRequest):
request = super().process_request(request)
orig_file = '{0}.orig'.format(self._filename)
if os.path.exists(orig_file):
modified_time = os.path.getmtime(orig_file)
elif os.path.exists(self._filename):
modified_time = os.path.getmtime(self._filename)
else:
modified_time = None
_logger.debug('Checking for last modified={0}.', modified_time)
if modified_time:
date_str = email.utils.formatdate(modified_time)
request.fields['If-Modified-Since'] = date_str
return request
[docs]class NullWriterSession(BaseWriterSession):
[docs] def process_request(self, request):
return request
[docs] def process_response(self, response):
return response
[docs] def discard_document(self, response):
pass
[docs] def save_document(self, response):
pass
[docs]class NullWriter(BaseWriter):
'''File writer that doesn't write files.'''
[docs] def session(self) -> NullWriterSession:
return NullWriterSession()
[docs]class MuxBody(Body):
'''Writes data into a second file.'''
def __init__(self, stream: BinaryIO, **kwargs):
super().__init__(**kwargs)
self._stream = stream
[docs] def write(self, data: bytes) -> int:
self._stream.write(data)
return super().__getattr__('write')(data)
[docs] def writelines(self, lines):
for line in lines:
self._stream.write(line)
return super().__getattr__('writelines')(lines)
[docs] def flush(self):
self._stream.flush()
return super().__getattr__('flush')()
[docs] def close(self):
self._stream.close()
return super().__getattr__('close')()
[docs]class SingleDocumentWriterSession(BaseWriterSession):
'''Write all data into stream.'''
def __init__(self, stream: BinaryIO, headers_included: bool):
self._stream = stream
self._headers_included = headers_included
[docs] def process_request(self, request):
return request
[docs] def process_response(self, response: BaseResponse):
if self._headers_included and isinstance(response, SerializableMixin):
self._stream.write(response.to_bytes())
if not self._stream.readable():
response.body = MuxBody(self._stream)
else:
response.body = Body(self._stream)
return response
[docs] def discard_document(self, response):
response.body.flush()
[docs] def save_document(self, response):
response.body.flush()
[docs]class SingleDocumentWriter(BaseWriter):
'''Writer that writes all the data into a single file.'''
def __init__(self, stream: BinaryIO, headers_included: bool=False):
self._stream = stream
self._headers_included = headers_included
[docs] def session(self) -> SingleDocumentWriterSession:
return SingleDocumentWriterSession(self._stream, self._headers_included)