test_oracle_to_azure_data_lake.py 4.8 KB

  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. import os
  19. import unittest
  20. from tempfile import TemporaryDirectory
  21. from unittest import mock
  22. from unittest.mock import MagicMock
  23. import unicodecsv as csv
  24. from airflow.providers.microsoft.azure.transfers.oracle_to_azure_data_lake import (
  25. OracleToAzureDataLakeOperator,
  26. )
  27. class TestOracleToAzureDataLakeTransfer(unittest.TestCase):
  28. mock_module_path = 'airflow.providers.microsoft.azure.transfers.oracle_to_azure_data_lake'
  29. def test_write_temp_file(self):
  30. task_id = "some_test_id"
  31. sql = "some_sql"
  32. sql_params = {':p_data': "2018-01-01"}
  33. oracle_conn_id = "oracle_conn_id"
  34. filename = "some_filename"
  35. azure_data_lake_conn_id = 'azure_data_lake_conn_id'
  36. azure_data_lake_path = 'azure_data_lake_path'
  37. delimiter = '|'
  38. encoding = 'utf-8'
  39. cursor_description = [
  40. ('id', "<class 'cx_Oracle.NUMBER'>", 39, None, 38, 0, 0),
  41. ('description', "<class 'cx_Oracle.STRING'>", 60, 240, None, None, 1),
  42. ]
  43. cursor_rows = [[1, 'description 1'], [2, 'description 2']]
  44. mock_cursor = MagicMock()
  45. mock_cursor.description = cursor_description
  46. mock_cursor.__iter__.return_value = cursor_rows
  47. op = OracleToAzureDataLakeOperator(
  48. task_id=task_id,
  49. filename=filename,
  50. oracle_conn_id=oracle_conn_id,
  51. sql=sql,
  52. sql_params=sql_params,
  53. azure_data_lake_conn_id=azure_data_lake_conn_id,
  54. azure_data_lake_path=azure_data_lake_path,
  55. delimiter=delimiter,
  56. encoding=encoding,
  57. )
  58. with TemporaryDirectory(prefix='airflow_oracle_to_azure_op_') as temp:
  59. op._write_temp_file(mock_cursor, os.path.join(temp, filename))
  60. assert os.path.exists(os.path.join(temp, filename)) == 1
  61. with open(os.path.join(temp, filename), 'rb') as csvfile:
  62. temp_file = csv.reader(csvfile, delimiter=delimiter, encoding=encoding)
  63. rownum = 0
  64. for row in temp_file:
  65. if rownum == 0:
  66. assert row[0] == 'id'
  67. assert row[1] == 'description'
  68. else:
  69. assert row[0] == str(cursor_rows[rownum - 1][0])
  70. assert row[1] == cursor_rows[rownum - 1][1]
  71. rownum = rownum + 1
  72. @mock.patch(mock_module_path + '.OracleHook', autospec=True)
  73. @mock.patch(mock_module_path + '.AzureDataLakeHook', autospec=True)
  74. def test_execute(self, mock_data_lake_hook, mock_oracle_hook):
  75. task_id = "some_test_id"
  76. sql = "some_sql"
  77. sql_params = {':p_data': "2018-01-01"}
  78. oracle_conn_id = "oracle_conn_id"
  79. filename = "some_filename"
  80. azure_data_lake_conn_id = 'azure_data_lake_conn_id'
  81. azure_data_lake_path = 'azure_data_lake_path'
  82. delimiter = '|'
  83. encoding = 'latin-1'
  84. cursor_description = [
  85. ('id', "<class 'cx_Oracle.NUMBER'>", 39, None, 38, 0, 0),
  86. ('description', "<class 'cx_Oracle.STRING'>", 60, 240, None, None, 1),
  87. ]
  88. cursor_rows = [[1, 'description 1'], [2, 'description 2']]
  89. cursor_mock = MagicMock()
  90. cursor_mock.description.return_value = cursor_description
  91. cursor_mock.__iter__.return_value = cursor_rows
  92. mock_oracle_conn = MagicMock()
  93. mock_oracle_conn.cursor().return_value = cursor_mock
  94. mock_oracle_hook.get_conn().return_value = mock_oracle_conn
  95. op = OracleToAzureDataLakeOperator(
  96. task_id=task_id,
  97. filename=filename,
  98. oracle_conn_id=oracle_conn_id,
  99. sql=sql,
  100. sql_params=sql_params,
  101. azure_data_lake_conn_id=azure_data_lake_conn_id,
  102. azure_data_lake_path=azure_data_lake_path,
  103. delimiter=delimiter,
  104. encoding=encoding,
  105. )
  106. op.execute(None)
  107. mock_oracle_hook.assert_called_once_with(oracle_conn_id=oracle_conn_id)
  108. mock_data_lake_hook.assert_called_once_with(azure_data_lake_conn_id=azure_data_lake_conn_id)