test_s3transfer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. # Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"). You
  4. # may not use this file except in compliance with the License. A copy of
  5. # the License is located at
  6. #
  7. # http://aws.amazon.com/apache2.0/
  8. #
  9. # or in the "license" file accompanying this file. This file is
  10. # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
  11. # ANY KIND, either express or implied. See the License for the specific
  12. # language governing permissions and limitations under the License.
  13. import hashlib
  14. import math
  15. import os
  16. import shutil
  17. import string
  18. import tempfile
  19. import threading
  20. from botocore.client import Config
  21. import s3transfer
  22. from tests.integration import BaseTransferManagerIntegTest
  23. def assert_files_equal(first, second):
  24. if os.path.getsize(first) != os.path.getsize(second):
  25. raise AssertionError(f"Files are not equal: {first}, {second}")
  26. first_md5 = md5_checksum(first)
  27. second_md5 = md5_checksum(second)
  28. if first_md5 != second_md5:
  29. raise AssertionError(
  30. "Files are not equal: {}(md5={}) != {}(md5={})".format(
  31. first, first_md5, second, second_md5
  32. )
  33. )
  34. def md5_checksum(filename):
  35. checksum = hashlib.md5()
  36. with open(filename, 'rb') as f:
  37. for chunk in iter(lambda: f.read(8192), b''):
  38. checksum.update(chunk)
  39. return checksum.hexdigest()
  40. def random_bucket_name(prefix='boto3-transfer', num_chars=10):
  41. base = string.ascii_lowercase + string.digits
  42. random_bytes = bytearray(os.urandom(num_chars))
  43. return prefix + ''.join([base[b % len(base)] for b in random_bytes])
  44. class FileCreator:
  45. def __init__(self):
  46. self.rootdir = tempfile.mkdtemp()
  47. def remove_all(self):
  48. shutil.rmtree(self.rootdir)
  49. def create_file(self, filename, contents, mode='w'):
  50. """Creates a file in a tmpdir
  51. ``filename`` should be a relative path, e.g. "foo/bar/baz.txt"
  52. It will be translated into a full path in a tmp dir.
  53. ``mode`` is the mode the file should be opened either as ``w`` or
  54. `wb``.
  55. Returns the full path to the file.
  56. """
  57. full_path = os.path.join(self.rootdir, filename)
  58. if not os.path.isdir(os.path.dirname(full_path)):
  59. os.makedirs(os.path.dirname(full_path))
  60. with open(full_path, mode) as f:
  61. f.write(contents)
  62. return full_path
  63. def create_file_with_size(self, filename, filesize):
  64. filename = self.create_file(filename, contents='')
  65. chunksize = 8192
  66. with open(filename, 'wb') as f:
  67. for i in range(int(math.ceil(filesize / float(chunksize)))):
  68. f.write(b'a' * chunksize)
  69. return filename
  70. def append_file(self, filename, contents):
  71. """Append contents to a file
  72. ``filename`` should be a relative path, e.g. "foo/bar/baz.txt"
  73. It will be translated into a full path in a tmp dir.
  74. Returns the full path to the file.
  75. """
  76. full_path = os.path.join(self.rootdir, filename)
  77. if not os.path.isdir(os.path.dirname(full_path)):
  78. os.makedirs(os.path.dirname(full_path))
  79. with open(full_path, 'a') as f:
  80. f.write(contents)
  81. return full_path
  82. def full_path(self, filename):
  83. """Translate relative path to full path in temp dir.
  84. f.full_path('foo/bar.txt') -> /tmp/asdfasd/foo/bar.txt
  85. """
  86. return os.path.join(self.rootdir, filename)
  87. class TestS3Transfers(BaseTransferManagerIntegTest):
  88. """Tests for the high level s3transfer module."""
  89. def create_s3_transfer(self, config=None):
  90. return s3transfer.S3Transfer(self.client, config=config)
  91. def assert_has_public_read_acl(self, response):
  92. grants = response['Grants']
  93. public_read = [
  94. g['Grantee'].get('URI', '')
  95. for g in grants
  96. if g['Permission'] == 'READ'
  97. ]
  98. self.assertIn('groups/global/AllUsers', public_read[0])
  99. def test_upload_below_threshold(self):
  100. config = s3transfer.TransferConfig(multipart_threshold=2 * 1024 * 1024)
  101. transfer = self.create_s3_transfer(config)
  102. filename = self.files.create_file_with_size(
  103. 'foo.txt', filesize=1024 * 1024
  104. )
  105. transfer.upload_file(filename, self.bucket_name, 'foo.txt')
  106. self.addCleanup(self.delete_object, 'foo.txt')
  107. self.assertTrue(self.object_exists('foo.txt'))
  108. def test_upload_above_threshold(self):
  109. config = s3transfer.TransferConfig(multipart_threshold=2 * 1024 * 1024)
  110. transfer = self.create_s3_transfer(config)
  111. filename = self.files.create_file_with_size(
  112. '20mb.txt', filesize=20 * 1024 * 1024
  113. )
  114. transfer.upload_file(filename, self.bucket_name, '20mb.txt')
  115. self.addCleanup(self.delete_object, '20mb.txt')
  116. self.assertTrue(self.object_exists('20mb.txt'))
  117. def test_upload_file_above_threshold_with_acl(self):
  118. config = s3transfer.TransferConfig(multipart_threshold=5 * 1024 * 1024)
  119. transfer = self.create_s3_transfer(config)
  120. filename = self.files.create_file_with_size(
  121. '6mb.txt', filesize=6 * 1024 * 1024
  122. )
  123. extra_args = {'ACL': 'public-read'}
  124. transfer.upload_file(
  125. filename, self.bucket_name, '6mb.txt', extra_args=extra_args
  126. )
  127. self.addCleanup(self.delete_object, '6mb.txt')
  128. self.assertTrue(self.object_exists('6mb.txt'))
  129. response = self.client.get_object_acl(
  130. Bucket=self.bucket_name, Key='6mb.txt'
  131. )
  132. self.assert_has_public_read_acl(response)
  133. def test_upload_file_above_threshold_with_ssec(self):
  134. key_bytes = os.urandom(32)
  135. extra_args = {
  136. 'SSECustomerKey': key_bytes,
  137. 'SSECustomerAlgorithm': 'AES256',
  138. }
  139. config = s3transfer.TransferConfig(multipart_threshold=5 * 1024 * 1024)
  140. transfer = self.create_s3_transfer(config)
  141. filename = self.files.create_file_with_size(
  142. '6mb.txt', filesize=6 * 1024 * 1024
  143. )
  144. transfer.upload_file(
  145. filename, self.bucket_name, '6mb.txt', extra_args=extra_args
  146. )
  147. self.addCleanup(self.delete_object, '6mb.txt')
  148. self.wait_object_exists('6mb.txt', extra_args)
  149. # A head object will fail if it has a customer key
  150. # associated with it and it's not provided in the HeadObject
  151. # request so we can use this to verify our functionality.
  152. response = self.client.head_object(
  153. Bucket=self.bucket_name, Key='6mb.txt', **extra_args
  154. )
  155. self.assertEqual(response['SSECustomerAlgorithm'], 'AES256')
  156. def test_progress_callback_on_upload(self):
  157. self.amount_seen = 0
  158. lock = threading.Lock()
  159. def progress_callback(amount):
  160. with lock:
  161. self.amount_seen += amount
  162. transfer = self.create_s3_transfer()
  163. filename = self.files.create_file_with_size(
  164. '20mb.txt', filesize=20 * 1024 * 1024
  165. )
  166. transfer.upload_file(
  167. filename, self.bucket_name, '20mb.txt', callback=progress_callback
  168. )
  169. self.addCleanup(self.delete_object, '20mb.txt')
  170. # The callback should have been called enough times such that
  171. # the total amount of bytes we've seen (via the "amount"
  172. # arg to the callback function) should be the size
  173. # of the file we uploaded.
  174. self.assertEqual(self.amount_seen, 20 * 1024 * 1024)
  175. def test_callback_called_once_with_sigv4(self):
  176. # Verify #98, where the callback was being invoked
  177. # twice when using signature version 4.
  178. self.amount_seen = 0
  179. lock = threading.Lock()
  180. def progress_callback(amount):
  181. with lock:
  182. self.amount_seen += amount
  183. client = self.session.create_client(
  184. 's3', self.region, config=Config(signature_version='s3v4')
  185. )
  186. transfer = s3transfer.S3Transfer(client)
  187. filename = self.files.create_file_with_size(
  188. '10mb.txt', filesize=10 * 1024 * 1024
  189. )
  190. transfer.upload_file(
  191. filename, self.bucket_name, '10mb.txt', callback=progress_callback
  192. )
  193. self.addCleanup(self.delete_object, '10mb.txt')
  194. self.assertEqual(self.amount_seen, 10 * 1024 * 1024)
  195. def test_can_send_extra_params_on_upload(self):
  196. transfer = self.create_s3_transfer()
  197. filename = self.files.create_file_with_size('foo.txt', filesize=1024)
  198. transfer.upload_file(
  199. filename,
  200. self.bucket_name,
  201. 'foo.txt',
  202. extra_args={'ACL': 'public-read'},
  203. )
  204. self.addCleanup(self.delete_object, 'foo.txt')
  205. self.wait_object_exists('foo.txt')
  206. response = self.client.get_object_acl(
  207. Bucket=self.bucket_name, Key='foo.txt'
  208. )
  209. self.assert_has_public_read_acl(response)
  210. def test_can_configure_threshold(self):
  211. config = s3transfer.TransferConfig(multipart_threshold=6 * 1024 * 1024)
  212. transfer = self.create_s3_transfer(config)
  213. filename = self.files.create_file_with_size(
  214. 'foo.txt', filesize=8 * 1024 * 1024
  215. )
  216. transfer.upload_file(filename, self.bucket_name, 'foo.txt')
  217. self.addCleanup(self.delete_object, 'foo.txt')
  218. self.assertTrue(self.object_exists('foo.txt'))
  219. def test_can_send_extra_params_on_download(self):
  220. # We're picking the customer provided sse feature
  221. # of S3 to test the extra_args functionality of
  222. # S3.
  223. key_bytes = os.urandom(32)
  224. extra_args = {
  225. 'SSECustomerKey': key_bytes,
  226. 'SSECustomerAlgorithm': 'AES256',
  227. }
  228. filename = self.files.create_file('foo.txt', 'hello world')
  229. self.upload_file(filename, 'foo.txt', extra_args)
  230. transfer = self.create_s3_transfer()
  231. download_path = os.path.join(self.files.rootdir, 'downloaded.txt')
  232. transfer.download_file(
  233. self.bucket_name, 'foo.txt', download_path, extra_args=extra_args
  234. )
  235. with open(download_path, 'rb') as f:
  236. self.assertEqual(f.read(), b'hello world')
  237. def test_progress_callback_on_download(self):
  238. self.amount_seen = 0
  239. lock = threading.Lock()
  240. def progress_callback(amount):
  241. with lock:
  242. self.amount_seen += amount
  243. transfer = self.create_s3_transfer()
  244. filename = self.files.create_file_with_size(
  245. '20mb.txt', filesize=20 * 1024 * 1024
  246. )
  247. self.upload_file(filename, '20mb.txt')
  248. download_path = os.path.join(self.files.rootdir, 'downloaded.txt')
  249. transfer.download_file(
  250. self.bucket_name,
  251. '20mb.txt',
  252. download_path,
  253. callback=progress_callback,
  254. )
  255. self.assertEqual(self.amount_seen, 20 * 1024 * 1024)
  256. def test_download_below_threshold(self):
  257. transfer = self.create_s3_transfer()
  258. filename = self.files.create_file_with_size(
  259. 'foo.txt', filesize=1024 * 1024
  260. )
  261. self.upload_file(filename, 'foo.txt')
  262. download_path = os.path.join(self.files.rootdir, 'downloaded.txt')
  263. transfer.download_file(self.bucket_name, 'foo.txt', download_path)
  264. assert_files_equal(filename, download_path)
  265. def test_download_above_threshold(self):
  266. transfer = self.create_s3_transfer()
  267. filename = self.files.create_file_with_size(
  268. 'foo.txt', filesize=20 * 1024 * 1024
  269. )
  270. self.upload_file(filename, 'foo.txt')
  271. download_path = os.path.join(self.files.rootdir, 'downloaded.txt')
  272. transfer.download_file(self.bucket_name, 'foo.txt', download_path)
  273. assert_files_equal(filename, download_path)