Bladeren bron

solve conflits

Sjim 2 jaren geleden
bovenliggende
commit
4113a9adea
100 gewijzigde bestanden met toevoegingen van 10589 en 1 verwijderingen
  1. 78 0
      Azure/add_azure_account.py
  2. 202 0
      Azure/add_azure_account_and_set_role_assignment.py
  3. 99 0
      Azure/adls.py
  4. 79 0
      Azure/adx.py
  5. 129 0
      Azure/azure_clients.py
  6. 855 0
      Azure/azure_rm.py
  7. 161 0
      Azure/azure_rm_aks_facts.py
  8. 94 0
      Azure/azure_service_principal_attribute.py
  9. 187 0
      Azure/azure_storage.py
  10. 155 0
      Azure/azure_system_helpers.py
  11. 283 0
      Azure/classAzureProvider.py
  12. 445 0
      Azure/client.py
  13. 96 0
      Azure/container_volume.py
  14. 227 0
      Azure/data_lake.py
  15. 91 0
      Azure/reproduce-14067.py
  16. 395 0
      Azure/submit_azureml_pytest.py
  17. 193 0
      Azure/test_adx.py
  18. 237 0
      Azure/test_azure.py
  19. 167 0
      Azure/test_azure_batch.py
  20. 99 0
      Azure/test_azure_container_instance.py
  21. 236 0
      Azure/test_azure_cosmos.py
  22. 594 0
      Azure/test_azure_data_factory.py
  23. 133 0
      Azure/test_azure_data_lake.py
  24. 257 0
      Azure/test_azure_fileshare.py
  25. 120 0
      Azure/test_azure_fileshare_to_gcs.py
  26. 1427 0
      Azure/test_azure_helper.py
  27. 143 0
      Azure/test_azure_key_vault.py
  28. 79 0
      Azure/test_base_azure.py
  29. 124 0
      Azure/test_oracle_to_azure_data_lake.py
  30. 469 0
      Azure/validate_azure_dladmin_identity.py
  31. 118 0
      File/outbuf.py
  32. 891 1
      File/utils.py
  33. 21 0
      Target/Azure/AddUp/Azure-blob-storage_4.py
  34. 15 0
      Target/Azure/AddUp/Azure-blob-storage_5.py
  35. 29 0
      Target/Azure/AddUp/blob-upload-1_1.py
  36. 9 0
      Target/Azure/AddUp/blob-upload-2_3.py
  37. 5 0
      Target/Azure/AddUp/blob-upload-2_4.py
  38. 13 0
      Target/Azure/AddUp/blob-upload-2_5.py
  39. 34 0
      Target/Azure/AddUp/blob-upload-2_6.py
  40. 22 0
      Target/Azure/AddUp/blob-upload-2_7.py
  41. 9 0
      Target/Azure/AddUp/blob-upload-2_9.py
  42. 58 0
      Target/Azure/AddUp/blob-upload_1.py
  43. 103 0
      Target/Azure/AddUp/circuitbreaker_1.py
  44. 97 0
      Target/Azure/AddUp/datafactory_4.py
  45. 28 0
      Target/Azure/AddUp/file_advanced_samples_2.py
  46. 20 0
      Target/Azure/AddUp/file_advanced_samples_3.py
  47. 22 0
      Target/Azure/AddUp/file_advanced_samples_4.py
  48. 65 0
      Target/Azure/AddUp/file_advanced_samples_6.py
  49. 22 0
      Target/Azure/AddUp/file_basic_samples_2.py
  50. 105 0
      Target/Azure/AddUp/file_basic_samples_3.py
  51. 29 0
      Target/Azure/AddUp/file_basic_samples_4.py
  52. 40 0
      Target/Azure/AddUp/python-quick-start_3.py
  53. 24 0
      Target/Azure/AddUp/table_advanced_samples_2.py
  54. 18 0
      Target/Azure/AddUp/table_advanced_samples_4.py
  55. 21 0
      Target/Azure/AddUp/table_advanced_samples_5.py
  56. 50 0
      Target/Azure/AddUp/table_advanced_samples_7.py
  57. 58 0
      Target/Azure/AddUp/table_basic_samples_2.py
  58. 32 0
      Target/Azure/DLfile_6.py
  59. 12 0
      Target/Azure/add_azure_account_1.py
  60. 16 0
      Target/Azure/add_azure_account_and_set_role_assignment_1.py
  61. 18 0
      Target/Azure/add_azure_account_and_set_role_assignment_2.py
  62. 20 0
      Target/Azure/add_azure_account_and_set_role_assignment_3.py
  63. 26 0
      Target/Azure/add_azure_account_and_set_role_assignment_4.py
  64. 20 0
      Target/Azure/add_azure_account_and_set_role_assignment_5.py
  65. 63 0
      Target/Azure/add_azure_account_and_set_role_assignment_6.py
  66. 3 0
      Target/Azure/adls_2.py
  67. 4 0
      Target/Azure/adls_4.py
  68. 10 0
      Target/Azure/azure_clients_1.py
  69. 11 0
      Target/Azure/azure_clients_2.py
  70. 11 0
      Target/Azure/azure_clients_3.py
  71. 11 0
      Target/Azure/azure_clients_4.py
  72. 11 0
      Target/Azure/azure_clients_5.py
  73. 11 0
      Target/Azure/azure_clients_6.py
  74. 11 0
      Target/Azure/azure_clients_7.py
  75. 13 0
      Target/Azure/azure_clients_8.py
  76. 11 0
      Target/Azure/azure_clients_9.py
  77. 25 0
      Target/Azure/azure_rm_14.py
  78. 98 0
      Target/Azure/azure_rm_15.py
  79. 59 0
      Target/Azure/azure_rm_2.py
  80. 11 0
      Target/Azure/azure_rm_9.py
  81. 17 0
      Target/Azure/azure_rm_aks_facts_4.py
  82. 32 0
      Target/Azure/azure_service_principal_attribute_1.py
  83. 7 0
      Target/Azure/azure_storage_11.py
  84. 28 0
      Target/Azure/azure_storage_12.py
  85. 15 0
      Target/Azure/azure_storage_5.py
  86. 10 0
      Target/Azure/azure_storage_8.py
  87. 25 0
      Target/Azure/azure_system_helpers_2.py
  88. 9 0
      Target/Azure/azure_system_helpers_3.py
  89. 3 0
      Target/Azure/azure_system_helpers_4.py
  90. 3 0
      Target/Azure/azure_system_helpers_5.py
  91. 3 0
      Target/Azure/azure_system_helpers_6.py
  92. 15 0
      Target/Azure/azure_system_helpers_7.py
  93. 16 0
      Target/Azure/azure_system_helpers_8.py
  94. 22 0
      Target/Azure/blob-adapter_2.py
  95. 4 0
      Target/Azure/blob-adapter_3.py
  96. 10 0
      Target/Azure/blob-adapter_4.py
  97. 17 0
      Target/Azure/blob-permission_3.py
  98. 46 0
      Target/Azure/blob-upload-1_3.py
  99. 8 0
      Target/Azure/blob-upload-1_4.py
  100. 12 0
      Target/Azure/blob-upload-2_4.py

+ 78 - 0
Azure/add_azure_account.py

@@ -0,0 +1,78 @@
+import numpy as np
+import boto3
+import requests
+import sys
+import json
+import time
+
+"""
+To run this you use python add_azure_account.py <CloudCheckrApiKey> <NameOfCloudCheckrAccount> <AzureDirectoryId> <AzureApplicationId> <AzureApplictionSecretKey> <AzureSubscriptionId> 
+
+To run this are the following input parameters cloudcheckr-admin-api-key unique-account-name-in-cloudcheckr azure-active-directory-id azure-application-id azure-application-secret azure-subscription-id
+
+The CloudCheckr admin api key is a 64 character string.
+The CloudCheckr Account name is the name of the new account in CloudCheckr.
+The azure-active-directory-id is the GUID directory id. This will generally be the same for all subscriptions that have the same parent. (Parent being their associated CSP or EA account that contains cost data)
+The azure-application-id is the GUID id of the application that was created previously. It can be re-used for multiple subscriptions, but have to give the application permissions in each subscription.
+The azure-application-secret is the secret key that was created previously. This is shown only once when generating the key. It can last 1 year, 2 years, or forever.
+The azure-subscription-id is the GUID that corresponds to the id of the subscription. This subscription will be different for every account that is added to CloudCheckr. 
+
+"""
+
+def create_azure_account(env, admin_api_key, account_name, azure_ad_id, azure_app_id, azure_api_access_key, azure_subscription_id):
+	"""
+	Creates an Azure Account in CloudCheckr. It will populate it with azure subscription credentials that were provided.
+	"""
+
+	api_url = env + "/api/account.json/add_azure_inventory_account"
+
+	add_azure_account_info = json.dumps({"account_name": account_name, "azure_ad_id": azure_ad_id, "azure_app_id": azure_app_id, "azure_api_access_key": azure_api_access_key, "azure_subscription_id": azure_subscription_id})
+
+	r7 = requests.post(api_url, headers = {"Content-Type": "application/json", "access_key": admin_api_key}, data = add_azure_account_info)
+
+	print(r7.json())
+
+def main():
+	try:
+		admin_api_key = str(sys.argv[1])
+	except IndexError:
+		print("Must include an admin api key in the command line")
+		return
+
+	try:
+		account_name = str(sys.argv[2])
+	except IndexError:
+		print("Must include a cloudcheckr account name")
+		return
+
+	try:
+		azure_ad_id = str(sys.argv[3])
+	except IndexError:
+		print("Must include an Azure Directory Id")
+		return
+
+	try:
+		azure_app_id = str(sys.argv[4])
+	except IndexError:
+		print("Must include an Azure Application Id")
+		return
+
+	try:
+		azure_api_access_key = str(sys.argv[5])
+	except IndexError:
+		print("Must include an Azure Api Access Key")
+		return
+
+	try:
+		azure_subscription_id = str(sys.argv[6])
+	except IndexError:
+		print("Must include an Azure Subscription Id")
+		return
+
+	# can change this it eu.cloudcheckr.com or au.cloudcheckr.com for different environments
+	env = "https://api.cloudcheckr.com"
+
+	create_azure_account(env, admin_api_key, account_name, azure_ad_id, azure_app_id, azure_api_access_key, azure_subscription_id)
+
+if __name__ == "__main__":
+	main()

+ 202 - 0
Azure/add_azure_account_and_set_role_assignment.py

@@ -0,0 +1,202 @@
+import numpy as np
+import boto3
+import requests
+import sys
+import json
+import time
+import uuid
+
+"""
+To run this you use python add_azure_account.py <CloudCheckrApiKey> <NameOfCloudCheckrAccount> <AzureDirectoryId> <AzureSubscriptionId> <AzureAdminApplicationId> <AzureAdminApplicationSecret> <AzureCloudCheckrApplicationName> <AzureCloudCheckrApplicationSecret>
+
+To run this are the following input parameters cloudcheckr-admin-api-key unique-account-name-in-cloudcheckr azure-active-directory-id azure-subscription-id azure-admin-application-id azure-admin-application-secret
+
+The CloudCheckr admin api key is a 64 character string.
+The CloudCheckr Account name is the name of the new account in CloudCheckr.
+The azure-active-directory-id is the GUID directory id. This will generally be the same for all subscriptions that have the same parent. (Parent being their associated CSP or EA account that contains cost data)
+The azure-subscription-id is the GUID that corresponds to the id of the subscription. This subscription will be different for every account that is added to CloudCheckr.
+The azure-admin-application-id is the GUID id of the application that was created previously that has admin permissions. It needs to be able to set application role assignments for the specified subscriptoin. It needs to be able to read from the Microsoft Graph API with Application.Read.All, Application.ReadWrite.All, Directory.Read.All permissions.
+The azure-application-secret is the secret key that was created previously for the application with admin permissions. This is shown only once when generating the key. It can last 1 year, 2 years, or forever.
+The azure-cloudcheckr-application-name is the name of the application that was created specifically for CloudCheckr. It will get the reader role assigned to it.
+The azure-cloudcheckr-application-secret is the secret key that was created previously for the CloudCheckr application. This is shown only once when generating the key. It can last 1 year, 2 years, or forever.
+
+"""
+
+
+def create_azure_account(env, CloudCheckrApiKey, account_name, AzureDirectoryId, AzureCloudCheckrApplicationId,
+                         AzureCloudCheckrApplicationSecret, AzureSubscriptionId):
+    """
+    Creates an Azure Account in CloudCheckr. It will populate it with azure subscription credentials that were provided.
+    """
+
+    api_url = env + "/api/account.json/add_azure_inventory_account"
+
+    add_azure_account_info = json.dumps(
+        {"account_name": account_name, "azure_ad_id": AzureDirectoryId, "azure_app_id": AzureCloudCheckrApplicationId,
+         "azure_api_access_key": AzureCloudCheckrApplicationSecret, "azure_subscription_id": AzureSubscriptionId})
+
+    r7 = requests.post(api_url, headers={"Content-Type": "application/json", "access_key": CloudCheckrApiKey},
+                       data=add_azure_account_info)
+
+    print(r7.json())
+
+
+def get_azure_reader_role_id(AzureApiBearerToken, AzureSubscriptionId):
+    """
+    Gets the id of the reader role for this subscription.
+
+    https://docs.microsoft.com/en-us/rest/api/authorization/roleassignments/list
+    """
+
+    api_url = "https://management.azure.com/subscriptions/" + AzureSubscriptionId + "/providers/Microsoft.Authorization/roleDefinitions?api-version=2015-07-01&$filter=roleName eq 'Reader'"
+    authorization_value = "Bearer " + AzureApiBearerToken
+
+    response = requests.get(api_url, headers={"Authorization": authorization_value})
+
+    if "value" in response.json():
+        value = (response.json()["value"])[0]
+        if "id" in value:
+            return value["id"]
+    print("Failed to get the Azure Reader Role Id")
+    return None
+
+
+def get_azure_cloudcheckr_service_principal_id(AzureGraphApiBearerToken, AzureCloudCheckrApplicationName):
+    """
+    Gets the service principal id Azure Application that was specifically created for CloudCheckr.
+    Note: This is not the application id. The service principal id is required for the role assignment.
+    This uses the microsoft Graph API.
+
+    https://docs.microsoft.com/en-us/graph/api/serviceprincipal-list?view=graph-rest-1.0&tabs=http
+    """
+
+    api_url = "https://graph.microsoft.com/v1.0/servicePrincipals?$filter=displayName eq '" + AzureCloudCheckrApplicationName + "'"
+    authorization_value = "Bearer " + AzureGraphApiBearerToken
+
+    response = requests.get(api_url, headers={"Authorization": authorization_value})
+
+    if "value" in response.json():
+        value = (response.json()["value"])[0]
+        if ("id" in value) and ("appId" in value):
+            return value["id"], value["appId"]
+    print("Failed to get the Azure CloudCheckr Application Service principal Id")
+    return None
+
+
+def set_azure_cloudcheckr_application_service_assignment(AzureApiBearerToken, AzureReaderRoleId,
+                                                         AzureCloudCheckrApplicationServicePrincipalId,
+                                                         AzureSubscriptionId):
+    """
+    Sets the previously created CloudCheckr application to have a reader role assignment.
+
+    https://docs.microsoft.com/en-us/azure/role-based-access-control/role-assignments-rest
+    """
+
+    RoleAssignmentId = str(uuid.uuid1())
+
+    api_url = "https://management.azure.com/subscriptions/" + AzureSubscriptionId + "/providers/Microsoft.Authorization/roleAssignments/" + RoleAssignmentId + "?api-version=2015-07-01"
+    authorization_value = "Bearer " + AzureApiBearerToken
+    role_assignment_data = json.dumps({"properties": {"principalId": AzureCloudCheckrApplicationServicePrincipalId,
+                                                      "roleDefinitionId": AzureReaderRoleId}})
+
+    response = requests.put(api_url, headers={"Authorization": authorization_value, "Content-Type": "application/json"},
+                            data=role_assignment_data)
+    print(response.json())
+
+    if "properties" in response.json():
+        properties = response.json()["properties"]
+        if "roleDefinitionId" in properties:
+            return properties["roleDefinitionId"]
+    print("Failed to set role assignment for the CloudCheckr Application to the specified subscription")
+    return None
+
+
+def get_azure_bearer_token(resource_url, azure_directory_id, azure_admin_application_id,
+                           azure_admin_application_secret):
+    """
+    Uses OAuth 2.0 to get the bearer token based on the client id and client secret.
+    """
+
+    api_url = "https://login.microsoftonline.com/" + azure_directory_id + "/oauth2/token"
+
+    client = {'grant_type': 'client_credentials',
+              'client_id': azure_admin_application_id,
+              'client_secret': azure_admin_application_secret,
+              'resource': resource_url,
+              }
+
+    response = requests.post(api_url, data=client)
+
+    if "access_token" in response.json():
+        return response.json()["access_token"]
+    print("Could not get Bearer token")
+    return None
+
+
+def main():
+    try:
+        CloudCheckrApiKey = str(sys.argv[1])
+    except IndexError:
+        print("Must include an admin api key in the command line")
+        return
+
+    try:
+        NameOfCloudCheckrAccount = str(sys.argv[2])
+    except IndexError:
+        print("Must include a cloudcheckr account name")
+        return
+
+    try:
+        AzureDirectoryId = str(sys.argv[3])
+    except IndexError:
+        print("Must include an Azure Directory Id")
+        return
+
+    try:
+        AzureSubscriptionId = str(sys.argv[4])
+    except IndexError:
+        print("Must include an Azure Subscription Id")
+        return
+
+    try:
+        AzureAdminApplicationId = str(sys.argv[5])
+    except IndexError:
+        print("Must include an Azure Admin ApplictApi Id")
+        return
+
+    try:
+        AzureAdminApplicationSecret = str(sys.argv[6])
+    except IndexError:
+        print("Must include an Azure Admin Application Secret")
+        return
+
+    try:
+        AzureCloudCheckrApplicationName = str(sys.argv[7])
+    except IndexError:
+        print("Must include an Azure CloudCheckr Application Name")
+        return
+
+    try:
+        AzureCloudCheckrApplicationSecret = str(sys.argv[8])
+    except IndexError:
+        print("Must include an Azure CloudCheckr Application Secret")
+        return
+
+    env = "https://glacier.cloudcheckr.com"
+
+    AzureApiBearerToken = get_azure_bearer_token("https://management.azure.com/", AzureDirectoryId,
+                                                 AzureAdminApplicationId, AzureAdminApplicationSecret)
+    AzureGraphApiBearerToken = get_azure_bearer_token("https://graph.microsoft.com/", AzureDirectoryId,
+                                                      AzureAdminApplicationId, AzureAdminApplicationSecret)
+    AzureReaderRoleId = get_azure_reader_role_id(AzureApiBearerToken, AzureSubscriptionId)
+    AzureCloudCheckrApplicationServicePrincipalId, AzureCloudCheckrApplicationId = get_azure_cloudcheckr_service_principal_id(
+        AzureGraphApiBearerToken, AzureCloudCheckrApplicationName)
+    set_azure_cloudcheckr_application_service_assignment(AzureApiBearerToken, AzureReaderRoleId,
+                                                         AzureCloudCheckrApplicationServicePrincipalId,
+                                                         AzureSubscriptionId)
+    create_azure_account(env, CloudCheckrApiKey, NameOfCloudCheckrAccount, AzureDirectoryId,
+                         AzureCloudCheckrApplicationId, AzureCloudCheckrApplicationSecret, AzureSubscriptionId)
+
+
+if __name__ == "__main__":
+    main()

+ 99 - 0
Azure/adls.py

@@ -0,0 +1,99 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import TYPE_CHECKING, Any, Sequence
+
+from airflow.models import BaseOperator
+from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook
+
+if TYPE_CHECKING:
+    from airflow.utils.context import Context
+
+
+class ADLSDeleteOperator(BaseOperator):
+    """
+    Delete files in the specified path.
+
+        .. seealso::
+            For more information on how to use this operator, take a look at the guide:
+            :ref:`howto/operator:ADLSDeleteOperator`
+
+    :param path: A directory or file to remove
+    :param recursive: Whether to loop into directories in the location and remove the files
+    :param ignore_not_found: Whether to raise error if file to delete is not found
+    :param azure_data_lake_conn_id: Reference to the :ref:`Azure Data Lake connection<howto/connection:adl>`.
+    """
+
+    template_fields: Sequence[str] = ('path',)
+    ui_color = '#901dd2'
+
+    def __init__(
+        self,
+        *,
+        path: str,
+        recursive: bool = False,
+        ignore_not_found: bool = True,
+        azure_data_lake_conn_id: str = 'azure_data_lake_default',
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.path = path
+        self.recursive = recursive
+        self.ignore_not_found = ignore_not_found
+        self.azure_data_lake_conn_id = azure_data_lake_conn_id
+
+    def execute(self, context: "Context") -> Any:
+        hook = AzureDataLakeHook(azure_data_lake_conn_id=self.azure_data_lake_conn_id)
+        return hook.remove(path=self.path, recursive=self.recursive, ignore_not_found=self.ignore_not_found)
+
+
+class ADLSListOperator(BaseOperator):
+    """
+    List all files from the specified path
+
+    This operator returns a python list with the names of files which can be used by
+     `xcom` in the downstream tasks.
+
+    :param path: The Azure Data Lake path to find the objects. Supports glob
+        strings (templated)
+    :param azure_data_lake_conn_id: Reference to the :ref:`Azure Data Lake connection<howto/connection:adl>`.
+
+    **Example**:
+        The following Operator would list all the Parquet files from ``folder/output/``
+        folder in the specified ADLS account ::
+
+            adls_files = ADLSListOperator(
+                task_id='adls_files',
+                path='folder/output/*.parquet',
+                azure_data_lake_conn_id='azure_data_lake_default'
+            )
+    """
+
+    template_fields: Sequence[str] = ('path',)
+    ui_color = '#901dd2'
+
+    def __init__(
+        self, *, path: str, azure_data_lake_conn_id: str = 'azure_data_lake_default', **kwargs
+    ) -> None:
+        super().__init__(**kwargs)
+        self.path = path
+        self.azure_data_lake_conn_id = azure_data_lake_conn_id
+
+    def execute(self, context: "Context") -> list:
+        hook = AzureDataLakeHook(azure_data_lake_conn_id=self.azure_data_lake_conn_id)
+        self.log.info('Getting list of ADLS files in path: %s', self.path)
+        return hook.list(path=self.path)

+ 79 - 0
Azure/adx.py

@@ -0,0 +1,79 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+"""This module contains Azure Data Explorer operators"""
+from typing import TYPE_CHECKING, Optional, Sequence, Union
+
+from azure.kusto.data._models import KustoResultTable
+
+from airflow.configuration import conf
+from airflow.models import BaseOperator
+from airflow.providers.microsoft.azure.hooks.adx import AzureDataExplorerHook
+
+if TYPE_CHECKING:
+    from airflow.utils.context import Context
+
+
+class AzureDataExplorerQueryOperator(BaseOperator):
+    """
+    Operator for querying Azure Data Explorer (Kusto).
+
+    :param query: KQL query to run (templated).
+    :param database: Database to run the query on (templated).
+    :param options: Optional query options. See:
+      https://docs.microsoft.com/en-us/azure/kusto/api/netfx/request-properties#list-of-clientrequestproperties
+    :param azure_data_explorer_conn_id: Reference to the
+        :ref:`Azure Data Explorer connection<howto/connection:adx>`.
+    """
+
+    ui_color = '#00a1f2'
+    template_fields: Sequence[str] = ('query', 'database')
+    template_ext: Sequence[str] = ('.kql',)
+
+    def __init__(
+        self,
+        *,
+        query: str,
+        database: str,
+        options: Optional[dict] = None,
+        azure_data_explorer_conn_id: str = 'azure_data_explorer_default',
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.query = query
+        self.database = database
+        self.options = options
+        self.azure_data_explorer_conn_id = azure_data_explorer_conn_id
+
+    def get_hook(self) -> AzureDataExplorerHook:
+        """Returns new instance of AzureDataExplorerHook"""
+        return AzureDataExplorerHook(self.azure_data_explorer_conn_id)
+
+    def execute(self, context: "Context") -> Union[KustoResultTable, str]:
+        """
+        Run KQL Query on Azure Data Explorer (Kusto).
+        Returns `PrimaryResult` of Query v2 HTTP response contents
+        (https://docs.microsoft.com/en-us/azure/kusto/api/rest/response2)
+        """
+        hook = self.get_hook()
+        response = hook.run_query(self.query, self.database, self.options)
+        if conf.getboolean('core', 'enable_xcom_pickling'):
+            return response.primary_results[0]
+        else:
+            return str(response.primary_results[0])

+ 129 - 0
Azure/azure_clients.py

@@ -0,0 +1,129 @@
+import json
+from azure.identity import ClientSecretCredential
+from azure.common.credentials import ServicePrincipalCredentials
+from azure.mgmt.compute import ComputeManagementClient
+from azure.mgmt.network import NetworkManagementClient
+from azure.mgmt.storage import StorageManagementClient
+from azure.mgmt.privatedns import PrivateDnsManagementClient
+from azure.mgmt.dns import DnsManagementClient
+from azure.storage.blob import BlobServiceClient
+from azure.storage.queue import QueueServiceClient
+from azure.mgmt.datalake.store import DataLakeStoreAccountManagementClient
+from azure.mgmt.resource import ResourceManagementClient
+
+############ Resource Management Client ########
+def get_resourcegroup_client(parameters):
+    tenant_id = parameters.get('azure_tenant_id')
+    client_id = parameters.get('azure_client_id')
+    secret = parameters.get('azure_client_secret')
+    subscription_id = parameters.get('azure_subscription_id')
+
+    token_credential = ClientSecretCredential(
+        tenant_id, client_id, secret)
+    resourcegroup_client = ResourceManagementClient(token_credential, subscription_id)
+    return resourcegroup_client
+
+########### Resource Management Client End #####
+def get_compute_client(parameters):
+    tenant_id = parameters.get('azure_tenant_id')
+    client_id = parameters.get('azure_client_id')
+    secret = parameters.get('azure_client_secret')
+    subscription_id = parameters.get('azure_subscription_id')
+
+    token_credential = ClientSecretCredential(
+        tenant_id, client_id, secret)
+    compute_client = ComputeManagementClient(token_credential,
+                                             subscription_id)
+    return compute_client
+
+
+def get_network_client(parameters):
+    tenant_id = parameters.get('azure_tenant_id')
+    client_id = parameters.get('azure_client_id')
+    secret = parameters.get('azure_client_secret')
+    subscription_id = parameters.get('azure_subscription_id')
+
+    token_credential = ClientSecretCredential(
+        tenant_id, client_id, secret)
+    network_client = NetworkManagementClient(token_credential,
+                                             subscription_id)
+    return network_client
+
+
+def get_dns_client(parameters):
+    tenant_id = parameters.get('azure_tenant_id')
+    client_id = parameters.get('azure_client_id')
+    secret = parameters.get('azure_client_secret')
+    subscription_id = parameters.get('azure_subscription_id')
+
+    token_credential = ClientSecretCredential(
+        tenant_id, client_id, secret)
+    dns_client = PrivateDnsManagementClient(token_credential,
+                                            subscription_id)
+    return dns_client
+
+
+def get_dns_ops_client(parameters):
+    tenant_id = parameters.get('azure_tenant_id')
+    client_id = parameters.get('azure_client_id')
+    secret = parameters.get('azure_client_secret')
+    subscription_id = parameters.get('azure_subscription_id')
+
+    token_credential = ClientSecretCredential(
+        tenant_id, client_id, secret)
+    dns_ops_client = DnsManagementClient(token_credential,
+                                            subscription_id)
+    return dns_ops_client
+
+
+
+def get_blob_service_client(parameters):
+    tenant_id = parameters.get('azure_tenant_id')
+    client_id = parameters.get('azure_client_id')
+    secret = parameters.get('azure_client_secret')
+    account_name = parameters.get('storage_account_name')
+    token_credential = ClientSecretCredential(
+        tenant_id, client_id, secret)
+    blob_service_client = BlobServiceClient(
+        account_url="https://%s.blob.core.windows.net" % account_name,
+        credential=token_credential)
+    return blob_service_client
+
+
+def get_queue_service_client(parameters):
+    tenant_id = parameters.get('azure_tenant_id')
+    client_id = parameters.get('azure_client_id')
+    secret = parameters.get('azure_client_secret')
+    account_name = parameters.get('storage_account_name')
+    token_credential = ClientSecretCredential(
+        tenant_id, client_id, secret)
+    queue_service_client = QueueServiceClient(
+        account_url="https://%s.queue.core.windows.net" % account_name,
+        credential=token_credential)
+    return queue_service_client
+
+def get_datalake_client(parameters):
+    tenant_id = parameters.get('azure_tenant_id')
+    client_id = parameters.get('azure_client_id')
+    secret = parameters.get('azure_client_secret')
+    subscription_id = parameters.get('azure_subscription_id')
+    credentials = ServicePrincipalCredentials(
+        client_id=client_id,
+        secret=secret,
+        tenant=tenant_id)
+
+    datalake_client = DataLakeStoreAccountManagementClient(credentials,
+                                                           subscription_id)
+    return datalake_client
+
+def get_storage_client(parameters):
+    tenant_id = parameters.get('azure_tenant_id')
+    client_id = parameters.get('azure_client_id')
+    secret = parameters.get('azure_client_secret')
+    subscription_id = parameters.get('azure_subscription_id')
+
+    token_credential = ClientSecretCredential(
+        tenant_id, client_id, secret)
+    storage_client = StorageManagementClient(token_credential,
+                                             subscription_id)
+    return storage_client

+ 855 - 0
Azure/azure_rm.py

@@ -0,0 +1,855 @@
+#!/usr/bin/env python
+#
+# Copyright (c) 2016 Matt Davis, <mdavis@ansible.com>
+#                    Chris Houseknecht, <house@redhat.com>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible.  If not, see <http://www.gnu.org/licenses/>.
+#
+
+'''
+Azure External Inventory Script
+===============================
+Generates dynamic inventory by making API requests to the Azure Resource
+Manager using the Azure Python SDK. For instruction on installing the
+Azure Python SDK see http://azure-sdk-for-python.readthedocs.org/
+
+Authentication
+--------------
+The order of precedence is command line arguments, environment variables,
+and finally the [default] profile found in ~/.azure/credentials.
+
+If using a credentials file, it should be an ini formatted file with one or
+more sections, which we refer to as profiles. The script looks for a
+[default] section, if a profile is not specified either on the command line
+or with an environment variable. The keys in a profile will match the
+list of command line arguments below.
+
+For command line arguments and environment variables specify a profile found
+in your ~/.azure/credentials file, or a service principal or Active Directory
+user.
+
+Command line arguments:
+ - profile
+ - client_id
+ - secret
+ - subscription_id
+ - tenant
+ - ad_user
+ - password
+ - cloud_environment
+
+Environment variables:
+ - AZURE_PROFILE
+ - AZURE_CLIENT_ID
+ - AZURE_SECRET
+ - AZURE_SUBSCRIPTION_ID
+ - AZURE_TENANT
+ - AZURE_AD_USER
+ - AZURE_PASSWORD
+ - AZURE_CLOUD_ENVIRONMENT
+
+Run for Specific Host
+-----------------------
+When run for a specific host using the --host option, a resource group is
+required. For a specific host, this script returns the following variables:
+
+{
+  "ansible_host": "XXX.XXX.XXX.XXX",
+  "computer_name": "computer_name2",
+  "fqdn": null,
+  "id": "/subscriptions/subscription-id/resourceGroups/galaxy-production/providers/Microsoft.Compute/virtualMachines/object-name",
+  "image": {
+    "offer": "CentOS",
+    "publisher": "OpenLogic",
+    "sku": "7.1",
+    "version": "latest"
+  },
+  "location": "westus",
+  "mac_address": "00-00-5E-00-53-FE",
+  "name": "object-name",
+  "network_interface": "interface-name",
+  "network_interface_id": "/subscriptions/subscription-id/resourceGroups/galaxy-production/providers/Microsoft.Network/networkInterfaces/object-name1",
+  "network_security_group": null,
+  "network_security_group_id": null,
+  "os_disk": {
+    "name": "object-name",
+    "operating_system_type": "Linux"
+  },
+  "plan": null,
+  "powerstate": "running",
+  "private_ip": "172.26.3.6",
+  "private_ip_alloc_method": "Static",
+  "provisioning_state": "Succeeded",
+  "public_ip": "XXX.XXX.XXX.XXX",
+  "public_ip_alloc_method": "Static",
+  "public_ip_id": "/subscriptions/subscription-id/resourceGroups/galaxy-production/providers/Microsoft.Network/publicIPAddresses/object-name",
+  "public_ip_name": "object-name",
+  "resource_group": "galaxy-production",
+  "security_group": "object-name",
+  "security_group_id": "/subscriptions/subscription-id/resourceGroups/galaxy-production/providers/Microsoft.Network/networkSecurityGroups/object-name",
+  "tags": {
+      "db": "database"
+  },
+  "type": "Microsoft.Compute/virtualMachines",
+  "virtual_machine_size": "Standard_DS4"
+}
+
+Groups
+------
+When run in --list mode, instances are grouped by the following categories:
+ - azure
+ - location
+ - resource_group
+ - security_group
+ - tag key
+ - tag key_value
+
+Control groups using azure_rm.ini or set environment variables:
+
+AZURE_GROUP_BY_RESOURCE_GROUP=yes
+AZURE_GROUP_BY_LOCATION=yes
+AZURE_GROUP_BY_SECURITY_GROUP=yes
+AZURE_GROUP_BY_TAG=yes
+
+Select hosts within specific resource groups by assigning a comma separated list to:
+
+AZURE_RESOURCE_GROUPS=resource_group_a,resource_group_b
+
+Select hosts for specific tag key by assigning a comma separated list of tag keys to:
+
+AZURE_TAGS=key1,key2,key3
+
+Select hosts for specific locations:
+
+AZURE_LOCATIONS=eastus,westus,eastus2
+
+Or, select hosts for specific tag key:value pairs by assigning a comma separated list key:value pairs to:
+
+AZURE_TAGS=key1:value1,key2:value2
+
+If you don't need the powerstate, you can improve performance by turning off powerstate fetching:
+AZURE_INCLUDE_POWERSTATE=no
+
+azure_rm.ini
+------------
+As mentioned above, you can control execution using environment variables or a .ini file. A sample
+azure_rm.ini is included. The name of the .ini file is the basename of the inventory script (in this case
+'azure_rm') with a .ini extension. It also assumes the .ini file is alongside the script. To specify
+a different path for the .ini file, define the AZURE_INI_PATH environment variable:
+
+  export AZURE_INI_PATH=/path/to/custom.ini
+
+Powerstate:
+-----------
+The powerstate attribute indicates whether or not a host is running. If the value is 'running', the machine is
+up. If the value is anything other than 'running', the machine is down, and will be unreachable.
+
+Examples:
+---------
+  Execute /bin/uname on all instances in the galaxy-qa resource group
+  $ ansible -i azure_rm.py galaxy-qa -m shell -a "/bin/uname -a"
+
+  Use the inventory script to print instance specific information
+  $ contrib/inventory/azure_rm.py --host my_instance_host_name --pretty
+
+  Use with a playbook
+  $ ansible-playbook -i contrib/inventory/azure_rm.py my_playbook.yml --limit galaxy-qa
+
+
+Insecure Platform Warning
+-------------------------
+If you receive InsecurePlatformWarning from urllib3, install the
+requests security packages:
+
+    pip install requests[security]
+
+
+author:
+    - Chris Houseknecht (@chouseknecht)
+    - Matt Davis (@nitzmahone)
+
+Company: Ansible by Red Hat
+
+Version: 1.0.0
+'''
+
+import argparse
+import ConfigParser
+import json
+import os
+import re
+import sys
+import inspect
+import traceback
+
+
+from packaging.version import Version
+
+from os.path import expanduser
+import ansible.module_utils.six.moves.urllib.parse as urlparse
+
+HAS_AZURE = True
+HAS_AZURE_EXC = None
+
+try:
+    from msrestazure.azure_exceptions import CloudError
+    from msrestazure import azure_cloud
+    from azure.mgmt.compute import __version__ as azure_compute_version
+    from azure.common import AzureMissingResourceHttpError, AzureHttpError
+    from azure.common.credentials import ServicePrincipalCredentials, UserPassCredentials
+    from azure.mgmt.network import NetworkManagementClient
+    from azure.mgmt.resource.resources import ResourceManagementClient
+    from azure.mgmt.compute import ComputeManagementClient
+except ImportError as exc:
+    HAS_AZURE_EXC = exc
+    HAS_AZURE = False
+
+
+AZURE_CREDENTIAL_ENV_MAPPING = dict(
+    profile='AZURE_PROFILE',
+    subscription_id='AZURE_SUBSCRIPTION_ID',
+    client_id='AZURE_CLIENT_ID',
+    secret='AZURE_SECRET',
+    tenant='AZURE_TENANT',
+    ad_user='AZURE_AD_USER',
+    password='AZURE_PASSWORD',
+    cloud_environment='AZURE_CLOUD_ENVIRONMENT',
+)
+
+AZURE_CONFIG_SETTINGS = dict(
+    resource_groups='AZURE_RESOURCE_GROUPS',
+    tags='AZURE_TAGS',
+    locations='AZURE_LOCATIONS',
+    include_powerstate='AZURE_INCLUDE_POWERSTATE',
+    group_by_resource_group='AZURE_GROUP_BY_RESOURCE_GROUP',
+    group_by_location='AZURE_GROUP_BY_LOCATION',
+    group_by_security_group='AZURE_GROUP_BY_SECURITY_GROUP',
+    group_by_tag='AZURE_GROUP_BY_TAG'
+)
+
+AZURE_MIN_VERSION = "2.0.0"
+
+
+def azure_id_to_dict(id):
+    pieces = re.sub(r'^\/', '', id).split('/')
+    result = {}
+    index = 0
+    while index < len(pieces) - 1:
+        result[pieces[index]] = pieces[index + 1]
+        index += 1
+    return result
+
+
+class AzureRM(object):
+
+    def __init__(self, args):
+        self._args = args
+        self._cloud_environment = None
+        self._compute_client = None
+        self._resource_client = None
+        self._network_client = None
+
+        self.debug = False
+        if args.debug:
+            self.debug = True
+
+        self.credentials = self._get_credentials(args)
+        if not self.credentials:
+            self.fail("Failed to get credentials. Either pass as parameters, set environment variables, "
+                      "or define a profile in ~/.azure/credentials.")
+
+        # if cloud_environment specified, look up/build Cloud object
+        raw_cloud_env = self.credentials.get('cloud_environment')
+        if not raw_cloud_env:
+            self._cloud_environment = azure_cloud.AZURE_PUBLIC_CLOUD  # SDK default
+        else:
+            # try to look up "well-known" values via the name attribute on azure_cloud members
+            all_clouds = [x[1] for x in inspect.getmembers(azure_cloud) if isinstance(x[1], azure_cloud.Cloud)]
+            matched_clouds = [x for x in all_clouds if x.name == raw_cloud_env]
+            if len(matched_clouds) == 1:
+                self._cloud_environment = matched_clouds[0]
+            elif len(matched_clouds) > 1:
+                self.fail("Azure SDK failure: more than one cloud matched for cloud_environment name '{0}'".format(raw_cloud_env))
+            else:
+                if not urlparse.urlparse(raw_cloud_env).scheme:
+                    self.fail("cloud_environment must be an endpoint discovery URL or one of {0}".format([x.name for x in all_clouds]))
+                try:
+                    self._cloud_environment = azure_cloud.get_cloud_from_metadata_endpoint(raw_cloud_env)
+                except Exception as e:
+                    self.fail("cloud_environment {0} could not be resolved: {1}".format(raw_cloud_env, e.message))
+
+        if self.credentials.get('subscription_id', None) is None:
+            self.fail("Credentials did not include a subscription_id value.")
+        self.log("setting subscription_id")
+        self.subscription_id = self.credentials['subscription_id']
+
+        if self.credentials.get('client_id') is not None and \
+           self.credentials.get('secret') is not None and \
+           self.credentials.get('tenant') is not None:
+            self.azure_credentials = ServicePrincipalCredentials(client_id=self.credentials['client_id'],
+                                                                 secret=self.credentials['secret'],
+                                                                 tenant=self.credentials['tenant'],
+                                                                 cloud_environment=self._cloud_environment)
+        elif self.credentials.get('ad_user') is not None and self.credentials.get('password') is not None:
+            tenant = self.credentials.get('tenant')
+            if not tenant:
+                tenant = 'common'
+            self.azure_credentials = UserPassCredentials(self.credentials['ad_user'],
+                                                         self.credentials['password'],
+                                                         tenant=tenant,
+                                                         cloud_environment=self._cloud_environment)
+        else:
+            self.fail("Failed to authenticate with provided credentials. Some attributes were missing. "
+                      "Credentials must include client_id, secret and tenant or ad_user and password.")
+
+    def log(self, msg):
+        if self.debug:
+            print(msg + u'\n')
+
+    def fail(self, msg):
+        raise Exception(msg)
+
+    def _get_profile(self, profile="default"):
+        path = expanduser("~")
+        path += "/.azure/credentials"
+        try:
+            config = ConfigParser.ConfigParser()
+            config.read(path)
+        except Exception as exc:
+            self.fail("Failed to access {0}. Check that the file exists and you have read "
+                      "access. {1}".format(path, str(exc)))
+        credentials = dict()
+        for key in AZURE_CREDENTIAL_ENV_MAPPING:
+            try:
+                credentials[key] = config.get(profile, key, raw=True)
+            except:
+                pass
+
+        if credentials.get('client_id') is not None or credentials.get('ad_user') is not None:
+            return credentials
+
+        return None
+
+    def _get_env_credentials(self):
+        env_credentials = dict()
+        for attribute, env_variable in AZURE_CREDENTIAL_ENV_MAPPING.items():
+            env_credentials[attribute] = os.environ.get(env_variable, None)
+
+        if env_credentials['profile'] is not None:
+            credentials = self._get_profile(env_credentials['profile'])
+            return credentials
+
+        if env_credentials['client_id'] is not None or env_credentials['ad_user'] is not None:
+            return env_credentials
+
+        return None
+
+    def _get_credentials(self, params):
+        # Get authentication credentials.
+        # Precedence: cmd line parameters-> environment variables-> default profile in ~/.azure/credentials.
+
+        self.log('Getting credentials')
+
+        arg_credentials = dict()
+        for attribute, env_variable in AZURE_CREDENTIAL_ENV_MAPPING.items():
+            arg_credentials[attribute] = getattr(params, attribute)
+
+        # try module params
+        if arg_credentials['profile'] is not None:
+            self.log('Retrieving credentials with profile parameter.')
+            credentials = self._get_profile(arg_credentials['profile'])
+            return credentials
+
+        if arg_credentials['client_id'] is not None:
+            self.log('Received credentials from parameters.')
+            return arg_credentials
+
+        if arg_credentials['ad_user'] is not None:
+            self.log('Received credentials from parameters.')
+            return arg_credentials
+
+        # try environment
+        env_credentials = self._get_env_credentials()
+        if env_credentials:
+            self.log('Received credentials from env.')
+            return env_credentials
+
+        # try default profile from ~./azure/credentials
+        default_credentials = self._get_profile()
+        if default_credentials:
+            self.log('Retrieved default profile credentials from ~/.azure/credentials.')
+            return default_credentials
+
+        return None
+
+    def _register(self, key):
+        try:
+            # We have to perform the one-time registration here. Otherwise, we receive an error the first
+            # time we attempt to use the requested client.
+            resource_client = self.rm_client
+            resource_client.providers.register(key)
+        except Exception as exc:
+            self.log("One-time registration of {0} failed - {1}".format(key, str(exc)))
+            self.log("You might need to register {0} using an admin account".format(key))
+            self.log(("To register a provider using the Python CLI: "
+                      "https://docs.microsoft.com/azure/azure-resource-manager/"
+                      "resource-manager-common-deployment-errors#noregisteredproviderfound"))
+
+    @property
+    def network_client(self):
+        self.log('Getting network client')
+        if not self._network_client:
+            self._network_client = NetworkManagementClient(
+                self.azure_credentials,
+                self.subscription_id,
+                base_url=self._cloud_environment.endpoints.resource_manager,
+                api_version='2017-06-01'
+            )
+            self._register('Microsoft.Network')
+        return self._network_client
+
+    @property
+    def rm_client(self):
+        self.log('Getting resource manager client')
+        if not self._resource_client:
+            self._resource_client = ResourceManagementClient(
+                self.azure_credentials,
+                self.subscription_id,
+                base_url=self._cloud_environment.endpoints.resource_manager,
+                api_version='2017-05-10'
+            )
+        return self._resource_client
+
+    @property
+    def compute_client(self):
+        self.log('Getting compute client')
+        if not self._compute_client:
+            self._compute_client = ComputeManagementClient(
+                self.azure_credentials,
+                self.subscription_id,
+                base_url=self._cloud_environment.endpoints.resource_manager,
+                api_version='2017-03-30'
+            )
+            self._register('Microsoft.Compute')
+        return self._compute_client
+
+
+class AzureInventory(object):
+
+    def __init__(self):
+
+        self._args = self._parse_cli_args()
+
+        try:
+            rm = AzureRM(self._args)
+        except Exception as e:
+            sys.exit("{0}".format(str(e)))
+
+        self._compute_client = rm.compute_client
+        self._network_client = rm.network_client
+        self._resource_client = rm.rm_client
+        self._security_groups = None
+
+        self.resource_groups = []
+        self.tags = None
+        self.locations = None
+        self.replace_dash_in_groups = False
+        self.group_by_resource_group = True
+        self.group_by_location = True
+        self.group_by_security_group = True
+        self.group_by_tag = True
+        self.include_powerstate = True
+
+        self._inventory = dict(
+            _meta=dict(
+                hostvars=dict()
+            ),
+            azure=[]
+        )
+
+        self._get_settings()
+
+        if self._args.resource_groups:
+            self.resource_groups = self._args.resource_groups.split(',')
+
+        if self._args.tags:
+            self.tags = self._args.tags.split(',')
+
+        if self._args.locations:
+            self.locations = self._args.locations.split(',')
+
+        if self._args.no_powerstate:
+            self.include_powerstate = False
+
+        self.get_inventory()
+        print(self._json_format_dict(pretty=self._args.pretty))
+        sys.exit(0)
+
+    def _parse_cli_args(self):
+        # Parse command line arguments
+        parser = argparse.ArgumentParser(
+            description='Produce an Ansible Inventory file for an Azure subscription')
+        parser.add_argument('--list', action='store_true', default=True,
+                            help='List instances (default: True)')
+        parser.add_argument('--debug', action='store_true', default=False,
+                            help='Send debug messages to STDOUT')
+        parser.add_argument('--host', action='store',
+                            help='Get all information about an instance')
+        parser.add_argument('--pretty', action='store_true', default=False,
+                            help='Pretty print JSON output(default: False)')
+        parser.add_argument('--profile', action='store',
+                            help='Azure profile contained in ~/.azure/credentials')
+        parser.add_argument('--subscription_id', action='store',
+                            help='Azure Subscription Id')
+        parser.add_argument('--client_id', action='store',
+                            help='Azure Client Id ')
+        parser.add_argument('--secret', action='store',
+                            help='Azure Client Secret')
+        parser.add_argument('--tenant', action='store',
+                            help='Azure Tenant Id')
+        parser.add_argument('--ad_user', action='store',
+                            help='Active Directory User')
+        parser.add_argument('--password', action='store',
+                            help='password')
+        parser.add_argument('--cloud_environment', action='store',
+                            help='Azure Cloud Environment name or metadata discovery URL')
+        parser.add_argument('--resource-groups', action='store',
+                            help='Return inventory for comma separated list of resource group names')
+        parser.add_argument('--tags', action='store',
+                            help='Return inventory for comma separated list of tag key:value pairs')
+        parser.add_argument('--locations', action='store',
+                            help='Return inventory for comma separated list of locations')
+        parser.add_argument('--no-powerstate', action='store_true', default=False,
+                            help='Do not include the power state of each virtual host')
+        return parser.parse_args()
+
+    def get_inventory(self):
+        if len(self.resource_groups) > 0:
+            # get VMs for requested resource groups
+            for resource_group in self.resource_groups:
+                try:
+                    virtual_machines = self._compute_client.virtual_machines.list(resource_group)
+                except Exception as exc:
+                    sys.exit("Error: fetching virtual machines for resource group {0} - {1}".format(resource_group, str(exc)))
+                if self._args.host or self.tags:
+                    selected_machines = self._selected_machines(virtual_machines)
+                    self._load_machines(selected_machines)
+                else:
+                    self._load_machines(virtual_machines)
+        else:
+            # get all VMs within the subscription
+            try:
+                virtual_machines = self._compute_client.virtual_machines.list_all()
+            except Exception as exc:
+                sys.exit("Error: fetching virtual machines - {0}".format(str(exc)))
+
+            if self._args.host or self.tags or self.locations:
+                selected_machines = self._selected_machines(virtual_machines)
+                self._load_machines(selected_machines)
+            else:
+                self._load_machines(virtual_machines)
+
+    def _load_machines(self, machines):
+        for machine in machines:
+            id_dict = azure_id_to_dict(machine.id)
+
+            # TODO - The API is returning an ID value containing resource group name in ALL CAPS. If/when it gets
+            #       fixed, we should remove the .lower(). Opened Issue
+            #       #574: https://github.com/Azure/azure-sdk-for-python/issues/574
+            resource_group = id_dict['resourceGroups'].lower()
+
+            if self.group_by_security_group:
+                self._get_security_groups(resource_group)
+
+            host_vars = dict(
+                ansible_host=None,
+                private_ip=None,
+                private_ip_alloc_method=None,
+                public_ip=None,
+                public_ip_name=None,
+                public_ip_id=None,
+                public_ip_alloc_method=None,
+                fqdn=None,
+                location=machine.location,
+                name=machine.name,
+                type=machine.type,
+                id=machine.id,
+                tags=machine.tags,
+                network_interface_id=None,
+                network_interface=None,
+                resource_group=resource_group,
+                mac_address=None,
+                plan=(machine.plan.name if machine.plan else None),
+                virtual_machine_size=machine.hardware_profile.vm_size,
+                computer_name=(machine.os_profile.computer_name if machine.os_profile else None),
+                provisioning_state=machine.provisioning_state,
+            )
+
+            host_vars['os_disk'] = dict(
+                name=machine.storage_profile.os_disk.name,
+                operating_system_type=machine.storage_profile.os_disk.os_type.value
+            )
+
+            if self.include_powerstate:
+                host_vars['powerstate'] = self._get_powerstate(resource_group, machine.name)
+
+            if machine.storage_profile.image_reference:
+                host_vars['image'] = dict(
+                    offer=machine.storage_profile.image_reference.offer,
+                    publisher=machine.storage_profile.image_reference.publisher,
+                    sku=machine.storage_profile.image_reference.sku,
+                    version=machine.storage_profile.image_reference.version
+                )
+
+            # Add windows details
+            if machine.os_profile is not None and machine.os_profile.windows_configuration is not None:
+                host_vars['windows_auto_updates_enabled'] = \
+                    machine.os_profile.windows_configuration.enable_automatic_updates
+                host_vars['windows_timezone'] = machine.os_profile.windows_configuration.time_zone
+                host_vars['windows_rm'] = None
+                if machine.os_profile.windows_configuration.win_rm is not None:
+                    host_vars['windows_rm'] = dict(listeners=None)
+                    if machine.os_profile.windows_configuration.win_rm.listeners is not None:
+                        host_vars['windows_rm']['listeners'] = []
+                        for listener in machine.os_profile.windows_configuration.win_rm.listeners:
+                            host_vars['windows_rm']['listeners'].append(dict(protocol=listener.protocol,
+                                                                             certificate_url=listener.certificate_url))
+
+            for interface in machine.network_profile.network_interfaces:
+                interface_reference = self._parse_ref_id(interface.id)
+                network_interface = self._network_client.network_interfaces.get(
+                    interface_reference['resourceGroups'],
+                    interface_reference['networkInterfaces'])
+                if network_interface.primary:
+                    if self.group_by_security_group and \
+                       self._security_groups[resource_group].get(network_interface.id, None):
+                        host_vars['security_group'] = \
+                            self._security_groups[resource_group][network_interface.id]['name']
+                        host_vars['security_group_id'] = \
+                            self._security_groups[resource_group][network_interface.id]['id']
+                    host_vars['network_interface'] = network_interface.name
+                    host_vars['network_interface_id'] = network_interface.id
+                    host_vars['mac_address'] = network_interface.mac_address
+                    for ip_config in network_interface.ip_configurations:
+                        host_vars['private_ip'] = ip_config.private_ip_address
+                        host_vars['private_ip_alloc_method'] = ip_config.private_ip_allocation_method
+                        if ip_config.public_ip_address:
+                            public_ip_reference = self._parse_ref_id(ip_config.public_ip_address.id)
+                            public_ip_address = self._network_client.public_ip_addresses.get(
+                                public_ip_reference['resourceGroups'],
+                                public_ip_reference['publicIPAddresses'])
+                            host_vars['ansible_host'] = public_ip_address.ip_address
+                            host_vars['public_ip'] = public_ip_address.ip_address
+                            host_vars['public_ip_name'] = public_ip_address.name
+                            host_vars['public_ip_alloc_method'] = public_ip_address.public_ip_allocation_method
+                            host_vars['public_ip_id'] = public_ip_address.id
+                            if public_ip_address.dns_settings:
+                                host_vars['fqdn'] = public_ip_address.dns_settings.fqdn
+
+            self._add_host(host_vars)
+
+    def _selected_machines(self, virtual_machines):
+        selected_machines = []
+        for machine in virtual_machines:
+            if self._args.host and self._args.host == machine.name:
+                selected_machines.append(machine)
+            if self.tags and self._tags_match(machine.tags, self.tags):
+                selected_machines.append(machine)
+            if self.locations and machine.location in self.locations:
+                selected_machines.append(machine)
+        return selected_machines
+
+    def _get_security_groups(self, resource_group):
+        ''' For a given resource_group build a mapping of network_interface.id to security_group name '''
+        if not self._security_groups:
+            self._security_groups = dict()
+        if not self._security_groups.get(resource_group):
+            self._security_groups[resource_group] = dict()
+            for group in self._network_client.network_security_groups.list(resource_group):
+                if group.network_interfaces:
+                    for interface in group.network_interfaces:
+                        self._security_groups[resource_group][interface.id] = dict(
+                            name=group.name,
+                            id=group.id
+                        )
+
+    def _get_powerstate(self, resource_group, name):
+        try:
+            vm = self._compute_client.virtual_machines.get(resource_group,
+                                                           name,
+                                                           expand='instanceview')
+        except Exception as exc:
+            sys.exit("Error: fetching instanceview for host {0} - {1}".format(name, str(exc)))
+
+        return next((s.code.replace('PowerState/', '')
+                    for s in vm.instance_view.statuses if s.code.startswith('PowerState')), None)
+
+    def _add_host(self, vars):
+
+        host_name = self._to_safe(vars['name'])
+        resource_group = self._to_safe(vars['resource_group'])
+        security_group = None
+        if vars.get('security_group'):
+            security_group = self._to_safe(vars['security_group'])
+
+        if self.group_by_resource_group:
+            if not self._inventory.get(resource_group):
+                self._inventory[resource_group] = []
+            self._inventory[resource_group].append(host_name)
+
+        if self.group_by_location:
+            if not self._inventory.get(vars['location']):
+                self._inventory[vars['location']] = []
+            self._inventory[vars['location']].append(host_name)
+
+        if self.group_by_security_group and security_group:
+            if not self._inventory.get(security_group):
+                self._inventory[security_group] = []
+            self._inventory[security_group].append(host_name)
+
+        self._inventory['_meta']['hostvars'][host_name] = vars
+        self._inventory['azure'].append(host_name)
+
+        if self.group_by_tag and vars.get('tags'):
+            for key, value in vars['tags'].items():
+                safe_key = self._to_safe(key)
+                safe_value = safe_key + '_' + self._to_safe(value)
+                if not self._inventory.get(safe_key):
+                    self._inventory[safe_key] = []
+                if not self._inventory.get(safe_value):
+                    self._inventory[safe_value] = []
+                self._inventory[safe_key].append(host_name)
+                self._inventory[safe_value].append(host_name)
+
+    def _json_format_dict(self, pretty=False):
+        # convert inventory to json
+        if pretty:
+            return json.dumps(self._inventory, sort_keys=True, indent=2)
+        else:
+            return json.dumps(self._inventory)
+
+    def _get_settings(self):
+        # Load settings from the .ini, if it exists. Otherwise,
+        # look for environment values.
+        file_settings = self._load_settings()
+        if file_settings:
+            for key in AZURE_CONFIG_SETTINGS:
+                if key in ('resource_groups', 'tags', 'locations') and file_settings.get(key):
+                    values = file_settings.get(key).split(',')
+                    if len(values) > 0:
+                        setattr(self, key, values)
+                elif file_settings.get(key):
+                    val = self._to_boolean(file_settings[key])
+                    setattr(self, key, val)
+        else:
+            env_settings = self._get_env_settings()
+            for key in AZURE_CONFIG_SETTINGS:
+                if key in('resource_groups', 'tags', 'locations') and env_settings.get(key):
+                    values = env_settings.get(key).split(',')
+                    if len(values) > 0:
+                        setattr(self, key, values)
+                elif env_settings.get(key, None) is not None:
+                    val = self._to_boolean(env_settings[key])
+                    setattr(self, key, val)
+
+    def _parse_ref_id(self, reference):
+        response = {}
+        keys = reference.strip('/').split('/')
+        for index in range(len(keys)):
+            if index < len(keys) - 1 and index % 2 == 0:
+                response[keys[index]] = keys[index + 1]
+        return response
+
+    def _to_boolean(self, value):
+        if value in ['Yes', 'yes', 1, 'True', 'true', True]:
+            result = True
+        elif value in ['No', 'no', 0, 'False', 'false', False]:
+            result = False
+        else:
+            result = True
+        return result
+
+    def _get_env_settings(self):
+        env_settings = dict()
+        for attribute, env_variable in AZURE_CONFIG_SETTINGS.items():
+            env_settings[attribute] = os.environ.get(env_variable, None)
+        return env_settings
+
+    def _load_settings(self):
+        basename = os.path.splitext(os.path.basename(__file__))[0]
+        default_path = os.path.join(os.path.dirname(__file__), (basename + '.ini'))
+        path = os.path.expanduser(os.path.expandvars(os.environ.get('AZURE_INI_PATH', default_path)))
+        config = None
+        settings = None
+        try:
+            config = ConfigParser.ConfigParser()
+            config.read(path)
+        except:
+            pass
+
+        if config is not None:
+            settings = dict()
+            for key in AZURE_CONFIG_SETTINGS:
+                try:
+                    settings[key] = config.get('azure', key, raw=True)
+                except:
+                    pass
+
+        return settings
+
+    def _tags_match(self, tag_obj, tag_args):
+        '''
+        Return True if the tags object from a VM contains the requested tag values.
+
+        :param tag_obj:  Dictionary of string:string pairs
+        :param tag_args: List of strings in the form key=value
+        :return: boolean
+        '''
+
+        if not tag_obj:
+            return False
+
+        matches = 0
+        for arg in tag_args:
+            arg_key = arg
+            arg_value = None
+            if re.search(r':', arg):
+                arg_key, arg_value = arg.split(':')
+            if arg_value and tag_obj.get(arg_key, None) == arg_value:
+                matches += 1
+            elif not arg_value and tag_obj.get(arg_key, None) is not None:
+                matches += 1
+        if matches == len(tag_args):
+            return True
+        return False
+
+    def _to_safe(self, word):
+        ''' Converts 'bad' characters in a string to underscores so they can be used as Ansible groups '''
+        regex = "[^A-Za-z0-9\_"
+        if not self.replace_dash_in_groups:
+            regex += "\-"
+        return re.sub(regex + "]", "_", word)
+
+
+def main():
+    if not HAS_AZURE:
+        sys.exit("The Azure python sdk is not installed (try `pip install 'azure>={0}' --upgrade`) - {1}".format(AZURE_MIN_VERSION, HAS_AZURE_EXC))
+
+    AzureInventory()
+
+
+if __name__ == '__main__':
+    main()

+ 161 - 0
Azure/azure_rm_aks_facts.py

@@ -0,0 +1,161 @@
+#!/usr/bin/python
+#
+# Copyright (c) 2018 Yuwei Zhou, <yuwzho@microsoft.com>
+#
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+
+from __future__ import absolute_import, division, print_function
+__metaclass__ = type
+
+
+ANSIBLE_METADATA = {'metadata_version': '1.1',
+                    'status': ['preview'],
+                    'supported_by': 'community'}
+
+DOCUMENTATION = '''
+---
+module: azure_rm_aks_facts
+
+version_added: "2.6"
+
+short_description: Get Azure Kubernetes Service facts.
+
+description:
+    - Get facts for a specific Azure Kubernetes Service or all Azure Kubernetes Services.
+
+options:
+    name:
+        description:
+            - Limit results to a specific resource group.
+    resource_group:
+        description:
+            - The resource group to search for the desired Azure Kubernetes Service
+    tags:
+        description:
+            - Limit results by providing a list of tags. Format tags as 'key' or 'key:value'.
+
+extends_documentation_fragment:
+    - azure
+
+author:
+    - "Yuwei Zhou (@yuwzho)"
+'''
+
+EXAMPLES = '''
+    - name: Get facts for one Azure Kubernetes Service
+      azure_rm_aks_facts:
+        name: Testing
+        resource_group: TestRG
+
+    - name: Get facts for all Azure Kubernetes Services
+      azure_rm_aks_facts:
+
+    - name: Get facts by tags
+      azure_rm_aks_facts:
+        tags:
+          - testing
+'''
+
+RETURN = '''
+azure_aks:
+    description: List of Azure Kubernetes Service dicts.
+    returned: always
+    type: list
+'''
+
+from ansible.module_utils.azure_rm_common import AzureRMModuleBase
+
+try:
+    from msrestazure.azure_exceptions import CloudError
+    from azure.common import AzureHttpError
+except:
+    # handled in azure_rm_common
+    pass
+
+AZURE_OBJECT_CLASS = 'managedClusters'
+
+
+class AzureRMManagedClusterFacts(AzureRMModuleBase):
+    """Utility class to get Azure Kubernetes Service facts"""
+
+    def __init__(self):
+
+        self.module_args = dict(
+            name=dict(type='str'),
+            resource_group=dict(type='str'),
+            tags=dict(type='list')
+        )
+
+        self.results = dict(
+            changed=False,
+            aks=[]
+        )
+
+        self.name = None
+        self.resource_group = None
+        self.tags = None
+
+        super(AzureRMManagedClusterFacts, self).__init__(
+            derived_arg_spec=self.module_args,
+            supports_tags=False,
+            facts_module=True
+        )
+
+    def exec_module(self, **kwargs):
+
+        for key in self.module_args:
+            setattr(self, key, kwargs[key])
+
+        self.results['aks'] = (
+            self.get_item() if self.name
+            else self.list_items()
+        )
+
+        return self.results
+
+    def get_item(self):
+        """Get a single Azure Kubernetes Service"""
+
+        self.log('Get properties for {0}'.format(self.name))
+
+        item = None
+        result = []
+
+        try:
+            item = self.containerservice_client.managed_clusters.get(
+                self.resource_group, self.name)
+        except CloudError:
+            pass
+
+        if item and self.has_tags(item.tags, self.tags):
+            result = [self.serialize_obj(item, AZURE_OBJECT_CLASS)]
+
+        return result
+
+    def list_items(self):
+        """Get all Azure Kubernetes Services"""
+
+        self.log('List all Azure Kubernetes Services')
+
+        try:
+            response = self.containerservice_client.managed_clusters.list(
+                self.resource_group)
+        except AzureHttpError as exc:
+            self.fail('Failed to list all items - {0}'.format(str(exc)))
+
+        results = []
+        for item in response:
+            if self.has_tags(item.tags, self.tags):
+                results.append(self.serialize_obj(item, AZURE_OBJECT_CLASS))
+
+        return results
+
+
+def main():
+    """Main module execution code path"""
+
+    AzureRMManagedClusterFacts()
+
+
+if __name__ == '__main__':
+    main()

+ 94 - 0
Azure/azure_service_principal_attribute.py

@@ -0,0 +1,94 @@
+# (c) 2018 Yunge Zhu, <yungez@microsoft.com>
+# (c) 2017 Ansible Project
+# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+DOCUMENTATION = """
+lookup: azure_service_principal_attribute
+
+requirements:
+  - azure-graphrbac
+
+author:
+  - Yunge Zhu <yungez@microsoft.com>
+
+version_added: "2.7"
+
+short_description: Look up Azure service principal attributes.
+
+description:
+  - Describes object id of your Azure service principal account.
+options:
+  azure_client_id:
+    description: azure service principal client id.
+  azure_secret:
+    description: azure service principal secret
+  azure_tenant:
+    description: azure tenant
+  azure_cloud_environment:
+    description: azure cloud environment
+"""
+
+EXAMPLES = """
+set_fact:
+  object_id: "{{ lookup('azure_service_principal_attribute',
+                         azure_client_id=azure_client_id,
+                         azure_secret=azure_secret,
+                         azure_tenant=azure_secret) }}"
+"""
+
+RETURN = """
+_raw:
+  description:
+    Returns object id of service principal.
+"""
+
+from ansible.errors import AnsibleError
+from ansible.plugins import AnsiblePlugin
+from ansible.plugins.lookup import LookupBase
+from ansible.module_utils._text import to_native
+
+try:
+    from azure.common.credentials import ServicePrincipalCredentials
+    from azure.graphrbac import GraphRbacManagementClient
+    from msrestazure import azure_cloud
+    from msrestazure.azure_exceptions import CloudError
+except ImportError:
+    raise AnsibleError(
+        "The lookup azure_service_principal_attribute requires azure.graphrbac, msrest")
+
+
+class LookupModule(LookupBase):
+    def run(self, terms, variables, **kwargs):
+
+        self.set_options(direct=kwargs)
+
+        credentials = {}
+        credentials['azure_client_id'] = self.get_option('azure_client_id', None)
+        credentials['azure_secret'] = self.get_option('azure_secret', None)
+        credentials['azure_tenant'] = self.get_option('azure_tenant', 'common')
+
+        if credentials['azure_client_id'] is None or credentials['azure_secret'] is None:
+            raise AnsibleError("Must specify azure_client_id and azure_secret")
+
+        _cloud_environment = azure_cloud.AZURE_PUBLIC_CLOUD
+        if self.get_option('azure_cloud_environment', None) is not None:
+            cloud_environment = azure_cloud.get_cloud_from_metadata_endpoint(credentials['azure_cloud_environment'])
+
+        try:
+            azure_credentials = ServicePrincipalCredentials(client_id=credentials['azure_client_id'],
+                                                            secret=credentials['azure_secret'],
+                                                            tenant=credentials['azure_tenant'],
+                                                            resource=_cloud_environment.endpoints.active_directory_graph_resource_id)
+
+            client = GraphRbacManagementClient(azure_credentials, credentials['azure_tenant'],
+                                               base_url=_cloud_environment.endpoints.active_directory_graph_resource_id)
+
+            response = list(client.service_principals.list(filter="appId eq '{0}'".format(credentials['azure_client_id'])))
+            sp = response[0]
+
+            return sp.object_id.split(',')
+        except CloudError as ex:
+            raise AnsibleError("Failed to get service principal object id: %s" % to_native(ex))
+        return False

+ 187 - 0
Azure/azure_storage.py

@@ -0,0 +1,187 @@
+import mimetypes
+import os.path
+import uuid
+
+from azure.common import AzureMissingResourceHttpError
+from azure.storage.blob import BlobServiceClient
+from django.conf import settings
+from django.core.exceptions import ImproperlyConfigured
+from django.core.files.base import ContentFile
+from django.core.files.storage import Storage
+from django.utils.deconstruct import deconstructible
+from django.utils.timezone import localtime
+from datetime import datetime
+
+
+from pac.blobs.utils import AttachedFile
+
+try:
+    import azure  # noqa
+except ImportError:
+    raise ImproperlyConfigured(
+        "Could not load Azure bindings. "
+        "See https://github.com/WindowsAzure/azure-sdk-for-python")
+
+
+def clean_name(name):
+    return os.path.normpath(name).replace("\\", "/")
+
+
+def setting(name, default=None):
+    """
+    Helper function to get a Django setting by name. If setting doesn't exists
+    it will return a default.
+    :param name: Name of setting
+    :type name: str
+    :param default: Value if setting is unfound
+    :returns: Setting's value
+    """
+    return getattr(settings, name, default)
+
+
+@deconstructible
+class AzureStorage(Storage):
+    # account_name = setting("AZURE_ACCOUNT_NAME")
+    # account_key = setting("AZURE_ACCOUNT_KEY")
+
+    azure_container_url = setting("AZURE_CONTAINER_URL")
+    azure_ssl = setting("AZURE_SSL")
+
+    def __init__(self, container=None, *args, **kwargs):
+        super(AzureStorage, self).__init__(*args, **kwargs)
+        self._connection = None
+
+        if container is None:
+            self.azure_container = setting("AZURE_CONTAINER")
+        else:
+            self.azure_container = container
+
+    def get_available_name(self, name, *args, **kwargs):
+
+        return {"original_name": name, "uuid": str(uuid.uuid4())}
+
+    @property
+    def connection(self):
+
+        if self._connection is None:
+            connect_str = setting("AZURE_STORAGE_CONNECTION_STRING")
+
+            # Create the BlobServiceClient object which will be used to create a container client
+            blob_service_client = BlobServiceClient.from_connection_string(connect_str)
+
+            # Create a unique name for the container
+            container_name = "pac-files"
+
+            # Create a blob client using the local file name as the name for the blob
+            self._connection = blob_service_client
+
+        return self._connection
+
+    @property
+    def azure_protocol(self):
+        """
+        :return: http | https | None
+        :rtype: str | None
+        """
+        if self.azure_ssl:
+            return 'https'
+        return 'http' if self.azure_ssl is not None else None
+
+    def __get_blob_properties(self, name):
+        """
+        :param name: Filename
+        :rtype: azure.storage.blob.models.Blob | None
+        """
+        try:
+            return self.connection.get_blob_properties(
+                self.azure_container,
+                name
+            )
+        except AzureMissingResourceHttpError:
+            return None
+
+    def _open(self, container, name, mode="rb"):
+        """
+        :param str name: Filename
+        :param str mode:
+        :rtype: ContentFile
+        """
+        print(f'Retrieving blob: container={self.azure_container}, blob={name}')
+        blob_client = self.connection.get_blob_client(container=container, blob=name)
+        contents = blob_client.download_blob().readall()
+        return ContentFile(contents)
+
+    def exists(self, name):
+        """
+        :param name: File name
+        :rtype: bool
+        """
+        return False  # self.__get_blob_properties(name) is not None
+
+    def delete(self, name):
+        """
+        :param name: File name
+        :return: None
+        """
+        try:
+            self.connection.delete_blob(self.azure_container, name)
+        except AzureMissingResourceHttpError:
+            pass
+
+    def size(self, name):
+        """
+        :param name:
+        :rtype: int
+        """
+        blob = self.connection.get_blob_properties(self.azure_container, name)
+        return blob.properties.content_length
+
+    def _save(self, name, content):
+        """
+        :param name:
+        :param File content:
+        :return:
+        """
+        original_name = name.get("original_name")
+        blob_file_name = datetime.now().strftime("%Y%m%d-%H:%M:%S.%f_") + original_name
+        # blob_name = "{}.{}".format(name.get("uuid"), original_name.partition(".")[-1])
+
+        if hasattr(content.file, 'content_type'):
+            content_type = content.file.content_type
+        else:
+            content_type = mimetypes.guess_type(original_name)
+
+        if hasattr(content, 'chunks'):
+            content_data = b''.join(chunk for chunk in content.chunks())
+        else:
+            content_data = content.read()
+
+        print(f'Saving blob: container={self.azure_container}, blob={blob_file_name}')
+        blob_client = self.connection.get_blob_client(container=self.azure_container, blob=blob_file_name)
+        obj = blob_client.upload_blob(content_data)
+        # create_blob_from_bytes(self.azure_container, name, content_data,
+        #
+        #                                        content_settings=ContentSettings(content_type=content_type))
+        af = AttachedFile(original_name, self.azure_container, blob_file_name)
+        return af
+
+    def url(self, name):
+        """
+
+        :param str name: Filename
+        :return: path
+        """
+        return self.connection.make_blob_url(
+            container_name=self.azure_container,
+            blob_name=name,
+            protocol=self.azure_protocol,
+        )
+
+    def modified_time(self, name):
+        """
+        :param name:
+        :rtype: datetime.datetime
+        """
+        blob = self.__get_blob_properties(name)
+
+        return localtime(blob.properties.last_modified).replace(tzinfo=None)

+ 155 - 0
Azure/azure_system_helpers.py

@@ -0,0 +1,155 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import json
+import os
+import random
+import string
+from contextlib import contextmanager
+
+import pytest
+
+from airflow.exceptions import AirflowException
+from airflow.models import Connection
+from airflow.providers.microsoft.azure.hooks.fileshare import AzureFileShareHook
+from airflow.utils.process_utils import patch_environ
+from tests.test_utils import AIRFLOW_MAIN_FOLDER
+from tests.test_utils.system_tests_class import SystemTest
+
+AZURE_DAG_FOLDER = os.path.join(
+    AIRFLOW_MAIN_FOLDER, "airflow", "providers", "microsoft", "azure", "example_dags"
+)
+WASB_CONNECTION_ID = os.environ.get("WASB_CONNECTION_ID", "wasb_default")
+
+DATA_LAKE_CONNECTION_ID = os.environ.get("AZURE_DATA_LAKE_CONNECTION_ID", 'azure_data_lake_default')
+DATA_LAKE_CONNECTION_TYPE = os.environ.get("AZURE_DATA_LAKE_CONNECTION_TYPE", 'azure_data_lake')
+
+
+@contextmanager
+def provide_wasb_default_connection(key_file_path: str):
+    """
+    Context manager to provide a temporary value for wasb_default connection
+
+    :param key_file_path: Path to file with wasb_default credentials .json file.
+    """
+    if not key_file_path.endswith(".json"):
+        raise AirflowException("Use a JSON key file.")
+    with open(key_file_path) as credentials:
+        creds = json.load(credentials)
+    conn = Connection(
+        conn_id=WASB_CONNECTION_ID,
+        conn_type="wasb",
+        host=creds.get("host", None),
+        login=creds.get("login", None),
+        password=creds.get("password", None),
+        extra=json.dumps(creds.get('extra', None)),
+    )
+    with patch_environ({f"AIRFLOW_CONN_{conn.conn_id.upper()}": conn.get_uri()}):
+        yield
+
+
+@contextmanager
+def provide_azure_data_lake_default_connection(key_file_path: str):
+    """
+    Context manager to provide a temporary value for azure_data_lake_default connection
+    :param key_file_path: Path to file with azure_data_lake_default credentials .json file.
+    """
+    required_fields = {'login', 'password', 'extra'}
+
+    if not key_file_path.endswith(".json"):
+        raise AirflowException("Use a JSON key file.")
+    with open(key_file_path) as credentials:
+        creds = json.load(credentials)
+    missing_keys = required_fields - creds.keys()
+    if missing_keys:
+        message = f"{missing_keys} fields are missing"
+        raise AirflowException(message)
+    conn = Connection(
+        conn_id=DATA_LAKE_CONNECTION_ID,
+        conn_type=DATA_LAKE_CONNECTION_TYPE,
+        host=creds.get("host", None),
+        login=creds.get("login", None),
+        password=creds.get("password", None),
+        extra=json.dumps(creds.get('extra', None)),
+    )
+    with patch_environ({f"AIRFLOW_CONN_{conn.conn_id.upper()}": conn.get_uri()}):
+        yield
+
+
+@contextmanager
+def provide_azure_fileshare(share_name: str, azure_fileshare_conn_id: str, file_name: str, directory: str):
+    AzureSystemTest.prepare_share(
+        share_name=share_name,
+        azure_fileshare_conn_id=azure_fileshare_conn_id,
+        file_name=file_name,
+        directory=directory,
+    )
+    yield
+    AzureSystemTest.delete_share(share_name=share_name, azure_fileshare_conn_id=azure_fileshare_conn_id)
+
+
+@pytest.mark.system("azure")
+class AzureSystemTest(SystemTest):
+    @classmethod
+    def create_share(cls, share_name: str, azure_fileshare_conn_id: str):
+        hook = AzureFileShareHook(azure_fileshare_conn_id=azure_fileshare_conn_id)
+        hook.create_share(share_name)
+
+    @classmethod
+    def delete_share(cls, share_name: str, azure_fileshare_conn_id: str):
+        hook = AzureFileShareHook(azure_fileshare_conn_id=azure_fileshare_conn_id)
+        hook.delete_share(share_name)
+
+    @classmethod
+    def create_directory(cls, share_name: str, azure_fileshare_conn_id: str, directory: str):
+        hook = AzureFileShareHook(azure_fileshare_conn_id=azure_fileshare_conn_id)
+        hook.create_directory(share_name=share_name, directory_name=directory)
+
+    @classmethod
+    def upload_file_from_string(
+        cls,
+        string_data: str,
+        share_name: str,
+        azure_fileshare_conn_id: str,
+        file_name: str,
+        directory: str,
+    ):
+        hook = AzureFileShareHook(azure_fileshare_conn_id=azure_fileshare_conn_id)
+        hook.load_string(
+            string_data=string_data,
+            share_name=share_name,
+            directory_name=directory,
+            file_name=file_name,
+        )
+
+    @classmethod
+    def prepare_share(cls, share_name: str, azure_fileshare_conn_id: str, file_name: str, directory: str):
+        """
+        Create share with a file in given directory. If directory is None, file is in root dir.
+        """
+        cls.create_share(share_name=share_name, azure_fileshare_conn_id=azure_fileshare_conn_id)
+        cls.create_directory(
+            share_name=share_name, azure_fileshare_conn_id=azure_fileshare_conn_id, directory=directory
+        )
+        string_data = "".join(random.choice(string.ascii_letters) for _ in range(1024))
+        cls.upload_file_from_string(
+            string_data=string_data,
+            share_name=share_name,
+            azure_fileshare_conn_id=azure_fileshare_conn_id,
+            file_name=file_name,
+            directory=directory,
+        )

+ 283 - 0
Azure/classAzureProvider.py

@@ -0,0 +1,283 @@
+import base64
+import json
+import random
+# import re
+import os
+# import shutil
+import sys
+# import tempfile
+
+from azure.devops.connection import Connection
+from msrest.authentication import BasicAuthentication
+
+from bdscan import classSCMProvider
+from bdscan import globals
+
+# from bdscan import utils
+
+# import azure
+# import azure.devops
+import requests
+from azure.devops.v6_0.git import GitPushRef, GitRefUpdate, GitPush, GitCommitRef, GitPullRequest, \
+    GitPullRequestCommentThread, Comment, GitPullRequestSearchCriteria
+
+
+class AzureProvider(classSCMProvider.SCMProvider):
+    def __init__(self):
+        super().__init__()
+        self.scm = 'azure'
+
+        self.azure_base_url = ''
+        self.azure_api_token = ''
+        self.azure_pull_request_id = ''
+        self.azure_project = ''
+        self.azure_project_id = ''
+        self.azure_repo_id = ''
+        self.azure_build_source_branch = ''
+
+        self.azure_credentials = None
+        self.azure_connection = None
+
+        self.azure_git_client = None
+
+    def init(self):
+        globals.printdebug(f"DEBUG: Initializing Azure DevOps SCM Provider")
+
+        self.azure_base_url = os.getenv('SYSTEM_COLLECTIONURI')
+        self.azure_api_token = os.getenv('SYSTEM_ACCESSTOKEN')
+        if not self.azure_api_token:
+            self.azure_api_token = os.getenv('AZURE_API_TOKEN')
+        self.azure_pull_request_id = os.getenv('SYSTEM_PULLREQUEST_PULLREQUESTID')
+        self.azure_project = os.getenv('SYSTEM_TEAMPROJECT')
+        self.azure_project_id = os.getenv('SYSTEM_TEAMPROJECTID')
+        self.azure_repo_id = os.getenv('BUILD_REPOSITORY_ID')
+        self.azure_build_source_branch = os.getenv('BUILD_SOURCEBRANCH')
+
+        globals.printdebug(f'DEBUG: Azure DevOps base_url={self.azure_base_url} api_token={self.azure_api_token} '
+                           f'pull_request_id={self.azure_pull_request_id} project={self.azure_project} '
+                           f'project_id={self.azure_project_id} repo_id={self.azure_repo_id}')
+
+        if not self.azure_base_url or not self.azure_project or not self.azure_repo_id or not self.azure_api_token \
+                or not self.azure_project_id:
+            print(f'BD-Scan-Action: ERROR: Azure DevOps requires that SYSTEM_COLLECTIONURI, SYSTEM_TEAMPROJECT,'
+                  'SYSTEM_TEAMPROJECTID, SYSTEM_ACCESSTOKEN or AZURE_API_TOKEN, and BUILD_REPOSITORY_ID be set.')
+            sys.exit(1)
+
+        if globals.args.comment_on_pr and not self.azure_pull_request_id:
+            print(f'BD-Scan-Action: ERROR: Azure DevOps requires that SYSTEM_PULLREQUEST_PULLREQUESTID be set'
+                  'when operating on a pull request')
+            sys.exit(1)
+
+        if globals.args.fix_pr and not self.azure_build_source_branch:
+            print(f'BD-Scan-Action: ERROR: Azure DevOps requires that BUILD_SOURCEBRANCH be set'
+                  'when operating on a pull request')
+            sys.exit(1)
+
+        self.azure_credentials = BasicAuthentication('', self.azure_api_token)
+        self.azure_connection = Connection(base_url=self.azure_base_url, creds=self.azure_credentials)
+
+        # Get a client (the "core" client provides access to projects, teams, etc)
+        self.azure_git_client = self.azure_connection.clients.get_git_client()
+
+        return True
+
+    def azure_create_branch(self, from_ref, branch_name):
+        authorization = str(base64.b64encode(bytes(':' + self.azure_api_token, 'ascii')), 'ascii')
+
+        url = f"{self.azure_base_url}/_apis/git/repositories/{self.azure_repo_id}/refs?api-version=6.0"
+
+        headers = {
+            'Authorization': 'Basic ' + authorization
+        }
+
+        body = [
+            {
+                'name': f"refs/heads/{branch_name}",
+                'oldObjectId': '0000000000000000000000000000000000000000',
+                'newObjectId': from_ref
+            }
+        ]
+
+        if globals.debug > 0:
+            print("DEBUG: perform API Call to ADO: " + url + " : " + json.dumps(body, indent=4, sort_keys=True) + "\n")
+        r = requests.post(url, json=body, headers=headers)
+
+        if r.status_code == 200:
+            if globals.debug > 0:
+                print(f"DEBUG: Success creating branch")
+                print(r.text)
+            return True
+        else:
+            print(f"BD-Scan-Action: ERROR: Failure creating branch: Error {r.status_code}")
+            print(r.text)
+            return False
+
+    def comp_commit_file_and_create_fixpr(self, comp, files_to_patch):
+        if len(files_to_patch) == 0:
+            print('BD-Scan-Action: WARN: Unable to apply fix patch - cannot determine containing package file')
+            return False
+
+        new_branch_seed = '%030x' % random.randrange(16 ** 30)
+        new_branch_name = f"synopsys-enablement-{new_branch_seed}"
+
+        globals.printdebug(f"DEBUG: Get commit for head of {self.azure_build_source_branch}'")
+
+        commits = self.azure_git_client.get_commits(self.azure_repo_id, None)
+        head_commit = commits[0]
+
+        globals.printdebug(f"DEBUG: Head commit={head_commit.commit_id}")
+
+        globals.printdebug(f"DEBUG: Creating new ref 'refs/heads/{new_branch_name}'")
+        self.azure_create_branch(head_commit.commit_id, new_branch_name)
+
+        gitRefUpdate = GitRefUpdate()
+        gitRefUpdate.name = f"refs/heads/{new_branch_name}"
+        gitRefUpdate.old_object_id = head_commit.commit_id
+
+        gitPush = GitPush()
+        gitPush.commits = []
+        gitPush.ref_updates = [gitRefUpdate]
+
+        # for file_to_patch in globals.files_to_patch:
+        for pkgfile in files_to_patch:
+            globals.printdebug(f"DEBUG: Upload file '{pkgfile}'")
+            try:
+                with open(files_to_patch[pkgfile], 'r') as fp:
+                    new_contents = fp.read()
+            except Exception as exc:
+                print(f"BD-Scan-Action: ERROR: Unable to open package file '{files_to_patch[pkgfile]}'"
+                      f" - {str(exc)}")
+                return False
+
+            gitCommitRef = GitCommitRef()
+            gitCommitRef.comment = "Added Synopsys pipeline template"
+            gitCommitRef.changes = [
+                {
+                    'changeType': 'edit',
+                    'item': {
+                        'path': pkgfile
+                    },
+                    'newContent': {
+                        'content': new_contents,
+                        'contentType': 'rawText'
+                    }
+                }
+            ]
+
+            gitPush.commits.append(gitCommitRef)
+
+            # globals.printdebug(f"DEBUG: Update file '{pkgfile}' with commit message '{commit_message}'")
+            # file = repo.update_file(pkgfile, commit_message, new_contents, orig_contents.sha, branch=new_branch_name)
+
+        push = self.azure_git_client.create_push(gitPush, self.azure_repo_id)
+
+        if not push:
+            print(f"BD-Scan-Action: ERROR: Create push failed")
+            sys.exit(1)
+
+        pr_title = f"Black Duck: Upgrade {comp.name} to version {comp.goodupgrade} fix known security vulerabilities"
+        pr_body = f"\n# Synopsys Black Duck Auto Pull Request\n" \
+                  f"Upgrade {comp.name} from version {comp.version} to " \
+                  f"{comp.goodupgrade} in order to fix security vulnerabilities:\n\n"
+
+        gitPullRequest = GitPullRequest()
+        gitPullRequest.source_ref_name = f"refs/heads/{new_branch_name}"
+        gitPullRequest.target_ref_name = self.azure_build_source_branch
+        gitPullRequest.title = pr_title
+        gitPullRequest.description = pr_body
+
+        pull = self.azure_git_client.create_pull_request(gitPullRequest, self.azure_repo_id)
+
+        if not pull:
+            print(f"BD-Scan-Action: ERROR: Create pull request failed")
+            sys.exit(1)
+
+        return True
+
+    def comp_fix_pr(self, comp):
+        ret = True
+        globals.printdebug(f"DEBUG: Fix '{comp.name}' version '{comp.version}' in "
+                           f"file '{comp.projfiles}' using ns '{comp.ns}' to version "
+                           f"'{comp.goodupgrade}'")
+
+        pull_request_title = f"Black Duck: Upgrade {comp.name} to version " \
+                             f"{comp.goodupgrade} to fix known security vulnerabilities"
+
+        search_criteria = None  # GitPullRequestSearchCriteria()
+
+        pulls = self.azure_git_client.get_pull_requests(self.azure_repo_id, search_criteria)
+        for pull in pulls:
+            if pull_request_title in pull.title:
+                globals.printdebug(f"DEBUG: Skipping pull request for {comp.name}' version "
+                                   f"'{comp.goodupgrade} as it is already present")
+                return
+
+        files_to_patch = comp.do_upgrade_dependency()
+
+        if len(files_to_patch) == 0:
+            print('BD-Scan-Action: WARN: Unable to apply fix patch - cannot determine containing package file')
+            return False
+
+        if not self.comp_commit_file_and_create_fixpr(comp, files_to_patch):
+            ret = False
+        return ret
+
+    def pr_comment(self, comment):
+        pr_threads = self.azure_git_client.get_threads(self.azure_repo_id, self.azure_pull_request_id)
+        existing_thread = None
+        existing_comment = None
+        for pr_thread in pr_threads:
+            for pr_thread_comment in pr_thread.comments:
+                if pr_thread_comment.content and globals.comment_on_pr_header in pr_thread_comment.content:
+                    existing_thread = pr_thread
+                    existing_comment = pr_thread_comment
+
+        comments_markdown = f"# {globals.comment_on_pr_header}\n{comment}"
+
+        if len(comments_markdown) > 65535:
+            comments_markdown = comments_markdown[:65535]
+
+        if existing_comment is not None:
+            globals.printdebug(f"DEBUG: Update/edit existing comment for PR #{self.azure_pull_request_id}\n"
+                               f"{comments_markdown}")
+
+            pr_thread_comment = Comment()
+            pr_thread_comment.parent_comment_id = 0
+            pr_thread_comment.content = comments_markdown
+            pr_thread_comment.comment_type = 1
+
+            retval = self.azure_git_client.update_comment(pr_thread_comment, self.azure_repo_id,
+                                                          self.azure_pull_request_id, existing_thread.id,
+                                                          existing_comment.id)
+
+            globals.printdebug(f"DEBUG: Updated thread, retval={retval}")
+        else:
+            globals.printdebug(f"DEBUG: Create new thread for PR #{self.azure_pull_request_id}")
+
+            pr_thread_comment = Comment()
+            pr_thread_comment.parent_comment_id = 0
+            pr_thread_comment.content = comments_markdown
+            pr_thread_comment.comment_type = 1
+
+            pr_thread = GitPullRequestCommentThread()
+            pr_thread.comments = [pr_thread_comment]
+            pr_thread.status = 1
+
+            retval = self.azure_git_client.create_thread(pr_thread, self.azure_repo_id, self.azure_pull_request_id)
+
+            globals.printdebug(f"DEBUG: Created thread, retval={retval}")
+        return True
+
+    def set_commit_status(self, is_ok):
+        globals.printdebug(f"WARNING: Azure DevOps does not support set_commit_status")
+        return
+
+    def check_files_in_pull_request(self):
+        globals.printdebug(f"WARNING: Azure DevOps does not support querying changed files, returning True")
+        found = True
+        return found
+
+    def check_files_in_commit(self):
+        globals.printdebug(f"WARNING: Azure DevOps does not support querying committed files, returning True")
+        found = True
+        return found

+ 445 - 0
Azure/client.py

@@ -0,0 +1,445 @@
+import os, sys, json
+from libfuncs import (
+    split_line
+)
+from crypto_client import (
+    create_keys,
+    save_file,
+    load_key_from_file,
+    generate_sign,
+    generate_sign_hex,
+    verify_sign,
+    encrypt,
+    decrypt,
+    make_md5
+)
+
+from azure_key_vault_client import (
+    get_secret as azure_get_secret,
+    set_secret as azure_set_secret,
+    del_secret as azure_del_secret,
+    get_secrets_keys as azure_get_secrets_keys,
+    get_deleted_secret as azure_get_deleted_secret,
+    purge_deleted_secret as azure_purge_deleted_secret,
+    azure_environ_name as eaen
+
+)
+from functools import (
+    wraps
+)
+
+from enum import (
+    Enum,
+    auto
+)
+
+
+class enumbase(Enum):
+    @property
+    def info(self):
+        return f"{self.name}:{self.value}"
+
+
+class autoname(enumbase):
+    def _generate_next_value_(name, start, count, last_values):
+        return name.lower()
+
+
+class azure_key_vault(object):
+    SPLIT_SYMBOL = ";"
+
+    class secret(object):
+        class ATTER_NAMES(autoname):
+            NAME = auto()
+            NAMES = auto()
+            ITEMS = auto()
+            KEY_VAULT_NAME = auto()
+            KEY_NAME = auto()
+            KEY_VALUE = auto()
+
+        def __init__(self, name):
+            [setattr(self, item.value, "") for item in self.ATTER_NAMES]
+            self.name = name
+
+    def __init__(self, names):
+        if names and isinstance(names, str):
+            names = names.split(self.SPLIT_SYMBOL)
+
+        if not names:
+            raise ValueError(f"input args({names}) is invalid.")
+
+        setattr(self, self.secret.ATTER_NAMES.NAMES.value, set(names))
+        for name in self.names:
+            setattr(self, name, self.secret(name))
+
+    def name_to_str(self, name):
+        return name if isinstance(name, str) else name.value
+
+    def add(self, name):
+        name = self.name_to_str(name)
+        setattr(self, name, self.secret(name))
+
+    def get(self, name):
+        name = self.name_to_str(name)
+        return getattr(self, name)
+
+    def set(self, name, key_vault_name, key_name, key_value=""):
+        secret = self.get(name)
+        assert secret, f"not found secret({name})"
+
+        secret.key_vault_name = key_vault_name
+        secret.key_name = key_name
+        secret.key_value = key_value
+
+    def is_exists(self, name):
+        name = self.name_to_str(name)
+
+        return name in self.names
+
+    def __getatter__(self, name):
+        if name == self.secret.ATTER_NAMES.ITEMS.value:
+            return [getattr(self, name) for name in self.names]
+        elif name == self.secret.ATTER_NAMES.NAMES.value:
+            return self.names
+
+
+class safemsgclient(object):
+    key_memory_id = "memory_id"
+    key_head_flag = "memkey_"
+
+    class azure_names(autoname):
+        '''
+           SIGN_KEY: private key
+           VERIFY_KEY: SIGN_KEY's public key
+           ENCRYPT_KEY: public key
+           DECRYPT_KEY: ENCRYPT_KEY's public key
+        '''
+        SIGN_KEY = auto()
+        VERIFY_KEY = auto()
+        ENCRYPT_KEY = auto()
+        DECRYPT_KEY = auto()
+
+    class key_source(autoname):
+        FILE = auto()
+        KEY_VAULT = auto()
+        MEMORY = auto()
+
+    def __init__(self, key_source=key_source.FILE, azure_names=azure_names, use_mempool=True, *args, **kwargs):
+        self.set_key_source(key_source)
+        setattr(self, "use_mempool", use_mempool)
+        self.__mempool_secrets = {}
+        self.set_key_source(key_source)
+        if key_source == key_source.KEY_VAULT:
+            self.__init_azure_env_id()
+            self.__init_azure_key_value_name(azure_names)
+        pass
+
+    def clear_mempool_secrets(self):
+        self.__mempool_secrets = {}
+
+    def use_mempool_secret(self):
+        self.use_mempool = True
+
+    def unuse_mempool_secret(self):
+        self.use_mempool = False
+
+    def __init_azure_env_id(self):
+        for item in eaen:
+            setattr(self, item.name, item)
+
+    def __init_azure_key_value_name(self, azure_names=azure_names):
+        setattr(self, "azure_key_vault", azure_key_vault([item.value for item in self.azure_names]))
+
+    def create_keys(self, num=2048, **kwargs):
+        return create_keys(num)
+
+    def save(self, key, filename, **kwargs):
+        if filename:
+            return save_file(key, filename)
+        return False
+
+    def load_key(self, filename, **kwargs):
+        secret = None
+        if filename:
+            if self.use_mempool:
+                secret = self.get_memory_key_value(filename)
+            if not secret:
+                secret = load_key_from_file(filename)
+                self.set_memory_key_value(filename, secret)
+            return secret
+        return None
+
+    def pre_azure_key(f):
+        def use_azure(*args, **kwargs):
+            self = args[0]
+            args = list(args[1:])
+            key_source = getattr(self, "key_source", None)
+
+            key = None
+            if args[0] and len(args[0]) > 0:
+                key = args[0]
+            elif key_source == self.key_source.MEMORY:
+                memory_id = kwargs.get(self.key_memory_id)
+                key = self.get_memory_key_value(memory_id)
+            elif key_source == self.key_source.FILE:
+                filename = kwargs.get("filename")
+                key = self.load_key(filename)
+            elif key_source == self.key_source.KEY_VAULT:
+                azure_name = kwargs.get("azure_name")
+                key_vault = self.azure_key_vault.get(azure_name)
+                key = self.get_azure_secret_value(key_vault.key_vault_name, key_vault.key_name)
+
+            args[0] = key
+
+            return f(self, *args, **kwargs)
+
+        return use_azure
+
+    def load_key_from_file(self, filename):
+        return load_key_from_file(filename)
+
+    @pre_azure_key
+    def verify_sign(self, pubkey, message, sign, secret=None, **kwargs):
+        return verify_sign(pubkey, message, sign, secret)
+
+    @pre_azure_key
+    def generate_sign(self, privkey, unsign_message, secret=None, **kwargs):
+        return generate_sign(privkey, unsign_message, secret)
+
+    @pre_azure_key
+    def generate_sign_hex(self, privkey, unsign_message, secret=None, **kwargs):
+        return generate_sign_hex(privkey, unsign_message, secret)
+
+    @pre_azure_key
+    def encrypt(self, pubkey, message, secret=None, **kwargs):
+        return encrypt(pubkey, message, secret)
+
+    @pre_azure_key
+    def decrypt(self, privkey, encrypt_message, secret=None, sentinel=None, **kwargs):
+        return decrypt(privkey, encrypt_message, secret, sentinel)
+
+    @classmethod
+    def make_md5(self, message):
+        return make_md5(message)
+
+    ''' set azure env value
+    '''
+
+    def set_azure_client_id(self, id):
+        self.AZURE_CLIENT_ID.env = id
+
+    def get_azure_client_id(self):
+        return self.AZURE_CLIENT_ID.env
+
+    def set_azure_tenant_id(self, id):
+        self.AZURE_TENANT_ID.env = id
+
+    def get_azure_tenant_id(self):
+        return self.AZURE_TENANT_ID.env
+
+    def set_azure_client_secret(self, secret):
+        self.AZURE_CLIENT_SECRET.env = secret
+
+    def get_azure_client_secret(self):
+        return self.AZURE_CLIENT_SECRET.env
+
+    def set_azure_secret_ids(self, client_id, tenant_id, secret):
+        self.set_azure_client_id(client_id)
+        self.set_azure_tenant_id(tenant_id)
+        self.set_azure_client_secret(secret)
+
+    def encode_azure_secret_ids(self, pub_key, client_id, tenant_id, secret, encode_secret=None):
+        kwargs = {
+            "AZURE_CLIENT_ID": client_id, \
+            "AZURE_TENANT_ID": tenant_id, \
+            "AZURE_CLIENT_SECRET": secret
+        }
+        datas = json.dumps(kwargs)
+        ids_datas = self.encrypt(pub_key, datas, encode_secret)
+        return ids_datas.encode().hex()
+
+    def decode_azure_secret_ids(self, datas, pri_key, encode_secret=None):
+        encrypt_datas = bytes.fromhex(datas).decode()
+        decrypt_datas = self.decrypt(pri_key, encrypt_datas, encode_secret)
+        ids = json.loads(decrypt_datas)
+        return (ids.get("AZURE_CLIENT_ID"), \
+                ids.get("AZURE_TENANT_ID"), \
+                ids.get("AZURE_CLIENT_SECRET"))
+
+    def save_azure_secret_ids_to_file(self, ids_filename, pub_key, client_id, tenant_id, secret, encode_secret=None):
+        datas = self.encode_azure_secret_ids(pub_key, client_id, tenant_id, secret, encode_secret)
+
+        with open(ids_filename, 'w') as pf:
+            pf.write(datas)
+            return True
+        return False
+
+    def load_azure_secret_ids_from_file(self, ids_filename, pri_key, encode_secret=None):
+        encrypt_datas = None
+        with open(ids_filename, 'r') as pf:
+            encrypt_datas = pf.read()
+
+        assert encrypt_datas, f"load ids file(ids_filename) failed."
+
+        return self.decode_azure_secret_ids(encrypt_datas, pri_key, encode_secret)
+
+    def set_azure_secret_ids_with_file(self, ids_filename, pri_key, encode_secret=None):
+        key_source = getattr(self, "key_source", None)
+        assert key_source == key_source.KEY_VAULT, f"client key source is not {key_source.KEY_VAULT.name}"
+        client_id, tenant_id, secret = self.load_azure_secret_ids_from_file(ids_filename, pri_key, encode_secret)
+        self.set_azure_secret_ids(client_id, tenant_id, secret)
+
+    def get_azure_envs(self):
+        '''
+            @dev show all environ info of azure
+            @return all environ info for azure
+        '''
+        return {item.name: getattr(self, item.name).env for item in eaen}
+
+    '''
+       azure key vault operate, must connect azure with azure cli or environ id
+       connect to azure:
+            case 1: az login -u USERNAME -p PASSWORD
+            case 2:  use set_azure_secret_ids to set environ id
+       CRUD operate: get_azure_secret, set_azure_secret, del_azure_secret
+    '''
+
+    def get_azure_secret(self, vault_name, key_name, version=None, **kwargs):
+        '''
+        @dev get secret from azure key vault
+        @param vault_name key vault name
+        @param key_name sercrt's key
+        @param version version of the secret to get. if unspecified, gets the latest version
+        @return secret(KeyVaultSecret)
+        '''
+        update_mempool = True
+        secret = None
+        key = self.create_memory_key_with_args(vault_name, key_name, version)
+        if self.use_mempool:
+            secret = self.get_memory_key_value(key)
+            if not secret:
+                secret = azure_get_secret(vault_name, key_name, version, **kwargs)
+            else:
+                update_mempool = False
+        else:
+            secret = azure_get_secret(vault_name, key_name, version, **kwargs)
+
+        if update_mempool:
+            self.set_memory_key_value(key, secret)
+        return secret
+
+    def get_azure_secret_value(self, vault_name, key_name, version=None, **kwargs):
+        '''
+        @dev get secret from azure key vault
+        @param vault_name name of key vault
+        @param key_name the name of secret
+        @param key_value the value of secret
+        @return value of secret(KeyVaultSecret)
+        '''
+        secret = None
+        update_mempool = True
+        key = self.create_memory_key_with_args(vault_name, key_name, version, "value")
+        if self.use_mempool:
+            secret = self.get_memory_key_value(key)
+            if not secret:
+                secret = azure_get_secret(vault_name, key_name, version, **kwargs).value
+            else:
+                update_mempool = False
+        else:
+            secret = azure_get_secret(vault_name, key_name, version, **kwargs).value
+
+        if update_mempool:
+            self.set_memory_key_value(key, secret)
+        return secret
+
+    def set_azure_secret(self, vault_name, key_name, key_value, **kwargs):
+        '''
+        @def set a secret value. If name is in use, create a new version of the secret. If not, create a new secret.
+        @param vault_name name of key vault
+        @param key_name the name of secret
+        @param key_value the value of secret
+        @param kwargs
+            enabled (bool) – Whether the secret is enabled for use.
+            tags (dict[str, str]) – Application specific metadata in the form of key-value pairs.
+            content_type (str) – An arbitrary string indicating the type of the secret, e.g. ‘password’
+            not_before (datetime) – Not before date of the secret in UTC
+            expires_on (datetime) – Expiry date of the secret in UTC
+        @return KeyVaultSecret
+        '''
+
+        ret = azure_set_secret(vault_name, key_name, key_value, **kwargs)
+        self.del_memory_value(self.create_memory_key_with_args(vault_name, key_name))
+        return ret
+
+    def del_azure_secret(self, vault_name, key_name, **kwargs):
+        self.del_memory_value(self.create_memory_key_with_args(vault_name, key_name))
+        return azure_del_secret(vault_name, key_name, **kwargs)
+
+    def get_azure_deleted_secret(self, vault_name, key_name, **kwargs):
+        '''
+        @dev get secret from azure key vault
+        @param vault_name key vault name
+        @param key_name sercrt's key
+        @return secret(DeletedSecret)
+        '''
+        return azure_get_deleted_secret(vault_name, key_name, **kwargs)
+
+    def get_azure_deleted_secret_id(self, vault_name, key_name, **kwargs):
+        '''
+        @dev get secret from azure key vault
+        @param vault_name key vault name
+        @param key_name sercrt's key
+        @return id of secret(DeletedSecret)
+        '''
+        return self.get_azure_deleted_secret(vault_name, key_name, **kwargs).id
+
+    def purge_deleted_secret(self, vault_name, key_name, **kwargs):
+        '''
+        @dev purge deleted secret from azure key vault
+        @param vault_name key vault name
+        @param key_name sercrt's key
+        '''
+        return azure_purge_deleted_secret(self, vault_name, key_name, **kwargs)
+
+    def set_azure_key_path(self, azure_name, key_vault_name, key_name):
+        if not self.azure_key_vault.is_exists(azure_name):
+            self.azure_key_vault.add(azure_name)
+
+        return self.azure_key_vault.set(azure_name, key_vault_name, key_name)
+
+    def get_azure_secrets_keys(self, vault_name):
+        return azure_get_secrets_keys(vault_name)
+
+    def create_memory_key(self, name):
+        if name.startswith(self.key_head_flag):
+            return name
+        return self.make_md5(f"{self.key_head_flag}_{name}")
+
+    def create_memory_key_with_args(self, *args):
+        name = '_'.join([str(arg) for arg in args])
+        return self.create_memory_key(name)
+
+    def set_key_source(self, key_source=key_source.MEMORY):
+        setattr(self, "key_source", key_source)
+
+    def set_memory_key_value(self, name, value):
+        return self.__mempool_secrets.update({self.create_memory_key(name): value})
+
+    @split_line
+    def get_memory_key_value(self, name):
+        return self.__mempool_secrets.get(self.create_memory_key(name), None)
+
+    @split_line
+    def del_memory_value(self, key_start):
+        for key in self.__mempool_secrets:
+            if key.startswith(key_start):
+                self.__mempool_secrets[key] = None
+
+    def __getatter__(self, name):
+        if getattr(self, name):
+            return getattr(self, name)
+
+        return safemsgclient()
+
+    def __call__(self, *args, **kwargs):
+        pass

+ 96 - 0
Azure/container_volume.py

@@ -0,0 +1,96 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from typing import Any, Dict
+
+from azure.mgmt.containerinstance.models import AzureFileVolume, Volume
+
+from airflow.hooks.base import BaseHook
+
+
+class AzureContainerVolumeHook(BaseHook):
+    """
+    A hook which wraps an Azure Volume.
+
+    :param azure_container_volume_conn_id: Reference to the
+        :ref:`Azure Container Volume connection id <howto/connection:azure_container_volume>`
+        of an Azure account of which container volumes should be used.
+    """
+
+    conn_name_attr = "azure_container_volume_conn_id"
+    default_conn_name = 'azure_container_volume_default'
+    conn_type = 'azure_container_volume'
+    hook_name = 'Azure Container Volume'
+
+    def __init__(self, azure_container_volume_conn_id: str = 'azure_container_volume_default') -> None:
+        super().__init__()
+        self.conn_id = azure_container_volume_conn_id
+
+    @staticmethod
+    def get_connection_form_widgets() -> Dict[str, Any]:
+        """Returns connection widgets to add to connection form"""
+        from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget
+        from flask_babel import lazy_gettext
+        from wtforms import PasswordField
+
+        return {
+            "extra__azure_container_volume__connection_string": PasswordField(
+                lazy_gettext('Blob Storage Connection String (optional)'), widget=BS3PasswordFieldWidget()
+            ),
+        }
+
+    @staticmethod
+    def get_ui_field_behaviour() -> Dict[str, Any]:
+        """Returns custom field behaviour"""
+        return {
+            "hidden_fields": ['schema', 'port', 'host', "extra"],
+            "relabeling": {
+                'login': 'Azure Client ID',
+                'password': 'Azure Secret',
+            },
+            "placeholders": {
+                'login': 'client_id (token credentials auth)',
+                'password': 'secret (token credentials auth)',
+                'extra__azure_container_volume__connection_string': 'connection string auth',
+            },
+        }
+
+    def get_storagekey(self) -> str:
+        """Get Azure File Volume storage key"""
+        conn = self.get_connection(self.conn_id)
+        service_options = conn.extra_dejson
+
+        if 'extra__azure_container_volume__connection_string' in service_options:
+            for keyvalue in service_options['extra__azure_container_volume__connection_string'].split(";"):
+                key, value = keyvalue.split("=", 1)
+                if key == "AccountKey":
+                    return value
+        return conn.password
+
+    def get_file_volume(
+        self, mount_name: str, share_name: str, storage_account_name: str, read_only: bool = False
+    ) -> Volume:
+        """Get Azure File Volume"""
+        return Volume(
+            name=mount_name,
+            azure_file=AzureFileVolume(
+                share_name=share_name,
+                storage_account_name=storage_account_name,
+                read_only=read_only,
+                storage_account_key=self.get_storagekey(),
+            ),
+        )

+ 227 - 0
Azure/data_lake.py

@@ -0,0 +1,227 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+"""
+This module contains integration with Azure Data Lake.
+
+AzureDataLakeHook communicates via a REST API compatible with WebHDFS. Make sure that a
+Airflow connection of type `azure_data_lake` exists. Authorization can be done by supplying a
+login (=Client ID), password (=Client Secret) and extra fields tenant (Tenant) and account_name (Account Name)
+(see connection `azure_data_lake_default` for an example).
+"""
+from typing import Any, Dict, Optional
+
+from azure.datalake.store import core, lib, multithread
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base import BaseHook
+
+
+class AzureDataLakeHook(BaseHook):
+    """
+    Interacts with Azure Data Lake.
+
+    Client ID and client secret should be in user and password parameters.
+    Tenant and account name should be extra field as
+    {"tenant": "<TENANT>", "account_name": "ACCOUNT_NAME"}.
+
+    :param azure_data_lake_conn_id: Reference to the :ref:`Azure Data Lake connection<howto/connection:adl>`.
+    """
+
+    conn_name_attr = 'azure_data_lake_conn_id'
+    default_conn_name = 'azure_data_lake_default'
+    conn_type = 'azure_data_lake'
+    hook_name = 'Azure Data Lake'
+
+    @staticmethod
+    def get_connection_form_widgets() -> Dict[str, Any]:
+        """Returns connection widgets to add to connection form"""
+        from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
+        from flask_babel import lazy_gettext
+        from wtforms import StringField
+
+        return {
+            "extra__azure_data_lake__tenant": StringField(
+                lazy_gettext('Azure Tenant ID'), widget=BS3TextFieldWidget()
+            ),
+            "extra__azure_data_lake__account_name": StringField(
+                lazy_gettext('Azure DataLake Store Name'), widget=BS3TextFieldWidget()
+            ),
+        }
+
+    @staticmethod
+    def get_ui_field_behaviour() -> Dict[str, Any]:
+        """Returns custom field behaviour"""
+        return {
+            "hidden_fields": ['schema', 'port', 'host', 'extra'],
+            "relabeling": {
+                'login': 'Azure Client ID',
+                'password': 'Azure Client Secret',
+            },
+            "placeholders": {
+                'login': 'client id',
+                'password': 'secret',
+                'extra__azure_data_lake__tenant': 'tenant id',
+                'extra__azure_data_lake__account_name': 'datalake store',
+            },
+        }
+
+    def __init__(self, azure_data_lake_conn_id: str = default_conn_name) -> None:
+        super().__init__()
+        self.conn_id = azure_data_lake_conn_id
+        self._conn: Optional[core.AzureDLFileSystem] = None
+        self.account_name: Optional[str] = None
+
+    def get_conn(self) -> core.AzureDLFileSystem:
+        """Return a AzureDLFileSystem object."""
+        if not self._conn:
+            conn = self.get_connection(self.conn_id)
+            service_options = conn.extra_dejson
+            self.account_name = service_options.get('account_name') or service_options.get(
+                'extra__azure_data_lake__account_name'
+            )
+            tenant = service_options.get('tenant') or service_options.get('extra__azure_data_lake__tenant')
+
+            adl_creds = lib.auth(tenant_id=tenant, client_secret=conn.password, client_id=conn.login)
+            self._conn = core.AzureDLFileSystem(adl_creds, store_name=self.account_name)
+            self._conn.connect()
+        return self._conn
+
+    def check_for_file(self, file_path: str) -> bool:
+        """
+        Check if a file exists on Azure Data Lake.
+
+        :param file_path: Path and name of the file.
+        :return: True if the file exists, False otherwise.
+        :rtype: bool
+        """
+        try:
+            files = self.get_conn().glob(file_path, details=False, invalidate_cache=True)
+            return len(files) == 1
+        except FileNotFoundError:
+            return False
+
+    def upload_file(
+        self,
+        local_path: str,
+        remote_path: str,
+        nthreads: int = 64,
+        overwrite: bool = True,
+        buffersize: int = 4194304,
+        blocksize: int = 4194304,
+        **kwargs,
+    ) -> None:
+        """
+        Upload a file to Azure Data Lake.
+
+        :param local_path: local path. Can be single file, directory (in which case,
+            upload recursively) or glob pattern. Recursive glob patterns using `**`
+            are not supported.
+        :param remote_path: Remote path to upload to; if multiple files, this is the
+            directory root to write within.
+        :param nthreads: Number of threads to use. If None, uses the number of cores.
+        :param overwrite: Whether to forcibly overwrite existing files/directories.
+            If False and remote path is a directory, will quit regardless if any files
+            would be overwritten or not. If True, only matching filenames are actually
+            overwritten.
+        :param buffersize: int [2**22]
+            Number of bytes for internal buffer. This block cannot be bigger than
+            a chunk and cannot be smaller than a block.
+        :param blocksize: int [2**22]
+            Number of bytes for a block. Within each chunk, we write a smaller
+            block for each API call. This block cannot be bigger than a chunk.
+        """
+        multithread.ADLUploader(
+            self.get_conn(),
+            lpath=local_path,
+            rpath=remote_path,
+            nthreads=nthreads,
+            overwrite=overwrite,
+            buffersize=buffersize,
+            blocksize=blocksize,
+            **kwargs,
+        )
+
+    def download_file(
+        self,
+        local_path: str,
+        remote_path: str,
+        nthreads: int = 64,
+        overwrite: bool = True,
+        buffersize: int = 4194304,
+        blocksize: int = 4194304,
+        **kwargs,
+    ) -> None:
+        """
+        Download a file from Azure Blob Storage.
+
+        :param local_path: local path. If downloading a single file, will write to this
+            specific file, unless it is an existing directory, in which case a file is
+            created within it. If downloading multiple files, this is the root
+            directory to write within. Will create directories as required.
+        :param remote_path: remote path/globstring to use to find remote files.
+            Recursive glob patterns using `**` are not supported.
+        :param nthreads: Number of threads to use. If None, uses the number of cores.
+        :param overwrite: Whether to forcibly overwrite existing files/directories.
+            If False and remote path is a directory, will quit regardless if any files
+            would be overwritten or not. If True, only matching filenames are actually
+            overwritten.
+        :param buffersize: int [2**22]
+            Number of bytes for internal buffer. This block cannot be bigger than
+            a chunk and cannot be smaller than a block.
+        :param blocksize: int [2**22]
+            Number of bytes for a block. Within each chunk, we write a smaller
+            block for each API call. This block cannot be bigger than a chunk.
+        """
+        multithread.ADLDownloader(
+            self.get_conn(),
+            lpath=local_path,
+            rpath=remote_path,
+            nthreads=nthreads,
+            overwrite=overwrite,
+            buffersize=buffersize,
+            blocksize=blocksize,
+            **kwargs,
+        )
+
+    def list(self, path: str) -> list:
+        """
+        List files in Azure Data Lake Storage
+
+        :param path: full path/globstring to use to list files in ADLS
+        """
+        if "*" in path:
+            return self.get_conn().glob(path)
+        else:
+            return self.get_conn().walk(path)
+
+    def remove(self, path: str, recursive: bool = False, ignore_not_found: bool = True) -> None:
+        """
+        Remove files in Azure Data Lake Storage
+
+        :param path: A directory or file to remove in ADLS
+        :param recursive: Whether to loop into directories in the location and remove the files
+        :param ignore_not_found: Whether to raise error if file to delete is not found
+        """
+        try:
+            self.get_conn().remove(path=path, recursive=recursive)
+        except FileNotFoundError:
+            if ignore_not_found:
+                self.log.info("File %s not found", path)
+            else:
+                raise AirflowException(f"File {path} not found")

+ 91 - 0
Azure/reproduce-14067.py

@@ -0,0 +1,91 @@
+# Sample code to reproduce timeout issue described in
+# https://github.com/Azure/azure-sdk-for-python/issues/14067
+
+# Run mitmproxy and enter interceptions mode
+# > type i  and then specify .*
+# Allow the first request to AAD to pass hitting A or double a
+
+mitm_proxy= "http://127.0.0.1:8080"
+container_name= "issue14067"
+blob_to_read = "debug.log"
+
+# Configure a proxy
+# https://docs.microsoft.com/en-us/azure/developer/python/azure-sdk-configure-proxy?tabs=cmd
+import io
+from io import BytesIO
+import os
+os.environ["HTTP_PROXY"] = mitm_proxy
+os.environ["HTTPS_PROXY"] = mitm_proxy
+
+# Retrieve the storage account and the storage key
+import json
+settings= {}
+with open('./settings.json') as f:
+    settings = json.load(f)
+account_name = settings["STORAGE_ACCOUNT_NAME"]
+
+# Configure identity that has "Storage Blob Data Reader" access
+os.environ["AZURE_CLIENT_ID"] = settings["AZURE_CLIENT_ID"]
+os.environ["AZURE_CLIENT_SECRET"] = settings["AZURE_CLIENT_SECRET"]
+os.environ["AZURE_TENANT_ID"] = settings["AZURE_TENANT_ID"]
+
+# Create the client
+from azure.storage.blob.aio import (
+    BlobServiceClient,
+    ContainerClient,
+    BlobClient,
+)
+from azure.core.exceptions import (
+    ResourceNotFoundError,
+    ClientAuthenticationError
+)
+from azure.identity.aio import DefaultAzureCredential
+
+async def download_blob_using_blobservice(account_name: str, credential: DefaultAzureCredential, container_name:str , blob_name: str, file_stream: io.BytesIO):
+    try:
+        # Timeout didn't work on this code...
+        blob_service = BlobServiceClient(f"{account_name}.blob.core.windows.net", credential=credential, connection_timeout=1, read_timeout=1)
+        blob_client = blob_service.get_blob_client(container_name, blob_name)
+        storage_stream_downloader = await blob_client.download_blob()
+        await storage_stream_downloader.readinto(file_stream)
+        return
+    except ResourceNotFoundError:
+        raise KeyError(blob_name)
+    except ClientAuthenticationError:
+        raise
+
+
+async def download_blob_using_blobclient(account_name: str, credential:DefaultAzureCredential, container_name:str , blob_name: str, file_stream: io.BytesIO):
+    try:
+        blob_client = BlobClient(f"{account_name}.blob.core.windows.net", credential=credential, container_name=container_name, blob_name=blob_name, connection_timeout=1, read_timeout=1)
+        storage_stream_downloader = await blob_client.download_blob()
+        await storage_stream_downloader.readinto(file_stream)
+        return
+    except ResourceNotFoundError:
+        raise KeyError(blob_name)
+    except ClientAuthenticationError:
+        raise
+
+# Execute method
+from io import (
+    BytesIO,
+    TextIOWrapper
+)
+import asyncio
+
+def execute_code(loop, timeout=None):
+    with BytesIO() as file_stream:
+        service_principal = DefaultAzureCredential(exclude_cli_credential=True)
+        future = asyncio.run_coroutine_threadsafe(
+            download_blob_using_blobclient(account_name,service_principal, container_name, blob_to_read, file_stream),
+            loop=loop)
+        future.result(timeout)
+        file_stream.flush()
+        file_stream.seek(0)
+        bw=TextIOWrapper(file_stream).read()
+        print(bw)
+        return
+
+loop = asyncio.get_event_loop()
+future = loop.run_in_executor(None, execute_code, loop)
+loop.run_until_complete(future)

+ 395 - 0
Azure/submit_azureml_pytest.py

@@ -0,0 +1,395 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""
+This python script sets up an environment on AzureML and submits a
+script to it to run pytest.  It is usually intended to be used as
+part of a DevOps pipeline which runs testing on a github repo but
+can also be used from command line.
+
+Many parameters are set to default values and some are expected to be passed
+in from either the DevOps pipeline or command line.
+If calling from command line, there are some parameters you must pass in for
+your job to run.
+
+
+Args:
+    Required:
+    --clustername (str): the Azure cluster for this run. It can already exist
+                         or it will be created.
+    --subid       (str): the Azure subscription id
+
+    Optional but suggested, this info will be stored on Azure as
+    text information as part of the experiment:
+    --pr          (str): the Github PR number
+    --reponame    (str): the Github repository name
+    --branch      (str): the branch being run
+                    It is also possible to put any text string in these.
+Example:
+    Usually, this script is run by a DevOps pipeline. It can also be
+    run from cmd line.
+    >>> python tests/ci/refac.py --clustername 'cluster-d3-v2'
+                                 --subid '12345678-9012-3456-abcd-123456789012'
+                                 --pr '666'
+                                 --reponame 'Recommenders'
+                                 --branch 'staging'
+"""
+import argparse
+import logging
+
+from azureml.core.authentication import AzureCliAuthentication
+from azureml.core import Workspace
+from azureml.core import Experiment
+from azureml.core.runconfig import RunConfiguration
+from azureml.core.conda_dependencies import CondaDependencies
+from azureml.core.script_run_config import ScriptRunConfig
+from azureml.core.compute import ComputeTarget, AmlCompute
+from azureml.core.compute_target import ComputeTargetException
+from azureml.core.workspace import WorkspaceException
+
+
+def setup_workspace(workspace_name, subscription_id, resource_group, cli_auth,
+                    location):
+    """
+    This sets up an Azure Workspace.
+    An existing Azure Workspace is used or a new one is created if needed for
+    the pytest run.
+
+    Args:
+        workspace_name  (str): Centralized location on Azure to work
+                               with all the artifacts used by AzureML
+                               service
+        subscription_id (str): the Azure subscription id
+        resource_group  (str): Azure Resource Groups are logical collections of
+                         assets associated with a project. Resource groups
+                         make it easy to track or delete all resources
+                         associated with a project by tracking or deleting
+                         the Resource group.
+        cli_auth         Azure authentication
+        location        (str): workspace reference
+
+    Returns:
+        ws: workspace reference
+    """
+    logger.debug('setup: workspace_name is {}'.format(workspace_name))
+    logger.debug('setup: resource_group is {}'.format(resource_group))
+    logger.debug('setup: subid is {}'.format(subscription_id))
+    logger.debug('setup: location is {}'.format(location))
+
+    try:
+            # use existing workspace if there is one
+            ws = Workspace.get(
+                name=workspace_name,
+                subscription_id=subscription_id,
+                resource_group=resource_group,
+                auth=cli_auth
+            )
+    except WorkspaceException:
+            # this call might take a minute or two.
+            logger.debug('Creating new workspace')
+            ws = Workspace.create(
+                name=workspace_name,
+                subscription_id=subscription_id,
+                resource_group=resource_group,
+                # create_resource_group=True,
+                location=location,
+                auth=cli_auth
+            )
+    return ws
+
+
+def setup_persistent_compute_target(workspace, cluster_name, vm_size,
+                                    max_nodes):
+    """
+    Set up a persistent compute target on AzureML.
+    A persistent compute target runs noticeably faster than a
+    regular compute target for subsequent runs.  The benefit
+    is that AzureML manages turning the compute on/off as needed for
+    each job so the user does not need to do this.
+
+    Args:
+        workspace    (str): Centralized location on Azure to work with
+                         all the
+                                artifacts used by AzureML service
+        cluster_name (str): the Azure cluster for this run. It can
+                            already exist or it will be created.
+        vm_size      (str): Azure VM size, like STANDARD_D3_V2
+        max_nodes    (int): Number of VMs, max_nodes=4 will
+                            autoscale up to 4 VMs
+    Returns:
+        cpu_cluster : cluster reference
+    """
+    # setting vmsize and num nodes creates a persistent AzureML
+    # compute resource
+
+    logger.debug("setup: cluster_name {}".format(cluster_name))
+    # https://docs.microsoft.com/en-us/azure/machine-learning/service/how-to-set-up-training-targets
+
+    try:
+        cpu_cluster = ComputeTarget(workspace=workspace, name=cluster_name)
+        logger.debug('setup: Found existing cluster, use it.')
+    except ComputeTargetException:
+        logger.debug('setup: create cluster')
+        compute_config = AmlCompute.provisioning_configuration(
+                       vm_size=vm_size,
+                       max_nodes=max_nodes)
+        cpu_cluster = ComputeTarget.create(workspace,
+                                           cluster_name,
+                                           compute_config)
+    cpu_cluster.wait_for_completion(show_output=True)
+    return cpu_cluster
+
+
+def create_run_config(cpu_cluster, docker_proc_type, conda_env_file):
+    """
+    AzureML requires the run environment to be setup prior to submission.
+    This configures a docker persistent compute.  Even though
+    it is called Persistent compute, AzureML handles startup/shutdown
+    of the compute environment.
+
+    Args:
+        cpu_cluster      (str) : Names the cluster for the test
+                                 In the case of unit tests, any of
+                                 the following:
+                                 - Reco_cpu_test
+                                 - Reco_gpu_test
+        docker_proc_type (str) : processor type, cpu or gpu
+        conda_env_file   (str) : filename which contains info to
+                                 set up conda env
+    Return:
+          run_amlcompute : AzureML run config
+    """
+
+    # runconfig with max_run_duration_seconds did not work, check why:
+    # run_amlcompute = RunConfiguration(max_run_duration_seconds=60*30)
+    run_amlcompute = RunConfiguration()
+    run_amlcompute.target = cpu_cluster
+    run_amlcompute.environment.docker.enabled = True
+    run_amlcompute.environment.docker.base_image = docker_proc_type
+
+    # Use conda_dependencies.yml to create a conda environment in
+    # the Docker image for execution
+    # False means the user will provide a conda file for setup
+    # True means the user will manually configure the environment
+    run_amlcompute.environment.python.user_managed_dependencies = False
+    run_amlcompute.environment.python.conda_dependencies = CondaDependencies(
+            conda_dependencies_file_path=conda_env_file)
+    return run_amlcompute
+
+
+def create_experiment(workspace, experiment_name):
+    """
+    AzureML requires an experiment as a container of trials.
+    This will either create a new experiment or use an
+    existing one.
+
+    Args:
+        workspace (str) : name of AzureML workspace
+        experiment_name (str) : AzureML experiment name
+    Return:
+        exp - AzureML experiment
+    """
+
+    logger.debug('create: experiment_name {}'.format(experiment_name))
+    exp = Experiment(workspace=workspace, name=experiment_name)
+    return(exp)
+
+
+def submit_experiment_to_azureml(test, test_folder, test_markers, junitxml,
+                                 run_config, experiment):
+
+    """
+    Submitting the experiment to AzureML actually runs the script.
+
+    Args:
+        test         (str) - pytest script, folder/test
+                             such as ./tests/ci/run_pytest.py
+        test_folder  (str) - folder where tests to run are stored,
+                             like ./tests/unit
+        test_markers (str) - test markers used by pytest
+                             "not notebooks and not spark and not gpu"
+        junitxml     (str) - file of output summary of tests run
+                             note "--junitxml" is required as part
+                             of the string
+                             Example: "--junitxml reports/test-unit.xml"
+        run_config - environment configuration
+        experiment - instance of an Experiment, a collection of
+                     trials where each trial is a run.
+    Return:
+          run : AzureML run or trial
+    """
+
+    logger.debug('submit: testfolder {}'.format(test_folder))
+    logger.debug('junitxml: {}'.format(junitxml))
+    project_folder = "."
+
+    script_run_config = ScriptRunConfig(source_directory=project_folder,
+                                        script=test,
+                                        run_config=run_config,
+                                        arguments=["--testfolder",
+                                                   test_folder,
+                                                   "--testmarkers",
+                                                   test_markers,
+                                                   "--xmlname",
+                                                   junitxml]
+                                        )
+    run = experiment.submit(script_run_config)
+    # waits only for configuration to complete
+    run.wait_for_completion(show_output=True, wait_post_processing=True)
+
+    # test logs can also be found on azure
+    # go to azure portal to see log in azure ws and look for experiment name
+    # and look for individual run
+    logger.debug('files {}'.format(run.get_file_names))
+
+    return run
+
+
+def create_arg_parser():
+    """
+    Many of the argument defaults are used as arg_parser makes it easy to
+    use defaults. The user has many options they can select.
+    """
+
+    parser = argparse.ArgumentParser(description='Process some inputs')
+    # script to run pytest
+    parser.add_argument("--test",
+                        action="store",
+                        default="./tests/ci/run_pytest.py",
+                        help="location of script to run pytest")
+    # test folder
+    parser.add_argument("--testfolder",
+                        action="store",
+                        default="./tests/unit",
+                        help="folder where tests are stored")
+    # pytest test markers
+    parser.add_argument("--testmarkers",
+                        action="store",
+                        default="not notebooks and not spark and not gpu",
+                        help="pytest markers indicate tests to run")
+    # test summary file
+    parser.add_argument("--junitxml",
+                        action="store",
+                        default="reports/test-unit.xml",
+                        help="file for returned test results")
+    # max num nodes in Azure cluster
+    parser.add_argument("--maxnodes",
+                        action="store",
+                        default=4,
+                        help="specify the maximum number of nodes for the run")
+    # Azure resource group
+    parser.add_argument("--rg",
+                        action="store",
+                        default="recommender",
+                        help="Azure Resource Group")
+    # AzureML workspace Name
+    parser.add_argument("--wsname",
+                        action="store",
+                        default="RecoWS",
+                        help="AzureML workspace name")
+    # AzureML clustername
+    parser.add_argument("--clustername",
+                        action="store",
+                        default="amlcompute",
+                        help="Set name of Azure cluster")
+    # Azure VM size
+    parser.add_argument("--vmsize",
+                        action="store",
+                        default="STANDARD_D3_V2",
+                        help="Set the size of the VM either STANDARD_D3_V2")
+    # cpu or gpu
+    parser.add_argument("--dockerproc",
+                        action="store",
+                        default="cpu",
+                        help="Base image used in docker container")
+    # Azure subscription id, when used in a pipeline, it is stored in keyvault
+    parser.add_argument("--subid",
+                        action="store",
+                        default="123456",
+                        help="Azure Subscription ID")
+    # ./reco.yaml is created in the azure devops pipeline.
+    # Not recommended to change this.
+    parser.add_argument("--condafile",
+                        action="store",
+                        default="./reco.yaml",
+                        help="file with environment variables")
+    # AzureML experiment name
+    parser.add_argument("--expname",
+                        action="store",
+                        default="persistentAML",
+                        help="experiment name on Azure")
+    # Azure datacenter location
+    parser.add_argument("--location",
+                        default="EastUS",
+                        help="Azure location")
+    # github repo, stored in AzureML experiment for info purposes
+    parser.add_argument("--reponame",
+                        action="store",
+                        default="--reponame MyGithubRepo",
+                        help="GitHub repo being tested")
+    # github branch, stored in AzureML experiment for info purposes
+    parser.add_argument("--branch",
+                        action="store",
+                        default="--branch MyGithubBranch",
+                        help=" Identify the branch test test is run on")
+    # github pull request, stored in AzureML experiment for info purposes
+    parser.add_argument("--pr",
+                        action="store",
+                        default="--pr PRTestRun",
+                        help="If a pr triggered the test, list it here")
+
+    args = parser.parse_args()
+
+    return args
+
+
+if __name__ == "__main__":
+    logger = logging.getLogger('submit_azureml_pytest.py')
+    # logger.setLevel(logging.DEBUG)
+    # logging.basicConfig(level=logging.DEBUG)
+    args = create_arg_parser()
+
+    if args.dockerproc == "cpu":
+        from azureml.core.runconfig import DEFAULT_CPU_IMAGE
+        docker_proc_type = DEFAULT_CPU_IMAGE
+    else:
+        from azureml.core.runconfig import DEFAULT_GPU_IMAGE
+        docker_proc_type = DEFAULT_GPU_IMAGE
+
+    cli_auth = AzureCliAuthentication()
+
+    workspace = setup_workspace(workspace_name=args.wsname,
+                                subscription_id=args.subid,
+                                resource_group=args.rg,
+                                cli_auth=cli_auth,
+                                location=args.location)
+
+    cpu_cluster = setup_persistent_compute_target(
+                      workspace=workspace,
+                      cluster_name=args.clustername,
+                      vm_size=args.vmsize,
+                      max_nodes=args.maxnodes)
+
+    run_config = create_run_config(cpu_cluster=cpu_cluster,
+                                   docker_proc_type=docker_proc_type,
+                                   conda_env_file=args.condafile)
+
+    logger.info('exp: In Azure, look for experiment named {}'.format(
+                args.expname))
+
+    # create new or use existing experiment
+    experiment = Experiment(workspace=workspace, name=args.expname)
+    run = submit_experiment_to_azureml(test=args.test,
+                                       test_folder=args.testfolder,
+                                       test_markers=args.testmarkers,
+                                       junitxml=args.junitxml,
+                                       run_config=run_config,
+                                       experiment=experiment)
+
+    # add helpful information to experiment on Azure
+    run.tag('RepoName', args.reponame)
+    run.tag('Branch', args.branch)
+    run.tag('PR', args.pr)
+    # download files from AzureML
+    run.download_files(prefix='reports', output_paths='./reports')
+    run.complete()

+ 193 - 0
Azure/test_adx.py

@@ -0,0 +1,193 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import json
+import unittest
+from unittest import mock
+
+import pytest
+from azure.kusto.data.request import ClientRequestProperties, KustoClient, KustoConnectionStringBuilder
+
+from airflow.exceptions import AirflowException
+from airflow.models import Connection
+from airflow.providers.microsoft.azure.hooks.adx import AzureDataExplorerHook
+from airflow.utils import db
+from airflow.utils.session import create_session
+
+ADX_TEST_CONN_ID = 'adx_test_connection_id'
+
+
+class TestAzureDataExplorerHook(unittest.TestCase):
+    def tearDown(self):
+        super().tearDown()
+        with create_session() as session:
+            session.query(Connection).filter(Connection.conn_id == ADX_TEST_CONN_ID).delete()
+
+    def test_conn_missing_method(self):
+        db.merge_conn(
+            Connection(
+                conn_id=ADX_TEST_CONN_ID,
+                conn_type='azure_data_explorer',
+                login='client_id',
+                password='client secret',
+                host='https://help.kusto.windows.net',
+                extra=json.dumps({}),
+            )
+        )
+        with pytest.raises(AirflowException) as ctx:
+            AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
+            assert 'is missing: `extra__azure_data_explorer__auth_method`' in str(ctx.value)
+
+    def test_conn_unknown_method(self):
+        db.merge_conn(
+            Connection(
+                conn_id=ADX_TEST_CONN_ID,
+                conn_type='azure_data_explorer',
+                login='client_id',
+                password='client secret',
+                host='https://help.kusto.windows.net',
+                extra=json.dumps({'extra__azure_data_explorer__auth_method': 'AAD_OTHER'}),
+            )
+        )
+        with pytest.raises(AirflowException) as ctx:
+            AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
+        assert 'Unknown authentication method: AAD_OTHER' in str(ctx.value)
+
+    def test_conn_missing_cluster(self):
+        db.merge_conn(
+            Connection(
+                conn_id=ADX_TEST_CONN_ID,
+                conn_type='azure_data_explorer',
+                login='client_id',
+                password='client secret',
+                extra=json.dumps({}),
+            )
+        )
+        with pytest.raises(AirflowException) as ctx:
+            AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
+        assert 'Host connection option is required' in str(ctx.value)
+
+    @mock.patch.object(KustoClient, '__init__')
+    def test_conn_method_aad_creds(self, mock_init):
+        mock_init.return_value = None
+        db.merge_conn(
+            Connection(
+                conn_id=ADX_TEST_CONN_ID,
+                conn_type='azure_data_explorer',
+                login='client_id',
+                password='client secret',
+                host='https://help.kusto.windows.net',
+                extra=json.dumps(
+                    {
+                        'extra__azure_data_explorer__tenant': 'tenant',
+                        'extra__azure_data_explorer__auth_method': 'AAD_CREDS',
+                    }
+                ),
+            )
+        )
+        AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
+        assert mock_init.called_with(
+            KustoConnectionStringBuilder.with_aad_user_password_authentication(
+                'https://help.kusto.windows.net', 'client_id', 'client secret', 'tenant'
+            )
+        )
+
+    @mock.patch.object(KustoClient, '__init__')
+    def test_conn_method_aad_app(self, mock_init):
+        mock_init.return_value = None
+        db.merge_conn(
+            Connection(
+                conn_id=ADX_TEST_CONN_ID,
+                conn_type='azure_data_explorer',
+                login='app_id',
+                password='app key',
+                host='https://help.kusto.windows.net',
+                extra=json.dumps(
+                    {
+                        'extra__azure_data_explorer__tenant': 'tenant',
+                        'extra__azure_data_explorer__auth_method': 'AAD_APP',
+                    }
+                ),
+            )
+        )
+        AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
+        assert mock_init.called_with(
+            KustoConnectionStringBuilder.with_aad_application_key_authentication(
+                'https://help.kusto.windows.net', 'app_id', 'app key', 'tenant'
+            )
+        )
+
+    @mock.patch.object(KustoClient, '__init__')
+    def test_conn_method_aad_app_cert(self, mock_init):
+        mock_init.return_value = None
+        db.merge_conn(
+            Connection(
+                conn_id=ADX_TEST_CONN_ID,
+                conn_type='azure_data_explorer',
+                login='client_id',
+                host='https://help.kusto.windows.net',
+                extra=json.dumps(
+                    {
+                        'extra__azure_data_explorer__tenant': 'tenant',
+                        'extra__azure_data_explorer__auth_method': 'AAD_APP_CERT',
+                        'extra__azure_data_explorer__certificate': 'PEM',
+                        'extra__azure_data_explorer__thumbprint': 'thumbprint',
+                    }
+                ),
+            )
+        )
+        AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
+        assert mock_init.called_with(
+            KustoConnectionStringBuilder.with_aad_application_certificate_authentication(
+                'https://help.kusto.windows.net', 'client_id', 'PEM', 'thumbprint', 'tenant'
+            )
+        )
+
+    @mock.patch.object(KustoClient, '__init__')
+    def test_conn_method_aad_device(self, mock_init):
+        mock_init.return_value = None
+        db.merge_conn(
+            Connection(
+                conn_id=ADX_TEST_CONN_ID,
+                conn_type='azure_data_explorer',
+                host='https://help.kusto.windows.net',
+                extra=json.dumps({'extra__azure_data_explorer__auth_method': 'AAD_DEVICE'}),
+            )
+        )
+        AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
+        assert mock_init.called_with(
+            KustoConnectionStringBuilder.with_aad_device_authentication('https://help.kusto.windows.net')
+        )
+
+    @mock.patch.object(KustoClient, 'execute')
+    def test_run_query(self, mock_execute):
+        mock_execute.return_value = None
+        db.merge_conn(
+            Connection(
+                conn_id=ADX_TEST_CONN_ID,
+                conn_type='azure_data_explorer',
+                host='https://help.kusto.windows.net',
+                extra=json.dumps({'extra__azure_data_explorer__auth_method': 'AAD_DEVICE'}),
+            )
+        )
+        hook = AzureDataExplorerHook(azure_data_explorer_conn_id=ADX_TEST_CONN_ID)
+        hook.run_query('Database', 'Logs | schema', options={'option1': 'option_value'})
+        properties = ClientRequestProperties()
+        properties.set_option('option1', 'option_value')
+        assert mock_execute.called_with('Database', 'Logs | schema', properties=properties)

+ 237 - 0
Azure/test_azure.py

@@ -0,0 +1,237 @@
+from subprocess import DEVNULL
+from typing import Generator
+from unittest.mock import Mock
+
+from azure.identity import DefaultAzureCredential
+from azure.storage.blob import (
+    BlobClient,
+    BlobProperties,
+    BlobServiceClient,
+    ContainerClient,
+)
+from pytest import fixture
+from pytest_mock import MockFixture
+
+from opta.core.azure import Azure
+from opta.layer import Layer, StructuredConfig
+
+
+@fixture()
+def azure_layer() -> Mock:
+    layer = Mock(spec=Layer)
+    layer.parent = None
+    layer.cloud = "azurerm"
+    layer.name = "blah"
+    layer.providers = {
+        "azurerm": {
+            "location": "centralus",
+            "tenant_id": "blahbc17-blah-blah-blah-blah291d395b",
+            "subscription_id": "blah99ae-blah-blah-blah-blahd2a04788",
+        }
+    }
+    layer.root.return_value = layer
+    layer.gen_providers.return_value = {
+        "terraform": {
+            "backend": {
+                "azurerm": {
+                    "resource_group_name": "dummy_resource_group",
+                    "storage_account_name": "dummy_storage_account",
+                    "container_name": "dummy_container_name",
+                    "key": "dummy_key",
+                }
+            }
+        },
+        "provider": {
+            "azurerm": {
+                "location": "centralus",
+                "tenant_id": "blahbc17-blah-blah-blah-blah291d395b",
+                "subscription_id": "blah99ae-blah-blah-blah-blahd2a04788",
+            }
+        },
+    }
+    layer.get_cluster_name.return_value = "mocked_cluster_name"
+    return layer
+
+
+@fixture(autouse=True)
+def reset_azure_creds() -> Generator:
+    Azure.credentials = None
+    yield
+
+
+class TestAzure:
+    def test_azure_set_kube_config(self, mocker: MockFixture, azure_layer: Mock) -> None:
+        mocked_ensure_installed = mocker.patch("opta.core.azure.ensure_installed")
+        mocker.patch(
+            "opta.core.azure.Azure.cluster_exist", return_value=True,
+        )
+        mocked_nice_run = mocker.patch("opta.core.azure.nice_run",)
+
+        Azure(azure_layer).set_kube_config()
+
+        mocked_ensure_installed.assert_has_calls([mocker.call("az")])
+        mocked_nice_run.assert_has_calls(
+            [
+                mocker.call(
+                    [
+                        "az",
+                        "aks",
+                        "get-credentials",
+                        "--resource-group",
+                        "dummy_resource_group",
+                        "--name",
+                        "mocked_cluster_name",
+                        "--admin",
+                        "--overwrite-existing",
+                        "--context",
+                        "dummy_resource_group-mocked_cluster_name",
+                    ],
+                    stdout=DEVNULL,
+                    check=True,
+                ),
+            ]
+        )
+
+    def test_get_credentials(self, mocker: MockFixture) -> None:
+        mocked_default_creds = mocker.patch("opta.core.azure.DefaultAzureCredential")
+        Azure.get_credentials()
+        mocked_default_creds.assert_called_once_with()
+
+    def test_get_remote_config(self, mocker: MockFixture, azure_layer: Mock) -> None:
+        mocked_creds = mocker.Mock()
+        mocked_default_creds = mocker.patch(
+            "opta.core.azure.DefaultAzureCredential", return_value=mocked_creds
+        )
+
+        mocked_container_client_instance = mocker.Mock()
+        mocked_container_client_instance.download_blob = mocker.Mock()
+        download_stream_mock = mocker.Mock()
+        download_stream_mock.readall = mocker.Mock(
+            return_value='{"opta_version":"1", "date": "mock_date", "original_spec": "mock_spec", "defaults": {}}'
+        )
+        mocked_container_client_instance.download_blob.return_value = download_stream_mock
+        mocked_container_client = mocker.patch(
+            "opta.core.azure.ContainerClient",
+            return_value=mocked_container_client_instance,
+        )
+        mocked_structured_config: StructuredConfig = {
+            "opta_version": "1",
+            "date": "mock_date",
+            "original_spec": "mock_spec",
+            "defaults": {},
+        }
+
+        assert Azure(azure_layer).get_remote_config() == mocked_structured_config
+
+        azure_layer.gen_providers.assert_called_once_with(0)
+        mocked_default_creds.assert_called_once_with()
+        mocked_container_client.assert_called_once_with(
+            account_url="https://dummy_storage_account.blob.core.windows.net",
+            container_name="dummy_container_name",
+            credential=mocked_creds,
+        )
+        mocked_container_client_instance.download_blob.assert_called_once_with(
+            f"opta_config/{azure_layer.name}"
+        )
+
+    def test_upload_opta_config(self, mocker: MockFixture, azure_layer: Mock) -> None:
+        Azure.credentials = None
+        mocked_creds = mocker.Mock()
+        mocked_default_creds = mocker.patch(
+            "opta.core.azure.DefaultAzureCredential", return_value=mocked_creds
+        )
+
+        mocked_container_client_instance = mocker.Mock()
+        mocked_container_client = mocker.patch(
+            "opta.core.azure.ContainerClient",
+            return_value=mocked_container_client_instance,
+        )
+        azure_layer.structured_config = mocker.Mock(return_value={"a": 1})
+
+        Azure(azure_layer).upload_opta_config()
+
+        azure_layer.gen_providers.assert_called_once_with(0)
+        mocked_default_creds.assert_called_once_with()
+        mocked_container_client.assert_called_once_with(
+            account_url="https://dummy_storage_account.blob.core.windows.net",
+            container_name="dummy_container_name",
+            credential=mocked_creds,
+        )
+        mocked_container_client_instance.upload_blob.assert_called_once_with(
+            name=f"opta_config/{azure_layer.name}", data='{"a": 1}', overwrite=True
+        )
+
+    def test_delete_opta_config(self, mocker: MockFixture, azure_layer: Mock) -> None:
+        Azure.credentials = None
+        mocked_creds = mocker.Mock()
+        mocked_default_creds = mocker.patch(
+            "opta.core.azure.DefaultAzureCredential", return_value=mocked_creds
+        )
+
+        mocked_container_client_instance = mocker.Mock()
+        mocked_container_client = mocker.patch(
+            "opta.core.azure.ContainerClient",
+            return_value=mocked_container_client_instance,
+        )
+
+        Azure(azure_layer).delete_opta_config()
+
+        azure_layer.gen_providers.assert_called_once_with(0)
+        mocked_default_creds.assert_called_once_with()
+        mocked_container_client.assert_called_once_with(
+            account_url="https://dummy_storage_account.blob.core.windows.net",
+            container_name="dummy_container_name",
+            credential=mocked_creds,
+        )
+        mocked_container_client_instance.delete_blob.assert_called_once_with(
+            f"opta_config/{azure_layer.name}", delete_snapshots="include"
+        )
+
+    def test_delete_remote_state(self, mocker: MockFixture, azure_layer: Mock) -> None:
+        Azure.credentials = None
+        mocked_creds = mocker.Mock()
+        mocked_default_creds = mocker.patch(
+            "opta.core.azure.DefaultAzureCredential", return_value=mocked_creds
+        )
+
+        mocked_container_client_instance = mocker.Mock()
+        mocked_container_client = mocker.patch(
+            "opta.core.azure.ContainerClient",
+            return_value=mocked_container_client_instance,
+        )
+
+        Azure(azure_layer).delete_remote_state()
+
+        azure_layer.gen_providers.assert_called_once_with(0)
+        mocked_default_creds.assert_called_once_with()
+        mocked_container_client.assert_called_once_with(
+            account_url="https://dummy_storage_account.blob.core.windows.net",
+            container_name="dummy_container_name",
+            credential=mocked_creds,
+        )
+        mocked_container_client_instance.delete_blob.assert_called_once_with(
+            azure_layer.name, delete_snapshots="include"
+        )
+
+    def test_get_terraform_lock_id(self, mocker: MockFixture, azure_layer: Mock) -> None:
+        mocker.patch(
+            "opta.core.azure.DefaultAzureCredential",
+            return_value=mocker.Mock(spec=DefaultAzureCredential),
+        )
+
+        mock_blob_service_client = mocker.Mock(spec=BlobServiceClient)
+        mock_container_client = mocker.Mock(spec=ContainerClient)
+        mock_blob_client = mocker.Mock(spec=BlobClient)
+        mock_blob_properties = mocker.Mock(spec=BlobProperties)
+        mock_blob_properties.metadata = {
+            "Terraformlockid": "J3siSUQiOiAibW9ja19sb2NrX2lkIn0n"
+        }
+
+        mocker.patch(
+            "opta.core.azure.BlobServiceClient", return_value=mock_blob_service_client
+        )
+        mock_blob_service_client.get_container_client.return_value = mock_container_client
+        mock_container_client.get_blob_client.return_value = mock_blob_client
+        mock_blob_client.get_blob_properties.return_value = mock_blob_properties
+
+        Azure(azure_layer).get_terraform_lock_id()

+ 167 - 0
Azure/test_azure_batch.py

@@ -0,0 +1,167 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import json
+import unittest
+from unittest import mock
+
+from azure.batch import BatchServiceClient, models as batch_models
+
+from airflow.models import Connection
+from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook
+from airflow.utils import db
+
+
+class TestAzureBatchHook(unittest.TestCase):
+    # set up the test environment
+    def setUp(self):
+        # set up the test variable
+        self.test_vm_conn_id = "test_azure_batch_vm"
+        self.test_cloud_conn_id = "test_azure_batch_cloud"
+        self.test_account_name = "test_account_name"
+        self.test_account_key = "test_account_key"
+        self.test_account_url = "http://test-endpoint:29000"
+        self.test_vm_size = "test-vm-size"
+        self.test_vm_publisher = "test.vm.publisher"
+        self.test_vm_offer = "test.vm.offer"
+        self.test_vm_sku = "test-sku"
+        self.test_cloud_os_family = "test-family"
+        self.test_cloud_os_version = "test-version"
+        self.test_node_agent_sku = "test-node-agent-sku"
+
+        # connect with vm configuration
+        db.merge_conn(
+            Connection(
+                conn_id=self.test_vm_conn_id,
+                conn_type="azure_batch",
+                extra=json.dumps({"extra__azure_batch__account_url": self.test_account_url}),
+            )
+        )
+        # connect with cloud service
+        db.merge_conn(
+            Connection(
+                conn_id=self.test_cloud_conn_id,
+                conn_type="azure_batch",
+                extra=json.dumps({"extra__azure_batch__account_url": self.test_account_url}),
+            )
+        )
+
+    def test_connection_and_client(self):
+        hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id)
+        assert isinstance(hook._connection(), Connection)
+        assert isinstance(hook.get_conn(), BatchServiceClient)
+
+    def test_configure_pool_with_vm_config(self):
+        hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id)
+        pool = hook.configure_pool(
+            pool_id='mypool',
+            vm_size="test_vm_size",
+            target_dedicated_nodes=1,
+            vm_publisher="test.vm.publisher",
+            vm_offer="test.vm.offer",
+            sku_starts_with="test-sku",
+        )
+        assert isinstance(pool, batch_models.PoolAddParameter)
+
+    def test_configure_pool_with_cloud_config(self):
+        hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id)
+        pool = hook.configure_pool(
+            pool_id='mypool',
+            vm_size="test_vm_size",
+            target_dedicated_nodes=1,
+            vm_publisher="test.vm.publisher",
+            vm_offer="test.vm.offer",
+            sku_starts_with="test-sku",
+        )
+        assert isinstance(pool, batch_models.PoolAddParameter)
+
+    def test_configure_pool_with_latest_vm(self):
+        with mock.patch(
+            "airflow.providers.microsoft.azure.hooks."
+            "batch.AzureBatchHook._get_latest_verified_image_vm_and_sku"
+        ) as mock_getvm:
+            hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id)
+            getvm_instance = mock_getvm
+            getvm_instance.return_value = ['test-image', 'test-sku']
+            pool = hook.configure_pool(
+                pool_id='mypool',
+                vm_size="test_vm_size",
+                use_latest_image_and_sku=True,
+                vm_publisher="test.vm.publisher",
+                vm_offer="test.vm.offer",
+                sku_starts_with="test-sku",
+            )
+            assert isinstance(pool, batch_models.PoolAddParameter)
+
+    @mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
+    def test_create_pool_with_vm_config(self, mock_batch):
+        hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id)
+        mock_instance = mock_batch.return_value.pool.add
+        pool = hook.configure_pool(
+            pool_id='mypool',
+            vm_size="test_vm_size",
+            target_dedicated_nodes=1,
+            vm_publisher="test.vm.publisher",
+            vm_offer="test.vm.offer",
+            sku_starts_with="test-sku",
+        )
+        hook.create_pool(pool=pool)
+        mock_instance.assert_called_once_with(pool)
+
+    @mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
+    def test_create_pool_with_cloud_config(self, mock_batch):
+        hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id)
+        mock_instance = mock_batch.return_value.pool.add
+        pool = hook.configure_pool(
+            pool_id='mypool',
+            vm_size="test_vm_size",
+            target_dedicated_nodes=1,
+            vm_publisher="test.vm.publisher",
+            vm_offer="test.vm.offer",
+            sku_starts_with="test-sku",
+        )
+        hook.create_pool(pool=pool)
+        mock_instance.assert_called_once_with(pool)
+
+    @mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
+    def test_wait_for_all_nodes(self, mock_batch):
+        # TODO: Add test
+        pass
+
+    @mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
+    def test_job_configuration_and_create_job(self, mock_batch):
+        hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id)
+        mock_instance = mock_batch.return_value.job.add
+        job = hook.configure_job(job_id='myjob', pool_id='mypool')
+        hook.create_job(job)
+        assert isinstance(job, batch_models.JobAddParameter)
+        mock_instance.assert_called_once_with(job)
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient')
+    def test_add_single_task_to_job(self, mock_batch):
+        hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id)
+        mock_instance = mock_batch.return_value.task.add
+        task = hook.configure_task(task_id="mytask", command_line="echo hello")
+        hook.add_single_task_to_job(job_id='myjob', task=task)
+        assert isinstance(task, batch_models.TaskAddParameter)
+        mock_instance.assert_called_once_with(job_id="myjob", task=task)
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient')
+    def test_wait_for_all_task_to_complete(self, mock_batch):
+        # TODO: Add test
+        pass

+ 99 - 0
Azure/test_azure_container_instance.py

@@ -0,0 +1,99 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import json
+import unittest
+from unittest.mock import patch
+
+from azure.mgmt.containerinstance.models import (
+    Container,
+    ContainerGroup,
+    Logs,
+    ResourceRequests,
+    ResourceRequirements,
+)
+
+from airflow.models import Connection
+from airflow.providers.microsoft.azure.hooks.container_instance import AzureContainerInstanceHook
+from airflow.utils import db
+
+
+class TestAzureContainerInstanceHook(unittest.TestCase):
+    def setUp(self):
+        db.merge_conn(
+            Connection(
+                conn_id='azure_container_instance_test',
+                conn_type='azure_container_instances',
+                login='login',
+                password='key',
+                extra=json.dumps({'tenantId': 'tenant_id', 'subscriptionId': 'subscription_id'}),
+            )
+        )
+
+        self.resources = ResourceRequirements(requests=ResourceRequests(memory_in_gb='4', cpu='1'))
+        with patch(
+            'azure.common.credentials.ServicePrincipalCredentials.__init__', autospec=True, return_value=None
+        ):
+            with patch('azure.mgmt.containerinstance.ContainerInstanceManagementClient'):
+                self.hook = AzureContainerInstanceHook(conn_id='azure_container_instance_test')
+
+    @patch('azure.mgmt.containerinstance.models.ContainerGroup')
+    @patch('azure.mgmt.containerinstance.operations.ContainerGroupsOperations.create_or_update')
+    def test_create_or_update(self, create_or_update_mock, container_group_mock):
+        self.hook.create_or_update('resource_group', 'aci-test', container_group_mock)
+        create_or_update_mock.assert_called_once_with('resource_group', 'aci-test', container_group_mock)
+
+    @patch('azure.mgmt.containerinstance.operations.ContainerGroupsOperations.get')
+    def test_get_state(self, get_state_mock):
+        self.hook.get_state('resource_group', 'aci-test')
+        get_state_mock.assert_called_once_with('resource_group', 'aci-test', raw=False)
+
+    @patch('azure.mgmt.containerinstance.operations.ContainerOperations.list_logs')
+    def test_get_logs(self, list_logs_mock):
+        expected_messages = ['log line 1\n', 'log line 2\n', 'log line 3\n']
+        logs = Logs(content=''.join(expected_messages))
+        list_logs_mock.return_value = logs
+
+        logs = self.hook.get_logs('resource_group', 'name', 'name')
+
+        assert logs == expected_messages
+
+    @patch('azure.mgmt.containerinstance.operations.ContainerGroupsOperations.delete')
+    def test_delete(self, delete_mock):
+        self.hook.delete('resource_group', 'aci-test')
+        delete_mock.assert_called_once_with('resource_group', 'aci-test')
+
+    @patch('azure.mgmt.containerinstance.operations.ContainerGroupsOperations.list_by_resource_group')
+    def test_exists_with_existing(self, list_mock):
+        list_mock.return_value = [
+            ContainerGroup(
+                os_type='Linux',
+                containers=[Container(name='test1', image='hello-world', resources=self.resources)],
+            )
+        ]
+        assert not self.hook.exists('test', 'test1')
+
+    @patch('azure.mgmt.containerinstance.operations.ContainerGroupsOperations.list_by_resource_group')
+    def test_exists_with_not_existing(self, list_mock):
+        list_mock.return_value = [
+            ContainerGroup(
+                os_type='Linux',
+                containers=[Container(name='test1', image='hello-world', resources=self.resources)],
+            )
+        ]
+        assert not self.hook.exists('test', 'not found')

+ 236 - 0
Azure/test_azure_cosmos.py

@@ -0,0 +1,236 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+import json
+import logging
+import unittest
+import uuid
+from unittest import mock
+
+import pytest
+from azure.cosmos.cosmos_client import CosmosClient
+
+from airflow.exceptions import AirflowException
+from airflow.models import Connection
+from airflow.providers.microsoft.azure.hooks.cosmos import AzureCosmosDBHook
+from airflow.utils import db
+
+
+class TestAzureCosmosDbHook(unittest.TestCase):
+
+    # Set up an environment to test with
+    def setUp(self):
+        # set up some test variables
+        self.test_end_point = 'https://test_endpoint:443'
+        self.test_master_key = 'magic_test_key'
+        self.test_database_name = 'test_database_name'
+        self.test_collection_name = 'test_collection_name'
+        self.test_database_default = 'test_database_default'
+        self.test_collection_default = 'test_collection_default'
+        db.merge_conn(
+            Connection(
+                conn_id='azure_cosmos_test_key_id',
+                conn_type='azure_cosmos',
+                login=self.test_end_point,
+                password=self.test_master_key,
+                extra=json.dumps(
+                    {
+                        'database_name': self.test_database_default,
+                        'collection_name': self.test_collection_default,
+                    }
+                ),
+            )
+        )
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient', autospec=True)
+    def test_client(self, mock_cosmos):
+        hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
+        assert hook._conn is None
+        assert isinstance(hook.get_conn(), CosmosClient)
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
+    def test_create_database(self, mock_cosmos):
+        hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
+        hook.create_database(self.test_database_name)
+        expected_calls = [mock.call().create_database('test_database_name')]
+        mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
+        mock_cosmos.assert_has_calls(expected_calls)
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
+    def test_create_database_exception(self, mock_cosmos):
+        hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
+        with pytest.raises(AirflowException):
+            hook.create_database(None)
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
+    def test_create_container_exception(self, mock_cosmos):
+        hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
+        with pytest.raises(AirflowException):
+            hook.create_collection(None)
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
+    def test_create_container(self, mock_cosmos):
+        hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
+        hook.create_collection(self.test_collection_name, self.test_database_name)
+        expected_calls = [
+            mock.call().get_database_client('test_database_name').create_container('test_collection_name')
+        ]
+        mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
+        mock_cosmos.assert_has_calls(expected_calls)
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
+    def test_create_container_default(self, mock_cosmos):
+        hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
+        hook.create_collection(self.test_collection_name)
+        expected_calls = [
+            mock.call().get_database_client('test_database_name').create_container('test_collection_name')
+        ]
+        mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
+        mock_cosmos.assert_has_calls(expected_calls)
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
+    def test_upsert_document_default(self, mock_cosmos):
+        test_id = str(uuid.uuid4())
+        # fmt: off
+        (mock_cosmos
+         .return_value
+         .get_database_client
+         .return_value
+         .get_container_client
+         .return_value
+         .upsert_item
+         .return_value) = {'id': test_id}
+        # fmt: on
+        hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
+        returned_item = hook.upsert_document({'id': test_id})
+        expected_calls = [
+            mock.call()
+            .get_database_client('test_database_name')
+            .get_container_client('test_collection_name')
+            .upsert_item({'id': test_id})
+        ]
+        mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
+        mock_cosmos.assert_has_calls(expected_calls)
+        logging.getLogger().info(returned_item)
+        assert returned_item['id'] == test_id
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
+    def test_upsert_document(self, mock_cosmos):
+        test_id = str(uuid.uuid4())
+        # fmt: off
+        (mock_cosmos
+         .return_value
+         .get_database_client
+         .return_value
+         .get_container_client
+         .return_value
+         .upsert_item
+         .return_value) = {'id': test_id}
+        # fmt: on
+        hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
+        returned_item = hook.upsert_document(
+            {'data1': 'somedata'},
+            database_name=self.test_database_name,
+            collection_name=self.test_collection_name,
+            document_id=test_id,
+        )
+
+        expected_calls = [
+            mock.call()
+            .get_database_client('test_database_name')
+            .get_container_client('test_collection_name')
+            .upsert_item({'data1': 'somedata', 'id': test_id})
+        ]
+
+        mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
+        mock_cosmos.assert_has_calls(expected_calls)
+        logging.getLogger().info(returned_item)
+        assert returned_item['id'] == test_id
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
+    def test_insert_documents(self, mock_cosmos):
+        test_id1 = str(uuid.uuid4())
+        test_id2 = str(uuid.uuid4())
+        test_id3 = str(uuid.uuid4())
+        documents = [
+            {'id': test_id1, 'data': 'data1'},
+            {'id': test_id2, 'data': 'data2'},
+            {'id': test_id3, 'data': 'data3'},
+        ]
+
+        hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
+        returned_item = hook.insert_documents(documents)
+        expected_calls = [
+            mock.call()
+            .get_database_client('test_database_name')
+            .get_container_client('test_collection_name')
+            .create_item({'data': 'data1', 'id': test_id1}),
+            mock.call()
+            .get_database_client('test_database_name')
+            .get_container_client('test_collection_name')
+            .create_item({'data': 'data2', 'id': test_id2}),
+            mock.call()
+            .get_database_client('test_database_name')
+            .get_container_client('test_collection_name')
+            .create_item({'data': 'data3', 'id': test_id3}),
+        ]
+        logging.getLogger().info(returned_item)
+        mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
+        mock_cosmos.assert_has_calls(expected_calls, any_order=True)
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
+    def test_delete_database(self, mock_cosmos):
+        hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
+        hook.delete_database(self.test_database_name)
+        expected_calls = [mock.call().delete_database('test_database_name')]
+        mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
+        mock_cosmos.assert_has_calls(expected_calls)
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
+    def test_delete_database_exception(self, mock_cosmos):
+        hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
+        with pytest.raises(AirflowException):
+            hook.delete_database(None)
+
+    @mock.patch('azure.cosmos.cosmos_client.CosmosClient')
+    def test_delete_container_exception(self, mock_cosmos):
+        hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
+        with pytest.raises(AirflowException):
+            hook.delete_collection(None)
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
+    def test_delete_container(self, mock_cosmos):
+        hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
+        hook.delete_collection(self.test_collection_name, self.test_database_name)
+        expected_calls = [
+            mock.call().get_database_client('test_database_name').delete_container('test_collection_name')
+        ]
+        mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
+        mock_cosmos.assert_has_calls(expected_calls)
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
+    def test_delete_container_default(self, mock_cosmos):
+        hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
+        hook.delete_collection(self.test_collection_name)
+        expected_calls = [
+            mock.call().get_database_client('test_database_name').delete_container('test_collection_name')
+        ]
+        mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
+        mock_cosmos.assert_has_calls(expected_calls)

+ 594 - 0
Azure/test_azure_data_factory.py

@@ -0,0 +1,594 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import json
+from typing import Type
+from unittest.mock import MagicMock, PropertyMock, patch
+
+import pytest
+from azure.identity import ClientSecretCredential, DefaultAzureCredential
+from azure.mgmt.datafactory.models import FactoryListResponse
+from pytest import fixture
+
+from airflow.exceptions import AirflowException
+from airflow.models.connection import Connection
+from airflow.providers.microsoft.azure.hooks.data_factory import (
+    AzureDataFactoryHook,
+    AzureDataFactoryPipelineRunException,
+    AzureDataFactoryPipelineRunStatus,
+    provide_targeted_factory,
+)
+from airflow.utils import db
+
+DEFAULT_RESOURCE_GROUP = "defaultResourceGroup"
+RESOURCE_GROUP = "testResourceGroup"
+
+DEFAULT_FACTORY = "defaultFactory"
+FACTORY = "testFactory"
+
+DEFAULT_CONNECTION_CLIENT_SECRET = "azure_data_factory_test_client_secret"
+DEFAULT_CONNECTION_DEFAULT_CREDENTIAL = "azure_data_factory_test_default_credential"
+
+MODEL = object()
+NAME = "testName"
+ID = "testId"
+
+
+def setup_module():
+    connection_client_secret = Connection(
+        conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
+        conn_type="azure_data_factory",
+        login="clientId",
+        password="clientSecret",
+        extra=json.dumps(
+            {
+                "extra__azure_data_factory__tenantId": "tenantId",
+                "extra__azure_data_factory__subscriptionId": "subscriptionId",
+                "extra__azure_data_factory__resource_group_name": DEFAULT_RESOURCE_GROUP,
+                "extra__azure_data_factory__factory_name": DEFAULT_FACTORY,
+            }
+        ),
+    )
+    connection_default_credential = Connection(
+        conn_id=DEFAULT_CONNECTION_DEFAULT_CREDENTIAL,
+        conn_type="azure_data_factory",
+        extra=json.dumps(
+            {
+                "extra__azure_data_factory__subscriptionId": "subscriptionId",
+                "extra__azure_data_factory__resource_group_name": DEFAULT_RESOURCE_GROUP,
+                "extra__azure_data_factory__factory_name": DEFAULT_FACTORY,
+            }
+        ),
+    )
+    connection_missing_subscription_id = Connection(
+        conn_id="azure_data_factory_missing_subscription_id",
+        conn_type="azure_data_factory",
+        login="clientId",
+        password="clientSecret",
+        extra=json.dumps(
+            {
+                "extra__azure_data_factory__tenantId": "tenantId",
+                "extra__azure_data_factory__resource_group_name": DEFAULT_RESOURCE_GROUP,
+                "extra__azure_data_factory__factory_name": DEFAULT_FACTORY,
+            }
+        ),
+    )
+    connection_missing_tenant_id = Connection(
+        conn_id="azure_data_factory_missing_tenant_id",
+        conn_type="azure_data_factory",
+        login="clientId",
+        password="clientSecret",
+        extra=json.dumps(
+            {
+                "extra__azure_data_factory__subscriptionId": "subscriptionId",
+                "extra__azure_data_factory__resource_group_name": DEFAULT_RESOURCE_GROUP,
+                "extra__azure_data_factory__factory_name": DEFAULT_FACTORY,
+            }
+        ),
+    )
+
+    db.merge_conn(connection_client_secret)
+    db.merge_conn(connection_default_credential)
+    db.merge_conn(connection_missing_subscription_id)
+    db.merge_conn(connection_missing_tenant_id)
+
+
+@fixture
+def hook():
+    client = AzureDataFactoryHook(azure_data_factory_conn_id=DEFAULT_CONNECTION_CLIENT_SECRET)
+    client._conn = MagicMock(
+        spec=[
+            "factories",
+            "linked_services",
+            "datasets",
+            "pipelines",
+            "pipeline_runs",
+            "triggers",
+            "trigger_runs",
+        ]
+    )
+
+    return client
+
+
+def parametrize(explicit_factory, implicit_factory):
+    def wrapper(func):
+        return pytest.mark.parametrize(
+            ("user_args", "sdk_args"),
+            (explicit_factory, implicit_factory),
+            ids=("explicit factory", "implicit factory"),
+        )(func)
+
+    return wrapper
+
+
+def test_provide_targeted_factory():
+    def echo(_, resource_group_name=None, factory_name=None):
+        return resource_group_name, factory_name
+
+    conn = MagicMock()
+    hook = MagicMock()
+    hook.get_connection.return_value = conn
+
+    conn.extra_dejson = {}
+    assert provide_targeted_factory(echo)(hook, RESOURCE_GROUP, FACTORY) == (RESOURCE_GROUP, FACTORY)
+
+    conn.extra_dejson = {
+        "extra__azure_data_factory__resource_group_name": DEFAULT_RESOURCE_GROUP,
+        "extra__azure_data_factory__factory_name": DEFAULT_FACTORY,
+    }
+    assert provide_targeted_factory(echo)(hook) == (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY)
+    assert provide_targeted_factory(echo)(hook, RESOURCE_GROUP, None) == (RESOURCE_GROUP, DEFAULT_FACTORY)
+    assert provide_targeted_factory(echo)(hook, None, FACTORY) == (DEFAULT_RESOURCE_GROUP, FACTORY)
+    assert provide_targeted_factory(echo)(hook, None, None) == (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY)
+
+    with pytest.raises(AirflowException):
+        conn.extra_dejson = {}
+        provide_targeted_factory(echo)(hook)
+
+
+@pytest.mark.parametrize(
+    ("connection_id", "credential_type"),
+    [
+        (DEFAULT_CONNECTION_CLIENT_SECRET, ClientSecretCredential),
+        (DEFAULT_CONNECTION_DEFAULT_CREDENTIAL, DefaultAzureCredential),
+    ],
+)
+def test_get_connection_by_credential_client_secret(connection_id: str, credential_type: Type):
+    hook = AzureDataFactoryHook(connection_id)
+
+    with patch.object(hook, "_create_client") as mock_create_client:
+        mock_create_client.return_value = MagicMock()
+        connection = hook.get_conn()
+        assert connection is not None
+        mock_create_client.assert_called_once()
+        assert isinstance(mock_create_client.call_args[0][0], credential_type)
+        assert mock_create_client.call_args[0][1] == "subscriptionId"
+
+
+@parametrize(
+    explicit_factory=((RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY)),
+    implicit_factory=((), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY)),
+)
+def test_get_factory(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.get_factory(*user_args)
+
+    hook._conn.factories.get.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, MODEL)),
+    implicit_factory=((MODEL,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, MODEL)),
+)
+def test_create_factory(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.create_factory(*user_args)
+
+    hook._conn.factories.create_or_update.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, MODEL)),
+    implicit_factory=((MODEL,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, MODEL)),
+)
+def test_update_factory(hook: AzureDataFactoryHook, user_args, sdk_args):
+    with patch.object(hook, "_factory_exists") as mock_factory_exists:
+        mock_factory_exists.return_value = True
+        hook.update_factory(*user_args)
+
+    hook._conn.factories.create_or_update.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, MODEL)),
+    implicit_factory=((MODEL,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, MODEL)),
+)
+def test_update_factory_non_existent(hook: AzureDataFactoryHook, user_args, sdk_args):
+    with patch.object(hook, "_factory_exists") as mock_factory_exists:
+        mock_factory_exists.return_value = False
+
+    with pytest.raises(AirflowException, match=r"Factory .+ does not exist"):
+        hook.update_factory(*user_args)
+
+
+@parametrize(
+    explicit_factory=((RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY)),
+    implicit_factory=((), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY)),
+)
+def test_delete_factory(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.delete_factory(*user_args)
+
+    hook._conn.factories.delete.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)),
+    implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)),
+)
+def test_get_linked_service(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.get_linked_service(*user_args)
+
+    hook._conn.linked_services.get.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)),
+    implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
+)
+def test_create_linked_service(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.create_linked_service(*user_args)
+
+    hook._conn.linked_services.create_or_update(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)),
+    implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
+)
+def test_update_linked_service(hook: AzureDataFactoryHook, user_args, sdk_args):
+    with patch.object(hook, "_linked_service_exists") as mock_linked_service_exists:
+        mock_linked_service_exists.return_value = True
+        hook.update_linked_service(*user_args)
+
+    hook._conn.linked_services.create_or_update(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)),
+    implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
+)
+def test_update_linked_service_non_existent(hook: AzureDataFactoryHook, user_args, sdk_args):
+    with patch.object(hook, "_linked_service_exists") as mock_linked_service_exists:
+        mock_linked_service_exists.return_value = False
+
+    with pytest.raises(AirflowException, match=r"Linked service .+ does not exist"):
+        hook.update_linked_service(*user_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)),
+    implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)),
+)
+def test_delete_linked_service(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.delete_linked_service(*user_args)
+
+    hook._conn.linked_services.delete.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)),
+    implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)),
+)
+def test_get_dataset(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.get_dataset(*user_args)
+
+    hook._conn.datasets.get.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)),
+    implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
+)
+def test_create_dataset(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.create_dataset(*user_args)
+
+    hook._conn.datasets.create_or_update.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)),
+    implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
+)
+def test_update_dataset(hook: AzureDataFactoryHook, user_args, sdk_args):
+    with patch.object(hook, "_dataset_exists") as mock_dataset_exists:
+        mock_dataset_exists.return_value = True
+        hook.update_dataset(*user_args)
+
+    hook._conn.datasets.create_or_update.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)),
+    implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
+)
+def test_update_dataset_non_existent(hook: AzureDataFactoryHook, user_args, sdk_args):
+    with patch.object(hook, "_dataset_exists") as mock_dataset_exists:
+        mock_dataset_exists.return_value = False
+
+    with pytest.raises(AirflowException, match=r"Dataset .+ does not exist"):
+        hook.update_dataset(*user_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)),
+    implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)),
+)
+def test_delete_dataset(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.delete_dataset(*user_args)
+
+    hook._conn.datasets.delete.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)),
+    implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)),
+)
+def test_get_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.get_pipeline(*user_args)
+
+    hook._conn.pipelines.get.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)),
+    implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
+)
+def test_create_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.create_pipeline(*user_args)
+
+    hook._conn.pipelines.create_or_update.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)),
+    implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
+)
+def test_update_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args):
+    with patch.object(hook, "_pipeline_exists") as mock_pipeline_exists:
+        mock_pipeline_exists.return_value = True
+        hook.update_pipeline(*user_args)
+
+    hook._conn.pipelines.create_or_update.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)),
+    implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
+)
+def test_update_pipeline_non_existent(hook: AzureDataFactoryHook, user_args, sdk_args):
+    with patch.object(hook, "_pipeline_exists") as mock_pipeline_exists:
+        mock_pipeline_exists.return_value = False
+
+    with pytest.raises(AirflowException, match=r"Pipeline .+ does not exist"):
+        hook.update_pipeline(*user_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)),
+    implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)),
+)
+def test_delete_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.delete_pipeline(*user_args)
+
+    hook._conn.pipelines.delete.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)),
+    implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)),
+)
+def test_run_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.run_pipeline(*user_args)
+
+    hook._conn.pipelines.create_run.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((ID, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, ID)),
+    implicit_factory=((ID,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, ID)),
+)
+def test_get_pipeline_run(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.get_pipeline_run(*user_args)
+
+    hook._conn.pipeline_runs.get.assert_called_with(*sdk_args)
+
+
+_wait_for_pipeline_run_status_test_args = [
+    (AzureDataFactoryPipelineRunStatus.SUCCEEDED, AzureDataFactoryPipelineRunStatus.SUCCEEDED, True),
+    (AzureDataFactoryPipelineRunStatus.FAILED, AzureDataFactoryPipelineRunStatus.SUCCEEDED, False),
+    (AzureDataFactoryPipelineRunStatus.CANCELLED, AzureDataFactoryPipelineRunStatus.SUCCEEDED, False),
+    (AzureDataFactoryPipelineRunStatus.IN_PROGRESS, AzureDataFactoryPipelineRunStatus.SUCCEEDED, "timeout"),
+    (AzureDataFactoryPipelineRunStatus.QUEUED, AzureDataFactoryPipelineRunStatus.SUCCEEDED, "timeout"),
+    (AzureDataFactoryPipelineRunStatus.CANCELING, AzureDataFactoryPipelineRunStatus.SUCCEEDED, "timeout"),
+    (AzureDataFactoryPipelineRunStatus.SUCCEEDED, AzureDataFactoryPipelineRunStatus.TERMINAL_STATUSES, True),
+    (AzureDataFactoryPipelineRunStatus.FAILED, AzureDataFactoryPipelineRunStatus.TERMINAL_STATUSES, True),
+    (AzureDataFactoryPipelineRunStatus.CANCELLED, AzureDataFactoryPipelineRunStatus.TERMINAL_STATUSES, True),
+]
+
+
+@pytest.mark.parametrize(
+    argnames=("pipeline_run_status", "expected_status", "expected_output"),
+    argvalues=_wait_for_pipeline_run_status_test_args,
+    ids=[
+        f"run_status_{argval[0]}_expected_{argval[1]}"
+        if isinstance(argval[1], str)
+        else f"run_status_{argval[0]}_expected_AnyTerminalStatus"
+        for argval in _wait_for_pipeline_run_status_test_args
+    ],
+)
+def test_wait_for_pipeline_run_status(hook, pipeline_run_status, expected_status, expected_output):
+    config = {"run_id": ID, "timeout": 3, "check_interval": 1, "expected_statuses": expected_status}
+
+    with patch.object(AzureDataFactoryHook, "get_pipeline_run") as mock_pipeline_run:
+        mock_pipeline_run.return_value.status = pipeline_run_status
+
+        if expected_output != "timeout":
+            assert hook.wait_for_pipeline_run_status(**config) == expected_output
+        else:
+            with pytest.raises(AzureDataFactoryPipelineRunException):
+                hook.wait_for_pipeline_run_status(**config)
+
+
+@parametrize(
+    explicit_factory=((ID, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, ID)),
+    implicit_factory=((ID,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, ID)),
+)
+def test_cancel_pipeline_run(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.cancel_pipeline_run(*user_args)
+
+    hook._conn.pipeline_runs.cancel.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)),
+    implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)),
+)
+def test_get_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.get_trigger(*user_args)
+
+    hook._conn.triggers.get.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)),
+    implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
+)
+def test_create_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.create_trigger(*user_args)
+
+    hook._conn.triggers.create_or_update.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)),
+    implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
+)
+def test_update_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
+    with patch.object(hook, "_trigger_exists") as mock_trigger_exists:
+        mock_trigger_exists.return_value = True
+        hook.update_trigger(*user_args)
+
+    hook._conn.triggers.create_or_update.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)),
+    implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
+)
+def test_update_trigger_non_existent(hook: AzureDataFactoryHook, user_args, sdk_args):
+    with patch.object(hook, "_trigger_exists") as mock_trigger_exists:
+        mock_trigger_exists.return_value = False
+
+    with pytest.raises(AirflowException, match=r"Trigger .+ does not exist"):
+        hook.update_trigger(*user_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)),
+    implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)),
+)
+def test_delete_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.delete_trigger(*user_args)
+
+    hook._conn.triggers.delete.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)),
+    implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)),
+)
+def test_start_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.start_trigger(*user_args)
+
+    hook._conn.triggers.begin_start.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)),
+    implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)),
+)
+def test_stop_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.stop_trigger(*user_args)
+
+    hook._conn.triggers.begin_stop.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, ID, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, ID)),
+    implicit_factory=((NAME, ID), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, ID)),
+)
+def test_rerun_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.rerun_trigger(*user_args)
+
+    hook._conn.trigger_runs.rerun.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, ID, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, ID)),
+    implicit_factory=((NAME, ID), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, ID)),
+)
+def test_cancel_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.cancel_trigger(*user_args)
+
+    hook._conn.trigger_runs.cancel.assert_called_with(*sdk_args)
+
+
+@pytest.mark.parametrize(
+    argnames="factory_list_result",
+    argvalues=[iter([FactoryListResponse]), iter([])],
+    ids=["factory_exists", "factory_does_not_exist"],
+)
+def test_connection_success(hook, factory_list_result):
+    hook.get_conn().factories.list.return_value = factory_list_result
+    status, msg = hook.test_connection()
+
+    assert status is True
+    assert msg == "Successfully connected to Azure Data Factory."
+
+
+def test_connection_failure(hook):
+    hook.get_conn().factories.list = PropertyMock(side_effect=Exception("Authentication failed."))
+    status, msg = hook.test_connection()
+
+    assert status is False
+    assert msg == "Authentication failed."
+
+
+def test_connection_failure_missing_subscription_id():
+    hook = AzureDataFactoryHook("azure_data_factory_missing_subscription_id")
+    status, msg = hook.test_connection()
+
+    assert status is False
+    assert msg == "A Subscription ID is required to connect to Azure Data Factory."
+
+
+def test_connection_failure_missing_tenant_id():
+    hook = AzureDataFactoryHook("azure_data_factory_missing_tenant_id")
+    status, msg = hook.test_connection()
+
+    assert status is False
+    assert msg == "A Tenant ID is required when authenticating with Client ID and Secret."

+ 133 - 0
Azure/test_azure_data_lake.py

@@ -0,0 +1,133 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import json
+import unittest
+from unittest import mock
+
+from airflow.models import Connection
+from airflow.utils import db
+
+
+class TestAzureDataLakeHook(unittest.TestCase):
+    def setUp(self):
+        db.merge_conn(
+            Connection(
+                conn_id='adl_test_key',
+                conn_type='azure_data_lake',
+                login='client_id',
+                password='client secret',
+                extra=json.dumps({"tenant": "tenant", "account_name": "accountname"}),
+            )
+        )
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.data_lake.lib', autospec=True)
+    def test_conn(self, mock_lib):
+        from azure.datalake.store import core
+
+        from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook
+
+        hook = AzureDataLakeHook(azure_data_lake_conn_id='adl_test_key')
+        assert hook._conn is None
+        assert hook.conn_id == 'adl_test_key'
+        assert isinstance(hook.get_conn(), core.AzureDLFileSystem)
+        assert mock_lib.auth.called
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.data_lake.core.AzureDLFileSystem', autospec=True)
+    @mock.patch('airflow.providers.microsoft.azure.hooks.data_lake.lib', autospec=True)
+    def test_check_for_blob(self, mock_lib, mock_filesystem):
+        from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook
+
+        hook = AzureDataLakeHook(azure_data_lake_conn_id='adl_test_key')
+        hook.check_for_file('file_path')
+        mock_filesystem.glob.called
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.data_lake.multithread.ADLUploader', autospec=True)
+    @mock.patch('airflow.providers.microsoft.azure.hooks.data_lake.lib', autospec=True)
+    def test_upload_file(self, mock_lib, mock_uploader):
+        from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook
+
+        hook = AzureDataLakeHook(azure_data_lake_conn_id='adl_test_key')
+        hook.upload_file(
+            local_path='tests/hooks/test_adl_hook.py',
+            remote_path='/test_adl_hook.py',
+            nthreads=64,
+            overwrite=True,
+            buffersize=4194304,
+            blocksize=4194304,
+        )
+        mock_uploader.assert_called_once_with(
+            hook.get_conn(),
+            lpath='tests/hooks/test_adl_hook.py',
+            rpath='/test_adl_hook.py',
+            nthreads=64,
+            overwrite=True,
+            buffersize=4194304,
+            blocksize=4194304,
+        )
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.data_lake.multithread.ADLDownloader', autospec=True)
+    @mock.patch('airflow.providers.microsoft.azure.hooks.data_lake.lib', autospec=True)
+    def test_download_file(self, mock_lib, mock_downloader):
+        from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook
+
+        hook = AzureDataLakeHook(azure_data_lake_conn_id='adl_test_key')
+        hook.download_file(
+            local_path='test_adl_hook.py',
+            remote_path='/test_adl_hook.py',
+            nthreads=64,
+            overwrite=True,
+            buffersize=4194304,
+            blocksize=4194304,
+        )
+        mock_downloader.assert_called_once_with(
+            hook.get_conn(),
+            lpath='test_adl_hook.py',
+            rpath='/test_adl_hook.py',
+            nthreads=64,
+            overwrite=True,
+            buffersize=4194304,
+            blocksize=4194304,
+        )
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.data_lake.core.AzureDLFileSystem', autospec=True)
+    @mock.patch('airflow.providers.microsoft.azure.hooks.data_lake.lib', autospec=True)
+    def test_list_glob(self, mock_lib, mock_fs):
+        from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook
+
+        hook = AzureDataLakeHook(azure_data_lake_conn_id='adl_test_key')
+        hook.list('file_path/*')
+        mock_fs.return_value.glob.assert_called_once_with('file_path/*')
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.data_lake.core.AzureDLFileSystem', autospec=True)
+    @mock.patch('airflow.providers.microsoft.azure.hooks.data_lake.lib', autospec=True)
+    def test_list_walk(self, mock_lib, mock_fs):
+        from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook
+
+        hook = AzureDataLakeHook(azure_data_lake_conn_id='adl_test_key')
+        hook.list('file_path/some_folder/')
+        mock_fs.return_value.walk.assert_called_once_with('file_path/some_folder/')
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.data_lake.core.AzureDLFileSystem', autospec=True)
+    @mock.patch('airflow.providers.microsoft.azure.hooks.data_lake.lib', autospec=True)
+    def test_remove(self, mock_lib, mock_fs):
+        from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook
+
+        hook = AzureDataLakeHook(azure_data_lake_conn_id='adl_test_key')
+        hook.remove('filepath', True)
+        mock_fs.return_value.remove.assert_called_once_with('filepath', recursive=True)

+ 257 - 0
Azure/test_azure_fileshare.py

@@ -0,0 +1,257 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+"""
+This module contains integration with Azure File Share.
+
+Cloud variant of a SMB file share. Make sure that a Airflow connection of
+type `wasb` exists. Authorization can be done by supplying a login (=Storage account name)
+and password (=Storage account key), or login and SAS token in the extra field
+(see connection `azure_fileshare_default` for an example).
+"""
+
+import json
+import unittest
+from unittest import mock
+
+import pytest
+from azure.storage.file import Directory, File
+
+from airflow.models import Connection
+from airflow.providers.microsoft.azure.hooks.fileshare import AzureFileShareHook
+from airflow.utils import db
+
+
+class TestAzureFileshareHook(unittest.TestCase):
+    def setUp(self):
+        db.merge_conn(
+            Connection(
+                conn_id='azure_fileshare_test_key',
+                conn_type='azure_file_share',
+                login='login',
+                password='key',
+            )
+        )
+        db.merge_conn(
+            Connection(
+                conn_id='azure_fileshare_extras',
+                conn_type='azure_fileshare',
+                login='login',
+                extra=json.dumps(
+                    {
+                        'extra__azure_fileshare__sas_token': 'token',
+                        'extra__azure_fileshare__protocol': 'http',
+                    }
+                ),
+            )
+        )
+        db.merge_conn(
+            # Neither password nor sas_token present
+            Connection(
+                conn_id='azure_fileshare_missing_credentials',
+                conn_type='azure_fileshare',
+                login='login',
+            )
+        )
+        db.merge_conn(
+            Connection(
+                conn_id='azure_fileshare_extras_deprecated',
+                conn_type='azure_fileshare',
+                login='login',
+                extra=json.dumps(
+                    {
+                        'sas_token': 'token',
+                    }
+                ),
+            )
+        )
+        db.merge_conn(
+            Connection(
+                conn_id='azure_fileshare_extras_deprecated_empty_wasb_extra',
+                conn_type='azure_fileshare',
+                login='login',
+                password='password',
+                extra=json.dumps(
+                    {
+                        'extra__azure_fileshare__shared_access_key': '',
+                    }
+                ),
+            )
+        )
+
+        db.merge_conn(
+            Connection(
+                conn_id='azure_fileshare_extras_wrong',
+                conn_type='azure_fileshare',
+                login='login',
+                extra=json.dumps(
+                    {
+                        'wrong_key': 'token',
+                    }
+                ),
+            )
+        )
+
+    def test_key_and_connection(self):
+        from azure.storage.file import FileService
+
+        hook = AzureFileShareHook(azure_fileshare_conn_id='azure_fileshare_test_key')
+        assert hook.conn_id == 'azure_fileshare_test_key'
+        assert hook._conn is None
+        print(hook.get_conn())
+        assert isinstance(hook.get_conn(), FileService)
+
+    def test_sas_token(self):
+        from azure.storage.file import FileService
+
+        hook = AzureFileShareHook(azure_fileshare_conn_id='azure_fileshare_extras')
+        assert hook.conn_id == 'azure_fileshare_extras'
+        assert isinstance(hook.get_conn(), FileService)
+
+    def test_deprecated_sas_token(self):
+        from azure.storage.file import FileService
+
+        with pytest.warns(DeprecationWarning):
+            hook = AzureFileShareHook(azure_fileshare_conn_id='azure_fileshare_extras_deprecated')
+            assert hook.conn_id == 'azure_fileshare_extras_deprecated'
+            assert isinstance(hook.get_conn(), FileService)
+
+    def test_deprecated_wasb_connection(self):
+        from azure.storage.file import FileService
+
+        with pytest.warns(DeprecationWarning):
+            hook = AzureFileShareHook(
+                azure_fileshare_conn_id='azure_fileshare_extras_deprecated_empty_wasb_extra'
+            )
+            assert hook.conn_id == 'azure_fileshare_extras_deprecated_empty_wasb_extra'
+            assert isinstance(hook.get_conn(), FileService)
+
+    def test_wrong_extras(self):
+        hook = AzureFileShareHook(azure_fileshare_conn_id='azure_fileshare_extras_wrong')
+        assert hook.conn_id == 'azure_fileshare_extras_wrong'
+        with pytest.raises(TypeError, match=".*wrong_key.*"):
+            hook.get_conn()
+
+    def test_missing_credentials(self):
+        hook = AzureFileShareHook(azure_fileshare_conn_id='azure_fileshare_missing_credentials')
+        assert hook.conn_id == 'azure_fileshare_missing_credentials'
+        with pytest.raises(ValueError, match=".*account_key or sas_token.*"):
+            hook.get_conn()
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.fileshare.FileService', autospec=True)
+    def test_check_for_file(self, mock_service):
+        mock_instance = mock_service.return_value
+        mock_instance.exists.return_value = True
+        hook = AzureFileShareHook(azure_fileshare_conn_id='azure_fileshare_extras')
+        assert hook.check_for_file('share', 'directory', 'file', timeout=3)
+        mock_instance.exists.assert_called_once_with('share', 'directory', 'file', timeout=3)
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.fileshare.FileService', autospec=True)
+    def test_check_for_directory(self, mock_service):
+        mock_instance = mock_service.return_value
+        mock_instance.exists.return_value = True
+        hook = AzureFileShareHook(azure_fileshare_conn_id='azure_fileshare_extras')
+        assert hook.check_for_directory('share', 'directory', timeout=3)
+        mock_instance.exists.assert_called_once_with('share', 'directory', timeout=3)
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.fileshare.FileService', autospec=True)
+    def test_load_file(self, mock_service):
+        mock_instance = mock_service.return_value
+        hook = AzureFileShareHook(azure_fileshare_conn_id='azure_fileshare_extras')
+        hook.load_file('path', 'share', 'directory', 'file', max_connections=1)
+        mock_instance.create_file_from_path.assert_called_once_with(
+            'share', 'directory', 'file', 'path', max_connections=1
+        )
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.fileshare.FileService', autospec=True)
+    def test_load_string(self, mock_service):
+        mock_instance = mock_service.return_value
+        hook = AzureFileShareHook(azure_fileshare_conn_id='azure_fileshare_extras')
+        hook.load_string('big string', 'share', 'directory', 'file', timeout=1)
+        mock_instance.create_file_from_text.assert_called_once_with(
+            'share', 'directory', 'file', 'big string', timeout=1
+        )
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.fileshare.FileService', autospec=True)
+    def test_load_stream(self, mock_service):
+        mock_instance = mock_service.return_value
+        hook = AzureFileShareHook(azure_fileshare_conn_id='azure_fileshare_extras')
+        hook.load_stream('stream', 'share', 'directory', 'file', 42, timeout=1)
+        mock_instance.create_file_from_stream.assert_called_once_with(
+            'share', 'directory', 'file', 'stream', 42, timeout=1
+        )
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.fileshare.FileService', autospec=True)
+    def test_list_directories_and_files(self, mock_service):
+        mock_instance = mock_service.return_value
+        hook = AzureFileShareHook(azure_fileshare_conn_id='azure_fileshare_extras')
+        hook.list_directories_and_files('share', 'directory', timeout=1)
+        mock_instance.list_directories_and_files.assert_called_once_with('share', 'directory', timeout=1)
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.fileshare.FileService', autospec=True)
+    def test_list_files(self, mock_service):
+        mock_instance = mock_service.return_value
+        mock_instance.list_directories_and_files.return_value = [
+            File("file1"),
+            File("file2"),
+            Directory("dir1"),
+            Directory("dir2"),
+        ]
+        hook = AzureFileShareHook(azure_fileshare_conn_id='azure_fileshare_extras')
+        files = hook.list_files('share', 'directory', timeout=1)
+        assert files == ["file1", 'file2']
+        mock_instance.list_directories_and_files.assert_called_once_with('share', 'directory', timeout=1)
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.fileshare.FileService', autospec=True)
+    def test_create_directory(self, mock_service):
+        mock_instance = mock_service.return_value
+        hook = AzureFileShareHook(azure_fileshare_conn_id='azure_fileshare_extras')
+        hook.create_directory('share', 'directory', timeout=1)
+        mock_instance.create_directory.assert_called_once_with('share', 'directory', timeout=1)
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.fileshare.FileService', autospec=True)
+    def test_get_file(self, mock_service):
+        mock_instance = mock_service.return_value
+        hook = AzureFileShareHook(azure_fileshare_conn_id='azure_fileshare_extras')
+        hook.get_file('path', 'share', 'directory', 'file', max_connections=1)
+        mock_instance.get_file_to_path.assert_called_once_with(
+            'share', 'directory', 'file', 'path', max_connections=1
+        )
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.fileshare.FileService', autospec=True)
+    def test_get_file_to_stream(self, mock_service):
+        mock_instance = mock_service.return_value
+        hook = AzureFileShareHook(azure_fileshare_conn_id='azure_fileshare_extras')
+        hook.get_file_to_stream('stream', 'share', 'directory', 'file', max_connections=1)
+        mock_instance.get_file_to_stream.assert_called_once_with(
+            'share', 'directory', 'file', 'stream', max_connections=1
+        )
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.fileshare.FileService', autospec=True)
+    def test_create_share(self, mock_service):
+        mock_instance = mock_service.return_value
+        hook = AzureFileShareHook(azure_fileshare_conn_id='azure_fileshare_extras')
+        hook.create_share('my_share')
+        mock_instance.create_share.assert_called_once_with('my_share')
+
+    @mock.patch('airflow.providers.microsoft.azure.hooks.fileshare.FileService', autospec=True)
+    def test_delete_share(self, mock_service):
+        mock_instance = mock_service.return_value
+        hook = AzureFileShareHook(azure_fileshare_conn_id='azure_fileshare_extras')
+        hook.delete_share('my_share')
+        mock_instance.delete_share.assert_called_once_with('my_share')

+ 120 - 0
Azure/test_azure_fileshare_to_gcs.py

@@ -0,0 +1,120 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+from unittest import mock
+
+from airflow.providers.google.cloud.transfers.azure_fileshare_to_gcs import AzureFileShareToGCSOperator
+
+TASK_ID = 'test-azure-fileshare-to-gcs'
+AZURE_FILESHARE_SHARE = 'test-share'
+AZURE_FILESHARE_DIRECTORY_NAME = '/path/to/dir'
+GCS_PATH_PREFIX = 'gs://gcs-bucket/data/'
+MOCK_FILES = ["TEST1.csv", "TEST2.csv", "TEST3.csv"]
+AZURE_FILESHARE_CONN_ID = 'azure_fileshare_default'
+GCS_CONN_ID = 'google_cloud_default'
+IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"]
+
+
+class TestAzureFileShareToGCSOperator(unittest.TestCase):
+    def test_init(self):
+        """Test AzureFileShareToGCSOperator instance is properly initialized."""
+
+        operator = AzureFileShareToGCSOperator(
+            task_id=TASK_ID,
+            share_name=AZURE_FILESHARE_SHARE,
+            directory_name=AZURE_FILESHARE_DIRECTORY_NAME,
+            azure_fileshare_conn_id=AZURE_FILESHARE_CONN_ID,
+            gcp_conn_id=GCS_CONN_ID,
+            dest_gcs=GCS_PATH_PREFIX,
+            google_impersonation_chain=IMPERSONATION_CHAIN,
+        )
+
+        assert operator.task_id == TASK_ID
+        assert operator.share_name == AZURE_FILESHARE_SHARE
+        assert operator.directory_name == AZURE_FILESHARE_DIRECTORY_NAME
+        assert operator.azure_fileshare_conn_id == AZURE_FILESHARE_CONN_ID
+        assert operator.gcp_conn_id == GCS_CONN_ID
+        assert operator.dest_gcs == GCS_PATH_PREFIX
+        assert operator.google_impersonation_chain == IMPERSONATION_CHAIN
+
+    @mock.patch('airflow.providers.google.cloud.transfers.azure_fileshare_to_gcs.AzureFileShareHook')
+    @mock.patch('airflow.providers.google.cloud.transfers.azure_fileshare_to_gcs.GCSHook')
+    def test_execute(self, gcs_mock_hook, azure_fileshare_mock_hook):
+        """Test the execute function when the run is successful."""
+
+        operator = AzureFileShareToGCSOperator(
+            task_id=TASK_ID,
+            share_name=AZURE_FILESHARE_SHARE,
+            directory_name=AZURE_FILESHARE_DIRECTORY_NAME,
+            azure_fileshare_conn_id=AZURE_FILESHARE_CONN_ID,
+            gcp_conn_id=GCS_CONN_ID,
+            dest_gcs=GCS_PATH_PREFIX,
+            google_impersonation_chain=IMPERSONATION_CHAIN,
+        )
+
+        azure_fileshare_mock_hook.return_value.list_files.return_value = MOCK_FILES
+
+        uploaded_files = operator.execute(None)
+
+        gcs_mock_hook.return_value.upload.assert_has_calls(
+            [
+                mock.call('gcs-bucket', 'data/TEST1.csv', mock.ANY, gzip=False),
+                mock.call('gcs-bucket', 'data/TEST3.csv', mock.ANY, gzip=False),
+                mock.call('gcs-bucket', 'data/TEST2.csv', mock.ANY, gzip=False),
+            ],
+            any_order=True,
+        )
+
+        azure_fileshare_mock_hook.assert_called_once_with(AZURE_FILESHARE_CONN_ID)
+
+        gcs_mock_hook.assert_called_once_with(
+            gcp_conn_id=GCS_CONN_ID,
+            delegate_to=None,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+
+        assert sorted(MOCK_FILES) == sorted(uploaded_files)
+
+    @mock.patch('airflow.providers.google.cloud.transfers.azure_fileshare_to_gcs.AzureFileShareHook')
+    @mock.patch('airflow.providers.google.cloud.transfers.azure_fileshare_to_gcs.GCSHook')
+    def test_execute_with_gzip(self, gcs_mock_hook, azure_fileshare_mock_hook):
+        """Test the execute function when the run is successful."""
+
+        operator = AzureFileShareToGCSOperator(
+            task_id=TASK_ID,
+            share_name=AZURE_FILESHARE_SHARE,
+            directory_name=AZURE_FILESHARE_DIRECTORY_NAME,
+            azure_fileshare_conn_id=AZURE_FILESHARE_CONN_ID,
+            gcp_conn_id=GCS_CONN_ID,
+            dest_gcs=GCS_PATH_PREFIX,
+            google_impersonation_chain=IMPERSONATION_CHAIN,
+            gzip=True,
+        )
+
+        azure_fileshare_mock_hook.return_value.list_files.return_value = MOCK_FILES
+
+        operator.execute(None)
+
+        gcs_mock_hook.return_value.upload.assert_has_calls(
+            [
+                mock.call('gcs-bucket', 'data/TEST1.csv', mock.ANY, gzip=True),
+                mock.call('gcs-bucket', 'data/TEST3.csv', mock.ANY, gzip=True),
+                mock.call('gcs-bucket', 'data/TEST2.csv', mock.ANY, gzip=True),
+            ],
+            any_order=True,
+        )

+ 1427 - 0
Azure/test_azure_helper.py

@@ -0,0 +1,1427 @@
+# This file is part of cloud-init. See LICENSE file for license information.
+
+import copy
+import os
+import re
+import unittest
+from textwrap import dedent
+from xml.etree import ElementTree
+from xml.sax.saxutils import escape, unescape
+
+from cloudinit.sources.helpers import azure as azure_helper
+from cloudinit.tests.helpers import CiTestCase, ExitStack, mock, populate_dir
+
+from cloudinit.util import load_file
+from cloudinit.sources.helpers.azure import WALinuxAgentShim as wa_shim
+
+GOAL_STATE_TEMPLATE = """\
+<?xml version="1.0" encoding="utf-8"?>
+<GoalState xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:noNamespaceSchemaLocation="goalstate10.xsd">
+  <Version>2012-11-30</Version>
+  <Incarnation>{incarnation}</Incarnation>
+  <Machine>
+    <ExpectedState>Started</ExpectedState>
+    <StopRolesDeadlineHint>300000</StopRolesDeadlineHint>
+    <LBProbePorts>
+      <Port>16001</Port>
+    </LBProbePorts>
+    <ExpectHealthReport>FALSE</ExpectHealthReport>
+  </Machine>
+  <Container>
+    <ContainerId>{container_id}</ContainerId>
+    <RoleInstanceList>
+      <RoleInstance>
+        <InstanceId>{instance_id}</InstanceId>
+        <State>Started</State>
+        <Configuration>
+          <HostingEnvironmentConfig>
+            http://100.86.192.70:80/...hostingEnvironmentConfig...
+          </HostingEnvironmentConfig>
+          <SharedConfig>http://100.86.192.70:80/..SharedConfig..</SharedConfig>
+          <ExtensionsConfig>
+            http://100.86.192.70:80/...extensionsConfig...
+          </ExtensionsConfig>
+          <FullConfig>http://100.86.192.70:80/...fullConfig...</FullConfig>
+          <Certificates>{certificates_url}</Certificates>
+          <ConfigName>68ce47.0.68ce47.0.utl-trusty--292258.1.xml</ConfigName>
+        </Configuration>
+      </RoleInstance>
+    </RoleInstanceList>
+  </Container>
+</GoalState>
+"""
+
+HEALTH_REPORT_XML_TEMPLATE = '''\
+<?xml version="1.0" encoding="utf-8"?>
+<Health xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xmlns:xsd="http://www.w3.org/2001/XMLSchema">
+  <GoalStateIncarnation>{incarnation}</GoalStateIncarnation>
+  <Container>
+    <ContainerId>{container_id}</ContainerId>
+    <RoleInstanceList>
+      <Role>
+        <InstanceId>{instance_id}</InstanceId>
+        <Health>
+          <State>{health_status}</State>
+          {health_detail_subsection}
+        </Health>
+      </Role>
+    </RoleInstanceList>
+  </Container>
+</Health>
+'''
+
+HEALTH_DETAIL_SUBSECTION_XML_TEMPLATE = dedent('''\
+    <Details>
+      <SubStatus>{health_substatus}</SubStatus>
+      <Description>{health_description}</Description>
+    </Details>
+    ''')
+
+HEALTH_REPORT_DESCRIPTION_TRIM_LEN = 512
+
+
+class SentinelException(Exception):
+    pass
+
+
+class TestFindEndpoint(CiTestCase):
+
+    def setUp(self):
+        super(TestFindEndpoint, self).setUp()
+        patches = ExitStack()
+        self.addCleanup(patches.close)
+
+        self.load_file = patches.enter_context(
+            mock.patch.object(azure_helper.util, 'load_file'))
+
+        self.dhcp_options = patches.enter_context(
+            mock.patch.object(wa_shim, '_load_dhclient_json'))
+
+        self.networkd_leases = patches.enter_context(
+            mock.patch.object(wa_shim, '_networkd_get_value_from_leases'))
+        self.networkd_leases.return_value = None
+
+    def test_missing_file(self):
+        """wa_shim find_endpoint uses default endpoint if leasefile not found
+        """
+        self.assertEqual(wa_shim.find_endpoint(), "168.63.129.16")
+
+    def test_missing_special_azure_line(self):
+        """wa_shim find_endpoint uses default endpoint if leasefile is found
+        but does not contain DHCP Option 245 (whose value is the endpoint)
+        """
+        self.load_file.return_value = ''
+        self.dhcp_options.return_value = {'eth0': {'key': 'value'}}
+        self.assertEqual(wa_shim.find_endpoint(), "168.63.129.16")
+
+    @staticmethod
+    def _build_lease_content(encoded_address):
+        endpoint = azure_helper._get_dhcp_endpoint_option_name()
+        return '\n'.join([
+            'lease {',
+            ' interface "eth0";',
+            ' option {0} {1};'.format(endpoint, encoded_address),
+            '}'])
+
+    def test_from_dhcp_client(self):
+        self.dhcp_options.return_value = {"eth0": {"unknown_245": "5:4:3:2"}}
+        self.assertEqual('5.4.3.2', wa_shim.find_endpoint(None))
+
+    @mock.patch('cloudinit.sources.helpers.azure.util.is_FreeBSD')
+    def test_latest_lease_used(self, m_is_freebsd):
+        m_is_freebsd.return_value = False  # To avoid hitting load_file
+        encoded_addresses = ['5:4:3:2', '4:3:2:1']
+        file_content = '\n'.join([self._build_lease_content(encoded_address)
+                                  for encoded_address in encoded_addresses])
+        self.load_file.return_value = file_content
+        self.assertEqual(encoded_addresses[-1].replace(':', '.'),
+                         wa_shim.find_endpoint("foobar"))
+
+
+class TestExtractIpAddressFromLeaseValue(CiTestCase):
+
+    def test_hex_string(self):
+        ip_address, encoded_address = '98.76.54.32', '62:4c:36:20'
+        self.assertEqual(
+            ip_address, wa_shim.get_ip_from_lease_value(encoded_address))
+
+    def test_hex_string_with_single_character_part(self):
+        ip_address, encoded_address = '4.3.2.1', '4:3:2:1'
+        self.assertEqual(
+            ip_address, wa_shim.get_ip_from_lease_value(encoded_address))
+
+    def test_packed_string(self):
+        ip_address, encoded_address = '98.76.54.32', 'bL6 '
+        self.assertEqual(
+            ip_address, wa_shim.get_ip_from_lease_value(encoded_address))
+
+    def test_packed_string_with_escaped_quote(self):
+        ip_address, encoded_address = '100.72.34.108', 'dH\\"l'
+        self.assertEqual(
+            ip_address, wa_shim.get_ip_from_lease_value(encoded_address))
+
+    def test_packed_string_containing_a_colon(self):
+        ip_address, encoded_address = '100.72.58.108', 'dH:l'
+        self.assertEqual(
+            ip_address, wa_shim.get_ip_from_lease_value(encoded_address))
+
+
+class TestGoalStateParsing(CiTestCase):
+
+    default_parameters = {
+        'incarnation': 1,
+        'container_id': 'MyContainerId',
+        'instance_id': 'MyInstanceId',
+        'certificates_url': 'MyCertificatesUrl',
+    }
+
+    def _get_formatted_goal_state_xml_string(self, **kwargs):
+        parameters = self.default_parameters.copy()
+        parameters.update(kwargs)
+        xml = GOAL_STATE_TEMPLATE.format(**parameters)
+        if parameters['certificates_url'] is None:
+            new_xml_lines = []
+            for line in xml.splitlines():
+                if 'Certificates' in line:
+                    continue
+                new_xml_lines.append(line)
+            xml = '\n'.join(new_xml_lines)
+        return xml
+
+    def _get_goal_state(self, m_azure_endpoint_client=None, **kwargs):
+        if m_azure_endpoint_client is None:
+            m_azure_endpoint_client = mock.MagicMock()
+        xml = self._get_formatted_goal_state_xml_string(**kwargs)
+        return azure_helper.GoalState(xml, m_azure_endpoint_client)
+
+    def test_incarnation_parsed_correctly(self):
+        incarnation = '123'
+        goal_state = self._get_goal_state(incarnation=incarnation)
+        self.assertEqual(incarnation, goal_state.incarnation)
+
+    def test_container_id_parsed_correctly(self):
+        container_id = 'TestContainerId'
+        goal_state = self._get_goal_state(container_id=container_id)
+        self.assertEqual(container_id, goal_state.container_id)
+
+    def test_instance_id_parsed_correctly(self):
+        instance_id = 'TestInstanceId'
+        goal_state = self._get_goal_state(instance_id=instance_id)
+        self.assertEqual(instance_id, goal_state.instance_id)
+
+    def test_instance_id_byte_swap(self):
+        """Return true when previous_iid is byteswapped current_iid"""
+        previous_iid = "D0DF4C54-4ECB-4A4B-9954-5BDF3ED5C3B8"
+        current_iid = "544CDFD0-CB4E-4B4A-9954-5BDF3ED5C3B8"
+        self.assertTrue(
+            azure_helper.is_byte_swapped(previous_iid, current_iid))
+
+    def test_instance_id_no_byte_swap_same_instance_id(self):
+        previous_iid = "D0DF4C54-4ECB-4A4B-9954-5BDF3ED5C3B8"
+        current_iid = "D0DF4C54-4ECB-4A4B-9954-5BDF3ED5C3B8"
+        self.assertFalse(
+            azure_helper.is_byte_swapped(previous_iid, current_iid))
+
+    def test_instance_id_no_byte_swap_diff_instance_id(self):
+        previous_iid = "D0DF4C54-4ECB-4A4B-9954-5BDF3ED5C3B8"
+        current_iid = "G0DF4C54-4ECB-4A4B-9954-5BDF3ED5C3B8"
+        self.assertFalse(
+            azure_helper.is_byte_swapped(previous_iid, current_iid))
+
+    def test_certificates_xml_parsed_and_fetched_correctly(self):
+        m_azure_endpoint_client = mock.MagicMock()
+        certificates_url = 'TestCertificatesUrl'
+        goal_state = self._get_goal_state(
+            m_azure_endpoint_client=m_azure_endpoint_client,
+            certificates_url=certificates_url)
+        certificates_xml = goal_state.certificates_xml
+        self.assertEqual(1, m_azure_endpoint_client.get.call_count)
+        self.assertEqual(
+            certificates_url,
+            m_azure_endpoint_client.get.call_args[0][0])
+        self.assertTrue(
+            m_azure_endpoint_client.get.call_args[1].get(
+                'secure', False))
+        self.assertEqual(
+            m_azure_endpoint_client.get.return_value.contents,
+            certificates_xml)
+
+    def test_missing_certificates_skips_http_get(self):
+        m_azure_endpoint_client = mock.MagicMock()
+        goal_state = self._get_goal_state(
+            m_azure_endpoint_client=m_azure_endpoint_client,
+            certificates_url=None)
+        certificates_xml = goal_state.certificates_xml
+        self.assertEqual(0, m_azure_endpoint_client.get.call_count)
+        self.assertIsNone(certificates_xml)
+
+    def test_invalid_goal_state_xml_raises_parse_error(self):
+        xml = 'random non-xml data'
+        with self.assertRaises(ElementTree.ParseError):
+            azure_helper.GoalState(xml, mock.MagicMock())
+
+    def test_missing_container_id_in_goal_state_xml_raises_exc(self):
+        xml = self._get_formatted_goal_state_xml_string()
+        xml = re.sub('<ContainerId>.*</ContainerId>', '', xml)
+        with self.assertRaises(azure_helper.InvalidGoalStateXMLException):
+            azure_helper.GoalState(xml, mock.MagicMock())
+
+    def test_missing_instance_id_in_goal_state_xml_raises_exc(self):
+        xml = self._get_formatted_goal_state_xml_string()
+        xml = re.sub('<InstanceId>.*</InstanceId>', '', xml)
+        with self.assertRaises(azure_helper.InvalidGoalStateXMLException):
+            azure_helper.GoalState(xml, mock.MagicMock())
+
+    def test_missing_incarnation_in_goal_state_xml_raises_exc(self):
+        xml = self._get_formatted_goal_state_xml_string()
+        xml = re.sub('<Incarnation>.*</Incarnation>', '', xml)
+        with self.assertRaises(azure_helper.InvalidGoalStateXMLException):
+            azure_helper.GoalState(xml, mock.MagicMock())
+
+
+class TestAzureEndpointHttpClient(CiTestCase):
+
+    regular_headers = {
+        'x-ms-agent-name': 'WALinuxAgent',
+        'x-ms-version': '2012-11-30',
+    }
+
+    def setUp(self):
+        super(TestAzureEndpointHttpClient, self).setUp()
+        patches = ExitStack()
+        self.addCleanup(patches.close)
+        self.m_http_with_retries = patches.enter_context(
+            mock.patch.object(azure_helper, 'http_with_retries'))
+
+    def test_non_secure_get(self):
+        client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
+        url = 'MyTestUrl'
+        response = client.get(url, secure=False)
+        self.assertEqual(1, self.m_http_with_retries.call_count)
+        self.assertEqual(self.m_http_with_retries.return_value, response)
+        self.assertEqual(
+            mock.call(url, headers=self.regular_headers),
+            self.m_http_with_retries.call_args)
+
+    def test_non_secure_get_raises_exception(self):
+        client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
+        url = 'MyTestUrl'
+        self.m_http_with_retries.side_effect = SentinelException
+        self.assertRaises(SentinelException, client.get, url, secure=False)
+        self.assertEqual(1, self.m_http_with_retries.call_count)
+
+    def test_secure_get(self):
+        url = 'MyTestUrl'
+        m_certificate = mock.MagicMock()
+        expected_headers = self.regular_headers.copy()
+        expected_headers.update({
+            "x-ms-cipher-name": "DES_EDE3_CBC",
+            "x-ms-guest-agent-public-x509-cert": m_certificate,
+        })
+        client = azure_helper.AzureEndpointHttpClient(m_certificate)
+        response = client.get(url, secure=True)
+        self.assertEqual(1, self.m_http_with_retries.call_count)
+        self.assertEqual(self.m_http_with_retries.return_value, response)
+        self.assertEqual(
+            mock.call(url, headers=expected_headers),
+            self.m_http_with_retries.call_args)
+
+    def test_secure_get_raises_exception(self):
+        url = 'MyTestUrl'
+        client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
+        self.m_http_with_retries.side_effect = SentinelException
+        self.assertRaises(SentinelException, client.get, url, secure=True)
+        self.assertEqual(1, self.m_http_with_retries.call_count)
+
+    def test_post(self):
+        m_data = mock.MagicMock()
+        url = 'MyTestUrl'
+        client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
+        response = client.post(url, data=m_data)
+        self.assertEqual(1, self.m_http_with_retries.call_count)
+        self.assertEqual(self.m_http_with_retries.return_value, response)
+        self.assertEqual(
+            mock.call(url, data=m_data, headers=self.regular_headers),
+            self.m_http_with_retries.call_args)
+
+    def test_post_raises_exception(self):
+        m_data = mock.MagicMock()
+        url = 'MyTestUrl'
+        client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
+        self.m_http_with_retries.side_effect = SentinelException
+        self.assertRaises(SentinelException, client.post, url, data=m_data)
+        self.assertEqual(1, self.m_http_with_retries.call_count)
+
+    def test_post_with_extra_headers(self):
+        url = 'MyTestUrl'
+        client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
+        extra_headers = {'test': 'header'}
+        client.post(url, extra_headers=extra_headers)
+        expected_headers = self.regular_headers.copy()
+        expected_headers.update(extra_headers)
+        self.assertEqual(1, self.m_http_with_retries.call_count)
+        self.assertEqual(
+            mock.call(url, data=mock.ANY, headers=expected_headers),
+            self.m_http_with_retries.call_args)
+
+    def test_post_with_sleep_with_extra_headers_raises_exception(self):
+        m_data = mock.MagicMock()
+        url = 'MyTestUrl'
+        extra_headers = {'test': 'header'}
+        client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
+        self.m_http_with_retries.side_effect = SentinelException
+        self.assertRaises(
+            SentinelException, client.post,
+            url, data=m_data, extra_headers=extra_headers)
+        self.assertEqual(1, self.m_http_with_retries.call_count)
+
+
+class TestAzureHelperHttpWithRetries(CiTestCase):
+
+    with_logs = True
+
+    max_readurl_attempts = 240
+    default_readurl_timeout = 5
+    periodic_logging_attempts = 12
+
+    def setUp(self):
+        super(TestAzureHelperHttpWithRetries, self).setUp()
+        patches = ExitStack()
+        self.addCleanup(patches.close)
+
+        self.m_readurl = patches.enter_context(
+            mock.patch.object(
+                azure_helper.url_helper, 'readurl', mock.MagicMock()))
+        patches.enter_context(
+            mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock()))
+
+    def test_http_with_retries(self):
+        self.m_readurl.return_value = 'TestResp'
+        self.assertEqual(
+            azure_helper.http_with_retries('testurl'),
+            self.m_readurl.return_value)
+        self.assertEqual(self.m_readurl.call_count, 1)
+
+    def test_http_with_retries_propagates_readurl_exc_and_logs_exc(
+            self):
+        self.m_readurl.side_effect = SentinelException
+
+        self.assertRaises(
+            SentinelException, azure_helper.http_with_retries, 'testurl')
+        self.assertEqual(self.m_readurl.call_count, self.max_readurl_attempts)
+
+        self.assertIsNotNone(
+            re.search(
+                r'Failed HTTP request with Azure endpoint \S* during '
+                r'attempt \d+ with exception: \S*',
+                self.logs.getvalue()))
+        self.assertIsNone(
+            re.search(
+                r'Successful HTTP request with Azure endpoint \S* after '
+                r'\d+ attempts',
+                self.logs.getvalue()))
+
+    def test_http_with_retries_delayed_success_due_to_temporary_readurl_exc(
+            self):
+        self.m_readurl.side_effect = \
+            [SentinelException] * self.periodic_logging_attempts + \
+            ['TestResp']
+        self.m_readurl.return_value = 'TestResp'
+
+        response = azure_helper.http_with_retries('testurl')
+        self.assertEqual(
+            response,
+            self.m_readurl.return_value)
+        self.assertEqual(
+            self.m_readurl.call_count,
+            self.periodic_logging_attempts + 1)
+
+    def test_http_with_retries_long_delay_logs_periodic_failure_msg(self):
+        self.m_readurl.side_effect = \
+            [SentinelException] * self.periodic_logging_attempts + \
+            ['TestResp']
+        self.m_readurl.return_value = 'TestResp'
+
+        azure_helper.http_with_retries('testurl')
+
+        self.assertEqual(
+            self.m_readurl.call_count,
+            self.periodic_logging_attempts + 1)
+        self.assertIsNotNone(
+            re.search(
+                r'Failed HTTP request with Azure endpoint \S* during '
+                r'attempt \d+ with exception: \S*',
+                self.logs.getvalue()))
+        self.assertIsNotNone(
+            re.search(
+                r'Successful HTTP request with Azure endpoint \S* after '
+                r'\d+ attempts',
+                self.logs.getvalue()))
+
+    def test_http_with_retries_short_delay_does_not_log_periodic_failure_msg(
+            self):
+        self.m_readurl.side_effect = \
+            [SentinelException] * \
+            (self.periodic_logging_attempts - 1) + \
+            ['TestResp']
+        self.m_readurl.return_value = 'TestResp'
+
+        azure_helper.http_with_retries('testurl')
+        self.assertEqual(
+            self.m_readurl.call_count,
+            self.periodic_logging_attempts)
+
+        self.assertIsNone(
+            re.search(
+                r'Failed HTTP request with Azure endpoint \S* during '
+                r'attempt \d+ with exception: \S*',
+                self.logs.getvalue()))
+        self.assertIsNotNone(
+            re.search(
+                r'Successful HTTP request with Azure endpoint \S* after '
+                r'\d+ attempts',
+                self.logs.getvalue()))
+
+    def test_http_with_retries_calls_url_helper_readurl_with_args_kwargs(self):
+        testurl = mock.MagicMock()
+        kwargs = {
+            'headers': mock.MagicMock(),
+            'data': mock.MagicMock(),
+            # timeout kwarg should not be modified or deleted if present
+            'timeout': mock.MagicMock()
+        }
+        azure_helper.http_with_retries(testurl, **kwargs)
+        self.m_readurl.assert_called_once_with(testurl, **kwargs)
+
+    def test_http_with_retries_adds_timeout_kwarg_if_not_present(self):
+        testurl = mock.MagicMock()
+        kwargs = {
+            'headers': mock.MagicMock(),
+            'data': mock.MagicMock()
+        }
+        expected_kwargs = copy.deepcopy(kwargs)
+        expected_kwargs['timeout'] = self.default_readurl_timeout
+
+        azure_helper.http_with_retries(testurl, **kwargs)
+        self.m_readurl.assert_called_once_with(testurl, **expected_kwargs)
+
+    def test_http_with_retries_deletes_retries_kwargs_passed_in(
+            self):
+        """http_with_retries already implements retry logic,
+        so url_helper.readurl should not have retries.
+        http_with_retries should delete kwargs that
+        cause url_helper.readurl to retry.
+        """
+        testurl = mock.MagicMock()
+        kwargs = {
+            'headers': mock.MagicMock(),
+            'data': mock.MagicMock(),
+            'timeout': mock.MagicMock(),
+            'retries': mock.MagicMock(),
+            'infinite': mock.MagicMock()
+        }
+        expected_kwargs = copy.deepcopy(kwargs)
+        expected_kwargs.pop('retries', None)
+        expected_kwargs.pop('infinite', None)
+
+        azure_helper.http_with_retries(testurl, **kwargs)
+        self.m_readurl.assert_called_once_with(testurl, **expected_kwargs)
+        self.assertIn(
+            'retries kwarg passed in for communication with Azure endpoint.',
+            self.logs.getvalue())
+        self.assertIn(
+            'infinite kwarg passed in for communication with Azure endpoint.',
+            self.logs.getvalue())
+
+
+class TestOpenSSLManager(CiTestCase):
+
+    def setUp(self):
+        super(TestOpenSSLManager, self).setUp()
+        patches = ExitStack()
+        self.addCleanup(patches.close)
+
+        self.subp = patches.enter_context(
+            mock.patch.object(azure_helper.subp, 'subp'))
+        try:
+            self.open = patches.enter_context(
+                mock.patch('__builtin__.open'))
+        except ImportError:
+            self.open = patches.enter_context(
+                mock.patch('builtins.open'))
+
+    @mock.patch.object(azure_helper, 'cd', mock.MagicMock())
+    @mock.patch.object(azure_helper.temp_utils, 'mkdtemp')
+    def test_openssl_manager_creates_a_tmpdir(self, mkdtemp):
+        manager = azure_helper.OpenSSLManager()
+        self.assertEqual(mkdtemp.return_value, manager.tmpdir)
+
+    def test_generate_certificate_uses_tmpdir(self):
+        subp_directory = {}
+
+        def capture_directory(*args, **kwargs):
+            subp_directory['path'] = os.getcwd()
+
+        self.subp.side_effect = capture_directory
+        manager = azure_helper.OpenSSLManager()
+        self.assertEqual(manager.tmpdir, subp_directory['path'])
+        manager.clean_up()
+
+    @mock.patch.object(azure_helper, 'cd', mock.MagicMock())
+    @mock.patch.object(azure_helper.temp_utils, 'mkdtemp', mock.MagicMock())
+    @mock.patch.object(azure_helper.util, 'del_dir')
+    def test_clean_up(self, del_dir):
+        manager = azure_helper.OpenSSLManager()
+        manager.clean_up()
+        self.assertEqual([mock.call(manager.tmpdir)], del_dir.call_args_list)
+
+
+class TestOpenSSLManagerActions(CiTestCase):
+
+    def setUp(self):
+        super(TestOpenSSLManagerActions, self).setUp()
+
+        self.allowed_subp = True
+
+    def _data_file(self, name):
+        path = 'tests/data/azure'
+        return os.path.join(path, name)
+
+    @unittest.skip("todo move to cloud_test")
+    def test_pubkey_extract(self):
+        cert = load_file(self._data_file('pubkey_extract_cert'))
+        good_key = load_file(self._data_file('pubkey_extract_ssh_key'))
+        sslmgr = azure_helper.OpenSSLManager()
+        key = sslmgr._get_ssh_key_from_cert(cert)
+        self.assertEqual(good_key, key)
+
+        good_fingerprint = '073E19D14D1C799224C6A0FD8DDAB6A8BF27D473'
+        fingerprint = sslmgr._get_fingerprint_from_cert(cert)
+        self.assertEqual(good_fingerprint, fingerprint)
+
+    @unittest.skip("todo move to cloud_test")
+    @mock.patch.object(azure_helper.OpenSSLManager, '_decrypt_certs_from_xml')
+    def test_parse_certificates(self, mock_decrypt_certs):
+        """Azure control plane puts private keys as well as certificates
+           into the Certificates XML object. Make sure only the public keys
+           from certs are extracted and that fingerprints are converted to
+           the form specified in the ovf-env.xml file.
+        """
+        cert_contents = load_file(self._data_file('parse_certificates_pem'))
+        fingerprints = load_file(self._data_file(
+            'parse_certificates_fingerprints')
+        ).splitlines()
+        mock_decrypt_certs.return_value = cert_contents
+        sslmgr = azure_helper.OpenSSLManager()
+        keys_by_fp = sslmgr.parse_certificates('')
+        for fp in keys_by_fp.keys():
+            self.assertIn(fp, fingerprints)
+        for fp in fingerprints:
+            self.assertIn(fp, keys_by_fp)
+
+
+class TestGoalStateHealthReporter(CiTestCase):
+
+    maxDiff = None
+
+    default_parameters = {
+        'incarnation': 1634,
+        'container_id': 'MyContainerId',
+        'instance_id': 'MyInstanceId'
+    }
+
+    test_azure_endpoint = 'TestEndpoint'
+    test_health_report_url = 'http://{0}/machine?comp=health'.format(
+        test_azure_endpoint)
+    test_default_headers = {'Content-Type': 'text/xml; charset=utf-8'}
+
+    provisioning_success_status = 'Ready'
+    provisioning_not_ready_status = 'NotReady'
+    provisioning_failure_substatus = 'ProvisioningFailed'
+    provisioning_failure_err_description = (
+        'Test error message containing provisioning failure details')
+
+    def setUp(self):
+        super(TestGoalStateHealthReporter, self).setUp()
+        patches = ExitStack()
+        self.addCleanup(patches.close)
+
+        patches.enter_context(
+            mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock()))
+        self.read_file_or_url = patches.enter_context(
+            mock.patch.object(azure_helper.url_helper, 'read_file_or_url'))
+
+        self.post = patches.enter_context(
+            mock.patch.object(azure_helper.AzureEndpointHttpClient,
+                              'post'))
+
+        self.GoalState = patches.enter_context(
+            mock.patch.object(azure_helper, 'GoalState'))
+        self.GoalState.return_value.container_id = \
+            self.default_parameters['container_id']
+        self.GoalState.return_value.instance_id = \
+            self.default_parameters['instance_id']
+        self.GoalState.return_value.incarnation = \
+            self.default_parameters['incarnation']
+
+    def _text_from_xpath_in_xroot(self, xroot, xpath):
+        element = xroot.find(xpath)
+        if element is not None:
+            return element.text
+        return None
+
+    def _get_formatted_health_report_xml_string(self, **kwargs):
+        return HEALTH_REPORT_XML_TEMPLATE.format(**kwargs)
+
+    def _get_formatted_health_detail_subsection_xml_string(self, **kwargs):
+        return HEALTH_DETAIL_SUBSECTION_XML_TEMPLATE.format(**kwargs)
+
+    def _get_report_ready_health_document(self):
+        return self._get_formatted_health_report_xml_string(
+            incarnation=escape(str(self.default_parameters['incarnation'])),
+            container_id=escape(self.default_parameters['container_id']),
+            instance_id=escape(self.default_parameters['instance_id']),
+            health_status=escape(self.provisioning_success_status),
+            health_detail_subsection='')
+
+    def _get_report_failure_health_document(self):
+        health_detail_subsection = \
+            self._get_formatted_health_detail_subsection_xml_string(
+                health_substatus=escape(self.provisioning_failure_substatus),
+                health_description=escape(
+                    self.provisioning_failure_err_description))
+
+        return self._get_formatted_health_report_xml_string(
+            incarnation=escape(str(self.default_parameters['incarnation'])),
+            container_id=escape(self.default_parameters['container_id']),
+            instance_id=escape(self.default_parameters['instance_id']),
+            health_status=escape(self.provisioning_not_ready_status),
+            health_detail_subsection=health_detail_subsection)
+
+    def test_send_ready_signal_sends_post_request(self):
+        with mock.patch.object(
+                azure_helper.GoalStateHealthReporter,
+                'build_report') as m_build_report:
+            client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
+            reporter = azure_helper.GoalStateHealthReporter(
+                azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()),
+                client, self.test_azure_endpoint)
+            reporter.send_ready_signal()
+
+            self.assertEqual(1, self.post.call_count)
+            self.assertEqual(
+                mock.call(
+                    self.test_health_report_url,
+                    data=m_build_report.return_value,
+                    extra_headers=self.test_default_headers),
+                self.post.call_args)
+
+    def test_send_failure_signal_sends_post_request(self):
+        with mock.patch.object(
+                azure_helper.GoalStateHealthReporter,
+                'build_report') as m_build_report:
+            client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
+            reporter = azure_helper.GoalStateHealthReporter(
+                azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()),
+                client, self.test_azure_endpoint)
+            reporter.send_failure_signal(
+                description=self.provisioning_failure_err_description)
+
+            self.assertEqual(1, self.post.call_count)
+            self.assertEqual(
+                mock.call(
+                    self.test_health_report_url,
+                    data=m_build_report.return_value,
+                    extra_headers=self.test_default_headers),
+                self.post.call_args)
+
+    def test_build_report_for_ready_signal_health_document(self):
+        health_document = self._get_report_ready_health_document()
+        reporter = azure_helper.GoalStateHealthReporter(
+            azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()),
+            azure_helper.AzureEndpointHttpClient(mock.MagicMock()),
+            self.test_azure_endpoint)
+        generated_health_document = reporter.build_report(
+            incarnation=self.default_parameters['incarnation'],
+            container_id=self.default_parameters['container_id'],
+            instance_id=self.default_parameters['instance_id'],
+            status=self.provisioning_success_status)
+
+        self.assertEqual(health_document, generated_health_document)
+
+        generated_xroot = ElementTree.fromstring(generated_health_document)
+        self.assertEqual(
+            self._text_from_xpath_in_xroot(
+                generated_xroot, './GoalStateIncarnation'),
+            str(self.default_parameters['incarnation']))
+        self.assertEqual(
+            self._text_from_xpath_in_xroot(
+                generated_xroot, './Container/ContainerId'),
+            str(self.default_parameters['container_id']))
+        self.assertEqual(
+            self._text_from_xpath_in_xroot(
+                generated_xroot,
+                './Container/RoleInstanceList/Role/InstanceId'),
+            str(self.default_parameters['instance_id']))
+        self.assertEqual(
+            self._text_from_xpath_in_xroot(
+                generated_xroot,
+                './Container/RoleInstanceList/Role/Health/State'),
+            escape(self.provisioning_success_status))
+        self.assertIsNone(
+            self._text_from_xpath_in_xroot(
+                generated_xroot,
+                './Container/RoleInstanceList/Role/Health/Details'))
+        self.assertIsNone(
+            self._text_from_xpath_in_xroot(
+                generated_xroot,
+                './Container/RoleInstanceList/Role/Health/Details/SubStatus'))
+        self.assertIsNone(
+            self._text_from_xpath_in_xroot(
+                generated_xroot,
+                './Container/RoleInstanceList/Role/Health/Details/Description')
+        )
+
+    def test_build_report_for_failure_signal_health_document(self):
+        health_document = self._get_report_failure_health_document()
+        reporter = azure_helper.GoalStateHealthReporter(
+            azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()),
+            azure_helper.AzureEndpointHttpClient(mock.MagicMock()),
+            self.test_azure_endpoint)
+        generated_health_document = reporter.build_report(
+            incarnation=self.default_parameters['incarnation'],
+            container_id=self.default_parameters['container_id'],
+            instance_id=self.default_parameters['instance_id'],
+            status=self.provisioning_not_ready_status,
+            substatus=self.provisioning_failure_substatus,
+            description=self.provisioning_failure_err_description)
+
+        self.assertEqual(health_document, generated_health_document)
+
+        generated_xroot = ElementTree.fromstring(generated_health_document)
+        self.assertEqual(
+            self._text_from_xpath_in_xroot(
+                generated_xroot, './GoalStateIncarnation'),
+            str(self.default_parameters['incarnation']))
+        self.assertEqual(
+            self._text_from_xpath_in_xroot(
+                generated_xroot, './Container/ContainerId'),
+            self.default_parameters['container_id'])
+        self.assertEqual(
+            self._text_from_xpath_in_xroot(
+                generated_xroot,
+                './Container/RoleInstanceList/Role/InstanceId'),
+            self.default_parameters['instance_id'])
+        self.assertEqual(
+            self._text_from_xpath_in_xroot(
+                generated_xroot,
+                './Container/RoleInstanceList/Role/Health/State'),
+            escape(self.provisioning_not_ready_status))
+        self.assertEqual(
+            self._text_from_xpath_in_xroot(
+                generated_xroot,
+                './Container/RoleInstanceList/Role/Health/Details/'
+                'SubStatus'),
+            escape(self.provisioning_failure_substatus))
+        self.assertEqual(
+            self._text_from_xpath_in_xroot(
+                generated_xroot,
+                './Container/RoleInstanceList/Role/Health/Details/'
+                'Description'),
+            escape(self.provisioning_failure_err_description))
+
+    def test_send_ready_signal_calls_build_report(self):
+        with mock.patch.object(
+            azure_helper.GoalStateHealthReporter, 'build_report'
+        ) as m_build_report:
+            reporter = azure_helper.GoalStateHealthReporter(
+                azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()),
+                azure_helper.AzureEndpointHttpClient(mock.MagicMock()),
+                self.test_azure_endpoint)
+            reporter.send_ready_signal()
+
+            self.assertEqual(1, m_build_report.call_count)
+            self.assertEqual(
+                mock.call(
+                    incarnation=self.default_parameters['incarnation'],
+                    container_id=self.default_parameters['container_id'],
+                    instance_id=self.default_parameters['instance_id'],
+                    status=self.provisioning_success_status),
+                m_build_report.call_args)
+
+    def test_send_failure_signal_calls_build_report(self):
+        with mock.patch.object(
+            azure_helper.GoalStateHealthReporter, 'build_report'
+        ) as m_build_report:
+            reporter = azure_helper.GoalStateHealthReporter(
+                azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()),
+                azure_helper.AzureEndpointHttpClient(mock.MagicMock()),
+                self.test_azure_endpoint)
+            reporter.send_failure_signal(
+                description=self.provisioning_failure_err_description)
+
+            self.assertEqual(1, m_build_report.call_count)
+            self.assertEqual(
+                mock.call(
+                    incarnation=self.default_parameters['incarnation'],
+                    container_id=self.default_parameters['container_id'],
+                    instance_id=self.default_parameters['instance_id'],
+                    status=self.provisioning_not_ready_status,
+                    substatus=self.provisioning_failure_substatus,
+                    description=self.provisioning_failure_err_description),
+                m_build_report.call_args)
+
+    def test_build_report_escapes_chars(self):
+        incarnation = 'jd8\'9*&^<\'A><A[p&o+\"SD()*&&&LKAJSD23'
+        container_id = '&&<\"><><ds8\'9+7&d9a86!@($09asdl;<>'
+        instance_id = 'Opo>>>jas\'&d;[p&fp\"a<<!!@&&'
+        health_status = '&<897\"6&>&aa\'sd!@&!)((*<&>'
+        health_substatus = '&as\"d<<a&s>d<\'^@!5&6<7'
+        health_description = '&&&>!#$\"&&<as\'1!@$d&>><>&\"sd<67<]>>'
+
+        health_detail_subsection = \
+            self._get_formatted_health_detail_subsection_xml_string(
+                health_substatus=escape(health_substatus),
+                health_description=escape(health_description))
+        health_document = self._get_formatted_health_report_xml_string(
+            incarnation=escape(incarnation),
+            container_id=escape(container_id),
+            instance_id=escape(instance_id),
+            health_status=escape(health_status),
+            health_detail_subsection=health_detail_subsection)
+
+        reporter = azure_helper.GoalStateHealthReporter(
+            azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()),
+            azure_helper.AzureEndpointHttpClient(mock.MagicMock()),
+            self.test_azure_endpoint)
+        generated_health_document = reporter.build_report(
+            incarnation=incarnation,
+            container_id=container_id,
+            instance_id=instance_id,
+            status=health_status,
+            substatus=health_substatus,
+            description=health_description)
+
+        self.assertEqual(health_document, generated_health_document)
+
+    def test_build_report_conforms_to_length_limits(self):
+        reporter = azure_helper.GoalStateHealthReporter(
+            azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()),
+            azure_helper.AzureEndpointHttpClient(mock.MagicMock()),
+            self.test_azure_endpoint)
+        long_err_msg = 'a9&ea8>>>e as1< d\"q2*&(^%\'a=5<' * 100
+        generated_health_document = reporter.build_report(
+            incarnation=self.default_parameters['incarnation'],
+            container_id=self.default_parameters['container_id'],
+            instance_id=self.default_parameters['instance_id'],
+            status=self.provisioning_not_ready_status,
+            substatus=self.provisioning_failure_substatus,
+            description=long_err_msg)
+
+        generated_xroot = ElementTree.fromstring(generated_health_document)
+        generated_health_report_description = self._text_from_xpath_in_xroot(
+            generated_xroot,
+            './Container/RoleInstanceList/Role/Health/Details/Description')
+        self.assertEqual(
+            len(unescape(generated_health_report_description)),
+            HEALTH_REPORT_DESCRIPTION_TRIM_LEN)
+
+    def test_trim_description_then_escape_conforms_to_len_limits_worst_case(
+            self):
+        """When unescaped characters are XML-escaped, the length increases.
+        Char      Escape String
+        <         &lt;
+        >         &gt;
+        "         &quot;
+        '         &apos;
+        &         &amp;
+
+        We (step 1) trim the health report XML's description field,
+        and then (step 2) XML-escape the health report XML's description field.
+
+        The health report XML's description field limit within cloud-init
+        is HEALTH_REPORT_DESCRIPTION_TRIM_LEN.
+
+        The Azure platform's limit on the health report XML's description field
+        is 4096 chars.
+
+        For worst-case chars, there is a 5x blowup in length
+        when the chars are XML-escaped.
+        ' and " when XML-escaped have a 5x blowup.
+
+        Ensure that (1) trimming and then (2) XML-escaping does not blow past
+        the Azure platform's limit for health report XML's description field
+        (4096 chars).
+        """
+        reporter = azure_helper.GoalStateHealthReporter(
+            azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()),
+            azure_helper.AzureEndpointHttpClient(mock.MagicMock()),
+            self.test_azure_endpoint)
+        long_err_msg = '\'\"' * 10000
+        generated_health_document = reporter.build_report(
+            incarnation=self.default_parameters['incarnation'],
+            container_id=self.default_parameters['container_id'],
+            instance_id=self.default_parameters['instance_id'],
+            status=self.provisioning_not_ready_status,
+            substatus=self.provisioning_failure_substatus,
+            description=long_err_msg)
+
+        generated_xroot = ElementTree.fromstring(generated_health_document)
+        generated_health_report_description = self._text_from_xpath_in_xroot(
+            generated_xroot,
+            './Container/RoleInstanceList/Role/Health/Details/Description')
+        # The escaped description string should be less than
+        # the Azure platform limit for the escaped description string.
+        self.assertLessEqual(len(generated_health_report_description), 4096)
+
+
+class TestWALinuxAgentShim(CiTestCase):
+
+    def setUp(self):
+        super(TestWALinuxAgentShim, self).setUp()
+        patches = ExitStack()
+        self.addCleanup(patches.close)
+
+        self.AzureEndpointHttpClient = patches.enter_context(
+            mock.patch.object(azure_helper, 'AzureEndpointHttpClient'))
+        self.find_endpoint = patches.enter_context(
+            mock.patch.object(wa_shim, 'find_endpoint'))
+        self.GoalState = patches.enter_context(
+            mock.patch.object(azure_helper, 'GoalState'))
+        self.OpenSSLManager = patches.enter_context(
+            mock.patch.object(azure_helper, 'OpenSSLManager', autospec=True))
+        patches.enter_context(
+            mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock()))
+
+        self.test_incarnation = 'TestIncarnation'
+        self.test_container_id = 'TestContainerId'
+        self.test_instance_id = 'TestInstanceId'
+        self.GoalState.return_value.incarnation = self.test_incarnation
+        self.GoalState.return_value.container_id = self.test_container_id
+        self.GoalState.return_value.instance_id = self.test_instance_id
+
+    def test_http_client_does_not_use_certificate_for_report_ready(self):
+        shim = wa_shim()
+        shim.register_with_azure_and_fetch_data()
+        self.assertEqual(
+            [mock.call(None)],
+            self.AzureEndpointHttpClient.call_args_list)
+
+    def test_http_client_does_not_use_certificate_for_report_failure(self):
+        shim = wa_shim()
+        shim.register_with_azure_and_report_failure(description='TestDesc')
+        self.assertEqual(
+            [mock.call(None)],
+            self.AzureEndpointHttpClient.call_args_list)
+
+    def test_correct_url_used_for_goalstate_during_report_ready(self):
+        self.find_endpoint.return_value = 'test_endpoint'
+        shim = wa_shim()
+        shim.register_with_azure_and_fetch_data()
+        m_get = self.AzureEndpointHttpClient.return_value.get
+        self.assertEqual(
+            [mock.call('http://test_endpoint/machine/?comp=goalstate')],
+            m_get.call_args_list)
+        self.assertEqual(
+            [mock.call(
+                m_get.return_value.contents,
+                self.AzureEndpointHttpClient.return_value,
+                False
+            )],
+            self.GoalState.call_args_list)
+
+    def test_correct_url_used_for_goalstate_during_report_failure(self):
+        self.find_endpoint.return_value = 'test_endpoint'
+        shim = wa_shim()
+        shim.register_with_azure_and_report_failure(description='TestDesc')
+        m_get = self.AzureEndpointHttpClient.return_value.get
+        self.assertEqual(
+            [mock.call('http://test_endpoint/machine/?comp=goalstate')],
+            m_get.call_args_list)
+        self.assertEqual(
+            [mock.call(
+                m_get.return_value.contents,
+                self.AzureEndpointHttpClient.return_value,
+                False
+            )],
+            self.GoalState.call_args_list)
+
+    def test_certificates_used_to_determine_public_keys(self):
+        # if register_with_azure_and_fetch_data() isn't passed some info about
+        # the user's public keys, there's no point in even trying to parse the
+        # certificates
+        shim = wa_shim()
+        mypk = [{'fingerprint': 'fp1', 'path': 'path1'},
+                {'fingerprint': 'fp3', 'path': 'path3', 'value': ''}]
+        certs = {'fp1': 'expected-key',
+                 'fp2': 'should-not-be-found',
+                 'fp3': 'expected-no-value-key',
+                 }
+        sslmgr = self.OpenSSLManager.return_value
+        sslmgr.parse_certificates.return_value = certs
+        data = shim.register_with_azure_and_fetch_data(pubkey_info=mypk)
+        self.assertEqual(
+            [mock.call(self.GoalState.return_value.certificates_xml)],
+            sslmgr.parse_certificates.call_args_list)
+        self.assertIn('expected-key', data['public-keys'])
+        self.assertIn('expected-no-value-key', data['public-keys'])
+        self.assertNotIn('should-not-be-found', data['public-keys'])
+
+    def test_absent_certificates_produces_empty_public_keys(self):
+        mypk = [{'fingerprint': 'fp1', 'path': 'path1'}]
+        self.GoalState.return_value.certificates_xml = None
+        shim = wa_shim()
+        data = shim.register_with_azure_and_fetch_data(pubkey_info=mypk)
+        self.assertEqual([], data['public-keys'])
+
+    def test_correct_url_used_for_report_ready(self):
+        self.find_endpoint.return_value = 'test_endpoint'
+        shim = wa_shim()
+        shim.register_with_azure_and_fetch_data()
+        expected_url = 'http://test_endpoint/machine?comp=health'
+        self.assertEqual(
+            [mock.call(expected_url, data=mock.ANY, extra_headers=mock.ANY)],
+            self.AzureEndpointHttpClient.return_value.post
+                .call_args_list)
+
+    def test_correct_url_used_for_report_failure(self):
+        self.find_endpoint.return_value = 'test_endpoint'
+        shim = wa_shim()
+        shim.register_with_azure_and_report_failure(description='TestDesc')
+        expected_url = 'http://test_endpoint/machine?comp=health'
+        self.assertEqual(
+            [mock.call(expected_url, data=mock.ANY, extra_headers=mock.ANY)],
+            self.AzureEndpointHttpClient.return_value.post
+                .call_args_list)
+
+    def test_goal_state_values_used_for_report_ready(self):
+        shim = wa_shim()
+        shim.register_with_azure_and_fetch_data()
+        posted_document = (
+            self.AzureEndpointHttpClient.return_value.post
+                .call_args[1]['data']
+        )
+        self.assertIn(self.test_incarnation, posted_document)
+        self.assertIn(self.test_container_id, posted_document)
+        self.assertIn(self.test_instance_id, posted_document)
+
+    def test_goal_state_values_used_for_report_failure(self):
+        shim = wa_shim()
+        shim.register_with_azure_and_report_failure(description='TestDesc')
+        posted_document = (
+            self.AzureEndpointHttpClient.return_value.post
+                .call_args[1]['data']
+        )
+        self.assertIn(self.test_incarnation, posted_document)
+        self.assertIn(self.test_container_id, posted_document)
+        self.assertIn(self.test_instance_id, posted_document)
+
+    def test_xml_elems_in_report_ready_post(self):
+        shim = wa_shim()
+        shim.register_with_azure_and_fetch_data()
+        health_document = HEALTH_REPORT_XML_TEMPLATE.format(
+            incarnation=escape(self.test_incarnation),
+            container_id=escape(self.test_container_id),
+            instance_id=escape(self.test_instance_id),
+            health_status=escape('Ready'),
+            health_detail_subsection='')
+        posted_document = (
+            self.AzureEndpointHttpClient.return_value.post
+                .call_args[1]['data'])
+        self.assertEqual(health_document, posted_document)
+
+    def test_xml_elems_in_report_failure_post(self):
+        shim = wa_shim()
+        shim.register_with_azure_and_report_failure(description='TestDesc')
+        health_document = HEALTH_REPORT_XML_TEMPLATE.format(
+            incarnation=escape(self.test_incarnation),
+            container_id=escape(self.test_container_id),
+            instance_id=escape(self.test_instance_id),
+            health_status=escape('NotReady'),
+            health_detail_subsection=HEALTH_DETAIL_SUBSECTION_XML_TEMPLATE
+            .format(
+                health_substatus=escape('ProvisioningFailed'),
+                health_description=escape('TestDesc')))
+        posted_document = (
+            self.AzureEndpointHttpClient.return_value.post
+                .call_args[1]['data'])
+        self.assertEqual(health_document, posted_document)
+
+    @mock.patch.object(azure_helper, 'GoalStateHealthReporter', autospec=True)
+    def test_register_with_azure_and_fetch_data_calls_send_ready_signal(
+            self, m_goal_state_health_reporter):
+        shim = wa_shim()
+        shim.register_with_azure_and_fetch_data()
+        self.assertEqual(
+            1,
+            m_goal_state_health_reporter.return_value.send_ready_signal
+            .call_count)
+
+    @mock.patch.object(azure_helper, 'GoalStateHealthReporter', autospec=True)
+    def test_register_with_azure_and_report_failure_calls_send_failure_signal(
+            self, m_goal_state_health_reporter):
+        shim = wa_shim()
+        shim.register_with_azure_and_report_failure(description='TestDesc')
+        m_goal_state_health_reporter.return_value.send_failure_signal \
+            .assert_called_once_with(description='TestDesc')
+
+    def test_register_with_azure_and_report_failure_does_not_need_certificates(
+            self):
+        shim = wa_shim()
+        with mock.patch.object(
+            shim, '_fetch_goal_state_from_azure', autospec=True
+        ) as m_fetch_goal_state_from_azure:
+            shim.register_with_azure_and_report_failure(description='TestDesc')
+            m_fetch_goal_state_from_azure.assert_called_once_with(
+                need_certificate=False)
+
+    def test_clean_up_can_be_called_at_any_time(self):
+        shim = wa_shim()
+        shim.clean_up()
+
+    def test_openssl_manager_not_instantiated_by_shim_report_status(self):
+        shim = wa_shim()
+        shim.register_with_azure_and_fetch_data()
+        shim.register_with_azure_and_report_failure(description='TestDesc')
+        shim.clean_up()
+        self.OpenSSLManager.assert_not_called()
+
+    def test_clean_up_after_report_ready(self):
+        shim = wa_shim()
+        shim.register_with_azure_and_fetch_data()
+        shim.clean_up()
+        self.OpenSSLManager.return_value.clean_up.assert_not_called()
+
+    def test_clean_up_after_report_failure(self):
+        shim = wa_shim()
+        shim.register_with_azure_and_report_failure(description='TestDesc')
+        shim.clean_up()
+        self.OpenSSLManager.return_value.clean_up.assert_not_called()
+
+    def test_fetch_goalstate_during_report_ready_raises_exc_on_get_exc(self):
+        self.AzureEndpointHttpClient.return_value.get \
+            .side_effect = SentinelException
+        shim = wa_shim()
+        self.assertRaises(SentinelException,
+                          shim.register_with_azure_and_fetch_data)
+
+    def test_fetch_goalstate_during_report_failure_raises_exc_on_get_exc(self):
+        self.AzureEndpointHttpClient.return_value.get \
+            .side_effect = SentinelException
+        shim = wa_shim()
+        self.assertRaises(SentinelException,
+                          shim.register_with_azure_and_report_failure,
+                          description='TestDesc')
+
+    def test_fetch_goalstate_during_report_ready_raises_exc_on_parse_exc(self):
+        self.GoalState.side_effect = SentinelException
+        shim = wa_shim()
+        self.assertRaises(SentinelException,
+                          shim.register_with_azure_and_fetch_data)
+
+    def test_fetch_goalstate_during_report_failure_raises_exc_on_parse_exc(
+            self):
+        self.GoalState.side_effect = SentinelException
+        shim = wa_shim()
+        self.assertRaises(SentinelException,
+                          shim.register_with_azure_and_report_failure,
+                          description='TestDesc')
+
+    def test_failure_to_send_report_ready_health_doc_bubbles_up(self):
+        self.AzureEndpointHttpClient.return_value.post \
+            .side_effect = SentinelException
+        shim = wa_shim()
+        self.assertRaises(SentinelException,
+                          shim.register_with_azure_and_fetch_data)
+
+    def test_failure_to_send_report_failure_health_doc_bubbles_up(self):
+        self.AzureEndpointHttpClient.return_value.post \
+            .side_effect = SentinelException
+        shim = wa_shim()
+        self.assertRaises(SentinelException,
+                          shim.register_with_azure_and_report_failure,
+                          description='TestDesc')
+
+
+class TestGetMetadataGoalStateXMLAndReportReadyToFabric(CiTestCase):
+
+    def setUp(self):
+        super(TestGetMetadataGoalStateXMLAndReportReadyToFabric, self).setUp()
+        patches = ExitStack()
+        self.addCleanup(patches.close)
+
+        self.m_shim = patches.enter_context(
+            mock.patch.object(azure_helper, 'WALinuxAgentShim'))
+
+    def test_data_from_shim_returned(self):
+        ret = azure_helper.get_metadata_from_fabric()
+        self.assertEqual(
+            self.m_shim.return_value.register_with_azure_and_fetch_data
+                .return_value,
+            ret)
+
+    def test_success_calls_clean_up(self):
+        azure_helper.get_metadata_from_fabric()
+        self.assertEqual(1, self.m_shim.return_value.clean_up.call_count)
+
+    def test_failure_in_registration_propagates_exc_and_calls_clean_up(
+            self):
+        self.m_shim.return_value.register_with_azure_and_fetch_data \
+            .side_effect = SentinelException
+        self.assertRaises(SentinelException,
+                          azure_helper.get_metadata_from_fabric)
+        self.assertEqual(1, self.m_shim.return_value.clean_up.call_count)
+
+    def test_calls_shim_register_with_azure_and_fetch_data(self):
+        m_pubkey_info = mock.MagicMock()
+        azure_helper.get_metadata_from_fabric(pubkey_info=m_pubkey_info)
+        self.assertEqual(
+            1,
+            self.m_shim.return_value
+                .register_with_azure_and_fetch_data.call_count)
+        self.assertEqual(
+            mock.call(pubkey_info=m_pubkey_info),
+            self.m_shim.return_value
+                .register_with_azure_and_fetch_data.call_args)
+
+    def test_instantiates_shim_with_kwargs(self):
+        m_fallback_lease_file = mock.MagicMock()
+        m_dhcp_options = mock.MagicMock()
+        azure_helper.get_metadata_from_fabric(
+            fallback_lease_file=m_fallback_lease_file,
+            dhcp_opts=m_dhcp_options)
+        self.assertEqual(1, self.m_shim.call_count)
+        self.assertEqual(
+            mock.call(
+                fallback_lease_file=m_fallback_lease_file,
+                dhcp_options=m_dhcp_options),
+            self.m_shim.call_args)
+
+
+class TestGetMetadataGoalStateXMLAndReportFailureToFabric(CiTestCase):
+
+    def setUp(self):
+        super(
+            TestGetMetadataGoalStateXMLAndReportFailureToFabric, self).setUp()
+        patches = ExitStack()
+        self.addCleanup(patches.close)
+
+        self.m_shim = patches.enter_context(
+            mock.patch.object(azure_helper, 'WALinuxAgentShim'))
+
+    def test_success_calls_clean_up(self):
+        azure_helper.report_failure_to_fabric()
+        self.assertEqual(
+            1,
+            self.m_shim.return_value.clean_up.call_count)
+
+    def test_failure_in_shim_report_failure_propagates_exc_and_calls_clean_up(
+            self):
+        self.m_shim.return_value.register_with_azure_and_report_failure \
+            .side_effect = SentinelException
+        self.assertRaises(SentinelException,
+                          azure_helper.report_failure_to_fabric)
+        self.assertEqual(
+            1,
+            self.m_shim.return_value.clean_up.call_count)
+
+    def test_report_failure_to_fabric_with_desc_calls_shim_report_failure(
+            self):
+        azure_helper.report_failure_to_fabric(description='TestDesc')
+        self.m_shim.return_value.register_with_azure_and_report_failure \
+            .assert_called_once_with(description='TestDesc')
+
+    def test_report_failure_to_fabric_with_no_desc_calls_shim_report_failure(
+            self):
+        azure_helper.report_failure_to_fabric()
+        # default err message description should be shown to the user
+        # if no description is passed in
+        self.m_shim.return_value.register_with_azure_and_report_failure \
+            .assert_called_once_with(
+                description=azure_helper
+                .DEFAULT_REPORT_FAILURE_USER_VISIBLE_MESSAGE)
+
+    def test_report_failure_to_fabric_empty_desc_calls_shim_report_failure(
+            self):
+        azure_helper.report_failure_to_fabric(description='')
+        # default err message description should be shown to the user
+        # if an empty description is passed in
+        self.m_shim.return_value.register_with_azure_and_report_failure \
+            .assert_called_once_with(
+                description=azure_helper
+                .DEFAULT_REPORT_FAILURE_USER_VISIBLE_MESSAGE)
+
+    def test_instantiates_shim_with_kwargs(self):
+        m_fallback_lease_file = mock.MagicMock()
+        m_dhcp_options = mock.MagicMock()
+        azure_helper.report_failure_to_fabric(
+            fallback_lease_file=m_fallback_lease_file,
+            dhcp_opts=m_dhcp_options)
+        self.m_shim.assert_called_once_with(
+            fallback_lease_file=m_fallback_lease_file,
+            dhcp_options=m_dhcp_options)
+
+
+class TestExtractIpAddressFromNetworkd(CiTestCase):
+
+    azure_lease = dedent("""\
+    # This is private data. Do not parse.
+    ADDRESS=10.132.0.5
+    NETMASK=255.255.255.255
+    ROUTER=10.132.0.1
+    SERVER_ADDRESS=169.254.169.254
+    NEXT_SERVER=10.132.0.1
+    MTU=1460
+    T1=43200
+    T2=75600
+    LIFETIME=86400
+    DNS=169.254.169.254
+    NTP=169.254.169.254
+    DOMAINNAME=c.ubuntu-foundations.internal
+    DOMAIN_SEARCH_LIST=c.ubuntu-foundations.internal google.internal
+    HOSTNAME=tribaal-test-171002-1349.c.ubuntu-foundations.internal
+    ROUTES=10.132.0.1/32,0.0.0.0 0.0.0.0/0,10.132.0.1
+    CLIENTID=ff405663a200020000ab11332859494d7a8b4c
+    OPTION_245=624c3620
+    """)
+
+    def setUp(self):
+        super(TestExtractIpAddressFromNetworkd, self).setUp()
+        self.lease_d = self.tmp_dir()
+
+    def test_no_valid_leases_is_none(self):
+        """No valid leases should return None."""
+        self.assertIsNone(
+            wa_shim._networkd_get_value_from_leases(self.lease_d))
+
+    def test_option_245_is_found_in_single(self):
+        """A single valid lease with 245 option should return it."""
+        populate_dir(self.lease_d, {'9': self.azure_lease})
+        self.assertEqual(
+            '624c3620', wa_shim._networkd_get_value_from_leases(self.lease_d))
+
+    def test_option_245_not_found_returns_None(self):
+        """A valid lease, but no option 245 should return None."""
+        populate_dir(
+            self.lease_d,
+            {'9': self.azure_lease.replace("OPTION_245", "OPTION_999")})
+        self.assertIsNone(
+            wa_shim._networkd_get_value_from_leases(self.lease_d))
+
+    def test_multiple_returns_first(self):
+        """Somewhat arbitrarily return the first address when multiple.
+
+        Most important at the moment is that this is consistent behavior
+        rather than changing randomly as in order of a dictionary."""
+        myval = "624c3601"
+        populate_dir(
+            self.lease_d,
+            {'9': self.azure_lease,
+             '2': self.azure_lease.replace("624c3620", myval)})
+        self.assertEqual(
+            myval, wa_shim._networkd_get_value_from_leases(self.lease_d))
+
+
+# vi: ts=4 expandtab

+ 143 - 0
Azure/test_azure_key_vault.py

@@ -0,0 +1,143 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from unittest import TestCase, mock
+
+from azure.core.exceptions import ResourceNotFoundError
+
+from airflow.providers.microsoft.azure.secrets.key_vault import AzureKeyVaultBackend
+
+
+class TestAzureKeyVaultBackend(TestCase):
+    @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.get_conn_value')
+    def test_get_connections(self, mock_get_value):
+        mock_get_value.return_value = 'scheme://user:pass@host:100'
+        conn_list = AzureKeyVaultBackend().get_connections('fake_conn')
+        conn = conn_list[0]
+        assert conn.host == 'host'
+
+    @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.DefaultAzureCredential')
+    @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.SecretClient')
+    def test_get_conn_uri(self, mock_secret_client, mock_azure_cred):
+        mock_cred = mock.Mock()
+        mock_sec_client = mock.Mock()
+        mock_azure_cred.return_value = mock_cred
+        mock_secret_client.return_value = mock_sec_client
+
+        mock_sec_client.get_secret.return_value = mock.Mock(
+            value='postgresql://airflow:airflow@host:5432/airflow'
+        )
+
+        backend = AzureKeyVaultBackend(vault_url="https://example-akv-resource-name.vault.azure.net/")
+        returned_uri = backend.get_conn_uri(conn_id='hi')
+        mock_secret_client.assert_called_once_with(
+            credential=mock_cred, vault_url='https://example-akv-resource-name.vault.azure.net/'
+        )
+        assert returned_uri == 'postgresql://airflow:airflow@host:5432/airflow'
+
+    @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.client')
+    def test_get_conn_uri_non_existent_key(self, mock_client):
+        """
+        Test that if the key with connection ID is not present,
+        AzureKeyVaultBackend.get_connections should return None
+        """
+        conn_id = 'test_mysql'
+        mock_client.get_secret.side_effect = ResourceNotFoundError
+        backend = AzureKeyVaultBackend(vault_url="https://example-akv-resource-name.vault.azure.net/")
+
+        assert backend.get_conn_uri(conn_id=conn_id) is None
+        assert [] == backend.get_connections(conn_id=conn_id)
+
+    @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.client')
+    def test_get_variable(self, mock_client):
+        mock_client.get_secret.return_value = mock.Mock(value='world')
+        backend = AzureKeyVaultBackend()
+        returned_uri = backend.get_variable('hello')
+        mock_client.get_secret.assert_called_with(name='airflow-variables-hello')
+        assert 'world' == returned_uri
+
+    @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.client')
+    def test_get_variable_non_existent_key(self, mock_client):
+        """
+        Test that if Variable key is not present,
+        AzureKeyVaultBackend.get_variables should return None
+        """
+        mock_client.get_secret.side_effect = ResourceNotFoundError
+        backend = AzureKeyVaultBackend()
+        assert backend.get_variable('test_mysql') is None
+
+    @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.client')
+    def test_get_secret_value_not_found(self, mock_client):
+        """
+        Test that if a non-existent secret returns None
+        """
+        mock_client.get_secret.side_effect = ResourceNotFoundError
+        backend = AzureKeyVaultBackend()
+        assert (
+            backend._get_secret(path_prefix=backend.connections_prefix, secret_id='test_non_existent') is None
+        )
+
+    @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.client')
+    def test_get_secret_value(self, mock_client):
+        """
+        Test that get_secret returns the secret value
+        """
+        mock_client.get_secret.return_value = mock.Mock(value='super-secret')
+        backend = AzureKeyVaultBackend()
+        secret_val = backend._get_secret('af-secrets', 'test_mysql_password')
+        mock_client.get_secret.assert_called_with(name='af-secrets-test-mysql-password')
+        assert secret_val == 'super-secret'
+
+    @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend._get_secret')
+    def test_connection_prefix_none_value(self, mock_get_secret):
+        """
+        Test that if Connections prefix is None,
+        AzureKeyVaultBackend.get_connections should return None
+        AzureKeyVaultBackend._get_secret should not be called
+        """
+        kwargs = {'connections_prefix': None}
+
+        backend = AzureKeyVaultBackend(**kwargs)
+        assert backend.get_conn_uri('test_mysql') is None
+        mock_get_secret.assert_not_called()
+
+    @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend._get_secret')
+    def test_variable_prefix_none_value(self, mock_get_secret):
+        """
+        Test that if Variables prefix is None,
+        AzureKeyVaultBackend.get_variables should return None
+        AzureKeyVaultBackend._get_secret should not be called
+        """
+        kwargs = {'variables_prefix': None}
+
+        backend = AzureKeyVaultBackend(**kwargs)
+        assert backend.get_variable('hello') is None
+        mock_get_secret.assert_not_called()
+
+    @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend._get_secret')
+    def test_config_prefix_none_value(self, mock_get_secret):
+        """
+        Test that if Config prefix is None,
+        AzureKeyVaultBackend.get_config should return None
+        AzureKeyVaultBackend._get_secret should not be called
+        """
+        kwargs = {'config_prefix': None}
+
+        backend = AzureKeyVaultBackend(**kwargs)
+        assert backend.get_config('test_mysql') is None
+        mock_get_secret.assert_not_called()

+ 79 - 0
Azure/test_base_azure.py

@@ -0,0 +1,79 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import unittest
+from unittest.mock import Mock, patch
+
+from airflow.models import Connection
+from airflow.providers.microsoft.azure.hooks.base_azure import AzureBaseHook
+
+
+class TestBaseAzureHook(unittest.TestCase):
+    @patch('airflow.providers.microsoft.azure.hooks.base_azure.get_client_from_auth_file')
+    @patch(
+        'airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook.get_connection',
+        return_value=Connection(conn_id='azure_default', extra='{ "key_path": "key_file.json" }'),
+    )
+    def test_get_conn_with_key_path(self, mock_connection, mock_get_client_from_auth_file):
+        mock_sdk_client = Mock()
+
+        auth_sdk_client = AzureBaseHook(mock_sdk_client).get_conn()
+
+        mock_get_client_from_auth_file.assert_called_once_with(
+            client_class=mock_sdk_client, auth_path=mock_connection.return_value.extra_dejson['key_path']
+        )
+        assert auth_sdk_client == mock_get_client_from_auth_file.return_value
+
+    @patch('airflow.providers.microsoft.azure.hooks.base_azure.get_client_from_json_dict')
+    @patch(
+        'airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook.get_connection',
+        return_value=Connection(conn_id='azure_default', extra='{ "key_json": { "test": "test" } }'),
+    )
+    def test_get_conn_with_key_json(self, mock_connection, mock_get_client_from_json_dict):
+        mock_sdk_client = Mock()
+
+        auth_sdk_client = AzureBaseHook(mock_sdk_client).get_conn()
+
+        mock_get_client_from_json_dict.assert_called_once_with(
+            client_class=mock_sdk_client, config_dict=mock_connection.return_value.extra_dejson['key_json']
+        )
+        assert auth_sdk_client == mock_get_client_from_json_dict.return_value
+
+    @patch('airflow.providers.microsoft.azure.hooks.base_azure.ServicePrincipalCredentials')
+    @patch(
+        'airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook.get_connection',
+        return_value=Connection(
+            conn_id='azure_default',
+            login='my_login',
+            password='my_password',
+            extra='{ "tenantId": "my_tenant", "subscriptionId": "my_subscription" }',
+        ),
+    )
+    def test_get_conn_with_credentials(self, mock_connection, mock_spc):
+        mock_sdk_client = Mock()
+
+        auth_sdk_client = AzureBaseHook(mock_sdk_client).get_conn()
+
+        mock_spc.assert_called_once_with(
+            client_id=mock_connection.return_value.login,
+            secret=mock_connection.return_value.password,
+            tenant=mock_connection.return_value.extra_dejson['tenantId'],
+        )
+        mock_sdk_client.assert_called_once_with(
+            credentials=mock_spc.return_value,
+            subscription_id=mock_connection.return_value.extra_dejson['subscriptionId'],
+        )
+        assert auth_sdk_client == mock_sdk_client.return_value

+ 124 - 0
Azure/test_oracle_to_azure_data_lake.py

@@ -0,0 +1,124 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import unittest
+from tempfile import TemporaryDirectory
+from unittest import mock
+from unittest.mock import MagicMock
+
+import unicodecsv as csv
+
+from airflow.providers.microsoft.azure.transfers.oracle_to_azure_data_lake import (
+    OracleToAzureDataLakeOperator,
+)
+
+
+class TestOracleToAzureDataLakeTransfer(unittest.TestCase):
+
+    mock_module_path = 'airflow.providers.microsoft.azure.transfers.oracle_to_azure_data_lake'
+
+    def test_write_temp_file(self):
+        task_id = "some_test_id"
+        sql = "some_sql"
+        sql_params = {':p_data': "2018-01-01"}
+        oracle_conn_id = "oracle_conn_id"
+        filename = "some_filename"
+        azure_data_lake_conn_id = 'azure_data_lake_conn_id'
+        azure_data_lake_path = 'azure_data_lake_path'
+        delimiter = '|'
+        encoding = 'utf-8'
+        cursor_description = [
+            ('id', "<class 'cx_Oracle.NUMBER'>", 39, None, 38, 0, 0),
+            ('description', "<class 'cx_Oracle.STRING'>", 60, 240, None, None, 1),
+        ]
+        cursor_rows = [[1, 'description 1'], [2, 'description 2']]
+        mock_cursor = MagicMock()
+        mock_cursor.description = cursor_description
+        mock_cursor.__iter__.return_value = cursor_rows
+
+        op = OracleToAzureDataLakeOperator(
+            task_id=task_id,
+            filename=filename,
+            oracle_conn_id=oracle_conn_id,
+            sql=sql,
+            sql_params=sql_params,
+            azure_data_lake_conn_id=azure_data_lake_conn_id,
+            azure_data_lake_path=azure_data_lake_path,
+            delimiter=delimiter,
+            encoding=encoding,
+        )
+
+        with TemporaryDirectory(prefix='airflow_oracle_to_azure_op_') as temp:
+            op._write_temp_file(mock_cursor, os.path.join(temp, filename))
+
+            assert os.path.exists(os.path.join(temp, filename)) == 1
+
+            with open(os.path.join(temp, filename), 'rb') as csvfile:
+                temp_file = csv.reader(csvfile, delimiter=delimiter, encoding=encoding)
+
+                rownum = 0
+                for row in temp_file:
+                    if rownum == 0:
+                        assert row[0] == 'id'
+                        assert row[1] == 'description'
+                    else:
+                        assert row[0] == str(cursor_rows[rownum - 1][0])
+                        assert row[1] == cursor_rows[rownum - 1][1]
+                    rownum = rownum + 1
+
+    @mock.patch(mock_module_path + '.OracleHook', autospec=True)
+    @mock.patch(mock_module_path + '.AzureDataLakeHook', autospec=True)
+    def test_execute(self, mock_data_lake_hook, mock_oracle_hook):
+        task_id = "some_test_id"
+        sql = "some_sql"
+        sql_params = {':p_data': "2018-01-01"}
+        oracle_conn_id = "oracle_conn_id"
+        filename = "some_filename"
+        azure_data_lake_conn_id = 'azure_data_lake_conn_id'
+        azure_data_lake_path = 'azure_data_lake_path'
+        delimiter = '|'
+        encoding = 'latin-1'
+        cursor_description = [
+            ('id', "<class 'cx_Oracle.NUMBER'>", 39, None, 38, 0, 0),
+            ('description', "<class 'cx_Oracle.STRING'>", 60, 240, None, None, 1),
+        ]
+        cursor_rows = [[1, 'description 1'], [2, 'description 2']]
+        cursor_mock = MagicMock()
+        cursor_mock.description.return_value = cursor_description
+        cursor_mock.__iter__.return_value = cursor_rows
+        mock_oracle_conn = MagicMock()
+        mock_oracle_conn.cursor().return_value = cursor_mock
+        mock_oracle_hook.get_conn().return_value = mock_oracle_conn
+
+        op = OracleToAzureDataLakeOperator(
+            task_id=task_id,
+            filename=filename,
+            oracle_conn_id=oracle_conn_id,
+            sql=sql,
+            sql_params=sql_params,
+            azure_data_lake_conn_id=azure_data_lake_conn_id,
+            azure_data_lake_path=azure_data_lake_path,
+            delimiter=delimiter,
+            encoding=encoding,
+        )
+
+        op.execute(None)
+
+        mock_oracle_hook.assert_called_once_with(oracle_conn_id=oracle_conn_id)
+        mock_data_lake_hook.assert_called_once_with(azure_data_lake_conn_id=azure_data_lake_conn_id)

+ 469 - 0
Azure/validate_azure_dladmin_identity.py

@@ -0,0 +1,469 @@
+#!/usr/bin/env python3
+###
+# CLOUDERA CDP Control (cdpctl)
+#
+# (C) Cloudera, Inc. 2021-2021
+# All rights reserved.
+#
+# Applicable Open Source License: GNU AFFERO GENERAL PUBLIC LICENSE
+#
+# NOTE: Cloudera open source products are modular software products
+# made up of hundreds of individual components, each of which was
+# individually copyrighted.  Each Cloudera open source product is a
+# collective work under U.S. Copyright Law. Your license to use the
+# collective work is as provided in your written agreement with
+# Cloudera.  Used apart from the collective work, this file is
+# licensed for your use pursuant to the open source license
+# identified above.
+#
+# This code is provided to you pursuant a written agreement with
+# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute
+# this code. If you do not have a written agreement with Cloudera nor
+# with an authorized and properly licensed third party, you do not
+# have any rights to access nor to use this code.
+#
+# Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the
+# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY
+# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED
+# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO
+# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND
+# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU,
+# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS
+# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE
+# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY
+# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR
+# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES
+# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF
+# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
+# DATA.
+#
+# Source File Name:  validate_azure_dladmin_identity.py
+###
+"""Validation of Azure Datalake Admin Identity."""
+from typing import Any, Dict, List
+
+import pytest
+from azure.mgmt.authorization import AuthorizationManagementClient
+from azure.mgmt.resource import ResourceManagementClient
+
+from cdpctl.validation import fail, get_config_value
+from cdpctl.validation.azure_utils import (
+    check_for_actions,
+    get_client,
+    get_role_assignments,
+    get_storage_container_scope,
+    parse_adls_path,
+)
+from cdpctl.validation.infra.issues import (
+    AZURE_IDENTITY_MISSING_ACTIONS_FOR_LOCATION,
+    AZURE_IDENTITY_MISSING_DATA_ACTIONS_FOR_LOCATION,
+)
+
+
+@pytest.fixture(autouse=True, name="resource_client")
+def resource_client_fixture(config: Dict[str, Any]) -> ResourceManagementClient:
+    """Return an Azure Resource Client."""
+    return get_client("resource", config)
+
+
+@pytest.fixture(autouse=True, name="auth_client")
+def auth_client_fixture(config: Dict[str, Any]) -> AuthorizationManagementClient:
+    """Return an Azure Auth Client."""
+    return get_client("auth", config)
+
+
+@pytest.mark.azure
+@pytest.mark.infra
+def azure_dladmin_actions_for_logs_storage_validation(
+    config: Dict[str, Any],
+    auth_client: AuthorizationManagementClient,
+    resource_client: ResourceManagementClient,
+    azure_data_required_actions,
+) -> None:  # pragma: no cover
+    """Datalake Admin Identity has required Actions on logs storage location."""  # noqa: D401,E501
+    _azure_dladmin_logs_storage_actions_check(
+        config=config,
+        auth_client=auth_client,
+        resource_client=resource_client,
+        azure_data_required_actions=azure_data_required_actions,
+    )
+
+
+def _azure_dladmin_logs_storage_actions_check(
+    config: Dict[str, Any],
+    auth_client: AuthorizationManagementClient,
+    resource_client: ResourceManagementClient,
+    azure_data_required_actions: List[str],
+) -> None:  # pragma: no cover
+    # noqa: D401,E501
+    sub_id: str = get_config_value(config=config, key="infra:azure:subscription_id")
+    rg_name: str = get_config_value(config=config, key="infra:azure:metagroup:name")
+    storage_name: str = get_config_value(config=config, key="env:azure:storage:name")
+    log_path: str = get_config_value(config=config, key="env:azure:storage:path:logs")
+    datalake_admin: str = get_config_value(
+        config=config, key="env:azure:role:name:datalake_admin"
+    )
+
+    parsed_logger_path = parse_adls_path(log_path)
+    container_name = parsed_logger_path[1]
+
+    role_assignments = get_role_assignments(
+        auth_client=auth_client,
+        resource_client=resource_client,
+        identity_name=datalake_admin,
+        subscription_id=sub_id,
+        resource_group=rg_name,
+    )
+
+    proper_scope = get_storage_container_scope(
+        sub_id, rg_name, storage_name, container_name
+    )
+
+    missing_actions, _ = check_for_actions(
+        auth_client=auth_client,
+        role_assigments=role_assignments,
+        proper_scope=proper_scope,
+        required_actions=azure_data_required_actions,
+        required_data_actions=[],
+    )
+
+    if missing_actions:
+        fail(
+            AZURE_IDENTITY_MISSING_ACTIONS_FOR_LOCATION,
+            subjects=[
+                datalake_admin,
+                f"storageAccounts/{storage_name}/blobServices/default/containers/{container_name}",  # noqa: E501
+            ],
+            resources=missing_actions,
+        )
+
+
+@pytest.mark.azure
+@pytest.mark.infra
+def azure_dladmin_data_actions_for_logs_storage_validation(
+    config: Dict[str, Any],
+    auth_client: AuthorizationManagementClient,
+    resource_client: ResourceManagementClient,
+    azure_data_required_data_actions,
+) -> None:  # pragma: no cover
+    """Datalake Admin Identity has required Data Actions on logs storage location."""  # noqa: D401,E501
+    _azure_dladmin_logs_storage_data_actions_check(
+        config=config,
+        auth_client=auth_client,
+        resource_client=resource_client,
+        azure_data_required_data_actions=azure_data_required_data_actions,
+    )
+
+
+def _azure_dladmin_logs_storage_data_actions_check(
+    config: Dict[str, Any],
+    auth_client: AuthorizationManagementClient,
+    resource_client: ResourceManagementClient,
+    azure_data_required_data_actions: List[str],
+) -> None:  # pragma: no cover
+    # noqa: D401,E501
+    sub_id: str = get_config_value(config=config, key="infra:azure:subscription_id")
+    rg_name: str = get_config_value(config=config, key="infra:azure:metagroup:name")
+    storage_name: str = get_config_value(config=config, key="env:azure:storage:name")
+    log_path: str = get_config_value(config=config, key="env:azure:storage:path:logs")
+    datalake_admin: str = get_config_value(
+        config=config, key="env:azure:role:name:datalake_admin"
+    )
+
+    parsed_logger_path = parse_adls_path(log_path)
+    container_name = parsed_logger_path[1]
+
+    role_assignments = get_role_assignments(
+        auth_client=auth_client,
+        resource_client=resource_client,
+        identity_name=datalake_admin,
+        subscription_id=sub_id,
+        resource_group=rg_name,
+    )
+
+    proper_scope = get_storage_container_scope(
+        sub_id, rg_name, storage_name, container_name
+    )
+
+    _, missing_data_actions = check_for_actions(
+        auth_client=auth_client,
+        role_assigments=role_assignments,
+        proper_scope=proper_scope,
+        required_actions=[],
+        required_data_actions=azure_data_required_data_actions,
+    )
+    if missing_data_actions:
+        fail(
+            AZURE_IDENTITY_MISSING_DATA_ACTIONS_FOR_LOCATION,
+            subjects=[
+                datalake_admin,
+                f"storageAccounts/{storage_name}/blobServices/default/containers/{container_name}",  # noqa: E501
+            ],
+            resources=missing_data_actions,
+        )
+
+
+@pytest.mark.azure
+@pytest.mark.infra
+def azure_dladmin_actions_for_data_storage_validation(
+    config: Dict[str, Any],
+    auth_client: AuthorizationManagementClient,
+    resource_client: ResourceManagementClient,
+    azure_data_required_actions,
+) -> None:  # pragma: no cover
+    """Datalake Admin Identity has required Actions on data storage location."""  # noqa: D401,E501
+    _azure_dladmin_data_storage_actions_check(
+        config=config,
+        auth_client=auth_client,
+        resource_client=resource_client,
+        azure_data_required_actions=azure_data_required_actions,
+    )
+
+
+def _azure_dladmin_data_storage_actions_check(
+    config: Dict[str, Any],
+    auth_client: AuthorizationManagementClient,
+    resource_client: ResourceManagementClient,
+    azure_data_required_actions: List[str],
+) -> None:  # pragma: no cover
+    # noqa: D401,E501
+    sub_id: str = get_config_value(config=config, key="infra:azure:subscription_id")
+    rg_name: str = get_config_value(config=config, key="infra:azure:metagroup:name")
+    storage_name: str = get_config_value(config=config, key="env:azure:storage:name")
+    data_path: str = get_config_value(config=config, key="env:azure:storage:path:data")
+    datalake_admin: str = get_config_value(
+        config=config, key="env:azure:role:name:datalake_admin"
+    )
+
+    parsed_data_path = parse_adls_path(data_path)
+    container_name = parsed_data_path[1]
+
+    role_assignments = get_role_assignments(
+        auth_client=auth_client,
+        resource_client=resource_client,
+        identity_name=datalake_admin,
+        subscription_id=sub_id,
+        resource_group=rg_name,
+    )
+
+    proper_scope = get_storage_container_scope(
+        sub_id, rg_name, storage_name, container_name
+    )
+
+    missing_actions, _ = check_for_actions(
+        auth_client=auth_client,
+        role_assigments=role_assignments,
+        proper_scope=proper_scope,
+        required_actions=azure_data_required_actions,
+        required_data_actions=[],
+    )
+
+    if missing_actions:
+        fail(
+            AZURE_IDENTITY_MISSING_ACTIONS_FOR_LOCATION,
+            subjects=[
+                datalake_admin,
+                f"storageAccounts/{storage_name}/blobServices/default/containers/{container_name}",  # noqa: E501
+            ],
+            resources=missing_actions,
+        )
+
+
+@pytest.mark.azure
+@pytest.mark.infra
+def azure_dladmin_data_actions_for_data_storage_validation(
+    config: Dict[str, Any],
+    auth_client: AuthorizationManagementClient,
+    resource_client: ResourceManagementClient,
+    azure_data_required_data_actions,
+) -> None:  # pragma: no cover
+    """Datalake Admin Identity has required Data Actions on data storage location."""  # noqa: D401,E501
+    _azure_dladmin_data_storage_data_actions_check(
+        config=config,
+        auth_client=auth_client,
+        resource_client=resource_client,
+        azure_data_required_data_actions=azure_data_required_data_actions,
+    )
+
+
+def _azure_dladmin_data_storage_data_actions_check(
+    config: Dict[str, Any],
+    auth_client: AuthorizationManagementClient,
+    resource_client: ResourceManagementClient,
+    azure_data_required_data_actions: List[str],
+) -> None:  # pragma: no cover
+    # noqa: D401,E501
+    sub_id: str = get_config_value(config=config, key="infra:azure:subscription_id")
+    rg_name: str = get_config_value(config=config, key="infra:azure:metagroup:name")
+    storage_name: str = get_config_value(config=config, key="env:azure:storage:name")
+    data_path: str = get_config_value(config=config, key="env:azure:storage:path:data")
+    datalake_admin: str = get_config_value(
+        config=config, key="env:azure:role:name:datalake_admin"
+    )
+
+    parsed_data_path = parse_adls_path(data_path)
+    container_name = parsed_data_path[1]
+
+    role_assignments = get_role_assignments(
+        auth_client=auth_client,
+        resource_client=resource_client,
+        identity_name=datalake_admin,
+        subscription_id=sub_id,
+        resource_group=rg_name,
+    )
+
+    proper_scope = get_storage_container_scope(
+        sub_id, rg_name, storage_name, container_name
+    )
+
+    _, missing_data_actions = check_for_actions(
+        auth_client=auth_client,
+        role_assigments=role_assignments,
+        proper_scope=proper_scope,
+        required_actions=[],
+        required_data_actions=azure_data_required_data_actions,
+    )
+    if missing_data_actions:
+        fail(
+            AZURE_IDENTITY_MISSING_DATA_ACTIONS_FOR_LOCATION,
+            subjects=[
+                datalake_admin,
+                f"storageAccounts/{storage_name}/blobServices/default/containers/{container_name}",  # noqa: E501
+            ],
+            resources=missing_data_actions,
+        )
+
+
+@pytest.mark.azure
+@pytest.mark.infra
+def azure_dladmin_actions_for_backup_storage_validation(
+    config: Dict[str, Any],
+    auth_client: AuthorizationManagementClient,
+    resource_client: ResourceManagementClient,
+    azure_data_required_actions,
+) -> None:  # pragma: no cover
+    """Datalake Admin Identity has required Actions on backup storage location."""  # noqa: D401,E501
+    _azure_dladmin_backup_storage_actions_check(
+        config=config,
+        auth_client=auth_client,
+        resource_client=resource_client,
+        azure_data_required_actions=azure_data_required_actions,
+    )
+
+
+def _azure_dladmin_backup_storage_actions_check(
+    config: Dict[str, Any],
+    auth_client: AuthorizationManagementClient,
+    resource_client: ResourceManagementClient,
+    azure_data_required_actions: List[str],
+) -> None:  # pragma: no cover
+    # noqa: D401,E501
+    sub_id: str = get_config_value(config=config, key="infra:azure:subscription_id")
+    rg_name: str = get_config_value(config=config, key="infra:azure:metagroup:name")
+    storage_name: str = get_config_value(config=config, key="env:azure:storage:name")
+    backup_path: str = get_config_value(
+        config=config, key="env:azure:storage:path:backup"
+    )
+    datalake_admin: str = get_config_value(
+        config=config, key="env:azure:role:name:datalake_admin"
+    )
+
+    parsed_logger_path = parse_adls_path(backup_path)
+    container_name = parsed_logger_path[1]
+
+    role_assignments = get_role_assignments(
+        auth_client=auth_client,
+        resource_client=resource_client,
+        identity_name=datalake_admin,
+        subscription_id=sub_id,
+        resource_group=rg_name,
+    )
+
+    proper_scope = get_storage_container_scope(
+        sub_id, rg_name, storage_name, container_name
+    )
+
+    missing_actions, _ = check_for_actions(
+        auth_client=auth_client,
+        role_assigments=role_assignments,
+        proper_scope=proper_scope,
+        required_actions=azure_data_required_actions,
+        required_data_actions=[],
+    )
+
+    if missing_actions:
+        fail(
+            AZURE_IDENTITY_MISSING_ACTIONS_FOR_LOCATION,
+            subjects=[
+                datalake_admin,
+                f"storageAccounts/{storage_name}/blobServices/default/containers/{container_name}",  # noqa: E501
+            ],
+            resources=missing_actions,
+        )
+
+
+@pytest.mark.azure
+@pytest.mark.infra
+def azure_dladmin_data_actions_for_backup_storage_validation(
+    config: Dict[str, Any],
+    auth_client: AuthorizationManagementClient,
+    resource_client: ResourceManagementClient,
+    azure_data_required_data_actions,
+) -> None:  # pragma: no cover
+    """Datalake Admin Identity has required Data Actions on backup storage location."""  # noqa: D401,E501
+    _azure_dladmin_backup_storage_data_actions_check(
+        config=config,
+        auth_client=auth_client,
+        resource_client=resource_client,
+        azure_data_required_data_actions=azure_data_required_data_actions,
+    )
+
+
+def _azure_dladmin_backup_storage_data_actions_check(
+    config: Dict[str, Any],
+    auth_client: AuthorizationManagementClient,
+    resource_client: ResourceManagementClient,
+    azure_data_required_data_actions: List[str],
+) -> None:  # pragma: no cover
+    # noqa: D401,E501
+    sub_id: str = get_config_value(config=config, key="infra:azure:subscription_id")
+    rg_name: str = get_config_value(config=config, key="infra:azure:metagroup:name")
+    storage_name: str = get_config_value(config=config, key="env:azure:storage:name")
+    backup_path: str = get_config_value(
+        config=config, key="env:azure:storage:path:backup"
+    )
+    datalake_admin: str = get_config_value(
+        config=config, key="env:azure:role:name:datalake_admin"
+    )
+
+    parsed_logger_path = parse_adls_path(backup_path)
+    container_name = parsed_logger_path[1]
+
+    role_assignments = get_role_assignments(
+        auth_client=auth_client,
+        resource_client=resource_client,
+        identity_name=datalake_admin,
+        subscription_id=sub_id,
+        resource_group=rg_name,
+    )
+
+    proper_scope = get_storage_container_scope(
+        sub_id, rg_name, storage_name, container_name
+    )
+
+    _, missing_data_actions = check_for_actions(
+        auth_client=auth_client,
+        role_assigments=role_assignments,
+        proper_scope=proper_scope,
+        required_actions=[],
+        required_data_actions=azure_data_required_data_actions,
+    )
+    if missing_data_actions:
+        fail(
+            AZURE_IDENTITY_MISSING_DATA_ACTIONS_FOR_LOCATION,
+            subjects=[
+                datalake_admin,
+                f"storageAccounts/{storage_name}/blobServices/default/containers/{container_name}",  # noqa: E501
+            ],
+            resources=missing_data_actions,
+        )

+ 118 - 0
File/outbuf.py

@@ -0,0 +1,118 @@
+#!/usr/bin/python -u
+import sys
+import libxml2
+import StringIO
+
+
+def testSimpleBufferWrites():
+    f = StringIO.StringIO()
+    buf = libxml2.createOutputBuffer(f, "ISO-8859-1")
+    buf.write(3, "foo")
+    buf.writeString("bar")
+    buf.close()
+
+    if f.getvalue() != "foobar":
+        print
+        "Failed to save to StringIO"
+        sys.exit(1)
+
+
+def testSaveDocToBuffer():
+    """
+    Regression test for bug #154294.
+    """
+    input = '<foo>Hello</foo>'
+    expected = '''\
+<?xml version="1.0" encoding="UTF-8"?>
+<foo>Hello</foo>
+'''
+    f = StringIO.StringIO()
+    buf = libxml2.createOutputBuffer(f, 'UTF-8')
+    doc = libxml2.parseDoc(input)
+    doc.saveFileTo(buf, 'UTF-8')
+    doc.freeDoc()
+
+
+
+def testSaveFormattedDocToBuffer():
+    input = '<outer><inner>Some text</inner><inner/></outer>'
+    # The formatted and non-formatted versions of the output.
+    expected = ('''\
+<?xml version="1.0" encoding="UTF-8"?>
+<outer><inner>Some text</inner><inner/></outer>
+''', '''\
+<?xml version="1.0" encoding="UTF-8"?>
+<outer>
+  <inner>Some text</inner>
+  <inner/>
+</outer>
+''')
+    doc = libxml2.parseDoc(input)
+    for i in (0, 1):
+        f = StringIO.StringIO()
+        buf = libxml2.createOutputBuffer(f, 'UTF-8')
+        doc.saveFormatFileTo(buf, 'UTF-8', i)
+        if f.getvalue() != expected[i]:
+            print
+            'xmlDoc.saveFormatFileTo() call failed.'
+            print
+            '     got: %s' % repr(f.getvalue())
+            print
+            'expected: %s' % repr(expected[i])
+            sys.exit(1)
+    doc.freeDoc()
+
+
+def testSaveIntoOutputBuffer():
+    """
+    Similar to the previous two tests, except this time we invoke the save
+    methods on the output buffer object and pass in an XML node object.
+    """
+    input = '<foo>Hello</foo>'
+    expected = '''\
+<?xml version="1.0" encoding="UTF-8"?>
+<foo>Hello</foo>
+'''
+    f = StringIO.StringIO()
+    doc = libxml2.parseDoc(input)
+    buf = libxml2.createOutputBuffer(f, 'UTF-8')
+    buf.saveFileTo(doc, 'UTF-8')
+    if f.getvalue() != expected:
+        print
+        'outputBuffer.saveFileTo() call failed.'
+        print
+        '     got: %s' % repr(f.getvalue())
+        print
+        'expected: %s' % repr(expected)
+        sys.exit(1)
+    f = StringIO.StringIO()
+    buf = libxml2.createOutputBuffer(f, 'UTF-8')
+    buf.saveFormatFileTo(doc, 'UTF-8', 1)
+    if f.getvalue() != expected:
+        print
+        'outputBuffer.saveFormatFileTo() call failed.'
+        print
+        '     got: %s' % repr(f.getvalue())
+        print
+        'expected: %s' % repr(expected)
+        sys.exit(1)
+    doc.freeDoc()
+
+
+if __name__ == '__main__':
+    # Memory debug specific
+    libxml2.debugMemory(1)
+
+    testSimpleBufferWrites()
+    testSaveDocToBuffer()
+    testSaveFormattedDocToBuffer()
+    testSaveIntoOutputBuffer()
+
+    libxml2.cleanupParser()
+    if libxml2.debugMemory(1) == 0:
+        print
+        "OK"
+    else:
+        print
+        "Memory leak %d bytes" % (libxml2.debugMemory(1))
+        libxml2.dumpMemory()

+ 891 - 1
File/utils.py

@@ -31,4 +31,894 @@ def save_task_checkpoint(file_path, task_num):
         task_num (int): Number of task increment.
     """
     save_path = os.path.join(file_path, 'checkpoint_task_' + str(task_num) + '.pth.tar')
-    shutil.copyfile(os.path.join(file_path, 'checkpoint.pth.tar'), save_path)
+    shutil.copyfile(os.path.join(file_path, 'checkpoint.pth.tar'), save_path)
+
+
+def pickle_dump(item, out_file):
+    with open(out_file, "wb") as opened_file:
+        pickle.dump(item, opened_file)
+
+
+def write_to_clf(clf_data, save_file):
+    # Save dataset for text classification to file.
+    """
+    clf_data: List[List[str]] [[text1, label1],[text2,label2]...]
+    file format: tsv, row: text + tab + label
+    """
+    with open(save_file, 'w', encoding='utf-8') as f:
+        f.writelines("\n".join(["\t".join(str(r) for r in row) for row in clf_data]))
+
+
+def write_to_seq2seq(seq_data, save_file):
+    """
+    clf_data: List[List[str]] [[src1, tgt1],[src2,tgt2]...]
+    file format: tsv, row: src + tab + tgt
+    """
+    with open(save_file, 'w', encoding='utf-8') as f:
+        f.writelines("\n".join(["\t".join([str(r) for r in row]) for row in seq_data]))
+
+
+def write_to_ner(cls, ner_data, save_file):
+    """
+    :param cls:
+    :param ner_data:
+    :param save_file:
+    :return:
+    """
+    with open(save_file, 'w', encoding='utf-8') as f:
+        f.writelines("\n".join(["\t".join(str(r) for r in row) for row in ner_data]))
+
+
+def quick_save(self, model, save_name, optimizer=None):
+    save_path = os.path.join(self.save_dir, save_name + '_weights.pth')
+    if optimizer:
+        opt_weights = optimizer.get_weights()
+        np.save(os.path.join(self.save_dir, save_name + '_opt_weights'), opt_weights)
+    model.save_weights(save_path, save_format='h5')
+
+
+def save(self, model, iter_nb, train_metrics_values, test_metrics_values, tasks_weights=[], optimizer=None):
+    self.logs_dict['train'][str(iter_nb)] = {}
+    self.logs_dict['val'][str(iter_nb)] = {}
+    for k in range(len(self.metrics)):
+        self.logs_dict['train'][str(iter_nb)][self.metrics[k]] = float(train_metrics_values[k])
+        self.logs_dict['val'][str(iter_nb)][self.metrics[k]] = float(test_metrics_values[k])
+
+    if len(tasks_weights) > 0:
+        for k in range(len(tasks_weights)):
+            self.logs_dict['val'][str(iter_nb)]['weight_' + str(k)] = tasks_weights[k]
+
+    with open(self.logs_file, 'w') as f:
+        json.dump(self.logs_dict, f)
+
+    ckpt = {
+        'model_state_dict': model.state_dict(),
+        'iter_nb': iter_nb,
+    }
+    if optimizer:
+        ckpt['optimizer_state_dict'] = optimizer.state_dict()
+
+    # Saves best miou score if reached
+    if 'MEAN_IOU' in self.metrics:
+        miou = float(test_metrics_values[self.metrics.index('MEAN_IOU')])
+        if miou > self.best_miou and iter_nb > 0:
+            print('Best miou. Saving it.')
+            torch.save(ckpt, self.best_miou_weights_file)
+            self.best_miou = miou
+            self.config_dict['best_miou'] = self.best_miou
+    # Saves best relative error if reached
+    if 'REL_ERR' in self.metrics:
+        rel_error = float(test_metrics_values[self.metrics.index('REL_ERR')])
+        if rel_error < self.best_rel_error and iter_nb > 0:
+            print('Best rel error. Saving it.')
+            torch.save(ckpt, self.best_rel_error_weights_file)
+            self.best_rel_error = rel_error
+            self.config_dict['best_rel_error'] = self.best_rel_error
+
+    # Saves last checkpoint
+    torch.save(ckpt, self.last_checkpoint_weights_file)
+    self.iter_nb = iter_nb
+    self.config_dict['iter'] = self.iter_nb
+    with open(self.config_file, 'w') as f:
+        json.dump(self.config_dict, f)
+
+
+def extract_spec(dataset='train'):
+    f = open(data_path + dataset + '_list.txt', 'r')
+
+    i = 0
+    for file_name in f:
+        i = i + 1
+        if not (i % 10):
+            print(i)
+
+        # load audio file
+        file_name = file_name.rstrip('\n')
+        file_path = data_path + file_name
+        # print file_path
+        y0, sr = librosa.load(file_path, sr=22050)
+        # we use first 1 second
+        half = len(y0) / 4
+        y = y0[:round(half)]
+        # mfcc
+        mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=MFCC_DIM)
+        # delta mfcc and double delta
+        delta_mfcc = librosa.feature.delta(mfcc)
+        ddelta_mfcc = librosa.feature.delta(mfcc, order=2)
+
+        # STFT
+        D = np.abs(librosa.core.stft(y, hop_length=512, n_fft=1024, win_length=1024))
+        D_dB = librosa.amplitude_to_db(D, ref=np.max)
+
+        # mel spectrogram
+        mel_S = librosa.feature.melspectrogram(S=D, sr=sr, n_mels=128)
+        S_dB = librosa.power_to_db(mel_S, ref=np.max)  # log compression
+
+        # spectral centroid
+        spec_centroid = librosa.feature.spectral_centroid(S=D)
+
+        # concatenate all features
+        features = np.concatenate([mfcc, delta_mfcc, ddelta_mfcc, spec_centroid], axis=0)
+
+        # save mfcc as a file
+        file_name = file_name.replace('.wav', '.npy')
+        save_file = spec_path + file_name
+
+        if not os.path.exists(os.path.dirname(save_file)):
+            os.makedirs(os.path.dirname(save_file))
+        np.save(save_file, features)
+
+    f.close();
+
+
+def extract_codebook(dataset='train'):
+    f = open(data_path + dataset + '_list.txt', 'r')
+    i = 0
+    for file_name in f:
+        i = i + 1
+        if not (i % 10):
+            print(i)
+        # load audio file
+        file_name = file_name.rstrip('\n')
+        file_path = data_path + file_name
+        # #print file_path
+        y0, sr = librosa.load(file_path, sr=22050)
+        # we use first 1 second
+        half = len(y0) / 4
+        y = y0[:round(half)]
+        # STFT
+        S_full, phase = librosa.magphase(librosa.stft(y, n_fft=1024, window='hann', hop_length=256, win_length=1024))
+        n = len(y)
+
+        # Check the shape of matrix: row must corresponds to the example index !!!
+        X = S_full.T
+
+        # codebook by using K-Means Clustering
+        K = 20
+        kmeans = KMeans(n_clusters=K, random_state=0).fit(X)
+        features_kmeans = np.zeros(X.shape[0])
+        # for each sample, summarize feature!!!
+        codebook = np.zeros(K)
+        for sample in range(X.shape[0]):
+            features_kmeans[sample] = kmeans.labels_[sample]
+
+        # codebook histogram!
+        unique, counts = np.unique(features_kmeans, return_counts=True)
+
+        for u in unique:
+            u = int(u)
+            codebook[u] = counts[u]
+        # save mfcc as a file
+        file_name = file_name.replace('.wav', '.npy')
+        save_file = codebook_path + file_name
+
+        if not os.path.exists(os.path.dirname(save_file)):
+            os.makedirs(os.path.dirname(save_file))
+        np.save(save_file, codebook)
+
+    f.close()
+
+
+def run(self):
+    file = QtCore.QFile(self.filePath)
+    if not file.open(QtCore.QIODevice.WriteOnly):
+        self.saveFileFinished.emit(SAVE_FILE_ERROR, self.urlStr, self.filePath)
+    file.write(self.fileData)
+    file.close()
+    self.saveFileFinished.emit(0, self.urlStr, self.filePath)
+
+
+def saveFile(self, fileName, data):
+    file = QtCore.QFile(fileName)
+    if not file.open(QtCore.QIODevice.WriteOnly):
+        return False
+    file.write(data.readAll())
+    file.close()
+    return True
+
+
+def serialize(self):
+    """Callback to serialize the array."""
+    string_file = io.BytesIO()
+    try:
+        numpy.save(string_file, self.array, allow_pickle=False)
+        serialized = string_file.getvalue()
+    finally:
+        string_file.close()
+    return serialized
+
+
+def train(self, save=False, save_dir=None):
+    train_img_list = glob.glob(self.path_train + "/*")
+    print(train_img_list)
+
+    train_features = []
+
+    for img_file in train_img_list:
+        img = io.imread(img_file)
+        img = color.rgb2lab(img)
+        img_features = self.extract_texton_feature(img, self.fb, self.nb_features)
+        train_features.extend(img_features)
+
+    train_features = np.array(train_features)
+    print(train_features.shape)
+
+    kmeans_cluster = MiniBatchKMeans(n_clusters=self.nb_clusters, verbose=1, max_iter=300)
+    kmeans_cluster.fit(train_features)
+    print(kmeans_cluster.cluster_centers_)
+    print(kmeans_cluster.cluster_centers_.shape)
+
+    self.cluster = kmeans_cluster
+
+    # save kmeans result
+    if save is True:
+        with open(save_dir, 'wb') as f:
+            pickle.dump(self.cluster, f)
+
+    def save(self, event):
+        if not self.filename:
+            self.save_as(event)
+        else:
+            if self.writefile(self.filename):
+                self.set_saved(True)
+                try:
+                    self.editwin.store_file_breaks()
+                except AttributeError:  # may be a PyShell
+                    pass
+        self.text.focus_set()
+        return "break"
+
+
+def writefile(self, filename):
+    self.fixlastline()
+    chars = self.encode(self.text.get("1.0", "end-1c"))
+    if self.eol_convention != "\n":
+        chars = chars.replace("\n", self.eol_convention)
+    try:
+        f = open(filename, "wb")
+        f.write(chars)
+        f.flush()
+        f.close()
+        return True
+    except IOError as msg:
+        tkMessageBox.showerror("I/O Error", str(msg),
+                               master=self.text)
+        return False
+
+
+def save_response_content(response,
+                          destination,
+                          file_size=None,
+                          chunk_size=32768):
+    if file_size is not None:
+        pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
+
+        readable_file_size = sizeof_fmt(file_size)
+    else:
+        pbar = None
+
+    with open(destination, 'wb') as f:
+        downloaded_size = 0
+        for chunk in response.iter_content(chunk_size):
+            downloaded_size += chunk_size
+            if pbar is not None:
+                pbar.update(1)
+                pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} '
+                                     f'/ {readable_file_size}')
+            if chunk:  # filter out keep-alive new chunks
+                f.write(chunk)
+        if pbar is not None:
+            pbar.close()
+
+
+def generateHuman(cloth_list, person_id, sex):
+    haveAcc = 0
+    # load acc
+    hair = open('modeleTxt/hair.txt', 'r').readlines()
+    shoe = open('modeleTxt/shoe.txt', 'r').readlines()
+    pifu = open('modeleTxt/skin.txt', 'r').readlines()
+
+    if not os.path.exists(person_save_Folder):
+        os.makedirs(person_save_Folder)
+
+    if sex > 0:
+        Gender1 = 1000000
+    else:
+        Gender1 = 0
+    #     setting
+    Gender = '%.6f' % (Gender1 / 1000000)
+    Muscle = '%.6f' % (random.randint(0, 1000000) / 1000000)
+    African_1 = random.randint(0, 1000000)
+    African = '%.6f' % (African_1 / 1000000)
+    Asian_1 = random.randint(0, 1000000 - African_1)
+    Asian = '%.6f' % (Asian_1 / 1000000)
+    Caucasian = '%.6f' % ((1000000 - Asian_1 - African_1) / 1000000)
+    if Gender1 > 1000000 / 2:
+        m_height = random.gauss(170, 5.7) / 200
+        while m_height > 1:
+            m_height = random.gauss(170, 5.7) / 200
+        Height = '%.6f' % (m_height)
+    else:
+        m_height = random.gauss(160, 5.2) / 200
+        while m_height > 1:
+            m_height = random.gauss(160, 5.2) / 200
+        Height = '%.6f' % (m_height)
+    BreastSize = '%.6f' % (random.randint(0, 70) / 100)
+    Age = '%.6f' % (random.randint(20, 90) / 100)
+    BreastFirmness = '%.6f' % (random.randint(30, 100) / 100)
+    Weight = '%.6f' % (random.randint(0, 1000000) / 1000000)
+
+    file_name = 'B' + str(person_id)
+    # creating person file
+    f = open(person_save_Folder + file_name + ".mhm", 'a')
+    f.write('# Written by MakeHuman 1.1.1\n')
+    f.write('version v1.1.1\n')
+    f.write('tags ' + file_name + '\n')
+    f.write('camera 0.0 0.0 0.0 0.0 0.0 1.0\n')
+    f.write('modifier macrodetails-universal/Muscle ' + Muscle + '\n')
+    f.write('modifier macrodetails/African ' + African + '\n')
+    f.write('modifier macrodetails-proportions/BodyProportions 0.500000\n')
+    f.write('modifier macrodetails/Gender ' + Gender + '\n')
+    f.write('modifier macrodetails-height/Height ' + Height + '\n')
+    f.write('modifier breast/BreastSize ' + BreastSize + '\n')
+    f.write('modifier macrodetails/Age ' + Age + '\n')
+    f.write('modifier breast/BreastFirmness ' + BreastFirmness + '\n')
+    f.write('modifier macrodetails/Asian ' + Asian + '\n')
+    f.write('modifier macrodetails/Caucasian ' + Caucasian + '\n')
+    f.write('modifier macrodetails-universal/Weight ' + Weight + '\n')
+    f.write('skeleton cmu_mb.mhskel\n')
+    f.write('eyes HighPolyEyes 2c12f43b-1303-432c-b7ce-d78346baf2e6\n')
+
+    # adding clothes
+    if Gender1 > 1000000 / 2:
+        f.write(hair[random.randint(0, len(hair) - 1)])
+    else:
+        f.write(hair[random.randint(0, len(hair) - 1)])
+    f.write(shoe[random.randint(0, len(shoe) - 1)])
+    for i in range(0, len(cloth_list)):
+        f.write(cloth_list[i] + '\n')
+    f.write('clothesHideFaces True\n')
+    f.write(pifu[random.randint(0, len(pifu) - 1)])
+    f.write('material Braid01 eead6f99-d6c6-4f6b-b6c2-210459d7a62e braid01.mhmat\n')
+    f.write('material HighPolyEyes 2c12f43b-1303-432c-b7ce-d78346baf2e6 eyes/materials/brown.mhmat\n')
+    f.write('subdivide False\n')
+
+
+def notice_write(request):
+    if request.method == 'POST':
+        form = ContentForm(request.POST)
+        form_file = FileForm(request.POST, request.FILES)
+        if form.is_valid():
+            question = form.save(commit=False)
+            question.author = request.user
+            question.create_date = timezone.now()
+            question.boardname_id = 7
+            question.save()
+            if form_file.is_valid():
+                form_file = FileForm(request.POST, request.FILES)
+                file_save = form_file.save(commit=False)
+                file_save.author = request.user
+                file_save.postcontent = question
+                file_save.boardname_id = 7
+                file_save.file = request.FILES.get("file")
+                file_save.save()
+            return redirect('notice_view')
+    return render(request, 'notice_write.html')
+
+
+def test_write(request):
+    if request.method == 'POST':
+        form = ContentForm(request.POST)
+        form_file = FileForm(request.POST, request.FILES)
+        if form.is_valid():
+            question = form.save(commit=False)
+            question.author = request.user
+            question.create_date = timezone.now()
+            question.boardname_id = 14
+            question.save()
+            if form_file.is_valid():
+                form_file = FileForm(request.POST, request.FILES)
+                file_save = form_file.save(commit=False)
+                file_save.author = request.user
+                file_save.postcontent = question
+                file_save.boardname_id = 14
+                file_save.file = request.FILES.get("file")
+                file_save.save()
+            return redirect('test_list')
+    return render(request, 'test_write.html')
+
+
+def down_file(url, name, path):
+    if os.path.exists(path):
+        return
+
+    print("开始下载:" + name + ".mp3")
+    headers = {'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8',
+               "Accept-Encoding": "gzip, deflate, br",
+               "Accept-Language": "zh-CN,zh;q=0.9",
+               "Upgrade-Insecure-Requests": "1",
+               'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/63.0.3239.132 Safari/537.36'}
+
+    count = 0
+    while count < 3:
+        try:
+
+            r = requests.get(url, headers=headers, stream=True, timeout=60)
+            # print(r.status_code)
+            if (r.status_code == 200):
+                with open(path, "wb+") as f:
+                    for chunk in r.iter_content(1024):
+                        f.write(chunk)
+                print("完成下载:" + name + ".mp3")
+                break
+        except Exception as e:
+            print(e)
+            print("下载出错:" + name + ".mp3,3秒后重试")
+            if os.path.exists(path):
+                os.remove(path)
+
+            time.sleep(3)
+        count += 1
+
+    pass
+
+
+def save_as():
+    global file_name
+    content = content_text.get(1.0, 'end')
+    with open(file_name, 'w') as save:
+        save.write(content)
+
+
+def export_save(data_player, data_kick, guild_id, save_name=""):
+    if save_name: save_name = "_" + save_name
+    print(" - Partie enregistrée -")
+
+    with open(f"saves/save{save_name}.json", "w") as file:
+        file.write(json.dumps(
+            {
+                "players": [data_player[player_id].export() for player_id in data_player],
+                "kicks": data_kick,
+                "guild_id": guild_id
+            }, indent=4))
+
+
+def conv(heic_path, save_dir, filetype, quality):
+    # 保存先のディレクトリとファイル名
+    extension = "." + filetype
+    save_path = save_dir / filetype / pathlib.Path(*heic_path.parts[1:]).with_suffix(extension)
+    # フォルダ作成
+    save_path.parent.mkdir(parents=True, exist_ok=True)
+    # HEICファイルpyheifで読み込み
+    heif_file = pyheif.read(heic_path)
+    # 読み込んだファイルの中身をdata変数へ
+    data = Image.frombytes(
+        heif_file.mode,
+        heif_file.size,
+        heif_file.data,
+        "raw",
+        heif_file.mode,
+        heif_file.stride,
+    )
+    # JPEGで保存
+    data.save(save_path, quality=quality)
+    print("保存:", save_path)
+
+
+def parsing_sravni_ru(soup):
+    names = soup.find_all('span', class_='_106rrj0')  # scraping names
+
+    # scraping age childrens
+    age_divs = soup.find_all('div', {'style': 'grid-area:firstCell-1', 'class': '_pjql8'})
+    ages = []
+    for i in age_divs:
+        age_span = i.find('span')
+        ages.append(age_span)
+
+    # scraping course duration
+    duration_divs = soup.find_all('div', {'style': 'grid-area:secondCell-1', 'class': '_pjql8'})
+    durations = []
+    for i in duration_divs:
+        duration_span = i.find('span')
+        durations.append(duration_span)
+
+    # scraping price
+    prices = soup.find_all('span', class_='_e9qrci _k8dl2y')
+
+    items = []
+    for (n, l, i, p) in zip(names, ages, durations, prices):
+        name = n.text.strip()
+        age = l.text.strip()
+        duration = i.text.strip()
+        price = p.text.strip().replace('\xa0', '')
+        items.append(
+            {
+                'name': name,
+                'age': age,
+                'duration': duration,
+                'price': price,
+            }
+        )
+
+    # save json file
+    with open("./data/items.json", "w", encoding="utf-8") as f:
+        json.dump(items, f, indent=4, ensure_ascii=False)
+
+    with open("./data/items.csv", 'a', encoding="utf-8") as file:
+        for i in items:
+            writer = csv.writer(file)
+            writer.writerow(
+                (
+                    i['name'],
+                    i['age'],
+                    i['duration'],
+                    i['price']
+                )
+            )
+
+
+def save_to_file(self, path):
+    with open(path, "w") as f:
+        f.write(self.cert_pem())
+        f.write(self.key_pem())
+
+
+def save_cert_to_file(self, path):
+    with open(path, "w") as f:
+        f.write(self.cert_pem())
+
+
+def _save_large_file(self, os_path, content, format):
+    """Save content of a generic file."""
+    if format not in {'text', 'base64'}:
+        raise web.HTTPError(
+            400,
+            "Must specify format of file contents as 'text' or 'base64'",
+        )
+    try:
+        if format == 'text':
+            bcontent = content.encode('utf8')
+        else:
+            b64_bytes = content.encode('ascii')
+            bcontent = base64.b64decode(b64_bytes)
+    except Exception as e:
+        raise web.HTTPError(
+            400, u'Encoding error saving %s: %s' % (os_path, e)
+        )
+
+    with self.perm_to_403(os_path):
+        if os.path.islink(os_path):
+            os_path = os.path.join(os.path.dirname(os_path), os.readlink(os_path))
+        with io.open(os_path, 'ab') as f:
+            f.write(bcontent)
+
+
+def get_unzip_hdfs_file(hdfs_file_url, save_dir):
+    # 判断保存路径是否存在,不存在的话创建此目录
+    if os.path.isdir(save_dir):
+        pass
+    else:
+        os.mkdir(save_dir)
+
+    # hdfs文件名
+    filename = hdfs_file_url.split("/").pop()
+
+    # 保存到本地的文件名
+    save_filename = ""
+
+    # 判断是否为压缩文件
+    if filename.endswith(".gz"):
+        save_filename = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())) + ".gz"
+    else:
+        save_filename = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
+
+    # 判断保存路径最后是否有/
+    if save_dir.endswith("/"):
+        save_file = save_dir + save_filename
+    else:
+        save_file = save_dir + "/" + save_filename
+
+    # 生成下载hdfs文件的命令
+    hadoop_get = 'hadoop fs -get %s %s' % (hdfs_file_url, save_file)
+    logger.info("download hdfs file cammond: " + hadoop_get)
+    # shell执行生成的hdfs命令
+    try:
+        os.system(hadoop_get)
+    except Exception as e:
+        logger.error(e)
+        return False
+
+    # 判断下载的hdfs文件是否为压缩文件
+    if save_file.endswith(".gz"):
+
+        # 对此压缩文件进行压缩
+        try:
+            # 解压后的文件名
+            f_name = save_file.replace(".gz", "")
+            # 解压缩
+            g_file = gzip.GzipFile(save_file)
+            # 写入文件
+            open(f_name, "w+").write(g_file.read())
+            # 关闭文件流
+            g_file.close()
+
+            return f_name
+        except Exception as e:
+            logger.error(e)
+            return False
+    else:
+        return save_file
+
+
+"""
+根据HDFS文件目录下载此目录下所有的文件
+参数说明:
+hdfs_dir:HDFS文件目录
+save_dir:要保存的目录
+返回结果说明:执行成功返回True,执行失败返回False
+"""
+
+
+def get_unzip_hdfs_file_from_dir(hdfs_dir, save_dir):
+    # 命令:获取hdfs目录下的文件
+    hadoop_ls = "hadoop fs -ls %s | grep -i '^-'" % hdfs_dir
+
+    # 解压后的文件列表
+    save_file_list = []
+    # 执行shell命令
+    hdfs_result = exec_sh(hadoop_ls, None)
+
+    # 获取命令执行输出
+    hdfs_stdout = hdfs_result["stdout"]
+    # print("hdfs_stdout = " + hdfs_stdout)
+
+    # 要下载的HDFS文件列表
+    hdfs_list = []
+
+    # 判断是否有输出
+    if hdfs_stdout:
+        # 以行分割, 一行是一个文件的信息
+        hdfs_lines = hdfs_stdout.split("\n")
+
+        # 对每一行进行处理
+        for line in hdfs_lines:
+
+            # 以空白字符为分割符获取hdfs文件名
+            line_list = re.split("\s+", line)
+
+            # -rw-r--r--   2 caoweidong supergroup      42815 2017-01-23 14:20 /user/000000_0.gz
+            if line_list.__len__() == 8:
+                # print("line_list[7] = " + line_list[7])
+
+                # HDFS文件加入下载列表
+                hdfs_list.append(line_list[7])
+            else:
+                pass
+        # 下载文件
+        for file in hdfs_list:
+            save_filename = get_unzip_hdfs_file(file, save_dir)
+            save_file_list.append(save_filename)
+        return save_file_list
+    else:
+        return False
+
+
+def save_game(self):
+    save_file = open("saves/main_save.xml", "w+")
+
+    level = self.save_level()
+    self.tree.append(level)
+
+    team = self.save_team()
+    self.tree.append(team)
+
+    # Store XML tree in file
+    save_file.write(etree.tostring(self.tree, pretty_print=True, encoding="unicode"))
+
+    save_file.close()
+
+    def save_upload_file(
+            self,
+            file: UploadFile,
+            save_dir_path: pathlib.Path,
+            job_id: str,
+            dt_string: str,
+    ) -> pathlib.Path:
+        """Save `file` under `save_dir_path`.
+        Args:
+            file (UploadFile): A file want to save.
+            save_dir_path (pathlib.Path): A path to directory where file will be saved.
+            job_id (str): A job id. This will used part of filename.
+            dt_string (str): A datetime info. This will used part of filename.
+        Return:
+            pathlib.Path: A path where file is saved.
+        """
+        if not save_dir_path.exists():
+            save_dir_path.mkdir(parents=True, exist_ok=True)
+
+        save_path: Final = save_dir_path / f"{dt_string}_{job_id}_{file.filename}"
+
+        try:
+            with save_path.open("wb") as f:
+                shutil.copyfileobj(file.file, f)
+        finally:
+            file.file.close()
+
+        return save_path
+
+
+def save_output(output, list_to_save):
+    if not output:
+        with open(output, "w") as f:
+            for item in list_to_save:
+                f.write("%s\n" % item)
+        print(f"Output file: {output}")
+
+
+def _saveTestWavFile(self, filename, wav_data):
+    with open(filename, "wb") as f:
+        file_path = os.path.join(dir_name, "some_audio_%d.wav" % i)
+        self._saveTestWavFile(file_path, wav_data)
+
+
+def _save_large_file(self, os_path, content, format):
+    """Save content of a generic file."""
+    if format not in {'text', 'base64'}:
+        raise web.HTTPError(
+            400,
+            "Must specify format of file contents as 'text' or 'base64'",
+        )
+    try:
+        if format == 'text':
+            bcontent = content.encode('utf8')
+        else:
+            b64_bytes = content.encode('ascii')
+            bcontent = base64.b64decode(b64_bytes)
+    except Exception as e:
+        raise web.HTTPError(
+            400, u'Encoding error saving %s: %s' % (os_path, e)
+        )
+
+    with self.perm_to_403(os_path):
+        if os.path.islink(os_path):
+            os_path = os.path.join(os.path.dirname(os_path), os.readlink(os_path))
+        with io.open(os_path, 'ab') as f:
+            f.write(bcontent)
+
+
+def _post_save_script(model, os_path, contents_manager, **kwargs):
+    """convert notebooks to Python script after save with nbconvert
+    replaces `jupyter notebook --script`
+    """
+    from nbconvert.exporters.script import ScriptExporter
+    warnings.warn("`_post_save_script` is deprecated and will be removed in Notebook 5.0", DeprecationWarning)
+
+    if model['type'] != 'notebook':
+        return
+
+    global _script_exporter
+    if _script_exporter is None:
+        _script_exporter = ScriptExporter(parent=contents_manager)
+    log = contents_manager.log
+
+    base, ext = os.path.splitext(os_path)
+    script, resources = _script_exporter.from_filename(os_path)
+    script_fname = base + resources.get('output_extension', '.txt')
+    log.info("Saving script /%s", to_api_path(script_fname, contents_manager.root_dir))
+    with io.open(script_fname, 'w', encoding='utf-8') as f:
+        f.write(script)
+
+
+def _save_data(filename, data):
+    """
+    Save formatted skeleton data to a pickle file
+    """
+    if filename[-2:] == ".p":
+        filename = filename
+    else:
+        filename = str(filename + ".p")
+
+    with open(filename, 'wb') as fp:
+        pickle.dump(data, fp, protocol=pickle.HIGHEST_PROTOCOL)
+    print("Saved data to file: " + filename)
+
+
+def download_unknowns(url: str) -> None:
+    """."""
+    page_content: bytes = get_none_soup(url)
+    page_string: bytes = page_content[0:100]
+    """parse section of page bytes and use as name. If unknown encoding
+    convert to number string (exclude first few bytes that state filetype) """
+    try:
+        page_unicode = page_string.decode("ISO-8859-1").replace(R'%', '_')
+        page_parsed = [char for char in page_unicode if char.isalnum() or char == '_']
+        unknown_file_name = "".join(page_parsed)[10:30]
+    except UnicodeDecodeError:
+        try:
+            page_unicode = page_string.decode('utf-8').replace(R'%', '_')
+            page_parsed = [char for char in page_unicode if char.isalnum() or char == '_']
+            unknown_file_name = "".join(page_parsed)[10:30]
+        except UnicodeDecodeError:
+            unknown_file_name = "unk_"
+            for char in page_content[10:30]:
+                if char != b'\\':
+                    unknown_file_name += str(char)
+    print(unknown_file_name)
+    """check beginning of page bytes for a filetype"""
+    if b'%PDF' in page_string:  # ;
+        extension = '.pdf'
+    else:
+        extension = '.unk.txt'
+
+    with open(save_file_dir + '/' + unknown_file_name + extension, 'wb') as file:
+        file.write(page_content)  # ; print(save_file_dir)
+
+
+def download_images(start_url: str, filetypes: List[str]) -> None:
+    """.."""
+    base_url = get_base_url(start_url)
+    # print(start_url)
+    soup = get_soup(start_url)  # ;print(soup)
+    if soup is not None:
+        for index, image in enumerate(soup.select('img')):  # print(image)
+            # image_raw = str(image)
+            src_raw = str(image.get('src'))  # print(image.attrs['src'])
+            if src_raw.startswith('http'):
+                image_url = src_raw
+            elif src_raw.startswith('/'):
+                image_url = base_url + src_raw
+            else:
+                image_url = src_raw
+            # print(image_url)
+            for image_type in filter(lambda x: x in src_raw, filetypes):  # print(image)
+                image_response = requests.get(image_url, stream=True)
+                if image_response.status_code == 200:
+                    image_name = re.sub(r'.*/', '', src_raw).replace(R'.', '_')
+                    # print(image_name, index)
+                    fp: BinaryIO = open(save_image_dir + '/' + image_name + str(index) + image_type, 'wb')
+                    fp.write(image_response.content)
+                    fp.close()
+                    # i = Image.open(BytesIO(image_response.content))
+                    # i.save(image_name)
+
+
+def _unicode_save(self, temp_file):
+    im = pygame.Surface((10, 10), 0, 32)
+    try:
+        with open(temp_file, "w") as f:
+            pass
+        os.remove(temp_file)
+    except IOError:
+        raise unittest.SkipTest("the path cannot be opened")
+
+    self.assertFalse(os.path.exists(temp_file))
+
+    try:
+        imageext.save_extended(im, temp_file)
+
+        self.assertGreater(os.path.getsize(temp_file), 10)
+    finally:
+        try:
+            os.remove(temp_file)
+        except EnvironmentError:
+            pass

+ 21 - 0
Target/Azure/AddUp/Azure-blob-storage_4.py

@@ -0,0 +1,21 @@
+def create_blob_from_url(storage_connection_string,container_name):
+    try:
+        # urls to fetch into blob storage
+        url_list = get_random_images()
+
+        # Instantiate a new BlobServiceClient and a new ContainerClient
+        blob_service_client = BlobServiceClient.from_connection_string(storage_connection_string)
+        container_client = blob_service_client.get_container_client(container_name)
+
+        for u in url_list:
+            # Download file from url then upload blob file
+            r = requests.get(u, stream = True)
+            if r.status_code == 200:
+                r.raw.decode_content = True
+                blob_client = container_client.get_blob_client(get_filename_from_url(u))
+                blob_client.upload_blob(r.raw,overwrite=True)
+        return True
+        
+    except Exception as e:
+        print(e.message, e.args)
+        return False

+ 15 - 0
Target/Azure/AddUp/Azure-blob-storage_5.py

@@ -0,0 +1,15 @@
+def create_blob_from_path(storage_connection_string,container_name):
+    try:
+        # Instantiate a new BlobServiceClient and a new ContainerClient
+        blob_service_client = BlobServiceClient.from_connection_string(storage_connection_string)
+        container_client = blob_service_client.get_container_client(container_name)
+
+        for f in list_files():
+            with open(f["local_path"], "rb") as data:
+                blob_client = container_client.get_blob_client(f["file_name"])
+                blob_client.upload_blob(data,overwrite=True)
+        return True
+
+    except Exception as e:
+        print(e.message, e.args)
+        return False

+ 29 - 0
Target/Azure/AddUp/blob-upload-1_1.py

@@ -0,0 +1,29 @@
+def upload_file():
+    if request.method == 'POST':
+        file = request.files['file']
+        filename = secure_filename(file.filename)
+        fileextension = filename.rsplit('.',1)[1]
+        Randomfilename = id_generator()
+        filename = Randomfilename + '.' + fileextension
+        try:
+            blob_service.create_blob_from_stream(container, filename, file)
+        except Exception:
+            print ('Exception=' + Exception)
+            pass
+        ref =  'http://'+ account + '.blob.core.windows.net/' + container + '/' + filename
+        return '''
+	    <!doctype html>
+	    <title>File Link</title>
+	    <h1>Uploaded File Link</h1>
+	    <p>''' + ref + '''</p>
+	    <img src="'''+ ref +'''">
+	    '''
+    return '''
+    <!doctype html>
+    <title>Upload new File</title>
+    <h1>Upload new File</h1>
+    <form action="" method=post enctype=multipart/form-data>
+      <p><input type=file name=file>
+         <input type=submit value=Upload>
+    </form>
+    '''

+ 9 - 0
Target/Azure/AddUp/blob-upload-2_3.py

@@ -0,0 +1,9 @@
+def _get_service(self):
+        if not hasattr(self, '_blob_service'):
+            self._blob_service = BlobService(
+                account_name=self.account_name,
+                account_key=self.account_key,
+                protocol='https' if self.use_ssl else 'http'
+            )
+
+        return self._blob_service

+ 5 - 0
Target/Azure/AddUp/blob-upload-2_4.py

@@ -0,0 +1,5 @@
+def _get_properties(self, name):
+        return self._get_service().get_blob_properties(
+            container_name=self.container,
+            blob_name=name
+        )

+ 13 - 0
Target/Azure/AddUp/blob-upload-2_5.py

@@ -0,0 +1,13 @@
+def _open(self, name, mode='rb'):
+        """
+        Return the AzureStorageFile.
+        """
+
+        from django.core.files.base import ContentFile
+
+        contents = self._get_service().get_blob_to_bytes(
+            container_name=self.container,
+            blob_name=name
+        )
+
+        return ContentFile(contents)

+ 34 - 0
Target/Azure/AddUp/blob-upload-2_6.py

@@ -0,0 +1,34 @@
+def _save(self, name, content):
+        """
+        Use the Azure Storage service to write ``content`` to a remote file
+        (called ``name``).
+        """
+        
+
+        content.open()
+
+        content_type = None
+
+        if hasattr(content.file, 'content_type'):
+            content_type = content.file.content_type
+        else:
+            content_type = mimetypes.guess_type(name)[0]
+
+        cache_control = self.get_cache_control(
+            self.container,
+            name,
+            content_type
+        )
+
+        self._get_service().put_block_blob_from_file(
+            container_name=self.container,
+            blob_name=name,
+            stream=content,
+            x_ms_blob_content_type=content_type,
+            cache_control=cache_control,
+            x_ms_blob_cache_control=cache_control
+        )
+
+        content.close()
+
+        return name

+ 22 - 0
Target/Azure/AddUp/blob-upload-2_7.py

@@ -0,0 +1,22 @@
+def listdir(self, path):
+        """
+        Lists the contents of the specified path, returning a 2-tuple of lists;
+        the first item being directories, the second item being files.
+        """
+
+        files = []
+
+        if path and not path.endswith('/'):
+            path = '%s/' % path
+
+        path_len = len(path)
+
+        if not path:
+            path = None
+
+        blob_list = self._get_service().list_blobs(self.container, prefix=path)
+
+        for name in blob_list:
+            files.append(name[path_len:])
+
+        return ([], files)

+ 9 - 0
Target/Azure/AddUp/blob-upload-2_9.py

@@ -0,0 +1,9 @@
+def delete(self, name):
+        """
+        Deletes the file referenced by name.
+        """
+
+        try:
+            self._get_service().delete_blob(self.container, name)
+        except AzureMissingResourceHttpError:
+            pass

+ 58 - 0
Target/Azure/AddUp/blob-upload_1.py

@@ -0,0 +1,58 @@
+def run_sample():
+    try:
+        # Create the BlockBlobService that is used to call the Blob service for the storage account
+        blob_service_client = BlockBlobService(
+            account_name='accountname', account_key='accountkey')
+
+        # Create a container called 'quickstartblobs'.
+        container_name = 'quickstartblobs'
+        blob_service_client.create_container(container_name)
+
+        # Set the permission so the blobs are public.
+        blob_service_client.set_container_acl(
+            container_name, public_access=PublicAccess.Container)
+
+        # Create Sample folder if it not exists, and create a file in folder Sample to test the upload and download.
+        local_path = os.path.expanduser("~/Sample")
+        if not os.path.exists(local_path):
+            os.makedirs(os.path.expanduser("~/Sample"))
+        local_file_name = "QuickStart_" + str(uuid.uuid4()) + ".txt"
+        full_path_to_file = os.path.join(local_path, local_file_name)
+
+        # Write text to the file.
+        file = open(full_path_to_file,  'w')
+        file.write("Hello, World!")
+        file.close()
+
+        print("Temp file = " + full_path_to_file)
+        print("\nUploading to Blob storage as blob" + local_file_name)
+
+        # Upload the created file, use local_file_name for the blob name
+        blob_service_client.create_blob_from_path(
+            container_name, local_file_name, full_path_to_file)
+
+        # List the blobs in the container
+        print("\nList blobs in the container")
+        generator = blob_service_client.list_blobs(container_name)
+        for blob in generator:
+            print("\t Blob name: " + blob.name)
+
+        # Download the blob(s).
+        # Add '_DOWNLOADED' as prefix to '.txt' so you can see both files in Documents.
+        full_path_to_file2 = os.path.join(local_path, str.replace(
+            local_file_name ,'.txt', '_DOWNLOADED.txt'))
+        print("\nDownloading blob to " + full_path_to_file2)
+        blob_service_client.get_blob_to_path(
+            container_name, local_file_name, full_path_to_file2)
+
+        sys.stdout.write("Sample finished running. When you hit <any key>, the sample will be deleted and the sample "
+                         "application will exit.")
+        sys.stdout.flush()
+        input()
+
+        # Clean up resources. This includes the container and the temp files
+        blob_service_client.delete_container(container_name)
+        os.remove(full_path_to_file)
+        os.remove(full_path_to_file2)
+    except Exception as e:
+        print(e)

+ 103 - 0
Target/Azure/AddUp/circuitbreaker_1.py

@@ -0,0 +1,103 @@
+def run_circuit_breaker():
+    # Name of image to use for testing.
+    image_to_upload = "HelloWorld.png"
+
+    global blob_client
+    global container_name
+    try:
+
+        # Create a reference to the blob client and container using the storage account name and key
+        blob_client = BlockBlobService(account_name, account_key)
+
+        # Make the container unique by using a UUID in the name.
+        container_name = "democontainer" + str(uuid.uuid4())
+        blob_client.create_container(container_name)
+
+    except Exception as ex:
+        print("Please make sure you have put the correct storage account name and key.")
+        print(ex)
+
+    # Define a reference to the actual blob and upload the block_blob to the newly created container
+    full_path_to_file = os.path.join(os.path.dirname(__file__), image_to_upload)
+    blob_client.create_blob_from_path(container_name, image_to_upload, full_path_to_file)
+
+    # Set the location mode to secondary, so you can check just the secondary data center.
+    blob_client.location_mode = LocationMode.SECONDARY
+    blob_client.retry = LinearRetry(backoff=0).retry
+
+    # Before proceeding, wait until the blob has been replicated to the secondary data center.
+    # Loop and check for the presence of the blob once in a second until it hits 60 seconds
+    # or until it finds it
+    counter = 0
+    while counter < 60:
+        counter += 1
+        sys.stdout.write("\nAttempt {0} to see if the blob has replicated to the secondary storage yet.".format(counter))
+        sys.stdout.flush()
+        if blob_client.exists(container_name, image_to_upload):
+            break
+
+        # Wait a second, then loop around and try again
+        # When it's finished replicating to the secondary, continue.
+        time.sleep(1)
+
+    # Set the starting LocationMode to Primary, then Secondary.
+    # Here we use the linear retry by default, but allow it to retry to secondary if
+    # the initial request to primary fails.
+    # Note that the default is Primary. You must have RA-GRS enabled to use this
+    blob_client.location_mode = LocationMode.PRIMARY
+    blob_client.retry = LinearRetry(max_attempts=retry_threshold, backoff=1).retry
+
+    ''' 
+        ************INSTRUCTIONS**************k
+        To perform the test, first replace the 'accountname' and 'accountkey' with your storage account name and key.
+        Every time it calls get_blob_to_path it will hit the response_callback function.
+
+        Next, run this app. While this loop is running, pause the program by pressing any key, and
+        put the intercept code in Fiddler (that will intercept and return a 503).
+
+        For instructions on modifying Fiddler, look at the Fiddler_script.text file in this project.
+        There are also full instructions in the ReadMe_Instructions.txt file included in this project.
+
+        After adding the custom script to Fiddler, calls to primary storage will fail with a retryable
+        error which will trigger the Retrying event (above).
+        Then it will switch over and read the secondary. It will do that 20 times, then try to
+        switch back to the primary.
+        After seeing that happen, pause this again and remove the intercepting Fiddler code
+        Then you'll see it return to the primary and finish.
+        '''
+
+    print("\n\nThe application will pause at 200 unit interval")
+
+    for i in range(0, 1000):
+        if blob_client.location_mode == LocationMode.SECONDARY:
+            sys.stdout.write("S{0} ".format(str(i)))
+        else:
+            sys.stdout.write("P{0} ".format(str(i)))
+        sys.stdout.flush()
+
+        try:
+
+            # These function is called immediately after retry evaluation is performed.
+            # It is used to trigger the change from primary to secondary and back
+            blob_client.retry_callback = retry_callback
+
+            # Download the file
+            blob_client.get_blob_to_path(container_name, image_to_upload,
+                                                str.replace(full_path_to_file, ".png", "Copy.png"))
+
+            # Set the application to pause at 200 unit intervals to implement simulated failures
+            if i == 200 or i == 400 or i == 600 or i == 800:
+                sys.stdout.write("\nPress the Enter key to resume")
+                sys.stdout.flush()
+                if sys.version_info[0] < 3:
+                    raw_input()
+                else:
+                    input()
+        except Exception as ex:
+            print(ex)
+        finally:
+            # Force an exists call to succeed by resetting the status
+            blob_client.response_callback = response_callback
+
+    # Clean up resources
+    blob_client.delete_container(container_name)

+ 97 - 0
Target/Azure/AddUp/datafactory_4.py

@@ -0,0 +1,97 @@
+def main():
+
+    # Azure subscription ID
+    subscription_id = '<Azure subscription ID>'
+
+    # This program creates this resource group. If it's an existing resource group, comment out the code that creates the resource group
+    rg_name = '<Azure resource group name>'
+
+    # The data factory name. It must be globally unique.
+    df_name = '<Data factory name>'        
+
+    # Specify your Active Directory client ID, client secret, and tenant ID
+    credentials = ServicePrincipalCredentials(client_id='<AAD application ID>', secret='<AAD app authentication key>', tenant='<AAD tenant ID>')
+    resource_client = ResourceManagementClient(credentials, subscription_id)
+    adf_client = DataFactoryManagementClient(credentials, subscription_id)
+
+    rg_params = {'location':'eastus'}
+    df_params = {'location':'eastus'}
+
+    # create the resource group
+    # comment out if the resource group already exits
+    resource_client.resource_groups.create_or_update(rg_name, rg_params)
+
+    # Create a data factory
+    df_resource = Factory(location='eastus')
+    df = adf_client.factories.create_or_update(rg_name, df_name, df_resource)
+    print_item(df)
+    while df.provisioning_state != 'Succeeded':
+        df = adf_client.factories.get(rg_name, df_name)
+        time.sleep(1)
+
+    # Create an Azure Storage linked service
+    ls_name = 'storageLinkedService'
+
+    # Specify the name and key of your Azure Storage account
+    storage_string = SecureString('DefaultEndpointsProtocol=https;AccountName=<Azure storage account>;AccountKey=<Azure storage authentication key>')
+
+    ls_azure_storage = AzureStorageLinkedService(connection_string=storage_string)
+    ls = adf_client.linked_services.create_or_update(rg_name, df_name, ls_name, ls_azure_storage)
+    print_item(ls)
+
+    # Create an Azure blob dataset (input)
+    ds_name = 'ds_in'
+    ds_ls = LinkedServiceReference(ls_name)
+    blob_path= 'adftutorial/inputpy'
+    blob_filename = 'input.txt'
+    ds_azure_blob= AzureBlobDataset(ds_ls, folder_path=blob_path, file_name = blob_filename)
+    ds = adf_client.datasets.create_or_update(rg_name, df_name, ds_name, ds_azure_blob)
+    print_item(ds)
+
+    # Create an Azure blob dataset (output)
+    dsOut_name = 'ds_out'
+    output_blobpath = 'adftutorial/outputpy'
+    dsOut_azure_blob = AzureBlobDataset(ds_ls, folder_path=output_blobpath)
+    dsOut = adf_client.datasets.create_or_update(rg_name, df_name, dsOut_name, dsOut_azure_blob)
+    print_item(dsOut)
+
+    # Create a copy activity
+    act_name =  'copyBlobtoBlob'
+    blob_source = BlobSource()
+    blob_sink = BlobSink()
+    dsin_ref = DatasetReference(ds_name)
+    dsOut_ref = DatasetReference(dsOut_name)
+    copy_activity = CopyActivity(act_name,inputs=[dsin_ref], outputs=[dsOut_ref], source=blob_source, sink=blob_sink)
+
+    # Create a pipeline with the copy activity
+    p_name =  'copyPipeline'
+    params_for_pipeline = {}
+    p_obj = PipelineResource(activities=[copy_activity], parameters=params_for_pipeline)
+    p = adf_client.pipelines.create_or_update(rg_name, df_name, p_name, p_obj)
+    print_item(p)
+
+    # Create a pipeline run
+    run_response = adf_client.pipelines.create_run(rg_name, df_name, p_name,
+        {
+        }
+    )
+
+    # Monitor the pipeilne run
+    time.sleep(30)
+    pipeline_run = adf_client.pipeline_runs.get(rg_name, df_name, run_response.run_id)
+    print("\n\tPipeline run status: {}".format(pipeline_run.status))
+    activity_runs_paged = list(adf_client.activity_runs.list_by_pipeline_run(rg_name, df_name, pipeline_run.run_id, datetime.now() - timedelta(1),  datetime.now() + timedelta(1)))
+    print_activity_run_details(activity_runs_paged[0])
+
+    # Create a trigger
+    tr_name = 'mytrigger'
+    scheduler_recurrence = ScheduleTriggerRecurrence(frequency='Minute', interval='15',start_time=datetime.now(), end_time=datetime.now() + timedelta(1), time_zone='UTC') 
+    pipeline_parameters = {'inputPath':'adftutorial/inputpy', 'outputPath':'adftutorial/outputpy'}
+    pipelines_to_run = []
+    pipeline_reference = PipelineReference('copyPipeline')
+    pipelines_to_run.append(TriggerPipelineReference(pipeline_reference, pipeline_parameters))
+    tr_properties = ScheduleTrigger(description='My scheduler trigger', pipelines = pipelines_to_run, recurrence=scheduler_recurrence)    
+    adf_client.triggers.create_or_update(rg_name, df_name, tr_name, tr_properties)
+
+    # start the trigger
+    adf_client.triggers.start(rg_name, df_name, tr_name)

+ 28 - 0
Target/Azure/AddUp/file_advanced_samples_2.py

@@ -0,0 +1,28 @@
+def run_all_samples(self, connection_string):
+        print('Azure Storage File Advanced samples - Starting.')
+        
+        try:
+            # Create an instance of ShareServiceClient
+            service = ShareServiceClient.from_connection_string(conn_str=connection_string)
+
+            # List shares
+            print('\n\n* List shares *\n')
+            self.list_shares(service)
+
+            # Set Cors
+            print('\n\n* Set cors rules *\n')
+            self.set_cors_rules(service)
+
+            # Set Service Properties
+            print('\n\n* Set service properties *\n')
+            self.set_service_properties(service)
+
+            # Share, directory and file properties and metadata
+            print('\n\n* Metadata and properties *\n')
+            self.metadata_and_properties(service)
+
+        except Exception as e:
+            print('Error occurred in the sample.', e) 
+
+        finally:
+            print('\nAzure Storage File Advanced samples - Completed.\n')

+ 20 - 0
Target/Azure/AddUp/file_advanced_samples_3.py

@@ -0,0 +1,20 @@
+def list_shares(self, service):
+        share_prefix = 'sharesample' + self.random_data.get_random_name(6)
+
+        try:        
+            print('1. Create multiple shares with prefix: ', share_prefix)
+            for i in range(5):
+                service.create_share(share_name=share_prefix + str(i))
+            
+            print('2. List shares')
+            shares = service.list_shares()
+            for share in shares:
+                print('  Share name:' + share.name)
+
+        except Exception as e:
+            print(e) 
+
+        finally:
+            print('3. Delete shares with prefix:' + share_prefix) 
+            for i in range(5):
+                service.delete_share(share_prefix + str(i))

+ 22 - 0
Target/Azure/AddUp/file_advanced_samples_4.py

@@ -0,0 +1,22 @@
+def set_cors_rules(self, service):
+        print('1. Get Cors Rules')
+        original_cors_rules = service.get_service_properties()['cors']
+
+        print('2. Overwrite Cors Rules')
+        cors_rule = CorsRule(
+            allowed_origins=['*'], 
+            allowed_methods=['POST', 'GET'],
+            allowed_headers=['*'],
+            exposed_headers=['*'],
+            max_age_in_seconds=3600)
+
+        try:
+            service.set_service_properties(cors=[cors_rule])
+        except Exception as e:
+            print(e)
+        finally:
+            #reverting cors rules back to the original ones
+            print('3. Revert Cors Rules back the original ones')
+            service.set_service_properties(cors=original_cors_rules)
+        
+        print("CORS sample completed")

+ 65 - 0
Target/Azure/AddUp/file_advanced_samples_6.py

@@ -0,0 +1,65 @@
+def metadata_and_properties(self, service):
+        share_name = 'sharename' + self.random_data.get_random_name(6)
+
+        try:
+            # All directories and share must be created in a parent share.
+            # Max capacity: 5TB per share
+            print('1. Create sample share with name ' + share_name)
+            quota = 1 # in GB
+            metadata = { "foo": "bar", "baz": "foo" }
+            share_client = service.create_share(share_name=share_name)
+            print('Sample share "'+ share_name +'" created.')
+
+            print('2. Get share properties.')
+            properties = share_client.get_share_properties()
+
+            print('3. Get share metadata.')
+            get_metadata = properties['metadata']
+            for k, v in get_metadata.items():
+                print("\t" + k + ": " + v)
+
+            dir_name = 'dirname' + self.random_data.get_random_name(6)
+
+            print('4. Create sample directory with name ' + dir_name)
+            metadata = { "abc": "def", "jkl": "mno" }
+            directory_client = share_client.create_directory(dir_name, metadata=metadata)
+            print('Sample directory "'+ dir_name +'" created.')
+
+            print('5. Get directory properties.')
+            properties = directory_client.get_directory_properties()
+            
+            print('6. Get directory metadata.')
+            get_metadata = properties['metadata']
+            for k, v in get_metadata.items():
+                print("\t" + k + ": " + v)
+
+            file_name = 'sample.txt'
+            # Uploading text to share_name/dir_name/sample.txt in Azure Files account.
+            # Max capacity: 1TB per file
+            print('7. Upload sample file from text to directory.')
+            metadata = { "prop1": "val1", "prop2": "val2" }
+            file_client = directory_client.get_file_client(file_name)
+            file_client.upload_file('Hello World! - from text sample', metadata=metadata)
+            print('Sample file "' + file_name + '" created and uploaded to: ' + share_name + '/' + dir_name)        
+
+            print('8. Get file properties.')
+            properties = file_client.get_file_properties()
+
+            print('9. Get file metadata.')
+            get_metadata = properties['metadata']
+            for k, v in get_metadata.items():
+                print("\t" + k + ": " + v)
+
+            # This is for demo purposes, all files will be deleted when share is deleted
+            print('10. Delete file.')
+            file_client.delete_file()
+
+            # This is for demo purposes, all directories will be deleted when share is deleted
+            print('11. Delete directory.')
+            directory_client.delete_directory()
+
+        finally:
+            print('12. Delete share.')
+            share_client.delete_share(share_name)
+
+        print("Metadata and properties sample completed")

+ 22 - 0
Target/Azure/AddUp/file_basic_samples_2.py

@@ -0,0 +1,22 @@
+def run_all_samples(self, connection_string):
+        print('Azure Storage File Basis samples - Starting.')
+        
+        #declare variables
+        filename = 'filesample' + self.random_data.get_random_name(6)
+        sharename = 'sharesample' + self.random_data.get_random_name(6)
+        
+        try:
+            # Create an instance of ShareServiceClient
+            service = ShareServiceClient.from_connection_string(conn_str=connection_string)
+
+            print('\n\n* Basic file operations *\n')
+            self.basic_file_operations(sharename, filename, service)
+
+        except Exception as e:
+            print('error:' + e) 
+
+        finally:
+            # Delete all Azure Files created in this sample
+            self.file_delete_samples(sharename, filename, service)
+
+        print('\nAzure Storage File Basic samples - Completed.\n')

+ 105 - 0
Target/Azure/AddUp/file_basic_samples_3.py

@@ -0,0 +1,105 @@
+def basic_file_operations(self, sharename, filename, service):
+        # Creating an SMB file share in your Azure Files account.
+        print('\nAttempting to create a sample file from text for upload demonstration.')   
+        # All directories and share must be created in a parent share.
+        # Max capacity: 5TB per share
+
+        print('Creating sample share.')
+        share_client = service.create_share(share_name=sharename)
+        print('Sample share "'+ sharename +'" created.')
+
+
+        # Creating an optional file directory in your Azure Files account.
+        print('Creating a sample directory.')    
+        # Get the directory client
+        directory_client = share_client.create_directory("mydirectory")
+        print('Sample directory "mydirectory" created.')
+
+
+        # Uploading text to sharename/mydirectory/my_text_file in Azure Files account.
+        # Max capacity: 1TB per file
+        print('Uploading a sample file from text.')   
+        # create_file_client
+        file_client = directory_client.get_file_client(filename)
+        # Upload a file
+        file_client.upload_file('Hello World! - from text sample')
+        print('Sample file "' + filename + '" created and uploaded to: ' + sharename + '/mydirectory')
+  
+
+        # Demonstrate how to copy a file
+        print('\nCopying file ' + filename)
+        # Create another file client which will copy the file from url
+        destination_file_client = share_client.get_file_client('file1copy')
+
+        # Copy the sample source file from the url to the destination file
+        copy_resp = destination_file_client.start_copy_from_url(source_url=file_client.url)
+        if copy_resp['copy_status'] ==  'pending':
+            # Demonstrate how to abort a copy operation (just for demo, probably will never get here)
+            print('Abort copy operation')
+            destination_file.abort_copy()
+        else:
+            print('Copy was a ' + copy_resp['copy_status'])
+        
+
+        # Demonstrate how to create a share and upload a file from a local temporary file path
+        print('\nAttempting to upload a sample file from path for upload demonstration.')  
+        # Creating a temporary file to upload to Azure Files
+        print('Creating a temporary file from text.') 
+        with tempfile.NamedTemporaryFile(delete=False) as my_temp_file:
+            my_temp_file.file.write(b"Hello world!")
+        print('Sample temporary file created.') 
+
+        # Uploading my_temp_file to sharename folder in Azure Files
+        # Max capacity: 1TB per file
+        print('Uploading a sample file from local path.')
+        # Create file_client
+        file_client = share_client.get_file_client(filename)
+
+        # Upload a file
+        with open(my_temp_file.name, "rb") as source_file:
+            file_client.upload_file(source_file)
+
+        print('Sample file "' + filename + '" uploaded from path to share: ' + sharename)
+
+        # Close the temp file
+        my_temp_file.close()
+
+        # Get the list of valid ranges and write to the specified range
+        print('\nGet list of valid ranges of the file.') 
+        file_ranges = file_client.get_ranges()
+
+        data = b'abcdefghijkl'
+        print('Put a range of data to the file.')
+        
+        file_client.upload_range(data=data, offset=file_ranges[0]['start'], length=len(data))
+
+
+        # Demonstrate how to download a file from Azure Files
+        # The following example download the file that was previously uploaded to Azure Files
+        print('\nAttempting to download a sample file from Azure files for demonstration.')
+
+        destination_file = tempfile.tempdir + '\mypathfile.txt'
+
+        with open(destination_file, "wb") as file_handle:
+            data = file_client.download_file()
+            data.readinto(file_handle)
+
+        print('Sample file downloaded to: ' + destination_file)
+
+
+        # Demonstrate how to list files and directories contains under Azure File share
+        print('\nAttempting to list files and directories directory under share "' + sharename + '":')
+
+        # Create a generator to list directories and files under share
+        # This is not a recursive listing operation
+        generator = share_client.list_directories_and_files()
+
+        # Prints the directories and files under the share
+        for file_or_dir in generator:
+            print(file_or_dir['name'])
+        
+        # remove temp file
+        os.remove(my_temp_file.name)
+
+        print('Files and directories under share "' + sharename + '" listed.')
+        print('\nCompleted successfully - Azure basic Files operations.')

+ 29 - 0
Target/Azure/AddUp/file_basic_samples_4.py

@@ -0,0 +1,29 @@
+def file_delete_samples(self, sharename, filename, service):
+        print('\nDeleting all samples created for this demonstration.')
+
+        try:
+            # Deleting file: 'sharename/mydirectory/filename'
+            # This is for demo purposes only, it's unnecessary, as we're deleting the share later
+            print('Deleting a sample file.')
+
+            share_client = service.get_share_client(sharename)
+            directory_client = share_client.get_directory_client('mydirectory')
+            
+            directory_client.delete_file(file_name=filename)
+            print('Sample file "' + filename + '" deleted from: ' + sharename + '/mydirectory' )
+
+            # Deleting directory: 'sharename/mydirectory'
+            print('Deleting sample directory and all files and directories under it.')
+            share_client.delete_directory('mydirectory')
+            print('Sample directory "/mydirectory" deleted from: ' + sharename)
+
+            # Deleting share: 'sharename'
+            print('Deleting sample share ' + sharename + ' and all files and directories under it.')
+            share_client.delete_share(sharename)
+            print('Sample share "' + sharename + '" deleted.')
+
+            print('\nCompleted successfully - Azure Files samples deleted.')
+
+        except Exception as e:
+            print('********ErrorDelete***********')
+            print(e)

+ 40 - 0
Target/Azure/AddUp/python-quick-start_3.py

@@ -0,0 +1,40 @@
+def upload_file_to_container(blob_storage_service_client: BlobServiceClient,
+                             container_name: str, file_path: str) -> batchmodels.ResourceFile:
+    """
+    Uploads a local file to an Azure Blob storage container.
+
+    :param blob_storage_service_client: A blob service client.
+    :param str container_name: The name of the Azure Blob storage container.
+    :param str file_path: The local path to the file.
+    :return: A ResourceFile initialized with a SAS URL appropriate for Batch
+    tasks.
+    """
+    blob_name = os.path.basename(file_path)
+    blob_client = blob_storage_service_client.get_blob_client(container_name, blob_name)
+
+    print(f'Uploading file {file_path} to container [{container_name}]...')
+
+    with open(file_path, "rb") as data:
+        blob_client.upload_blob(data, overwrite=True)
+
+    sas_token = generate_blob_sas(
+        config.STORAGE_ACCOUNT_NAME,
+        container_name,
+        blob_name,
+        account_key=config.STORAGE_ACCOUNT_KEY,
+        permission=BlobSasPermissions(read=True),
+        expiry=datetime.datetime.utcnow() + datetime.timedelta(hours=2)
+    )
+
+    sas_url = generate_sas_url(
+        config.STORAGE_ACCOUNT_NAME,
+        config.STORAGE_ACCOUNT_DOMAIN,
+        container_name,
+        blob_name,
+        sas_token
+    )
+
+    return batchmodels.ResourceFile(
+        http_url=sas_url,
+        file_path=blob_name
+    )

+ 24 - 0
Target/Azure/AddUp/table_advanced_samples_2.py

@@ -0,0 +1,24 @@
+def run_all_samples(self, account):
+        table_service = account.create_table_service()
+        print('Azure Storage Advanced Table samples - Starting.')
+        
+        print('\n\n* List tables *\n')
+        self.list_tables(table_service)
+        
+        if not account.is_azure_cosmosdb_table():
+           print('\n\n* Set service properties *\n')
+           self.set_service_properties(table_service)
+        
+           print('\n\n* Set Cors rules *\n')
+           self.set_cors_rules(table_service)
+        
+           print('\n\n* ACL operations *\n')
+           self.table_acl_operations(table_service)
+        
+        if (config.IS_EMULATED):
+            print('\n\n* Shared Access Signature is not supported in emulator *\n')
+        else:
+            print('\n\n* SAS operations *\n')
+            self.table_operations_with_sas(account)
+
+        print('\nAzure Storage Advanced Table samples - Completed.\n')

+ 18 - 0
Target/Azure/AddUp/table_advanced_samples_4.py

@@ -0,0 +1,18 @@
+def set_service_properties(self, table_service):
+        print('1. Get Table service properties')
+        props = table_service.get_table_service_properties()
+
+        retention = RetentionPolicy(enabled=True, days=5)
+        logging = Logging(delete=True, read=False, write=True, retention_policy=retention)
+        hour_metrics = Metrics(enabled=True, include_apis=True, retention_policy=retention)
+        minute_metrics = Metrics(enabled=False)
+
+        try:
+            print('2. Ovewrite Table service properties')
+            table_service.set_table_service_properties(logging=logging, hour_metrics=hour_metrics, minute_metrics=minute_metrics)
+
+        finally:
+            print('3. Revert Table service properties back to the original ones')
+            table_service.set_table_service_properties(logging=props.logging, hour_metrics=props.hour_metrics, minute_metrics=props.minute_metrics)
+
+        print('4. Set Table service properties completed')

+ 21 - 0
Target/Azure/AddUp/table_advanced_samples_5.py

@@ -0,0 +1,21 @@
+def set_cors_rules(self, table_service):
+        cors_rule = CorsRule(
+            allowed_origins=['*'], 
+            allowed_methods=['POST', 'GET'],
+            allowed_headers=['*'],
+            exposed_headers=['*'],
+            max_age_in_seconds=3600)
+        
+        print('1. Get Cors Rules')
+        original_cors_rules = table_service.get_table_service_properties().cors
+
+        try:        
+            print('2. Overwrite Cors Rules')
+            table_service.set_table_service_properties(cors=[cors_rule])
+
+        finally:
+            #reverting cors rules back to the original ones
+            print('3. Revert Cors Rules back the original ones')
+            table_service.set_table_service_properties(cors=original_cors_rules)
+        
+        print("CORS sample completed")

+ 50 - 0
Target/Azure/AddUp/table_advanced_samples_7.py

@@ -0,0 +1,50 @@
+def table_operations_with_sas(self, account):
+        table_name = 'sastable' + self.random_data.get_random_name(6)
+        
+        try:
+            # Create a Table Service object
+            table_service = account.create_table_service()
+            
+            print('1. Create table with name - ' + table_name)
+            table_service.create_table(table_name)
+            
+            # Create a Shared Access Signature for the table
+            print('2. Get sas for table')
+            
+            table_sas = table_service.generate_table_shared_access_signature(
+                table_name, 
+                TablePermissions.QUERY + TablePermissions.ADD + TablePermissions.UPDATE + TablePermissions.DELETE, 
+                datetime.datetime.utcnow() + datetime.timedelta(hours=1))
+
+            shared_account = TableStorageAccount(account_name=account.account_name, sas_token=table_sas, endpoint_suffix=account.endpoint_suffix)
+            shared_table_service = shared_account.create_table_service()
+
+            # Create a sample entity to insert into the table
+            customer = {'PartitionKey': 'Harp', 'RowKey': '1', 'email' : 'harp@contoso.com', 'phone' : '555-555-5555'}
+
+            # Insert the entity into the table
+            print('3. Insert new entity into table with sas - ' + table_name)
+            shared_table_service.insert_entity(table_name, customer)
+            
+            # Demonstrate how to query the entity
+            print('4. Read the inserted entity with sas.')
+            entity = shared_table_service.get_entity(table_name, 'Harp', '1')
+            
+            print(entity['email'])
+            print(entity['phone'])
+
+            # Demonstrate how to update the entity by changing the phone number
+            print('5. Update an existing entity by changing the phone number with sas')
+            customer = {'PartitionKey': 'Harp', 'RowKey': '1', 'email' : 'harp@contoso.com', 'phone' : '425-123-1234'}
+            shared_table_service.update_entity(table_name, customer)
+
+            # Demonstrate how to delete an entity
+            print('6. Delete the entity with sas')
+            shared_table_service.delete_entity(table_name, 'Harp', '1')
+
+        finally:
+            print('7. Delete table')
+            if(table_service.exists(table_name)):
+                table_service.delete_table(table_name)
+            
+        print("Table operations with sas completed")

+ 58 - 0
Target/Azure/AddUp/table_basic_samples_2.py

@@ -0,0 +1,58 @@
+def run_all_samples(self, account):
+        print('Azure Storage Basic Table samples - Starting.')
+        table_name = 'tablebasics' + self.random_data.get_random_name(6)
+        table_service = None
+        try:
+            table_service = account.create_table_service()
+
+            # Create a new table
+            print('Create a table with name - ' + table_name)
+
+            try:
+                table_service.create_table(table_name)
+            except Exception as err:
+                print('Error creating table, ' + table_name + 'check if it already exists')
+ 
+            # Create a sample entity to insert into the table
+            customer = {'PartitionKey': 'Harp', 'RowKey': '1', 'email' : 'harp@contoso.com', 'phone' : '555-555-5555'}
+
+            # Insert the entity into the table
+            print('Inserting a new entity into table - ' + table_name)
+            table_service.insert_entity(table_name, customer)
+            print('Successfully inserted the new entity')
+
+            # Demonstrate how to query the entity
+            print('Read the inserted entity.')
+            entity = table_service.get_entity(table_name, 'Harp', '1')
+            print(entity['email'])
+            print(entity['phone'])
+
+            # Demonstrate how to update the entity by changing the phone number
+            print('Update an existing entity by changing the phone number')
+            customer = {'PartitionKey': 'Harp', 'RowKey': '1', 'email' : 'harp@contoso.com', 'phone' : '425-123-1234'}
+            table_service.update_entity(table_name, customer)
+
+            # Demonstrate how to query the updated entity, filter the results with a filter query and select only the value in the phone column
+            print('Read the updated entity with a filter query')
+            entities = table_service.query_entities(table_name, filter="PartitionKey eq 'Harp'", select='phone')
+            for entity in entities:
+                print(entity['phone'])
+
+            # Demonstrate how to delete an entity
+            print('Delete the entity')
+            table_service.delete_entity(table_name, 'Harp', '1')
+            print('Successfully deleted the entity')
+
+        except Exception as e:
+            if (config.IS_EMULATED):
+                print('Error occurred in the sample. If you are using the emulator, please make sure the emulator is running.', e)
+            else: 
+                print('Error occurred in the sample. Please make sure the account name and key are correct.', e)
+        finally:
+            # Demonstrate deleting the table, if you don't want to have the table deleted comment the below block of code
+            print('Deleting the table.')
+            if(table_service.exists(table_name)):
+                table_service.delete_table(table_name)
+            print('Successfully deleted the table')
+
+        print('\nAzure Storage Basic Table samples - Completed.\n')

+ 32 - 0
Target/Azure/DLfile_6.py

@@ -0,0 +1,32 @@
+def upload_files():
+ adl = core.AzureDLFileSystem(adlCreds, store_name=config.store_name)
+ uploadedFolders = adl.ls(adls_upload_folder_path)
+ 
+ uploadedFolders = set([folder.replace(adls_upload_folder_path[1:], "")+"/" for folder in uploadedFolders])
+ 
+ local_folders = glob.glob(local_upload_folder_path+"*") # * means all if need specific format then *.csv
+ local_folders = set([d.replace(local_upload_folder_path, "")+"/" for d in local_folders])
+
+ to_upload_folders = local_folders.difference(uploadedFolders)
+
+ folder_names = sorted([d.replace(local_upload_folder_path, "") for d in to_upload_folders])
+
+ files = []
+ for folder in folder_names:
+  path = local_upload_folder_path+folder
+  for f in listdir(path):
+   if isfile(join(path, f)):
+    files.append(folder+f)
+
+
+ print("Uploading the following folders:<br>{}<br>Total number of files to upload:<br>{}".format(", ". join(folder_names), len(files)))
+ 
+
+ for f in files:
+  adl.put(local_upload_folder_path+f, adls_upload_folder_path+f)
+    
+
+ print("Upload finished.")
+ time.sleep(2)
+ global uploaded_files
+ uploaded_files = True

+ 12 - 0
Target/Azure/add_azure_account_1.py

@@ -0,0 +1,12 @@
+def create_azure_account(env, admin_api_key, account_name, azure_ad_id, azure_app_id, azure_api_access_key, azure_subscription_id):
+	"""
+	Creates an Azure Account in CloudCheckr. It will populate it with azure subscription credentials that were provided.
+	"""
+
+	api_url = env + "/api/account.json/add_azure_inventory_account"
+
+	add_azure_account_info = json.dumps({"account_name": account_name, "azure_ad_id": azure_ad_id, "azure_app_id": azure_app_id, "azure_api_access_key": azure_api_access_key, "azure_subscription_id": azure_subscription_id})
+
+	r7 = requests.post(api_url, headers = {"Content-Type": "application/json", "access_key": admin_api_key}, data = add_azure_account_info)
+
+	print(r7.json())

+ 16 - 0
Target/Azure/add_azure_account_and_set_role_assignment_1.py

@@ -0,0 +1,16 @@
+def create_azure_account(env, CloudCheckrApiKey, account_name, AzureDirectoryId, AzureCloudCheckrApplicationId,
+                         AzureCloudCheckrApplicationSecret, AzureSubscriptionId):
+    """
+    Creates an Azure Account in CloudCheckr. It will populate it with azure subscription credentials that were provided.
+    """
+
+    api_url = env + "/api/account.json/add_azure_inventory_account"
+
+    add_azure_account_info = json.dumps(
+        {"account_name": account_name, "azure_ad_id": AzureDirectoryId, "azure_app_id": AzureCloudCheckrApplicationId,
+         "azure_api_access_key": AzureCloudCheckrApplicationSecret, "azure_subscription_id": AzureSubscriptionId})
+
+    r7 = requests.post(api_url, headers={"Content-Type": "application/json", "access_key": CloudCheckrApiKey},
+                       data=add_azure_account_info)
+
+    print(r7.json())

+ 18 - 0
Target/Azure/add_azure_account_and_set_role_assignment_2.py

@@ -0,0 +1,18 @@
+def get_azure_reader_role_id(AzureApiBearerToken, AzureSubscriptionId):
+    """
+    Gets the id of the reader role for this subscription.
+
+    https://docs.microsoft.com/en-us/rest/api/authorization/roleassignments/list
+    """
+
+    api_url = "https://management.azure.com/subscriptions/" + AzureSubscriptionId + "/providers/Microsoft.Authorization/roleDefinitions?api-version=2015-07-01&$filter=roleName eq 'Reader'"
+    authorization_value = "Bearer " + AzureApiBearerToken
+
+    response = requests.get(api_url, headers={"Authorization": authorization_value})
+
+    if "value" in response.json():
+        value = (response.json()["value"])[0]
+        if "id" in value:
+            return value["id"]
+    print("Failed to get the Azure Reader Role Id")
+    return None

+ 20 - 0
Target/Azure/add_azure_account_and_set_role_assignment_3.py

@@ -0,0 +1,20 @@
+def get_azure_cloudcheckr_service_principal_id(AzureGraphApiBearerToken, AzureCloudCheckrApplicationName):
+    """
+    Gets the service principal id Azure Application that was specifically created for CloudCheckr.
+    Note: This is not the application id. The service principal id is required for the role assignment.
+    This uses the microsoft Graph API.
+
+    https://docs.microsoft.com/en-us/graph/api/serviceprincipal-list?view=graph-rest-1.0&tabs=http
+    """
+
+    api_url = "https://graph.microsoft.com/v1.0/servicePrincipals?$filter=displayName eq '" + AzureCloudCheckrApplicationName + "'"
+    authorization_value = "Bearer " + AzureGraphApiBearerToken
+
+    response = requests.get(api_url, headers={"Authorization": authorization_value})
+
+    if "value" in response.json():
+        value = (response.json()["value"])[0]
+        if ("id" in value) and ("appId" in value):
+            return value["id"], value["appId"]
+    print("Failed to get the Azure CloudCheckr Application Service principal Id")
+    return None

+ 26 - 0
Target/Azure/add_azure_account_and_set_role_assignment_4.py

@@ -0,0 +1,26 @@
+def set_azure_cloudcheckr_application_service_assignment(AzureApiBearerToken, AzureReaderRoleId,
+                                                         AzureCloudCheckrApplicationServicePrincipalId,
+                                                         AzureSubscriptionId):
+    """
+    Sets the previously created CloudCheckr application to have a reader role assignment.
+
+    https://docs.microsoft.com/en-us/azure/role-based-access-control/role-assignments-rest
+    """
+
+    RoleAssignmentId = str(uuid.uuid1())
+
+    api_url = "https://management.azure.com/subscriptions/" + AzureSubscriptionId + "/providers/Microsoft.Authorization/roleAssignments/" + RoleAssignmentId + "?api-version=2015-07-01"
+    authorization_value = "Bearer " + AzureApiBearerToken
+    role_assignment_data = json.dumps({"properties": {"principalId": AzureCloudCheckrApplicationServicePrincipalId,
+                                                      "roleDefinitionId": AzureReaderRoleId}})
+
+    response = requests.put(api_url, headers={"Authorization": authorization_value, "Content-Type": "application/json"},
+                            data=role_assignment_data)
+    print(response.json())
+
+    if "properties" in response.json():
+        properties = response.json()["properties"]
+        if "roleDefinitionId" in properties:
+            return properties["roleDefinitionId"]
+    print("Failed to set role assignment for the CloudCheckr Application to the specified subscription")
+    return None

+ 20 - 0
Target/Azure/add_azure_account_and_set_role_assignment_5.py

@@ -0,0 +1,20 @@
+def get_azure_bearer_token(resource_url, azure_directory_id, azure_admin_application_id,
+                           azure_admin_application_secret):
+    """
+    Uses OAuth 2.0 to get the bearer token based on the client id and client secret.
+    """
+
+    api_url = "https://login.microsoftonline.com/" + azure_directory_id + "/oauth2/token"
+
+    client = {'grant_type': 'client_credentials',
+              'client_id': azure_admin_application_id,
+              'client_secret': azure_admin_application_secret,
+              'resource': resource_url,
+              }
+
+    response = requests.post(api_url, data=client)
+
+    if "access_token" in response.json():
+        return response.json()["access_token"]
+    print("Could not get Bearer token")
+    return None

+ 63 - 0
Target/Azure/add_azure_account_and_set_role_assignment_6.py

@@ -0,0 +1,63 @@
+def main():
+    try:
+        CloudCheckrApiKey = str(sys.argv[1])
+    except IndexError:
+        print("Must include an admin api key in the command line")
+        return
+
+    try:
+        NameOfCloudCheckrAccount = str(sys.argv[2])
+    except IndexError:
+        print("Must include a cloudcheckr account name")
+        return
+
+    try:
+        AzureDirectoryId = str(sys.argv[3])
+    except IndexError:
+        print("Must include an Azure Directory Id")
+        return
+
+    try:
+        AzureSubscriptionId = str(sys.argv[4])
+    except IndexError:
+        print("Must include an Azure Subscription Id")
+        return
+
+    try:
+        AzureAdminApplicationId = str(sys.argv[5])
+    except IndexError:
+        print("Must include an Azure Admin ApplictApi Id")
+        return
+
+    try:
+        AzureAdminApplicationSecret = str(sys.argv[6])
+    except IndexError:
+        print("Must include an Azure Admin Application Secret")
+        return
+
+    try:
+        AzureCloudCheckrApplicationName = str(sys.argv[7])
+    except IndexError:
+        print("Must include an Azure CloudCheckr Application Name")
+        return
+
+    try:
+        AzureCloudCheckrApplicationSecret = str(sys.argv[8])
+    except IndexError:
+        print("Must include an Azure CloudCheckr Application Secret")
+        return
+
+    env = "https://glacier.cloudcheckr.com"
+
+    AzureApiBearerToken = get_azure_bearer_token("https://management.azure.com/", AzureDirectoryId,
+                                                 AzureAdminApplicationId, AzureAdminApplicationSecret)
+    AzureGraphApiBearerToken = get_azure_bearer_token("https://graph.microsoft.com/", AzureDirectoryId,
+                                                      AzureAdminApplicationId, AzureAdminApplicationSecret)
+    AzureReaderRoleId = get_azure_reader_role_id(AzureApiBearerToken, AzureSubscriptionId)
+    AzureCloudCheckrApplicationServicePrincipalId, AzureCloudCheckrApplicationId = get_azure_cloudcheckr_service_principal_id(
+        AzureGraphApiBearerToken, AzureCloudCheckrApplicationName)
+    set_azure_cloudcheckr_application_service_assignment(AzureApiBearerToken, AzureReaderRoleId,
+                                                         AzureCloudCheckrApplicationServicePrincipalId,
+                                                         AzureSubscriptionId)
+    create_azure_account(env, CloudCheckrApiKey, NameOfCloudCheckrAccount, AzureDirectoryId,
+                         AzureCloudCheckrApplicationId, AzureCloudCheckrApplicationSecret, AzureSubscriptionId)

+ 3 - 0
Target/Azure/adls_2.py

@@ -0,0 +1,3 @@
+def execute(self, context: "Context") -> Any:
+        hook = AzureDataLakeHook(azure_data_lake_conn_id=self.azure_data_lake_conn_id)
+        return hook.remove(path=self.path, recursive=self.recursive, ignore_not_found=self.ignore_not_found)

+ 4 - 0
Target/Azure/adls_4.py

@@ -0,0 +1,4 @@
+def execute(self, context: "Context") -> list:
+        hook = AzureDataLakeHook(azure_data_lake_conn_id=self.azure_data_lake_conn_id)
+        self.log.info('Getting list of ADLS files in path: %s', self.path)
+        return hook.list(path=self.path)

+ 10 - 0
Target/Azure/azure_clients_1.py

@@ -0,0 +1,10 @@
+def get_resourcegroup_client(parameters):
+    tenant_id = parameters.get('azure_tenant_id')
+    client_id = parameters.get('azure_client_id')
+    secret = parameters.get('azure_client_secret')
+    subscription_id = parameters.get('azure_subscription_id')
+
+    token_credential = ClientSecretCredential(
+        tenant_id, client_id, secret)
+    resourcegroup_client = ResourceManagementClient(token_credential, subscription_id)
+    return resourcegroup_client

+ 11 - 0
Target/Azure/azure_clients_2.py

@@ -0,0 +1,11 @@
+def get_compute_client(parameters):
+    tenant_id = parameters.get('azure_tenant_id')
+    client_id = parameters.get('azure_client_id')
+    secret = parameters.get('azure_client_secret')
+    subscription_id = parameters.get('azure_subscription_id')
+
+    token_credential = ClientSecretCredential(
+        tenant_id, client_id, secret)
+    compute_client = ComputeManagementClient(token_credential,
+                                             subscription_id)
+    return compute_client

+ 11 - 0
Target/Azure/azure_clients_3.py

@@ -0,0 +1,11 @@
+def get_network_client(parameters):
+    tenant_id = parameters.get('azure_tenant_id')
+    client_id = parameters.get('azure_client_id')
+    secret = parameters.get('azure_client_secret')
+    subscription_id = parameters.get('azure_subscription_id')
+
+    token_credential = ClientSecretCredential(
+        tenant_id, client_id, secret)
+    network_client = NetworkManagementClient(token_credential,
+                                             subscription_id)
+    return network_client

+ 11 - 0
Target/Azure/azure_clients_4.py

@@ -0,0 +1,11 @@
+def get_dns_client(parameters):
+    tenant_id = parameters.get('azure_tenant_id')
+    client_id = parameters.get('azure_client_id')
+    secret = parameters.get('azure_client_secret')
+    subscription_id = parameters.get('azure_subscription_id')
+
+    token_credential = ClientSecretCredential(
+        tenant_id, client_id, secret)
+    dns_client = PrivateDnsManagementClient(token_credential,
+                                            subscription_id)
+    return dns_client

+ 11 - 0
Target/Azure/azure_clients_5.py

@@ -0,0 +1,11 @@
+def get_dns_ops_client(parameters):
+    tenant_id = parameters.get('azure_tenant_id')
+    client_id = parameters.get('azure_client_id')
+    secret = parameters.get('azure_client_secret')
+    subscription_id = parameters.get('azure_subscription_id')
+
+    token_credential = ClientSecretCredential(
+        tenant_id, client_id, secret)
+    dns_ops_client = DnsManagementClient(token_credential,
+                                            subscription_id)
+    return dns_ops_client

+ 11 - 0
Target/Azure/azure_clients_6.py

@@ -0,0 +1,11 @@
+def get_blob_service_client(parameters):
+    tenant_id = parameters.get('azure_tenant_id')
+    client_id = parameters.get('azure_client_id')
+    secret = parameters.get('azure_client_secret')
+    account_name = parameters.get('storage_account_name')
+    token_credential = ClientSecretCredential(
+        tenant_id, client_id, secret)
+    blob_service_client = BlobServiceClient(
+        account_url="https://%s.blob.core.windows.net" % account_name,
+        credential=token_credential)
+    return blob_service_client

+ 11 - 0
Target/Azure/azure_clients_7.py

@@ -0,0 +1,11 @@
+def get_queue_service_client(parameters):
+    tenant_id = parameters.get('azure_tenant_id')
+    client_id = parameters.get('azure_client_id')
+    secret = parameters.get('azure_client_secret')
+    account_name = parameters.get('storage_account_name')
+    token_credential = ClientSecretCredential(
+        tenant_id, client_id, secret)
+    queue_service_client = QueueServiceClient(
+        account_url="https://%s.queue.core.windows.net" % account_name,
+        credential=token_credential)
+    return queue_service_client

+ 13 - 0
Target/Azure/azure_clients_8.py

@@ -0,0 +1,13 @@
+def get_datalake_client(parameters):
+    tenant_id = parameters.get('azure_tenant_id')
+    client_id = parameters.get('azure_client_id')
+    secret = parameters.get('azure_client_secret')
+    subscription_id = parameters.get('azure_subscription_id')
+    credentials = ServicePrincipalCredentials(
+        client_id=client_id,
+        secret=secret,
+        tenant=tenant_id)
+
+    datalake_client = DataLakeStoreAccountManagementClient(credentials,
+                                                           subscription_id)
+    return datalake_client

+ 11 - 0
Target/Azure/azure_clients_9.py

@@ -0,0 +1,11 @@
+def get_storage_client(parameters):
+    tenant_id = parameters.get('azure_tenant_id')
+    client_id = parameters.get('azure_client_id')
+    secret = parameters.get('azure_client_secret')
+    subscription_id = parameters.get('azure_subscription_id')
+
+    token_credential = ClientSecretCredential(
+        tenant_id, client_id, secret)
+    storage_client = StorageManagementClient(token_credential,
+                                             subscription_id)
+    return storage_client

+ 25 - 0
Target/Azure/azure_rm_14.py

@@ -0,0 +1,25 @@
+def get_inventory(self):
+        if len(self.resource_groups) > 0:
+            # get VMs for requested resource groups
+            for resource_group in self.resource_groups:
+                try:
+                    virtual_machines = self._compute_client.virtual_machines.list(resource_group)
+                except Exception as exc:
+                    sys.exit("Error: fetching virtual machines for resource group {0} - {1}".format(resource_group, str(exc)))
+                if self._args.host or self.tags:
+                    selected_machines = self._selected_machines(virtual_machines)
+                    self._load_machines(selected_machines)
+                else:
+                    self._load_machines(virtual_machines)
+        else:
+            # get all VMs within the subscription
+            try:
+                virtual_machines = self._compute_client.virtual_machines.list_all()
+            except Exception as exc:
+                sys.exit("Error: fetching virtual machines - {0}".format(str(exc)))
+
+            if self._args.host or self.tags or self.locations:
+                selected_machines = self._selected_machines(virtual_machines)
+                self._load_machines(selected_machines)
+            else:
+                self._load_machines(virtual_machines)

+ 98 - 0
Target/Azure/azure_rm_15.py

@@ -0,0 +1,98 @@
+def _load_machines(self, machines):
+        for machine in machines:
+            id_dict = azure_id_to_dict(machine.id)
+
+            # TODO - The API is returning an ID value containing resource group name in ALL CAPS. If/when it gets
+            #       fixed, we should remove the .lower(). Opened Issue
+            #       #574: https://github.com/Azure/azure-sdk-for-python/issues/574
+            resource_group = id_dict['resourceGroups'].lower()
+
+            if self.group_by_security_group:
+                self._get_security_groups(resource_group)
+
+            host_vars = dict(
+                ansible_host=None,
+                private_ip=None,
+                private_ip_alloc_method=None,
+                public_ip=None,
+                public_ip_name=None,
+                public_ip_id=None,
+                public_ip_alloc_method=None,
+                fqdn=None,
+                location=machine.location,
+                name=machine.name,
+                type=machine.type,
+                id=machine.id,
+                tags=machine.tags,
+                network_interface_id=None,
+                network_interface=None,
+                resource_group=resource_group,
+                mac_address=None,
+                plan=(machine.plan.name if machine.plan else None),
+                virtual_machine_size=machine.hardware_profile.vm_size,
+                computer_name=(machine.os_profile.computer_name if machine.os_profile else None),
+                provisioning_state=machine.provisioning_state,
+            )
+
+            host_vars['os_disk'] = dict(
+                name=machine.storage_profile.os_disk.name,
+                operating_system_type=machine.storage_profile.os_disk.os_type.value
+            )
+
+            if self.include_powerstate:
+                host_vars['powerstate'] = self._get_powerstate(resource_group, machine.name)
+
+            if machine.storage_profile.image_reference:
+                host_vars['image'] = dict(
+                    offer=machine.storage_profile.image_reference.offer,
+                    publisher=machine.storage_profile.image_reference.publisher,
+                    sku=machine.storage_profile.image_reference.sku,
+                    version=machine.storage_profile.image_reference.version
+                )
+
+            # Add windows details
+            if machine.os_profile is not None and machine.os_profile.windows_configuration is not None:
+                host_vars['windows_auto_updates_enabled'] = \
+                    machine.os_profile.windows_configuration.enable_automatic_updates
+                host_vars['windows_timezone'] = machine.os_profile.windows_configuration.time_zone
+                host_vars['windows_rm'] = None
+                if machine.os_profile.windows_configuration.win_rm is not None:
+                    host_vars['windows_rm'] = dict(listeners=None)
+                    if machine.os_profile.windows_configuration.win_rm.listeners is not None:
+                        host_vars['windows_rm']['listeners'] = []
+                        for listener in machine.os_profile.windows_configuration.win_rm.listeners:
+                            host_vars['windows_rm']['listeners'].append(dict(protocol=listener.protocol,
+                                                                             certificate_url=listener.certificate_url))
+
+            for interface in machine.network_profile.network_interfaces:
+                interface_reference = self._parse_ref_id(interface.id)
+                network_interface = self._network_client.network_interfaces.get(
+                    interface_reference['resourceGroups'],
+                    interface_reference['networkInterfaces'])
+                if network_interface.primary:
+                    if self.group_by_security_group and \
+                       self._security_groups[resource_group].get(network_interface.id, None):
+                        host_vars['security_group'] = \
+                            self._security_groups[resource_group][network_interface.id]['name']
+                        host_vars['security_group_id'] = \
+                            self._security_groups[resource_group][network_interface.id]['id']
+                    host_vars['network_interface'] = network_interface.name
+                    host_vars['network_interface_id'] = network_interface.id
+                    host_vars['mac_address'] = network_interface.mac_address
+                    for ip_config in network_interface.ip_configurations:
+                        host_vars['private_ip'] = ip_config.private_ip_address
+                        host_vars['private_ip_alloc_method'] = ip_config.private_ip_allocation_method
+                        if ip_config.public_ip_address:
+                            public_ip_reference = self._parse_ref_id(ip_config.public_ip_address.id)
+                            public_ip_address = self._network_client.public_ip_addresses.get(
+                                public_ip_reference['resourceGroups'],
+                                public_ip_reference['publicIPAddresses'])
+                            host_vars['ansible_host'] = public_ip_address.ip_address
+                            host_vars['public_ip'] = public_ip_address.ip_address
+                            host_vars['public_ip_name'] = public_ip_address.name
+                            host_vars['public_ip_alloc_method'] = public_ip_address.public_ip_allocation_method
+                            host_vars['public_ip_id'] = public_ip_address.id
+                            if public_ip_address.dns_settings:
+                                host_vars['fqdn'] = public_ip_address.dns_settings.fqdn
+
+            self._add_host(host_vars)

+ 59 - 0
Target/Azure/azure_rm_2.py

@@ -0,0 +1,59 @@
+def __init__(self, args):
+        self._args = args
+        self._cloud_environment = None
+        self._compute_client = None
+        self._resource_client = None
+        self._network_client = None
+
+        self.debug = False
+        if args.debug:
+            self.debug = True
+
+        self.credentials = self._get_credentials(args)
+        if not self.credentials:
+            self.fail("Failed to get credentials. Either pass as parameters, set environment variables, "
+                      "or define a profile in ~/.azure/credentials.")
+
+        # if cloud_environment specified, look up/build Cloud object
+        raw_cloud_env = self.credentials.get('cloud_environment')
+        if not raw_cloud_env:
+            self._cloud_environment = azure_cloud.AZURE_PUBLIC_CLOUD  # SDK default
+        else:
+            # try to look up "well-known" values via the name attribute on azure_cloud members
+            all_clouds = [x[1] for x in inspect.getmembers(azure_cloud) if isinstance(x[1], azure_cloud.Cloud)]
+            matched_clouds = [x for x in all_clouds if x.name == raw_cloud_env]
+            if len(matched_clouds) == 1:
+                self._cloud_environment = matched_clouds[0]
+            elif len(matched_clouds) > 1:
+                self.fail("Azure SDK failure: more than one cloud matched for cloud_environment name '{0}'".format(raw_cloud_env))
+            else:
+                if not urlparse.urlparse(raw_cloud_env).scheme:
+                    self.fail("cloud_environment must be an endpoint discovery URL or one of {0}".format([x.name for x in all_clouds]))
+                try:
+                    self._cloud_environment = azure_cloud.get_cloud_from_metadata_endpoint(raw_cloud_env)
+                except Exception as e:
+                    self.fail("cloud_environment {0} could not be resolved: {1}".format(raw_cloud_env, e.message))
+
+        if self.credentials.get('subscription_id', None) is None:
+            self.fail("Credentials did not include a subscription_id value.")
+        self.log("setting subscription_id")
+        self.subscription_id = self.credentials['subscription_id']
+
+        if self.credentials.get('client_id') is not None and \
+           self.credentials.get('secret') is not None and \
+           self.credentials.get('tenant') is not None:
+            self.azure_credentials = ServicePrincipalCredentials(client_id=self.credentials['client_id'],
+                                                                 secret=self.credentials['secret'],
+                                                                 tenant=self.credentials['tenant'],
+                                                                 cloud_environment=self._cloud_environment)
+        elif self.credentials.get('ad_user') is not None and self.credentials.get('password') is not None:
+            tenant = self.credentials.get('tenant')
+            if not tenant:
+                tenant = 'common'
+            self.azure_credentials = UserPassCredentials(self.credentials['ad_user'],
+                                                         self.credentials['password'],
+                                                         tenant=tenant,
+                                                         cloud_environment=self._cloud_environment)
+        else:
+            self.fail("Failed to authenticate with provided credentials. Some attributes were missing. "
+                      "Credentials must include client_id, secret and tenant or ad_user and password.")

+ 11 - 0
Target/Azure/azure_rm_9.py

@@ -0,0 +1,11 @@
+def network_client(self):
+        self.log('Getting network client')
+        if not self._network_client:
+            self._network_client = NetworkManagementClient(
+                self.azure_credentials,
+                self.subscription_id,
+                base_url=self._cloud_environment.endpoints.resource_manager,
+                api_version='2017-06-01'
+            )
+            self._register('Microsoft.Network')
+        return self._network_client

+ 17 - 0
Target/Azure/azure_rm_aks_facts_4.py

@@ -0,0 +1,17 @@
+def list_items(self):
+        """Get all Azure Kubernetes Services"""
+
+        self.log('List all Azure Kubernetes Services')
+
+        try:
+            response = self.containerservice_client.managed_clusters.list(
+                self.resource_group)
+        except AzureHttpError as exc:
+            self.fail('Failed to list all items - {0}'.format(str(exc)))
+
+        results = []
+        for item in response:
+            if self.has_tags(item.tags, self.tags):
+                results.append(self.serialize_obj(item, AZURE_OBJECT_CLASS))
+
+        return results

+ 32 - 0
Target/Azure/azure_service_principal_attribute_1.py

@@ -0,0 +1,32 @@
+def run(self, terms, variables, **kwargs):
+
+        self.set_options(direct=kwargs)
+
+        credentials = {}
+        credentials['azure_client_id'] = self.get_option('azure_client_id', None)
+        credentials['azure_secret'] = self.get_option('azure_secret', None)
+        credentials['azure_tenant'] = self.get_option('azure_tenant', 'common')
+
+        if credentials['azure_client_id'] is None or credentials['azure_secret'] is None:
+            raise AnsibleError("Must specify azure_client_id and azure_secret")
+
+        _cloud_environment = azure_cloud.AZURE_PUBLIC_CLOUD
+        if self.get_option('azure_cloud_environment', None) is not None:
+            cloud_environment = azure_cloud.get_cloud_from_metadata_endpoint(credentials['azure_cloud_environment'])
+
+        try:
+            azure_credentials = ServicePrincipalCredentials(client_id=credentials['azure_client_id'],
+                                                            secret=credentials['azure_secret'],
+                                                            tenant=credentials['azure_tenant'],
+                                                            resource=_cloud_environment.endpoints.active_directory_graph_resource_id)
+
+            client = GraphRbacManagementClient(azure_credentials, credentials['azure_tenant'],
+                                               base_url=_cloud_environment.endpoints.active_directory_graph_resource_id)
+
+            response = list(client.service_principals.list(filter="appId eq '{0}'".format(credentials['azure_client_id'])))
+            sp = response[0]
+
+            return sp.object_id.split(',')
+        except CloudError as ex:
+            raise AnsibleError("Failed to get service principal object id: %s" % to_native(ex))
+        return False

+ 7 - 0
Target/Azure/azure_storage_11.py

@@ -0,0 +1,7 @@
+def size(self, name):
+        """
+        :param name:
+        :rtype: int
+        """
+        blob = self.connection.get_blob_properties(self.azure_container, name)
+        return blob.properties.content_length

+ 28 - 0
Target/Azure/azure_storage_12.py

@@ -0,0 +1,28 @@
+def _save(self, name, content):
+        """
+        :param name:
+        :param File content:
+        :return:
+        """
+        original_name = name.get("original_name")
+        blob_file_name = datetime.now().strftime("%Y%m%d-%H:%M:%S.%f_") + original_name
+        # blob_name = "{}.{}".format(name.get("uuid"), original_name.partition(".")[-1])
+
+        if hasattr(content.file, 'content_type'):
+            content_type = content.file.content_type
+        else:
+            content_type = mimetypes.guess_type(original_name)
+
+        if hasattr(content, 'chunks'):
+            content_data = b''.join(chunk for chunk in content.chunks())
+        else:
+            content_data = content.read()
+
+        print(f'Saving blob: container={self.azure_container}, blob={blob_file_name}')
+        blob_client = self.connection.get_blob_client(container=self.azure_container, blob=blob_file_name)
+        obj = blob_client.upload_blob(content_data)
+        # create_blob_from_bytes(self.azure_container, name, content_data,
+        #
+        #                                        content_settings=ContentSettings(content_type=content_type))
+        af = AttachedFile(original_name, self.azure_container, blob_file_name)
+        return af

+ 15 - 0
Target/Azure/azure_storage_5.py

@@ -0,0 +1,15 @@
+def connection(self):
+
+        if self._connection is None:
+            connect_str = setting("AZURE_STORAGE_CONNECTION_STRING")
+
+            # Create the BlobServiceClient object which will be used to create a container client
+            blob_service_client = BlobServiceClient.from_connection_string(connect_str)
+
+            # Create a unique name for the container
+            container_name = "pac-files"
+
+            # Create a blob client using the local file name as the name for the blob
+            self._connection = blob_service_client
+
+        return self._connection

+ 10 - 0
Target/Azure/azure_storage_8.py

@@ -0,0 +1,10 @@
+def _open(self, container, name, mode="rb"):
+        """
+        :param str name: Filename
+        :param str mode:
+        :rtype: ContentFile
+        """
+        print(f'Retrieving blob: container={self.azure_container}, blob={name}')
+        blob_client = self.connection.get_blob_client(container=container, blob=name)
+        contents = blob_client.download_blob().readall()
+        return ContentFile(contents)

+ 25 - 0
Target/Azure/azure_system_helpers_2.py

@@ -0,0 +1,25 @@
+def provide_azure_data_lake_default_connection(key_file_path: str):
+    """
+    Context manager to provide a temporary value for azure_data_lake_default connection
+    :param key_file_path: Path to file with azure_data_lake_default credentials .json file.
+    """
+    required_fields = {'login', 'password', 'extra'}
+
+    if not key_file_path.endswith(".json"):
+        raise AirflowException("Use a JSON key file.")
+    with open(key_file_path) as credentials:
+        creds = json.load(credentials)
+    missing_keys = required_fields - creds.keys()
+    if missing_keys:
+        message = f"{missing_keys} fields are missing"
+        raise AirflowException(message)
+    conn = Connection(
+        conn_id=DATA_LAKE_CONNECTION_ID,
+        conn_type=DATA_LAKE_CONNECTION_TYPE,
+        host=creds.get("host", None),
+        login=creds.get("login", None),
+        password=creds.get("password", None),
+        extra=json.dumps(creds.get('extra', None)),
+    )
+    with patch_environ({f"AIRFLOW_CONN_{conn.conn_id.upper()}": conn.get_uri()}):
+        yield

+ 9 - 0
Target/Azure/azure_system_helpers_3.py

@@ -0,0 +1,9 @@
+def provide_azure_fileshare(share_name: str, azure_fileshare_conn_id: str, file_name: str, directory: str):
+    AzureSystemTest.prepare_share(
+        share_name=share_name,
+        azure_fileshare_conn_id=azure_fileshare_conn_id,
+        file_name=file_name,
+        directory=directory,
+    )
+    yield
+    AzureSystemTest.delete_share(share_name=share_name, azure_fileshare_conn_id=azure_fileshare_conn_id)

+ 3 - 0
Target/Azure/azure_system_helpers_4.py

@@ -0,0 +1,3 @@
+def create_share(cls, share_name: str, azure_fileshare_conn_id: str):
+        hook = AzureFileShareHook(azure_fileshare_conn_id=azure_fileshare_conn_id)
+        hook.create_share(share_name)

+ 3 - 0
Target/Azure/azure_system_helpers_5.py

@@ -0,0 +1,3 @@
+def delete_share(cls, share_name: str, azure_fileshare_conn_id: str):
+        hook = AzureFileShareHook(azure_fileshare_conn_id=azure_fileshare_conn_id)
+        hook.delete_share(share_name)

+ 3 - 0
Target/Azure/azure_system_helpers_6.py

@@ -0,0 +1,3 @@
+def create_directory(cls, share_name: str, azure_fileshare_conn_id: str, directory: str):
+        hook = AzureFileShareHook(azure_fileshare_conn_id=azure_fileshare_conn_id)
+        hook.create_directory(share_name=share_name, directory_name=directory)

+ 15 - 0
Target/Azure/azure_system_helpers_7.py

@@ -0,0 +1,15 @@
+def upload_file_from_string(
+        cls,
+        string_data: str,
+        share_name: str,
+        azure_fileshare_conn_id: str,
+        file_name: str,
+        directory: str,
+    ):
+        hook = AzureFileShareHook(azure_fileshare_conn_id=azure_fileshare_conn_id)
+        hook.load_string(
+            string_data=string_data,
+            share_name=share_name,
+            directory_name=directory,
+            file_name=file_name,
+        )

+ 16 - 0
Target/Azure/azure_system_helpers_8.py

@@ -0,0 +1,16 @@
+def prepare_share(cls, share_name: str, azure_fileshare_conn_id: str, file_name: str, directory: str):
+        """
+        Create share with a file in given directory. If directory is None, file is in root dir.
+        """
+        cls.create_share(share_name=share_name, azure_fileshare_conn_id=azure_fileshare_conn_id)
+        cls.create_directory(
+            share_name=share_name, azure_fileshare_conn_id=azure_fileshare_conn_id, directory=directory
+        )
+        string_data = "".join(random.choice(string.ascii_letters) for _ in range(1024))
+        cls.upload_file_from_string(
+            string_data=string_data,
+            share_name=share_name,
+            azure_fileshare_conn_id=azure_fileshare_conn_id,
+            file_name=file_name,
+            directory=directory,
+        )

+ 22 - 0
Target/Azure/blob-adapter_2.py

@@ -0,0 +1,22 @@
+def upload(self, file_dict):
+        upload_response = {}
+        for key in file_dict:
+            print("File Dict Key: [{}] value is: {}".format(key, file_dict[key]))
+            print("\nUploading to Azure Storage as blob:\n\t" + key)
+
+            self.blob_client = self.blob_service_client.get_blob_client(container=self.get_config('container_name'), blob=key)
+            with open(file_dict[key], "rb") as data:
+                try:
+                    self.blob_client.upload_blob(data)
+                    print('File: Uploaded Successfully: {}'.format(key))
+                    upload_response[key] = 'Successfully Uploaded'
+                except ResourceExistsError:
+                    print('File: NOT Uploaded Successfully: {}'.format(key))
+                    upload_response[key] = 'This Resource already exists'
+                    upload_response['Partial'] = True
+                    print('This Resource already exists')
+                    # return 'This Resource already exists'
+        print("Before Returning Response:")
+        print(jsonify(upload_response))
+        print("---------------")
+        return upload_response

+ 4 - 0
Target/Azure/blob-adapter_3.py

@@ -0,0 +1,4 @@
+def get_blob_client(self, blob_name):
+        self.blob_client = self.blob_service_client.get_blob_client(
+            container=self.get_config('container_name'), blob=blob_name)
+        return self.blob_client

+ 10 - 0
Target/Azure/blob-adapter_4.py

@@ -0,0 +1,10 @@
+def list_blobs(self):
+        print("\nList blobs in the container")
+        self.container_client = self.blob_service_client.get_container_client(
+            container=self.get_config('container_name'))
+        blob_list = self.container_client.list_blobs()
+        blobs = []
+        for blob in blob_list:
+            # print("\t Blob name: " + blob.name)
+            blobs.append(blob.name)
+        return blobs

+ 17 - 0
Target/Azure/blob-permission_3.py

@@ -0,0 +1,17 @@
+def create_blob_link(self, blob_folder, blob_name) -> str:
+        if blob_folder:
+            full_path_blob = f"{blob_folder}/{blob_name}"
+        else:
+            full_path_blob = blob_name
+        url = f"https://{self.account_name}.blob.core.windows.net/{self.destination}/{full_path_blob}"
+        sas_token = generate_blob_sas(
+            account_name=self.account_name,
+            account_key=self.account_key,
+            container_name=self.destination,
+            blob_name=full_path_blob,
+            permission=BlobSasPermissions(read=True, delete_previous_version=False),
+            expiry=datetime.utcnow() + timedelta(days=self.expiry_download_links),
+        )
+
+        url_with_sas = f"{url}?{sas_token}"
+        return url_with_sas

+ 46 - 0
Target/Azure/blob-upload-1_3.py

@@ -0,0 +1,46 @@
+def upload_single(self):
+        blob_service_client = BlobServiceClient.from_connection_string(self.connection_string)
+        download_links = {}
+
+        for root, dirs, files in os.walk(self.folder):
+            for file in files:
+
+                full_path = os.path.join(root, file)
+
+                # ignore hidden files
+                if file.startswith("."):
+                    continue
+
+                # if list_files is given, only upload matched files
+                if self.list_files and file not in self.list_files:
+                    continue
+
+                # if extension is given only upload if extension is matched
+                if self.extension and os.path.isfile(full_path) and not file.lower().endswith(self.extension.lower()):
+                    continue
+
+                blob_folder = root.replace(self.folder, "").lstrip("/")
+
+                if self.blob_folder:
+                    # we only want to append blob_folder if it actually is a path or folder
+                    # blob_folder can be empty string ""
+                    if blob_folder:
+                        blob_folder = os.path.join(self.blob_folder, blob_folder)
+                    else:
+                        blob_folder = self.blob_folder
+
+                # if no folder is given, just upload to the container root path
+                if not blob_folder:
+                    container = self.destination
+                else:
+                    container = os.path.join(self.destination, blob_folder)
+                container_client = blob_service_client.get_container_client(container=container)
+
+                with open(full_path, "rb") as data:
+                    logging.debug(f"Uploading blob {full_path}")
+                    container_client.upload_blob(data=data, name=file, overwrite=self.overwrite)
+
+                if self.create_download_links:
+                    download_links[file] = self.create_blob_link(blob_folder=blob_folder, blob_name=file)
+
+        return download_links

+ 8 - 0
Target/Azure/blob-upload-1_4.py

@@ -0,0 +1,8 @@
+def upload(self):
+        self.checks()
+
+        logging.info(f"Uploading to container {self.destination} with method = '{self.method}'.")
+        if self.method == "batch":
+            return self.upload_batch()
+        else:
+            return self.upload_single()

+ 12 - 0
Target/Azure/blob-upload-2_4.py

@@ -0,0 +1,12 @@
+def upload_image(self, file_name):
+        # Create blob with same name as local file name
+        blob_client = self.blob_service_client.get_blob_client(container=MY_IMAGE_CONTAINER,
+                                                               blob=file_name)
+        # Get full path to the file
+        upload_file_path = os.path.join(LOCAL_IMAGE_PATH, file_name)
+        # Create blob on storage
+        # Overwrite if it already exists!
+        image_content_setting = ContentSettings(content_type='image/jpeg')
+        print(f"uploading file - {file_name}")
+        with open(upload_file_path, "rb") as data:
+            blob_client.upload_blob(data, overwrite=True, content_settings=image_content_setting)

Some files were not shown because too many files changed in this diff