from typing import List, Optional, Set

from pydantic import model_validator
from typing_extensions import Self, override

from pipelex import log
from pipelex.config import StaticValidationReaction, get_config
from pipelex.core.pipe_input_spec import PipeInputSpec
from pipelex.core.pipe_output import PipeOutput
from pipelex.core.pipe_run_params import PipeRunMode, PipeRunParams
from pipelex.core.working_memory import WorkingMemory
from pipelex.exceptions import PipeRunParamsError, StaticValidationError, StaticValidationErrorType
from pipelex.hub import get_required_pipe
from pipelex.pipe_controllers.pipe_controller import PipeController
from pipelex.pipe_controllers.sub_pipe import SubPipe
from pipelex.pipeline.job_metadata import JobMetadata


class PipeSequence(PipeController):
    sequential_sub_pipes: List[SubPipe]

    @override
    def needed_inputs(self) -> PipeInputSpec:
        """
        Calculate the inputs needed by this sequence.
        This is the inputs needed by all pipes in the sequence MINUS the outputs generated by previous steps.
        """

        needed_inputs = PipeInputSpec.make_empty()
        generated_outputs: Set[str] = set()

        for sequential_sub_pipe in self.sequential_sub_pipes:
            sub_pipe_needed_inputs = get_required_pipe(pipe_code=sequential_sub_pipe.pipe_code).needed_inputs()

            # Handle batching: if this sub_pipe has batch_params, exclude the batch_as input
            # since it's provided by the batching mechanism
            if sequential_sub_pipe.batch_params:
                batch_as_input = sequential_sub_pipe.batch_params.input_item_stuff_name
                # Create a new PipeInputSpec without the batch_as input
                filtered_needed_inputs = PipeInputSpec.make_empty()
                for var_name, concept_code in sub_pipe_needed_inputs.items:
                    if var_name != batch_as_input:
                        filtered_needed_inputs.add_requirement(variable_name=var_name, concept_code=concept_code)
                sub_pipe_needed_inputs = filtered_needed_inputs

            # Add inputs that haven't been generated by previous steps
            for var_name, concept_code in sub_pipe_needed_inputs.items:
                if var_name not in generated_outputs:
                    needed_inputs.add_requirement(variable_name=var_name, concept_code=concept_code)

            # Add this step's output to generated outputs
            if sequential_sub_pipe.output_name:
                generated_outputs.add(sequential_sub_pipe.output_name)

        return needed_inputs

    @override
    def required_variables(self) -> Set[str]:
        return set()

    @model_validator(mode="after")
    def validate_inputs(self) -> Self:
        if len(self.sequential_sub_pipes) == 0:
            raise ValueError(f"Pipe'{self.code}'(PipeSequence) must have at least 1 step")
        return self

    def _validate_output_multiplicity_support(self, pipe_run_params: PipeRunParams) -> None:
        """Validate that the pipe supports the requested output multiplicity."""
        if pipe_run_params.is_multiple_output_required:
            raise PipeRunParamsError(
                f"{self.__class__.__name__} does not support multiple outputs, got output_multiplicity = {pipe_run_params.output_multiplicity}"
            )

    def _validate_inputs(self):
        """
        Validate that the inputs declared for this PipeSequence match what is actually needed.
        """
        static_validation_config = get_config().pipelex.static_validation_config
        default_reaction = static_validation_config.default_reaction
        reactions = static_validation_config.reactions

        the_needed_inputs = self.needed_inputs()

        # Check all required variables are in the inputs
        for required_variable_name, _, _ in the_needed_inputs.detailed_requirements:
            if required_variable_name not in self.inputs.variables:
                missing_input_var_error = StaticValidationError(
                    error_type=StaticValidationErrorType.MISSING_INPUT_VARIABLE,
                    domain_code=self.domain,
                    pipe_code=self.code,
                    variable_names=[required_variable_name],
                )
                match reactions.get(StaticValidationErrorType.MISSING_INPUT_VARIABLE, default_reaction):
                    case StaticValidationReaction.IGNORE:
                        pass
                    case StaticValidationReaction.LOG:
                        log.error(missing_input_var_error.desc())
                    case StaticValidationReaction.RAISE:
                        raise missing_input_var_error

        # Check that all declared inputs are actually needed
        for input_name in self.inputs.variables:
            if input_name not in the_needed_inputs.required_names:
                extraneous_input_var_error = StaticValidationError(
                    error_type=StaticValidationErrorType.EXTRANEOUS_INPUT_VARIABLE,
                    domain_code=self.domain,
                    pipe_code=self.code,
                    variable_names=[input_name],
                )
                match reactions.get(StaticValidationErrorType.EXTRANEOUS_INPUT_VARIABLE, default_reaction):
                    case StaticValidationReaction.IGNORE:
                        pass
                    case StaticValidationReaction.LOG:
                        log.error(extraneous_input_var_error.desc())
                    case StaticValidationReaction.RAISE:
                        raise extraneous_input_var_error

    @override
    def validate_with_libraries(self):
        """
        Perform full validation after all libraries are loaded.
        This is called after all pipes and concepts are available.
        """
        self._validate_inputs()

    @override
    def pipe_dependencies(self) -> Set[str]:
        return set(sub_pipe.pipe_code for sub_pipe in self.sequential_sub_pipes)

    @override
    async def _run_controller_pipe(
        self,
        job_metadata: JobMetadata,
        working_memory: WorkingMemory,
        pipe_run_params: PipeRunParams,
        output_name: Optional[str] = None,
    ) -> PipeOutput:
        pipe_run_params.push_pipe_layer(pipe_code=self.code)
        self._validate_output_multiplicity_support(pipe_run_params)

        evolving_memory = working_memory

        for sub_pipe_index, sub_pipe in enumerate(self.sequential_sub_pipes):
            sub_pipe_run_params: PipeRunParams
            # only the last step should apply the final_stuff_code
            if sub_pipe_index == len(self.sequential_sub_pipes) - 1:
                sub_pipe_run_params = pipe_run_params.model_copy()
            else:
                sub_pipe_run_params = pipe_run_params.model_copy(update=({"final_stuff_code": None}))
            pipe_output = await sub_pipe.run_pipe(
                calling_pipe_code=self.code,
                working_memory=evolving_memory,
                job_metadata=job_metadata,
                sub_pipe_run_params=sub_pipe_run_params,
            )
            evolving_memory = pipe_output.working_memory
        return PipeOutput(
            working_memory=evolving_memory,
            pipeline_run_id=job_metadata.pipeline_run_id,
        )

    @override
    async def _dry_run_controller_pipe(
        self,
        job_metadata: JobMetadata,
        working_memory: WorkingMemory,
        pipe_run_params: PipeRunParams,
        output_name: Optional[str] = None,
    ) -> PipeOutput:
        if pipe_run_params.run_mode != PipeRunMode.DRY:
            raise PipeRunParamsError(f"PipeSequence._dry_run_controller_pipe() called with run_mode = {pipe_run_params.run_mode} in pipe {self.code}")
        return await self._run_controller_pipe(
            job_metadata=job_metadata,
            working_memory=working_memory,
            pipe_run_params=pipe_run_params,
            output_name=output_name,
        )
