import threading
import itertools
import traceback
from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor

from tqdm import tqdm

from synthegrator.code_problems import CodeProblem, CodeSolution
from synthegrator.code_solver import BaseCodeSolver
from synthegrator.solution_eval import SolutionEvaluation, SolutionEvaluator


class SolveThread(threading.Thread):
    def __init__(
        self,
        solver: BaseCodeSolver,
        problem: CodeProblem,
    ) -> None:
        super().__init__()
        self.solver = solver
        self.problem = problem

    def run(self) -> CodeSolution | None:
        try:
            solution = self.solver.solve(self.problem)
        except (
            Exception
        ) as e:  # We don't want one thread failure to hang the experiment
            print("Solver exception")
            print(e)
            print(traceback.format_exc())
        else:
            return solution
        return None


class EvaluateThread(threading.Thread):
    def __init__(
        self,
        solution: CodeSolution,
        evaluator: SolutionEvaluator = None,
    ) -> None:
        super().__init__()
        self.solution = solution
        self.evaluator = evaluator
        if solution is not None and evaluator is None:
            self.evaluator = solution.problem.preferred_solution_evaluator

    def run(self) -> SolutionEvaluation:
        try:
            if self.solution is None:
                msg = "Solution is None"
                raise ValueError(msg)
            evaluation = self.evaluator.evaluate(self.solution)
        except (
            Exception
        ) as e:  # We don't want one thread failure to hang the experiment
            if "Too many open files" in str(e):
                # Custom error message
                custom_message = (
                    f"{e!s}\n\nThis is a previously observed known issue. "
                    "It might either be a bug/leak in synthegrator, or if you are "
                    "actually running on a massive "
                    "number of threads it might be a ulimit problem where it can be fixed by "
                    "increasing limits. The exact recommendation is not clear yet though. "
                    "Please make a GitHub issue if you observe."
                )
                raise Exception(custom_message) from e
            return SolutionEvaluation(
                solution=self.solution,
                test_results=None,
                extra_metrics=[],
                main_metric=None,
                exception=e,
                exception_traceback=traceback.format_exc(),
            )
        else:
            return evaluation


def solve_all_problems(
    solver: BaseCodeSolver,
    problems: Iterable[CodeProblem],
    max_threads: int = 1,
    cuda_clear_freq: int = 100,
) -> Iterable[CodeSolution]:
    """
    Solve all problems in an iterable.

    Args:
    ----
        solver (BaseCodeSolver): The solver to use.
        problems (Iterable[CodeProblem]): The problems to solve. You might
            want to slice it beforehand if you want to limit the number of
            problems solved.
        max_threads (int, optional): The maximum number of threads to use.
            Defaults to 1 (no multi-threading). You should also check whether
            solver.allows_multithreading is True.
        cuda_clear_freq (int, optional): The frequency at which to clear the
            CUDA cache. Defaults to 100 assuming cuda is available (else will
            not do anything).

    """
    # Safety check for None or empty problems
    if problems is None:
        raise ValueError("problems parameter cannot be None")

    is_empty, problems = is_generator_empty(problems)
    if is_empty:
        return iter(())
    
    # Proceed with the rest of the function
    try:
        import torch
    except ImportError:
        cuda_clear_freq = None
    else:
        if not torch.cuda.is_available():
            cuda_clear_freq = None
    if max_threads > 1 and not solver.allows_multithreading:
        msg = f"{solver.__class__.__name__} does not allow multithreading"
        raise ValueError(
            msg,
        )
    if cuda_clear_freq and max_threads > 1:
        msg = "cuda_clear_freq and max_threads > 1 are incompatible"
        raise ValueError(
            msg,
        )
    cuda_reset_counter = 0

    solve_threads = (
        SolveThread(
            solver,
            problem,
        )
        for problem in problems
    )

    def pbar(data):
        return tqdm(
            data,
            total=len(problems) if hasattr(problems, "__len__") else None,
            desc="Solving problems",
        )

    if max_threads <= 1:
        for solve_thread in pbar(
            solve_threads,
        ):
            if cuda_clear_freq is not None and cuda_reset_counter > cuda_clear_freq:
                torch.cuda.empty_cache()
                cuda_reset_counter = 0
            solution = solve_thread.run()
            if solution is not None:
                yield solution
            cuda_reset_counter += 1
    else:
        with ThreadPoolExecutor(max_workers=max_threads) as service:
            futures = [
                service.submit(
                    thread.run,
                )
                for thread in solve_threads
            ]
            for future in pbar(
                futures,
            ):
                yield future.result(timeout=259200)

            service.shutdown()


