# veriskgo/llm.py
import time
import traceback
import functools
import inspect
from typing import Any, Dict
from typing import   Optional
from veriskgo.trace_manager import TraceManager
from veriskgo.trace_manager import serialize_value


def normalize_bedrock_usage(response: dict) -> Dict[str, int]:
    """Normalize Bedrock token usage."""
    usage = response.get("usage", {})

    input_tokens = usage.get("inputTokens", 0)
    output_tokens = usage.get("outputTokens", 0)

    return {
        "input": input_tokens,
        "output": output_tokens,
        "total": input_tokens + output_tokens,
    }


def extract_bedrock_text(response: dict) -> str:
    """Extract text from Bedrock response."""
    try:
        return response["output"]["message"]["content"][0]["text"]
    except Exception:
        return ""

# ============================================================
# Decorator: track_llm_call (Generation Span)
# ============================================================

def track_llm_call(
    name: Optional[str] = None,
    *,
    tags: Optional[Dict[str, Any]] = None,
):
    """
    Specialized decorator for LLM calls.
    Produces Langfuse-compatible `generation` spans.

    EXPECTS THE DECORATED FUNCTION TO RETURN RAW PROVIDER RESPONSE:
    {
        "model": "...",
        "output": {
            "message": { "content": [ { "text": "hello" } ] }
        },
        "usage": { "inputTokens": 12, "outputTokens": 20 }
        ...
    }
    """

    def decorator(func):
        span_name = name or func.__name__
        is_async = inspect.iscoroutinefunction(func)

        # ------------------------------ async wrapper ------------------------------
        async def async_wrapper(*args, **kwargs):

            if not TraceManager.has_active_trace():
                return await func(*args, **kwargs)

            prompt = args[0] if args else kwargs.get("prompt", "")

            span_id = TraceManager.start_span(
                span_name,
                input_data={
                    "prompt": prompt,
                    "tags": tags or {},
                },
                tags={"span_type": "generation", **(tags or {})}
            )

            start = time.time()

            try:
                response = await func(*args, **kwargs)
                duration = int((time.time() - start) * 1000)

                output_block = _process_llm_response(prompt, response, duration)

                TraceManager.end_span(span_id, output_block)
                return response

            except Exception as e:
                duration = int((time.time() - start) * 1000)

                error_block = {
                    "status": "error",
                    "latency_ms": duration,
                    "error": str(e),
                    "stacktrace": traceback.format_exc(),
                }

                TraceManager.end_span(span_id, error_block)
                raise

        # ------------------------------ sync wrapper ------------------------------
        def sync_wrapper(*args, **kwargs):

            if not TraceManager.has_active_trace():
                return func(*args, **kwargs)

            prompt = args[0] if args else kwargs.get("prompt", "")

            span_id = TraceManager.start_span(
                span_name,
                input_data={
                    "prompt": prompt,
                    "tags": tags or {},
                },
                tags={"span_type": "generation", **(tags or {})}
            )

            start = time.time()

            try:
                response = func(*args, **kwargs)
                duration = int((time.time() - start) * 1000)

                output_block = _process_llm_response(prompt, response, duration)

                TraceManager.end_span(span_id, output_block)
                return response

            except Exception as e:
                duration = int((time.time() - start) * 1000)

                error_block = {
                    "status": "error",
                    "latency_ms": duration,
                    "error": str(e),
                    "stacktrace": traceback.format_exc(),
                }

                TraceManager.end_span(span_id, error_block)
                raise

        return functools.wraps(func)(async_wrapper if is_async else sync_wrapper)

    return decorator


# ============================================================
# Helper for LLM Response Processing
# ============================================================

def _process_llm_response(prompt: str, response: Dict[str, Any], latency_ms: int):
    """
    Converts Bedrock response → Langfuse-compliant `generation` span format.
    """

    # Extract output text
    try:
        text = response["output"]["message"]["content"][0]["text"]
    except Exception:
        text = ""

    # Extract usage
    usage = response.get("usage", {})
    input_tokens = usage.get("inputTokens", 0)
    output_tokens = usage.get("outputTokens", 0)
    total_tokens = input_tokens + output_tokens

    # Optional cost calculation
    input_cost = input_tokens * 0.0000015
    output_cost = output_tokens * 0.000005
    total_cost = input_cost + output_cost

    model = response.get("model", "unknown")

    return {
        "status": "success",
        "type": "generation",
        "latency_ms": latency_ms,
        "input": {
            "prompt": prompt,
            "model": model,
            "messages": [
                {"role": "user", "content": prompt}
            ],
        },
        "output": {
            "text": text,
            "finish_reason": "stop"
        },
        "usage_details": {
            "input": input_tokens,
            "output": output_tokens,
            "total": total_tokens,
        },
        "cost_details": {
            "input": round(input_cost, 6),
            "output": round(output_cost, 6),
            "total": round(total_cost, 6),
        },
        "raw_response": serialize_value(response),
    }
