123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365 |
- import logging
- from functools import lru_cache
- from urllib.parse import urlencode, quote_plus
- from tenacity import (
- retry,
- retry_if_result,
- wait_exponential,
- stop_after_attempt,
- after_log,
- )
- from boto_utils import fetch_job_manifest, paginate
- from botocore.exceptions import ClientError
- from utils import remove_none, retry_wrapper
- # BEGINNING OF s3transfer MONKEY PATCH
- # https://github.com/boto/s3transfer/issues/82#issuecomment-837971614
- import s3transfer.upload
- import s3transfer.tasks
- class PutObjectTask(s3transfer.tasks.Task):
- # Copied from s3transfer/upload.py, changed to return the result of client.put_object.
- def _main(self, client, fileobj, bucket, key, extra_args):
- with fileobj as body:
- return client.put_object(Bucket=bucket, Key=key, Body=body, **extra_args)
- class CompleteMultipartUploadTask(s3transfer.tasks.Task):
- # Copied from s3transfer/tasks.py, changed to return a result.
- def _main(self, client, bucket, key, upload_id, parts, extra_args):
- print(f"Multipart upload {upload_id} for {key}.")
- return client.complete_multipart_upload(
- Bucket=bucket,
- Key=key,
- UploadId=upload_id,
- MultipartUpload={"Parts": parts},
- **extra_args,
- )
- s3transfer.upload.PutObjectTask = PutObjectTask
- s3transfer.upload.CompleteMultipartUploadTask = CompleteMultipartUploadTask
- # END OF s3transfer MONKEY PATCH
- logger = logging.getLogger(__name__)
- def save(client, buf, bucket, key, metadata, source_version=None):
- """
- Save a buffer to S3, preserving any existing properties on the object
- """
- # Get Object Settings
- request_payer_args, _ = get_requester_payment(client, bucket)
- object_info_args, _ = get_object_info(client, bucket, key, source_version)
- tagging_args, _ = get_object_tags(client, bucket, key, source_version)
- acl_args, acl_resp = get_object_acl(client, bucket, key, source_version)
- extra_args = {
- **request_payer_args,
- **object_info_args,
- **tagging_args,
- **acl_args,
- **{"Metadata": metadata},
- }
- logger.info("Object settings: %s", extra_args)
- # Write Object Back to S3
- logger.info("Saving updated object to s3://%s/%s", bucket, key)
- resp = client.upload_fileobj(buf, bucket, key, ExtraArgs=extra_args)
- new_version_id = resp["VersionId"]
- logger.info("Object uploaded to S3")
- # GrantWrite cannot be set whilst uploading therefore ACLs need to be restored separately
- write_grantees = ",".join(get_grantees(acl_resp, "WRITE"))
- if write_grantees:
- logger.info("WRITE grant found. Restoring additional grantees for object")
- client.put_object_acl(
- Bucket=bucket,
- Key=key,
- VersionId=new_version_id,
- **{
- **request_payer_args,
- **acl_args,
- "GrantWrite": write_grantees,
- },
- )
- logger.info("Processing of file s3://%s/%s complete", bucket, key)
- return new_version_id
- @lru_cache()
- def get_requester_payment(client, bucket):
- """
- Generates a dict containing the request payer args supported when calling S3.
- GetBucketRequestPayment call will be cached
- :returns tuple containing the info formatted for ExtraArgs and the raw response
- """
- request_payer = client.get_bucket_request_payment(Bucket=bucket)
- return (
- remove_none(
- {
- "RequestPayer": "requester"
- if request_payer["Payer"] == "Requester"
- else None,
- }
- ),
- request_payer,
- )
- @lru_cache()
- def get_object_info(client, bucket, key, version_id=None):
- """
- Generates a dict containing the non-ACL/Tagging args supported when uploading to S3.
- HeadObject call will be cached
- :returns tuple containing the info formatted for ExtraArgs and the raw response
- """
- kwargs = {"Bucket": bucket, "Key": key, **get_requester_payment(client, bucket)[0]}
- if version_id:
- kwargs["VersionId"] = version_id
- object_info = client.head_object(**kwargs)
- return (
- remove_none(
- {
- "CacheControl": object_info.get("CacheControl"),
- "ContentDisposition": object_info.get("ContentDisposition"),
- "ContentEncoding": object_info.get("ContentEncoding"),
- "ContentLanguage": object_info.get("ContentLanguage"),
- "ContentType": object_info.get("ContentType"),
- "Expires": object_info.get("Expires"),
- "Metadata": object_info.get("Metadata"),
- "ServerSideEncryption": object_info.get("ServerSideEncryption"),
- "StorageClass": object_info.get("StorageClass"),
- "SSECustomerAlgorithm": object_info.get("SSECustomerAlgorithm"),
- "SSEKMSKeyId": object_info.get("SSEKMSKeyId"),
- "WebsiteRedirectLocation": object_info.get("WebsiteRedirectLocation"),
- }
- ),
- object_info,
- )
- @lru_cache()
- def get_object_tags(client, bucket, key, version_id=None):
- """
- Generates a dict containing the Tagging args supported when uploading to S3
- GetObjectTagging call will be cached
- :returns tuple containing tagging formatted for ExtraArgs and the raw response
- """
- kwargs = {"Bucket": bucket, "Key": key}
- if version_id:
- kwargs["VersionId"] = version_id
- tagging = client.get_object_tagging(**kwargs)
- return (
- remove_none(
- {
- "Tagging": urlencode(
- {tag["Key"]: tag["Value"] for tag in tagging["TagSet"]},
- quote_via=quote_plus,
- )
- }
- ),
- tagging,
- )
- @lru_cache()
- def get_object_acl(client, bucket, key, version_id=None):
- """
- Generates a dict containing the ACL args supported when uploading to S3
- GetObjectAcl call will be cached
- :returns tuple containing ACL formatted for ExtraArgs and the raw response
- """
- kwargs = {"Bucket": bucket, "Key": key, **get_requester_payment(client, bucket)[0]}
- if version_id:
- kwargs["VersionId"] = version_id
- acl = client.get_object_acl(**kwargs)
- existing_owner = {"id={}".format(acl["Owner"]["ID"])}
- return (
- remove_none(
- {
- "GrantFullControl": ",".join(
- existing_owner | get_grantees(acl, "FULL_CONTROL")
- ),
- "GrantRead": ",".join(get_grantees(acl, "READ")),
- "GrantReadACP": ",".join(get_grantees(acl, "READ_ACP")),
- "GrantWriteACP": ",".join(get_grantees(acl, "WRITE_ACP")),
- }
- ),
- acl,
- )
- def get_grantees(acl, grant_type):
- prop_map = {
- "CanonicalUser": ("ID", "id"),
- "AmazonCustomerByEmail": ("EmailAddress", "emailAddress"),
- "Group": ("URI", "uri"),
- }
- filtered = [
- grantee["Grantee"]
- for grantee in acl.get("Grants")
- if grantee["Permission"] == grant_type
- ]
- grantees = set()
- for grantee in filtered:
- identifier_type = grantee["Type"]
- identifier_prop = prop_map[identifier_type]
- grantees.add("{}={}".format(identifier_prop[1], grantee[identifier_prop[0]]))
- return grantees
- @lru_cache()
- def validate_bucket_versioning(client, bucket):
- resp = client.get_bucket_versioning(Bucket=bucket)
- versioning_enabled = resp.get("Status") == "Enabled"
- mfa_delete_enabled = resp.get("MFADelete") == "Enabled"
- if not versioning_enabled:
- raise ValueError("Bucket {} does not have versioning enabled".format(bucket))
- if mfa_delete_enabled:
- raise ValueError("Bucket {} has MFA Delete enabled".format(bucket))
- return True
- @lru_cache()
- def fetch_manifest(manifest_object):
- return fetch_job_manifest(manifest_object)
- def delete_old_versions(client, input_bucket, input_key, new_version):
- try:
- resp = list(
- paginate(
- client,
- client.list_object_versions,
- ["Versions", "DeleteMarkers"],
- Bucket=input_bucket,
- Prefix=input_key,
- VersionIdMarker=new_version,
- KeyMarker=input_key,
- )
- )
- versions = [el[0] for el in resp if el[0] is not None]
- delete_markers = [el[1] for el in resp if el[1] is not None]
- versions.extend(delete_markers)
- sorted_versions = sorted(versions, key=lambda x: x["LastModified"])
- version_ids = [v["VersionId"] for v in sorted_versions]
- errors = []
- max_deletions = 1000
- for i in range(0, len(version_ids), max_deletions):
- objects = [
- {"Key": input_key, "VersionId": version_id}
- for version_id in version_ids[i : i + max_deletions]
- ]
- resp = delete_s3_objects(client, input_bucket, objects)
- errors.extend(resp.get("Errors", []))
- if len(errors) > 0:
- raise DeleteOldVersionsError(
- errors=[
- "Delete object {} version {} failed: {}".format(
- e["Key"], e["VersionId"], e["Message"]
- )
- for e in errors
- ]
- )
- except ClientError as e:
- raise DeleteOldVersionsError(errors=[str(e)])
- @retry(
- wait=wait_exponential(multiplier=1, min=1, max=10),
- stop=stop_after_attempt(10),
- retry=(retry_if_result(lambda r: len(r.get("Errors", [])) > 0)),
- retry_error_callback=lambda r: r.outcome.result(),
- after=after_log(logger, logging.DEBUG),
- )
- def delete_s3_objects(client, bucket, objects):
- return client.delete_objects(
- Bucket=bucket,
- Delete={
- "Objects": objects,
- "Quiet": True,
- },
- )
- def verify_object_versions_integrity(
- client, bucket, key, from_version_id, to_version_id
- ):
- def raise_exception(msg):
- raise IntegrityCheckFailedError(msg, client, bucket, key, to_version_id)
- conflict_error_template = "A {} ({}) was detected for the given object between read and write operations ({} and {})."
- not_found_error_template = "Previous version ({}) has been deleted."
- object_versions = retry_wrapper(client.list_object_versions)(
- Bucket=bucket,
- Prefix=key,
- VersionIdMarker=to_version_id,
- KeyMarker=key,
- MaxKeys=1,
- )
- versions = object_versions.get("Versions", [])
- delete_markers = object_versions.get("DeleteMarkers", [])
- all_versions = versions + delete_markers
- if not len(all_versions):
- return raise_exception(not_found_error_template.format(from_version_id))
- prev_version = all_versions[0]
- prev_version_id = prev_version["VersionId"]
- if prev_version_id != from_version_id:
- conflicting_version_type = (
- "delete marker" if "ETag" not in prev_version else "version"
- )
- return raise_exception(
- conflict_error_template.format(
- conflicting_version_type,
- prev_version_id,
- from_version_id,
- to_version_id,
- )
- )
- return True
- def rollback_object_version(client, bucket, key, version, on_error):
- """Delete newly created object version as soon as integrity conflict is detected"""
- try:
- return client.delete_object(Bucket=bucket, Key=key, VersionId=version)
- except ClientError as e:
- err_message = "ClientError: {}. Version rollback caused by version integrity conflict failed".format(
- str(e)
- )
- on_error(err_message)
- except Exception as e:
- err_message = "Unknown error: {}. Version rollback caused by version integrity conflict failed".format(
- str(e)
- )
- on_error(err_message)
- class DeleteOldVersionsError(Exception):
- def __init__(self, errors):
- super().__init__("\n".join(errors))
- self.errors = errors
- class IntegrityCheckFailedError(Exception):
- def __init__(self, message, client, bucket, key, version_id):
- self.message = message
- self.client = client
- self.bucket = bucket
- self.key = key
- self.version_id = version_id
|