12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879 |
- import unittest
- from unittest.mock import Mock, patch
- from airflow.models import Connection
- from airflow.providers.microsoft.azure.hooks.base_azure import AzureBaseHook
- class TestBaseAzureHook(unittest.TestCase):
- @patch('airflow.providers.microsoft.azure.hooks.base_azure.get_client_from_auth_file')
- @patch(
- 'airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook.get_connection',
- return_value=Connection(conn_id='azure_default', extra='{ "key_path": "key_file.json" }'),
- )
- def test_get_conn_with_key_path(self, mock_connection, mock_get_client_from_auth_file):
- mock_sdk_client = Mock()
- auth_sdk_client = AzureBaseHook(mock_sdk_client).get_conn()
- mock_get_client_from_auth_file.assert_called_once_with(
- client_class=mock_sdk_client, auth_path=mock_connection.return_value.extra_dejson['key_path']
- )
- assert auth_sdk_client == mock_get_client_from_auth_file.return_value
- @patch('airflow.providers.microsoft.azure.hooks.base_azure.get_client_from_json_dict')
- @patch(
- 'airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook.get_connection',
- return_value=Connection(conn_id='azure_default', extra='{ "key_json": { "test": "test" } }'),
- )
- def test_get_conn_with_key_json(self, mock_connection, mock_get_client_from_json_dict):
- mock_sdk_client = Mock()
- auth_sdk_client = AzureBaseHook(mock_sdk_client).get_conn()
- mock_get_client_from_json_dict.assert_called_once_with(
- client_class=mock_sdk_client, config_dict=mock_connection.return_value.extra_dejson['key_json']
- )
- assert auth_sdk_client == mock_get_client_from_json_dict.return_value
- @patch('airflow.providers.microsoft.azure.hooks.base_azure.ServicePrincipalCredentials')
- @patch(
- 'airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook.get_connection',
- return_value=Connection(
- conn_id='azure_default',
- login='my_login',
- password='my_password',
- extra='{ "tenantId": "my_tenant", "subscriptionId": "my_subscription" }',
- ),
- )
- def test_get_conn_with_credentials(self, mock_connection, mock_spc):
- mock_sdk_client = Mock()
- auth_sdk_client = AzureBaseHook(mock_sdk_client).get_conn()
- mock_spc.assert_called_once_with(
- client_id=mock_connection.return_value.login,
- secret=mock_connection.return_value.password,
- tenant=mock_connection.return_value.extra_dejson['tenantId'],
- )
- mock_sdk_client.assert_called_once_with(
- credentials=mock_spc.return_value,
- subscription_id=mock_connection.return_value.extra_dejson['subscriptionId'],
- )
- assert auth_sdk_client == mock_sdk_client.return_value
|