"""
EmberFrame: Semantic DataFrame for Financial Data

A pandas-backed data structure with semantic understanding of financial concepts,
column types, and business rules.
"""

from __future__ import annotations

from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Union

import pandas as pd
from pydantic import BaseModel, Field, field_validator


class ColumnType(str, Enum):
    """Semantic column types for financial data."""

    ACCOUNT = "account"
    AMOUNT = "amount"
    DATE = "date"
    CATEGORY = "category"
    ENTITY = "entity"
    DESCRIPTION = "description"
    TRANSACTION_ID = "transaction_id"
    REFERENCE = "reference"
    DEBIT = "debit"
    CREDIT = "credit"
    BALANCE = "balance"
    UNKNOWN = "unknown"


class ColumnMetadata(BaseModel):
    """Metadata for a column in an EmberFrame."""

    name: str
    semantic_type: ColumnType
    original_name: Optional[str] = None
    confidence: float = Field(ge=0.0, le=1.0, default=1.0)
    description: Optional[str] = None
    format_hint: Optional[str] = None


class FrameMetadata(BaseModel):
    """Metadata for an entire EmberFrame."""

    source: str
    source_type: str  # e.g., "quickbooks", "csv", "xero"
    ingestion_timestamp: datetime = Field(default_factory=datetime.now)
    row_count: int
    column_count: int
    columns: List[ColumnMetadata]
    entity: Optional[str] = None  # Company/entity name
    period_start: Optional[datetime] = None
    period_end: Optional[datetime] = None
    currency: str = "USD"
    notes: Optional[str] = None


