123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236 |
- #
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- #
- import json
- import logging
- import unittest
- import uuid
- from unittest import mock
- import pytest
- from azure.cosmos.cosmos_client import CosmosClient
- from airflow.exceptions import AirflowException
- from airflow.models import Connection
- from airflow.providers.microsoft.azure.hooks.cosmos import AzureCosmosDBHook
- from airflow.utils import db
- class TestAzureCosmosDbHook(unittest.TestCase):
- # Set up an environment to test with
- def setUp(self):
- # set up some test variables
- self.test_end_point = 'https://test_endpoint:443'
- self.test_master_key = 'magic_test_key'
- self.test_database_name = 'test_database_name'
- self.test_collection_name = 'test_collection_name'
- self.test_database_default = 'test_database_default'
- self.test_collection_default = 'test_collection_default'
- db.merge_conn(
- Connection(
- conn_id='azure_cosmos_test_key_id',
- conn_type='azure_cosmos',
- login=self.test_end_point,
- password=self.test_master_key,
- extra=json.dumps(
- {
- 'database_name': self.test_database_default,
- 'collection_name': self.test_collection_default,
- }
- ),
- )
- )
- @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient', autospec=True)
- def test_client(self, mock_cosmos):
- hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
- assert hook._conn is None
- assert isinstance(hook.get_conn(), CosmosClient)
- @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
- def test_create_database(self, mock_cosmos):
- hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
- hook.create_database(self.test_database_name)
- expected_calls = [mock.call().create_database('test_database_name')]
- mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
- mock_cosmos.assert_has_calls(expected_calls)
- @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
- def test_create_database_exception(self, mock_cosmos):
- hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
- with pytest.raises(AirflowException):
- hook.create_database(None)
- @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
- def test_create_container_exception(self, mock_cosmos):
- hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
- with pytest.raises(AirflowException):
- hook.create_collection(None)
- @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
- def test_create_container(self, mock_cosmos):
- hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
- hook.create_collection(self.test_collection_name, self.test_database_name)
- expected_calls = [
- mock.call().get_database_client('test_database_name').create_container('test_collection_name')
- ]
- mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
- mock_cosmos.assert_has_calls(expected_calls)
- @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
- def test_create_container_default(self, mock_cosmos):
- hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
- hook.create_collection(self.test_collection_name)
- expected_calls = [
- mock.call().get_database_client('test_database_name').create_container('test_collection_name')
- ]
- mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
- mock_cosmos.assert_has_calls(expected_calls)
- @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
- def test_upsert_document_default(self, mock_cosmos):
- test_id = str(uuid.uuid4())
- # fmt: off
- (mock_cosmos
- .return_value
- .get_database_client
- .return_value
- .get_container_client
- .return_value
- .upsert_item
- .return_value) = {'id': test_id}
- # fmt: on
- hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
- returned_item = hook.upsert_document({'id': test_id})
- expected_calls = [
- mock.call()
- .get_database_client('test_database_name')
- .get_container_client('test_collection_name')
- .upsert_item({'id': test_id})
- ]
- mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
- mock_cosmos.assert_has_calls(expected_calls)
- logging.getLogger().info(returned_item)
- assert returned_item['id'] == test_id
- @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
- def test_upsert_document(self, mock_cosmos):
- test_id = str(uuid.uuid4())
- # fmt: off
- (mock_cosmos
- .return_value
- .get_database_client
- .return_value
- .get_container_client
- .return_value
- .upsert_item
- .return_value) = {'id': test_id}
- # fmt: on
- hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
- returned_item = hook.upsert_document(
- {'data1': 'somedata'},
- database_name=self.test_database_name,
- collection_name=self.test_collection_name,
- document_id=test_id,
- )
- expected_calls = [
- mock.call()
- .get_database_client('test_database_name')
- .get_container_client('test_collection_name')
- .upsert_item({'data1': 'somedata', 'id': test_id})
- ]
- mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
- mock_cosmos.assert_has_calls(expected_calls)
- logging.getLogger().info(returned_item)
- assert returned_item['id'] == test_id
- @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
- def test_insert_documents(self, mock_cosmos):
- test_id1 = str(uuid.uuid4())
- test_id2 = str(uuid.uuid4())
- test_id3 = str(uuid.uuid4())
- documents = [
- {'id': test_id1, 'data': 'data1'},
- {'id': test_id2, 'data': 'data2'},
- {'id': test_id3, 'data': 'data3'},
- ]
- hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
- returned_item = hook.insert_documents(documents)
- expected_calls = [
- mock.call()
- .get_database_client('test_database_name')
- .get_container_client('test_collection_name')
- .create_item({'data': 'data1', 'id': test_id1}),
- mock.call()
- .get_database_client('test_database_name')
- .get_container_client('test_collection_name')
- .create_item({'data': 'data2', 'id': test_id2}),
- mock.call()
- .get_database_client('test_database_name')
- .get_container_client('test_collection_name')
- .create_item({'data': 'data3', 'id': test_id3}),
- ]
- logging.getLogger().info(returned_item)
- mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
- mock_cosmos.assert_has_calls(expected_calls, any_order=True)
- @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
- def test_delete_database(self, mock_cosmos):
- hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
- hook.delete_database(self.test_database_name)
- expected_calls = [mock.call().delete_database('test_database_name')]
- mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
- mock_cosmos.assert_has_calls(expected_calls)
- @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
- def test_delete_database_exception(self, mock_cosmos):
- hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
- with pytest.raises(AirflowException):
- hook.delete_database(None)
- @mock.patch('azure.cosmos.cosmos_client.CosmosClient')
- def test_delete_container_exception(self, mock_cosmos):
- hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
- with pytest.raises(AirflowException):
- hook.delete_collection(None)
- @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
- def test_delete_container(self, mock_cosmos):
- hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
- hook.delete_collection(self.test_collection_name, self.test_database_name)
- expected_calls = [
- mock.call().get_database_client('test_database_name').delete_container('test_collection_name')
- ]
- mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
- mock_cosmos.assert_has_calls(expected_calls)
- @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient')
- def test_delete_container_default(self, mock_cosmos):
- hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
- hook.delete_collection(self.test_collection_name)
- expected_calls = [
- mock.call().get_database_client('test_database_name').delete_container('test_collection_name')
- ]
- mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
- mock_cosmos.assert_has_calls(expected_calls)
|