handlers.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. """
  2. DataMapper handlers
  3. """
  4. import json
  5. import os
  6. import boto3
  7. from boto_utils import DecimalEncoder, get_user_info, running_job_exists
  8. from decorators import (
  9. with_logging,
  10. request_validator,
  11. catch_errors,
  12. add_cors_headers,
  13. json_body_loader,
  14. load_schema,
  15. )
  16. dynamodb_resource = boto3.resource("dynamodb")
  17. table = dynamodb_resource.Table(os.getenv("DataMapperTable"))
  18. glue_client = boto3.client("glue")
  19. PARQUET_HIVE_SERDE = "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"
  20. JSON_HIVE_SERDE = "org.apache.hive.hcatalog.data.JsonSerDe"
  21. JSON_OPENX_SERDE = "org.openx.data.jsonserde.JsonSerDe"
  22. SUPPORTED_SERDE_LIBS = [PARQUET_HIVE_SERDE, JSON_HIVE_SERDE, JSON_OPENX_SERDE]
  23. @with_logging
  24. @add_cors_headers
  25. @request_validator(load_schema("get_data_mapper"))
  26. @catch_errors
  27. def get_data_mapper_handler(event, context):
  28. data_mapper_id = event["pathParameters"]["data_mapper_id"]
  29. item = table.get_item(Key={"DataMapperId": data_mapper_id}).get("Item")
  30. if not item:
  31. return {"statusCode": 404}
  32. return {"statusCode": 200, "body": json.dumps(item, cls=DecimalEncoder)}
  33. @with_logging
  34. @add_cors_headers
  35. @request_validator(load_schema("list_data_mappers"))
  36. @catch_errors
  37. def get_data_mappers_handler(event, context):
  38. qs = event.get("queryStringParameters")
  39. if not qs:
  40. qs = {}
  41. page_size = int(qs.get("page_size", 10))
  42. scan_params = {"Limit": page_size}
  43. start_at = qs.get("start_at")
  44. if start_at:
  45. scan_params["ExclusiveStartKey"] = {"DataMapperId": start_at}
  46. items = table.scan(**scan_params).get("Items", [])
  47. if len(items) < page_size:
  48. next_start = None
  49. else:
  50. next_start = items[-1]["DataMapperId"]
  51. return {
  52. "statusCode": 200,
  53. "body": json.dumps(
  54. {"DataMappers": items, "NextStart": next_start}, cls=DecimalEncoder
  55. ),
  56. }
  57. @with_logging
  58. @add_cors_headers
  59. @json_body_loader
  60. @request_validator(load_schema("create_data_mapper"))
  61. @catch_errors
  62. def put_data_mapper_handler(event, context):
  63. path_params = event["pathParameters"]
  64. body = event["body"]
  65. validate_mapper(body)
  66. item = {
  67. "DataMapperId": path_params["data_mapper_id"],
  68. "Columns": body["Columns"],
  69. "QueryExecutor": body["QueryExecutor"],
  70. "QueryExecutorParameters": body["QueryExecutorParameters"],
  71. "CreatedBy": get_user_info(event),
  72. "RoleArn": body["RoleArn"],
  73. "Format": body.get("Format", "parquet"),
  74. "DeleteOldVersions": body.get("DeleteOldVersions", True),
  75. "IgnoreObjectNotFoundExceptions": body.get(
  76. "IgnoreObjectNotFoundExceptions", False
  77. ),
  78. }
  79. table.put_item(Item=item)
  80. return {"statusCode": 201, "body": json.dumps(item)}
  81. @with_logging
  82. @add_cors_headers
  83. @request_validator(load_schema("delete_data_mapper"))
  84. @catch_errors
  85. def delete_data_mapper_handler(event, context):
  86. if running_job_exists():
  87. raise ValueError("Cannot delete Data Mappers whilst there is a job in progress")
  88. data_mapper_id = event["pathParameters"]["data_mapper_id"]
  89. table.delete_item(Key={"DataMapperId": data_mapper_id})
  90. return {"statusCode": 204}
  91. def validate_mapper(mapper):
  92. existing_s3_locations = get_existing_s3_locations(mapper["DataMapperId"])
  93. if mapper["QueryExecutorParameters"].get("DataCatalogProvider") == "glue":
  94. table_details = get_table_details_from_mapper(mapper)
  95. new_location = get_glue_table_location(table_details)
  96. serde_lib, serde_params = get_glue_table_format(table_details)
  97. for partition in mapper["QueryExecutorParameters"].get("PartitionKeys", []):
  98. if partition not in get_glue_table_partition_keys(table_details):
  99. raise ValueError("Partition Key {} doesn't exist".format(partition))
  100. if any([is_overlap(new_location, e) for e in existing_s3_locations]):
  101. raise ValueError(
  102. "A data mapper already exists which covers this S3 location"
  103. )
  104. if serde_lib not in SUPPORTED_SERDE_LIBS:
  105. raise ValueError(
  106. "The format for the specified table is not supported. The SerDe lib must be one of {}".format(
  107. ", ".join(SUPPORTED_SERDE_LIBS)
  108. )
  109. )
  110. if serde_lib == JSON_OPENX_SERDE:
  111. not_allowed_json_params = {
  112. "ignore.malformed.json": "TRUE",
  113. "dots.in.keys": "TRUE",
  114. }
  115. for param, value in not_allowed_json_params.items():
  116. if param in serde_params and serde_params[param] == value:
  117. raise ValueError(
  118. "The parameter {} cannot be {} for SerDe library {}".format(
  119. param, value, JSON_OPENX_SERDE
  120. )
  121. )
  122. if any([k for k, v in serde_params.items() if k.startswith("mapping.")]):
  123. raise ValueError(
  124. "Column mappings are not supported for SerDe library {}".format(
  125. JSON_OPENX_SERDE
  126. )
  127. )
  128. def get_existing_s3_locations(current_data_mapper_id):
  129. items = table.scan()["Items"]
  130. glue_mappers = [
  131. get_table_details_from_mapper(mapper)
  132. for mapper in items
  133. if mapper["QueryExecutorParameters"].get("DataCatalogProvider") == "glue"
  134. and mapper["DataMapperId"] != current_data_mapper_id
  135. ]
  136. return [get_glue_table_location(m) for m in glue_mappers]
  137. def get_table_details_from_mapper(mapper):
  138. db = mapper["QueryExecutorParameters"]["Database"]
  139. table_name = mapper["QueryExecutorParameters"]["Table"]
  140. return glue_client.get_table(DatabaseName=db, Name=table_name)
  141. def get_glue_table_location(t):
  142. return t["Table"]["StorageDescriptor"]["Location"]
  143. def get_glue_table_format(t):
  144. return (
  145. t["Table"]["StorageDescriptor"]["SerdeInfo"]["SerializationLibrary"],
  146. t["Table"]["StorageDescriptor"]["SerdeInfo"]["Parameters"],
  147. )
  148. def get_glue_table_partition_keys(t):
  149. return [x["Name"] for x in t["Table"]["PartitionKeys"]]
  150. def is_overlap(a, b):
  151. return a in b or b in a