"""
Modeler agent for financial forecasting and scenario analysis.
"""

from __future__ import annotations

from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional

import numpy as np
import pandas as pd
from pydantic import BaseModel, Field

from emberquant.agents.base import BaseAgent
from emberquant.core.emberframe import EmberFrame


class ForecastModel(BaseModel):
    """Forecast model results."""

    model_type: str = Field(description="Type of forecasting model used")
    horizon_days: int = Field(description="Number of days forecasted")
    forecast_data: Dict[str, List[float]] = Field(description="Forecasted values")
    confidence_intervals: Optional[Dict[str, Any]] = Field(
        default=None, description="Confidence intervals for forecasts"
    )
    metrics: Dict[str, Any] = Field(description="Model performance metrics")


class ScenarioAnalysis(BaseModel):
    """Scenario analysis results."""

    base_case: Dict[str, Any] = Field(description="Base case projections")
    best_case: Dict[str, Any] = Field(description="Best case projections")
    worst_case: Dict[str, Any] = Field(description="Worst case projections")
    assumptions: Dict[str, Any] = Field(description="Scenario assumptions")


class CostStructure(BaseModel):
    """Cost structure analysis."""

    fixed_costs: float = Field(description="Estimated fixed costs")
    variable_costs: float = Field(description="Estimated variable costs")
    fixed_cost_items: List[str] = Field(description="Items classified as fixed costs")
    variable_cost_items: List[str] = Field(description="Items classified as variable")
    analysis_period: str = Field(description="Period analyzed")


@dataclass
class ModelingConfig:
    """Configuration for Modeler agent."""

    forecast_horizon_days: int = 90
    confidence_level: float = 0.95
    seasonality_detection: bool = True
    trend_detection: bool = True
    cost_classification_threshold: float = 0.3  # CV threshold for fixed vs variable
    verbose: bool = False


