generate_queries.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. """
  2. Task for generating Athena queries from glue catalog aka Query Planning
  3. """
  4. import json
  5. import os
  6. import boto3
  7. from operator import itemgetter
  8. from boto_utils import paginate, batch_sqs_msgs, deserialize_item, DecimalEncoder
  9. from decorators import with_logging
  10. ddb = boto3.resource("dynamodb")
  11. ddb_client = boto3.client("dynamodb")
  12. glue_client = boto3.client("glue")
  13. s3 = boto3.resource("s3")
  14. sqs = boto3.resource("sqs")
  15. queue = sqs.Queue(os.getenv("QueryQueue"))
  16. jobs_table = ddb.Table(os.getenv("JobTable", "S3F2_Jobs"))
  17. data_mapper_table_name = os.getenv("DataMapperTable", "S3F2_DataMappers")
  18. deletion_queue_table_name = os.getenv("DeletionQueueTable", "S3F2_DeletionQueue")
  19. manifests_bucket_name = os.getenv("ManifestsBucket", "S3F2-manifests-bucket")
  20. glue_db = os.getenv("GlueDatabase", "s3f2_manifests_database")
  21. glue_table = os.getenv("JobManifestsGlueTable", "s3f2_manifests_table")
  22. COMPOSITE_JOIN_TOKEN = "_S3F2COMP_"
  23. MANIFEST_KEY = "manifests/{job_id}/{data_mapper_id}/manifest.json"
  24. COMPOSITE_JOIN_TOKEN = "_S3F2COMP_"
  25. ARRAYSTRUCT = "array<struct>"
  26. ARRAYSTRUCT_PREFIX = "array<struct<"
  27. ARRAYSTRUCT_SUFFIX = ">>"
  28. STRUCT = "struct"
  29. STRUCT_PREFIX = "struct<"
  30. STRUCT_SUFFIX = ">"
  31. SCHEMA_INVALID = "Column schema is not valid"
  32. ALLOWED_TYPES = [
  33. "bigint",
  34. "char",
  35. "decimal",
  36. "double",
  37. "float",
  38. "int",
  39. "smallint",
  40. "string",
  41. "tinyint",
  42. "varchar",
  43. ]
  44. @with_logging
  45. def handler(event, context):
  46. job_id = event["ExecutionName"]
  47. deletion_items = get_deletion_queue()
  48. manifests_partitions = []
  49. data_mappers = get_data_mappers()
  50. total_queries = 0
  51. for data_mapper in data_mappers:
  52. query_executor = data_mapper["QueryExecutor"]
  53. if query_executor == "athena":
  54. queries = generate_athena_queries(data_mapper, deletion_items, job_id)
  55. if len(queries) > 0:
  56. manifests_partitions.append([job_id, data_mapper["DataMapperId"]])
  57. else:
  58. raise NotImplementedError(
  59. "Unsupported data mapper query executor: '{}'".format(query_executor)
  60. )
  61. batch_sqs_msgs(queue, queries)
  62. total_queries += len(queries)
  63. write_partitions(manifests_partitions)
  64. return {
  65. "GeneratedQueries": total_queries,
  66. "DeletionQueueSize": len(deletion_items),
  67. "Manifests": [
  68. "s3://{}/{}".format(
  69. manifests_bucket_name,
  70. MANIFEST_KEY.format(
  71. job_id=partition_tuple[0], data_mapper_id=partition_tuple[1]
  72. ),
  73. )
  74. for partition_tuple in manifests_partitions
  75. ],
  76. }
  77. def build_manifest_row(columns, match_id, item_id, item_createdat, is_composite):
  78. """
  79. Function for building each row of the manifest that will be written to S3.
  80. * What are 'queryablematchid' and 'queryablecolumns'?
  81. A convenience stringified value of match_id and its column when the match
  82. is simple, or a stringified joint value when composite (for instance,
  83. "John_S3F2COMP_Doe" and "first_name_S3F2COMP_last_name"). The purpose of
  84. these fields is optimise query execution by doing the SQL JOINs over strings only.
  85. * What are MatchId and Columns?
  86. Original values to be used by the ECS task instead.
  87. Note that the MatchId is declared as array<string> in the Glue Table as it's
  88. not possible to declare it as array of generic types and the design is for
  89. using a single table schema for each match/column tuple, despite
  90. the current column type.
  91. This means that using the "MatchId" field in Athena will always coherce its values
  92. to strings, for instance [1234] => ["1234"]. That's ok because when working with
  93. the manifest, the Fargate task will read and parse the JSON directly and therefore
  94. will use its original type (for instance, int over strings to do the comparison).
  95. """
  96. iterable_match = match_id if is_composite else [match_id]
  97. queryable = COMPOSITE_JOIN_TOKEN.join(str(x) for x in iterable_match)
  98. queryable_cols = COMPOSITE_JOIN_TOKEN.join(str(x) for x in columns)
  99. return (
  100. json.dumps(
  101. {
  102. "Columns": columns,
  103. "MatchId": iterable_match,
  104. "DeletionQueueItemId": item_id,
  105. "CreatedAt": item_createdat,
  106. "QueryableColumns": queryable_cols,
  107. "QueryableMatchId": queryable,
  108. },
  109. cls=DecimalEncoder,
  110. )
  111. + "\n"
  112. )
  113. def generate_athena_queries(data_mapper, deletion_items, job_id):
  114. """
  115. For each Data Mapper, it generates a list of parameters needed for each
  116. query execution. The matches for the given column are saved in an external
  117. S3 object (aka manifest) to allow its size to grow into the thousands without
  118. incurring in DDB Document size limit, SQS message size limit, or Athena query
  119. size limit. The manifest S3 Path is finally referenced as part of the SQS message.
  120. """
  121. manifest_key = MANIFEST_KEY.format(
  122. job_id=job_id, data_mapper_id=data_mapper["DataMapperId"]
  123. )
  124. db = data_mapper["QueryExecutorParameters"]["Database"]
  125. table_name = data_mapper["QueryExecutorParameters"]["Table"]
  126. table = get_table(db, table_name)
  127. columns_tree = get_columns_tree(table)
  128. all_partition_keys = [p["Name"] for p in table.get("PartitionKeys", [])]
  129. partition_keys = data_mapper["QueryExecutorParameters"].get(
  130. "PartitionKeys", all_partition_keys
  131. )
  132. columns = [c for c in data_mapper["Columns"]]
  133. msg = {
  134. "DataMapperId": data_mapper["DataMapperId"],
  135. "QueryExecutor": data_mapper["QueryExecutor"],
  136. "Format": data_mapper["Format"],
  137. "Database": db,
  138. "Table": table_name,
  139. "Columns": columns,
  140. "PartitionKeys": [],
  141. "DeleteOldVersions": data_mapper.get("DeleteOldVersions", True),
  142. "IgnoreObjectNotFoundExceptions": data_mapper.get(
  143. "IgnoreObjectNotFoundExceptions", False
  144. ),
  145. }
  146. if data_mapper.get("RoleArn", None):
  147. msg["RoleArn"] = data_mapper["RoleArn"]
  148. # Workout which deletion items should be included in this query
  149. applicable_match_ids = [
  150. item
  151. for item in deletion_items
  152. if msg["DataMapperId"] in item.get("DataMappers", [])
  153. or len(item.get("DataMappers", [])) == 0
  154. ]
  155. if len(applicable_match_ids) == 0:
  156. return []
  157. # Compile a list of MatchIds grouped by Column
  158. columns_with_matches = {}
  159. manifest = ""
  160. for item in applicable_match_ids:
  161. mid, item_id, item_createdat = itemgetter(
  162. "MatchId", "DeletionQueueItemId", "CreatedAt"
  163. )(item)
  164. is_simple = not isinstance(mid, list)
  165. if is_simple:
  166. for column in msg["Columns"]:
  167. casted = cast_to_type(mid, column, table_name, columns_tree)
  168. if column not in columns_with_matches:
  169. columns_with_matches[column] = {
  170. "Column": column,
  171. "Type": "Simple",
  172. }
  173. manifest += build_manifest_row(
  174. [column], casted, item_id, item_createdat, False
  175. )
  176. else:
  177. sorted_mid = sorted(mid, key=lambda x: x["Column"])
  178. query_columns = list(map(lambda x: x["Column"], sorted_mid))
  179. column_key = COMPOSITE_JOIN_TOKEN.join(query_columns)
  180. composite_match = list(
  181. map(
  182. lambda x: cast_to_type(
  183. x["Value"], x["Column"], table_name, columns_tree
  184. ),
  185. sorted_mid,
  186. )
  187. )
  188. if column_key not in columns_with_matches:
  189. columns_with_matches[column_key] = {
  190. "Columns": query_columns,
  191. "Type": "Composite",
  192. }
  193. manifest += build_manifest_row(
  194. query_columns, composite_match, item_id, item_createdat, True
  195. )
  196. s3.Bucket(manifests_bucket_name).put_object(Body=manifest, Key=manifest_key)
  197. msg["Columns"] = list(columns_with_matches.values())
  198. msg["Manifest"] = "s3://{}/{}".format(manifests_bucket_name, manifest_key)
  199. if len(partition_keys) == 0:
  200. return [msg]
  201. # For every partition combo of every table, create a query
  202. partitions = set()
  203. for partition in get_partitions(db, table_name):
  204. current = tuple(
  205. (
  206. all_partition_keys[i],
  207. cast_to_type(v, all_partition_keys[i], table_name, columns_tree),
  208. )
  209. for i, v in enumerate(partition["Values"])
  210. if all_partition_keys[i] in partition_keys
  211. )
  212. partitions.add(current)
  213. ret = []
  214. for current in partitions:
  215. current_dict = [{"Key": k, "Value": v} for k, v in current]
  216. ret.append({**msg, "PartitionKeys": current_dict})
  217. return ret
  218. def get_deletion_queue():
  219. results = paginate(
  220. ddb_client, ddb_client.scan, "Items", TableName=deletion_queue_table_name
  221. )
  222. return [deserialize_item(result) for result in results]
  223. def get_data_mappers():
  224. results = paginate(
  225. ddb_client, ddb_client.scan, "Items", TableName=data_mapper_table_name
  226. )
  227. for result in results:
  228. yield deserialize_item(result)
  229. def get_table(db, table_name):
  230. return glue_client.get_table(DatabaseName=db, Name=table_name)["Table"]
  231. def get_columns_tree(table):
  232. return list(
  233. map(
  234. column_mapper,
  235. table["StorageDescriptor"]["Columns"] + table.get("PartitionKeys", []),
  236. )
  237. )
  238. def get_partitions(db, table_name):
  239. return paginate(
  240. glue_client,
  241. glue_client.get_partitions,
  242. ["Partitions"],
  243. DatabaseName=db,
  244. TableName=table_name,
  245. ExcludeColumnSchema=True,
  246. )
  247. def write_partitions(partitions):
  248. """
  249. In order for the manifests to be used by Athena in a JOIN, we make them
  250. available as partitions with Job and DataMapperId tuple.
  251. """
  252. max_create_batch_size = 100
  253. for i in range(0, len(partitions), max_create_batch_size):
  254. glue_client.batch_create_partition(
  255. DatabaseName=glue_db,
  256. TableName=glue_table,
  257. PartitionInputList=[
  258. {
  259. "Values": partition_tuple,
  260. "StorageDescriptor": {
  261. "Columns": [
  262. {"Name": "columns", "Type": "array<string>"},
  263. {"Name": "matchid", "Type": "array<string>"},
  264. {"Name": "deletionqueueitemid", "Type": "string"},
  265. {"Name": "createdat", "Type": "int"},
  266. {"Name": "queryablecolumns", "Type": "string"},
  267. {"Name": "queryablematchid", "Type": "string"},
  268. ],
  269. "Location": "s3://{}/manifests/{}/{}/".format(
  270. manifests_bucket_name,
  271. partition_tuple[0],
  272. partition_tuple[1],
  273. ),
  274. "InputFormat": "org.apache.hadoop.mapred.TextInputFormat",
  275. "OutputFormat": "org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat",
  276. "Compressed": False,
  277. "SerdeInfo": {
  278. "SerializationLibrary": "org.openx.data.jsonserde.JsonSerDe",
  279. },
  280. "StoredAsSubDirectories": False,
  281. },
  282. }
  283. for partition_tuple in partitions[i : i + max_create_batch_size]
  284. ],
  285. )
  286. def get_inner_children(str, prefix, suffix):
  287. """
  288. Function to get inner children from complex type string
  289. "struct<name:string,age:int>" => "name:string,age:int"
  290. """
  291. if not str.endswith(suffix):
  292. raise ValueError(SCHEMA_INVALID)
  293. return str[len(prefix) : -len(suffix)]
  294. def get_nested_children(str, nested_type):
  295. """
  296. Function to get next nested child type from a children string
  297. starting with a complex type such as struct or array
  298. "struct<name:string,age:int,s:struct<n:int>>,b:string" =>
  299. "struct<name:string,age:int,s:struct<n:int>>"
  300. """
  301. is_struct = nested_type == STRUCT
  302. prefix = STRUCT_PREFIX if is_struct else ARRAYSTRUCT_PREFIX
  303. suffix = STRUCT_SUFFIX if is_struct else ARRAYSTRUCT_SUFFIX
  304. n_opened_tags = len(suffix)
  305. end_index = -1
  306. to_parse = str[len(prefix) :]
  307. for i in range(len(to_parse)):
  308. char = to_parse[i : (i + 1)]
  309. if char == "<":
  310. n_opened_tags += 1
  311. if char == ">":
  312. n_opened_tags -= 1
  313. if n_opened_tags == 0:
  314. end_index = i
  315. break
  316. if end_index < 0:
  317. raise ValueError(SCHEMA_INVALID)
  318. return str[0 : (end_index + len(prefix) + 1)]
  319. def get_nested_type(str):
  320. """
  321. Function to get next nested child type from a children string
  322. starting with a non complex type
  323. "string,a:int" => "string"
  324. """
  325. upper_index = str.find(",")
  326. return str[0:upper_index] if upper_index >= 0 else str
  327. def set_no_identifier_to_node_and_its_children(node):
  328. """
  329. Function to set canBeIdentifier=false to item and its children
  330. Example:
  331. {
  332. name: "arr",
  333. type: "array<struct>",
  334. canBeIdentifier: false,
  335. children: [
  336. { name: "field", type: "int", canBeIdentifier: true },
  337. { name: "n", type: "string", canBeIdentifier: true }
  338. ]
  339. } => {
  340. name: "arr",
  341. type: "array<struct>",
  342. canBeIdentifier: false,
  343. children: [
  344. { name: "field", type: "int", canBeIdentifier: false },
  345. { name: "n", type: "string", canBeIdentifier: false }
  346. ]
  347. }
  348. """
  349. node["CanBeIdentifier"] = False
  350. for child in node.get("Children", []):
  351. set_no_identifier_to_node_and_its_children(child)
  352. def column_mapper(col):
  353. """
  354. Function to map Columns from AWS Glue schema to tree
  355. Example 1:
  356. { Name: "Name", Type: "int" } =>
  357. { name: "Name", type: "int", canBeIdentifier: true }
  358. Example 2:
  359. { Name: "complex", Type: "struct<a:string,b:struct<c:int>>"} =>
  360. { name: "complex", type: "struct", children: [
  361. { name: "a", type: "string", canBeIdentifier: false},
  362. { name: "b", type: "struct", children: [
  363. { name: "c", type: "int", canBeIdentifier: false}
  364. ], canBeIdentifier: false}
  365. ], canBeIdentifier: false}
  366. """
  367. prefix = suffix = None
  368. result_type = col["Type"]
  369. has_children = False
  370. if result_type.startswith(ARRAYSTRUCT_PREFIX):
  371. result_type = ARRAYSTRUCT
  372. prefix = ARRAYSTRUCT_PREFIX
  373. suffix = ARRAYSTRUCT_SUFFIX
  374. has_children = True
  375. elif result_type.startswith(STRUCT_PREFIX):
  376. result_type = STRUCT
  377. prefix = STRUCT_PREFIX
  378. suffix = STRUCT_SUFFIX
  379. has_children = True
  380. type_is_decimal_with_precision = result_type.startswith("decimal(")
  381. result = {
  382. "Name": col["Name"],
  383. "Type": result_type,
  384. "CanBeIdentifier": col["CanBeIdentifier"]
  385. if "CanBeIdentifier" in col
  386. else result_type in ALLOWED_TYPES or type_is_decimal_with_precision,
  387. }
  388. if has_children:
  389. result["Children"] = []
  390. children_to_parse = get_inner_children(col["Type"], prefix, suffix)
  391. while len(children_to_parse) > 0:
  392. sep = ":"
  393. name = children_to_parse[0 : children_to_parse.index(sep)]
  394. rest = children_to_parse[len(name) + len(sep) :]
  395. nested_type = "other"
  396. if rest.startswith(STRUCT_PREFIX):
  397. nested_type = STRUCT
  398. elif rest.startswith(ARRAYSTRUCT_PREFIX):
  399. nested_type = ARRAYSTRUCT
  400. c_type = (
  401. get_nested_type(rest)
  402. if nested_type == "other"
  403. else get_nested_children(rest, nested_type)
  404. )
  405. result["Children"].append(
  406. column_mapper(
  407. {
  408. "Name": name,
  409. "Type": c_type,
  410. "CanBeIdentifier": c_type in ALLOWED_TYPES,
  411. }
  412. )
  413. )
  414. children_to_parse = children_to_parse[len(name) + len(sep) + len(c_type) :]
  415. if children_to_parse.startswith(","):
  416. children_to_parse = children_to_parse[1:]
  417. if result_type != STRUCT:
  418. set_no_identifier_to_node_and_its_children(result)
  419. return result
  420. def get_column_info(col, columns_tree):
  421. current = columns_tree
  422. col_array = col.split(".")
  423. found = None
  424. for col_segment in col_array:
  425. found = next((x for x in current if x["Name"] == col_segment), None)
  426. if not found:
  427. return None, False
  428. current = found["Children"] if "Children" in found else []
  429. return found["Type"], found["CanBeIdentifier"]
  430. def cast_to_type(val, col, table_name, columns_tree):
  431. col_type, can_be_identifier = get_column_info(col, columns_tree)
  432. if not col_type:
  433. raise ValueError("Column {} not found at table {}".format(col, table_name))
  434. elif not can_be_identifier:
  435. raise ValueError(
  436. "Column {} at table {} is not a supported column type for querying".format(
  437. col, table_name
  438. )
  439. )
  440. if col_type in ("bigint", "int", "smallint", "tinyint"):
  441. return int(val)
  442. if col_type in ("double", "float"):
  443. return float(val)
  444. return str(val)