# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import glob import os from s3transfer.subscribers import BaseSubscriber from s3transfer.utils import OSUtils from tests import HAS_CRT, assert_files_equal, requires_crt from tests.integration import BaseTransferManagerIntegTest if HAS_CRT: from awscrt.exceptions import AwsCrtError import s3transfer.crt class RecordingSubscriber(BaseSubscriber): def __init__(self): self.on_queued_called = False self.on_done_called = False self.bytes_transferred = 0 def on_queued(self, **kwargs): self.on_queued_called = True def on_progress(self, future, bytes_transferred, **kwargs): self.bytes_transferred += bytes_transferred def on_done(self, **kwargs): self.on_done_called = True @requires_crt class TestCRTS3Transfers(BaseTransferManagerIntegTest): """Tests for the high level s3transfer based on CRT implementation.""" def _create_s3_transfer(self): self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer( self.session ) credetial_resolver = self.session.get_component('credential_provider') self.s3_crt_client = s3transfer.crt.create_s3_crt_client( self.session.get_config_variable("region"), credetial_resolver ) self.record_subscriber = RecordingSubscriber() self.osutil = OSUtils() return s3transfer.crt.CRTTransferManager( self.s3_crt_client, self.request_serializer ) def _assert_has_public_read_acl(self, response): grants = response['Grants'] public_read = [ g['Grantee'].get('URI', '') for g in grants if g['Permission'] == 'READ' ] self.assertIn('groups/global/AllUsers', public_read[0]) def _assert_subscribers_called(self, expected_bytes_transferred=None): self.assertTrue(self.record_subscriber.on_queued_called) self.assertTrue(self.record_subscriber.on_done_called) if expected_bytes_transferred: self.assertEqual( self.record_subscriber.bytes_transferred, expected_bytes_transferred, ) def test_upload_below_multipart_chunksize(self): transfer = self._create_s3_transfer() file_size = 1024 * 1024 filename = self.files.create_file_with_size( 'foo.txt', filesize=file_size ) self.addCleanup(self.delete_object, 'foo.txt') with transfer: future = transfer.upload( filename, self.bucket_name, 'foo.txt', subscribers=[self.record_subscriber], ) future.result() self.assertTrue(self.object_exists('foo.txt')) self._assert_subscribers_called(file_size) def test_upload_above_multipart_chunksize(self): transfer = self._create_s3_transfer() file_size = 20 * 1024 * 1024 filename = self.files.create_file_with_size( '20mb.txt', filesize=file_size ) self.addCleanup(self.delete_object, '20mb.txt') with transfer: future = transfer.upload( filename, self.bucket_name, '20mb.txt', subscribers=[self.record_subscriber], ) future.result() self.assertTrue(self.object_exists('20mb.txt')) self._assert_subscribers_called(file_size) def test_upload_file_above_threshold_with_acl(self): transfer = self._create_s3_transfer() file_size = 6 * 1024 * 1024 filename = self.files.create_file_with_size( '6mb.txt', filesize=file_size ) extra_args = {'ACL': 'public-read'} self.addCleanup(self.delete_object, '6mb.txt') with transfer: future = transfer.upload( filename, self.bucket_name, '6mb.txt', extra_args=extra_args, subscribers=[self.record_subscriber], ) future.result() self.assertTrue(self.object_exists('6mb.txt')) response = self.client.get_object_acl( Bucket=self.bucket_name, Key='6mb.txt' ) self._assert_has_public_read_acl(response) self._assert_subscribers_called(file_size) def test_upload_file_above_threshold_with_ssec(self): key_bytes = os.urandom(32) extra_args = { 'SSECustomerKey': key_bytes, 'SSECustomerAlgorithm': 'AES256', } file_size = 6 * 1024 * 1024 transfer = self._create_s3_transfer() filename = self.files.create_file_with_size( '6mb.txt', filesize=file_size ) self.addCleanup(self.delete_object, '6mb.txt') with transfer: future = transfer.upload( filename, self.bucket_name, '6mb.txt', extra_args=extra_args, subscribers=[self.record_subscriber], ) future.result() # A head object will fail if it has a customer key # associated with it and it's not provided in the HeadObject # request so we can use this to verify our functionality. oringal_extra_args = { 'SSECustomerKey': key_bytes, 'SSECustomerAlgorithm': 'AES256', } self.wait_object_exists('6mb.txt', oringal_extra_args) response = self.client.head_object( Bucket=self.bucket_name, Key='6mb.txt', **oringal_extra_args ) self.assertEqual(response['SSECustomerAlgorithm'], 'AES256') self._assert_subscribers_called(file_size) def test_can_send_extra_params_on_download(self): # We're picking the customer provided sse feature # of S3 to test the extra_args functionality of # S3. key_bytes = os.urandom(32) extra_args = { 'SSECustomerKey': key_bytes, 'SSECustomerAlgorithm': 'AES256', } filename = self.files.create_file('foo.txt', 'hello world') self.upload_file(filename, 'foo.txt', extra_args) transfer = self._create_s3_transfer() download_path = os.path.join(self.files.rootdir, 'downloaded.txt') with transfer: future = transfer.download( self.bucket_name, 'foo.txt', download_path, extra_args=extra_args, subscribers=[self.record_subscriber], ) future.result() file_size = self.osutil.get_file_size(download_path) self._assert_subscribers_called(file_size) with open(download_path, 'rb') as f: self.assertEqual(f.read(), b'hello world') def test_download_below_threshold(self): transfer = self._create_s3_transfer() filename = self.files.create_file_with_size( 'foo.txt', filesize=1024 * 1024 ) self.upload_file(filename, 'foo.txt') download_path = os.path.join(self.files.rootdir, 'downloaded.txt') with transfer: future = transfer.download( self.bucket_name, 'foo.txt', download_path, subscribers=[self.record_subscriber], ) future.result() file_size = self.osutil.get_file_size(download_path) self._assert_subscribers_called(file_size) assert_files_equal(filename, download_path) def test_download_above_threshold(self): transfer = self._create_s3_transfer() filename = self.files.create_file_with_size( 'foo.txt', filesize=20 * 1024 * 1024 ) self.upload_file(filename, 'foo.txt') download_path = os.path.join(self.files.rootdir, 'downloaded.txt') with transfer: future = transfer.download( self.bucket_name, 'foo.txt', download_path, subscribers=[self.record_subscriber], ) future.result() assert_files_equal(filename, download_path) file_size = self.osutil.get_file_size(download_path) self._assert_subscribers_called(file_size) def test_delete(self): transfer = self._create_s3_transfer() filename = self.files.create_file_with_size( 'foo.txt', filesize=1024 * 1024 ) self.upload_file(filename, 'foo.txt') with transfer: future = transfer.delete(self.bucket_name, 'foo.txt') future.result() self.assertTrue(self.object_not_exists('foo.txt')) def test_many_files_download(self): transfer = self._create_s3_transfer() filename = self.files.create_file_with_size( '1mb.txt', filesize=1024 * 1024 ) self.upload_file(filename, '1mb.txt') filenames = [] base_filename = os.path.join(self.files.rootdir, 'file') for i in range(10): filenames.append(base_filename + str(i)) with transfer: for filename in filenames: transfer.download(self.bucket_name, '1mb.txt', filename) for download_path in filenames: assert_files_equal(filename, download_path) def test_many_files_upload(self): transfer = self._create_s3_transfer() keys = [] filenames = [] base_key = 'foo' sufix = '.txt' for i in range(10): key = base_key + str(i) + sufix keys.append(key) filename = self.files.create_file_with_size( key, filesize=1024 * 1024 ) filenames.append(filename) self.addCleanup(self.delete_object, key) with transfer: for filename, key in zip(filenames, keys): transfer.upload(filename, self.bucket_name, key) for key in keys: self.assertTrue(self.object_exists(key)) def test_many_files_delete(self): transfer = self._create_s3_transfer() keys = [] base_key = 'foo' sufix = '.txt' filename = self.files.create_file_with_size( '1mb.txt', filesize=1024 * 1024 ) for i in range(10): key = base_key + str(i) + sufix keys.append(key) self.upload_file(filename, key) with transfer: for key in keys: transfer.delete(self.bucket_name, key) for key in keys: self.assertTrue(self.object_not_exists(key)) def test_upload_cancel(self): transfer = self._create_s3_transfer() filename = self.files.create_file_with_size( '20mb.txt', filesize=20 * 1024 * 1024 ) future = None try: with transfer: future = transfer.upload( filename, self.bucket_name, '20mb.txt' ) raise KeyboardInterrupt() except KeyboardInterrupt: pass with self.assertRaises(AwsCrtError) as cm: future.result() self.assertEqual(cm.name, 'AWS_ERROR_S3_CANCELED') self.assertTrue(self.object_not_exists('20mb.txt')) def test_download_cancel(self): transfer = self._create_s3_transfer() filename = self.files.create_file_with_size( 'foo.txt', filesize=20 * 1024 * 1024 ) self.upload_file(filename, 'foo.txt') download_path = os.path.join(self.files.rootdir, 'downloaded.txt') future = None try: with transfer: future = transfer.download( self.bucket_name, 'foo.txt', download_path, subscribers=[self.record_subscriber], ) raise KeyboardInterrupt() except KeyboardInterrupt: pass with self.assertRaises(AwsCrtError) as err: future.result() self.assertEqual(err.name, 'AWS_ERROR_S3_CANCELED') possible_matches = glob.glob('%s*' % download_path) self.assertEqual(possible_matches, []) self._assert_subscribers_called()