from typing import Optional, Literal
import scipy.optimize as opt
from sklearn.metrics import mean_squared_error as mse
import numpy as np
import pandas as pd
import numpy.typing as npt
from .model import EpidemicModel


class ModelFitterError(Exception):
    pass

class ModelFitter:
    # method = 'Nelder-Mead'

    def __init__(self, model: EpidemicModel):
        if not isinstance(model, EpidemicModel):
            raise ModelFitterError('Model must be EpidemicModel')

        self._model = model
        self._population_size = model.population_size
        self._changeable_stages: dict[str, tuple[float, float]] = {}
        self._changeable_factors: dict[str, tuple[float, float]] = {}

        self._real_df: Optional[pd.DataFrame] = None
        self._real_flow_df: Optional[pd.DataFrame] = None
        self._interval: int = 1
        self._mse_history: list[float] = []
        self._duration: int = 0

    @staticmethod
    def _check_bounds_dict(bounds_dict: dict[str, tuple[float, float]], names: list[str]):
        for name, bounds in bounds_dict.items():
            if name not in names:
                raise ModelFitterError(f'Name {name} not found in the model')
            match bounds:
                case int(left) | float(left), int(right) | float(right):
                    pass
                case _:
                    raise ModelFitterError(f'Bounds must be a tuple of floats or ints, got {bounds}')

    def set_changeable_stages(self, changeable_stages: dict[str, tuple[float, float]] | Literal['all', 'none']):
        if changeable_stages == 'all':
            self._changeable_stages = {name: (0, self._population_size) for name in self._model.stage_names}
            return
        elif changeable_stages == 'none':
            self._changeable_stages = {}
            return
        elif isinstance(changeable_stages, dict):
            self._check_bounds_dict(changeable_stages, self._model.stage_names)
            self._changeable_stages = {k: changeable_stages[k] for k in sorted(changeable_stages.keys())}
        else:
            raise ModelFitterError(f'Changeable stages must be a dict or "all" or "none", got {changeable_stages}')

    def set_changeable_factors(self, changeable_factors: dict[str, tuple[float, float]] | Literal['all', 'none']):
        if changeable_factors == 'all':
            self._changeable_factors = {name: (0.0, 1.0) for name in self._model.factor_names}
            return
        elif changeable_factors == 'none':
            self._changeable_factors = {}
            return
        elif isinstance(changeable_factors, dict):
            self._check_bounds_dict(changeable_factors, self._model.factor_names)
            self._changeable_factors = {k: changeable_factors[k] for k in sorted(changeable_factors.keys())}
        else:
            raise ModelFitterError(f'Changeable factors must be a dict or "all" or "none", got {changeable_factors}')

    def fit(self, real_df: pd.DataFrame = None, real_flows_df: pd.DataFrame = None, interval: int = 1):
        if real_df is None and real_flows_df is None:
            raise ModelFitterError('Either real_df or real_flows_df must be provided')
        if not self._changeable_stages and not self._changeable_factors:
            raise ModelFitterError('No stages or factors are changeable')
        if not isinstance(interval, int) or interval < 1:
            raise ModelFitterError('Interval must be int > 1')

        duration1 = 0
        duration2 = 0

        if real_flows_df is not None:
            names = self._model.flow_names
            not_existing = set(real_flows_df.columns) - set(names)
            if not_existing:
                raise ModelFitterError(f'Flows {not_existing} not found in the model')
            self._real_flow_df = real_flows_df.copy()
            duration1 = len(self._real_flow_df) + 1

        if real_df is not None:
            names = self._model.stage_names
            not_existing = set(real_df.columns) - set(names)
            if not_existing:
                raise ModelFitterError(f'Stages {not_existing} not found in the model')
            self._real_df = real_df.copy()
            duration2 = len(self._real_df)

        self._duration = max(duration1, duration2)

        self._interval = interval

        param_start = []
        bounds = []

        for st in self._changeable_stages:
            param_start.append(self._model.stages_dict[st])
            bounds.append(self._changeable_stages[st])

        for fa in self._changeable_factors:
            param_start.append(self._model.factors_dict[fa])
            bounds.append(self._changeable_factors[fa])

        param_start = np.array(param_start, dtype=np.float64)
        self._mse_history = []
        result = opt.minimize(self._get_mse, param_start, method='Nelder-Mead', bounds=bounds)
        return result

    def _get_mse(self, parameters: npt.NDArray[np.float64]):
        start_stages = {stage_name: parameters[i] for i, stage_name in enumerate(self._changeable_stages)}
        factors = {factor_name: parameters[i] for i, factor_name in
                   enumerate(self._changeable_factors, start=len(self._changeable_stages))}

        self._model.set_start_stages(**start_stages)
        self._model.set_factors(**factors)

        self._model.start(self._duration, delta=self._interval, full_save=True)
        mse_full = 0
        if self._real_df is not None:
            real = self._real_df
            model = self._model.result_df.head(len(self._real_df))[self._real_df.columns]
            mse_full += mse(real, model)
        if self._real_flow_df is not None:
            real = self._real_flow_df
            model = self._model.flows_df.head(len(self._real_flow_df))[self._real_flow_df.columns]
            mse_full += mse(real, model)

        self._mse_history.append(mse_full)
        # print(f'MSE in fit: {mse_full}')
        return mse_full