from __future__ import annotations
import uuid
import threading
import time
import traceback
import json
import inspect
from datetime import datetime, timezone
from typing import Optional, Dict, Any
import functools
from .sqs import send_to_sqs

# ============================================================
# Safe Serialization Helpers
# ============================================================

def serialize_value(value: Any) -> Any:
    """Safe JSON serializer."""
    try:
        return json.loads(json.dumps(value, default=str))
    except Exception:
        return str(value)


def extract_locals(frame) -> Dict[str, Any]:
    """Extract safe local vars."""
    if not frame:
        return {}
    try:
        return {k: serialize_value(v) for k, v in frame.f_locals.items()}
    except Exception:
        return {}


def extract_self_attrs(frame) -> Dict[str, Any]:
    """Extract safe attributes of self."""
    if not frame:
        return {}
    try:
        self_obj = frame.f_locals.get("self")
        if self_obj is None:
            return {}
        return {
            k: serialize_value(v)
            for k, v in self_obj.__dict__.items()
            if not k.startswith("_")
        }
    except Exception:
        return {}


def get_previous_frame():
    """Returns the calling frame safely."""
    frame = inspect.currentframe()
    return frame.f_back if frame else None


# ============================================================
# Core Trace Manager
# ============================================================

class TraceManager:
    """Centralized trace + span lifecycle manager."""

    _lock = threading.Lock()

    _active: Dict[str, Any] = {
        "trace_id": None,
        "spans": [],
        "stack": [],
    }
    @classmethod
    def finalize_and_send(
        cls,
        *,
        user_id: str,
        session_id: str,
        trace_name: str,
        trace_input: dict,
        trace_output: dict,
        extra_spans: list = [],
    ):
        """
        Safely finalize the trace, attach metadata, and send to SQS.
        """

        bundle = cls.end_trace()

        if not bundle:
            print("[VeriskGO] ERROR: No trace bundle was created.")
            return False

        # Attach extra auto-generated spans
        if extra_spans:
            for span in extra_spans:
                bundle["spans"].append(span)

        # Attach metadata
        bundle["user_id"] = user_id
        bundle["session_id"] = session_id
        bundle["trace_name"] = trace_name
        bundle["trace_input"] = trace_input
        bundle["trace_output"] = trace_output

        # Send to SQS
        send_to_sqs(bundle)
        print("[VeriskGO] Trace sent.\n")

        return True

    @staticmethod
    def _now() -> str:
        return datetime.now(timezone.utc).isoformat()

    @staticmethod
    def _id() -> str:
        return uuid.uuid4().hex

    # --------------------------
    # Trace API
    # --------------------------
    @classmethod
    def has_active_trace(cls) -> bool:
        return cls._active["trace_id"] is not None

    @classmethod
    def start_trace(
        cls,
        name: str,
        metadata: Optional[Dict[str, Any]] = None,
    ) -> str:
        with cls._lock:
            trace_id = cls._id()
            root_id = cls._id()

            root_span = {
                "span_id": root_id,
                "parent_span_id": None,
                "name": name,
                "type": "root",
                "timestamp": cls._now(),
                "input": None,
                "output": None,
                "metadata": metadata or {},
                "duration_ms": 0,
            }

            cls._active["trace_id"] = trace_id
            cls._active["spans"] = [root_span]
            cls._active["stack"] = [{"span_id": root_id, "start": time.time()}]

            return trace_id

    @classmethod
    def end_trace(cls, final_output: Optional[Any] = None) -> Optional[Dict[str, Any]]:
        with cls._lock:
            if not cls._active["trace_id"]:
                return None

            while cls._active["stack"]:
                cls._end_current_span()

            if final_output:
                cls._active["spans"][0]["output"] = final_output

            bundle = {
                "trace_id": cls._active["trace_id"],
                "spans": cls._active["spans"].copy(),
            }

            cls._active["trace_id"] = None
            cls._active["spans"] = []
            cls._active["stack"] = []

            return bundle

    # --------------------------
    # Span API
    # --------------------------
    @classmethod
    def start_span(
        cls,
        name: str,
        input_data: Optional[Any] = None,
        tags: Optional[Dict[str, Any]] = None,
    ) -> Optional[str]:

        with cls._lock:
            if not cls._active["trace_id"]:
                return None

            parent = cls._active["stack"][-1]["span_id"]
            sid = cls._id()

            span = {
                "span_id": sid,
                "parent_span_id": parent,
                "name": name,
                "type": "child",
                "timestamp": cls._now(),
                "input": input_data,
                "metadata": tags or {},
                "output": None,
                "duration_ms": 0,
            }

            cls._active["spans"].append(span)
            cls._active["stack"].append({"span_id": sid, "start": time.time()})

            return sid

    @classmethod
    def end_span(cls, span_id: Optional[str], output_data: Optional[Any] = None):
        with cls._lock:
            if not cls._active["stack"]:
                return

            for i in reversed(range(len(cls._active["stack"]))):
                entry = cls._active["stack"][i]
                if entry["span_id"] == span_id:

                    duration = int((time.time() - entry["start"]) * 1000)
                    cls._active["stack"].pop(i)

                    for sp in cls._active["spans"]:
                        if sp["span_id"] == span_id:
                            sp["duration_ms"] = duration
                            sp["output"] = output_data
                            return

    @classmethod
    def _end_current_span(cls, output_data=None):
        entry = cls._active["stack"].pop()
        sid = entry["span_id"]
        duration = int((time.time() - entry["start"]) * 1000)

        for sp in cls._active["spans"]:
            if sp["span_id"] == sid:
                sp["duration_ms"] = duration
                if output_data:
                    sp["output"] = output_data
                return

    @classmethod
    def add_span(cls, span_dict: Dict[str, Any]):
        with cls._lock:
            cls._active["spans"].append(span_dict)


