import datetime
import math
from typing import Tuple

import numpy as np
import pandas as pd

from sampo.schemas.graph import EdgeType
from sampo.utilities.name_mapper import NameMapper


def get_all_connections(graph_df: pd.DataFrame,
                        use_mapper: bool = False,
                        mapper: NameMapper | None = None) \
        -> Tuple[dict[str, list], dict[str, list]]:

    task_name_column = 'granular_name'

    num_tasks = len(graph_df)
    # Get the upper triangular indices to avoid duplicate pairs
    indices = np.triu_indices(num_tasks, k=1)
    works1_ids = graph_df['activity_id'].values[indices[0]]
    works1_names = graph_df[task_name_column].values[indices[0]]

    works2_ids = graph_df['activity_id'].values[indices[1]]
    works2_names = graph_df[task_name_column].values[indices[1]]

    if use_mapper:
        works1_names = np.vectorize(mapper.get)(works1_names)
        works2_names = np.vectorize(mapper.get)(works2_names)

    return {"ids": works1_ids, "names": works1_names}, {"ids": works2_ids, "names": works2_names}


def get_delta_between_dates(first: str, second: str) -> int:
    return max((datetime.date(int(first.split('-')[0]), int(first.split('-')[1]), int(first.split('-')[2])) -
                datetime.date(int(second.split('-')[0]), int(second.split('-')[1]), int(second.split('-')[2]))).days, 1)


def find_min_without_outliers(lst: list[float]) -> float:
    return round(min([x for x in lst if x >= np.mean(lst) - 3 * np.std(lst)]), 2)


def gather_links_types_statistics(s1: str, f1: str, s2: str, f2: str) \
        -> Tuple[int, int, int, list, list, int, list, list, int, list, list, int, list, list]:
    """
    Count statistics on the occurrence of different mutual arrangement of tasks

    :param s1: start of first work
    :param f1: finish of first work
    :param s2: start of second work
    :param f2: finish of second work
    :return: Statistics on the occurrence of different mutual arrangement of tasks
    """

    fs12, fs21, ss12, ss21 = 0, 0, 0, 0
    ss12_lags, ss12_percent_lags, ss21_lags, ss21_percent_lags = [], [], [], []

    ffs12, ffs21 = 0, 0
    ffs12_lags, ffs12_percent_lags, ffs21_lags, ffs21_percent_lags = [], [], [], []

    if s1 == s2 and f1 == f2:
        ffs12 += 1
        ffs12_percent_lags.append(0.01)
        ffs12_lags.append(0.01)
    if f2 <= s1:
        fs21 += 1
    else:
        if s2 >= f1:
            fs12 += 1
        else:
            if s2 >= s1:
                if f2 >= f1:
                    ffs12 += 1
                    if get_delta_between_dates(f1, s1) != 0:
                        ffs12_percent_lags.append(get_delta_between_dates(s2, s1) / get_delta_between_dates(f1, s1))
                    else:
                        ffs12_percent_lags.append(0)
                    ffs12_lags.append(get_delta_between_dates(s2, s1))
                else:
                    ss12 += 1
                    if get_delta_between_dates(f1, s1) != 0:
                        ss12_percent_lags.append(get_delta_between_dates(s2, s1) / get_delta_between_dates(f1, s1))
                    else:
                        ss12_percent_lags.append(0)
                    ss12_lags.append(get_delta_between_dates(s2, s1))
            else:
                if f2 <= f1:
                    ffs21 += 1
                    if get_delta_between_dates(f2, s2) != 0:
                        ffs21_percent_lags.append(get_delta_between_dates(s1, s2) / get_delta_between_dates(f2, s2))
                    else:
                        ffs21_percent_lags.append(0)
                    ffs21_lags.append(get_delta_between_dates(s1, s2))
                else:
                    ss21 += 1
                    if get_delta_between_dates(f2, s2) != 0:
                        ss21_percent_lags.append(get_delta_between_dates(s1, s2) / get_delta_between_dates(f2, s2))
                    else:
                        ss21_percent_lags.append(0)
                    ss21_lags.append(get_delta_between_dates(s1, s2))
    return fs12, fs21, ss12, ss12_lags, ss12_percent_lags, ss21, ss21_lags, ss21_percent_lags, ffs12, ffs12_lags, ffs12_percent_lags, \
        ffs21, ffs21_lags, ffs21_percent_lags


