test_azure_batch.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. #
  19. import json
  20. import unittest
  21. from unittest import mock
  22. from azure.batch import BatchServiceClient, models as batch_models
  23. from airflow.models import Connection
  24. from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook
  25. from airflow.utils import db
  26. class TestAzureBatchHook(unittest.TestCase):
  27. # set up the test environment
  28. def setUp(self):
  29. # set up the test variable
  30. self.test_vm_conn_id = "test_azure_batch_vm"
  31. self.test_cloud_conn_id = "test_azure_batch_cloud"
  32. self.test_account_name = "test_account_name"
  33. self.test_account_key = "test_account_key"
  34. self.test_account_url = "http://test-endpoint:29000"
  35. self.test_vm_size = "test-vm-size"
  36. self.test_vm_publisher = "test.vm.publisher"
  37. self.test_vm_offer = "test.vm.offer"
  38. self.test_vm_sku = "test-sku"
  39. self.test_cloud_os_family = "test-family"
  40. self.test_cloud_os_version = "test-version"
  41. self.test_node_agent_sku = "test-node-agent-sku"
  42. # connect with vm configuration
  43. db.merge_conn(
  44. Connection(
  45. conn_id=self.test_vm_conn_id,
  46. conn_type="azure_batch",
  47. extra=json.dumps({"extra__azure_batch__account_url": self.test_account_url}),
  48. )
  49. )
  50. # connect with cloud service
  51. db.merge_conn(
  52. Connection(
  53. conn_id=self.test_cloud_conn_id,
  54. conn_type="azure_batch",
  55. extra=json.dumps({"extra__azure_batch__account_url": self.test_account_url}),
  56. )
  57. )
  58. def test_connection_and_client(self):
  59. hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id)
  60. assert isinstance(hook._connection(), Connection)
  61. assert isinstance(hook.get_conn(), BatchServiceClient)
  62. def test_configure_pool_with_vm_config(self):
  63. hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id)
  64. pool = hook.configure_pool(
  65. pool_id='mypool',
  66. vm_size="test_vm_size",
  67. target_dedicated_nodes=1,
  68. vm_publisher="test.vm.publisher",
  69. vm_offer="test.vm.offer",
  70. sku_starts_with="test-sku",
  71. )
  72. assert isinstance(pool, batch_models.PoolAddParameter)
  73. def test_configure_pool_with_cloud_config(self):
  74. hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id)
  75. pool = hook.configure_pool(
  76. pool_id='mypool',
  77. vm_size="test_vm_size",
  78. target_dedicated_nodes=1,
  79. vm_publisher="test.vm.publisher",
  80. vm_offer="test.vm.offer",
  81. sku_starts_with="test-sku",
  82. )
  83. assert isinstance(pool, batch_models.PoolAddParameter)
  84. def test_configure_pool_with_latest_vm(self):
  85. with mock.patch(
  86. "airflow.providers.microsoft.azure.hooks."
  87. "batch.AzureBatchHook._get_latest_verified_image_vm_and_sku"
  88. ) as mock_getvm:
  89. hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id)
  90. getvm_instance = mock_getvm
  91. getvm_instance.return_value = ['test-image', 'test-sku']
  92. pool = hook.configure_pool(
  93. pool_id='mypool',
  94. vm_size="test_vm_size",
  95. use_latest_image_and_sku=True,
  96. vm_publisher="test.vm.publisher",
  97. vm_offer="test.vm.offer",
  98. sku_starts_with="test-sku",
  99. )
  100. assert isinstance(pool, batch_models.PoolAddParameter)
  101. @mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
  102. def test_create_pool_with_vm_config(self, mock_batch):
  103. hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id)
  104. mock_instance = mock_batch.return_value.pool.add
  105. pool = hook.configure_pool(
  106. pool_id='mypool',
  107. vm_size="test_vm_size",
  108. target_dedicated_nodes=1,
  109. vm_publisher="test.vm.publisher",
  110. vm_offer="test.vm.offer",
  111. sku_starts_with="test-sku",
  112. )
  113. hook.create_pool(pool=pool)
  114. mock_instance.assert_called_once_with(pool)
  115. @mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
  116. def test_create_pool_with_cloud_config(self, mock_batch):
  117. hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id)
  118. mock_instance = mock_batch.return_value.pool.add
  119. pool = hook.configure_pool(
  120. pool_id='mypool',
  121. vm_size="test_vm_size",
  122. target_dedicated_nodes=1,
  123. vm_publisher="test.vm.publisher",
  124. vm_offer="test.vm.offer",
  125. sku_starts_with="test-sku",
  126. )
  127. hook.create_pool(pool=pool)
  128. mock_instance.assert_called_once_with(pool)
  129. @mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
  130. def test_wait_for_all_nodes(self, mock_batch):
  131. # TODO: Add test
  132. pass
  133. @mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient")
  134. def test_job_configuration_and_create_job(self, mock_batch):
  135. hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id)
  136. mock_instance = mock_batch.return_value.job.add
  137. job = hook.configure_job(job_id='myjob', pool_id='mypool')
  138. hook.create_job(job)
  139. assert isinstance(job, batch_models.JobAddParameter)
  140. mock_instance.assert_called_once_with(job)
  141. @mock.patch('airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient')
  142. def test_add_single_task_to_job(self, mock_batch):
  143. hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id)
  144. mock_instance = mock_batch.return_value.task.add
  145. task = hook.configure_task(task_id="mytask", command_line="echo hello")
  146. hook.add_single_task_to_job(job_id='myjob', task=task)
  147. assert isinstance(task, batch_models.TaskAddParameter)
  148. mock_instance.assert_called_once_with(job_id="myjob", task=task)
  149. @mock.patch('airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient')
  150. def test_wait_for_all_task_to_complete(self, mock_batch):
  151. # TODO: Add test
  152. pass