import logging
import time
from typing import Optional

from opentelemetry._events import EventLogger
from .config import Config
from .event_emitter import (
    emit_streaming_response_events,
)
from .span_utils import (
    set_streaming_response_attributes,
)
from .utils import (
    count_prompt_tokens_from_request,
    dont_throw,
    error_metrics_attributes,
    set_span_attribute,
    shared_metrics_attributes,
    should_emit_events,
)
from opentelemetry.metrics import Counter, Histogram
from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import (
    GEN_AI_RESPONSE_ID,
)
from opentelemetry.semconv_ai import SpanAttributes
from opentelemetry.trace.status import Status, StatusCode

logger = logging.getLogger(__name__)


@dont_throw
def _process_response_item(item, complete_response):
    if item.type == "message_start":
        complete_response["model"] = item.message.model
        complete_response["usage"] = dict(item.message.usage)
        complete_response["id"] = item.message.id
    elif item.type == "content_block_start":
        index = item.index
        if len(complete_response.get("events")) <= index:
            complete_response["events"].append(
                {"index": index, "text": "", "type": item.content_block.type}
            )
            if item.content_block.type == "tool_use":
                complete_response["events"][index]["id"] = item.content_block.id
                complete_response["events"][index]["name"] = item.content_block.name
                complete_response["events"][index]["input"] = ""

    elif item.type == "content_block_delta":
        index = item.index
        if item.delta.type == "thinking_delta":
            complete_response["events"][index]["text"] += item.delta.thinking or ""
        elif item.delta.type == "text_delta":
            complete_response["events"][index]["text"] += item.delta.text or ""
        elif item.delta.type == "input_json_delta":
            complete_response["events"][index]["input"] += item.delta.partial_json
    elif item.type == "message_delta":
        for event in complete_response.get("events", []):
            event["finish_reason"] = item.delta.stop_reason
        if item.usage:
            if "usage" in complete_response:
                item_output_tokens = dict(item.usage).get("output_tokens", 0)
                existing_output_tokens = complete_response["usage"].get(
                    "output_tokens", 0
                )
                complete_response["usage"]["output_tokens"] = (
                    item_output_tokens + existing_output_tokens
                )
            else:
                complete_response["usage"] = dict(item.usage)


def _set_token_usage(
    span,
    complete_response,
    prompt_tokens,
    completion_tokens,
    metric_attributes: dict = {},
    token_histogram: Histogram = None,
    choice_counter: Counter = None,
):
    cache_read_tokens = (
        complete_response.get("usage", {}).get("cache_read_input_tokens", 0) or 0
    )
    cache_creation_tokens = (
        complete_response.get("usage", {}).get("cache_creation_input_tokens", 0) or 0
    )

    input_tokens = prompt_tokens + cache_read_tokens + cache_creation_tokens
    total_tokens = input_tokens + completion_tokens

    set_span_attribute(span, SpanAttributes.LLM_USAGE_PROMPT_TOKENS, input_tokens)
    set_span_attribute(
        span, SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, completion_tokens
    )
    set_span_attribute(span, SpanAttributes.LLM_USAGE_TOTAL_TOKENS, total_tokens)

    set_span_attribute(
        span, SpanAttributes.LLM_RESPONSE_MODEL, complete_response.get("model")
    )
    set_span_attribute(
        span, SpanAttributes.LLM_USAGE_CACHE_READ_INPUT_TOKENS, cache_read_tokens
    )
    set_span_attribute(
        span,
        SpanAttributes.LLM_USAGE_CACHE_CREATION_INPUT_TOKENS,
        cache_creation_tokens,
    )

    if token_histogram and type(input_tokens) is int and input_tokens >= 0:
        token_histogram.record(
            input_tokens,
            attributes={
                **metric_attributes,
                SpanAttributes.LLM_TOKEN_TYPE: "input",
            },
        )

    if token_histogram and type(completion_tokens) is int and completion_tokens >= 0:
        token_histogram.record(
            completion_tokens,
            attributes={
                **metric_attributes,
                SpanAttributes.LLM_TOKEN_TYPE: "output",
            },
        )

    if type(complete_response.get("events")) is list and choice_counter:
        for event in complete_response.get("events"):
            choice_counter.add(
                1,
                attributes={
                    **metric_attributes,
                    SpanAttributes.LLM_RESPONSE_FINISH_REASON: event.get(
                        "finish_reason"
                    ),
                },
            )


def _handle_streaming_response(span, event_logger, complete_response):
    if should_emit_events() and event_logger:
        emit_streaming_response_events(event_logger, complete_response)
    else:
        if not span.is_recording():
            return
        set_streaming_response_attributes(span, complete_response.get("events"))


