"""Models for OmnibusX SDK."""

from pydantic import BaseModel


class SequencingTechnology:
    """A namespace for sequencing technologies."""

    SC_RNA_SEQ = "sc_rna_seq"
    BULK_RNA_SEQ = "bulk_rna_seq"
    SC_ATAC_SEQ = "sc_atac_seq"
    BULK_ATAC_SEQ = "bulk_atac_seq"
    WELL_BASED_SPATIAL = "well_based_spatial"


class SequencingPlatform:
    """A namespace for sequencing platform."""

    class ScRnaSeq:
        """A namespace for scRNAseq sequencing platform."""

        CHROMIUM_10X = "10x"
        CHROMIUM_AND_IMMUNE_RECEPTOR_10X = "immune_10x"
        CITE_SEQ = "cite_seq"
        SMART_SEQ_2 = "smart_seq2"
        DROP_SEQ = "drop_seq"
        OTHERS = "unknown"

    class BulkRnaSeq:
        """A namespace for bulk RNAseq sequencing platform."""

        ILLUMINA = "illumina"

    class ScAtacSeq:
        """A namespace for scATACseq sequencing platform."""

        ATACSEQ_10X = "10x_atacseq"
        ATACSEQ_GEX_10X = "10x_atacseq_gex"

    class WellBasedSpatial:
        """A namespace for well-based spatial sequencing platform."""

        VISIUM_10X = "10x_visium"
        VISIUM_HD_10X = "10x_visium_hd"
        GEOMX_DSP = "geomx_dsp"
        SLIDE_SEQ = "slide_seq"
        XENIUM = "xenium"


class DataFormat:
    """A namespace for data formats."""

    SCANPY = "scanpy"
    SEURAT = "seurat"
    TEXT = "text"
    TAB = "tab"
    TAB_HTSEQ = "tab_htseq"
    TAB_KALLISTO = "tab_kallisto"
    TAB_FEATURECOUNTS = "tab_featurecounts"
    TAB_STAR = "tab_star"
    H5_10X = "h5_10x"
    MTX_10X = "mtx_10x"
    IMMUNE_10X = "immune_10x"
    VISIUM = "visium"
    VISIUM_HD = "visium_hd"
    VISIUM_AGGR = "visium_aggr"
    GEOMX_DSP = "geomx_dsp"
    GEOMX_IPA = "geomx_ipa"
    ATAC_H5 = "atac_h5"
    ATAC_MTX = "atac_mtx"
    ATAC_GEX_H5 = "atac_gex_h5"
    ATAC_GEX_MTX = "atac_gex_mtx"
    XENIUM = "xenium"


class GeneReferenceVersion:
    """A namespace for gene reference versions."""

    ENSEMBL_111 = 111


class Species:
    """A namespace for species."""

    HUMAN = "Homo_sapiens"
    MOUSE = "Mus_musculus"


class FileLocation:
    """A namespace for file location."""

    SERVER = "server"


class TaskType:
    """A namespace for task type."""

    IMPORT_OMNIBUSX_FILE = "IMPORT_OMNIBUSX_FILE"
    PREPROCESS_DATASET = "PREPROCESS_DATASET"


class TaskStatus:
    """A namespace for task status."""

    PENDING = "PENDING"
    RUNNING = "RUNNING"
    SUCCESS = "SUCCESS"
    FAILED = "FAILED"
    TERMINATED = "TERMINATED"


class SampleBatch(BaseModel):
    """A namespace for sample batch."""

    file_path: str
    batch_name: str


class UserGroup(BaseModel):
    """A namespace for user group."""

    user_group_id: str
    name: str
    description: str


class ImportOmnibusXFileParams(BaseModel):
    """A namespace for importing an omnibusx file parameters."""

    omnibusx_file_path: str
    group_id: str


class AddTaskParams(BaseModel):
    """A namespace for adding new task parameters."""

    task_type: str
    params: ImportOmnibusXFileParams | dict


class AddTaskResponse(BaseModel):
    """A namespace for add task response."""

    task_id: str


class TaskLog(BaseModel):
    """A namespace for task."""

    id: str
    task_type: str
    dataset_id: str
    created_at: int
    finished_at: int
    status: str
    log: str
    params: str
    result: str
    is_committed: bool


class AddFileUploadPayload(BaseModel):
    """Parameters for initiating a file upload."""

    original_filename: str
    n_chunks: int
    folder_id: str


class FileUploadChunk(BaseModel):
    """Information about a file upload chunk."""

    id: str
    folder_id: str
    chunk: int
    n_chunks: int


