processpool.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008
  1. # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"). You
  4. # may not use this file except in compliance with the License. A copy of
  5. # the License is located at
  6. #
  7. # http://aws.amazon.com/apache2.0/
  8. #
  9. # or in the "license" file accompanying this file. This file is
  10. # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
  11. # ANY KIND, either express or implied. See the License for the specific
  12. # language governing permissions and limitations under the License.
  13. """Speeds up S3 throughput by using processes
  14. Getting Started
  15. ===============
  16. The :class:`ProcessPoolDownloader` can be used to download a single file by
  17. calling :meth:`ProcessPoolDownloader.download_file`:
  18. .. code:: python
  19. from s3transfer.processpool import ProcessPoolDownloader
  20. with ProcessPoolDownloader() as downloader:
  21. downloader.download_file('mybucket', 'mykey', 'myfile')
  22. This snippet downloads the S3 object located in the bucket ``mybucket`` at the
  23. key ``mykey`` to the local file ``myfile``. Any errors encountered during the
  24. transfer are not propagated. To determine if a transfer succeeded or
  25. failed, use the `Futures`_ interface.
  26. The :class:`ProcessPoolDownloader` can be used to download multiple files as
  27. well:
  28. .. code:: python
  29. from s3transfer.processpool import ProcessPoolDownloader
  30. with ProcessPoolDownloader() as downloader:
  31. downloader.download_file('mybucket', 'mykey', 'myfile')
  32. downloader.download_file('mybucket', 'myotherkey', 'myotherfile')
  33. When running this snippet, the downloading of ``mykey`` and ``myotherkey``
  34. happen in parallel. The first ``download_file`` call does not block the
  35. second ``download_file`` call. The snippet blocks when exiting
  36. the context manager and blocks until both downloads are complete.
  37. Alternatively, the ``ProcessPoolDownloader`` can be instantiated
  38. and explicitly be shutdown using :meth:`ProcessPoolDownloader.shutdown`:
  39. .. code:: python
  40. from s3transfer.processpool import ProcessPoolDownloader
  41. downloader = ProcessPoolDownloader()
  42. downloader.download_file('mybucket', 'mykey', 'myfile')
  43. downloader.download_file('mybucket', 'myotherkey', 'myotherfile')
  44. downloader.shutdown()
  45. For this code snippet, the call to ``shutdown`` blocks until both
  46. downloads are complete.
  47. Additional Parameters
  48. =====================
  49. Additional parameters can be provided to the ``download_file`` method:
  50. * ``extra_args``: A dictionary containing any additional client arguments
  51. to include in the
  52. `GetObject <https://botocore.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.get_object>`_
  53. API request. For example:
  54. .. code:: python
  55. from s3transfer.processpool import ProcessPoolDownloader
  56. with ProcessPoolDownloader() as downloader:
  57. downloader.download_file(
  58. 'mybucket', 'mykey', 'myfile',
  59. extra_args={'VersionId': 'myversion'})
  60. * ``expected_size``: By default, the downloader will make a HeadObject
  61. call to determine the size of the object. To opt-out of this additional
  62. API call, you can provide the size of the object in bytes:
  63. .. code:: python
  64. from s3transfer.processpool import ProcessPoolDownloader
  65. MB = 1024 * 1024
  66. with ProcessPoolDownloader() as downloader:
  67. downloader.download_file(
  68. 'mybucket', 'mykey', 'myfile', expected_size=2 * MB)
  69. Futures
  70. =======
  71. When ``download_file`` is called, it immediately returns a
  72. :class:`ProcessPoolTransferFuture`. The future can be used to poll the state
  73. of a particular transfer. To get the result of the download,
  74. call :meth:`ProcessPoolTransferFuture.result`. The method blocks
  75. until the transfer completes, whether it succeeds or fails. For example:
  76. .. code:: python
  77. from s3transfer.processpool import ProcessPoolDownloader
  78. with ProcessPoolDownloader() as downloader:
  79. future = downloader.download_file('mybucket', 'mykey', 'myfile')
  80. print(future.result())
  81. If the download succeeds, the future returns ``None``:
  82. .. code:: python
  83. None
  84. If the download fails, the exception causing the failure is raised. For
  85. example, if ``mykey`` did not exist, the following error would be raised
  86. .. code:: python
  87. botocore.exceptions.ClientError: An error occurred (404) when calling the HeadObject operation: Not Found
  88. .. note::
  89. :meth:`ProcessPoolTransferFuture.result` can only be called while the
  90. ``ProcessPoolDownloader`` is running (e.g. before calling ``shutdown`` or
  91. inside the context manager).
  92. Process Pool Configuration
  93. ==========================
  94. By default, the downloader has the following configuration options:
  95. * ``multipart_threshold``: The threshold size for performing ranged downloads
  96. in bytes. By default, ranged downloads happen for S3 objects that are
  97. greater than or equal to 8 MB in size.
  98. * ``multipart_chunksize``: The size of each ranged download in bytes. By
  99. default, the size of each ranged download is 8 MB.
  100. * ``max_request_processes``: The maximum number of processes used to download
  101. S3 objects. By default, the maximum is 10 processes.
  102. To change the default configuration, use the :class:`ProcessTransferConfig`:
  103. .. code:: python
  104. from s3transfer.processpool import ProcessPoolDownloader
  105. from s3transfer.processpool import ProcessTransferConfig
  106. config = ProcessTransferConfig(
  107. multipart_threshold=64 * 1024 * 1024, # 64 MB
  108. max_request_processes=50
  109. )
  110. downloader = ProcessPoolDownloader(config=config)
  111. Client Configuration
  112. ====================
  113. The process pool downloader creates ``botocore`` clients on your behalf. In
  114. order to affect how the client is created, pass the keyword arguments
  115. that would have been used in the :meth:`botocore.Session.create_client` call:
  116. .. code:: python
  117. from s3transfer.processpool import ProcessPoolDownloader
  118. from s3transfer.processpool import ProcessTransferConfig
  119. downloader = ProcessPoolDownloader(
  120. client_kwargs={'region_name': 'us-west-2'})
  121. This snippet ensures that all clients created by the ``ProcessPoolDownloader``
  122. are using ``us-west-2`` as their region.
  123. """
  124. import collections
  125. import contextlib
  126. import logging
  127. import multiprocessing
  128. import signal
  129. import threading
  130. from copy import deepcopy
  131. import botocore.session
  132. from botocore.config import Config
  133. from s3transfer.compat import MAXINT, BaseManager
  134. from s3transfer.constants import ALLOWED_DOWNLOAD_ARGS, MB, PROCESS_USER_AGENT
  135. from s3transfer.exceptions import CancelledError, RetriesExceededError
  136. from s3transfer.futures import BaseTransferFuture, BaseTransferMeta
  137. from s3transfer.utils import (
  138. S3_RETRYABLE_DOWNLOAD_ERRORS,
  139. CallArgs,
  140. OSUtils,
  141. calculate_num_parts,
  142. calculate_range_parameter,
  143. )
  144. logger = logging.getLogger(__name__)
  145. SHUTDOWN_SIGNAL = 'SHUTDOWN'
  146. # The DownloadFileRequest tuple is submitted from the ProcessPoolDownloader
  147. # to the GetObjectSubmitter in order for the submitter to begin submitting
  148. # GetObjectJobs to the GetObjectWorkers.
  149. DownloadFileRequest = collections.namedtuple(
  150. 'DownloadFileRequest',
  151. [
  152. 'transfer_id', # The unique id for the transfer
  153. 'bucket', # The bucket to download the object from
  154. 'key', # The key to download the object from
  155. 'filename', # The user-requested download location
  156. 'extra_args', # Extra arguments to provide to client calls
  157. 'expected_size', # The user-provided expected size of the download
  158. ],
  159. )
  160. # The GetObjectJob tuple is submitted from the GetObjectSubmitter
  161. # to the GetObjectWorkers to download the file or parts of the file.
  162. GetObjectJob = collections.namedtuple(
  163. 'GetObjectJob',
  164. [
  165. 'transfer_id', # The unique id for the transfer
  166. 'bucket', # The bucket to download the object from
  167. 'key', # The key to download the object from
  168. 'temp_filename', # The temporary file to write the content to via
  169. # completed GetObject calls.
  170. 'extra_args', # Extra arguments to provide to the GetObject call
  171. 'offset', # The offset to write the content for the temp file.
  172. 'filename', # The user-requested download location. The worker
  173. # of final GetObjectJob will move the file located at
  174. # temp_filename to the location of filename.
  175. ],
  176. )
  177. @contextlib.contextmanager
  178. def ignore_ctrl_c():
  179. original_handler = _add_ignore_handler_for_interrupts()
  180. yield
  181. signal.signal(signal.SIGINT, original_handler)
  182. def _add_ignore_handler_for_interrupts():
  183. # Windows is unable to pickle signal.signal directly so it needs to
  184. # be wrapped in a function defined at the module level
  185. return signal.signal(signal.SIGINT, signal.SIG_IGN)
  186. class ProcessTransferConfig:
  187. def __init__(
  188. self,
  189. multipart_threshold=8 * MB,
  190. multipart_chunksize=8 * MB,
  191. max_request_processes=10,
  192. ):
  193. """Configuration for the ProcessPoolDownloader
  194. :param multipart_threshold: The threshold for which ranged downloads
  195. occur.
  196. :param multipart_chunksize: The chunk size of each ranged download.
  197. :param max_request_processes: The maximum number of processes that
  198. will be making S3 API transfer-related requests at a time.
  199. """
  200. self.multipart_threshold = multipart_threshold
  201. self.multipart_chunksize = multipart_chunksize
  202. self.max_request_processes = max_request_processes
  203. class ProcessPoolDownloader:
  204. def __init__(self, client_kwargs=None, config=None):
  205. """Downloads S3 objects using process pools
  206. :type client_kwargs: dict
  207. :param client_kwargs: The keyword arguments to provide when
  208. instantiating S3 clients. The arguments must match the keyword
  209. arguments provided to the
  210. `botocore.session.Session.create_client()` method.
  211. :type config: ProcessTransferConfig
  212. :param config: Configuration for the downloader
  213. """
  214. if client_kwargs is None:
  215. client_kwargs = {}
  216. self._client_factory = ClientFactory(client_kwargs)
  217. self._transfer_config = config
  218. if config is None:
  219. self._transfer_config = ProcessTransferConfig()
  220. self._download_request_queue = multiprocessing.Queue(1000)
  221. self._worker_queue = multiprocessing.Queue(1000)
  222. self._osutil = OSUtils()
  223. self._started = False
  224. self._start_lock = threading.Lock()
  225. # These below are initialized in the start() method
  226. self._manager = None
  227. self._transfer_monitor = None
  228. self._submitter = None
  229. self._workers = []
  230. def download_file(
  231. self, bucket, key, filename, extra_args=None, expected_size=None
  232. ):
  233. """Downloads the object's contents to a file
  234. :type bucket: str
  235. :param bucket: The name of the bucket to download from
  236. :type key: str
  237. :param key: The name of the key to download from
  238. :type filename: str
  239. :param filename: The name of a file to download to.
  240. :type extra_args: dict
  241. :param extra_args: Extra arguments that may be passed to the
  242. client operation
  243. :type expected_size: int
  244. :param expected_size: The expected size in bytes of the download. If
  245. provided, the downloader will not call HeadObject to determine the
  246. object's size and use the provided value instead. The size is
  247. needed to determine whether to do a multipart download.
  248. :rtype: s3transfer.futures.TransferFuture
  249. :returns: Transfer future representing the download
  250. """
  251. self._start_if_needed()
  252. if extra_args is None:
  253. extra_args = {}
  254. self._validate_all_known_args(extra_args)
  255. transfer_id = self._transfer_monitor.notify_new_transfer()
  256. download_file_request = DownloadFileRequest(
  257. transfer_id=transfer_id,
  258. bucket=bucket,
  259. key=key,
  260. filename=filename,
  261. extra_args=extra_args,
  262. expected_size=expected_size,
  263. )
  264. logger.debug(
  265. 'Submitting download file request: %s.', download_file_request
  266. )
  267. self._download_request_queue.put(download_file_request)
  268. call_args = CallArgs(
  269. bucket=bucket,
  270. key=key,
  271. filename=filename,
  272. extra_args=extra_args,
  273. expected_size=expected_size,
  274. )
  275. future = self._get_transfer_future(transfer_id, call_args)
  276. return future
  277. def shutdown(self):
  278. """Shutdown the downloader
  279. It will wait till all downloads are complete before returning.
  280. """
  281. self._shutdown_if_needed()
  282. def __enter__(self):
  283. return self
  284. def __exit__(self, exc_type, exc_value, *args):
  285. if isinstance(exc_value, KeyboardInterrupt):
  286. if self._transfer_monitor is not None:
  287. self._transfer_monitor.notify_cancel_all_in_progress()
  288. self.shutdown()
  289. def _start_if_needed(self):
  290. with self._start_lock:
  291. if not self._started:
  292. self._start()
  293. def _start(self):
  294. self._start_transfer_monitor_manager()
  295. self._start_submitter()
  296. self._start_get_object_workers()
  297. self._started = True
  298. def _validate_all_known_args(self, provided):
  299. for kwarg in provided:
  300. if kwarg not in ALLOWED_DOWNLOAD_ARGS:
  301. download_args = ', '.join(ALLOWED_DOWNLOAD_ARGS)
  302. raise ValueError(
  303. f"Invalid extra_args key '{kwarg}', "
  304. f"must be one of: {download_args}"
  305. )
  306. def _get_transfer_future(self, transfer_id, call_args):
  307. meta = ProcessPoolTransferMeta(
  308. call_args=call_args, transfer_id=transfer_id
  309. )
  310. future = ProcessPoolTransferFuture(
  311. monitor=self._transfer_monitor, meta=meta
  312. )
  313. return future
  314. def _start_transfer_monitor_manager(self):
  315. logger.debug('Starting the TransferMonitorManager.')
  316. self._manager = TransferMonitorManager()
  317. # We do not want Ctrl-C's to cause the manager to shutdown immediately
  318. # as worker processes will still need to communicate with it when they
  319. # are shutting down. So instead we ignore Ctrl-C and let the manager
  320. # be explicitly shutdown when shutting down the downloader.
  321. self._manager.start(_add_ignore_handler_for_interrupts)
  322. self._transfer_monitor = self._manager.TransferMonitor()
  323. def _start_submitter(self):
  324. logger.debug('Starting the GetObjectSubmitter.')
  325. self._submitter = GetObjectSubmitter(
  326. transfer_config=self._transfer_config,
  327. client_factory=self._client_factory,
  328. transfer_monitor=self._transfer_monitor,
  329. osutil=self._osutil,
  330. download_request_queue=self._download_request_queue,
  331. worker_queue=self._worker_queue,
  332. )
  333. self._submitter.start()
  334. def _start_get_object_workers(self):
  335. logger.debug(
  336. 'Starting %s GetObjectWorkers.',
  337. self._transfer_config.max_request_processes,
  338. )
  339. for _ in range(self._transfer_config.max_request_processes):
  340. worker = GetObjectWorker(
  341. queue=self._worker_queue,
  342. client_factory=self._client_factory,
  343. transfer_monitor=self._transfer_monitor,
  344. osutil=self._osutil,
  345. )
  346. worker.start()
  347. self._workers.append(worker)
  348. def _shutdown_if_needed(self):
  349. with self._start_lock:
  350. if self._started:
  351. self._shutdown()
  352. def _shutdown(self):
  353. self._shutdown_submitter()
  354. self._shutdown_get_object_workers()
  355. self._shutdown_transfer_monitor_manager()
  356. self._started = False
  357. def _shutdown_transfer_monitor_manager(self):
  358. logger.debug('Shutting down the TransferMonitorManager.')
  359. self._manager.shutdown()
  360. def _shutdown_submitter(self):
  361. logger.debug('Shutting down the GetObjectSubmitter.')
  362. self._download_request_queue.put(SHUTDOWN_SIGNAL)
  363. self._submitter.join()
  364. def _shutdown_get_object_workers(self):
  365. logger.debug('Shutting down the GetObjectWorkers.')
  366. for _ in self._workers:
  367. self._worker_queue.put(SHUTDOWN_SIGNAL)
  368. for worker in self._workers:
  369. worker.join()
  370. class ProcessPoolTransferFuture(BaseTransferFuture):
  371. def __init__(self, monitor, meta):
  372. """The future associated to a submitted process pool transfer request
  373. :type monitor: TransferMonitor
  374. :param monitor: The monitor associated to the process pool downloader
  375. :type meta: ProcessPoolTransferMeta
  376. :param meta: The metadata associated to the request. This object
  377. is visible to the requester.
  378. """
  379. self._monitor = monitor
  380. self._meta = meta
  381. @property
  382. def meta(self):
  383. return self._meta
  384. def done(self):
  385. return self._monitor.is_done(self._meta.transfer_id)
  386. def result(self):
  387. try:
  388. return self._monitor.poll_for_result(self._meta.transfer_id)
  389. except KeyboardInterrupt:
  390. # For the multiprocessing Manager, a thread is given a single
  391. # connection to reuse in communicating between the thread in the
  392. # main process and the Manager's process. If a Ctrl-C happens when
  393. # polling for the result, it will make the main thread stop trying
  394. # to receive from the connection, but the Manager process will not
  395. # know that the main process has stopped trying to receive and
  396. # will not close the connection. As a result if another message is
  397. # sent to the Manager process, the listener in the Manager
  398. # processes will not process the new message as it is still trying
  399. # trying to process the previous message (that was Ctrl-C'd) and
  400. # thus cause the thread in the main process to hang on its send.
  401. # The only way around this is to create a new connection and send
  402. # messages from that new connection instead.
  403. self._monitor._connect()
  404. self.cancel()
  405. raise
  406. def cancel(self):
  407. self._monitor.notify_exception(
  408. self._meta.transfer_id, CancelledError()
  409. )
  410. class ProcessPoolTransferMeta(BaseTransferMeta):
  411. """Holds metadata about the ProcessPoolTransferFuture"""
  412. def __init__(self, transfer_id, call_args):
  413. self._transfer_id = transfer_id
  414. self._call_args = call_args
  415. self._user_context = {}
  416. @property
  417. def call_args(self):
  418. return self._call_args
  419. @property
  420. def transfer_id(self):
  421. return self._transfer_id
  422. @property
  423. def user_context(self):
  424. return self._user_context
  425. class ClientFactory:
  426. def __init__(self, client_kwargs=None):
  427. """Creates S3 clients for processes
  428. Botocore sessions and clients are not pickleable so they cannot be
  429. inherited across Process boundaries. Instead, they must be instantiated
  430. once a process is running.
  431. """
  432. self._client_kwargs = client_kwargs
  433. if self._client_kwargs is None:
  434. self._client_kwargs = {}
  435. client_config = deepcopy(self._client_kwargs.get('config', Config()))
  436. if not client_config.user_agent_extra:
  437. client_config.user_agent_extra = PROCESS_USER_AGENT
  438. else:
  439. client_config.user_agent_extra += " " + PROCESS_USER_AGENT
  440. self._client_kwargs['config'] = client_config
  441. def create_client(self):
  442. """Create a botocore S3 client"""
  443. return botocore.session.Session().create_client(
  444. 's3', **self._client_kwargs
  445. )
  446. class TransferMonitor:
  447. def __init__(self):
  448. """Monitors transfers for cross-process communication
  449. Notifications can be sent to the monitor and information can be
  450. retrieved from the monitor for a particular transfer. This abstraction
  451. is ran in a ``multiprocessing.managers.BaseManager`` in order to be
  452. shared across processes.
  453. """
  454. # TODO: Add logic that removes the TransferState if the transfer is
  455. # marked as done and the reference to the future is no longer being
  456. # held onto. Without this logic, this dictionary will continue to
  457. # grow in size with no limit.
  458. self._transfer_states = {}
  459. self._id_count = 0
  460. self._init_lock = threading.Lock()
  461. def notify_new_transfer(self):
  462. with self._init_lock:
  463. transfer_id = self._id_count
  464. self._transfer_states[transfer_id] = TransferState()
  465. self._id_count += 1
  466. return transfer_id
  467. def is_done(self, transfer_id):
  468. """Determine a particular transfer is complete
  469. :param transfer_id: Unique identifier for the transfer
  470. :return: True, if done. False, otherwise.
  471. """
  472. return self._transfer_states[transfer_id].done
  473. def notify_done(self, transfer_id):
  474. """Notify a particular transfer is complete
  475. :param transfer_id: Unique identifier for the transfer
  476. """
  477. self._transfer_states[transfer_id].set_done()
  478. def poll_for_result(self, transfer_id):
  479. """Poll for the result of a transfer
  480. :param transfer_id: Unique identifier for the transfer
  481. :return: If the transfer succeeded, it will return the result. If the
  482. transfer failed, it will raise the exception associated to the
  483. failure.
  484. """
  485. self._transfer_states[transfer_id].wait_till_done()
  486. exception = self._transfer_states[transfer_id].exception
  487. if exception:
  488. raise exception
  489. return None
  490. def notify_exception(self, transfer_id, exception):
  491. """Notify an exception was encountered for a transfer
  492. :param transfer_id: Unique identifier for the transfer
  493. :param exception: The exception encountered for that transfer
  494. """
  495. # TODO: Not all exceptions are pickleable so if we are running
  496. # this in a multiprocessing.BaseManager we will want to
  497. # make sure to update this signature to ensure pickleability of the
  498. # arguments or have the ProxyObject do the serialization.
  499. self._transfer_states[transfer_id].exception = exception
  500. def notify_cancel_all_in_progress(self):
  501. for transfer_state in self._transfer_states.values():
  502. if not transfer_state.done:
  503. transfer_state.exception = CancelledError()
  504. def get_exception(self, transfer_id):
  505. """Retrieve the exception encountered for the transfer
  506. :param transfer_id: Unique identifier for the transfer
  507. :return: The exception encountered for that transfer. Otherwise
  508. if there were no exceptions, returns None.
  509. """
  510. return self._transfer_states[transfer_id].exception
  511. def notify_expected_jobs_to_complete(self, transfer_id, num_jobs):
  512. """Notify the amount of jobs expected for a transfer
  513. :param transfer_id: Unique identifier for the transfer
  514. :param num_jobs: The number of jobs to complete the transfer
  515. """
  516. self._transfer_states[transfer_id].jobs_to_complete = num_jobs
  517. def notify_job_complete(self, transfer_id):
  518. """Notify that a single job is completed for a transfer
  519. :param transfer_id: Unique identifier for the transfer
  520. :return: The number of jobs remaining to complete the transfer
  521. """
  522. return self._transfer_states[transfer_id].decrement_jobs_to_complete()
  523. class TransferState:
  524. """Represents the current state of an individual transfer"""
  525. # NOTE: Ideally the TransferState object would be used directly by the
  526. # various different abstractions in the ProcessPoolDownloader and remove
  527. # the need for the TransferMonitor. However, it would then impose the
  528. # constraint that two hops are required to make or get any changes in the
  529. # state of a transfer across processes: one hop to get a proxy object for
  530. # the TransferState and then a second hop to communicate calling the
  531. # specific TransferState method.
  532. def __init__(self):
  533. self._exception = None
  534. self._done_event = threading.Event()
  535. self._job_lock = threading.Lock()
  536. self._jobs_to_complete = 0
  537. @property
  538. def done(self):
  539. return self._done_event.is_set()
  540. def set_done(self):
  541. self._done_event.set()
  542. def wait_till_done(self):
  543. self._done_event.wait(MAXINT)
  544. @property
  545. def exception(self):
  546. return self._exception
  547. @exception.setter
  548. def exception(self, val):
  549. self._exception = val
  550. @property
  551. def jobs_to_complete(self):
  552. return self._jobs_to_complete
  553. @jobs_to_complete.setter
  554. def jobs_to_complete(self, val):
  555. self._jobs_to_complete = val
  556. def decrement_jobs_to_complete(self):
  557. with self._job_lock:
  558. self._jobs_to_complete -= 1
  559. return self._jobs_to_complete
  560. class TransferMonitorManager(BaseManager):
  561. pass
  562. TransferMonitorManager.register('TransferMonitor', TransferMonitor)
  563. class BaseS3TransferProcess(multiprocessing.Process):
  564. def __init__(self, client_factory):
  565. super().__init__()
  566. self._client_factory = client_factory
  567. self._client = None
  568. def run(self):
  569. # Clients are not pickleable so their instantiation cannot happen
  570. # in the __init__ for processes that are created under the
  571. # spawn method.
  572. self._client = self._client_factory.create_client()
  573. with ignore_ctrl_c():
  574. # By default these processes are ran as child processes to the
  575. # main process. Any Ctrl-c encountered in the main process is
  576. # propagated to the child process and interrupt it at any time.
  577. # To avoid any potentially bad states caused from an interrupt
  578. # (i.e. a transfer failing to notify its done or making the
  579. # communication protocol become out of sync with the
  580. # TransferMonitor), we ignore all Ctrl-C's and allow the main
  581. # process to notify these child processes when to stop processing
  582. # jobs.
  583. self._do_run()
  584. def _do_run(self):
  585. raise NotImplementedError('_do_run()')
  586. class GetObjectSubmitter(BaseS3TransferProcess):
  587. def __init__(
  588. self,
  589. transfer_config,
  590. client_factory,
  591. transfer_monitor,
  592. osutil,
  593. download_request_queue,
  594. worker_queue,
  595. ):
  596. """Submit GetObjectJobs to fulfill a download file request
  597. :param transfer_config: Configuration for transfers.
  598. :param client_factory: ClientFactory for creating S3 clients.
  599. :param transfer_monitor: Monitor for notifying and retrieving state
  600. of transfer.
  601. :param osutil: OSUtils object to use for os-related behavior when
  602. performing the transfer.
  603. :param download_request_queue: Queue to retrieve download file
  604. requests.
  605. :param worker_queue: Queue to submit GetObjectJobs for workers
  606. to perform.
  607. """
  608. super().__init__(client_factory)
  609. self._transfer_config = transfer_config
  610. self._transfer_monitor = transfer_monitor
  611. self._osutil = osutil
  612. self._download_request_queue = download_request_queue
  613. self._worker_queue = worker_queue
  614. def _do_run(self):
  615. while True:
  616. download_file_request = self._download_request_queue.get()
  617. if download_file_request == SHUTDOWN_SIGNAL:
  618. logger.debug('Submitter shutdown signal received.')
  619. return
  620. try:
  621. self._submit_get_object_jobs(download_file_request)
  622. except Exception as e:
  623. logger.debug(
  624. 'Exception caught when submitting jobs for '
  625. 'download file request %s: %s',
  626. download_file_request,
  627. e,
  628. exc_info=True,
  629. )
  630. self._transfer_monitor.notify_exception(
  631. download_file_request.transfer_id, e
  632. )
  633. self._transfer_monitor.notify_done(
  634. download_file_request.transfer_id
  635. )
  636. def _submit_get_object_jobs(self, download_file_request):
  637. size = self._get_size(download_file_request)
  638. temp_filename = self._allocate_temp_file(download_file_request, size)
  639. if size < self._transfer_config.multipart_threshold:
  640. self._submit_single_get_object_job(
  641. download_file_request, temp_filename
  642. )
  643. else:
  644. self._submit_ranged_get_object_jobs(
  645. download_file_request, temp_filename, size
  646. )
  647. def _get_size(self, download_file_request):
  648. expected_size = download_file_request.expected_size
  649. if expected_size is None:
  650. expected_size = self._client.head_object(
  651. Bucket=download_file_request.bucket,
  652. Key=download_file_request.key,
  653. **download_file_request.extra_args,
  654. )['ContentLength']
  655. return expected_size
  656. def _allocate_temp_file(self, download_file_request, size):
  657. temp_filename = self._osutil.get_temp_filename(
  658. download_file_request.filename
  659. )
  660. self._osutil.allocate(temp_filename, size)
  661. return temp_filename
  662. def _submit_single_get_object_job(
  663. self, download_file_request, temp_filename
  664. ):
  665. self._notify_jobs_to_complete(download_file_request.transfer_id, 1)
  666. self._submit_get_object_job(
  667. transfer_id=download_file_request.transfer_id,
  668. bucket=download_file_request.bucket,
  669. key=download_file_request.key,
  670. temp_filename=temp_filename,
  671. offset=0,
  672. extra_args=download_file_request.extra_args,
  673. filename=download_file_request.filename,
  674. )
  675. def _submit_ranged_get_object_jobs(
  676. self, download_file_request, temp_filename, size
  677. ):
  678. part_size = self._transfer_config.multipart_chunksize
  679. num_parts = calculate_num_parts(size, part_size)
  680. self._notify_jobs_to_complete(
  681. download_file_request.transfer_id, num_parts
  682. )
  683. for i in range(num_parts):
  684. offset = i * part_size
  685. range_parameter = calculate_range_parameter(
  686. part_size, i, num_parts
  687. )
  688. get_object_kwargs = {'Range': range_parameter}
  689. get_object_kwargs.update(download_file_request.extra_args)
  690. self._submit_get_object_job(
  691. transfer_id=download_file_request.transfer_id,
  692. bucket=download_file_request.bucket,
  693. key=download_file_request.key,
  694. temp_filename=temp_filename,
  695. offset=offset,
  696. extra_args=get_object_kwargs,
  697. filename=download_file_request.filename,
  698. )
  699. def _submit_get_object_job(self, **get_object_job_kwargs):
  700. self._worker_queue.put(GetObjectJob(**get_object_job_kwargs))
  701. def _notify_jobs_to_complete(self, transfer_id, jobs_to_complete):
  702. logger.debug(
  703. 'Notifying %s job(s) to complete for transfer_id %s.',
  704. jobs_to_complete,
  705. transfer_id,
  706. )
  707. self._transfer_monitor.notify_expected_jobs_to_complete(
  708. transfer_id, jobs_to_complete
  709. )
  710. class GetObjectWorker(BaseS3TransferProcess):
  711. # TODO: It may make sense to expose these class variables as configuration
  712. # options if users want to tweak them.
  713. _MAX_ATTEMPTS = 5
  714. _IO_CHUNKSIZE = 2 * MB
  715. def __init__(self, queue, client_factory, transfer_monitor, osutil):
  716. """Fulfills GetObjectJobs
  717. Downloads the S3 object, writes it to the specified file, and
  718. renames the file to its final location if it completes the final
  719. job for a particular transfer.
  720. :param queue: Queue for retrieving GetObjectJob's
  721. :param client_factory: ClientFactory for creating S3 clients
  722. :param transfer_monitor: Monitor for notifying
  723. :param osutil: OSUtils object to use for os-related behavior when
  724. performing the transfer.
  725. """
  726. super().__init__(client_factory)
  727. self._queue = queue
  728. self._client_factory = client_factory
  729. self._transfer_monitor = transfer_monitor
  730. self._osutil = osutil
  731. def _do_run(self):
  732. while True:
  733. job = self._queue.get()
  734. if job == SHUTDOWN_SIGNAL:
  735. logger.debug('Worker shutdown signal received.')
  736. return
  737. if not self._transfer_monitor.get_exception(job.transfer_id):
  738. self._run_get_object_job(job)
  739. else:
  740. logger.debug(
  741. 'Skipping get object job %s because there was a previous '
  742. 'exception.',
  743. job,
  744. )
  745. remaining = self._transfer_monitor.notify_job_complete(
  746. job.transfer_id
  747. )
  748. logger.debug(
  749. '%s jobs remaining for transfer_id %s.',
  750. remaining,
  751. job.transfer_id,
  752. )
  753. if not remaining:
  754. self._finalize_download(
  755. job.transfer_id, job.temp_filename, job.filename
  756. )
  757. def _run_get_object_job(self, job):
  758. try:
  759. self._do_get_object(
  760. bucket=job.bucket,
  761. key=job.key,
  762. temp_filename=job.temp_filename,
  763. extra_args=job.extra_args,
  764. offset=job.offset,
  765. )
  766. except Exception as e:
  767. logger.debug(
  768. 'Exception caught when downloading object for '
  769. 'get object job %s: %s',
  770. job,
  771. e,
  772. exc_info=True,
  773. )
  774. self._transfer_monitor.notify_exception(job.transfer_id, e)
  775. def _do_get_object(self, bucket, key, extra_args, temp_filename, offset):
  776. last_exception = None
  777. for i in range(self._MAX_ATTEMPTS):
  778. try:
  779. response = self._client.get_object(
  780. Bucket=bucket, Key=key, **extra_args
  781. )
  782. self._write_to_file(temp_filename, offset, response['Body'])
  783. return
  784. except S3_RETRYABLE_DOWNLOAD_ERRORS as e:
  785. logger.debug(
  786. 'Retrying exception caught (%s), '
  787. 'retrying request, (attempt %s / %s)',
  788. e,
  789. i + 1,
  790. self._MAX_ATTEMPTS,
  791. exc_info=True,
  792. )
  793. last_exception = e
  794. raise RetriesExceededError(last_exception)
  795. def _write_to_file(self, filename, offset, body):
  796. with open(filename, 'rb+') as f:
  797. f.seek(offset)
  798. chunks = iter(lambda: body.read(self._IO_CHUNKSIZE), b'')
  799. for chunk in chunks:
  800. f.write(chunk)
  801. def _finalize_download(self, transfer_id, temp_filename, filename):
  802. if self._transfer_monitor.get_exception(transfer_id):
  803. self._osutil.remove_file(temp_filename)
  804. else:
  805. self._do_file_rename(transfer_id, temp_filename, filename)
  806. self._transfer_monitor.notify_done(transfer_id)
  807. def _do_file_rename(self, transfer_id, temp_filename, filename):
  808. try:
  809. self._osutil.rename_file(temp_filename, filename)
  810. except Exception as e:
  811. self._transfer_monitor.notify_exception(transfer_id, e)
  812. self._osutil.remove_file(temp_filename)