data_lake.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. #
  19. """
  20. This module contains integration with Azure Data Lake.
  21. AzureDataLakeHook communicates via a REST API compatible with WebHDFS. Make sure that a
  22. Airflow connection of type `azure_data_lake` exists. Authorization can be done by supplying a
  23. login (=Client ID), password (=Client Secret) and extra fields tenant (Tenant) and account_name (Account Name)
  24. (see connection `azure_data_lake_default` for an example).
  25. """
  26. from typing import Any, Dict, Optional
  27. from azure.datalake.store import core, lib, multithread
  28. from airflow.exceptions import AirflowException
  29. from airflow.hooks.base import BaseHook
  30. class AzureDataLakeHook(BaseHook):
  31. """
  32. Interacts with Azure Data Lake.
  33. Client ID and client secret should be in user and password parameters.
  34. Tenant and account name should be extra field as
  35. {"tenant": "<TENANT>", "account_name": "ACCOUNT_NAME"}.
  36. :param azure_data_lake_conn_id: Reference to the :ref:`Azure Data Lake connection<howto/connection:adl>`.
  37. """
  38. conn_name_attr = 'azure_data_lake_conn_id'
  39. default_conn_name = 'azure_data_lake_default'
  40. conn_type = 'azure_data_lake'
  41. hook_name = 'Azure Data Lake'
  42. @staticmethod
  43. def get_connection_form_widgets() -> Dict[str, Any]:
  44. """Returns connection widgets to add to connection form"""
  45. from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
  46. from flask_babel import lazy_gettext
  47. from wtforms import StringField
  48. return {
  49. "extra__azure_data_lake__tenant": StringField(
  50. lazy_gettext('Azure Tenant ID'), widget=BS3TextFieldWidget()
  51. ),
  52. "extra__azure_data_lake__account_name": StringField(
  53. lazy_gettext('Azure DataLake Store Name'), widget=BS3TextFieldWidget()
  54. ),
  55. }
  56. @staticmethod
  57. def get_ui_field_behaviour() -> Dict[str, Any]:
  58. """Returns custom field behaviour"""
  59. return {
  60. "hidden_fields": ['schema', 'port', 'host', 'extra'],
  61. "relabeling": {
  62. 'login': 'Azure Client ID',
  63. 'password': 'Azure Client Secret',
  64. },
  65. "placeholders": {
  66. 'login': 'client id',
  67. 'password': 'secret',
  68. 'extra__azure_data_lake__tenant': 'tenant id',
  69. 'extra__azure_data_lake__account_name': 'datalake store',
  70. },
  71. }
  72. def __init__(self, azure_data_lake_conn_id: str = default_conn_name) -> None:
  73. super().__init__()
  74. self.conn_id = azure_data_lake_conn_id
  75. self._conn: Optional[core.AzureDLFileSystem] = None
  76. self.account_name: Optional[str] = None
  77. def get_conn(self) -> core.AzureDLFileSystem:
  78. """Return a AzureDLFileSystem object."""
  79. if not self._conn:
  80. conn = self.get_connection(self.conn_id)
  81. service_options = conn.extra_dejson
  82. self.account_name = service_options.get('account_name') or service_options.get(
  83. 'extra__azure_data_lake__account_name'
  84. )
  85. tenant = service_options.get('tenant') or service_options.get('extra__azure_data_lake__tenant')
  86. adl_creds = lib.auth(tenant_id=tenant, client_secret=conn.password, client_id=conn.login)
  87. self._conn = core.AzureDLFileSystem(adl_creds, store_name=self.account_name)
  88. self._conn.connect()
  89. return self._conn
  90. def check_for_file(self, file_path: str) -> bool:
  91. """
  92. Check if a file exists on Azure Data Lake.
  93. :param file_path: Path and name of the file.
  94. :return: True if the file exists, False otherwise.
  95. :rtype: bool
  96. """
  97. try:
  98. files = self.get_conn().glob(file_path, details=False, invalidate_cache=True)
  99. return len(files) == 1
  100. except FileNotFoundError:
  101. return False
  102. def upload_file(
  103. self,
  104. local_path: str,
  105. remote_path: str,
  106. nthreads: int = 64,
  107. overwrite: bool = True,
  108. buffersize: int = 4194304,
  109. blocksize: int = 4194304,
  110. **kwargs,
  111. ) -> None:
  112. """
  113. Upload a file to Azure Data Lake.
  114. :param local_path: local path. Can be single file, directory (in which case,
  115. upload recursively) or glob pattern. Recursive glob patterns using `**`
  116. are not supported.
  117. :param remote_path: Remote path to upload to; if multiple files, this is the
  118. directory root to write within.
  119. :param nthreads: Number of threads to use. If None, uses the number of cores.
  120. :param overwrite: Whether to forcibly overwrite existing files/directories.
  121. If False and remote path is a directory, will quit regardless if any files
  122. would be overwritten or not. If True, only matching filenames are actually
  123. overwritten.
  124. :param buffersize: int [2**22]
  125. Number of bytes for internal buffer. This block cannot be bigger than
  126. a chunk and cannot be smaller than a block.
  127. :param blocksize: int [2**22]
  128. Number of bytes for a block. Within each chunk, we write a smaller
  129. block for each API call. This block cannot be bigger than a chunk.
  130. """
  131. multithread.ADLUploader(
  132. self.get_conn(),
  133. lpath=local_path,
  134. rpath=remote_path,
  135. nthreads=nthreads,
  136. overwrite=overwrite,
  137. buffersize=buffersize,
  138. blocksize=blocksize,
  139. **kwargs,
  140. )
  141. def download_file(
  142. self,
  143. local_path: str,
  144. remote_path: str,
  145. nthreads: int = 64,
  146. overwrite: bool = True,
  147. buffersize: int = 4194304,
  148. blocksize: int = 4194304,
  149. **kwargs,
  150. ) -> None:
  151. """
  152. Download a file from Azure Blob Storage.
  153. :param local_path: local path. If downloading a single file, will write to this
  154. specific file, unless it is an existing directory, in which case a file is
  155. created within it. If downloading multiple files, this is the root
  156. directory to write within. Will create directories as required.
  157. :param remote_path: remote path/globstring to use to find remote files.
  158. Recursive glob patterns using `**` are not supported.
  159. :param nthreads: Number of threads to use. If None, uses the number of cores.
  160. :param overwrite: Whether to forcibly overwrite existing files/directories.
  161. If False and remote path is a directory, will quit regardless if any files
  162. would be overwritten or not. If True, only matching filenames are actually
  163. overwritten.
  164. :param buffersize: int [2**22]
  165. Number of bytes for internal buffer. This block cannot be bigger than
  166. a chunk and cannot be smaller than a block.
  167. :param blocksize: int [2**22]
  168. Number of bytes for a block. Within each chunk, we write a smaller
  169. block for each API call. This block cannot be bigger than a chunk.
  170. """
  171. multithread.ADLDownloader(
  172. self.get_conn(),
  173. lpath=local_path,
  174. rpath=remote_path,
  175. nthreads=nthreads,
  176. overwrite=overwrite,
  177. buffersize=buffersize,
  178. blocksize=blocksize,
  179. **kwargs,
  180. )
  181. def list(self, path: str) -> list:
  182. """
  183. List files in Azure Data Lake Storage
  184. :param path: full path/globstring to use to list files in ADLS
  185. """
  186. if "*" in path:
  187. return self.get_conn().glob(path)
  188. else:
  189. return self.get_conn().walk(path)
  190. def remove(self, path: str, recursive: bool = False, ignore_not_found: bool = True) -> None:
  191. """
  192. Remove files in Azure Data Lake Storage
  193. :param path: A directory or file to remove in ADLS
  194. :param recursive: Whether to loop into directories in the location and remove the files
  195. :param ignore_not_found: Whether to raise error if file to delete is not found
  196. """
  197. try:
  198. self.get_conn().remove(path=path, recursive=recursive)
  199. except FileNotFoundError:
  200. if ignore_not_found:
  201. self.log.info("File %s not found", path)
  202. else:
  203. raise AirflowException(f"File {path} not found")