main.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. import argparse
  2. import json
  3. import os
  4. import sys
  5. import signal
  6. import time
  7. import logging
  8. from multiprocessing import Pool, cpu_count
  9. from operator import itemgetter
  10. import boto3
  11. import pyarrow as pa
  12. from boto_utils import get_session, json_lines_iterator, parse_s3_url
  13. from botocore.exceptions import ClientError
  14. from pyarrow.lib import ArrowException
  15. from cse import decrypt, encrypt, is_kms_cse_encrypted
  16. from events import (
  17. sanitize_message,
  18. emit_failure_event,
  19. emit_deletion_event,
  20. emit_skipped_event,
  21. )
  22. from json_handler import delete_matches_from_json_file
  23. from parquet_handler import delete_matches_from_parquet_file
  24. from s3 import (
  25. delete_old_versions,
  26. DeleteOldVersionsError,
  27. fetch_manifest,
  28. get_object_info,
  29. IntegrityCheckFailedError,
  30. rollback_object_version,
  31. save,
  32. validate_bucket_versioning,
  33. verify_object_versions_integrity,
  34. )
  35. FIVE_MB = 5 * 2**20
  36. ROLE_SESSION_NAME = "s3f2"
  37. logger = logging.getLogger(__name__)
  38. logger.setLevel(os.getenv("LOG_LEVEL", logging.INFO))
  39. formatter = logging.Formatter("[%(levelname)s] %(message)s")
  40. handler = logging.StreamHandler(stream=sys.stdout)
  41. handler.setFormatter(formatter)
  42. logger.addHandler(handler)
  43. def handle_error(
  44. sqs_msg,
  45. message_body,
  46. err_message,
  47. event_name="ObjectUpdateFailed",
  48. change_msg_visibility=True,
  49. ):
  50. logger.error(sanitize_message(err_message, message_body))
  51. try:
  52. emit_failure_event(message_body, err_message, event_name)
  53. except KeyError:
  54. logger.error("Unable to emit failure event due to invalid Job ID")
  55. except (json.decoder.JSONDecodeError, ValueError):
  56. logger.error("Unable to emit failure event due to invalid message")
  57. except ClientError as e:
  58. logger.error("Unable to emit failure event: %s", str(e))
  59. if change_msg_visibility:
  60. try:
  61. sqs_msg.change_visibility(VisibilityTimeout=0)
  62. except (
  63. sqs_msg.meta.client.exceptions.MessageNotInflight,
  64. sqs_msg.meta.client.exceptions.ReceiptHandleIsInvalid,
  65. ) as e:
  66. logger.error("Unable to change message visibility: %s", str(e))
  67. def handle_skip(sqs_msg, message_body, skip_reason):
  68. sqs_msg.delete()
  69. logger.info(sanitize_message(skip_reason, message_body))
  70. emit_skipped_event(message_body, skip_reason)
  71. def validate_message(message):
  72. body = json.loads(message)
  73. mandatory_keys = ["JobId", "Object", "Columns"]
  74. for k in mandatory_keys:
  75. if k not in body:
  76. raise ValueError("Malformed message. Missing key: %s", k)
  77. def delete_matches_from_file(input_file, to_delete, file_format, compressed=False):
  78. logger.info("Generating new file without matches")
  79. if file_format == "json":
  80. return delete_matches_from_json_file(input_file, to_delete, compressed)
  81. return delete_matches_from_parquet_file(input_file, to_delete)
  82. def build_matches(cols, manifest_object):
  83. """
  84. This function takes the columns and the manifests, and returns
  85. the match_ids grouped by column.
  86. Input example:
  87. [{"Column":"customer_id", "Type":"Simple"}]
  88. Output example:
  89. [{"Column":"customer_id", "Type":"Simple", "MatchIds":[123, 234]}]
  90. """
  91. COMPOSITE_MATCH_TOKEN = "_S3F2COMP_"
  92. manifest = fetch_manifest(manifest_object)
  93. matches = {}
  94. for line in json_lines_iterator(manifest):
  95. if not line["QueryableColumns"] in matches:
  96. matches[line["QueryableColumns"]] = []
  97. is_simple = len(line["Columns"]) == 1
  98. match = line["MatchId"][0] if is_simple else line["MatchId"]
  99. matches[line["QueryableColumns"]].append(match)
  100. return list(
  101. map(
  102. lambda c: {
  103. "MatchIds": matches[
  104. COMPOSITE_MATCH_TOKEN.join(c["Columns"])
  105. if "Columns" in c
  106. else c["Column"]
  107. ],
  108. **c,
  109. },
  110. cols,
  111. )
  112. )
  113. def execute(queue_url, message_body, receipt_handle):
  114. logger.info("Message received")
  115. queue = get_queue(queue_url)
  116. msg = queue.Message(receipt_handle)
  117. try:
  118. # Parse and validate incoming message
  119. validate_message(message_body)
  120. body = json.loads(message_body)
  121. session = get_session(body.get("RoleArn"), ROLE_SESSION_NAME)
  122. ignore_not_found_exceptions = body.get("IgnoreObjectNotFoundExceptions", False)
  123. client = session.client("s3")
  124. kms_client = session.client("kms")
  125. cols, object_path, job_id, file_format, manifest_object = itemgetter(
  126. "Columns", "Object", "JobId", "Format", "Manifest"
  127. )(body)
  128. input_bucket, input_key = parse_s3_url(object_path)
  129. validate_bucket_versioning(client, input_bucket)
  130. match_ids = build_matches(cols, manifest_object)
  131. s3 = pa.fs.S3FileSystem(
  132. region=os.getenv("AWS_DEFAULT_REGION"),
  133. session_name=ROLE_SESSION_NAME,
  134. external_id=ROLE_SESSION_NAME,
  135. role_arn=body.get("RoleArn"),
  136. load_frequency=60 * 60,
  137. )
  138. # Download the object in-memory and convert to PyArrow NativeFile
  139. logger.info("Downloading and opening %s object in-memory", object_path)
  140. with s3.open_input_stream(
  141. "{}/{}".format(input_bucket, input_key), buffer_size=FIVE_MB
  142. ) as f:
  143. source_version = f.metadata()["VersionId"].decode("utf-8")
  144. logger.info("Using object version %s as source", source_version)
  145. # Write new file in-memory
  146. compressed = object_path.endswith(".gz")
  147. object_info, _ = get_object_info(
  148. client, input_bucket, input_key, source_version
  149. )
  150. metadata = object_info["Metadata"]
  151. is_encrypted = is_kms_cse_encrypted(metadata)
  152. input_file = decrypt(f, metadata, kms_client) if is_encrypted else f
  153. out_sink, stats = delete_matches_from_file(
  154. input_file, match_ids, file_format, compressed
  155. )
  156. if stats["DeletedRows"] == 0:
  157. raise ValueError(
  158. "The object {} was processed successfully but no rows required deletion".format(
  159. object_path
  160. )
  161. )
  162. with pa.BufferReader(out_sink.getvalue()) as output_buf:
  163. if is_encrypted:
  164. output_buf, metadata = encrypt(output_buf, metadata, kms_client)
  165. logger.info("Uploading new object version to S3")
  166. new_version = save(
  167. client,
  168. output_buf,
  169. input_bucket,
  170. input_key,
  171. metadata,
  172. source_version,
  173. )
  174. logger.info("New object version: %s", new_version)
  175. verify_object_versions_integrity(
  176. client, input_bucket, input_key, source_version, new_version
  177. )
  178. if body.get("DeleteOldVersions"):
  179. logger.info(
  180. "Deleting object {} versions older than version {}".format(
  181. input_key, new_version
  182. )
  183. )
  184. delete_old_versions(client, input_bucket, input_key, new_version)
  185. msg.delete()
  186. emit_deletion_event(body, stats)
  187. except FileNotFoundError as e:
  188. err_message = "Apache Arrow S3FileSystem Error: {}".format(str(e))
  189. if ignore_not_found_exceptions:
  190. handle_skip(msg, body, "Ignored error: {}".format(err_message))
  191. else:
  192. handle_error(msg, message_body, err_message)
  193. except (KeyError, ArrowException) as e:
  194. err_message = "Apache Arrow processing error: {}".format(str(e))
  195. handle_error(msg, message_body, err_message)
  196. except IOError as e:
  197. err_message = "Unable to retrieve object: {}".format(str(e))
  198. handle_error(msg, message_body, err_message)
  199. except MemoryError as e:
  200. err_message = "Insufficient memory to work on object: {}".format(str(e))
  201. handle_error(msg, message_body, err_message)
  202. except ClientError as e:
  203. ignore_error = False
  204. err_message = "ClientError: {}".format(str(e))
  205. if e.operation_name == "PutObjectAcl":
  206. err_message += ". Redacted object uploaded successfully but unable to restore WRITE ACL"
  207. if e.operation_name == "ListObjectVersions":
  208. err_message += ". Could not verify redacted object version integrity"
  209. if e.operation_name == "HeadObject" and e.response["Error"]["Code"] == "404":
  210. ignore_error = ignore_not_found_exceptions
  211. if ignore_error:
  212. skip_reason = "Ignored error: {}".format(err_message)
  213. handle_skip(msg, body, skip_reason)
  214. else:
  215. handle_error(msg, message_body, err_message)
  216. except ValueError as e:
  217. err_message = "Unprocessable message: {}".format(str(e))
  218. handle_error(msg, message_body, err_message)
  219. except DeleteOldVersionsError as e:
  220. err_message = "Unable to delete previous versions: {}".format(str(e))
  221. handle_error(msg, message_body, err_message)
  222. except IntegrityCheckFailedError as e:
  223. err_description, client, bucket, key, version_id = e.args
  224. err_message = "Object version integrity check failed: {}".format(
  225. err_description
  226. )
  227. handle_error(msg, message_body, err_message)
  228. rollback_object_version(
  229. client,
  230. bucket,
  231. key,
  232. version_id,
  233. on_error=lambda err: handle_error(
  234. None, "{}", err, "ObjectRollbackFailed", False
  235. ),
  236. )
  237. except Exception as e:
  238. err_message = "Unknown error during message processing: {}".format(str(e))
  239. handle_error(msg, message_body, err_message)
  240. def kill_handler(msgs, process_pool):
  241. logger.info("Received shutdown signal. Cleaning up %s messages", str(len(msgs)))
  242. process_pool.terminate()
  243. for msg in msgs:
  244. try:
  245. handle_error(msg, msg.body, "SIGINT/SIGTERM received during processing")
  246. except (ClientError, ValueError) as e:
  247. logger.error("Unable to gracefully cleanup message: %s", str(e))
  248. sys.exit(1 if len(msgs) > 0 else 0)
  249. def get_queue(queue_url, **resource_kwargs):
  250. if not resource_kwargs.get("endpoint_url") and os.getenv("AWS_DEFAULT_REGION"):
  251. resource_kwargs["endpoint_url"] = "https://sqs.{}.{}".format(
  252. os.getenv("AWS_DEFAULT_REGION"), os.getenv("AWS_URL_SUFFIX")
  253. )
  254. sqs = boto3.resource("sqs", **resource_kwargs)
  255. return sqs.Queue(queue_url)
  256. def main(queue_url, max_messages, wait_time, sleep_time):
  257. logger.info("CPU count for system: %s", cpu_count())
  258. messages = []
  259. queue = get_queue(queue_url)
  260. with Pool(maxtasksperchild=1) as pool:
  261. signal.signal(signal.SIGINT, lambda *_: kill_handler(messages, pool))
  262. signal.signal(signal.SIGTERM, lambda *_: kill_handler(messages, pool))
  263. while 1:
  264. logger.info("Fetching messages...")
  265. messages = queue.receive_messages(
  266. WaitTimeSeconds=wait_time, MaxNumberOfMessages=max_messages
  267. )
  268. if len(messages) == 0:
  269. logger.info("No messages. Sleeping")
  270. time.sleep(sleep_time)
  271. else:
  272. processes = [(queue_url, m.body, m.receipt_handle) for m in messages]
  273. pool.starmap(execute, processes)
  274. messages = []
  275. def parse_args(args):
  276. parser = argparse.ArgumentParser(
  277. description="Read and process new deletion tasks from a deletion queue"
  278. )
  279. parser.add_argument("--wait_time", type=int, default=5)
  280. parser.add_argument("--max_messages", type=int, default=1)
  281. parser.add_argument("--sleep_time", type=int, default=30)
  282. parser.add_argument(
  283. "--queue_url", type=str, default=os.getenv("DELETE_OBJECTS_QUEUE")
  284. )
  285. return parser.parse_args(args)
  286. if __name__ == "__main__":
  287. opts = parse_args(sys.argv[1:])
  288. main(opts.queue_url, opts.max_messages, opts.wait_time, opts.sleep_time)