from enum import StrEnum
from typing import Literal, TypeVar

from pydantic import BaseModel, Field

from olmoearth_run.shared.models.step_type import StepType


class WandbRunInfo(BaseModel):
    """Information about a Weights & Biases run."""
    run_id: str = Field(description="Weights & Biases run ID")
    url: str = Field(description="URL to the Weights & Biases run")


###############
# Task Results
###############
class CreatePartitionsTaskResults(BaseModel):
    step_type: Literal[StepType.CREATE_PARTITIONS] = Field(default=StepType.CREATE_PARTITIONS)
    partition_ids: list[str] = Field(description="The list of partition identifiers")


class DatasetBuildTaskResults(BaseModel):
    step_type: Literal[StepType.DATASET_BUILD] = Field(default=StepType.DATASET_BUILD)
    dataset_build_path: str = Field(description="The path of the root of the built dataset")
    dataset_size_mb: float | None = Field(default=None, description="The size of the built dataset in megabytes")


class DatasetBuildFromWindowsTaskResults(BaseModel):
    step_type: Literal[StepType.DATASET_BUILD_FROM_WINDOWS] = Field(default=StepType.DATASET_BUILD_FROM_WINDOWS)
    dataset_build_path: str = Field(description="The path of the root of the built dataset from pre-created windows")
    dataset_size_mb: float | None = Field(default=None, description="The size of the built dataset in megabytes")


class InferenceResultsDataType(StrEnum):
    RASTER = "RASTER"
    VECTOR = "VECTOR"


class RunInferenceTaskResults(BaseModel):
    step_type: Literal[StepType.RUN_INFERENCE] = Field(default=StepType.RUN_INFERENCE)
    inference_results_data_type: InferenceResultsDataType | None = Field(default=None)


class PostprocessPartitionTaskResults(BaseModel):
    step_type: Literal[StepType.POSTPROCESS_PARTITION] = Field(default=StepType.POSTPROCESS_PARTITION)
    partition_ids: list[str] = Field(description="The partition IDs that were postprocessed")
    output_files: list[str] = Field(description="The output files generated by the postprocessing step")
    inference_results_data_type: InferenceResultsDataType = Field(description="The type of inference results data")


class CombinePartitionsTaskResults(BaseModel):
    step_type: Literal[StepType.COMBINE_PARTITIONS] = Field(default=StepType.COMBINE_PARTITIONS)
    generated_file_paths: list[str] = Field(description="The combined geojson or geotiff filepaths")


class PrepareLabeledWindowsTaskResults(BaseModel):
    step_type: Literal[StepType.PREPARE_LABELED_WINDOWS] = Field(default=StepType.PREPARE_LABELED_WINDOWS)
    windows_count: int = Field(description="Number of labeled windows created")


class FineTuneTaskResults(BaseModel):
    step_type: Literal[StepType.FINE_TUNE] = Field(default=StepType.FINE_TUNE)
    checkpoint_path: str = Field(description="Path to the best checkpoint file")
    wandb_run_info: WandbRunInfo = Field(description="Weights & Biases run information")


TaskResults = PrepareLabeledWindowsTaskResults | CreatePartitionsTaskResults | DatasetBuildTaskResults | DatasetBuildFromWindowsTaskResults | FineTuneTaskResults | RunInferenceTaskResults | PostprocessPartitionTaskResults | CombinePartitionsTaskResults | None
TaskResultsType = TypeVar("TaskResultsType", bound=TaskResults)
