import math
import random

from typing import TYPE_CHECKING

from pydantic import BaseModel, ConfigDict, Field

from .constants import (
    API_ERROR_INDICATORS,
    DEFAULT_MAX_RETRIES,
    DEFAULT_REQUEST_TIMEOUT,
    ENGINE_DEFAULT_BATCH_SIZE,
    ENGINE_DEFAULT_NUM_EXAMPLES,
    ENGINE_DEFAULT_TEMPERATURE,
    ERROR_CATEGORIES,
    ERROR_DATASET_FILENAME,
    INTERRUPTED_DATASET_FILENAME,
)
from .dataset import Dataset
from .exceptions import DataSetGeneratorError
from .llm import LLMClient
from .prompts import CONVERSATION_GENERATION_PROMPT
from .schemas import ChatTranscript
from .topic_model import TopicModel

# Handle circular import for type hints
if TYPE_CHECKING:
    from .topic_model import TopicModel


class DataSetGeneratorConfig(BaseModel):
    """Configuration for the data engine."""

    model_config = ConfigDict(arbitrary_types_allowed=True)

    instructions: str = Field(default="", description="Additional instructions for data generation")
    generation_system_prompt: str = Field(
        ..., min_length=1, description="System prompt for content generation"
    )
    dataset_system_prompt: str | None = Field(
        None,
        description="System prompt that goes into the final dataset (falls back to generation_system_prompt if not provided)",
    )
    provider: str = Field(
        ..., min_length=1, description="LLM provider (openai, anthropic, gemini, ollama)"
    )
    model_name: str = Field(..., min_length=1, description="Name of the model to use")
    prompt_template: str | None = Field(default=None, description="Custom prompt template")
    example_data: Dataset | None = Field(
        default=None, description="Example dataset for few-shot learning"
    )
    temperature: float = Field(
        default=ENGINE_DEFAULT_TEMPERATURE,
        ge=0.0,
        le=2.0,
        description="Temperature for model generation",
    )
    max_retries: int = Field(
        default=DEFAULT_MAX_RETRIES,
        ge=1,
        le=10,
        description="Maximum number of retries for failed requests",
    )
    default_batch_size: int = Field(
        default=ENGINE_DEFAULT_BATCH_SIZE,
        ge=1,
        le=100,
        description="Default batch size for generation",
    )
    default_num_examples: int = Field(
        default=ENGINE_DEFAULT_NUM_EXAMPLES,
        ge=0,
        le=10,
        description="Default number of examples to include",
    )
    request_timeout: int = Field(
        default=DEFAULT_REQUEST_TIMEOUT, ge=5, le=300, description="Request timeout in seconds"
    )
    sys_msg: bool = Field(default=True, description="Whether to include system message in dataset")


