123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315 |
- import argparse
- import json
- import os
- import sys
- import signal
- import time
- import logging
- from multiprocessing import Pool, cpu_count
- from operator import itemgetter
- import boto3
- import pyarrow as pa
- from boto_utils import get_session, json_lines_iterator, parse_s3_url
- from botocore.exceptions import ClientError
- from pyarrow.lib import ArrowException
- from cse import decrypt, encrypt, is_kms_cse_encrypted
- from events import (
- sanitize_message,
- emit_failure_event,
- emit_deletion_event,
- emit_skipped_event,
- )
- from json_handler import delete_matches_from_json_file
- from parquet_handler import delete_matches_from_parquet_file
- from s3 import (
- delete_old_versions,
- DeleteOldVersionsError,
- fetch_manifest,
- get_object_info,
- IntegrityCheckFailedError,
- rollback_object_version,
- save,
- validate_bucket_versioning,
- verify_object_versions_integrity,
- )
- FIVE_MB = 5 * 2**20
- ROLE_SESSION_NAME = "s3f2"
- logger = logging.getLogger(__name__)
- logger.setLevel(os.getenv("LOG_LEVEL", logging.INFO))
- formatter = logging.Formatter("[%(levelname)s] %(message)s")
- handler = logging.StreamHandler(stream=sys.stdout)
- handler.setFormatter(formatter)
- logger.addHandler(handler)
- def handle_error(
- sqs_msg,
- message_body,
- err_message,
- event_name="ObjectUpdateFailed",
- change_msg_visibility=True,
- ):
- logger.error(sanitize_message(err_message, message_body))
- try:
- emit_failure_event(message_body, err_message, event_name)
- except KeyError:
- logger.error("Unable to emit failure event due to invalid Job ID")
- except (json.decoder.JSONDecodeError, ValueError):
- logger.error("Unable to emit failure event due to invalid message")
- except ClientError as e:
- logger.error("Unable to emit failure event: %s", str(e))
- if change_msg_visibility:
- try:
- sqs_msg.change_visibility(VisibilityTimeout=0)
- except (
- sqs_msg.meta.client.exceptions.MessageNotInflight,
- sqs_msg.meta.client.exceptions.ReceiptHandleIsInvalid,
- ) as e:
- logger.error("Unable to change message visibility: %s", str(e))
- def handle_skip(sqs_msg, message_body, skip_reason):
- sqs_msg.delete()
- logger.info(sanitize_message(skip_reason, message_body))
- emit_skipped_event(message_body, skip_reason)
- def validate_message(message):
- body = json.loads(message)
- mandatory_keys = ["JobId", "Object", "Columns"]
- for k in mandatory_keys:
- if k not in body:
- raise ValueError("Malformed message. Missing key: %s", k)
- def delete_matches_from_file(input_file, to_delete, file_format, compressed=False):
- logger.info("Generating new file without matches")
- if file_format == "json":
- return delete_matches_from_json_file(input_file, to_delete, compressed)
- return delete_matches_from_parquet_file(input_file, to_delete)
- def build_matches(cols, manifest_object):
- """
- This function takes the columns and the manifests, and returns
- the match_ids grouped by column.
- Input example:
- [{"Column":"customer_id", "Type":"Simple"}]
- Output example:
- [{"Column":"customer_id", "Type":"Simple", "MatchIds":[123, 234]}]
- """
- COMPOSITE_MATCH_TOKEN = "_S3F2COMP_"
- manifest = fetch_manifest(manifest_object)
- matches = {}
- for line in json_lines_iterator(manifest):
- if not line["QueryableColumns"] in matches:
- matches[line["QueryableColumns"]] = []
- is_simple = len(line["Columns"]) == 1
- match = line["MatchId"][0] if is_simple else line["MatchId"]
- matches[line["QueryableColumns"]].append(match)
- return list(
- map(
- lambda c: {
- "MatchIds": matches[
- COMPOSITE_MATCH_TOKEN.join(c["Columns"])
- if "Columns" in c
- else c["Column"]
- ],
- **c,
- },
- cols,
- )
- )
- def execute(queue_url, message_body, receipt_handle):
- logger.info("Message received")
- queue = get_queue(queue_url)
- msg = queue.Message(receipt_handle)
- try:
- # Parse and validate incoming message
- validate_message(message_body)
- body = json.loads(message_body)
- session = get_session(body.get("RoleArn"), ROLE_SESSION_NAME)
- ignore_not_found_exceptions = body.get("IgnoreObjectNotFoundExceptions", False)
- client = session.client("s3")
- kms_client = session.client("kms")
- cols, object_path, job_id, file_format, manifest_object = itemgetter(
- "Columns", "Object", "JobId", "Format", "Manifest"
- )(body)
- input_bucket, input_key = parse_s3_url(object_path)
- validate_bucket_versioning(client, input_bucket)
- match_ids = build_matches(cols, manifest_object)
- s3 = pa.fs.S3FileSystem(
- region=os.getenv("AWS_DEFAULT_REGION"),
- session_name=ROLE_SESSION_NAME,
- external_id=ROLE_SESSION_NAME,
- role_arn=body.get("RoleArn"),
- load_frequency=60 * 60,
- )
- # Download the object in-memory and convert to PyArrow NativeFile
- logger.info("Downloading and opening %s object in-memory", object_path)
- with s3.open_input_stream(
- "{}/{}".format(input_bucket, input_key), buffer_size=FIVE_MB
- ) as f:
- source_version = f.metadata()["VersionId"].decode("utf-8")
- logger.info("Using object version %s as source", source_version)
- # Write new file in-memory
- compressed = object_path.endswith(".gz")
- object_info, _ = get_object_info(
- client, input_bucket, input_key, source_version
- )
- metadata = object_info["Metadata"]
- is_encrypted = is_kms_cse_encrypted(metadata)
- input_file = decrypt(f, metadata, kms_client) if is_encrypted else f
- out_sink, stats = delete_matches_from_file(
- input_file, match_ids, file_format, compressed
- )
- if stats["DeletedRows"] == 0:
- raise ValueError(
- "The object {} was processed successfully but no rows required deletion".format(
- object_path
- )
- )
- with pa.BufferReader(out_sink.getvalue()) as output_buf:
- if is_encrypted:
- output_buf, metadata = encrypt(output_buf, metadata, kms_client)
- logger.info("Uploading new object version to S3")
- new_version = save(
- client,
- output_buf,
- input_bucket,
- input_key,
- metadata,
- source_version,
- )
- logger.info("New object version: %s", new_version)
- verify_object_versions_integrity(
- client, input_bucket, input_key, source_version, new_version
- )
- if body.get("DeleteOldVersions"):
- logger.info(
- "Deleting object {} versions older than version {}".format(
- input_key, new_version
- )
- )
- delete_old_versions(client, input_bucket, input_key, new_version)
- msg.delete()
- emit_deletion_event(body, stats)
- except FileNotFoundError as e:
- err_message = "Apache Arrow S3FileSystem Error: {}".format(str(e))
- if ignore_not_found_exceptions:
- handle_skip(msg, body, "Ignored error: {}".format(err_message))
- else:
- handle_error(msg, message_body, err_message)
- except (KeyError, ArrowException) as e:
- err_message = "Apache Arrow processing error: {}".format(str(e))
- handle_error(msg, message_body, err_message)
- except IOError as e:
- err_message = "Unable to retrieve object: {}".format(str(e))
- handle_error(msg, message_body, err_message)
- except MemoryError as e:
- err_message = "Insufficient memory to work on object: {}".format(str(e))
- handle_error(msg, message_body, err_message)
- except ClientError as e:
- ignore_error = False
- err_message = "ClientError: {}".format(str(e))
- if e.operation_name == "PutObjectAcl":
- err_message += ". Redacted object uploaded successfully but unable to restore WRITE ACL"
- if e.operation_name == "ListObjectVersions":
- err_message += ". Could not verify redacted object version integrity"
- if e.operation_name == "HeadObject" and e.response["Error"]["Code"] == "404":
- ignore_error = ignore_not_found_exceptions
- if ignore_error:
- skip_reason = "Ignored error: {}".format(err_message)
- handle_skip(msg, body, skip_reason)
- else:
- handle_error(msg, message_body, err_message)
- except ValueError as e:
- err_message = "Unprocessable message: {}".format(str(e))
- handle_error(msg, message_body, err_message)
- except DeleteOldVersionsError as e:
- err_message = "Unable to delete previous versions: {}".format(str(e))
- handle_error(msg, message_body, err_message)
- except IntegrityCheckFailedError as e:
- err_description, client, bucket, key, version_id = e.args
- err_message = "Object version integrity check failed: {}".format(
- err_description
- )
- handle_error(msg, message_body, err_message)
- rollback_object_version(
- client,
- bucket,
- key,
- version_id,
- on_error=lambda err: handle_error(
- None, "{}", err, "ObjectRollbackFailed", False
- ),
- )
- except Exception as e:
- err_message = "Unknown error during message processing: {}".format(str(e))
- handle_error(msg, message_body, err_message)
- def kill_handler(msgs, process_pool):
- logger.info("Received shutdown signal. Cleaning up %s messages", str(len(msgs)))
- process_pool.terminate()
- for msg in msgs:
- try:
- handle_error(msg, msg.body, "SIGINT/SIGTERM received during processing")
- except (ClientError, ValueError) as e:
- logger.error("Unable to gracefully cleanup message: %s", str(e))
- sys.exit(1 if len(msgs) > 0 else 0)
- def get_queue(queue_url, **resource_kwargs):
- if not resource_kwargs.get("endpoint_url") and os.getenv("AWS_DEFAULT_REGION"):
- resource_kwargs["endpoint_url"] = "https://sqs.{}.{}".format(
- os.getenv("AWS_DEFAULT_REGION"), os.getenv("AWS_URL_SUFFIX")
- )
- sqs = boto3.resource("sqs", **resource_kwargs)
- return sqs.Queue(queue_url)
- def main(queue_url, max_messages, wait_time, sleep_time):
- logger.info("CPU count for system: %s", cpu_count())
- messages = []
- queue = get_queue(queue_url)
- with Pool(maxtasksperchild=1) as pool:
- signal.signal(signal.SIGINT, lambda *_: kill_handler(messages, pool))
- signal.signal(signal.SIGTERM, lambda *_: kill_handler(messages, pool))
- while 1:
- logger.info("Fetching messages...")
- messages = queue.receive_messages(
- WaitTimeSeconds=wait_time, MaxNumberOfMessages=max_messages
- )
- if len(messages) == 0:
- logger.info("No messages. Sleeping")
- time.sleep(sleep_time)
- else:
- processes = [(queue_url, m.body, m.receipt_handle) for m in messages]
- pool.starmap(execute, processes)
- messages = []
- def parse_args(args):
- parser = argparse.ArgumentParser(
- description="Read and process new deletion tasks from a deletion queue"
- )
- parser.add_argument("--wait_time", type=int, default=5)
- parser.add_argument("--max_messages", type=int, default=1)
- parser.add_argument("--sleep_time", type=int, default=30)
- parser.add_argument(
- "--queue_url", type=str, default=os.getenv("DELETE_OBJECTS_QUEUE")
- )
- return parser.parse_args(args)
- if __name__ == "__main__":
- opts = parse_args(sys.argv[1:])
- main(opts.queue_url, opts.max_messages, opts.wait_time, opts.sleep_time)
|