from typing import Literal, TypeVar
from uuid import UUID

from pydantic import BaseModel, Field


from olmoearth_run.shared.models.step_type import StepType


###############
# Task Args
###############


class _CommonInferenceTaskArgs(BaseModel):
    # all tasks in the inference workflow get the following fields:
    step_type: StepType
    scratch_path: str = Field(description="A path in a cloud blob storage where the worker can write data to")
    model_stage_id: UUID = Field(description="The ID of the model stage that is being executed")
    model_stage_root_path: str = Field(description="A path in cloud blob storage that contains the configuration files for the model stage that is being executed")
    dataset_path: str = Field(description="Path in cloud blob storage to the dataset for this model stage")

    def get_name_for_id(self) -> str:
        """
        Task Ids are deterministically generated by their step id + a name. This method provides the name part
        It should uniquely identify this piece of work within a single step, and is used to prevent duplicate tasks.
        """
        raise NotImplementedError


class PrepareLabeledWindowsTaskArgs(BaseModel):
    step_type: Literal[StepType.PREPARE_LABELED_WINDOWS] = Field(default=StepType.PREPARE_LABELED_WINDOWS)
    model_stage_root_path: str = Field(description="Path to the directory containing model configuration files")
    dataset_path: str = Field(description="Path where the dataset should be written")

    def get_name_for_id(self) -> str:
        return str(StepType.PREPARE_LABELED_WINDOWS)


class FineTuneTaskArgs(BaseModel):
    step_type: Literal[StepType.FINE_TUNE] = Field(default=StepType.FINE_TUNE)
    model_stage_root_path: str = Field(description="Path to the directory containing model configuration files")
    dataset_path: str = Field(description="Path to the dataset for training")
    step_id: UUID = Field(description="The ID of the fine-tuning step running this task")

    def get_name_for_id(self) -> str:
        return str(StepType.FINE_TUNE)


class CreatePartitionsTaskArgs(_CommonInferenceTaskArgs):
    step_type: Literal[StepType.CREATE_PARTITIONS] = Field(default=StepType.CREATE_PARTITIONS)

    def get_name_for_id(self) -> str:
        return str(StepType.CREATE_PARTITIONS)


class DatasetBuildTaskArgs(_CommonInferenceTaskArgs):
    step_type: Literal[StepType.DATASET_BUILD] = Field(default=StepType.DATASET_BUILD)
    partition_ids: list[str]

    def get_name_for_id(self) -> str:
        return "::".join(sorted(self.partition_ids))


class DatasetBuildFromWindowsTaskArgs(BaseModel):
    step_type: Literal[StepType.DATASET_BUILD_FROM_WINDOWS] = Field(default=StepType.DATASET_BUILD_FROM_WINDOWS)
    dataset_path: str = Field(description="Path in cloud blob storage to the dataset for this model stage")
    worker_index: int = Field(default=0, description="Index of this worker (0-based)")
    total_workers: int = Field(default=1, description="Total number of parallel workers")

    def get_name_for_id(self) -> str:
        return f"{StepType.DATASET_BUILD_FROM_WINDOWS}_{self.worker_index}"


class RunInferenceTaskArgs(_CommonInferenceTaskArgs):
    step_type: Literal[StepType.RUN_INFERENCE] = Field(default=StepType.RUN_INFERENCE)
    partition_ids: list[str]

    def get_name_for_id(self) -> str:
        return "::".join(sorted(self.partition_ids))


class PostprocessPartitionTaskArgs(_CommonInferenceTaskArgs):
    step_type: Literal[StepType.POSTPROCESS_PARTITION] = Field(default=StepType.POSTPROCESS_PARTITION)
    partition_ids: list[str]

    def get_name_for_id(self) -> str:
        return "::".join(sorted(self.partition_ids))


class CombinePartitionsTaskArgs(_CommonInferenceTaskArgs):
    step_type: Literal[StepType.COMBINE_PARTITIONS] = Field(default=StepType.COMBINE_PARTITIONS)
    partition_ids: list[str]

    def get_name_for_id(self) -> str:
        return str(StepType.COMBINE_PARTITIONS)


TaskArgs = CombinePartitionsTaskArgs | CreatePartitionsTaskArgs | DatasetBuildTaskArgs | DatasetBuildFromWindowsTaskArgs | FineTuneTaskArgs | PostprocessPartitionTaskArgs | PrepareLabeledWindowsTaskArgs | RunInferenceTaskArgs
TaskArgsType = TypeVar("TaskArgsType", bound=TaskArgs)
