"""
Auditor Agent: Financial analysis and anomaly detection.

The Auditor is responsible for:
- Benford's law analysis
- Variance detection
- Policy compliance checking
- Anomaly flagging
"""

from __future__ import annotations

from typing import Any, Dict, List, Optional

import numpy as np
import pandas as pd
from pydantic import BaseModel

from emberquant.agents.base import AgentConfig, BaseAgent
from emberquant.core.emberframe import ColumnType, EmberFrame


class AnomalyFinding(BaseModel):
    """Represents a single anomaly or finding."""

    severity: str  # "low", "medium", "high", "critical"
    category: str
    description: str
    affected_rows: Optional[List[int]] = None
    evidence: Optional[Dict[str, Any]] = None


class AuditReport(BaseModel):
    """Complete audit report."""

    total_rows: int
    total_amount: float
    findings: List[AnomalyFinding]
    benford_analysis: Optional[Dict[str, Any]] = None
    variance_analysis: Optional[Dict[str, Any]] = None
    summary: str


class AuditorConfig(AgentConfig):
    """Configuration for the Auditor agent."""

    run_benford: bool = True
    run_variance: bool = True
    benford_threshold: float = 0.15  # Chi-square threshold
    variance_std_threshold: float = 3.0  # Standard deviations for outliers


