test_processpool.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"). You
  4. # may not use this file except in compliance with the License. A copy of
  5. # the License is located at
  6. #
  7. # http://aws.amazon.com/apache2.0/
  8. #
  9. # or in the "license" file accompanying this file. This file is
  10. # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
  11. # ANY KIND, either express or implied. See the License for the specific
  12. # language governing permissions and limitations under the License.
  13. import glob
  14. import os
  15. from io import BytesIO
  16. from multiprocessing.managers import BaseManager
  17. import botocore.exceptions
  18. import botocore.session
  19. from botocore.stub import Stubber
  20. from s3transfer.exceptions import CancelledError
  21. from s3transfer.processpool import ProcessPoolDownloader, ProcessTransferConfig
  22. from tests import FileCreator, mock, unittest
  23. class StubbedClient:
  24. def __init__(self):
  25. self._client = botocore.session.get_session().create_client(
  26. 's3',
  27. 'us-west-2',
  28. aws_access_key_id='foo',
  29. aws_secret_access_key='bar',
  30. )
  31. self._stubber = Stubber(self._client)
  32. self._stubber.activate()
  33. self._caught_stubber_errors = []
  34. def get_object(self, **kwargs):
  35. return self._client.get_object(**kwargs)
  36. def head_object(self, **kwargs):
  37. return self._client.head_object(**kwargs)
  38. def add_response(self, *args, **kwargs):
  39. self._stubber.add_response(*args, **kwargs)
  40. def add_client_error(self, *args, **kwargs):
  41. self._stubber.add_client_error(*args, **kwargs)
  42. class StubbedClientManager(BaseManager):
  43. pass
  44. StubbedClientManager.register('StubbedClient', StubbedClient)
  45. # Ideally a Mock would be used here. However, they cannot be pickled
  46. # for Windows. So instead we define a factory class at the module level that
  47. # can return a stubbed client we initialized in the setUp.
  48. class StubbedClientFactory:
  49. def __init__(self, stubbed_client):
  50. self._stubbed_client = stubbed_client
  51. def __call__(self, *args, **kwargs):
  52. # The __call__ is defined so we can provide an instance of the
  53. # StubbedClientFactory to mock.patch() and have the instance be
  54. # returned when the patched class is instantiated.
  55. return self
  56. def create_client(self):
  57. return self._stubbed_client
  58. class TestProcessPoolDownloader(unittest.TestCase):
  59. def setUp(self):
  60. # The stubbed client needs to run in a manager to be shared across
  61. # processes and have it properly consume the stubbed response across
  62. # processes.
  63. self.manager = StubbedClientManager()
  64. self.manager.start()
  65. self.stubbed_client = self.manager.StubbedClient()
  66. self.stubbed_client_factory = StubbedClientFactory(self.stubbed_client)
  67. self.client_factory_patch = mock.patch(
  68. 's3transfer.processpool.ClientFactory', self.stubbed_client_factory
  69. )
  70. self.client_factory_patch.start()
  71. self.files = FileCreator()
  72. self.config = ProcessTransferConfig(max_request_processes=1)
  73. self.downloader = ProcessPoolDownloader(config=self.config)
  74. self.bucket = 'mybucket'
  75. self.key = 'mykey'
  76. self.filename = self.files.full_path('filename')
  77. self.remote_contents = b'my content'
  78. self.stream = BytesIO(self.remote_contents)
  79. def tearDown(self):
  80. self.manager.shutdown()
  81. self.client_factory_patch.stop()
  82. self.files.remove_all()
  83. def assert_contents(self, filename, expected_contents):
  84. self.assertTrue(os.path.exists(filename))
  85. with open(filename, 'rb') as f:
  86. self.assertEqual(f.read(), expected_contents)
  87. def test_download_file(self):
  88. self.stubbed_client.add_response(
  89. 'head_object', {'ContentLength': len(self.remote_contents)}
  90. )
  91. self.stubbed_client.add_response('get_object', {'Body': self.stream})
  92. with self.downloader:
  93. self.downloader.download_file(self.bucket, self.key, self.filename)
  94. self.assert_contents(self.filename, self.remote_contents)
  95. def test_download_multiple_files(self):
  96. self.stubbed_client.add_response('get_object', {'Body': self.stream})
  97. self.stubbed_client.add_response(
  98. 'get_object', {'Body': BytesIO(self.remote_contents)}
  99. )
  100. with self.downloader:
  101. self.downloader.download_file(
  102. self.bucket,
  103. self.key,
  104. self.filename,
  105. expected_size=len(self.remote_contents),
  106. )
  107. other_file = self.files.full_path('filename2')
  108. self.downloader.download_file(
  109. self.bucket,
  110. self.key,
  111. other_file,
  112. expected_size=len(self.remote_contents),
  113. )
  114. self.assert_contents(self.filename, self.remote_contents)
  115. self.assert_contents(other_file, self.remote_contents)
  116. def test_download_file_ranged_download(self):
  117. half_of_content_length = int(len(self.remote_contents) / 2)
  118. self.stubbed_client.add_response(
  119. 'head_object', {'ContentLength': len(self.remote_contents)}
  120. )
  121. self.stubbed_client.add_response(
  122. 'get_object',
  123. {'Body': BytesIO(self.remote_contents[:half_of_content_length])},
  124. )
  125. self.stubbed_client.add_response(
  126. 'get_object',
  127. {'Body': BytesIO(self.remote_contents[half_of_content_length:])},
  128. )
  129. downloader = ProcessPoolDownloader(
  130. config=ProcessTransferConfig(
  131. multipart_chunksize=half_of_content_length,
  132. multipart_threshold=half_of_content_length,
  133. max_request_processes=1,
  134. )
  135. )
  136. with downloader:
  137. downloader.download_file(self.bucket, self.key, self.filename)
  138. self.assert_contents(self.filename, self.remote_contents)
  139. def test_download_file_extra_args(self):
  140. self.stubbed_client.add_response(
  141. 'head_object',
  142. {'ContentLength': len(self.remote_contents)},
  143. expected_params={
  144. 'Bucket': self.bucket,
  145. 'Key': self.key,
  146. 'VersionId': 'versionid',
  147. },
  148. )
  149. self.stubbed_client.add_response(
  150. 'get_object',
  151. {'Body': self.stream},
  152. expected_params={
  153. 'Bucket': self.bucket,
  154. 'Key': self.key,
  155. 'VersionId': 'versionid',
  156. },
  157. )
  158. with self.downloader:
  159. self.downloader.download_file(
  160. self.bucket,
  161. self.key,
  162. self.filename,
  163. extra_args={'VersionId': 'versionid'},
  164. )
  165. self.assert_contents(self.filename, self.remote_contents)
  166. def test_download_file_expected_size(self):
  167. self.stubbed_client.add_response('get_object', {'Body': self.stream})
  168. with self.downloader:
  169. self.downloader.download_file(
  170. self.bucket,
  171. self.key,
  172. self.filename,
  173. expected_size=len(self.remote_contents),
  174. )
  175. self.assert_contents(self.filename, self.remote_contents)
  176. def test_cleans_up_tempfile_on_failure(self):
  177. self.stubbed_client.add_client_error('get_object', 'NoSuchKey')
  178. with self.downloader:
  179. self.downloader.download_file(
  180. self.bucket,
  181. self.key,
  182. self.filename,
  183. expected_size=len(self.remote_contents),
  184. )
  185. self.assertFalse(os.path.exists(self.filename))
  186. # Any tempfile should have been erased as well
  187. possible_matches = glob.glob('%s*' % self.filename + os.extsep)
  188. self.assertEqual(possible_matches, [])
  189. def test_validates_extra_args(self):
  190. with self.downloader:
  191. with self.assertRaises(ValueError):
  192. self.downloader.download_file(
  193. self.bucket,
  194. self.key,
  195. self.filename,
  196. extra_args={'NotSupported': 'NotSupported'},
  197. )
  198. def test_result_with_success(self):
  199. self.stubbed_client.add_response('get_object', {'Body': self.stream})
  200. with self.downloader:
  201. future = self.downloader.download_file(
  202. self.bucket,
  203. self.key,
  204. self.filename,
  205. expected_size=len(self.remote_contents),
  206. )
  207. self.assertIsNone(future.result())
  208. def test_result_with_exception(self):
  209. self.stubbed_client.add_client_error('get_object', 'NoSuchKey')
  210. with self.downloader:
  211. future = self.downloader.download_file(
  212. self.bucket,
  213. self.key,
  214. self.filename,
  215. expected_size=len(self.remote_contents),
  216. )
  217. with self.assertRaises(botocore.exceptions.ClientError):
  218. future.result()
  219. def test_result_with_cancel(self):
  220. self.stubbed_client.add_response('get_object', {'Body': self.stream})
  221. with self.downloader:
  222. future = self.downloader.download_file(
  223. self.bucket,
  224. self.key,
  225. self.filename,
  226. expected_size=len(self.remote_contents),
  227. )
  228. future.cancel()
  229. with self.assertRaises(CancelledError):
  230. future.result()
  231. def test_shutdown_with_no_downloads(self):
  232. downloader = ProcessPoolDownloader()
  233. try:
  234. downloader.shutdown()
  235. except AttributeError:
  236. self.fail(
  237. 'The downloader should be able to be shutdown even though '
  238. 'the downloader was never started.'
  239. )
  240. def test_shutdown_with_no_downloads_and_ctrl_c(self):
  241. # Special shutdown logic happens if a KeyboardInterrupt is raised in
  242. # the context manager. However, this logic can not happen if the
  243. # downloader was never started. So a KeyboardInterrupt should be
  244. # the only exception propagated.
  245. with self.assertRaises(KeyboardInterrupt):
  246. with self.downloader:
  247. raise KeyboardInterrupt()