test_s3transfer.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780
  1. # Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the 'License'). You
  4. # may not use this file except in compliance with the License. A copy of
  5. # the License is located at
  6. #
  7. # http://aws.amazon.com/apache2.0/
  8. #
  9. # or in the 'license' file accompanying this file. This file is
  10. # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
  11. # ANY KIND, either express or implied. See the License for the specific
  12. # language governing permissions and limitations under the License.
  13. import os
  14. import shutil
  15. import socket
  16. import tempfile
  17. from concurrent import futures
  18. from contextlib import closing
  19. from io import BytesIO, StringIO
  20. from s3transfer import (
  21. MultipartDownloader,
  22. MultipartUploader,
  23. OSUtils,
  24. QueueShutdownError,
  25. ReadFileChunk,
  26. S3Transfer,
  27. ShutdownQueue,
  28. StreamReaderProgress,
  29. TransferConfig,
  30. disable_upload_callbacks,
  31. enable_upload_callbacks,
  32. random_file_extension,
  33. )
  34. from s3transfer.exceptions import RetriesExceededError, S3UploadFailedError
  35. from tests import mock, unittest
  36. class InMemoryOSLayer(OSUtils):
  37. def __init__(self, filemap):
  38. self.filemap = filemap
  39. def get_file_size(self, filename):
  40. return len(self.filemap[filename])
  41. def open_file_chunk_reader(self, filename, start_byte, size, callback):
  42. return closing(BytesIO(self.filemap[filename]))
  43. def open(self, filename, mode):
  44. if 'wb' in mode:
  45. fileobj = BytesIO()
  46. self.filemap[filename] = fileobj
  47. return closing(fileobj)
  48. else:
  49. return closing(self.filemap[filename])
  50. def remove_file(self, filename):
  51. if filename in self.filemap:
  52. del self.filemap[filename]
  53. def rename_file(self, current_filename, new_filename):
  54. if current_filename in self.filemap:
  55. self.filemap[new_filename] = self.filemap.pop(current_filename)
  56. class SequentialExecutor:
  57. def __init__(self, max_workers):
  58. pass
  59. def __enter__(self):
  60. return self
  61. def __exit__(self, *args, **kwargs):
  62. pass
  63. # The real map() interface actually takes *args, but we specifically do
  64. # _not_ use this interface.
  65. def map(self, function, args):
  66. results = []
  67. for arg in args:
  68. results.append(function(arg))
  69. return results
  70. def submit(self, function):
  71. future = futures.Future()
  72. future.set_result(function())
  73. return future
  74. class TestOSUtils(unittest.TestCase):
  75. def setUp(self):
  76. self.tempdir = tempfile.mkdtemp()
  77. def tearDown(self):
  78. shutil.rmtree(self.tempdir)
  79. def test_get_file_size(self):
  80. with mock.patch('os.path.getsize') as m:
  81. OSUtils().get_file_size('myfile')
  82. m.assert_called_with('myfile')
  83. def test_open_file_chunk_reader(self):
  84. with mock.patch('s3transfer.ReadFileChunk') as m:
  85. OSUtils().open_file_chunk_reader('myfile', 0, 100, None)
  86. m.from_filename.assert_called_with(
  87. 'myfile', 0, 100, None, enable_callback=False
  88. )
  89. def test_open_file(self):
  90. fileobj = OSUtils().open(os.path.join(self.tempdir, 'foo'), 'w')
  91. self.assertTrue(hasattr(fileobj, 'write'))
  92. def test_remove_file_ignores_errors(self):
  93. with mock.patch('os.remove') as remove:
  94. remove.side_effect = OSError('fake error')
  95. OSUtils().remove_file('foo')
  96. remove.assert_called_with('foo')
  97. def test_remove_file_proxies_remove_file(self):
  98. with mock.patch('os.remove') as remove:
  99. OSUtils().remove_file('foo')
  100. remove.assert_called_with('foo')
  101. def test_rename_file(self):
  102. with mock.patch('s3transfer.compat.rename_file') as rename_file:
  103. OSUtils().rename_file('foo', 'newfoo')
  104. rename_file.assert_called_with('foo', 'newfoo')
  105. class TestReadFileChunk(unittest.TestCase):
  106. def setUp(self):
  107. self.tempdir = tempfile.mkdtemp()
  108. def tearDown(self):
  109. shutil.rmtree(self.tempdir)
  110. def test_read_entire_chunk(self):
  111. filename = os.path.join(self.tempdir, 'foo')
  112. with open(filename, 'wb') as f:
  113. f.write(b'onetwothreefourfivesixseveneightnineten')
  114. chunk = ReadFileChunk.from_filename(
  115. filename, start_byte=0, chunk_size=3
  116. )
  117. self.assertEqual(chunk.read(), b'one')
  118. self.assertEqual(chunk.read(), b'')
  119. def test_read_with_amount_size(self):
  120. filename = os.path.join(self.tempdir, 'foo')
  121. with open(filename, 'wb') as f:
  122. f.write(b'onetwothreefourfivesixseveneightnineten')
  123. chunk = ReadFileChunk.from_filename(
  124. filename, start_byte=11, chunk_size=4
  125. )
  126. self.assertEqual(chunk.read(1), b'f')
  127. self.assertEqual(chunk.read(1), b'o')
  128. self.assertEqual(chunk.read(1), b'u')
  129. self.assertEqual(chunk.read(1), b'r')
  130. self.assertEqual(chunk.read(1), b'')
  131. def test_reset_stream_emulation(self):
  132. filename = os.path.join(self.tempdir, 'foo')
  133. with open(filename, 'wb') as f:
  134. f.write(b'onetwothreefourfivesixseveneightnineten')
  135. chunk = ReadFileChunk.from_filename(
  136. filename, start_byte=11, chunk_size=4
  137. )
  138. self.assertEqual(chunk.read(), b'four')
  139. chunk.seek(0)
  140. self.assertEqual(chunk.read(), b'four')
  141. def test_read_past_end_of_file(self):
  142. filename = os.path.join(self.tempdir, 'foo')
  143. with open(filename, 'wb') as f:
  144. f.write(b'onetwothreefourfivesixseveneightnineten')
  145. chunk = ReadFileChunk.from_filename(
  146. filename, start_byte=36, chunk_size=100000
  147. )
  148. self.assertEqual(chunk.read(), b'ten')
  149. self.assertEqual(chunk.read(), b'')
  150. self.assertEqual(len(chunk), 3)
  151. def test_tell_and_seek(self):
  152. filename = os.path.join(self.tempdir, 'foo')
  153. with open(filename, 'wb') as f:
  154. f.write(b'onetwothreefourfivesixseveneightnineten')
  155. chunk = ReadFileChunk.from_filename(
  156. filename, start_byte=36, chunk_size=100000
  157. )
  158. self.assertEqual(chunk.tell(), 0)
  159. self.assertEqual(chunk.read(), b'ten')
  160. self.assertEqual(chunk.tell(), 3)
  161. chunk.seek(0)
  162. self.assertEqual(chunk.tell(), 0)
  163. def test_file_chunk_supports_context_manager(self):
  164. filename = os.path.join(self.tempdir, 'foo')
  165. with open(filename, 'wb') as f:
  166. f.write(b'abc')
  167. with ReadFileChunk.from_filename(
  168. filename, start_byte=0, chunk_size=2
  169. ) as chunk:
  170. val = chunk.read()
  171. self.assertEqual(val, b'ab')
  172. def test_iter_is_always_empty(self):
  173. # This tests the workaround for the httplib bug (see
  174. # the source for more info).
  175. filename = os.path.join(self.tempdir, 'foo')
  176. open(filename, 'wb').close()
  177. chunk = ReadFileChunk.from_filename(
  178. filename, start_byte=0, chunk_size=10
  179. )
  180. self.assertEqual(list(chunk), [])
  181. class TestReadFileChunkWithCallback(TestReadFileChunk):
  182. def setUp(self):
  183. super().setUp()
  184. self.filename = os.path.join(self.tempdir, 'foo')
  185. with open(self.filename, 'wb') as f:
  186. f.write(b'abc')
  187. self.amounts_seen = []
  188. def callback(self, amount):
  189. self.amounts_seen.append(amount)
  190. def test_callback_is_invoked_on_read(self):
  191. chunk = ReadFileChunk.from_filename(
  192. self.filename, start_byte=0, chunk_size=3, callback=self.callback
  193. )
  194. chunk.read(1)
  195. chunk.read(1)
  196. chunk.read(1)
  197. self.assertEqual(self.amounts_seen, [1, 1, 1])
  198. def test_callback_can_be_disabled(self):
  199. chunk = ReadFileChunk.from_filename(
  200. self.filename, start_byte=0, chunk_size=3, callback=self.callback
  201. )
  202. chunk.disable_callback()
  203. # Now reading from the ReadFileChunk should not invoke
  204. # the callback.
  205. chunk.read()
  206. self.assertEqual(self.amounts_seen, [])
  207. def test_callback_will_also_be_triggered_by_seek(self):
  208. chunk = ReadFileChunk.from_filename(
  209. self.filename, start_byte=0, chunk_size=3, callback=self.callback
  210. )
  211. chunk.read(2)
  212. chunk.seek(0)
  213. chunk.read(2)
  214. chunk.seek(1)
  215. chunk.read(2)
  216. self.assertEqual(self.amounts_seen, [2, -2, 2, -1, 2])
  217. class TestStreamReaderProgress(unittest.TestCase):
  218. def test_proxies_to_wrapped_stream(self):
  219. original_stream = StringIO('foobarbaz')
  220. wrapped = StreamReaderProgress(original_stream)
  221. self.assertEqual(wrapped.read(), 'foobarbaz')
  222. def test_callback_invoked(self):
  223. amounts_seen = []
  224. def callback(amount):
  225. amounts_seen.append(amount)
  226. original_stream = StringIO('foobarbaz')
  227. wrapped = StreamReaderProgress(original_stream, callback)
  228. self.assertEqual(wrapped.read(), 'foobarbaz')
  229. self.assertEqual(amounts_seen, [9])
  230. class TestMultipartUploader(unittest.TestCase):
  231. def test_multipart_upload_uses_correct_client_calls(self):
  232. client = mock.Mock()
  233. uploader = MultipartUploader(
  234. client,
  235. TransferConfig(),
  236. InMemoryOSLayer({'filename': b'foobar'}),
  237. SequentialExecutor,
  238. )
  239. client.create_multipart_upload.return_value = {'UploadId': 'upload_id'}
  240. client.upload_part.return_value = {'ETag': 'first'}
  241. uploader.upload_file('filename', 'bucket', 'key', None, {})
  242. # We need to check both the sequence of calls (create/upload/complete)
  243. # as well as the params passed between the calls, including
  244. # 1. The upload_id was plumbed through
  245. # 2. The collected etags were added to the complete call.
  246. client.create_multipart_upload.assert_called_with(
  247. Bucket='bucket', Key='key'
  248. )
  249. # Should be two parts.
  250. client.upload_part.assert_called_with(
  251. Body=mock.ANY,
  252. Bucket='bucket',
  253. UploadId='upload_id',
  254. Key='key',
  255. PartNumber=1,
  256. )
  257. client.complete_multipart_upload.assert_called_with(
  258. MultipartUpload={'Parts': [{'PartNumber': 1, 'ETag': 'first'}]},
  259. Bucket='bucket',
  260. UploadId='upload_id',
  261. Key='key',
  262. )
  263. def test_multipart_upload_injects_proper_kwargs(self):
  264. client = mock.Mock()
  265. uploader = MultipartUploader(
  266. client,
  267. TransferConfig(),
  268. InMemoryOSLayer({'filename': b'foobar'}),
  269. SequentialExecutor,
  270. )
  271. client.create_multipart_upload.return_value = {'UploadId': 'upload_id'}
  272. client.upload_part.return_value = {'ETag': 'first'}
  273. extra_args = {
  274. 'SSECustomerKey': 'fakekey',
  275. 'SSECustomerAlgorithm': 'AES256',
  276. 'StorageClass': 'REDUCED_REDUNDANCY',
  277. }
  278. uploader.upload_file('filename', 'bucket', 'key', None, extra_args)
  279. client.create_multipart_upload.assert_called_with(
  280. Bucket='bucket',
  281. Key='key',
  282. # The initial call should inject all the storage class params.
  283. SSECustomerKey='fakekey',
  284. SSECustomerAlgorithm='AES256',
  285. StorageClass='REDUCED_REDUNDANCY',
  286. )
  287. client.upload_part.assert_called_with(
  288. Body=mock.ANY,
  289. Bucket='bucket',
  290. UploadId='upload_id',
  291. Key='key',
  292. PartNumber=1,
  293. # We only have to forward certain **extra_args in subsequent
  294. # UploadPart calls.
  295. SSECustomerKey='fakekey',
  296. SSECustomerAlgorithm='AES256',
  297. )
  298. client.complete_multipart_upload.assert_called_with(
  299. MultipartUpload={'Parts': [{'PartNumber': 1, 'ETag': 'first'}]},
  300. Bucket='bucket',
  301. UploadId='upload_id',
  302. Key='key',
  303. )
  304. def test_multipart_upload_is_aborted_on_error(self):
  305. # If the create_multipart_upload succeeds and any upload_part
  306. # fails, then abort_multipart_upload will be called.
  307. client = mock.Mock()
  308. uploader = MultipartUploader(
  309. client,
  310. TransferConfig(),
  311. InMemoryOSLayer({'filename': b'foobar'}),
  312. SequentialExecutor,
  313. )
  314. client.create_multipart_upload.return_value = {'UploadId': 'upload_id'}
  315. client.upload_part.side_effect = Exception(
  316. "Some kind of error occurred."
  317. )
  318. with self.assertRaises(S3UploadFailedError):
  319. uploader.upload_file('filename', 'bucket', 'key', None, {})
  320. client.abort_multipart_upload.assert_called_with(
  321. Bucket='bucket', Key='key', UploadId='upload_id'
  322. )
  323. class TestMultipartDownloader(unittest.TestCase):
  324. maxDiff = None
  325. def test_multipart_download_uses_correct_client_calls(self):
  326. client = mock.Mock()
  327. response_body = b'foobarbaz'
  328. client.get_object.return_value = {'Body': BytesIO(response_body)}
  329. downloader = MultipartDownloader(
  330. client, TransferConfig(), InMemoryOSLayer({}), SequentialExecutor
  331. )
  332. downloader.download_file(
  333. 'bucket', 'key', 'filename', len(response_body), {}
  334. )
  335. client.get_object.assert_called_with(
  336. Range='bytes=0-', Bucket='bucket', Key='key'
  337. )
  338. def test_multipart_download_with_multiple_parts(self):
  339. client = mock.Mock()
  340. response_body = b'foobarbaz'
  341. client.get_object.return_value = {'Body': BytesIO(response_body)}
  342. # For testing purposes, we're testing with a multipart threshold
  343. # of 4 bytes and a chunksize of 4 bytes. Given b'foobarbaz',
  344. # this should result in 3 calls. In python slices this would be:
  345. # r[0:4], r[4:8], r[8:9]. But the Range param will be slightly
  346. # different because they use inclusive ranges.
  347. config = TransferConfig(multipart_threshold=4, multipart_chunksize=4)
  348. downloader = MultipartDownloader(
  349. client, config, InMemoryOSLayer({}), SequentialExecutor
  350. )
  351. downloader.download_file(
  352. 'bucket', 'key', 'filename', len(response_body), {}
  353. )
  354. # We're storing these in **extra because the assertEqual
  355. # below is really about verifying we have the correct value
  356. # for the Range param.
  357. extra = {'Bucket': 'bucket', 'Key': 'key'}
  358. self.assertEqual(
  359. client.get_object.call_args_list,
  360. # Note these are inclusive ranges.
  361. [
  362. mock.call(Range='bytes=0-3', **extra),
  363. mock.call(Range='bytes=4-7', **extra),
  364. mock.call(Range='bytes=8-', **extra),
  365. ],
  366. )
  367. def test_retry_on_failures_from_stream_reads(self):
  368. # If we get an exception during a call to the response body's .read()
  369. # method, we should retry the request.
  370. client = mock.Mock()
  371. response_body = b'foobarbaz'
  372. stream_with_errors = mock.Mock()
  373. stream_with_errors.read.side_effect = [
  374. socket.error("fake error"),
  375. response_body,
  376. ]
  377. client.get_object.return_value = {'Body': stream_with_errors}
  378. config = TransferConfig(multipart_threshold=4, multipart_chunksize=4)
  379. downloader = MultipartDownloader(
  380. client, config, InMemoryOSLayer({}), SequentialExecutor
  381. )
  382. downloader.download_file(
  383. 'bucket', 'key', 'filename', len(response_body), {}
  384. )
  385. # We're storing these in **extra because the assertEqual
  386. # below is really about verifying we have the correct value
  387. # for the Range param.
  388. extra = {'Bucket': 'bucket', 'Key': 'key'}
  389. self.assertEqual(
  390. client.get_object.call_args_list,
  391. # The first call to range=0-3 fails because of the
  392. # side_effect above where we make the .read() raise a
  393. # socket.error.
  394. # The second call to range=0-3 then succeeds.
  395. [
  396. mock.call(Range='bytes=0-3', **extra),
  397. mock.call(Range='bytes=0-3', **extra),
  398. mock.call(Range='bytes=4-7', **extra),
  399. mock.call(Range='bytes=8-', **extra),
  400. ],
  401. )
  402. def test_exception_raised_on_exceeded_retries(self):
  403. client = mock.Mock()
  404. response_body = b'foobarbaz'
  405. stream_with_errors = mock.Mock()
  406. stream_with_errors.read.side_effect = socket.error("fake error")
  407. client.get_object.return_value = {'Body': stream_with_errors}
  408. config = TransferConfig(multipart_threshold=4, multipart_chunksize=4)
  409. downloader = MultipartDownloader(
  410. client, config, InMemoryOSLayer({}), SequentialExecutor
  411. )
  412. with self.assertRaises(RetriesExceededError):
  413. downloader.download_file(
  414. 'bucket', 'key', 'filename', len(response_body), {}
  415. )
  416. def test_io_thread_failure_triggers_shutdown(self):
  417. client = mock.Mock()
  418. response_body = b'foobarbaz'
  419. client.get_object.return_value = {'Body': BytesIO(response_body)}
  420. os_layer = mock.Mock()
  421. mock_fileobj = mock.MagicMock()
  422. mock_fileobj.__enter__.return_value = mock_fileobj
  423. mock_fileobj.write.side_effect = Exception("fake IO error")
  424. os_layer.open.return_value = mock_fileobj
  425. downloader = MultipartDownloader(
  426. client, TransferConfig(), os_layer, SequentialExecutor
  427. )
  428. # We're verifying that the exception raised from the IO future
  429. # propagates back up via download_file().
  430. with self.assertRaisesRegex(Exception, "fake IO error"):
  431. downloader.download_file(
  432. 'bucket', 'key', 'filename', len(response_body), {}
  433. )
  434. def test_download_futures_fail_triggers_shutdown(self):
  435. class FailedDownloadParts(SequentialExecutor):
  436. def __init__(self, max_workers):
  437. self.is_first = True
  438. def submit(self, function):
  439. future = futures.Future()
  440. if self.is_first:
  441. # This is the download_parts_thread.
  442. future.set_exception(
  443. Exception("fake download parts error")
  444. )
  445. self.is_first = False
  446. return future
  447. client = mock.Mock()
  448. response_body = b'foobarbaz'
  449. client.get_object.return_value = {'Body': BytesIO(response_body)}
  450. downloader = MultipartDownloader(
  451. client, TransferConfig(), InMemoryOSLayer({}), FailedDownloadParts
  452. )
  453. with self.assertRaisesRegex(Exception, "fake download parts error"):
  454. downloader.download_file(
  455. 'bucket', 'key', 'filename', len(response_body), {}
  456. )
  457. class TestS3Transfer(unittest.TestCase):
  458. def setUp(self):
  459. self.client = mock.Mock()
  460. self.random_file_patch = mock.patch('s3transfer.random_file_extension')
  461. self.random_file = self.random_file_patch.start()
  462. self.random_file.return_value = 'RANDOM'
  463. def tearDown(self):
  464. self.random_file_patch.stop()
  465. def test_callback_handlers_register_on_put_item(self):
  466. osutil = InMemoryOSLayer({'smallfile': b'foobar'})
  467. transfer = S3Transfer(self.client, osutil=osutil)
  468. transfer.upload_file('smallfile', 'bucket', 'key')
  469. events = self.client.meta.events
  470. events.register_first.assert_called_with(
  471. 'request-created.s3',
  472. disable_upload_callbacks,
  473. unique_id='s3upload-callback-disable',
  474. )
  475. events.register_last.assert_called_with(
  476. 'request-created.s3',
  477. enable_upload_callbacks,
  478. unique_id='s3upload-callback-enable',
  479. )
  480. def test_upload_below_multipart_threshold_uses_put_object(self):
  481. fake_files = {
  482. 'smallfile': b'foobar',
  483. }
  484. osutil = InMemoryOSLayer(fake_files)
  485. transfer = S3Transfer(self.client, osutil=osutil)
  486. transfer.upload_file('smallfile', 'bucket', 'key')
  487. self.client.put_object.assert_called_with(
  488. Bucket='bucket', Key='key', Body=mock.ANY
  489. )
  490. def test_extra_args_on_uploaded_passed_to_api_call(self):
  491. extra_args = {'ACL': 'public-read'}
  492. fake_files = {'smallfile': b'hello world'}
  493. osutil = InMemoryOSLayer(fake_files)
  494. transfer = S3Transfer(self.client, osutil=osutil)
  495. transfer.upload_file(
  496. 'smallfile', 'bucket', 'key', extra_args=extra_args
  497. )
  498. self.client.put_object.assert_called_with(
  499. Bucket='bucket', Key='key', Body=mock.ANY, ACL='public-read'
  500. )
  501. def test_uses_multipart_upload_when_over_threshold(self):
  502. with mock.patch('s3transfer.MultipartUploader') as uploader:
  503. fake_files = {
  504. 'smallfile': b'foobar',
  505. }
  506. osutil = InMemoryOSLayer(fake_files)
  507. config = TransferConfig(
  508. multipart_threshold=2, multipart_chunksize=2
  509. )
  510. transfer = S3Transfer(self.client, osutil=osutil, config=config)
  511. transfer.upload_file('smallfile', 'bucket', 'key')
  512. uploader.return_value.upload_file.assert_called_with(
  513. 'smallfile', 'bucket', 'key', None, {}
  514. )
  515. def test_uses_multipart_download_when_over_threshold(self):
  516. with mock.patch('s3transfer.MultipartDownloader') as downloader:
  517. osutil = InMemoryOSLayer({})
  518. over_multipart_threshold = 100 * 1024 * 1024
  519. transfer = S3Transfer(self.client, osutil=osutil)
  520. callback = mock.sentinel.CALLBACK
  521. self.client.head_object.return_value = {
  522. 'ContentLength': over_multipart_threshold,
  523. }
  524. transfer.download_file(
  525. 'bucket', 'key', 'filename', callback=callback
  526. )
  527. downloader.return_value.download_file.assert_called_with(
  528. # Note how we're downloading to a temporary random file.
  529. 'bucket',
  530. 'key',
  531. 'filename.RANDOM',
  532. over_multipart_threshold,
  533. {},
  534. callback,
  535. )
  536. def test_download_file_with_invalid_extra_args(self):
  537. below_threshold = 20
  538. osutil = InMemoryOSLayer({})
  539. transfer = S3Transfer(self.client, osutil=osutil)
  540. self.client.head_object.return_value = {
  541. 'ContentLength': below_threshold
  542. }
  543. with self.assertRaises(ValueError):
  544. transfer.download_file(
  545. 'bucket',
  546. 'key',
  547. '/tmp/smallfile',
  548. extra_args={'BadValue': 'foo'},
  549. )
  550. def test_upload_file_with_invalid_extra_args(self):
  551. osutil = InMemoryOSLayer({})
  552. transfer = S3Transfer(self.client, osutil=osutil)
  553. bad_args = {"WebsiteRedirectLocation": "/foo"}
  554. with self.assertRaises(ValueError):
  555. transfer.upload_file(
  556. 'bucket', 'key', '/tmp/smallfile', extra_args=bad_args
  557. )
  558. def test_download_file_fowards_extra_args(self):
  559. extra_args = {
  560. 'SSECustomerKey': 'foo',
  561. 'SSECustomerAlgorithm': 'AES256',
  562. }
  563. below_threshold = 20
  564. osutil = InMemoryOSLayer({'smallfile': b'hello world'})
  565. transfer = S3Transfer(self.client, osutil=osutil)
  566. self.client.head_object.return_value = {
  567. 'ContentLength': below_threshold
  568. }
  569. self.client.get_object.return_value = {'Body': BytesIO(b'foobar')}
  570. transfer.download_file(
  571. 'bucket', 'key', '/tmp/smallfile', extra_args=extra_args
  572. )
  573. # Note that we need to invoke the HeadObject call
  574. # and the PutObject call with the extra_args.
  575. # This is necessary. Trying to HeadObject an SSE object
  576. # will return a 400 if you don't provide the required
  577. # params.
  578. self.client.get_object.assert_called_with(
  579. Bucket='bucket',
  580. Key='key',
  581. SSECustomerAlgorithm='AES256',
  582. SSECustomerKey='foo',
  583. )
  584. def test_get_object_stream_is_retried_and_succeeds(self):
  585. below_threshold = 20
  586. osutil = InMemoryOSLayer({'smallfile': b'hello world'})
  587. transfer = S3Transfer(self.client, osutil=osutil)
  588. self.client.head_object.return_value = {
  589. 'ContentLength': below_threshold
  590. }
  591. self.client.get_object.side_effect = [
  592. # First request fails.
  593. socket.error("fake error"),
  594. # Second succeeds.
  595. {'Body': BytesIO(b'foobar')},
  596. ]
  597. transfer.download_file('bucket', 'key', '/tmp/smallfile')
  598. self.assertEqual(self.client.get_object.call_count, 2)
  599. def test_get_object_stream_uses_all_retries_and_errors_out(self):
  600. below_threshold = 20
  601. osutil = InMemoryOSLayer({})
  602. transfer = S3Transfer(self.client, osutil=osutil)
  603. self.client.head_object.return_value = {
  604. 'ContentLength': below_threshold
  605. }
  606. # Here we're raising an exception every single time, which
  607. # will exhaust our retry count and propagate a
  608. # RetriesExceededError.
  609. self.client.get_object.side_effect = socket.error("fake error")
  610. with self.assertRaises(RetriesExceededError):
  611. transfer.download_file('bucket', 'key', 'smallfile')
  612. self.assertEqual(self.client.get_object.call_count, 5)
  613. # We should have also cleaned up the in progress file
  614. # we were downloading to.
  615. self.assertEqual(osutil.filemap, {})
  616. def test_download_below_multipart_threshold(self):
  617. below_threshold = 20
  618. osutil = InMemoryOSLayer({'smallfile': b'hello world'})
  619. transfer = S3Transfer(self.client, osutil=osutil)
  620. self.client.head_object.return_value = {
  621. 'ContentLength': below_threshold
  622. }
  623. self.client.get_object.return_value = {'Body': BytesIO(b'foobar')}
  624. transfer.download_file('bucket', 'key', 'smallfile')
  625. self.client.get_object.assert_called_with(Bucket='bucket', Key='key')
  626. def test_can_create_with_just_client(self):
  627. transfer = S3Transfer(client=mock.Mock())
  628. self.assertIsInstance(transfer, S3Transfer)
  629. class TestShutdownQueue(unittest.TestCase):
  630. def test_handles_normal_put_get_requests(self):
  631. q = ShutdownQueue()
  632. q.put('foo')
  633. self.assertEqual(q.get(), 'foo')
  634. def test_put_raises_error_on_shutdown(self):
  635. q = ShutdownQueue()
  636. q.trigger_shutdown()
  637. with self.assertRaises(QueueShutdownError):
  638. q.put('foo')
  639. class TestRandomFileExtension(unittest.TestCase):
  640. def test_has_proper_length(self):
  641. self.assertEqual(len(random_file_extension(num_digits=4)), 4)
  642. class TestCallbackHandlers(unittest.TestCase):
  643. def setUp(self):
  644. self.request = mock.Mock()
  645. def test_disable_request_on_put_object(self):
  646. disable_upload_callbacks(self.request, 'PutObject')
  647. self.request.body.disable_callback.assert_called_with()
  648. def test_disable_request_on_upload_part(self):
  649. disable_upload_callbacks(self.request, 'UploadPart')
  650. self.request.body.disable_callback.assert_called_with()
  651. def test_enable_object_on_put_object(self):
  652. enable_upload_callbacks(self.request, 'PutObject')
  653. self.request.body.enable_callback.assert_called_with()
  654. def test_enable_object_on_upload_part(self):
  655. enable_upload_callbacks(self.request, 'UploadPart')
  656. self.request.body.enable_callback.assert_called_with()
  657. def test_dont_disable_if_missing_interface(self):
  658. del self.request.body.disable_callback
  659. disable_upload_callbacks(self.request, 'PutObject')
  660. self.assertEqual(self.request.body.method_calls, [])
  661. def test_dont_enable_if_missing_interface(self):
  662. del self.request.body.enable_callback
  663. enable_upload_callbacks(self.request, 'PutObject')
  664. self.assertEqual(self.request.body.method_calls, [])
  665. def test_dont_disable_if_wrong_operation(self):
  666. disable_upload_callbacks(self.request, 'OtherOperation')
  667. self.assertFalse(self.request.body.disable_callback.called)
  668. def test_dont_enable_if_wrong_operation(self):
  669. enable_upload_callbacks(self.request, 'OtherOperation')
  670. self.assertFalse(self.request.body.enable_callback.called)