class AuditorAgent(BaseAgent):
    """
    The Auditor Agent performs financial analysis and anomaly detection.

    Uses statistical methods like Benford's Law to detect potential fraud
    or data quality issues.
    """

    # Expected first digit distribution according to Benford's Law
    BENFORD_DISTRIBUTION = {
        1: 0.301,
        2: 0.176,
        3: 0.125,
        4: 0.097,
        5: 0.079,
        6: 0.067,
        7: 0.058,
        8: 0.051,
        9: 0.046,
    }

    def __init__(self, config: Optional[AuditorConfig] = None) -> None:
        """
        Initialize the Auditor agent.

        Args:
            config: Optional configuration
        """
        super().__init__(config or AuditorConfig())
        self.config: AuditorConfig = self.config  # type: ignore

    @property
    def name(self) -> str:
        """Return the agent name."""
        return "Auditor"

    def execute(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        """
        Execute the Auditor's analysis tasks.

        Expected inputs:
            - emberframe: EmberFrame to audit

        Returns:
            Dictionary with 'audit_report' containing the AuditReport
        """
        self._log("Starting audit analysis...")

        emberframe = inputs.get("emberframe")
        if emberframe is None:
            raise ValueError("No EmberFrame provided to Auditor agent")

        if not isinstance(emberframe, EmberFrame):
            raise ValueError(f"Expected EmberFrame, got {type(emberframe)}")

        findings: List[AnomalyFinding] = []
        benford_result = None
        variance_result = None

        # Get amount columns for analysis
        amount_columns = emberframe.get_all_columns_by_type(ColumnType.AMOUNT)
        if not amount_columns:
            # Fallback to numeric columns
            amount_columns = list(
                emberframe.data.select_dtypes(include=[np.number]).columns
            )

        if not amount_columns:
            self._log("Warning: No numeric columns found for analysis")
            amount_column = None
        else:
            amount_column = amount_columns[0]
            self._log(f"Analyzing column: {amount_column}")

        # Run Benford's Law analysis
        if self.config.run_benford and amount_column:
            benford_result = self._benford_analysis(emberframe.data, amount_column)
            if benford_result["anomaly_detected"]:
                findings.append(
                    AnomalyFinding(
                        severity="medium",
                        category="benford_violation",
                        description=(
                            f"Benford's Law violation detected in {amount_column}. "
                            f"Chi-square statistic: {benford_result['chi_square']:.4f}"
                        ),
                        evidence=benford_result,
                    )
                )

        # Run variance analysis
        if self.config.run_variance and amount_column:
            variance_result = self._variance_analysis(emberframe.data, amount_column)
            if variance_result["outliers_found"] > 0:
                findings.append(
                    AnomalyFinding(
                        severity="low",
                        category="variance_outliers",
                        description=(
                            f"Found {variance_result['outliers_found']} outliers in {amount_column}"
                        ),
                        affected_rows=variance_result["outlier_indices"],
                        evidence=variance_result,
                    )
                )

        # Calculate total amount
        total_amount = 0.0
        if amount_column:
            total_amount = float(emberframe.data[amount_column].sum())

        # Generate summary
        summary = self._generate_summary(len(emberframe), findings)

        report = AuditReport(
            total_rows=len(emberframe),
            total_amount=total_amount,
            findings=findings,
            benford_analysis=benford_result,
            variance_analysis=variance_result,
            summary=summary,
        )

        self._log(f"Audit complete. Found {len(findings)} findings.")

        return {"audit_report": report, "status": "success"}

    def _benford_analysis(self, df: pd.DataFrame, column: str) -> Dict[str, Any]:
        """
        Perform Benford's Law analysis on a numeric column.

        Args:
            df: DataFrame to analyze
            column: Column name to analyze

        Returns:
            Dictionary with analysis results
        """
        # Get absolute values and remove zeros/nulls
        values = df[column].abs()
        values = values[values > 0].dropna()

        if len(values) < 30:
            return {
                "error": "Insufficient data for Benford analysis (need at least 30 values)",
                "anomaly_detected": False,
            }

        # Extract first digit
        first_digits = values.astype(str).str.replace(".", "", regex=False).str[0].astype(int)

        # Calculate observed distribution
        observed_dist = first_digits.value_counts(normalize=True).sort_index()

        # Calculate expected distribution
        expected_dist = pd.Series(self.BENFORD_DISTRIBUTION)

        # Align the series
        digits = range(1, 10)
        observed = [observed_dist.get(d, 0) for d in digits]
        expected = [expected_dist.get(d, 0) for d in digits]

        # Chi-square test
        chi_square = sum(
            ((obs - exp) ** 2) / exp for obs, exp in zip(observed, expected) if exp > 0
        )

        # Critical value for 8 degrees of freedom at 95% confidence is ~15.507
        anomaly_detected = chi_square > self.config.benford_threshold * 100

        return {
            "chi_square": chi_square,
            "anomaly_detected": anomaly_detected,
            "observed_distribution": dict(zip(digits, observed)),
            "expected_distribution": dict(zip(digits, expected)),
            "sample_size": len(values),
        }

    def _variance_analysis(self, df: pd.DataFrame, column: str) -> Dict[str, Any]:
        """
        Perform variance analysis to detect outliers.

        Args:
            df: DataFrame to analyze
            column: Column name to analyze

        Returns:
            Dictionary with analysis results
        """
        values = df[column].dropna()

        if len(values) < 10:
            return {
                "error": "Insufficient data for variance analysis",
                "outliers_found": 0,
            }

        mean = values.mean()
        std = values.std()

        # Identify outliers (beyond N standard deviations)
        threshold = self.config.variance_std_threshold
        outliers = values[np.abs(values - mean) > threshold * std]

        outlier_indices = outliers.index.tolist()

        return {
            "mean": float(mean),
            "std": float(std),
            "outliers_found": len(outliers),
            "outlier_indices": outlier_indices,
            "threshold_std": threshold,
            "min_outlier": float(outliers.min()) if len(outliers) > 0 else None,
            "max_outlier": float(outliers.max()) if len(outliers) > 0 else None,
        }

    def _generate_summary(self, total_rows: int, findings: List[AnomalyFinding]) -> str:
        """
        Generate a human-readable summary of the audit.

        Args:
            total_rows: Total number of rows analyzed
            findings: List of findings

        Returns:
            Summary string
        """
        if not findings:
            return f"Audit complete. Analyzed {total_rows} transactions. No anomalies detected."

        severity_counts = {"low": 0, "medium": 0, "high": 0, "critical": 0}
        for finding in findings:
            severity_counts[finding.severity] += 1

        summary_parts = [f"Audit complete. Analyzed {total_rows} transactions."]

        if severity_counts["critical"] > 0:
            summary_parts.append(f"{severity_counts['critical']} critical issue(s) found.")
        if severity_counts["high"] > 0:
            summary_parts.append(f"{severity_counts['high']} high severity issue(s) found.")
        if severity_counts["medium"] > 0:
            summary_parts.append(f"{severity_counts['medium']} medium severity issue(s) found.")
        if severity_counts["low"] > 0:
            summary_parts.append(f"{severity_counts['low']} low severity issue(s) found.")

        return " ".join(summary_parts)