class DataSetGenerator:
    def __init__(self, **kwargs):
        """Initialize DataSetGenerator with parameters."""
        try:
            self.config = DataSetGeneratorConfig.model_validate(kwargs)
        except Exception as e:
            raise DataSetGeneratorError(f"Invalid generator configuration: {str(e)}") from e  # noqa: TRY003

        # Initialize from config
        self.provider = self.config.provider
        self.model_name = self.config.model_name
        self.dataset = Dataset()
        self.failed_samples = []
        self.failure_analysis = {category: [] for category in ERROR_CATEGORIES}

        # Initialize LLM client
        self.llm_client = LLMClient(
            provider=self.provider,
            model_name=self.model_name,
        )

        # Store dataset system prompt for dataset inclusion (with fallback)
        self.dataset_system_prompt = (
            self.config.dataset_system_prompt or self.config.generation_system_prompt
        )
        # Store generation prompt for content generation
        self.generation_prompt = self.config.generation_system_prompt

    def _validate_create_data_params(
        self,
        num_steps: int,
        batch_size: int,
        topic_model: "TopicModel | None" = None,
    ) -> None:
        """Validate parameters for data creation."""
        if num_steps is None or num_steps <= 0:
            raise DataSetGeneratorError("positive")

        if batch_size <= 0:
            raise DataSetGeneratorError("positive")

        if topic_model and len(topic_model.get_all_paths()) == 0:
            raise DataSetGeneratorError("")

    def _prepare_topic_paths(
        self,
        num_steps: int,
        batch_size: int,
        topic_model: "TopicModel | None" = None,
    ) -> tuple[list | None, int]:
        """Prepare and validate topic paths for data generation."""
        topic_paths = None
        if topic_model is not None:
            topic_paths = topic_model.get_all_paths()
            total_paths = len(topic_paths)
            required_samples = num_steps * batch_size

            if required_samples > total_paths:
                # Provide detailed error with recommendations
                max_steps_for_batch = total_paths // batch_size
                max_batch_for_steps = total_paths // num_steps if num_steps > 0 else total_paths

                error_msg = (
                    f"Insufficient topic paths for dataset generation:\n"
                    f"  • Available paths: {total_paths}\n"
                    f"  • Requested samples: {required_samples} ({num_steps} steps × {batch_size} batch size)\n"
                    f"  • Shortfall: {required_samples - total_paths} samples\n\n"
                    f"Recommendations:\n"
                    f"  • Reduce --num-steps to {max_steps_for_batch} (with current batch size {batch_size})\n"
                    f"  • Reduce --batch-size to {max_batch_for_steps} (with current {num_steps} steps)\n"
                    f"  • Increase topic tree/graph depth or degree to generate more paths"
                )
                raise DataSetGeneratorError(error_msg)

            # Bandit: not a security function
            topic_paths = random.sample(topic_paths, required_samples)  # nosec
            num_steps = math.ceil(len(topic_paths) / batch_size)

        return topic_paths, num_steps

    def _generate_batch_prompts(
        self,
        batch_size: int,
        start_idx: int,
        topic_paths: list,
        data_creation_prompt: str,
        num_example_demonstrations: int,
    ) -> list[str]:
        """Generate prompts for a batch."""
        prompts = []
        for i in range(batch_size):
            path = None
            if topic_paths:
                current_idx = start_idx + i
                if current_idx < len(topic_paths):
                    path = topic_paths[current_idx]
                else:
                    break

            sample_prompt = self.build_prompt(
                data_creation_prompt=data_creation_prompt,
                num_example_demonstrations=num_example_demonstrations,
                subtopics_list=path,
            )
            prompts.append(sample_prompt)
        return prompts

    def _generate_structured_samples(
        self,
        prompts: list[str],
        include_sys_msg: bool,
    ) -> tuple[list, list]:
        """Generate structured samples using Outlines."""
        samples = []
        failed_responses = []

        for prompt in prompts:
            try:
                # Generate structured conversation using ChatTranscript schema
                conversation = self.llm_client.generate(
                    prompt=prompt,
                    schema=ChatTranscript,
                    max_retries=self.config.max_retries,
                    max_tokens=2000,
                    temperature=self.config.temperature,
                )

                # Convert Pydantic model to dict
                sample = conversation.model_dump()

                # Add system message at the start if sys_msg is True
                if include_sys_msg:
                    sample["messages"].insert(
                        0,
                        {
                            "role": "system",
                            "content": self.dataset_system_prompt,
                        },
                    )

                samples.append(sample)

            except Exception as e:
                error_msg = f"Generation failed: {str(e)}"
                failed_responses.append(error_msg)
                failure_type = self.analyze_failure(str(e), error=e)
                self.failure_analysis[failure_type].append(error_msg)

        return samples, failed_responses

    def analyze_failure(self, response_content: str, error: Exception | None = None) -> str:
        """Analyze the failure reason for a sample."""
        if error:
            error_str = str(error)
            if "schema" in error_str.lower():
                return "invalid_schema"
            if any(api_err in error_str.lower() for api_err in API_ERROR_INDICATORS):
                return "api_errors"
            return "other_errors"

        if not response_content or response_content.isspace():
            return "empty_responses"

        # Check if response seems to be attempting JSON but failing
        if any(char in response_content for char in "{}[]"):
            return "json_parsing_errors"
        return "malformed_responses"

    def summarize_failures(self) -> dict:
        """Generate a summary of all failures."""
        summary = {
            "total_failures": len(self.failed_samples),
            "failure_types": {k: len(v) for k, v in self.failure_analysis.items()},
            "failure_examples": {},
        }

        # Add example failures for each category
        for _category, failures in self.failure_analysis.items():
            if failures:
                # Get up to 3 examples for each category
                examples = failures[:3]
                summary["failure_examples"] = [
                    (
                        str(ex)[:200] + "..."
                        if len(str(ex)) > 200  # noqa: PLR2004
                        else str(ex)  # noqa: PLR2004
                    )  # noqa: PLR2004
                    for ex in examples
                ]
        return summary

    def create_data(
        self,
        num_steps: int | None = None,
        num_example_demonstrations: int = 3,
        batch_size: int = 10,
        topic_model: TopicModel | None = None,
        model_name: str | None = None,
        sys_msg: bool | None = None,
    ):
        # Set default value for num_steps if None
        if num_steps is None:
            num_steps = 1

        # Validate inputs
        self._validate_create_data_params(num_steps, batch_size, topic_model)

        # Use instance model_name as fallback if none provided
        if model_name:
            self.model_name = model_name.strip()

        if not self.model_name:
            raise DataSetGeneratorError("")

        # Use provided sys_msg or fall back to config.sys_msg
        include_sys_msg = sys_msg if sys_msg is not None else self.config.sys_msg

        # Prepare topic paths and adjust num_steps if necessary
        topic_paths, num_steps = self._prepare_topic_paths(num_steps, batch_size, topic_model)

        total_samples = num_steps * batch_size
        data_creation_prompt = CONVERSATION_GENERATION_PROMPT

        # Use generator pattern for progress events (no TUI dependencies)
        generator = self._run_generation_loop(
            num_steps=num_steps,
            batch_size=batch_size,
            total_samples=total_samples,
            topic_paths=topic_paths or [],
            data_creation_prompt=data_creation_prompt,
            num_example_demonstrations=num_example_demonstrations,
            include_sys_msg=include_sys_msg,
        )

        # Consume the generator and return the final dataset
        final_result = None
        for event in generator:
            final_result = event

        return final_result

    def create_data_with_events(
        self,
        num_steps: int | None = None,
        num_example_demonstrations: int = 3,
        batch_size: int = 10,
        topic_model: TopicModel | None = None,
        model_name: str | None = None,
        sys_msg: bool | None = None,
    ):
        """Create dataset yielding progress events (for TUI integration)."""
        # Set default value for num_steps if None
        if num_steps is None:
            num_steps = 1

        # Validate inputs
        self._validate_create_data_params(num_steps, batch_size, topic_model)

        # Use instance model_name as fallback if none provided
        if model_name:
            self.model_name = model_name.strip()

        if not self.model_name:
            raise DataSetGeneratorError("")

        # Use provided sys_msg or fall back to config.sys_msg
        include_sys_msg = sys_msg if sys_msg is not None else self.config.sys_msg

        # Prepare topic paths and adjust num_steps if necessary
        topic_paths, num_steps = self._prepare_topic_paths(num_steps, batch_size, topic_model)

        # Use instance model_name as fallback if none provided
        if model_name:
            self.model_name = model_name.strip()

        if not self.model_name:
            raise DataSetGeneratorError("")

        # Use provided sys_msg or fall back to config.sys_msg
        include_sys_msg = sys_msg if sys_msg is not None else self.config.sys_msg

        # Prepare topic paths and adjust num_steps if necessary
        topic_paths, num_steps = self._prepare_topic_paths(num_steps, batch_size, topic_model)

        total_samples = num_steps * batch_size
        data_creation_prompt = CONVERSATION_GENERATION_PROMPT

        # Yield from the generation loop
        yield from self._run_generation_loop(
            num_steps=num_steps,
            batch_size=batch_size,
            total_samples=total_samples,
            topic_paths=topic_paths or [],
            data_creation_prompt=data_creation_prompt,
            num_example_demonstrations=num_example_demonstrations,
            include_sys_msg=include_sys_msg,
        )

    def _run_generation_loop(  # noqa: PLR0912
        self,
        num_steps: int,
        batch_size: int,
        total_samples: int,
        topic_paths: list,
        data_creation_prompt: str,
        num_example_demonstrations: int,
        include_sys_msg: bool,
    ):
        """Run the main generation loop yielding progress events."""
        try:
            # Yield start event
            yield {
                "event": "generation_start",
                "model_name": self.model_name,
                "num_steps": num_steps,
                "batch_size": batch_size,
                "total_samples": total_samples,
            }

            for step in range(num_steps):
                yield {"event": "step_start", "step": step + 1, "total_steps": num_steps}

                start_idx = step * batch_size
                prompts = self._generate_batch_prompts(
                    batch_size,
                    start_idx,
                    topic_paths,
                    data_creation_prompt,
                    num_example_demonstrations,
                )

                success, samples_generated = self._process_batch_with_retries(
                    prompts, include_sys_msg
                )

                yield {
                    "event": "step_complete",
                    "step": step + 1,
                    "samples_generated": samples_generated,
                    "success": success,
                }

                if not success:
                    yield {
                        "event": "step_failed",
                        "step": step + 1,
                        "message": f"Failed to process batch {step + 1} after all retries",
                    }

            yield {
                "event": "generation_complete",
                "total_samples": len(self.dataset),
                "failed_samples": len(self.failed_samples),
            }

        except KeyboardInterrupt:
            yield {"event": "generation_interrupted", "message": "Generation interrupted by user."}
            self.print_failure_summary()
            self.save_dataset(INTERRUPTED_DATASET_FILENAME)

        except Exception as e:
            yield {"event": "generation_error", "error": str(e)}
            self.print_failure_summary()
            self.save_dataset(ERROR_DATASET_FILENAME)
            raise DataSetGeneratorError("failed") from e

        # Always return the dataset as the final result
        yield self.dataset

    def _process_batch_with_retries(
        self,
        prompts: list[str],
        include_sys_msg: bool,
    ) -> tuple[bool, int]:
        """Process a batch with retry logic."""
        for attempt in range(self.config.max_retries):
            try:
                samples, failed_responses = self._generate_structured_samples(
                    prompts, include_sys_msg
                )

                # Update failed samples
                self.failed_samples.extend(failed_responses)

                if samples:
                    failed_samples, failure_descriptions = self.dataset.add_samples(samples)
                    if failed_samples:
                        for sample, desc in zip(failed_samples, failure_descriptions, strict=True):
                            self.failed_samples.append(sample)
                            self.failure_analysis["invalid_schema"].append(desc)

                    successful_samples = len(samples) - len(failed_samples)

                    return True, successful_samples  # Success - exit retry loop
            except DataSetGeneratorError as e:
                # Authentication and API errors are now wrapped in DataSetGeneratorError
                error_str = str(e).lower()
                if any(
                    keyword in error_str
                    for keyword in ["api_key", "api key", "authentication", "unauthorized"]
                ):
                    error_msg = f"Authentication failed for provider '{self.provider}'. Please set the required API key environment variable."
                    self.failure_analysis["authentication_error"].append(error_msg)
                else:
                    error_msg = f"API error for provider '{self.provider}': {str(e)[:100]}..."
                    self.failure_analysis["api_errors"].append(error_msg)

                self.failed_samples.append(error_msg)

                print(f"Error: {error_msg}")

                return False, 0  # Don't retry authentication/API errors
            except Exception as e:
                if attempt == self.config.max_retries - 1:
                    self.failed_samples.append(str(e))
                    failure_type = self.analyze_failure(str(e), error=e)
                    self.failure_analysis[failure_type].append(str(e))
                    return False, 0

        return False, 0

    def print_failure_summary(self):
        """Print a detailed summary of all failures."""
        summary = self.summarize_failures()

        print("\n=== Failure Analysis Summary ===")
        print(f"Total Failed Samples: {summary['total_failures']}")
        print("\nFailure Types Breakdown:")
        for failure_type, count in summary["failure_types"].items():
            if count > 0:
                print(f"\n{failure_type.replace('_', ' ').title()}: {count}")
                if failure_type in summary["failure_examples"]:
                    print("Example failures:")
                    for i, example in enumerate(summary["failure_examples"][failure_type], 1):
                        print(f"  {i}. {example}")
        print("\n=============================")

    def build_prompt(
        self,
        data_creation_prompt: str,
        num_example_demonstrations: int,
        subtopics_list: list[str] | None = None,
    ) -> str:
        prompt = data_creation_prompt.replace("{{{{system_prompt}}}}", self.generation_prompt)
        prompt = prompt.replace("{{{{instructions}}}}", self.build_custom_instructions_text())
        prompt = prompt.replace(
            "{{{{examples}}}}", self.build_examples_text(num_example_demonstrations)
        )
        return prompt.replace("{{{{subtopics}}}}", self.build_subtopics_text(subtopics_list))

    def build_system_prompt(self):
        """Return the original system prompt for dataset inclusion."""
        return self.dataset_system_prompt

    def build_custom_instructions_text(self) -> str:
        if self.config.instructions is None or self.config.instructions == "":
            return ""
        return f"\nHere are additional instructions:\n<instructions>\n{self.config.instructions}\n</instructions>\n"

    def build_examples_text(self, num_example_demonstrations: int):
        if self.config.example_data is None or num_example_demonstrations == 0:
            return ""
        # Bandit: not a security function
        examples = random.sample(self.config.example_data.samples, num_example_demonstrations)  # nosec
        examples_text = "Here are output examples:\n\n"
        examples_text += "\n".join(f"Example {i + 1}: \n\n{ex}\n" for i, ex in enumerate(examples))
        return f"\nHere are output examples:\n<examples>\n{examples_text}\n</examples>\n"

    def build_subtopics_text(self, subtopic_list: list[str] | None):
        if subtopic_list is None:
            return ""
        return f"\nLastly, the topic of the training data should be related to the following subtopics: {' -> '.join(subtopic_list)}"

    def save_dataset(self, save_path: str):
        """Save the dataset to a file."""
        self.dataset.save(save_path)