class ModelerAgent(BaseAgent):
    """
    Agent for financial modeling, forecasting, and scenario analysis.

    Capabilities:
    - Time series forecasting
    - Fixed vs variable cost classification
    - Scenario analysis (best/worst/base case)
    - Revenue and expense projections
    """

    def __init__(self, config: Optional[ModelingConfig] = None) -> None:
        """
        Initialize Modeler agent.

        Args:
            config: Optional configuration for modeling parameters
        """
        super().__init__()
        self.config = config or ModelingConfig()

    @property
    def name(self) -> str:
        """Return the agent name."""
        return "Modeler"

    def execute(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        """
        Execute modeling and forecasting tasks.

        Expected inputs:
            - emberframe: EmberFrame to model
            - task: Modeling task (forecast, cost_structure, scenario)

        Returns:
            Dictionary with modeling results
        """
        self._log("Starting financial modeling...")

        emberframe = inputs.get("emberframe")
        if emberframe is None:
            raise ValueError("No EmberFrame provided to Modeler agent")

        if not isinstance(emberframe, EmberFrame):
            raise ValueError(f"Expected EmberFrame, got {type(emberframe)}")

        task = inputs.get("task", "all")

        results = {}

        if task in ["all", "forecast"]:
            results["forecast"] = self._create_forecast(emberframe)

        if task in ["all", "cost_structure"]:
            results["cost_structure"] = self._analyze_cost_structure(emberframe)

        if task in ["all", "scenario"]:
            results["scenario_analysis"] = self._create_scenario_analysis(emberframe)

        self._log("Modeling complete")

        return {
            "status": "success",
            "modeling_results": results,
            "emberframe": emberframe,
        }

    def _create_forecast(self, emberframe: EmberFrame) -> ForecastModel:
        """
        Create financial forecast.

        Args:
            emberframe: Data to forecast

        Returns:
            Forecast model with predictions
        """
        df = emberframe.data

        # Get date column
        date_col = None
        for col in df.columns:
            if "date" in col.lower():
                date_col = col
                break

        if date_col is None:
            raise ValueError("No date column found for forecasting")

        # Prepare time series data
        df[date_col] = pd.to_datetime(df[date_col], errors="coerce")
        df = df.dropna(subset=[date_col])
        df = df.sort_values(date_col)

        # Calculate daily totals
        amount_cols = []
        for col in df.columns:
            if any(x in col.lower() for x in ["amount", "debit", "credit"]):
                amount_cols.append(col)

        if not amount_cols:
            raise ValueError("No amount columns found for forecasting")

        # Aggregate by date
        df_daily = df.groupby(date_col)[amount_cols].sum().reset_index()

        # Simple moving average forecast
        window = min(30, len(df_daily) // 3)
        if window < 3:
            window = 3

        forecast_data = {}
        last_date = df_daily[date_col].max()

        for col in amount_cols:
            # Calculate moving average
            ma = df_daily[col].rolling(window=window).mean()
            trend = df_daily[col].diff().rolling(window=window).mean()

            # Forecast
            forecast_values = []
            last_value = df_daily[col].iloc[-1]
            last_trend = trend.iloc[-1] if not pd.isna(trend.iloc[-1]) else 0

            for i in range(self.config.forecast_horizon_days):
                forecast_val = last_value + (last_trend * (i + 1))
                forecast_values.append(max(0, forecast_val))  # No negative forecasts

            forecast_data[col] = forecast_values

        # Calculate metrics
        metrics = {
            "data_points": len(df_daily),
            "forecast_start": (last_date + timedelta(days=1)).strftime("%Y-%m-%d"),
            "forecast_end": (
                last_date + timedelta(days=self.config.forecast_horizon_days)
            ).strftime("%Y-%m-%d"),
        }

        return ForecastModel(
            model_type="moving_average",
            horizon_days=self.config.forecast_horizon_days,
            forecast_data=forecast_data,
            metrics=metrics,
        )

    def _analyze_cost_structure(self, emberframe: EmberFrame) -> CostStructure:
        """
        Analyze and classify costs as fixed vs variable.

        Args:
            emberframe: Data to analyze

        Returns:
            Cost structure analysis
        """
        df = emberframe.data

        # Find expense/cost columns
        cost_indicators = ["expense", "cost", "debit"]
        category_col = None
        amount_col = None

        for col in df.columns:
            if any(x in col.lower() for x in ["category", "account"]):
                category_col = col
            if any(x in col.lower() for x in ["amount", "debit"]):
                if amount_col is None:  # Take first match
                    amount_col = col

        if category_col is None or amount_col is None:
            # Return empty structure
            return CostStructure(
                fixed_costs=0.0,
                variable_costs=0.0,
                fixed_cost_items=[],
                variable_cost_items=[],
                analysis_period="unknown",
            )

        # Group by category and calculate CV (coefficient of variation)
        category_stats = df.groupby(category_col)[amount_col].agg(["mean", "std", "count"])
        category_stats["cv"] = category_stats["std"] / category_stats["mean"]

        # Classify: Low CV = Fixed, High CV = Variable
        fixed_items = []
        variable_items = []
        fixed_total = 0.0
        variable_total = 0.0

        for category, row in category_stats.iterrows():
            if pd.notna(row["cv"]):
                category_total = row["mean"] * row["count"]
                if row["cv"] < self.config.cost_classification_threshold:
                    fixed_items.append(str(category))
                    fixed_total += category_total
                else:
                    variable_items.append(str(category))
                    variable_total += category_total

        # Determine analysis period
        date_col = None
        for col in df.columns:
            if "date" in col.lower():
                date_col = col
                break

        if date_col:
            df[date_col] = pd.to_datetime(df[date_col], errors="coerce")
            min_date = df[date_col].min()
            max_date = df[date_col].max()
            period = f"{min_date.strftime('%Y-%m-%d')} to {max_date.strftime('%Y-%m-%d')}"
        else:
            period = "unknown"

        return CostStructure(
            fixed_costs=fixed_total,
            variable_costs=variable_total,
            fixed_cost_items=fixed_items,
            variable_cost_items=variable_items,
            analysis_period=period,
        )

    def _create_scenario_analysis(self, emberframe: EmberFrame) -> ScenarioAnalysis:
        """
        Create scenario analysis (base/best/worst case).

        Args:
            emberframe: Data to analyze

        Returns:
            Scenario analysis with three cases
        """
        df = emberframe.data

        # Calculate historical performance
        amount_col = None
        for col in df.columns:
            if any(x in col.lower() for x in ["amount", "revenue", "credit"]):
                amount_col = col
                break

        if amount_col is None:
            # Return default scenarios
            return ScenarioAnalysis(
                base_case={"revenue": 0.0, "growth_rate": 0.0},
                best_case={"revenue": 0.0, "growth_rate": 0.1},
                worst_case={"revenue": 0.0, "growth_rate": -0.1},
                assumptions={"model": "default"},
            )

        # Calculate statistics
        amounts = df[amount_col].dropna()
        mean = amounts.mean()
        std = amounts.std()
        total = amounts.sum()

        # Base case: Current average
        base_revenue = mean * 30  # Monthly projection

        # Best case: Mean + 1 std
        best_revenue = (mean + std) * 30

        # Worst case: Mean - 1 std
        worst_revenue = max(0, (mean - std) * 30)

        return ScenarioAnalysis(
            base_case={
                "monthly_revenue": base_revenue,
                "growth_rate": 0.0,
                "confidence": "medium",
            },
            best_case={
                "monthly_revenue": best_revenue,
                "growth_rate": ((best_revenue - base_revenue) / base_revenue)
                if base_revenue > 0
                else 0.0,
                "confidence": "optimistic",
            },
            worst_case={
                "monthly_revenue": worst_revenue,
                "growth_rate": ((worst_revenue - base_revenue) / base_revenue)
                if base_revenue > 0
                else -0.5,
                "confidence": "pessimistic",
            },
            assumptions={
                "model": "mean_variance",
                "historical_mean": mean,
                "historical_std": std,
                "data_points": len(amounts),
            },
        )
