123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155 |
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- import json
- import os
- import random
- import string
- from contextlib import contextmanager
- import pytest
- from airflow.exceptions import AirflowException
- from airflow.models import Connection
- from airflow.providers.microsoft.azure.hooks.fileshare import AzureFileShareHook
- from airflow.utils.process_utils import patch_environ
- from tests.test_utils import AIRFLOW_MAIN_FOLDER
- from tests.test_utils.system_tests_class import SystemTest
- AZURE_DAG_FOLDER = os.path.join(
- AIRFLOW_MAIN_FOLDER, "airflow", "providers", "microsoft", "azure", "example_dags"
- )
- WASB_CONNECTION_ID = os.environ.get("WASB_CONNECTION_ID", "wasb_default")
- DATA_LAKE_CONNECTION_ID = os.environ.get("AZURE_DATA_LAKE_CONNECTION_ID", 'azure_data_lake_default')
- DATA_LAKE_CONNECTION_TYPE = os.environ.get("AZURE_DATA_LAKE_CONNECTION_TYPE", 'azure_data_lake')
- @contextmanager
- def provide_wasb_default_connection(key_file_path: str):
- """
- Context manager to provide a temporary value for wasb_default connection
- :param key_file_path: Path to file with wasb_default credentials .json file.
- """
- if not key_file_path.endswith(".json"):
- raise AirflowException("Use a JSON key file.")
- with open(key_file_path) as credentials:
- creds = json.load(credentials)
- conn = Connection(
- conn_id=WASB_CONNECTION_ID,
- conn_type="wasb",
- host=creds.get("host", None),
- login=creds.get("login", None),
- password=creds.get("password", None),
- extra=json.dumps(creds.get('extra', None)),
- )
- with patch_environ({f"AIRFLOW_CONN_{conn.conn_id.upper()}": conn.get_uri()}):
- yield
- @contextmanager
- def provide_azure_data_lake_default_connection(key_file_path: str):
- """
- Context manager to provide a temporary value for azure_data_lake_default connection
- :param key_file_path: Path to file with azure_data_lake_default credentials .json file.
- """
- required_fields = {'login', 'password', 'extra'}
- if not key_file_path.endswith(".json"):
- raise AirflowException("Use a JSON key file.")
- with open(key_file_path) as credentials:
- creds = json.load(credentials)
- missing_keys = required_fields - creds.keys()
- if missing_keys:
- message = f"{missing_keys} fields are missing"
- raise AirflowException(message)
- conn = Connection(
- conn_id=DATA_LAKE_CONNECTION_ID,
- conn_type=DATA_LAKE_CONNECTION_TYPE,
- host=creds.get("host", None),
- login=creds.get("login", None),
- password=creds.get("password", None),
- extra=json.dumps(creds.get('extra', None)),
- )
- with patch_environ({f"AIRFLOW_CONN_{conn.conn_id.upper()}": conn.get_uri()}):
- yield
- @contextmanager
- def provide_azure_fileshare(share_name: str, azure_fileshare_conn_id: str, file_name: str, directory: str):
- AzureSystemTest.prepare_share(
- share_name=share_name,
- azure_fileshare_conn_id=azure_fileshare_conn_id,
- file_name=file_name,
- directory=directory,
- )
- yield
- AzureSystemTest.delete_share(share_name=share_name, azure_fileshare_conn_id=azure_fileshare_conn_id)
- @pytest.mark.system("azure")
- class AzureSystemTest(SystemTest):
- @classmethod
- def create_share(cls, share_name: str, azure_fileshare_conn_id: str):
- hook = AzureFileShareHook(azure_fileshare_conn_id=azure_fileshare_conn_id)
- hook.create_share(share_name)
- @classmethod
- def delete_share(cls, share_name: str, azure_fileshare_conn_id: str):
- hook = AzureFileShareHook(azure_fileshare_conn_id=azure_fileshare_conn_id)
- hook.delete_share(share_name)
- @classmethod
- def create_directory(cls, share_name: str, azure_fileshare_conn_id: str, directory: str):
- hook = AzureFileShareHook(azure_fileshare_conn_id=azure_fileshare_conn_id)
- hook.create_directory(share_name=share_name, directory_name=directory)
- @classmethod
- def upload_file_from_string(
- cls,
- string_data: str,
- share_name: str,
- azure_fileshare_conn_id: str,
- file_name: str,
- directory: str,
- ):
- hook = AzureFileShareHook(azure_fileshare_conn_id=azure_fileshare_conn_id)
- hook.load_string(
- string_data=string_data,
- share_name=share_name,
- directory_name=directory,
- file_name=file_name,
- )
- @classmethod
- def prepare_share(cls, share_name: str, azure_fileshare_conn_id: str, file_name: str, directory: str):
- """
- Create share with a file in given directory. If directory is None, file is in root dir.
- """
- cls.create_share(share_name=share_name, azure_fileshare_conn_id=azure_fileshare_conn_id)
- cls.create_directory(
- share_name=share_name, azure_fileshare_conn_id=azure_fileshare_conn_id, directory=directory
- )
- string_data = "".join(random.choice(string.ascii_letters) for _ in range(1024))
- cls.upload_file_from_string(
- string_data=string_data,
- share_name=share_name,
- azure_fileshare_conn_id=azure_fileshare_conn_id,
- file_name=file_name,
- directory=directory,
- )
|