class EmberFrame:
    """
    Semantic DataFrame for financial data.

    EmberFrame wraps a pandas DataFrame with rich metadata about column semantics,
    data provenance, and financial context. It provides methods for semantic
    operations like account classification, period analysis, and rule validation.
    """

    def __init__(
        self,
        data: pd.DataFrame,
        metadata: FrameMetadata,
    ) -> None:
        """
        Initialize an EmberFrame.

        Args:
            data: The underlying pandas DataFrame
            metadata: Rich metadata about the frame's structure and provenance
        """
        self._data = data.copy()
        self._metadata = metadata
        self._validate()

    def _validate(self) -> None:
        """Validate that metadata matches the actual DataFrame structure."""
        if len(self._data) != self._metadata.row_count:
            self._metadata.row_count = len(self._data)

        if len(self._data.columns) != self._metadata.column_count:
            self._metadata.column_count = len(self._data.columns)

        # Ensure all metadata columns exist in the DataFrame
        metadata_cols = {col.name for col in self._metadata.columns}
        df_cols = set(self._data.columns)

        if metadata_cols != df_cols:
            missing_in_df = metadata_cols - df_cols
            missing_in_metadata = df_cols - metadata_cols
            if missing_in_df:
                raise ValueError(f"Columns in metadata but not in DataFrame: {missing_in_df}")
            if missing_in_metadata:
                raise ValueError(
                    f"Columns in DataFrame but not in metadata: {missing_in_metadata}"
                )

    @classmethod
    def from_dataframe(
        cls,
        df: pd.DataFrame,
        source: str,
        source_type: str,
        infer_types: bool = True,
        **metadata_kwargs: Any,
    ) -> EmberFrame:
        """
        Create an EmberFrame from a pandas DataFrame.

        Args:
            df: Source DataFrame
            source: Source identifier (e.g., file path, API endpoint)
            source_type: Type of source (e.g., "csv", "quickbooks")
            infer_types: Whether to automatically infer semantic column types
            **metadata_kwargs: Additional metadata fields

        Returns:
            New EmberFrame instance
        """
        columns_metadata: List[ColumnMetadata] = []

        for col in df.columns:
            semantic_type = ColumnType.UNKNOWN
            confidence = 0.5

            if infer_types:
                semantic_type, confidence = cls._infer_column_type(col, df[col])

            columns_metadata.append(
                ColumnMetadata(
                    name=col,
                    semantic_type=semantic_type,
                    original_name=col,
                    confidence=confidence,
                )
            )

        metadata = FrameMetadata(
            source=source,
            source_type=source_type,
            row_count=len(df),
            column_count=len(df.columns),
            columns=columns_metadata,
            **metadata_kwargs,
        )

        return cls(df, metadata)

    @staticmethod
    def _infer_column_type(column_name: str, series: pd.Series) -> tuple[ColumnType, float]:
        """
        Infer the semantic type of a column based on name and content.

        Args:
            column_name: Name of the column
            series: The column data

        Returns:
            Tuple of (semantic_type, confidence_score)
        """
        col_lower = column_name.lower()

        # High confidence inference based on name
        if any(kw in col_lower for kw in ["account", "acct"]):
            return ColumnType.ACCOUNT, 0.9
        if any(kw in col_lower for kw in ["date", "datetime", "timestamp"]):
            return ColumnType.DATE, 0.9
        if "debit" in col_lower or col_lower == "dr":
            return ColumnType.DEBIT, 0.95
        if "credit" in col_lower or col_lower == "cr":
            return ColumnType.CREDIT, 0.95
        if "balance" in col_lower or "bal" in col_lower:
            return ColumnType.BALANCE, 0.9
        if any(kw in col_lower for kw in ["description", "desc", "memo", "narrative"]):
            return ColumnType.DESCRIPTION, 0.9
        if any(kw in col_lower for kw in ["category", "type", "class"]):
            return ColumnType.CATEGORY, 0.8
        if any(kw in col_lower for kw in ["entity", "company", "vendor", "customer"]):
            return ColumnType.ENTITY, 0.8
        if any(kw in col_lower for kw in ["id", "transaction_id", "trans_id", "txn"]):
            return ColumnType.TRANSACTION_ID, 0.85
        if any(kw in col_lower for kw in ["amount", "value", "total"]):
            return ColumnType.AMOUNT, 0.8

        # Content-based inference
        if pd.api.types.is_numeric_dtype(series):
            return ColumnType.AMOUNT, 0.6
        if pd.api.types.is_datetime64_any_dtype(series):
            return ColumnType.DATE, 0.8

        return ColumnType.UNKNOWN, 0.3

    @property
    def data(self) -> pd.DataFrame:
        """Access the underlying DataFrame (returns a copy)."""
        return self._data.copy()

    @property
    def metadata(self) -> FrameMetadata:
        """Access the frame metadata."""
        return self._metadata

    def get_column_by_type(self, semantic_type: ColumnType) -> Optional[str]:
        """
        Get the first column matching a semantic type.

        Args:
            semantic_type: The semantic type to search for

        Returns:
            Column name if found, None otherwise
        """
        for col in self._metadata.columns:
            if col.semantic_type == semantic_type:
                return col.name
        return None

    def get_all_columns_by_type(self, semantic_type: ColumnType) -> List[str]:
        """
        Get all columns matching a semantic type.

        Args:
            semantic_type: The semantic type to search for

        Returns:
            List of column names
        """
        return [col.name for col in self._metadata.columns if col.semantic_type == semantic_type]

    def filter_by_date_range(
        self, start_date: datetime, end_date: datetime, date_column: Optional[str] = None
    ) -> EmberFrame:
        """
        Filter the frame by a date range.

        Args:
            start_date: Start of the range (inclusive)
            end_date: End of the range (inclusive)
            date_column: Optional specific date column; if None, uses inferred date column

        Returns:
            New filtered EmberFrame
        """
        if date_column is None:
            date_column = self.get_column_by_type(ColumnType.DATE)
            if date_column is None:
                raise ValueError("No date column found in frame")

        # Ensure column is datetime
        date_series = pd.to_datetime(self._data[date_column])

        mask = (date_series >= start_date) & (date_series <= end_date)
        filtered_df = self._data[mask].copy()

        # Update metadata
        new_metadata = self._metadata.model_copy(deep=True)
        new_metadata.row_count = len(filtered_df)
        new_metadata.period_start = start_date
        new_metadata.period_end = end_date

        return EmberFrame(filtered_df, new_metadata)

    def aggregate_by_category(
        self, category_column: Optional[str] = None, amount_column: Optional[str] = None
    ) -> pd.DataFrame:
        """
        Aggregate amounts by category.

        Args:
            category_column: Column to group by; if None, uses inferred category column
            amount_column: Column to sum; if None, uses inferred amount column

        Returns:
            DataFrame with aggregated results
        """
        if category_column is None:
            category_column = self.get_column_by_type(ColumnType.CATEGORY)
            if category_column is None:
                raise ValueError("No category column found in frame")

        if amount_column is None:
            amount_column = self.get_column_by_type(ColumnType.AMOUNT)
            if amount_column is None:
                raise ValueError("No amount column found in frame")

        return self._data.groupby(category_column)[amount_column].sum().reset_index()

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary representation."""
        return {
            "data": self._data.to_dict(orient="records"),
            "metadata": self._metadata.model_dump(),
        }

    def __repr__(self) -> str:
        """String representation of the EmberFrame."""
        return (
            f"EmberFrame(source={self._metadata.source}, "
            f"rows={self._metadata.row_count}, "
            f"cols={self._metadata.column_count}, "
            f"type={self._metadata.source_type})"
        )

    def __len__(self) -> int:
        """Return the number of rows."""
        return len(self._data)

    def head(self, n: int = 5) -> pd.DataFrame:
        """Return the first n rows."""
        return self._data.head(n)

    def tail(self, n: int = 5) -> pd.DataFrame:
        """Return the last n rows."""
        return self._data.tail(n)

    def describe(self) -> pd.DataFrame:
        """Generate descriptive statistics."""
        return self._data.describe()