# ============================================================
# Decorator: track_function
# ============================================================


def track_llm_call(name=None, tags=None):
    """
    Proper Langfuse-compatible LLM span generator.
    Marks span type = 'generation'
    """

    def decorator(func):
        span_name = name or func.__name__

        def parse_bedrock_response(resp):
            """Extract normalized fields."""
            text = resp["output"]["message"]["content"][0]["text"]

            # usage tokens
            usage = resp.get("usage", {})
            input_tokens = usage.get("inputTokens", 0)
            output_tokens = usage.get("outputTokens", 0)
            total_tokens = input_tokens + output_tokens

            # simple cost calc
            input_cost = input_tokens * 0.0000015
            output_cost = output_tokens * 0.000005

            return {
                "text": text,
                "usage_details": {
                    "input": input_tokens,
                    "output": output_tokens,
                    "total": total_tokens,
                },
                "cost_details": {
                    "input": round(input_cost, 6),
                    "output": round(output_cost, 6),
                    "total": round(input_cost + output_cost, 6),
                },
            }

        def wrapper(*args, **kwargs):

            # manually force span type = generation
            span_id = TraceManager.start_span(
                span_name,
                input_data={"prompt": args[0], "tags": tags},
                tags=tags,
            )

            # FIX: change span type at top level
            for sp in TraceManager._active["spans"]:
                if sp["span_id"] == span_id:
                    sp["type"] = "generation"

            start = time.time()

            try:
                resp = func(*args, **kwargs)
                latency = int((time.time() - start) * 1000)

                parsed = parse_bedrock_response(resp)

                TraceManager.end_span(
                    span_id,
                    {
                        "status": "success",
                        "model": resp.get("model"),
                        "latency_ms": latency,
                        "input": {
                            "prompt": args[0],
                            "model": resp.get("model"),
                            "messages": resp.get("messages"),
                        },
                        "output": {
                            "text": parsed["text"],
                            "finish_reason": "stop",
                        },
                        "usage_details": parsed["usage_details"],
                        "cost_details": parsed["cost_details"],
                    },
                )
                return parsed

            except Exception as e:
                latency = int((time.time() - start) * 1000)
                TraceManager.end_span(
                    span_id,
                    {
                        "status": "error",
                        "error": str(e),
                        "stacktrace": traceback.format_exc(),
                        "latency_ms": latency,
                    },
                )
                raise

        return functools.wraps(func)(wrapper)

    return decorator