def get_all_seq_statistic(history_data: pd.DataFrame,
                          graph_df: pd.DataFrame,
                          use_model_name: bool = False,
                          mapper: NameMapper | None = None):
    df_grouped = history_data.copy()

    if use_model_name:
        column_name = 'model_name'
    else:
        if 'granular_name' not in history_data.columns:
            history_data['granular_name'] = [activity_name for activity_name in history_data['activity_name']]
        column_name = 'granular_name'

    df_grouped = df_grouped.groupby('upper_works')[column_name].apply(list).reset_index(name="Works")
    works1, works2 = get_all_connections(graph_df, use_model_name, mapper)

    # Declare structure with updated connections

    tasks_names = list(zip(works1['names'], works2['names']))
    tasks_ids = list(zip(works1['ids'], works2['ids']))

    predecessors_info_dict = {w_id: [] for w_id in graph_df['activity_id']}

    if len(tasks_names) != 0:
        for i in range(len(tasks_names)):
            w1, w2 = tasks_names[i]
            w1_id, w2_id = tasks_ids[i]

            if w1 != w2:
                fs12, fs21, ss12, ss21 = 0, 0, 0, 0
                ss12_lags, ss12_percent_lags, ss21_lags, ss21_percent_lags = [], [], [], []

                count = 0

                ffs12, ffs21 = 0, 0
                ffs12_lags, ffs12_percent_lags, ffs21_lags, ffs21_percent_lags = [], [], [], []

                for i, work_list in df_grouped.iterrows():
                    # Looking to see if this pair of works occurred within the same site in the historical data
                    if w1 in work_list['Works'] and w2 in work_list['Works']:
                        ind1 = history_data.loc[(history_data['upper_works'] == work_list['upper_works']) &
                                                (history_data[column_name] == w1)]
                        ind2 = history_data.loc[(history_data['upper_works'] == work_list['upper_works']) &
                                                (history_data[column_name] == w2)]

                        ind1_sorted = ind1.sort_values(by=['first_day', 'last_day']).reset_index(drop=True)
                        ind2_sorted = ind2.sort_values(by=['first_day', 'last_day']).reset_index(drop=True)

                        for l in range(min(len(ind1_sorted), len(ind2_sorted))):
                            s1, f1 = ind1_sorted.loc[l, 'first_day'], ind1_sorted.loc[l, 'last_day']

                            s2, f2 = ind2_sorted.loc[l, 'first_day'], ind2_sorted.loc[l, 'last_day']

                            if not any([type(x) == float for x in [s1, s2, f1, f2]]):
                                tasks_fs12, tasks_fs21, tasks_ss12, tasks_ss12_lags, tasks_ss12_percent_lags, tasks_ss21, tasks_ss21_lags, \
                                    tasks_ss21_percent_lags, tasks_ffs12, tasks_ffs12_lags, tasks_ffs12_percent_lags, tasks_ffs21, tasks_ffs21_lags, tasks_ffs21_percent_lags = gather_links_types_statistics(
                                    s1, f1, s2, f2)

                                count += 1

                                fs12 += tasks_fs12
                                fs21 += tasks_fs21

                                ss12 += tasks_ss12
                                ss12_lags.extend(tasks_ss12_lags)
                                ss12_percent_lags.extend(tasks_ss12_percent_lags)
                                ss21 += tasks_ss21
                                ss21_lags.extend(tasks_ss21_lags)
                                ss21_percent_lags.extend(tasks_ss21_percent_lags)

                                ffs12 += tasks_ffs12
                                ffs12_lags.extend(tasks_ffs12_lags)
                                ffs12_percent_lags.extend(tasks_ffs12_percent_lags)
                                ffs21 += tasks_ffs21
                                ffs21_lags.extend(tasks_ffs21_lags)
                                ffs21_percent_lags.extend(tasks_ffs21_percent_lags)

                if fs12 + ffs12 + ss12 >= fs21 + ffs21 + ss21:
                    order_con = 1
                    fs = fs12
                    ffs = ffs12
                    ss = ss12
                else:
                    order_con = 2
                    fs = fs21
                    ffs = ffs21
                    ss = ss21

                if max([fs, ss, ffs]) != 0:
                    if fs > ss:
                        if ffs > 0:
                            if order_con == 1:
                                predecessors_info_dict[w2_id].append([w1_id, 'FFS',
                                                                      find_min_without_outliers(ffs12_percent_lags),
                                                                      count])
                            else:
                                predecessors_info_dict[w1_id].append([w2_id, 'FFS',
                                                                      find_min_without_outliers(ffs21_percent_lags),
                                                                      count])
                        else:
                            if order_con == 1:
                                predecessors_info_dict[w2_id].append([w1_id, 'FS', -1, count])
                            else:
                                predecessors_info_dict[w1_id].append([w2_id, 'FS', -1, count])
                    elif ss > ffs:
                        if order_con == 1:
                                predecessors_info_dict[w2_id].append([w1_id, 'SS',
                                                                      find_min_without_outliers(ss12_percent_lags),
                                                                      count])
                        else:
                            predecessors_info_dict[w1_id].append([w2_id, 'SS',
                                                                  find_min_without_outliers(ss21_percent_lags), count])
                    else:
                        if order_con == 1:
                                predecessors_info_dict[w2_id].append([w1_id, 'FFS',
                                                                      find_min_without_outliers(ffs12_percent_lags),
                                                                      count])
                        else:
                            predecessors_info_dict[w1_id].append([w2_id, 'FFS',
                                                                  find_min_without_outliers(ffs21_percent_lags), count])

    return predecessors_info_dict


