12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879 |
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- import unittest
- from unittest.mock import Mock, patch
- from airflow.models import Connection
- from airflow.providers.microsoft.azure.hooks.base_azure import AzureBaseHook
- class TestBaseAzureHook(unittest.TestCase):
- @patch('airflow.providers.microsoft.azure.hooks.base_azure.get_client_from_auth_file')
- @patch(
- 'airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook.get_connection',
- return_value=Connection(conn_id='azure_default', extra='{ "key_path": "key_file.json" }'),
- )
- def test_get_conn_with_key_path(self, mock_connection, mock_get_client_from_auth_file):
- mock_sdk_client = Mock()
- auth_sdk_client = AzureBaseHook(mock_sdk_client).get_conn()
- mock_get_client_from_auth_file.assert_called_once_with(
- client_class=mock_sdk_client, auth_path=mock_connection.return_value.extra_dejson['key_path']
- )
- assert auth_sdk_client == mock_get_client_from_auth_file.return_value
- @patch('airflow.providers.microsoft.azure.hooks.base_azure.get_client_from_json_dict')
- @patch(
- 'airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook.get_connection',
- return_value=Connection(conn_id='azure_default', extra='{ "key_json": { "test": "test" } }'),
- )
- def test_get_conn_with_key_json(self, mock_connection, mock_get_client_from_json_dict):
- mock_sdk_client = Mock()
- auth_sdk_client = AzureBaseHook(mock_sdk_client).get_conn()
- mock_get_client_from_json_dict.assert_called_once_with(
- client_class=mock_sdk_client, config_dict=mock_connection.return_value.extra_dejson['key_json']
- )
- assert auth_sdk_client == mock_get_client_from_json_dict.return_value
- @patch('airflow.providers.microsoft.azure.hooks.base_azure.ServicePrincipalCredentials')
- @patch(
- 'airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook.get_connection',
- return_value=Connection(
- conn_id='azure_default',
- login='my_login',
- password='my_password',
- extra='{ "tenantId": "my_tenant", "subscriptionId": "my_subscription" }',
- ),
- )
- def test_get_conn_with_credentials(self, mock_connection, mock_spc):
- mock_sdk_client = Mock()
- auth_sdk_client = AzureBaseHook(mock_sdk_client).get_conn()
- mock_spc.assert_called_once_with(
- client_id=mock_connection.return_value.login,
- secret=mock_connection.return_value.password,
- tenant=mock_connection.return_value.extra_dejson['tenantId'],
- )
- mock_sdk_client.assert_called_once_with(
- credentials=mock_spc.return_value,
- subscription_id=mock_connection.return_value.extra_dejson['subscriptionId'],
- )
- assert auth_sdk_client == mock_sdk_client.return_value
|