def is_generator_empty(generator: Iterable) -> tuple[bool, Iterable]:
    """Check if a generator is empty and return a usable copy.
    
    Returns:
        A tuple of (is_empty, usable_generator)
        - is_empty: True if the generator is empty, False otherwise
        - usable_generator: A generator that includes all original items,
          including the first one checked for emptiness
    """
    # For sequences with a length, just check the length
    if hasattr(generator, "__len__"):
        return len(generator) == 0, generator
    
    # For generators, use tee to make copies
    try:
        # Create two copies - one to check for emptiness, one to return
        check_copy, return_copy = itertools.tee(generator)
        first_item = next(check_copy, None)
        
        if first_item is None:
            # Generator is empty
            return True, return_copy
        else:
            # Generator has items - create a chain with the first item and the rest
            return False, itertools.chain([first_item], check_copy)
    except Exception:
        # If we can't check, assume it's not empty to be safe
        return False, generator


def evaluate_all_solutions(
    solutions: Iterable[CodeSolution],
    evaluator: SolutionEvaluator = None,
    max_threads: int = 1,
) -> Iterable[SolutionEvaluation]:
    """
    Evaluate all solutions in an iterable.

    Args:
    ----
        solutions (Iterable[CodeSolution]): The solutions to evaluate.
        evaluator (SolutionEvaluator): The evaluator to use. If it is None,
            it will use the preferred evaluator of the problem.
        max_threads (int, optional): The maximum number of threads to use.
            Defaults to 1 (no multi-threading).

    """
    # Safety check for solutions
    if solutions is None:
        raise ValueError("solutions parameter cannot be None")
        
    is_empty, solutions = is_generator_empty(solutions)
    if is_empty:
        return iter(())

    evaluate_threads = (
        EvaluateThread(
            solution,
            evaluator,
        )
        for solution in solutions
    )

    def pbar(data):
        return tqdm(
            data,
            total=len(solutions) if hasattr(solutions, "__len__") else None,
            desc="Evaluating solutions",
        )

    if max_threads <= 1:
        for evaluate_thread in pbar(
            evaluate_threads,
        ):
            evaluation = evaluate_thread.run()
            yield evaluation
    else:
        with ThreadPoolExecutor(max_workers=max_threads) as service:
            futures = [
                service.submit(
                    thread.run,
                )
                for thread in evaluate_threads
            ]
            for future in pbar(
                futures,
            ):
                yield future.result(timeout=259200)

            service.shutdown()


def solve_and_evaluate_problems(
    solver: BaseCodeSolver,
    problems: Iterable[CodeProblem],
    max_threads_solve: int = 1,
    max_threads_eval: int = 1,
    finish_all_solves_first: bool = True,
    evaluator: SolutionEvaluator = None,
) -> Iterable[SolutionEvaluation]:
    # Safety check for problems
    if problems is None:
        raise ValueError("problems parameter cannot be None")

    solves = solve_all_problems(
        solver,
        problems,
        max_threads=max_threads_solve if solver.allows_multithreading else 1,
        cuda_clear_freq=None if solver.allows_multithreading else 100,
    )
    
    # Safety check for solves
    if solves is None:
        raise ValueError("solve_all_problems returned None instead of an iterable")

    if finish_all_solves_first:
        try:
            solves = list(solves)
        except TypeError as e:
            if "not iterable" in str(e):
                raise ValueError(f"solves is not iterable: {e}") from e
            raise

        for solve in solves:
            if solve is not None:
                if solve.problem is None:
                    raise RuntimeError("Problem is None")
    return evaluate_all_solutions(solves, evaluator=evaluator, max_threads=max_threads_eval)
