123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644 |
- # Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License"). You
- # may not use this file except in compliance with the License. A copy of
- # the License is located at
- #
- # http://aws.amazon.com/apache2.0/
- #
- # or in the "license" file accompanying this file. This file is
- # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
- # ANY KIND, either express or implied. See the License for the specific
- # language governing permissions and limitations under the License.
- import logging
- import threading
- from io import BytesIO
- import awscrt.http
- import botocore.awsrequest
- import botocore.session
- from awscrt.auth import AwsCredentials, AwsCredentialsProvider
- from awscrt.io import (
- ClientBootstrap,
- ClientTlsContext,
- DefaultHostResolver,
- EventLoopGroup,
- TlsContextOptions,
- )
- from awscrt.s3 import S3Client, S3RequestTlsMode, S3RequestType
- from botocore import UNSIGNED
- from botocore.compat import urlsplit
- from botocore.config import Config
- from botocore.exceptions import NoCredentialsError
- from s3transfer.constants import GB, MB
- from s3transfer.exceptions import TransferNotDoneError
- from s3transfer.futures import BaseTransferFuture, BaseTransferMeta
- from s3transfer.utils import CallArgs, OSUtils, get_callbacks
- logger = logging.getLogger(__name__)
- class CRTCredentialProviderAdapter:
- def __init__(self, botocore_credential_provider):
- self._botocore_credential_provider = botocore_credential_provider
- self._loaded_credentials = None
- self._lock = threading.Lock()
- def __call__(self):
- credentials = self._get_credentials().get_frozen_credentials()
- return AwsCredentials(
- credentials.access_key, credentials.secret_key, credentials.token
- )
- def _get_credentials(self):
- with self._lock:
- if self._loaded_credentials is None:
- loaded_creds = (
- self._botocore_credential_provider.load_credentials()
- )
- if loaded_creds is None:
- raise NoCredentialsError()
- self._loaded_credentials = loaded_creds
- return self._loaded_credentials
- def create_s3_crt_client(
- region,
- botocore_credential_provider=None,
- num_threads=None,
- target_throughput=5 * GB / 8,
- part_size=8 * MB,
- use_ssl=True,
- verify=None,
- ):
- """
- :type region: str
- :param region: The region used for signing
- :type botocore_credential_provider:
- Optional[botocore.credentials.CredentialResolver]
- :param botocore_credential_provider: Provide credentials for CRT
- to sign the request if not set, the request will not be signed
- :type num_threads: Optional[int]
- :param num_threads: Number of worker threads generated. Default
- is the number of processors in the machine.
- :type target_throughput: Optional[int]
- :param target_throughput: Throughput target in Bytes.
- Default is 0.625 GB/s (which translates to 5 Gb/s).
- :type part_size: Optional[int]
- :param part_size: Size, in Bytes, of parts that files will be downloaded
- or uploaded in.
- :type use_ssl: boolean
- :param use_ssl: Whether or not to use SSL. By default, SSL is used.
- Note that not all services support non-ssl connections.
- :type verify: Optional[boolean/string]
- :param verify: Whether or not to verify SSL certificates.
- By default SSL certificates are verified. You can provide the
- following values:
- * False - do not validate SSL certificates. SSL will still be
- used (unless use_ssl is False), but SSL certificates
- will not be verified.
- * path/to/cert/bundle.pem - A filename of the CA cert bundle to
- use. Specify this argument if you want to use a custom CA cert
- bundle instead of the default one on your system.
- """
- event_loop_group = EventLoopGroup(num_threads)
- host_resolver = DefaultHostResolver(event_loop_group)
- bootstrap = ClientBootstrap(event_loop_group, host_resolver)
- provider = None
- tls_connection_options = None
- tls_mode = (
- S3RequestTlsMode.ENABLED if use_ssl else S3RequestTlsMode.DISABLED
- )
- if verify is not None:
- tls_ctx_options = TlsContextOptions()
- if verify:
- tls_ctx_options.override_default_trust_store_from_path(
- ca_filepath=verify
- )
- else:
- tls_ctx_options.verify_peer = False
- client_tls_option = ClientTlsContext(tls_ctx_options)
- tls_connection_options = client_tls_option.new_connection_options()
- if botocore_credential_provider:
- credentails_provider_adapter = CRTCredentialProviderAdapter(
- botocore_credential_provider
- )
- provider = AwsCredentialsProvider.new_delegate(
- credentails_provider_adapter
- )
- target_gbps = target_throughput * 8 / GB
- return S3Client(
- bootstrap=bootstrap,
- region=region,
- credential_provider=provider,
- part_size=part_size,
- tls_mode=tls_mode,
- tls_connection_options=tls_connection_options,
- throughput_target_gbps=target_gbps,
- )
- class CRTTransferManager:
- def __init__(self, crt_s3_client, crt_request_serializer, osutil=None):
- """A transfer manager interface for Amazon S3 on CRT s3 client.
- :type crt_s3_client: awscrt.s3.S3Client
- :param crt_s3_client: The CRT s3 client, handling all the
- HTTP requests and functions under then hood
- :type crt_request_serializer: s3transfer.crt.BaseCRTRequestSerializer
- :param crt_request_serializer: Serializer, generates unsigned crt HTTP
- request.
- :type osutil: s3transfer.utils.OSUtils
- :param osutil: OSUtils object to use for os-related behavior when
- using with transfer manager.
- """
- if osutil is None:
- self._osutil = OSUtils()
- self._crt_s3_client = crt_s3_client
- self._s3_args_creator = S3ClientArgsCreator(
- crt_request_serializer, self._osutil
- )
- self._future_coordinators = []
- self._semaphore = threading.Semaphore(128) # not configurable
- # A counter to create unique id's for each transfer submitted.
- self._id_counter = 0
- def __enter__(self):
- return self
- def __exit__(self, exc_type, exc_value, *args):
- cancel = False
- if exc_type:
- cancel = True
- self._shutdown(cancel)
- def download(
- self, bucket, key, fileobj, extra_args=None, subscribers=None
- ):
- if extra_args is None:
- extra_args = {}
- if subscribers is None:
- subscribers = {}
- callargs = CallArgs(
- bucket=bucket,
- key=key,
- fileobj=fileobj,
- extra_args=extra_args,
- subscribers=subscribers,
- )
- return self._submit_transfer("get_object", callargs)
- def upload(self, fileobj, bucket, key, extra_args=None, subscribers=None):
- if extra_args is None:
- extra_args = {}
- if subscribers is None:
- subscribers = {}
- callargs = CallArgs(
- bucket=bucket,
- key=key,
- fileobj=fileobj,
- extra_args=extra_args,
- subscribers=subscribers,
- )
- return self._submit_transfer("put_object", callargs)
- def delete(self, bucket, key, extra_args=None, subscribers=None):
- if extra_args is None:
- extra_args = {}
- if subscribers is None:
- subscribers = {}
- callargs = CallArgs(
- bucket=bucket,
- key=key,
- extra_args=extra_args,
- subscribers=subscribers,
- )
- return self._submit_transfer("delete_object", callargs)
- def shutdown(self, cancel=False):
- self._shutdown(cancel)
- def _cancel_transfers(self):
- for coordinator in self._future_coordinators:
- if not coordinator.done():
- coordinator.cancel()
- def _finish_transfers(self):
- for coordinator in self._future_coordinators:
- coordinator.result()
- def _wait_transfers_done(self):
- for coordinator in self._future_coordinators:
- coordinator.wait_until_on_done_callbacks_complete()
- def _shutdown(self, cancel=False):
- if cancel:
- self._cancel_transfers()
- try:
- self._finish_transfers()
- except KeyboardInterrupt:
- self._cancel_transfers()
- except Exception:
- pass
- finally:
- self._wait_transfers_done()
- def _release_semaphore(self, **kwargs):
- self._semaphore.release()
- def _submit_transfer(self, request_type, call_args):
- on_done_after_calls = [self._release_semaphore]
- coordinator = CRTTransferCoordinator(transfer_id=self._id_counter)
- components = {
- 'meta': CRTTransferMeta(self._id_counter, call_args),
- 'coordinator': coordinator,
- }
- future = CRTTransferFuture(**components)
- afterdone = AfterDoneHandler(coordinator)
- on_done_after_calls.append(afterdone)
- try:
- self._semaphore.acquire()
- on_queued = self._s3_args_creator.get_crt_callback(
- future, 'queued'
- )
- on_queued()
- crt_callargs = self._s3_args_creator.get_make_request_args(
- request_type,
- call_args,
- coordinator,
- future,
- on_done_after_calls,
- )
- crt_s3_request = self._crt_s3_client.make_request(**crt_callargs)
- except Exception as e:
- coordinator.set_exception(e, True)
- on_done = self._s3_args_creator.get_crt_callback(
- future, 'done', after_subscribers=on_done_after_calls
- )
- on_done(error=e)
- else:
- coordinator.set_s3_request(crt_s3_request)
- self._future_coordinators.append(coordinator)
- self._id_counter += 1
- return future
- class CRTTransferMeta(BaseTransferMeta):
- """Holds metadata about the CRTTransferFuture"""
- def __init__(self, transfer_id=None, call_args=None):
- self._transfer_id = transfer_id
- self._call_args = call_args
- self._user_context = {}
- @property
- def call_args(self):
- return self._call_args
- @property
- def transfer_id(self):
- return self._transfer_id
- @property
- def user_context(self):
- return self._user_context
- class CRTTransferFuture(BaseTransferFuture):
- def __init__(self, meta=None, coordinator=None):
- """The future associated to a submitted transfer request via CRT S3 client
- :type meta: s3transfer.crt.CRTTransferMeta
- :param meta: The metadata associated to the transfer future.
- :type coordinator: s3transfer.crt.CRTTransferCoordinator
- :param coordinator: The coordinator associated to the transfer future.
- """
- self._meta = meta
- if meta is None:
- self._meta = CRTTransferMeta()
- self._coordinator = coordinator
- @property
- def meta(self):
- return self._meta
- def done(self):
- return self._coordinator.done()
- def result(self, timeout=None):
- self._coordinator.result(timeout)
- def cancel(self):
- self._coordinator.cancel()
- def set_exception(self, exception):
- """Sets the exception on the future."""
- if not self.done():
- raise TransferNotDoneError(
- 'set_exception can only be called once the transfer is '
- 'complete.'
- )
- self._coordinator.set_exception(exception, override=True)
- class BaseCRTRequestSerializer:
- def serialize_http_request(self, transfer_type, future):
- """Serialize CRT HTTP requests.
- :type transfer_type: string
- :param transfer_type: the type of transfer made,
- e.g 'put_object', 'get_object', 'delete_object'
- :type future: s3transfer.crt.CRTTransferFuture
- :rtype: awscrt.http.HttpRequest
- :returns: An unsigned HTTP request to be used for the CRT S3 client
- """
- raise NotImplementedError('serialize_http_request()')
- class BotocoreCRTRequestSerializer(BaseCRTRequestSerializer):
- def __init__(self, session, client_kwargs=None):
- """Serialize CRT HTTP request using botocore logic
- It also takes into account configuration from both the session
- and any keyword arguments that could be passed to
- `Session.create_client()` when serializing the request.
- :type session: botocore.session.Session
- :type client_kwargs: Optional[Dict[str, str]])
- :param client_kwargs: The kwargs for the botocore
- s3 client initialization.
- """
- self._session = session
- if client_kwargs is None:
- client_kwargs = {}
- self._resolve_client_config(session, client_kwargs)
- self._client = session.create_client(**client_kwargs)
- self._client.meta.events.register(
- 'request-created.s3.*', self._capture_http_request
- )
- self._client.meta.events.register(
- 'after-call.s3.*', self._change_response_to_serialized_http_request
- )
- self._client.meta.events.register(
- 'before-send.s3.*', self._make_fake_http_response
- )
- def _resolve_client_config(self, session, client_kwargs):
- user_provided_config = None
- if session.get_default_client_config():
- user_provided_config = session.get_default_client_config()
- if 'config' in client_kwargs:
- user_provided_config = client_kwargs['config']
- client_config = Config(signature_version=UNSIGNED)
- if user_provided_config:
- client_config = user_provided_config.merge(client_config)
- client_kwargs['config'] = client_config
- client_kwargs["service_name"] = "s3"
- def _crt_request_from_aws_request(self, aws_request):
- url_parts = urlsplit(aws_request.url)
- crt_path = url_parts.path
- if url_parts.query:
- crt_path = f'{crt_path}?{url_parts.query}'
- headers_list = []
- for name, value in aws_request.headers.items():
- if isinstance(value, str):
- headers_list.append((name, value))
- else:
- headers_list.append((name, str(value, 'utf-8')))
- crt_headers = awscrt.http.HttpHeaders(headers_list)
- # CRT requires body (if it exists) to be an I/O stream.
- crt_body_stream = None
- if aws_request.body:
- if hasattr(aws_request.body, 'seek'):
- crt_body_stream = aws_request.body
- else:
- crt_body_stream = BytesIO(aws_request.body)
- crt_request = awscrt.http.HttpRequest(
- method=aws_request.method,
- path=crt_path,
- headers=crt_headers,
- body_stream=crt_body_stream,
- )
- return crt_request
- def _convert_to_crt_http_request(self, botocore_http_request):
- # Logic that does CRTUtils.crt_request_from_aws_request
- crt_request = self._crt_request_from_aws_request(botocore_http_request)
- if crt_request.headers.get("host") is None:
- # If host is not set, set it for the request before using CRT s3
- url_parts = urlsplit(botocore_http_request.url)
- crt_request.headers.set("host", url_parts.netloc)
- if crt_request.headers.get('Content-MD5') is not None:
- crt_request.headers.remove("Content-MD5")
- return crt_request
- def _capture_http_request(self, request, **kwargs):
- request.context['http_request'] = request
- def _change_response_to_serialized_http_request(
- self, context, parsed, **kwargs
- ):
- request = context['http_request']
- parsed['HTTPRequest'] = request.prepare()
- def _make_fake_http_response(self, request, **kwargs):
- return botocore.awsrequest.AWSResponse(
- None,
- 200,
- {},
- FakeRawResponse(b""),
- )
- def _get_botocore_http_request(self, client_method, call_args):
- return getattr(self._client, client_method)(
- Bucket=call_args.bucket, Key=call_args.key, **call_args.extra_args
- )['HTTPRequest']
- def serialize_http_request(self, transfer_type, future):
- botocore_http_request = self._get_botocore_http_request(
- transfer_type, future.meta.call_args
- )
- crt_request = self._convert_to_crt_http_request(botocore_http_request)
- return crt_request
- class FakeRawResponse(BytesIO):
- def stream(self, amt=1024, decode_content=None):
- while True:
- chunk = self.read(amt)
- if not chunk:
- break
- yield chunk
- class CRTTransferCoordinator:
- """A helper class for managing CRTTransferFuture"""
- def __init__(self, transfer_id=None, s3_request=None):
- self.transfer_id = transfer_id
- self._s3_request = s3_request
- self._lock = threading.Lock()
- self._exception = None
- self._crt_future = None
- self._done_event = threading.Event()
- @property
- def s3_request(self):
- return self._s3_request
- def set_done_callbacks_complete(self):
- self._done_event.set()
- def wait_until_on_done_callbacks_complete(self, timeout=None):
- self._done_event.wait(timeout)
- def set_exception(self, exception, override=False):
- with self._lock:
- if not self.done() or override:
- self._exception = exception
- def cancel(self):
- if self._s3_request:
- self._s3_request.cancel()
- def result(self, timeout=None):
- if self._exception:
- raise self._exception
- try:
- self._crt_future.result(timeout)
- except KeyboardInterrupt:
- self.cancel()
- raise
- finally:
- if self._s3_request:
- self._s3_request = None
- self._crt_future.result(timeout)
- def done(self):
- if self._crt_future is None:
- return False
- return self._crt_future.done()
- def set_s3_request(self, s3_request):
- self._s3_request = s3_request
- self._crt_future = self._s3_request.finished_future
- class S3ClientArgsCreator:
- def __init__(self, crt_request_serializer, os_utils):
- self._request_serializer = crt_request_serializer
- self._os_utils = os_utils
- def get_make_request_args(
- self, request_type, call_args, coordinator, future, on_done_after_calls
- ):
- recv_filepath = None
- send_filepath = None
- s3_meta_request_type = getattr(
- S3RequestType, request_type.upper(), S3RequestType.DEFAULT
- )
- on_done_before_calls = []
- if s3_meta_request_type == S3RequestType.GET_OBJECT:
- final_filepath = call_args.fileobj
- recv_filepath = self._os_utils.get_temp_filename(final_filepath)
- file_ondone_call = RenameTempFileHandler(
- coordinator, final_filepath, recv_filepath, self._os_utils
- )
- on_done_before_calls.append(file_ondone_call)
- elif s3_meta_request_type == S3RequestType.PUT_OBJECT:
- send_filepath = call_args.fileobj
- data_len = self._os_utils.get_file_size(send_filepath)
- call_args.extra_args["ContentLength"] = data_len
- crt_request = self._request_serializer.serialize_http_request(
- request_type, future
- )
- return {
- 'request': crt_request,
- 'type': s3_meta_request_type,
- 'recv_filepath': recv_filepath,
- 'send_filepath': send_filepath,
- 'on_done': self.get_crt_callback(
- future, 'done', on_done_before_calls, on_done_after_calls
- ),
- 'on_progress': self.get_crt_callback(future, 'progress'),
- }
- def get_crt_callback(
- self,
- future,
- callback_type,
- before_subscribers=None,
- after_subscribers=None,
- ):
- def invoke_all_callbacks(*args, **kwargs):
- callbacks_list = []
- if before_subscribers is not None:
- callbacks_list += before_subscribers
- callbacks_list += get_callbacks(future, callback_type)
- if after_subscribers is not None:
- callbacks_list += after_subscribers
- for callback in callbacks_list:
- # The get_callbacks helper will set the first augment
- # by keyword, the other augments need to be set by keyword
- # as well
- if callback_type == "progress":
- callback(bytes_transferred=args[0])
- else:
- callback(*args, **kwargs)
- return invoke_all_callbacks
- class RenameTempFileHandler:
- def __init__(self, coordinator, final_filename, temp_filename, osutil):
- self._coordinator = coordinator
- self._final_filename = final_filename
- self._temp_filename = temp_filename
- self._osutil = osutil
- def __call__(self, **kwargs):
- error = kwargs['error']
- if error:
- self._osutil.remove_file(self._temp_filename)
- else:
- try:
- self._osutil.rename_file(
- self._temp_filename, self._final_filename
- )
- except Exception as e:
- self._osutil.remove_file(self._temp_filename)
- # the CRT future has done already at this point
- self._coordinator.set_exception(e)
- class AfterDoneHandler:
- def __init__(self, coordinator):
- self._coordinator = coordinator
- def __call__(self, **kwargs):
- self._coordinator.set_done_callbacks_complete()
|