def set_connections_info(graph_df: pd.DataFrame,
                         history_data: pd.DataFrame,
                         use_model_name: bool = False,
                         mapper: NameMapper | None = None,
                         change_connections_info: bool = False,
                         expert_connections_info: bool = False) \
        -> pd.DataFrame:
    """
    Restore tasks' connection based on history data

    :param: change_connections_info - whether existing connections' information should be modified based on history data
    :param: expert_connections_info - whether existing connections should not be modified based on connection history data
    :return: repaired DataFrame
    """
    tasks_df = graph_df.copy().set_index('activity_id', drop=False)

    # | ------ no changes ------- |
    if not change_connections_info and expert_connections_info:
        predecessor_counts_lst = [[0] * len(tasks_df['predecessor_ids'][i]) for i in range(tasks_df.shape[0])]
        tasks_df['counts'] = predecessor_counts_lst

        tasks_df['connection_types'] = tasks_df['connection_types'].apply(
            lambda x: [EdgeType(elem) if elem != '-1' else EdgeType.FinishStart for elem in x]
        )
        return tasks_df
    # connections_dict = get_all_seq_statistic(history_data, graph_df, use_model_name, mapper)
    # connections_dict = {'25809398': [], '25809830': [['25809840', 'SS', 0.21, 1]], '25809831': [['25809830', 'FFS', 0.01, 49], ['25809850', 'FS', -1, 1], ['25809853', 'FS', -1, 1], ['25809855', 'FS', -1, 4]], '25809833': [['25809830', 'FFS', 0.01, 218], ['25809831', 'FFS', 0.01, 252]], '25813507': [['25809830', 'FFS', 0.01, 49], ['25809831', 'FFS', 0.01, 62], ['25809833', 'FFS', 0.01, 140], ['25809852', 'FS', -1, 1], ['25809854', 'FS', -1, 1]], '25809836': [['25809830', 'FFS', 0.01, 99], ['25809831', 'FFS', 0.01, 137], ['25809833', 'FFS', 0.01, 1070], ['25813507', 'FFS', 0.01, 205]], '25809832': [['25809830', 'FS', -1, 2], ['25809831', 'FFS', 0.01, 2], ['25809833', 'FFS', 0.01, 3], ['25809836', 'FFS', 0.01, 3], ['25809839', 'SS', 0.45, 4], ['25809842', 'FFS', 0.5, 29], ['25809843', 'SS', 0.08, 13]], '25809837': [['25809831', 'FS', -1, 2], ['25809836', 'SS', 0.75, 2], ['25809835', 'FS', -1, 2], ['25809847', 'FS', -1, 1]], '25809838': [['25809831', 'FS', -1, 12], ['25809833', 'FFS', 0.01, 264], ['25813507', 'FFS', 0.01, 49], ['25809836', 'FFS', 0.01, 389], ['25809847', 'FFS', 0.06, 32]], '25809839': [['25809830', 'FS', -1, 15], ['25809831', 'FFS', 0.01, 44], ['25809833', 'FFS', 0.01, 368], ['25813507', 'FFS', 0.33, 27], ['25809836', 'FFS', 0.01, 440], ['25809838', 'FFS', 0.01, 325], ['25809835', 'FS', -1, 2], ['25809847', 'FFS', 0.02, 65], ['25809850', 'FS', -1, 1], ['25809852', 'FFS', 0.05, 282], ['25809848', 'FFS', 0.27, 9], ['25809853', 'FS', -1, 1], ['25809854', 'FFS', 0.05, 282], ['25809855', 'FS', -1, 5]], '25809834': [['25809831', 'FFS', 0.01, 17], ['25809833', 'FS', -1, 9], ['25809836', 'FS', -1, 11], ['25809838', 'FFS', 0.5, 8], ['25809839', 'FS', -1, 5], ['25809835', 'FS', -1, 5], ['25809841', 'FFS', 0.86, 12], ['25809847', 'FS', -1, 7], ['25809848', 'FS', -1, 3]], '25809835': [['25809831', 'FFS', 0.01, 21], ['25809833', 'SS', 0.25, 4], ['25809836', 'SS', 0.19, 2], ['25809838', 'FS', -1, 4], ['25809841', 'SS', 0.76, 1], ['25809852', 'FS', -1, 5], ['25809854', 'FS', -1, 5]], '25809842': [['25809830', 'FFS', 0.01, 6], ['25809831', 'FFS', 0.01, 4], ['25809833', 'FFS', 0.01, 7], ['25809836', 'FFS', 0.01, 9], ['25809839', 'FFS', 0.01, 2], ['25809840', 'SS', 0.79, 1], ['25809841', 'FS', -1, 11]], '25809843': [['25809842', 'FFS', 0.06, 1]], '25809844': [['25809832', 'FFS', 0.01, 4], ['25809843', 'SS', 0.03, 4], ['25809841', 'FS', -1, 1]], '25809840': [['25809831', 'FFS', 0.01, 36], ['25809833', 'FFS', 0.01, 42], ['25809836', 'FFS', 0.17, 45], ['25809838', 'FS', -1, 8], ['25809839', 'FS', -1, 13], ['25809834', 'FFS', 0.01, 22], ['25809835', 'FFS', 0.01, 24], ['25809841', 'FS', -1, 18], ['25809847', 'FS', -1, 2], ['25809850', 'FFS', 0.12, 1], ['25809848', 'FS', -1, 7], ['25809853', 'FFS', 0.12, 1], ['25809855', 'SS', 0.09, 1]], '25809841': [['25809830', 'FFS', 0.01, 174], ['25809831', 'FFS', 0.01, 38], ['25809833', 'FFS', 0.01, 356], ['25813507', 'FFS', 0.01, 77], ['25809836', 'FFS', 0.01, 322], ['25809832', 'SS', 0.32, 7], ['25809838', 'FFS', 0.01, 86], ['25809839', 'FFS', 0.01, 177], ['25809847', 'FFS', 0.2, 80], ['25809850', 'FS', -1, 2], ['25809852', 'FFS', 0.2, 12], ['25809848', 'FS', -1, 17], ['25809853', 'FS', -1, 2], ['25809854', 'FFS', 0.2, 12]], '25809399': [], '25809400': [], '25809847': [['25809830', 'FS', -1, 3], ['25809831', 'FFS', 0.01, 145], ['25809833', 'FFS', 0.01, 272], ['25813507', 'FFS', 0.01, 67], ['25809836', 'FFS', 0.01, 337], ['25809832', 'FS', -1, 1], ['25809850', 'FS', -1, 6], ['25809852', 'FFS', 0.05, 339], ['25809853', 'FS', -1, 6], ['25809854', 'FFS', 0.05, 339]], '25809850': [['25809833', 'FS', -1, 3], ['25809836', 'FFS', 0.67, 4], ['25809838', 'FFS', 0.03, 43]], '25809852': [['25809830', 'FS', -1, 1], ['25809831', 'FFS', 0.01, 7], ['25809833', 'FFS', 0.01, 22], ['25809836', 'FFS', 0.46, 22], ['25809838', 'FFS', 0.5, 14], ['25809834', 'FFS', 0.01, 3], ['25809842', 'FFS', 0.01, 2], ['25809840', 'SS', 0.01, 2], ['25809850', 'FFS', 0.01, 306], ['25809856', 'FS', -1, 5]], '25809848': [['25809830', 'FS', -1, 1], ['25809833', 'FFS', 0.01, 85], ['25809836', 'FFS', 0.01, 132], ['25809838', 'FFS', 0.01, 53], ['25809847', 'FFS', 0.07, 9], ['25809850', 'FFS', 0.01, 68], ['25809852', 'FFS', 0.01, 53], ['25809853', 'FFS', 0.03, 68], ['25809856', 'FS', -1, 1]], '25809853': [['25809833', 'FS', -1, 3], ['25809836', 'FFS', 0.67, 4], ['25809838', 'FFS', 0.03, 43], ['25809852', 'FFS', 0.01, 306]], '25809854': [['25809830', 'FS', -1, 1], ['25809831', 'FFS', 0.01, 7], ['25809833', 'FFS', 0.01, 22], ['25809836', 'FFS', 0.46, 22], ['25809838', 'FFS', 0.5, 14], ['25809834', 'FFS', 0.01, 3], ['25809842', 'FFS', 0.01, 2], ['25809840', 'SS', 0.01, 2], ['25809850', 'FFS', 0.01, 306], ['25809848', 'FFS', 0.01, 53], ['25809853', 'FFS', 0.01, 306], ['25809856', 'FS', -1, 5]], '25809856': [['25809830', 'FS', -1, 2], ['25809831', 'FFS', 0.01, 35], ['25809833', 'FFS', 0.01, 39], ['25813507', 'FS', -1, 21], ['25809836', 'FFS', 0.01, 42], ['25809839', 'FFS', 0.01, 21], ['25809834', 'FS', -1, 2], ['25809840', 'FS', -1, 1], ['25809841', 'FS', -1, 10], ['25809847', 'FS', -1, 41]], '25809401': [], '25809857': [['25809847', 'FS', -1, 1], ['25809852', 'FS', -1, 1], ['25809854', 'FS', -1, 1], ['25809855', 'FS', -1, 1]], '25809858': [['25809836', 'FS', -1, 1], ['25809838', 'FS', -1, 1], ['25809839', 'FFS', 0.01, 56], ['25809841', 'SS', 0.36, 4], ['25809847', 'FS', -1, 4], ['25809850', 'FFS', 0.93, 91], ['25809852', 'FFS', 0.01, 378], ['25809848', 'FFS', 0.01, 17], ['25809853', 'FFS', 0.93, 91], ['25809854', 'FFS', 0.01, 378], ['25809855', 'FFS', 0.96, 97]], '25809855': [['25809833', 'FS', -1, 1], ['25809838', 'FFS', 0.89, 4], ['25809841', 'SS', 0.03, 2], ['25809847', 'FFS', 0.56, 20], ['25809850', 'FFS', 0.01, 190], ['25809852', 'FFS', 0.01, 291], ['25809848', 'FFS', 0.85, 11], ['25809853', 'FFS', 0.01, 190], ['25809854', 'FFS', 0.01, 291]]}
    connections_dict = {'223084': [], '223085': [['223084', 'FFS', 0.01, 1529]], '223086': [['223084', 'FFS', 0.01, 896], ['223085', 'FFS', 0.01, 1070]], '223087': [['223084', 'FFS', 0.01, 173], ['223085', 'FFS', 0.01, 272], ['223086', 'FFS', 0.01, 337], ['223095', 'FS', -1, 1]], '223088': [['223084', 'FFS', 0.09, 18], ['223085', 'FFS', 0.08, 30], ['223086', 'FFS', 0.01, 38], ['223087', 'FFS', 0.01, 52], ['223089', 'FFS', 0.08, 52]], '223089': [['223084', 'FFS', 0.01, 173], ['223085', 'FFS', 0.01, 272], ['223086', 'FFS', 0.01, 337], ['223095', 'FS', -1, 1]], '223092': [['223084', 'FS', -1, 5], ['223085', 'FFS', 0.5, 6], ['223086', 'FFS', 0.36, 6], ['223087', 'FFS', 0.36, 7], ['223089', 'FFS', 0.36, 7], ['223093', 'FFS', 0.22, 1]], '223093': [], '223094': [], '223095': []}
    predecessors_ids_lst, predecessors_types_lst, predecessors_lags_lst, predecessors_counts_lst = [], [], [], []

    for task_id, pred_info_lst in connections_dict.items():
        row = tasks_df.loc[str(task_id)]

        # | ------ change links and info about them ------ |
        # by default, the links and their information change
        if len(pred_info_lst) > 0:
            pred_ids_lst, pred_types_lst, pred_lags_lst, pred_counts_lst = map(list, zip(*pred_info_lst))
        else:
            pred_ids_lst, pred_types_lst, pred_lags_lst, pred_counts_lst = ['-1'], ['-1'], [-1], [0]

        # | ------ change only info ------- |
        # change info about links if they exist
        if not expert_connections_info and not change_connections_info:
            if row['predecessor_ids'] != ['-1']:
                pred_info_lst = [[pred_types_lst[i], pred_lags_lst[i], pred_counts_lst[i]]
                                 for i in range(len(pred_ids_lst)) if pred_ids_lst[i] in row['predecessor_ids']]
                pred_ids_lst = row['predecessor_ids']
                if len(pred_info_lst) > 0:
                    pred_types_lst, pred_lags_lst, pred_counts_lst = map(list, zip(*pred_info_lst))
                else:
                    pred_types_lst, pred_lags_lst, pred_counts_lst = ['-1'], [-1], [0]

        # | ------ change only omitted info -----|
        if expert_connections_info and change_connections_info:
            if row['predecessor_ids'] != ['-1']:
                if row['lags'] != ['-1']:
                    pred_ids_lst = row['predecessor_ids']
                    pred_types_lst = row['connection_types']
                    pred_lags_lst = row['lags']
                    pred_counts_lst = [math.nan] * len(pred_ids_lst)
                else:
                    # if 'lags' is unknown, thus 'connection_type' is also unknown
                    pred_info_lst = [[pred_types_lst[i], pred_lags_lst[i], pred_counts_lst[i]]
                                     for i in range(len(pred_ids_lst)) if pred_ids_lst[i] in row['predecessor_ids']]
                    pred_ids_lst = row['predecessor_ids']
                    if len(pred_info_lst) > 0:
                        pred_types_lst, pred_lags_lst, pred_counts_lst = map(list, zip(*pred_info_lst))
                    else:
                        pred_types_lst, pred_lags_lst, pred_counts_lst = ['-1'], [-1], [0]

        predecessors_ids_lst.append(pred_ids_lst)
        predecessors_types_lst.append(pred_types_lst)
        predecessors_lags_lst.append(pred_lags_lst)
        predecessors_counts_lst.append(pred_counts_lst)
        while len(predecessors_types_lst[-1]) != len(predecessors_ids_lst[-1]):
            predecessors_types_lst[-1].append('FS')
            predecessors_lags_lst[-1].append(-1)
            predecessors_counts_lst[-1].append(0)

    # Convert strings to arrays
    tasks_df['predecessor_ids'] = predecessors_ids_lst
    tasks_df['connection_types'] = predecessors_types_lst
    tasks_df['lags'] = predecessors_lags_lst
    tasks_df['counts'] = predecessors_counts_lst

    tasks_df['connection_types'] = tasks_df['connection_types'].apply(
        lambda x: [EdgeType(elem) if elem != '-1' else EdgeType.FinishStart for elem in x]
    )

    return tasks_df