@dont_throw
def build_from_streaming_response(
    span,
    response,
    instance,
    start_time,
    token_histogram: Histogram = None,
    choice_counter: Counter = None,
    duration_histogram: Histogram = None,
    exception_counter: Counter = None,
    event_logger: Optional[EventLogger] = None,
    kwargs: dict = {},
):
    complete_response = {"events": [], "model": "", "usage": {}, "id": ""}
    for item in response:
        try:
            yield item
        except Exception as e:
            attributes = error_metrics_attributes(e)
            if exception_counter:
                exception_counter.add(1, attributes=attributes)
            raise e
        _process_response_item(item, complete_response)

    metric_attributes = shared_metrics_attributes(complete_response)
    set_span_attribute(span, GEN_AI_RESPONSE_ID, complete_response.get("id"))
    if duration_histogram:
        duration = time.time() - start_time
        duration_histogram.record(
            duration,
            attributes=metric_attributes,
        )

    # calculate token usage
    if Config.enrich_token_usage:
        try:
            completion_tokens = -1
            # prompt_usage
            if usage := complete_response.get("usage"):
                prompt_tokens = usage.get("input_tokens", 0) or 0
            else:
                prompt_tokens = count_prompt_tokens_from_request(instance, kwargs)

            # completion_usage
            if usage := complete_response.get("usage"):
                completion_tokens = usage.get("output_tokens", 0) or 0
            else:
                completion_content = ""
                if complete_response.get("events"):
                    model_name = complete_response.get("model") or None
                    for event in complete_response.get("events"):
                        if event.get("text"):
                            completion_content += event.get("text")

                    if model_name and hasattr(instance, "count_tokens"):
                        completion_tokens = instance.count_tokens(completion_content)

            _set_token_usage(
                span,
                complete_response,
                prompt_tokens,
                completion_tokens,
                metric_attributes,
                token_histogram,
                choice_counter,
            )
        except Exception as e:
            logger.warning("Failed to set token usage, error: %s", e)

    _handle_streaming_response(span, event_logger, complete_response)

    if span.is_recording():
        span.set_status(Status(StatusCode.OK))
        span.end()


@dont_throw
async def abuild_from_streaming_response(
    span,
    response,
    instance,
    start_time,
    token_histogram: Histogram = None,
    choice_counter: Counter = None,
    duration_histogram: Histogram = None,
    exception_counter: Counter = None,
    event_logger: Optional[EventLogger] = None,
    kwargs: dict = {},
):
    complete_response = {"events": [], "model": "", "usage": {}, "id": ""}
    async for item in response:
        try:
            yield item
        except Exception as e:
            attributes = error_metrics_attributes(e)
            if exception_counter:
                exception_counter.add(1, attributes=attributes)
            raise e
        _process_response_item(item, complete_response)

    set_span_attribute(span, GEN_AI_RESPONSE_ID, complete_response.get("id"))

    metric_attributes = shared_metrics_attributes(complete_response)

    if duration_histogram:
        duration = time.time() - start_time
        duration_histogram.record(
            duration,
            attributes=metric_attributes,
        )

    # calculate token usage
    if Config.enrich_token_usage:
        try:
            # prompt_usage
            if usage := complete_response.get("usage"):
                prompt_tokens = usage.get("input_tokens", 0)
            else:
                prompt_tokens = count_prompt_tokens_from_request(instance, kwargs)

            # completion_usage
            if usage := complete_response.get("usage"):
                completion_tokens = usage.get("output_tokens", 0)
            else:
                completion_content = ""
                if complete_response.get("events"):
                    model_name = complete_response.get("model") or None
                    for event in complete_response.get("events"):
                        if event.get("text"):
                            completion_content += event.get("text")

                    if model_name and hasattr(instance, "count_tokens"):
                        completion_tokens = instance.count_tokens(completion_content)

            _set_token_usage(
                span,
                complete_response,
                prompt_tokens,
                completion_tokens,
                metric_attributes,
                token_histogram,
                choice_counter,
            )
        except Exception as e:
            logger.warning("Failed to set token usage, error: %s", str(e))

    _handle_streaming_response(span, event_logger, complete_response)

    if span.is_recording():
        span.set_status(Status(StatusCode.OK))
        span.end()


class WrappedMessageStreamManager:
    """Wrapper for MessageStreamManager that handles instrumentation"""

    def __init__(
        self,
        stream_manager,
        span,
        instance,
        start_time,
        token_histogram,
        choice_counter,
        duration_histogram,
        exception_counter,
        event_logger,
        kwargs,
    ):
        self._stream_manager = stream_manager
        self._span = span
        self._instance = instance
        self._start_time = start_time
        self._token_histogram = token_histogram
        self._choice_counter = choice_counter
        self._duration_histogram = duration_histogram
        self._exception_counter = exception_counter
        self._event_logger = event_logger
        self._kwargs = kwargs

    def __enter__(self):
        # Call the original stream manager's __enter__ to get the actual stream
        stream = self._stream_manager.__enter__()
        # Return the wrapped stream
        return build_from_streaming_response(
            self._span,
            stream,
            self._instance,
            self._start_time,
            self._token_histogram,
            self._choice_counter,
            self._duration_histogram,
            self._exception_counter,
            self._event_logger,
            self._kwargs,
        )

    def __exit__(self, exc_type, exc_val, exc_tb):
        return self._stream_manager.__exit__(exc_type, exc_val, exc_tb)


class WrappedAsyncMessageStreamManager:
    """Wrapper for AsyncMessageStreamManager that handles instrumentation"""

    def __init__(
        self,
        stream_manager,
        span,
        instance,
        start_time,
        token_histogram,
        choice_counter,
        duration_histogram,
        exception_counter,
        event_logger,
        kwargs,
    ):
        self._stream_manager = stream_manager
        self._span = span
        self._instance = instance
        self._start_time = start_time
        self._token_histogram = token_histogram
        self._choice_counter = choice_counter
        self._duration_histogram = duration_histogram
        self._exception_counter = exception_counter
        self._event_logger = event_logger
        self._kwargs = kwargs

    async def __aenter__(self):
        # Call the original stream manager's __aenter__ to get the actual stream
        stream = await self._stream_manager.__aenter__()
        # Return the wrapped stream
        return abuild_from_streaming_response(
            self._span,
            stream,
            self._instance,
            self._start_time,
            self._token_histogram,
            self._choice_counter,
            self._duration_histogram,
            self._exception_counter,
            self._event_logger,
            self._kwargs,
        )

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        return await self._stream_manager.__aexit__(exc_type, exc_val, exc_tb)
