import asyncio
import logging
import os
import time
from typing import List

from litellm import acompletion
from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall

from eval_protocol.dataset_logger import default_logger
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig

logger = logging.getLogger(__name__)


class SingleTurnRolloutProcessor(RolloutProcessor):
    """Single turn rollout processor for direct LLM calls."""

    def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
        """Generate single turn rollout tasks and return them for external handling."""

        # Quiet LiteLLM logs in test runs unless user overrode
        try:
            if os.environ.get("LITELLM_LOG") is None:
                os.environ["LITELLM_LOG"] = "ERROR"
            _llog = logging.getLogger("LiteLLM")
            _llog.setLevel(logging.CRITICAL)
            _llog.propagate = False
            for _h in list(_llog.handlers):
                _llog.removeHandler(_h)
        except Exception:
            pass

        # Do not modify global LiteLLM cache. Disable caching per-request instead.

        async def process_row(row: EvaluationRow) -> EvaluationRow:
            """Process a single row asynchronously."""
            if len(row.messages) == 0:
                raise ValueError("Messages is empty. Please provide a non-empty dataset")

            messages_payload = [{"role": m.role, "content": m.content} for m in row.messages]

            request_params = {"messages": messages_payload, **config.completion_params}
            # Ensure caching is disabled only for this request (review feedback)
            request_params["cache"] = {"no-cache": True}
            # Single-level reasoning effort: expect `reasoning_effort` only
            effort_val = None

            if "reasoning_effort" in config.completion_params:
                effort_val = str(config.completion_params["reasoning_effort"])  # flat shape
            elif (
                isinstance(config.completion_params.get("extra_body"), dict)
                and "reasoning_effort" in config.completion_params["extra_body"]
            ):
                # Accept if user passed it directly inside extra_body
                effort_val = str(config.completion_params["extra_body"]["reasoning_effort"])  # already in extra_body

            if effort_val:
                # Always under extra_body so LiteLLM forwards to provider-specific param set
                request_params.setdefault("extra_body", {})
                request_params["extra_body"]["reasoning_effort"] = effort_val
                # Ensure unsupported top-level keys are not present
                if "reasoning_effort" in request_params:
                    request_params.pop("reasoning_effort", None)

            if row.tools is not None:
                request_params["tools"] = row.tools

            # Dynamic import to avoid static dependency/lint errors if LiteLLM isn't installed yet
            import importlib

            _litellm = importlib.import_module("litellm")
            acompletion = getattr(_litellm, "acompletion")
            response = await acompletion(**request_params)

            assistant_content = response.choices[0].message.content or ""
            tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None

            converted_tool_calls = None
            if tool_calls:
                converted_tool_calls = [
                    ChatCompletionMessageToolCall(
                        id=tool_call.id,
                        type=tool_call.type,
                        function={
                            "name": tool_call.function.name,
                            "arguments": tool_call.function.arguments,
                        },
                    )
                    for tool_call in tool_calls
                ]

            messages = list(row.messages) + [
                Message(
                    role="assistant",
                    content=assistant_content,
                    tool_calls=converted_tool_calls,
                )
            ]

            row.messages = messages
            default_logger.log(row)
            return row

        # Process rows with bounded concurrency
        max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
        semaphore = asyncio.Semaphore(max_concurrent)

        async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
            async with semaphore:
                result = await process_row(r)
                return result

        # Create and return tasks for external handling
        tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
        return tasks
