crt.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644
  1. # Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"). You
  4. # may not use this file except in compliance with the License. A copy of
  5. # the License is located at
  6. #
  7. # http://aws.amazon.com/apache2.0/
  8. #
  9. # or in the "license" file accompanying this file. This file is
  10. # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
  11. # ANY KIND, either express or implied. See the License for the specific
  12. # language governing permissions and limitations under the License.
  13. import logging
  14. import threading
  15. from io import BytesIO
  16. import awscrt.http
  17. import botocore.awsrequest
  18. import botocore.session
  19. from awscrt.auth import AwsCredentials, AwsCredentialsProvider
  20. from awscrt.io import (
  21. ClientBootstrap,
  22. ClientTlsContext,
  23. DefaultHostResolver,
  24. EventLoopGroup,
  25. TlsContextOptions,
  26. )
  27. from awscrt.s3 import S3Client, S3RequestTlsMode, S3RequestType
  28. from botocore import UNSIGNED
  29. from botocore.compat import urlsplit
  30. from botocore.config import Config
  31. from botocore.exceptions import NoCredentialsError
  32. from s3transfer.constants import GB, MB
  33. from s3transfer.exceptions import TransferNotDoneError
  34. from s3transfer.futures import BaseTransferFuture, BaseTransferMeta
  35. from s3transfer.utils import CallArgs, OSUtils, get_callbacks
  36. logger = logging.getLogger(__name__)
  37. class CRTCredentialProviderAdapter:
  38. def __init__(self, botocore_credential_provider):
  39. self._botocore_credential_provider = botocore_credential_provider
  40. self._loaded_credentials = None
  41. self._lock = threading.Lock()
  42. def __call__(self):
  43. credentials = self._get_credentials().get_frozen_credentials()
  44. return AwsCredentials(
  45. credentials.access_key, credentials.secret_key, credentials.token
  46. )
  47. def _get_credentials(self):
  48. with self._lock:
  49. if self._loaded_credentials is None:
  50. loaded_creds = (
  51. self._botocore_credential_provider.load_credentials()
  52. )
  53. if loaded_creds is None:
  54. raise NoCredentialsError()
  55. self._loaded_credentials = loaded_creds
  56. return self._loaded_credentials
  57. def create_s3_crt_client(
  58. region,
  59. botocore_credential_provider=None,
  60. num_threads=None,
  61. target_throughput=5 * GB / 8,
  62. part_size=8 * MB,
  63. use_ssl=True,
  64. verify=None,
  65. ):
  66. """
  67. :type region: str
  68. :param region: The region used for signing
  69. :type botocore_credential_provider:
  70. Optional[botocore.credentials.CredentialResolver]
  71. :param botocore_credential_provider: Provide credentials for CRT
  72. to sign the request if not set, the request will not be signed
  73. :type num_threads: Optional[int]
  74. :param num_threads: Number of worker threads generated. Default
  75. is the number of processors in the machine.
  76. :type target_throughput: Optional[int]
  77. :param target_throughput: Throughput target in Bytes.
  78. Default is 0.625 GB/s (which translates to 5 Gb/s).
  79. :type part_size: Optional[int]
  80. :param part_size: Size, in Bytes, of parts that files will be downloaded
  81. or uploaded in.
  82. :type use_ssl: boolean
  83. :param use_ssl: Whether or not to use SSL. By default, SSL is used.
  84. Note that not all services support non-ssl connections.
  85. :type verify: Optional[boolean/string]
  86. :param verify: Whether or not to verify SSL certificates.
  87. By default SSL certificates are verified. You can provide the
  88. following values:
  89. * False - do not validate SSL certificates. SSL will still be
  90. used (unless use_ssl is False), but SSL certificates
  91. will not be verified.
  92. * path/to/cert/bundle.pem - A filename of the CA cert bundle to
  93. use. Specify this argument if you want to use a custom CA cert
  94. bundle instead of the default one on your system.
  95. """
  96. event_loop_group = EventLoopGroup(num_threads)
  97. host_resolver = DefaultHostResolver(event_loop_group)
  98. bootstrap = ClientBootstrap(event_loop_group, host_resolver)
  99. provider = None
  100. tls_connection_options = None
  101. tls_mode = (
  102. S3RequestTlsMode.ENABLED if use_ssl else S3RequestTlsMode.DISABLED
  103. )
  104. if verify is not None:
  105. tls_ctx_options = TlsContextOptions()
  106. if verify:
  107. tls_ctx_options.override_default_trust_store_from_path(
  108. ca_filepath=verify
  109. )
  110. else:
  111. tls_ctx_options.verify_peer = False
  112. client_tls_option = ClientTlsContext(tls_ctx_options)
  113. tls_connection_options = client_tls_option.new_connection_options()
  114. if botocore_credential_provider:
  115. credentails_provider_adapter = CRTCredentialProviderAdapter(
  116. botocore_credential_provider
  117. )
  118. provider = AwsCredentialsProvider.new_delegate(
  119. credentails_provider_adapter
  120. )
  121. target_gbps = target_throughput * 8 / GB
  122. return S3Client(
  123. bootstrap=bootstrap,
  124. region=region,
  125. credential_provider=provider,
  126. part_size=part_size,
  127. tls_mode=tls_mode,
  128. tls_connection_options=tls_connection_options,
  129. throughput_target_gbps=target_gbps,
  130. )
  131. class CRTTransferManager:
  132. def __init__(self, crt_s3_client, crt_request_serializer, osutil=None):
  133. """A transfer manager interface for Amazon S3 on CRT s3 client.
  134. :type crt_s3_client: awscrt.s3.S3Client
  135. :param crt_s3_client: The CRT s3 client, handling all the
  136. HTTP requests and functions under then hood
  137. :type crt_request_serializer: s3transfer.crt.BaseCRTRequestSerializer
  138. :param crt_request_serializer: Serializer, generates unsigned crt HTTP
  139. request.
  140. :type osutil: s3transfer.utils.OSUtils
  141. :param osutil: OSUtils object to use for os-related behavior when
  142. using with transfer manager.
  143. """
  144. if osutil is None:
  145. self._osutil = OSUtils()
  146. self._crt_s3_client = crt_s3_client
  147. self._s3_args_creator = S3ClientArgsCreator(
  148. crt_request_serializer, self._osutil
  149. )
  150. self._future_coordinators = []
  151. self._semaphore = threading.Semaphore(128) # not configurable
  152. # A counter to create unique id's for each transfer submitted.
  153. self._id_counter = 0
  154. def __enter__(self):
  155. return self
  156. def __exit__(self, exc_type, exc_value, *args):
  157. cancel = False
  158. if exc_type:
  159. cancel = True
  160. self._shutdown(cancel)
  161. def download(
  162. self, bucket, key, fileobj, extra_args=None, subscribers=None
  163. ):
  164. if extra_args is None:
  165. extra_args = {}
  166. if subscribers is None:
  167. subscribers = {}
  168. callargs = CallArgs(
  169. bucket=bucket,
  170. key=key,
  171. fileobj=fileobj,
  172. extra_args=extra_args,
  173. subscribers=subscribers,
  174. )
  175. return self._submit_transfer("get_object", callargs)
  176. def upload(self, fileobj, bucket, key, extra_args=None, subscribers=None):
  177. if extra_args is None:
  178. extra_args = {}
  179. if subscribers is None:
  180. subscribers = {}
  181. callargs = CallArgs(
  182. bucket=bucket,
  183. key=key,
  184. fileobj=fileobj,
  185. extra_args=extra_args,
  186. subscribers=subscribers,
  187. )
  188. return self._submit_transfer("put_object", callargs)
  189. def delete(self, bucket, key, extra_args=None, subscribers=None):
  190. if extra_args is None:
  191. extra_args = {}
  192. if subscribers is None:
  193. subscribers = {}
  194. callargs = CallArgs(
  195. bucket=bucket,
  196. key=key,
  197. extra_args=extra_args,
  198. subscribers=subscribers,
  199. )
  200. return self._submit_transfer("delete_object", callargs)
  201. def shutdown(self, cancel=False):
  202. self._shutdown(cancel)
  203. def _cancel_transfers(self):
  204. for coordinator in self._future_coordinators:
  205. if not coordinator.done():
  206. coordinator.cancel()
  207. def _finish_transfers(self):
  208. for coordinator in self._future_coordinators:
  209. coordinator.result()
  210. def _wait_transfers_done(self):
  211. for coordinator in self._future_coordinators:
  212. coordinator.wait_until_on_done_callbacks_complete()
  213. def _shutdown(self, cancel=False):
  214. if cancel:
  215. self._cancel_transfers()
  216. try:
  217. self._finish_transfers()
  218. except KeyboardInterrupt:
  219. self._cancel_transfers()
  220. except Exception:
  221. pass
  222. finally:
  223. self._wait_transfers_done()
  224. def _release_semaphore(self, **kwargs):
  225. self._semaphore.release()
  226. def _submit_transfer(self, request_type, call_args):
  227. on_done_after_calls = [self._release_semaphore]
  228. coordinator = CRTTransferCoordinator(transfer_id=self._id_counter)
  229. components = {
  230. 'meta': CRTTransferMeta(self._id_counter, call_args),
  231. 'coordinator': coordinator,
  232. }
  233. future = CRTTransferFuture(**components)
  234. afterdone = AfterDoneHandler(coordinator)
  235. on_done_after_calls.append(afterdone)
  236. try:
  237. self._semaphore.acquire()
  238. on_queued = self._s3_args_creator.get_crt_callback(
  239. future, 'queued'
  240. )
  241. on_queued()
  242. crt_callargs = self._s3_args_creator.get_make_request_args(
  243. request_type,
  244. call_args,
  245. coordinator,
  246. future,
  247. on_done_after_calls,
  248. )
  249. crt_s3_request = self._crt_s3_client.make_request(**crt_callargs)
  250. except Exception as e:
  251. coordinator.set_exception(e, True)
  252. on_done = self._s3_args_creator.get_crt_callback(
  253. future, 'done', after_subscribers=on_done_after_calls
  254. )
  255. on_done(error=e)
  256. else:
  257. coordinator.set_s3_request(crt_s3_request)
  258. self._future_coordinators.append(coordinator)
  259. self._id_counter += 1
  260. return future
  261. class CRTTransferMeta(BaseTransferMeta):
  262. """Holds metadata about the CRTTransferFuture"""
  263. def __init__(self, transfer_id=None, call_args=None):
  264. self._transfer_id = transfer_id
  265. self._call_args = call_args
  266. self._user_context = {}
  267. @property
  268. def call_args(self):
  269. return self._call_args
  270. @property
  271. def transfer_id(self):
  272. return self._transfer_id
  273. @property
  274. def user_context(self):
  275. return self._user_context
  276. class CRTTransferFuture(BaseTransferFuture):
  277. def __init__(self, meta=None, coordinator=None):
  278. """The future associated to a submitted transfer request via CRT S3 client
  279. :type meta: s3transfer.crt.CRTTransferMeta
  280. :param meta: The metadata associated to the transfer future.
  281. :type coordinator: s3transfer.crt.CRTTransferCoordinator
  282. :param coordinator: The coordinator associated to the transfer future.
  283. """
  284. self._meta = meta
  285. if meta is None:
  286. self._meta = CRTTransferMeta()
  287. self._coordinator = coordinator
  288. @property
  289. def meta(self):
  290. return self._meta
  291. def done(self):
  292. return self._coordinator.done()
  293. def result(self, timeout=None):
  294. self._coordinator.result(timeout)
  295. def cancel(self):
  296. self._coordinator.cancel()
  297. def set_exception(self, exception):
  298. """Sets the exception on the future."""
  299. if not self.done():
  300. raise TransferNotDoneError(
  301. 'set_exception can only be called once the transfer is '
  302. 'complete.'
  303. )
  304. self._coordinator.set_exception(exception, override=True)
  305. class BaseCRTRequestSerializer:
  306. def serialize_http_request(self, transfer_type, future):
  307. """Serialize CRT HTTP requests.
  308. :type transfer_type: string
  309. :param transfer_type: the type of transfer made,
  310. e.g 'put_object', 'get_object', 'delete_object'
  311. :type future: s3transfer.crt.CRTTransferFuture
  312. :rtype: awscrt.http.HttpRequest
  313. :returns: An unsigned HTTP request to be used for the CRT S3 client
  314. """
  315. raise NotImplementedError('serialize_http_request()')
  316. class BotocoreCRTRequestSerializer(BaseCRTRequestSerializer):
  317. def __init__(self, session, client_kwargs=None):
  318. """Serialize CRT HTTP request using botocore logic
  319. It also takes into account configuration from both the session
  320. and any keyword arguments that could be passed to
  321. `Session.create_client()` when serializing the request.
  322. :type session: botocore.session.Session
  323. :type client_kwargs: Optional[Dict[str, str]])
  324. :param client_kwargs: The kwargs for the botocore
  325. s3 client initialization.
  326. """
  327. self._session = session
  328. if client_kwargs is None:
  329. client_kwargs = {}
  330. self._resolve_client_config(session, client_kwargs)
  331. self._client = session.create_client(**client_kwargs)
  332. self._client.meta.events.register(
  333. 'request-created.s3.*', self._capture_http_request
  334. )
  335. self._client.meta.events.register(
  336. 'after-call.s3.*', self._change_response_to_serialized_http_request
  337. )
  338. self._client.meta.events.register(
  339. 'before-send.s3.*', self._make_fake_http_response
  340. )
  341. def _resolve_client_config(self, session, client_kwargs):
  342. user_provided_config = None
  343. if session.get_default_client_config():
  344. user_provided_config = session.get_default_client_config()
  345. if 'config' in client_kwargs:
  346. user_provided_config = client_kwargs['config']
  347. client_config = Config(signature_version=UNSIGNED)
  348. if user_provided_config:
  349. client_config = user_provided_config.merge(client_config)
  350. client_kwargs['config'] = client_config
  351. client_kwargs["service_name"] = "s3"
  352. def _crt_request_from_aws_request(self, aws_request):
  353. url_parts = urlsplit(aws_request.url)
  354. crt_path = url_parts.path
  355. if url_parts.query:
  356. crt_path = f'{crt_path}?{url_parts.query}'
  357. headers_list = []
  358. for name, value in aws_request.headers.items():
  359. if isinstance(value, str):
  360. headers_list.append((name, value))
  361. else:
  362. headers_list.append((name, str(value, 'utf-8')))
  363. crt_headers = awscrt.http.HttpHeaders(headers_list)
  364. # CRT requires body (if it exists) to be an I/O stream.
  365. crt_body_stream = None
  366. if aws_request.body:
  367. if hasattr(aws_request.body, 'seek'):
  368. crt_body_stream = aws_request.body
  369. else:
  370. crt_body_stream = BytesIO(aws_request.body)
  371. crt_request = awscrt.http.HttpRequest(
  372. method=aws_request.method,
  373. path=crt_path,
  374. headers=crt_headers,
  375. body_stream=crt_body_stream,
  376. )
  377. return crt_request
  378. def _convert_to_crt_http_request(self, botocore_http_request):
  379. # Logic that does CRTUtils.crt_request_from_aws_request
  380. crt_request = self._crt_request_from_aws_request(botocore_http_request)
  381. if crt_request.headers.get("host") is None:
  382. # If host is not set, set it for the request before using CRT s3
  383. url_parts = urlsplit(botocore_http_request.url)
  384. crt_request.headers.set("host", url_parts.netloc)
  385. if crt_request.headers.get('Content-MD5') is not None:
  386. crt_request.headers.remove("Content-MD5")
  387. return crt_request
  388. def _capture_http_request(self, request, **kwargs):
  389. request.context['http_request'] = request
  390. def _change_response_to_serialized_http_request(
  391. self, context, parsed, **kwargs
  392. ):
  393. request = context['http_request']
  394. parsed['HTTPRequest'] = request.prepare()
  395. def _make_fake_http_response(self, request, **kwargs):
  396. return botocore.awsrequest.AWSResponse(
  397. None,
  398. 200,
  399. {},
  400. FakeRawResponse(b""),
  401. )
  402. def _get_botocore_http_request(self, client_method, call_args):
  403. return getattr(self._client, client_method)(
  404. Bucket=call_args.bucket, Key=call_args.key, **call_args.extra_args
  405. )['HTTPRequest']
  406. def serialize_http_request(self, transfer_type, future):
  407. botocore_http_request = self._get_botocore_http_request(
  408. transfer_type, future.meta.call_args
  409. )
  410. crt_request = self._convert_to_crt_http_request(botocore_http_request)
  411. return crt_request
  412. class FakeRawResponse(BytesIO):
  413. def stream(self, amt=1024, decode_content=None):
  414. while True:
  415. chunk = self.read(amt)
  416. if not chunk:
  417. break
  418. yield chunk
  419. class CRTTransferCoordinator:
  420. """A helper class for managing CRTTransferFuture"""
  421. def __init__(self, transfer_id=None, s3_request=None):
  422. self.transfer_id = transfer_id
  423. self._s3_request = s3_request
  424. self._lock = threading.Lock()
  425. self._exception = None
  426. self._crt_future = None
  427. self._done_event = threading.Event()
  428. @property
  429. def s3_request(self):
  430. return self._s3_request
  431. def set_done_callbacks_complete(self):
  432. self._done_event.set()
  433. def wait_until_on_done_callbacks_complete(self, timeout=None):
  434. self._done_event.wait(timeout)
  435. def set_exception(self, exception, override=False):
  436. with self._lock:
  437. if not self.done() or override:
  438. self._exception = exception
  439. def cancel(self):
  440. if self._s3_request:
  441. self._s3_request.cancel()
  442. def result(self, timeout=None):
  443. if self._exception:
  444. raise self._exception
  445. try:
  446. self._crt_future.result(timeout)
  447. except KeyboardInterrupt:
  448. self.cancel()
  449. raise
  450. finally:
  451. if self._s3_request:
  452. self._s3_request = None
  453. self._crt_future.result(timeout)
  454. def done(self):
  455. if self._crt_future is None:
  456. return False
  457. return self._crt_future.done()
  458. def set_s3_request(self, s3_request):
  459. self._s3_request = s3_request
  460. self._crt_future = self._s3_request.finished_future
  461. class S3ClientArgsCreator:
  462. def __init__(self, crt_request_serializer, os_utils):
  463. self._request_serializer = crt_request_serializer
  464. self._os_utils = os_utils
  465. def get_make_request_args(
  466. self, request_type, call_args, coordinator, future, on_done_after_calls
  467. ):
  468. recv_filepath = None
  469. send_filepath = None
  470. s3_meta_request_type = getattr(
  471. S3RequestType, request_type.upper(), S3RequestType.DEFAULT
  472. )
  473. on_done_before_calls = []
  474. if s3_meta_request_type == S3RequestType.GET_OBJECT:
  475. final_filepath = call_args.fileobj
  476. recv_filepath = self._os_utils.get_temp_filename(final_filepath)
  477. file_ondone_call = RenameTempFileHandler(
  478. coordinator, final_filepath, recv_filepath, self._os_utils
  479. )
  480. on_done_before_calls.append(file_ondone_call)
  481. elif s3_meta_request_type == S3RequestType.PUT_OBJECT:
  482. send_filepath = call_args.fileobj
  483. data_len = self._os_utils.get_file_size(send_filepath)
  484. call_args.extra_args["ContentLength"] = data_len
  485. crt_request = self._request_serializer.serialize_http_request(
  486. request_type, future
  487. )
  488. return {
  489. 'request': crt_request,
  490. 'type': s3_meta_request_type,
  491. 'recv_filepath': recv_filepath,
  492. 'send_filepath': send_filepath,
  493. 'on_done': self.get_crt_callback(
  494. future, 'done', on_done_before_calls, on_done_after_calls
  495. ),
  496. 'on_progress': self.get_crt_callback(future, 'progress'),
  497. }
  498. def get_crt_callback(
  499. self,
  500. future,
  501. callback_type,
  502. before_subscribers=None,
  503. after_subscribers=None,
  504. ):
  505. def invoke_all_callbacks(*args, **kwargs):
  506. callbacks_list = []
  507. if before_subscribers is not None:
  508. callbacks_list += before_subscribers
  509. callbacks_list += get_callbacks(future, callback_type)
  510. if after_subscribers is not None:
  511. callbacks_list += after_subscribers
  512. for callback in callbacks_list:
  513. # The get_callbacks helper will set the first augment
  514. # by keyword, the other augments need to be set by keyword
  515. # as well
  516. if callback_type == "progress":
  517. callback(bytes_transferred=args[0])
  518. else:
  519. callback(*args, **kwargs)
  520. return invoke_all_callbacks
  521. class RenameTempFileHandler:
  522. def __init__(self, coordinator, final_filename, temp_filename, osutil):
  523. self._coordinator = coordinator
  524. self._final_filename = final_filename
  525. self._temp_filename = temp_filename
  526. self._osutil = osutil
  527. def __call__(self, **kwargs):
  528. error = kwargs['error']
  529. if error:
  530. self._osutil.remove_file(self._temp_filename)
  531. else:
  532. try:
  533. self._osutil.rename_file(
  534. self._temp_filename, self._final_filename
  535. )
  536. except Exception as e:
  537. self._osutil.remove_file(self._temp_filename)
  538. # the CRT future has done already at this point
  539. self._coordinator.set_exception(e)
  540. class AfterDoneHandler:
  541. def __init__(self, coordinator):
  542. self._coordinator = coordinator
  543. def __call__(self, **kwargs):
  544. self._coordinator.set_done_callbacks_complete()