test_azure_key_vault.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. from unittest import TestCase, mock
  19. from azure.core.exceptions import ResourceNotFoundError
  20. from airflow.providers.microsoft.azure.secrets.key_vault import AzureKeyVaultBackend
  21. class TestAzureKeyVaultBackend(TestCase):
  22. @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.get_conn_value')
  23. def test_get_connections(self, mock_get_value):
  24. mock_get_value.return_value = 'scheme://user:pass@host:100'
  25. conn_list = AzureKeyVaultBackend().get_connections('fake_conn')
  26. conn = conn_list[0]
  27. assert conn.host == 'host'
  28. @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.DefaultAzureCredential')
  29. @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.SecretClient')
  30. def test_get_conn_uri(self, mock_secret_client, mock_azure_cred):
  31. mock_cred = mock.Mock()
  32. mock_sec_client = mock.Mock()
  33. mock_azure_cred.return_value = mock_cred
  34. mock_secret_client.return_value = mock_sec_client
  35. mock_sec_client.get_secret.return_value = mock.Mock(
  36. value='postgresql://airflow:airflow@host:5432/airflow'
  37. )
  38. backend = AzureKeyVaultBackend(vault_url="https://example-akv-resource-name.vault.azure.net/")
  39. returned_uri = backend.get_conn_uri(conn_id='hi')
  40. mock_secret_client.assert_called_once_with(
  41. credential=mock_cred, vault_url='https://example-akv-resource-name.vault.azure.net/'
  42. )
  43. assert returned_uri == 'postgresql://airflow:airflow@host:5432/airflow'
  44. @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.client')
  45. def test_get_conn_uri_non_existent_key(self, mock_client):
  46. """
  47. Test that if the key with connection ID is not present,
  48. AzureKeyVaultBackend.get_connections should return None
  49. """
  50. conn_id = 'test_mysql'
  51. mock_client.get_secret.side_effect = ResourceNotFoundError
  52. backend = AzureKeyVaultBackend(vault_url="https://example-akv-resource-name.vault.azure.net/")
  53. assert backend.get_conn_uri(conn_id=conn_id) is None
  54. assert [] == backend.get_connections(conn_id=conn_id)
  55. @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.client')
  56. def test_get_variable(self, mock_client):
  57. mock_client.get_secret.return_value = mock.Mock(value='world')
  58. backend = AzureKeyVaultBackend()
  59. returned_uri = backend.get_variable('hello')
  60. mock_client.get_secret.assert_called_with(name='airflow-variables-hello')
  61. assert 'world' == returned_uri
  62. @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.client')
  63. def test_get_variable_non_existent_key(self, mock_client):
  64. """
  65. Test that if Variable key is not present,
  66. AzureKeyVaultBackend.get_variables should return None
  67. """
  68. mock_client.get_secret.side_effect = ResourceNotFoundError
  69. backend = AzureKeyVaultBackend()
  70. assert backend.get_variable('test_mysql') is None
  71. @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.client')
  72. def test_get_secret_value_not_found(self, mock_client):
  73. """
  74. Test that if a non-existent secret returns None
  75. """
  76. mock_client.get_secret.side_effect = ResourceNotFoundError
  77. backend = AzureKeyVaultBackend()
  78. assert (
  79. backend._get_secret(path_prefix=backend.connections_prefix, secret_id='test_non_existent') is None
  80. )
  81. @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.client')
  82. def test_get_secret_value(self, mock_client):
  83. """
  84. Test that get_secret returns the secret value
  85. """
  86. mock_client.get_secret.return_value = mock.Mock(value='super-secret')
  87. backend = AzureKeyVaultBackend()
  88. secret_val = backend._get_secret('af-secrets', 'test_mysql_password')
  89. mock_client.get_secret.assert_called_with(name='af-secrets-test-mysql-password')
  90. assert secret_val == 'super-secret'
  91. @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend._get_secret')
  92. def test_connection_prefix_none_value(self, mock_get_secret):
  93. """
  94. Test that if Connections prefix is None,
  95. AzureKeyVaultBackend.get_connections should return None
  96. AzureKeyVaultBackend._get_secret should not be called
  97. """
  98. kwargs = {'connections_prefix': None}
  99. backend = AzureKeyVaultBackend(**kwargs)
  100. assert backend.get_conn_uri('test_mysql') is None
  101. mock_get_secret.assert_not_called()
  102. @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend._get_secret')
  103. def test_variable_prefix_none_value(self, mock_get_secret):
  104. """
  105. Test that if Variables prefix is None,
  106. AzureKeyVaultBackend.get_variables should return None
  107. AzureKeyVaultBackend._get_secret should not be called
  108. """
  109. kwargs = {'variables_prefix': None}
  110. backend = AzureKeyVaultBackend(**kwargs)
  111. assert backend.get_variable('hello') is None
  112. mock_get_secret.assert_not_called()
  113. @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend._get_secret')
  114. def test_config_prefix_none_value(self, mock_get_secret):
  115. """
  116. Test that if Config prefix is None,
  117. AzureKeyVaultBackend.get_config should return None
  118. AzureKeyVaultBackend._get_secret should not be called
  119. """
  120. kwargs = {'config_prefix': None}
  121. backend = AzureKeyVaultBackend(**kwargs)
  122. assert backend.get_config('test_mysql') is None
  123. mock_get_secret.assert_not_called()