import numpy as np

def arp_tensor_result(tensor_data, dis_matrix, first_id):
    # TODO 适配tensor 格式
    result_list_id = []
    unsorted_testcase_id_set = set()
    distance_list = {}  # min distance between each unsorted test case and test case in result_list_id
    for i in range(len(tensor_data)):
        unsorted_testcase_id_set.add(i)
        distance_list[i] = float('inf')

    while len(unsorted_testcase_id_set) > 1:
        if len(unsorted_testcase_id_set) == len(tensor_data):
            current_testcase_id = first_id
        else:
            current_testcase_id = max_distance_to_unsorted(distance_list)
        result_list_id.append(current_testcase_id)
        unsorted_testcase_id_set.remove(current_testcase_id)
        distance_list.pop(current_testcase_id)
        for temp_testcase_id in unsorted_testcase_id_set:
            if dis_matrix[current_testcase_id, temp_testcase_id] < distance_list[temp_testcase_id]:
                distance_list[temp_testcase_id] = dis_matrix[current_testcase_id, temp_testcase_id]
    return result_list_id


def max_distance_to_unsorted(distance_list):
    max_distance = float('-inf')
    max_unsorted_id = None
    for testcase_id, distance in distance_list.items():
        if distance > max_distance:
            max_distance = distance
            max_unsorted_id = testcase_id
    if max_unsorted_id is None:
        max_unsorted_id = list(distance_list.keys())[0]
    return max_unsorted_id


def get_distance_matrix_tensor(tensor_list, option):
    t = len(tensor_list)
    dis_matrix = np.zeros((t, t))
    for i in range(len(tensor_list)):
        for j in range(i + 1, len(tensor_list)):
            x = tensor_list[i]
            y = tensor_list[j]
            dis_matrix[i, j] = compute_distance(option, x, y)
            dis_matrix[j, i] = dis_matrix[i, j]
    return dis_matrix


def euclidean_distance(x, y):
    return np.sqrt(np.sum(np.square(x - y)))


def cosine_distance(x, y):
    temp = np.linalg.norm(x) * (np.linalg.norm(y))
    if temp == 0:
        return 1
    else:
        return 1 - np.dot(x, y) / temp


def compute_distance(distance_option, x, y):
    if distance_option == "e":
        return euclidean_distance(x, y)
    elif distance_option == "c":
        return cosine_distance(x, y)