# -*- coding: utf-8 -*-
# *******************************************************
#   ____                     _               _
#  / ___|___  _ __ ___   ___| |_   _ __ ___ | |
# | |   / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
# | |__| (_) | | | | | |  __/ |_ _| | | | | | |
#  \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
#
#  Sign up for free at http://www.comet.ml
#  Copyright (C) 2021-2025 Comet ML INC
#  This file can not be copied and/or distributed without the express
#  permission of Comet ML Inc.
# *******************************************************

import json
import os
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple

import pandas as pd

from .client import Client
from .config import get_config
from .utils import custom_metrics_to_dataframe

# Type alias for filters that can be either strings or predicate dictionaries
FilterType = List[Dict[str, str]]


def convert_config_filters_to_predicates(filters: List[str]) -> FilterType:
    """
    Convert a list of filter strings to predicate format.

    Args:
        filters: List of filter strings

    Returns:
        List[Dict[str, str]]: List of predicate dictionaries with 'key' and 'query' fields
    """
    return [{"key": filter, "query": filter} for filter in filters]


class Model:
    """
    A model instance for interacting with Comet MPM model-specific operations.

    This class provides high-level methods for querying model predictions, metrics,
    and feature analysis. It can be configured with panel options for default
    parameter values.

    Args:
        client: The Comet MPM client instance
        model_id: The ID of the model to work with
        panel_options: Optional dictionary containing default panel configuration
    """

    def __init__(
        self,
        client: Client,
        model_id: str,
        panel_options: Optional[Dict[str, Any]] = None,
    ):
        """
        Initialize a Model instance.

        Args:
            client: The Comet MPM client instance for making API calls
            model_id: The ID of the model to work with
            panel_options: Optional dictionary containing default panel configuration
                that will be used when parameters are not explicitly provided
        """
        self._client = client
        self.model_id = model_id
        self.panel_options = panel_options

    def get_details(self) -> Dict[str, Any]:
        """
        Get the details of a model.

        Returns:
            Dict[str, Any]: Model details including metadata, configuration, and status
        """
        return self._client.get_model_details(self.model_id)

    def get_nb_predictions(
        self,
        start_date: Optional[str] = None,
        end_date: Optional[str] = None,
        interval_type: Optional[str] = None,
        filters: Optional[FilterType] = None,
        model_version: Optional[str] = None,
    ) -> pd.DataFrame:
        """
        Get the number of predictions for a model within a specified time range.

        Args:
            start_date: Start date for filtering predictions (ISO format)
            end_date: End date for filtering predictions (ISO format)
            interval_type: Type of interval for aggregation ("DAILY" or "HOURLY")
            filters: List of filters to apply to predictions
            model_version: Specific model version to query

        Returns:
            pd.DataFrame: DataFrame containing the number of predictions matching the criteria
        """
        # Use SQL for now:
        df = self.get_custom_metric(
            "SELECT count(*) FROM model",
            start_date=start_date,
            end_date=end_date,
            interval_type=interval_type,
            filters=filters,
            model_version=model_version,
        )
        return df

    def get_custom_metric(
        self,
        sql: str,
        start_date: Optional[str] = None,
        end_date: Optional[str] = None,
        interval_type: Optional[str] = None,
        filters: Optional[FilterType] = None,
        model_version: Optional[str] = None,
    ) -> pd.DataFrame:
        """
        Execute a custom SQL query to retrieve model metrics.

        Args:
            sql: SQL query string to execute
            start_date: Start date for filtering results (ISO format)
            end_date: End date for filtering results (ISO format)
            interval_type: Type of interval for aggregation ("DAILY" or "HOURLY")
            filters: List of filters to apply to results
            model_version: Specific model version to query

        Returns:
            DataFrame: Results of the SQL query
        """
        if self.panel_options is not None:
            if start_date is None:
                start_date = self.panel_options["startDate"]
            if end_date is None:
                end_date = self.panel_options["endDate"]
            if interval_type is None:
                interval_type = self.panel_options["intervalType"]
            if filters is None:
                filters = convert_config_filters_to_predicates(
                    self.panel_options["filters"]
                )
            if model_version is None:
                model_version = self.panel_options["modelVersion"]

        # Ensure all required parameters are provided
        if (
            start_date is None
            or end_date is None
            or interval_type is None
            or filters is None
        ):
            raise ValueError(
                "All parameters (start_date, end_date, interval_type, filters) must be provided"
            )

        data = self._client.get_custom_metrics(
            model_id=self.model_id,
            sql=sql,
            start_date=start_date,
            end_date=end_date,
            interval_type=interval_type,
            filters=filters,
            model_version=model_version,
        )

        df = custom_metrics_to_dataframe(data, sql)
        return df

    def get_feature_drift(
        self,
        feature_name: str,
        algorithm: str = "EMD",
        start_date: Optional[str] = None,
        end_date: Optional[str] = None,
        interval_type: Optional[str] = None,
        filters: Optional[FilterType] = None,
        model_version: Optional[str] = None,
    ) -> pd.DataFrame:
        """
        Calculate drift metrics for a specific feature.

        Args:
            feature_name: Name of the feature to calculate drift for
            algorithm: Drift calculation algorithm ("EMD", "PSI", or "KL")
            start_date: Start date for drift calculation (ISO format)
            end_date: End date for drift calculation (ISO format)
            interval_type: Type of interval for aggregation ("DAILY" or "HOURLY")
            filters: List of filters to apply to drift calculation
            model_version: Specific model version to query

        Returns:
            DataFrame: Drift metrics for the specified feature
        """
        if self.panel_options is not None:
            if start_date is None:
                start_date = self.panel_options["startDate"]
            if end_date is None:
                end_date = self.panel_options["endDate"]
            if interval_type is None:
                interval_type = self.panel_options["intervalType"]
            if filters is None:
                filters = convert_config_filters_to_predicates(
                    self.panel_options["filters"]
                )
            if model_version is None:
                model_version = self.panel_options["modelVersion"]

        # Ensure all required parameters are provided
        if (
            start_date is None
            or end_date is None
            or interval_type is None
            or filters is None
        ):
            raise ValueError(
                "All parameters (start_date, end_date, interval_type, filters) must be provided"
            )

        data = self._client.get_feature_drift(
            feature_name=feature_name,
            algorithm=algorithm,
            model_id=self.model_id,
            start_date=start_date,
            end_date=end_date,
            interval_type=interval_type,
            filters=filters,
            model_version=model_version,
        )
        rows: Dict[Any, Dict[str, Any]] = defaultdict(dict)
        for filter_dict in data["data"]:
            filter_name = filter_dict["predicateKey"]
            for xy in filter_dict["data"]:
                key = xy["x"]  # timestamp
                rows[key][filter_name] = xy["y"]
        # Unwrap
        new_rows = []
        for key in rows:
            row = {"timestamp": key}
            for filter_name in rows[key]:
                row[filter_name] = rows[key][filter_name]
            new_rows.append(row)
        df = pd.DataFrame(new_rows)
        df.set_index("timestamp", inplace=True)
        return df

    def get_feature_category_distribution(
        self,
        feature_name: str,
        normalize: bool = False,
        start_date: Optional[str] = None,
        end_date: Optional[str] = None,
        interval_type: Optional[str] = None,
        filters: Optional[FilterType] = None,
        model_version: Optional[str] = None,
    ) -> pd.DataFrame:
        """
        Get the distribution of categories for a categorical feature.

        Args:
            feature_name: Name of the categorical feature
            normalize: If True, returns percentages instead of counts
            start_date: Start date for distribution calculation (ISO format)
            end_date: End date for distribution calculation (ISO format)
            interval_type: Type of interval for aggregation ("DAILY" or "HOURLY")
            filters: List of filters to apply to distribution calculation
            model_version: Specific model version to query

        Returns:
            DataFrame: Distribution of feature categories
        """
        if self.panel_options is not None:
            if start_date is None:
                start_date = self.panel_options["startDate"]
            if end_date is None:
                end_date = self.panel_options["endDate"]
            if interval_type is None:
                interval_type = self.panel_options["intervalType"]
            if filters is None:
                filters = convert_config_filters_to_predicates(
                    self.panel_options["filters"]
                )
            if model_version is None:
                model_version = self.panel_options["modelVersion"]

        # Ensure all required parameters are provided
        if (
            start_date is None
            or end_date is None
            or interval_type is None
            or filters is None
        ):
            raise ValueError(
                "All parameters (start_date, end_date, interval_type, filters) must be provided"
            )

        data = self._client.get_feature_category_distribution(
            feature_name=feature_name,
            model_id=self.model_id,
            normalize=normalize,
            start_date=start_date,
            end_date=end_date,
            interval_type=interval_type,
            filters=filters,
            model_version=model_version,
        )
        # First, collect the filter columns:
        rows: Dict[str, Dict[str, Any]] = defaultdict(dict)
        for filter_dict in data["data"]:
            for chart_data in filter_dict["chartData"]:
                for point in chart_data["points"]:
                    key = json.dumps((chart_data["value"], point["x"]))
                    rows[key].update({filter_dict["predicateKey"]: point["y"]})
        # Next, unwrap into rows:
        new_rows = []
        for key in rows:
            value, x = json.loads(key)
            new_row = {
                "timestamp": x,
                "value": value,
            }
            new_row.update(rows[key])
            new_rows.append(new_row)
        df = pd.DataFrame(new_rows)
        df.set_index("timestamp", inplace=True)
        return df

    def get_feature_density(
        self,
        feature_name: str,
        start_date: Optional[str] = None,
        end_date: Optional[str] = None,
        filters: Optional[FilterType] = None,
        model_version: Optional[str] = None,
        interval_type: Optional[str] = None,
    ) -> pd.DataFrame:
        """
        Get the probability density function (PDF) of a numeric feature.

        Args:
            feature_name: Name of the numeric feature
            start_date: Start date for density calculation (ISO format)
            end_date: End date for density calculation (ISO format)
            filters: List of filters to apply to density calculation
            model_version: Specific model version to query
            interval_type: Type of interval for aggregation ("DAILY" or "HOURLY")

        Returns:
            DataFrame: Probability density function of the feature values
        """
        if self.panel_options is not None:
            if start_date is None:
                start_date = self.panel_options["startDate"]
            if end_date is None:
                end_date = self.panel_options["endDate"]
            if filters is None:
                filters = convert_config_filters_to_predicates(
                    self.panel_options["filters"]
                )
            if model_version is None:
                model_version = self.panel_options["modelVersion"]
            if interval_type is None:
                interval_type = self.panel_options["intervalType"]

        # Ensure all required parameters are provided
        if (
            start_date is None
            or end_date is None
            or interval_type is None
            or filters is None
        ):
            raise ValueError(
                "All parameters (start_date, end_date, interval_type, filters) must be provided"
            )

        data = self._client.get_feature_density(
            model_id=self.model_id,
            feature_name=feature_name,
            start_date=start_date,
            end_date=end_date,
            filters=filters,
            model_version=model_version,
            interval_type=interval_type,
        )
        rows: Dict[Any, Dict[str, Any]] = defaultdict(dict)
        for item in data["data"]:
            predicate = item["predicateKey"]
            for point in item["pdfDistributionGraphs"]["wholeTimeRangeDistribution"][
                "graphPoints"
            ]:
                key = point["x"]
                rows[key][predicate] = point["y"]
        new_rows = []
        for key in rows:
            row = {"x": key}
            for item in rows[key]:
                row[item] = rows[key][item]
            new_rows.append(row)
        df = pd.DataFrame(new_rows)
        df.set_index("x", inplace=True)
        df.sort_values(by="x", inplace=True)
        return df

    def get_feature_percentiles(
        self,
        feature_name: str,
        percentiles: Optional[List[float]] = None,  # Only these are supported
        start_date: Optional[str] = None,
        end_date: Optional[str] = None,
        filters: Optional[FilterType] = None,
        model_version: Optional[str] = None,
        interval_type: Optional[str] = None,
    ) -> pd.DataFrame:
        if percentiles is None:
            percentiles = [0, 0.1, 0.25, 0.5, 0.75, 0.9, 1]
        """
        Get the specified percentiles for a numeric feature.

        Args:
            feature_name: Name of the numeric feature
            percentiles: List of percentiles to calculate (default: [0, 0.1, 0.25, 0.5, 0.75, 0.9, 1])
                Only these specific percentile values are supported
            start_date: Start date for percentile calculation (ISO format)
            end_date: End date for percentile calculation (ISO format)
            filters: List of filters to apply to percentile calculation
            model_version: Specific model version to query

        Returns:
            DataFrame: Percentile values for the specified feature
        """
        if self.panel_options is not None:
            if start_date is None:
                start_date = self.panel_options["startDate"]
            if end_date is None:
                end_date = self.panel_options["endDate"]
            if interval_type is None:
                interval_type = self.panel_options["intervalType"]
            if filters is None:
                filters = convert_config_filters_to_predicates(
                    self.panel_options["filters"]
                )
            if model_version is None:
                model_version = self.panel_options["modelVersion"]

        # Ensure all required parameters are provided
        if (
            start_date is None
            or end_date is None
            or interval_type is None
            or filters is None
        ):
            raise ValueError(
                "All parameters (start_date, end_date, interval_type, filters) must be provided"
            )

        data = self._client.get_feature_percentiles(
            model_id=self.model_id,
            feature_name=feature_name,
            percentiles=percentiles,
            start_date=start_date,
            end_date=end_date,
            interval_type=interval_type,
            filters=filters,
            model_version=model_version,
        )
        rows: Dict[Tuple[str, str], Dict[str, Any]] = defaultdict(dict)
        for item in data["data"]:
            for chart_data in item["chartData"]:
                predicate_key = item["predicateKey"]
                for point in chart_data["points"]:
                    timestamp = point["x"]
                    key = (timestamp, chart_data["value"])
                    rows[key][predicate_key] = point["y"]
        new_rows = []
        for key in rows:
            timestamp, percentile = key
            row = {
                "timestamp": timestamp,
                "percentile": percentile,
            }
            for predicate in rows[key]:
                row[predicate] = rows[key][predicate]
            new_rows.append(row)
        df = pd.DataFrame(new_rows)
        df.set_index("timestamp", inplace=True)
        return df

    def get_numerical_features(self) -> List[str]:
        """
        Get the list of numerical features available for this model.

        Returns:
            List[str]: List of numerical feature names
        """
        return self._client.get_numerical_features(self.model_id)

    def get_categorical_features(self) -> List[str]:
        """
        Get the list of categorical features available for this model.

        Returns:
            List[str]: List of categorical feature names
        """
        return self._client.get_categorical_features(self.model_id)


