test_crt.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. # Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"). You
  4. # may not use this file except in compliance with the License. A copy of
  5. # the License is located at
  6. #
  7. # http://aws.amazon.com/apache2.0/
  8. #
  9. # or in the "license" file accompanying this file. This file is
  10. # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
  11. # ANY KIND, either express or implied. See the License for the specific
  12. # language governing permissions and limitations under the License.
  13. import glob
  14. import os
  15. from s3transfer.subscribers import BaseSubscriber
  16. from s3transfer.utils import OSUtils
  17. from tests import HAS_CRT, assert_files_equal, requires_crt
  18. from tests.integration import BaseTransferManagerIntegTest
  19. if HAS_CRT:
  20. from awscrt.exceptions import AwsCrtError
  21. import s3transfer.crt
  22. class RecordingSubscriber(BaseSubscriber):
  23. def __init__(self):
  24. self.on_queued_called = False
  25. self.on_done_called = False
  26. self.bytes_transferred = 0
  27. def on_queued(self, **kwargs):
  28. self.on_queued_called = True
  29. def on_progress(self, future, bytes_transferred, **kwargs):
  30. self.bytes_transferred += bytes_transferred
  31. def on_done(self, **kwargs):
  32. self.on_done_called = True
  33. @requires_crt
  34. class TestCRTS3Transfers(BaseTransferManagerIntegTest):
  35. """Tests for the high level s3transfer based on CRT implementation."""
  36. def _create_s3_transfer(self):
  37. self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer(
  38. self.session
  39. )
  40. credetial_resolver = self.session.get_component('credential_provider')
  41. self.s3_crt_client = s3transfer.crt.create_s3_crt_client(
  42. self.session.get_config_variable("region"), credetial_resolver
  43. )
  44. self.record_subscriber = RecordingSubscriber()
  45. self.osutil = OSUtils()
  46. return s3transfer.crt.CRTTransferManager(
  47. self.s3_crt_client, self.request_serializer
  48. )
  49. def _assert_has_public_read_acl(self, response):
  50. grants = response['Grants']
  51. public_read = [
  52. g['Grantee'].get('URI', '')
  53. for g in grants
  54. if g['Permission'] == 'READ'
  55. ]
  56. self.assertIn('groups/global/AllUsers', public_read[0])
  57. def _assert_subscribers_called(self, expected_bytes_transferred=None):
  58. self.assertTrue(self.record_subscriber.on_queued_called)
  59. self.assertTrue(self.record_subscriber.on_done_called)
  60. if expected_bytes_transferred:
  61. self.assertEqual(
  62. self.record_subscriber.bytes_transferred,
  63. expected_bytes_transferred,
  64. )
  65. def test_upload_below_multipart_chunksize(self):
  66. transfer = self._create_s3_transfer()
  67. file_size = 1024 * 1024
  68. filename = self.files.create_file_with_size(
  69. 'foo.txt', filesize=file_size
  70. )
  71. self.addCleanup(self.delete_object, 'foo.txt')
  72. with transfer:
  73. future = transfer.upload(
  74. filename,
  75. self.bucket_name,
  76. 'foo.txt',
  77. subscribers=[self.record_subscriber],
  78. )
  79. future.result()
  80. self.assertTrue(self.object_exists('foo.txt'))
  81. self._assert_subscribers_called(file_size)
  82. def test_upload_above_multipart_chunksize(self):
  83. transfer = self._create_s3_transfer()
  84. file_size = 20 * 1024 * 1024
  85. filename = self.files.create_file_with_size(
  86. '20mb.txt', filesize=file_size
  87. )
  88. self.addCleanup(self.delete_object, '20mb.txt')
  89. with transfer:
  90. future = transfer.upload(
  91. filename,
  92. self.bucket_name,
  93. '20mb.txt',
  94. subscribers=[self.record_subscriber],
  95. )
  96. future.result()
  97. self.assertTrue(self.object_exists('20mb.txt'))
  98. self._assert_subscribers_called(file_size)
  99. def test_upload_file_above_threshold_with_acl(self):
  100. transfer = self._create_s3_transfer()
  101. file_size = 6 * 1024 * 1024
  102. filename = self.files.create_file_with_size(
  103. '6mb.txt', filesize=file_size
  104. )
  105. extra_args = {'ACL': 'public-read'}
  106. self.addCleanup(self.delete_object, '6mb.txt')
  107. with transfer:
  108. future = transfer.upload(
  109. filename,
  110. self.bucket_name,
  111. '6mb.txt',
  112. extra_args=extra_args,
  113. subscribers=[self.record_subscriber],
  114. )
  115. future.result()
  116. self.assertTrue(self.object_exists('6mb.txt'))
  117. response = self.client.get_object_acl(
  118. Bucket=self.bucket_name, Key='6mb.txt'
  119. )
  120. self._assert_has_public_read_acl(response)
  121. self._assert_subscribers_called(file_size)
  122. def test_upload_file_above_threshold_with_ssec(self):
  123. key_bytes = os.urandom(32)
  124. extra_args = {
  125. 'SSECustomerKey': key_bytes,
  126. 'SSECustomerAlgorithm': 'AES256',
  127. }
  128. file_size = 6 * 1024 * 1024
  129. transfer = self._create_s3_transfer()
  130. filename = self.files.create_file_with_size(
  131. '6mb.txt', filesize=file_size
  132. )
  133. self.addCleanup(self.delete_object, '6mb.txt')
  134. with transfer:
  135. future = transfer.upload(
  136. filename,
  137. self.bucket_name,
  138. '6mb.txt',
  139. extra_args=extra_args,
  140. subscribers=[self.record_subscriber],
  141. )
  142. future.result()
  143. # A head object will fail if it has a customer key
  144. # associated with it and it's not provided in the HeadObject
  145. # request so we can use this to verify our functionality.
  146. oringal_extra_args = {
  147. 'SSECustomerKey': key_bytes,
  148. 'SSECustomerAlgorithm': 'AES256',
  149. }
  150. self.wait_object_exists('6mb.txt', oringal_extra_args)
  151. response = self.client.head_object(
  152. Bucket=self.bucket_name, Key='6mb.txt', **oringal_extra_args
  153. )
  154. self.assertEqual(response['SSECustomerAlgorithm'], 'AES256')
  155. self._assert_subscribers_called(file_size)
  156. def test_can_send_extra_params_on_download(self):
  157. # We're picking the customer provided sse feature
  158. # of S3 to test the extra_args functionality of
  159. # S3.
  160. key_bytes = os.urandom(32)
  161. extra_args = {
  162. 'SSECustomerKey': key_bytes,
  163. 'SSECustomerAlgorithm': 'AES256',
  164. }
  165. filename = self.files.create_file('foo.txt', 'hello world')
  166. self.upload_file(filename, 'foo.txt', extra_args)
  167. transfer = self._create_s3_transfer()
  168. download_path = os.path.join(self.files.rootdir, 'downloaded.txt')
  169. with transfer:
  170. future = transfer.download(
  171. self.bucket_name,
  172. 'foo.txt',
  173. download_path,
  174. extra_args=extra_args,
  175. subscribers=[self.record_subscriber],
  176. )
  177. future.result()
  178. file_size = self.osutil.get_file_size(download_path)
  179. self._assert_subscribers_called(file_size)
  180. with open(download_path, 'rb') as f:
  181. self.assertEqual(f.read(), b'hello world')
  182. def test_download_below_threshold(self):
  183. transfer = self._create_s3_transfer()
  184. filename = self.files.create_file_with_size(
  185. 'foo.txt', filesize=1024 * 1024
  186. )
  187. self.upload_file(filename, 'foo.txt')
  188. download_path = os.path.join(self.files.rootdir, 'downloaded.txt')
  189. with transfer:
  190. future = transfer.download(
  191. self.bucket_name,
  192. 'foo.txt',
  193. download_path,
  194. subscribers=[self.record_subscriber],
  195. )
  196. future.result()
  197. file_size = self.osutil.get_file_size(download_path)
  198. self._assert_subscribers_called(file_size)
  199. assert_files_equal(filename, download_path)
  200. def test_download_above_threshold(self):
  201. transfer = self._create_s3_transfer()
  202. filename = self.files.create_file_with_size(
  203. 'foo.txt', filesize=20 * 1024 * 1024
  204. )
  205. self.upload_file(filename, 'foo.txt')
  206. download_path = os.path.join(self.files.rootdir, 'downloaded.txt')
  207. with transfer:
  208. future = transfer.download(
  209. self.bucket_name,
  210. 'foo.txt',
  211. download_path,
  212. subscribers=[self.record_subscriber],
  213. )
  214. future.result()
  215. assert_files_equal(filename, download_path)
  216. file_size = self.osutil.get_file_size(download_path)
  217. self._assert_subscribers_called(file_size)
  218. def test_delete(self):
  219. transfer = self._create_s3_transfer()
  220. filename = self.files.create_file_with_size(
  221. 'foo.txt', filesize=1024 * 1024
  222. )
  223. self.upload_file(filename, 'foo.txt')
  224. with transfer:
  225. future = transfer.delete(self.bucket_name, 'foo.txt')
  226. future.result()
  227. self.assertTrue(self.object_not_exists('foo.txt'))
  228. def test_many_files_download(self):
  229. transfer = self._create_s3_transfer()
  230. filename = self.files.create_file_with_size(
  231. '1mb.txt', filesize=1024 * 1024
  232. )
  233. self.upload_file(filename, '1mb.txt')
  234. filenames = []
  235. base_filename = os.path.join(self.files.rootdir, 'file')
  236. for i in range(10):
  237. filenames.append(base_filename + str(i))
  238. with transfer:
  239. for filename in filenames:
  240. transfer.download(self.bucket_name, '1mb.txt', filename)
  241. for download_path in filenames:
  242. assert_files_equal(filename, download_path)
  243. def test_many_files_upload(self):
  244. transfer = self._create_s3_transfer()
  245. keys = []
  246. filenames = []
  247. base_key = 'foo'
  248. sufix = '.txt'
  249. for i in range(10):
  250. key = base_key + str(i) + sufix
  251. keys.append(key)
  252. filename = self.files.create_file_with_size(
  253. key, filesize=1024 * 1024
  254. )
  255. filenames.append(filename)
  256. self.addCleanup(self.delete_object, key)
  257. with transfer:
  258. for filename, key in zip(filenames, keys):
  259. transfer.upload(filename, self.bucket_name, key)
  260. for key in keys:
  261. self.assertTrue(self.object_exists(key))
  262. def test_many_files_delete(self):
  263. transfer = self._create_s3_transfer()
  264. keys = []
  265. base_key = 'foo'
  266. sufix = '.txt'
  267. filename = self.files.create_file_with_size(
  268. '1mb.txt', filesize=1024 * 1024
  269. )
  270. for i in range(10):
  271. key = base_key + str(i) + sufix
  272. keys.append(key)
  273. self.upload_file(filename, key)
  274. with transfer:
  275. for key in keys:
  276. transfer.delete(self.bucket_name, key)
  277. for key in keys:
  278. self.assertTrue(self.object_not_exists(key))
  279. def test_upload_cancel(self):
  280. transfer = self._create_s3_transfer()
  281. filename = self.files.create_file_with_size(
  282. '20mb.txt', filesize=20 * 1024 * 1024
  283. )
  284. future = None
  285. try:
  286. with transfer:
  287. future = transfer.upload(
  288. filename, self.bucket_name, '20mb.txt'
  289. )
  290. raise KeyboardInterrupt()
  291. except KeyboardInterrupt:
  292. pass
  293. with self.assertRaises(AwsCrtError) as cm:
  294. future.result()
  295. self.assertEqual(cm.name, 'AWS_ERROR_S3_CANCELED')
  296. self.assertTrue(self.object_not_exists('20mb.txt'))
  297. def test_download_cancel(self):
  298. transfer = self._create_s3_transfer()
  299. filename = self.files.create_file_with_size(
  300. 'foo.txt', filesize=20 * 1024 * 1024
  301. )
  302. self.upload_file(filename, 'foo.txt')
  303. download_path = os.path.join(self.files.rootdir, 'downloaded.txt')
  304. future = None
  305. try:
  306. with transfer:
  307. future = transfer.download(
  308. self.bucket_name,
  309. 'foo.txt',
  310. download_path,
  311. subscribers=[self.record_subscriber],
  312. )
  313. raise KeyboardInterrupt()
  314. except KeyboardInterrupt:
  315. pass
  316. with self.assertRaises(AwsCrtError) as err:
  317. future.result()
  318. self.assertEqual(err.name, 'AWS_ERROR_S3_CANCELED')
  319. possible_matches = glob.glob('%s*' % download_path)
  320. self.assertEqual(possible_matches, [])
  321. self._assert_subscribers_called()