test_azure_cosmos.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. #
  19. import json
  20. import logging
  21. import unittest
  22. import uuid
  23. from unittest import mock
  24. import pytest
  25. from azure.cosmos.cosmos_client import CosmosClient
  26. from airflow.exceptions import AirflowException
  27. from airflow.models import Connection
  28. from airflow.providers.microsoft.azure.hooks.cosmos import AzureCosmosDBHook
  29. from airflow.utils import db
  30. class TestAzureCosmosDbHook(unittest.TestCase):
  31. # Set up an environment to test with
  32. def setUp(self):
  33. # set up some test variables
  34. self.test_end_point = 'https://test_endpoint:443'
  35. self.test_master_key = 'magic_test_key'
  36. self.test_database_name = 'test_database_name'
  37. self.test_collection_name = 'test_collection_name'
  38. self.test_database_default = 'test_database_default'
  39. self.test_collection_default = 'test_collection_default'
  40. db.merge_conn(
  41. Connection(
  42. conn_id='azure_cosmos_test_key_id',
  43. conn_type='azure_cosmos',
  44. login=self.test_end_point,
  45. password=self.test_master_key,
  46. extra=json.dumps(
  47. {
  48. 'database_name': self.test_database_default,
  49. 'collection_name': self.test_collection_default,
  50. }
  51. ),
  52. )
  53. )
  54. @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient', autospec=True)
  55. def test_client(self, mock_cosmos):
  56. hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
  57. assert hook._conn is None
  58. assert isinstance(hook.get_conn(), CosmosClient)
  59. @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
  60. def test_create_database(self, mock_cosmos):
  61. hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
  62. hook.create_database(self.test_database_name)
  63. expected_calls = [mock.call().create_database('test_database_name')]
  64. mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
  65. mock_cosmos.assert_has_calls(expected_calls)
  66. @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
  67. def test_create_database_exception(self, mock_cosmos):
  68. hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
  69. with pytest.raises(AirflowException):
  70. hook.create_database(None)
  71. @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
  72. def test_create_container_exception(self, mock_cosmos):
  73. hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
  74. with pytest.raises(AirflowException):
  75. hook.create_collection(None)
  76. @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
  77. def test_create_container(self, mock_cosmos):
  78. hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
  79. hook.create_collection(self.test_collection_name, self.test_database_name)
  80. expected_calls = [
  81. mock.call().get_database_client('test_database_name').create_container('test_collection_name')
  82. ]
  83. mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
  84. mock_cosmos.assert_has_calls(expected_calls)
  85. @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
  86. def test_create_container_default(self, mock_cosmos):
  87. hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
  88. hook.create_collection(self.test_collection_name)
  89. expected_calls = [
  90. mock.call().get_database_client('test_database_name').create_container('test_collection_name')
  91. ]
  92. mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
  93. mock_cosmos.assert_has_calls(expected_calls)
  94. @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
  95. def test_upsert_document_default(self, mock_cosmos):
  96. test_id = str(uuid.uuid4())
  97. # fmt: off
  98. (mock_cosmos
  99. .return_value
  100. .get_database_client
  101. .return_value
  102. .get_container_client
  103. .return_value
  104. .upsert_item
  105. .return_value) = {'id': test_id}
  106. # fmt: on
  107. hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
  108. returned_item = hook.upsert_document({'id': test_id})
  109. expected_calls = [
  110. mock.call()
  111. .get_database_client('test_database_name')
  112. .get_container_client('test_collection_name')
  113. .upsert_item({'id': test_id})
  114. ]
  115. mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
  116. mock_cosmos.assert_has_calls(expected_calls)
  117. logging.getLogger().info(returned_item)
  118. assert returned_item['id'] == test_id
  119. @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
  120. def test_upsert_document(self, mock_cosmos):
  121. test_id = str(uuid.uuid4())
  122. # fmt: off
  123. (mock_cosmos
  124. .return_value
  125. .get_database_client
  126. .return_value
  127. .get_container_client
  128. .return_value
  129. .upsert_item
  130. .return_value) = {'id': test_id}
  131. # fmt: on
  132. hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
  133. returned_item = hook.upsert_document(
  134. {'data1': 'somedata'},
  135. database_name=self.test_database_name,
  136. collection_name=self.test_collection_name,
  137. document_id=test_id,
  138. )
  139. expected_calls = [
  140. mock.call()
  141. .get_database_client('test_database_name')
  142. .get_container_client('test_collection_name')
  143. .upsert_item({'data1': 'somedata', 'id': test_id})
  144. ]
  145. mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
  146. mock_cosmos.assert_has_calls(expected_calls)
  147. logging.getLogger().info(returned_item)
  148. assert returned_item['id'] == test_id
  149. @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
  150. def test_insert_documents(self, mock_cosmos):
  151. test_id1 = str(uuid.uuid4())
  152. test_id2 = str(uuid.uuid4())
  153. test_id3 = str(uuid.uuid4())
  154. documents = [
  155. {'id': test_id1, 'data': 'data1'},
  156. {'id': test_id2, 'data': 'data2'},
  157. {'id': test_id3, 'data': 'data3'},
  158. ]
  159. hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
  160. returned_item = hook.insert_documents(documents)
  161. expected_calls = [
  162. mock.call()
  163. .get_database_client('test_database_name')
  164. .get_container_client('test_collection_name')
  165. .create_item({'data': 'data1', 'id': test_id1}),
  166. mock.call()
  167. .get_database_client('test_database_name')
  168. .get_container_client('test_collection_name')
  169. .create_item({'data': 'data2', 'id': test_id2}),
  170. mock.call()
  171. .get_database_client('test_database_name')
  172. .get_container_client('test_collection_name')
  173. .create_item({'data': 'data3', 'id': test_id3}),
  174. ]
  175. logging.getLogger().info(returned_item)
  176. mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
  177. mock_cosmos.assert_has_calls(expected_calls, any_order=True)
  178. @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
  179. def test_delete_database(self, mock_cosmos):
  180. hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
  181. hook.delete_database(self.test_database_name)
  182. expected_calls = [mock.call().delete_database('test_database_name')]
  183. mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
  184. mock_cosmos.assert_has_calls(expected_calls)
  185. @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
  186. def test_delete_database_exception(self, mock_cosmos):
  187. hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
  188. with pytest.raises(AirflowException):
  189. hook.delete_database(None)
  190. @mock.patch('azure.cosmos.cosmos_client.CosmosClient')
  191. def test_delete_container_exception(self, mock_cosmos):
  192. hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
  193. with pytest.raises(AirflowException):
  194. hook.delete_collection(None)
  195. @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
  196. def test_delete_container(self, mock_cosmos):
  197. hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
  198. hook.delete_collection(self.test_collection_name, self.test_database_name)
  199. expected_calls = [
  200. mock.call().get_database_client('test_database_name').delete_container('test_collection_name')
  201. ]
  202. mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
  203. mock_cosmos.assert_has_calls(expected_calls)
  204. @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
  205. def test_delete_container_default(self, mock_cosmos):
  206. hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
  207. hook.delete_collection(self.test_collection_name)
  208. expected_calls = [
  209. mock.call().get_database_client('test_database_name').delete_container('test_collection_name')
  210. ]
  211. mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
  212. mock_cosmos.assert_has_calls(expected_calls)