class API:
    """
    Main entry point for interacting with the Comet MPM API.

    Provides high-level methods for working with models and workspaces.

    Args:
        api_key: The Comet API key for authentication
    """

    def __init__(self, api_key: Optional[str] = None) -> None:
        """
        Initialize the Comet MPM API client.

        Args:
            api_key: The Comet API key for authentication
        """
        if api_key is None:
            api_key = get_config("comet.api_key")

        if api_key is None:
            api_key = os.environ.get("COMET_API_KEY")

        if api_key is None:
            raise Exception(
                "COMET_API_KEY is not defined in environment, and api_key is not given"
            )

        self._client = Client(api_key)

    def get_model(
        self, workspace_name: Optional[str] = None, model_name: Optional[str] = None
    ) -> Model:
        """
        Get model by workspace_name and model_name.
        """
        if workspace_name is None and model_name is None:
            return self.get_panel_model()
        else:
            raise NotImplementedError(
                "Looking up models by name is not implemented yet"
            )

    def get_panel_model(self) -> Model:
        """
        Get a Model instance configured with panel options from configuration.

        This method creates a Model instance using the panel configuration
        stored in the COMET_PANEL_OPTIONS configuration key.

        Returns:
            Model: A Model instance configured with panel options

        Raises:
            KeyError: If required panel options are missing from configuration
            Exception: If panel options are not properly configured
        """
        panel_options = get_config("COMET_PANEL_OPTIONS")
        return Model(
            client=self._client,
            model_id=panel_options["modelId"],
            panel_options=panel_options,
        )
