s3.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. import logging
  2. from functools import lru_cache
  3. from urllib.parse import urlencode, quote_plus
  4. from tenacity import (
  5. retry,
  6. retry_if_result,
  7. wait_exponential,
  8. stop_after_attempt,
  9. after_log,
  10. )
  11. from boto_utils import fetch_job_manifest, paginate
  12. from botocore.exceptions import ClientError
  13. from utils import remove_none, retry_wrapper
  14. # BEGINNING OF s3transfer MONKEY PATCH
  15. # https://github.com/boto/s3transfer/issues/82#issuecomment-837971614
  16. import s3transfer.upload
  17. import s3transfer.tasks
  18. class PutObjectTask(s3transfer.tasks.Task):
  19. # Copied from s3transfer/upload.py, changed to return the result of client.put_object.
  20. def _main(self, client, fileobj, bucket, key, extra_args):
  21. with fileobj as body:
  22. return client.put_object(Bucket=bucket, Key=key, Body=body, **extra_args)
  23. class CompleteMultipartUploadTask(s3transfer.tasks.Task):
  24. # Copied from s3transfer/tasks.py, changed to return a result.
  25. def _main(self, client, bucket, key, upload_id, parts, extra_args):
  26. print(f"Multipart upload {upload_id} for {key}.")
  27. return client.complete_multipart_upload(
  28. Bucket=bucket,
  29. Key=key,
  30. UploadId=upload_id,
  31. MultipartUpload={"Parts": parts},
  32. **extra_args,
  33. )
  34. s3transfer.upload.PutObjectTask = PutObjectTask
  35. s3transfer.upload.CompleteMultipartUploadTask = CompleteMultipartUploadTask
  36. # END OF s3transfer MONKEY PATCH
  37. logger = logging.getLogger(__name__)
  38. def save(client, buf, bucket, key, metadata, source_version=None):
  39. """
  40. Save a buffer to S3, preserving any existing properties on the object
  41. """
  42. # Get Object Settings
  43. request_payer_args, _ = get_requester_payment(client, bucket)
  44. object_info_args, _ = get_object_info(client, bucket, key, source_version)
  45. tagging_args, _ = get_object_tags(client, bucket, key, source_version)
  46. acl_args, acl_resp = get_object_acl(client, bucket, key, source_version)
  47. extra_args = {
  48. **request_payer_args,
  49. **object_info_args,
  50. **tagging_args,
  51. **acl_args,
  52. **{"Metadata": metadata},
  53. }
  54. logger.info("Object settings: %s", extra_args)
  55. # Write Object Back to S3
  56. logger.info("Saving updated object to s3://%s/%s", bucket, key)
  57. resp = client.upload_fileobj(buf, bucket, key, ExtraArgs=extra_args)
  58. new_version_id = resp["VersionId"]
  59. logger.info("Object uploaded to S3")
  60. # GrantWrite cannot be set whilst uploading therefore ACLs need to be restored separately
  61. write_grantees = ",".join(get_grantees(acl_resp, "WRITE"))
  62. if write_grantees:
  63. logger.info("WRITE grant found. Restoring additional grantees for object")
  64. client.put_object_acl(
  65. Bucket=bucket,
  66. Key=key,
  67. VersionId=new_version_id,
  68. **{
  69. **request_payer_args,
  70. **acl_args,
  71. "GrantWrite": write_grantees,
  72. },
  73. )
  74. logger.info("Processing of file s3://%s/%s complete", bucket, key)
  75. return new_version_id
  76. @lru_cache()
  77. def get_requester_payment(client, bucket):
  78. """
  79. Generates a dict containing the request payer args supported when calling S3.
  80. GetBucketRequestPayment call will be cached
  81. :returns tuple containing the info formatted for ExtraArgs and the raw response
  82. """
  83. request_payer = client.get_bucket_request_payment(Bucket=bucket)
  84. return (
  85. remove_none(
  86. {
  87. "RequestPayer": "requester"
  88. if request_payer["Payer"] == "Requester"
  89. else None,
  90. }
  91. ),
  92. request_payer,
  93. )
  94. @lru_cache()
  95. def get_object_info(client, bucket, key, version_id=None):
  96. """
  97. Generates a dict containing the non-ACL/Tagging args supported when uploading to S3.
  98. HeadObject call will be cached
  99. :returns tuple containing the info formatted for ExtraArgs and the raw response
  100. """
  101. kwargs = {"Bucket": bucket, "Key": key, **get_requester_payment(client, bucket)[0]}
  102. if version_id:
  103. kwargs["VersionId"] = version_id
  104. object_info = client.head_object(**kwargs)
  105. return (
  106. remove_none(
  107. {
  108. "CacheControl": object_info.get("CacheControl"),
  109. "ContentDisposition": object_info.get("ContentDisposition"),
  110. "ContentEncoding": object_info.get("ContentEncoding"),
  111. "ContentLanguage": object_info.get("ContentLanguage"),
  112. "ContentType": object_info.get("ContentType"),
  113. "Expires": object_info.get("Expires"),
  114. "Metadata": object_info.get("Metadata"),
  115. "ServerSideEncryption": object_info.get("ServerSideEncryption"),
  116. "StorageClass": object_info.get("StorageClass"),
  117. "SSECustomerAlgorithm": object_info.get("SSECustomerAlgorithm"),
  118. "SSEKMSKeyId": object_info.get("SSEKMSKeyId"),
  119. "WebsiteRedirectLocation": object_info.get("WebsiteRedirectLocation"),
  120. }
  121. ),
  122. object_info,
  123. )
  124. @lru_cache()
  125. def get_object_tags(client, bucket, key, version_id=None):
  126. """
  127. Generates a dict containing the Tagging args supported when uploading to S3
  128. GetObjectTagging call will be cached
  129. :returns tuple containing tagging formatted for ExtraArgs and the raw response
  130. """
  131. kwargs = {"Bucket": bucket, "Key": key}
  132. if version_id:
  133. kwargs["VersionId"] = version_id
  134. tagging = client.get_object_tagging(**kwargs)
  135. return (
  136. remove_none(
  137. {
  138. "Tagging": urlencode(
  139. {tag["Key"]: tag["Value"] for tag in tagging["TagSet"]},
  140. quote_via=quote_plus,
  141. )
  142. }
  143. ),
  144. tagging,
  145. )
  146. @lru_cache()
  147. def get_object_acl(client, bucket, key, version_id=None):
  148. """
  149. Generates a dict containing the ACL args supported when uploading to S3
  150. GetObjectAcl call will be cached
  151. :returns tuple containing ACL formatted for ExtraArgs and the raw response
  152. """
  153. kwargs = {"Bucket": bucket, "Key": key, **get_requester_payment(client, bucket)[0]}
  154. if version_id:
  155. kwargs["VersionId"] = version_id
  156. acl = client.get_object_acl(**kwargs)
  157. existing_owner = {"id={}".format(acl["Owner"]["ID"])}
  158. return (
  159. remove_none(
  160. {
  161. "GrantFullControl": ",".join(
  162. existing_owner | get_grantees(acl, "FULL_CONTROL")
  163. ),
  164. "GrantRead": ",".join(get_grantees(acl, "READ")),
  165. "GrantReadACP": ",".join(get_grantees(acl, "READ_ACP")),
  166. "GrantWriteACP": ",".join(get_grantees(acl, "WRITE_ACP")),
  167. }
  168. ),
  169. acl,
  170. )
  171. def get_grantees(acl, grant_type):
  172. prop_map = {
  173. "CanonicalUser": ("ID", "id"),
  174. "AmazonCustomerByEmail": ("EmailAddress", "emailAddress"),
  175. "Group": ("URI", "uri"),
  176. }
  177. filtered = [
  178. grantee["Grantee"]
  179. for grantee in acl.get("Grants")
  180. if grantee["Permission"] == grant_type
  181. ]
  182. grantees = set()
  183. for grantee in filtered:
  184. identifier_type = grantee["Type"]
  185. identifier_prop = prop_map[identifier_type]
  186. grantees.add("{}={}".format(identifier_prop[1], grantee[identifier_prop[0]]))
  187. return grantees
  188. @lru_cache()
  189. def validate_bucket_versioning(client, bucket):
  190. resp = client.get_bucket_versioning(Bucket=bucket)
  191. versioning_enabled = resp.get("Status") == "Enabled"
  192. mfa_delete_enabled = resp.get("MFADelete") == "Enabled"
  193. if not versioning_enabled:
  194. raise ValueError("Bucket {} does not have versioning enabled".format(bucket))
  195. if mfa_delete_enabled:
  196. raise ValueError("Bucket {} has MFA Delete enabled".format(bucket))
  197. return True
  198. @lru_cache()
  199. def fetch_manifest(manifest_object):
  200. return fetch_job_manifest(manifest_object)
  201. def delete_old_versions(client, input_bucket, input_key, new_version):
  202. try:
  203. resp = list(
  204. paginate(
  205. client,
  206. client.list_object_versions,
  207. ["Versions", "DeleteMarkers"],
  208. Bucket=input_bucket,
  209. Prefix=input_key,
  210. VersionIdMarker=new_version,
  211. KeyMarker=input_key,
  212. )
  213. )
  214. versions = [el[0] for el in resp if el[0] is not None]
  215. delete_markers = [el[1] for el in resp if el[1] is not None]
  216. versions.extend(delete_markers)
  217. sorted_versions = sorted(versions, key=lambda x: x["LastModified"])
  218. version_ids = [v["VersionId"] for v in sorted_versions]
  219. errors = []
  220. max_deletions = 1000
  221. for i in range(0, len(version_ids), max_deletions):
  222. objects = [
  223. {"Key": input_key, "VersionId": version_id}
  224. for version_id in version_ids[i : i + max_deletions]
  225. ]
  226. resp = delete_s3_objects(client, input_bucket, objects)
  227. errors.extend(resp.get("Errors", []))
  228. if len(errors) > 0:
  229. raise DeleteOldVersionsError(
  230. errors=[
  231. "Delete object {} version {} failed: {}".format(
  232. e["Key"], e["VersionId"], e["Message"]
  233. )
  234. for e in errors
  235. ]
  236. )
  237. except ClientError as e:
  238. raise DeleteOldVersionsError(errors=[str(e)])
  239. @retry(
  240. wait=wait_exponential(multiplier=1, min=1, max=10),
  241. stop=stop_after_attempt(10),
  242. retry=(retry_if_result(lambda r: len(r.get("Errors", [])) > 0)),
  243. retry_error_callback=lambda r: r.outcome.result(),
  244. after=after_log(logger, logging.DEBUG),
  245. )
  246. def delete_s3_objects(client, bucket, objects):
  247. return client.delete_objects(
  248. Bucket=bucket,
  249. Delete={
  250. "Objects": objects,
  251. "Quiet": True,
  252. },
  253. )
  254. def verify_object_versions_integrity(
  255. client, bucket, key, from_version_id, to_version_id
  256. ):
  257. def raise_exception(msg):
  258. raise IntegrityCheckFailedError(msg, client, bucket, key, to_version_id)
  259. conflict_error_template = "A {} ({}) was detected for the given object between read and write operations ({} and {})."
  260. not_found_error_template = "Previous version ({}) has been deleted."
  261. object_versions = retry_wrapper(client.list_object_versions)(
  262. Bucket=bucket,
  263. Prefix=key,
  264. VersionIdMarker=to_version_id,
  265. KeyMarker=key,
  266. MaxKeys=1,
  267. )
  268. versions = object_versions.get("Versions", [])
  269. delete_markers = object_versions.get("DeleteMarkers", [])
  270. all_versions = versions + delete_markers
  271. if not len(all_versions):
  272. return raise_exception(not_found_error_template.format(from_version_id))
  273. prev_version = all_versions[0]
  274. prev_version_id = prev_version["VersionId"]
  275. if prev_version_id != from_version_id:
  276. conflicting_version_type = (
  277. "delete marker" if "ETag" not in prev_version else "version"
  278. )
  279. return raise_exception(
  280. conflict_error_template.format(
  281. conflicting_version_type,
  282. prev_version_id,
  283. from_version_id,
  284. to_version_id,
  285. )
  286. )
  287. return True
  288. def rollback_object_version(client, bucket, key, version, on_error):
  289. """Delete newly created object version as soon as integrity conflict is detected"""
  290. try:
  291. return client.delete_object(Bucket=bucket, Key=key, VersionId=version)
  292. except ClientError as e:
  293. err_message = "ClientError: {}. Version rollback caused by version integrity conflict failed".format(
  294. str(e)
  295. )
  296. on_error(err_message)
  297. except Exception as e:
  298. err_message = "Unknown error: {}. Version rollback caused by version integrity conflict failed".format(
  299. str(e)
  300. )
  301. on_error(err_message)
  302. class DeleteOldVersionsError(Exception):
  303. def __init__(self, errors):
  304. super().__init__("\n".join(errors))
  305. self.errors = errors
  306. class IntegrityCheckFailedError(Exception):
  307. def __init__(self, message, client, bucket, key, version_id):
  308. self.message = message
  309. self.client = client
  310. self.bucket = bucket
  311. self.key = key
  312. self.version_id = version_id