class UploadFilesResponse(BaseModel):
    """Response from file upload completion."""

    folder_id: str
    folder_path: str


class UploadProgress(BaseModel):
    """Progress information for file uploads."""

    total_files: int
    total_chunks: int
    done_files: int
    done_chunks: int
    current_file: str


class BatchInfo(BaseModel):
    """Information about a data batch for preprocessing."""

    file_path: str
    batch_name: str


class PreprocessDatasetParams(BaseModel):
    """Parameters for preprocessing a dataset."""

    name: str
    description: str
    batches: list[BatchInfo]
    gene_reference_version: int
    gene_reference_id: str  # Species.HUMAN or Species.MOUSE
    technology: str  # SequencingTechnology.SC_RNA_SEQ or WELL_BASED_SPATIAL
    platform: str  # Platform from SequencingPlatform (ScRnaSeq or WellBasedSpatial)
    data_format: str  # DataFormat (SCANPY/SEURAT for scRNA-seq, XENIUM/VISIUM/VISIUM_HD for spatial)

    def model_post_init(self, __context) -> None:
        """Validate parameters after initialization."""
        # Validate gene_reference_id (species)
        valid_species = [Species.HUMAN, Species.MOUSE]
        if self.gene_reference_id not in valid_species:
            raise ValueError(
                f"gene_reference_id must be one of: {valid_species}. "
                f"Got: {self.gene_reference_id}"
            )

        # Validate technology
        supported_technologies = [
            SequencingTechnology.SC_RNA_SEQ,
            SequencingTechnology.WELL_BASED_SPATIAL,
        ]
        if self.technology not in supported_technologies:
            raise ValueError(
                f"Technology must be one of: {supported_technologies}. "
                f"Got: {self.technology}"
            )

        # Validate platform and data_format based on technology
        if self.technology == SequencingTechnology.SC_RNA_SEQ:
            # Validate platform (only those that support scanpy/seurat)
            supported_platforms = [
                SequencingPlatform.ScRnaSeq.OTHERS,
                SequencingPlatform.ScRnaSeq.CHROMIUM_10X,
                SequencingPlatform.ScRnaSeq.CITE_SEQ,
                SequencingPlatform.ScRnaSeq.SMART_SEQ_2,
                SequencingPlatform.ScRnaSeq.DROP_SEQ,
            ]
            if self.platform not in supported_platforms:
                raise ValueError(
                    f"Platform must be one of: {supported_platforms}. "
                    f"Got: {self.platform}"
                )

            # Validate data_format (only scanpy and seurat supported for scRNA-seq)
            supported_formats = [DataFormat.SCANPY, DataFormat.SEURAT]
            if self.data_format not in supported_formats:
                raise ValueError(
                    f"For scRNA-seq, data format must be one of: {supported_formats}. "
                    f"Got: {self.data_format}"
                )

        elif self.technology == SequencingTechnology.WELL_BASED_SPATIAL:
            # Validate platform for spatial technologies
            supported_platforms = [
                SequencingPlatform.WellBasedSpatial.XENIUM,
                SequencingPlatform.WellBasedSpatial.VISIUM_10X,
                SequencingPlatform.WellBasedSpatial.VISIUM_HD_10X,
            ]
            if self.platform not in supported_platforms:
                raise ValueError(
                    f"For spatial data, platform must be one of: {supported_platforms}. "
                    f"Got: {self.platform}"
                )

            # Validate data_format for spatial technologies
            supported_formats = [
                DataFormat.XENIUM,
                DataFormat.VISIUM,
                DataFormat.VISIUM_HD,
            ]
            if self.data_format not in supported_formats:
                raise ValueError(
                    f"For spatial data, data format must be one of: {supported_formats}. "
                    f"Got: {self.data_format}"
                )

            # Validate platform-format compatibility
            platform_format_map = {
                SequencingPlatform.WellBasedSpatial.XENIUM: DataFormat.XENIUM,
                SequencingPlatform.WellBasedSpatial.VISIUM_10X: DataFormat.VISIUM,
                SequencingPlatform.WellBasedSpatial.VISIUM_HD_10X: DataFormat.VISIUM_HD,
            }
            expected_format = platform_format_map.get(self.platform)
            if self.data_format != expected_format:
                raise ValueError(
                    f"Platform '{self.platform}' requires data format '{expected_format}', "
                    f"but got '{self.data_format}'"
                )

        # Validate batches
        if not self.batches:
            raise ValueError("At least one batch must be provided")
