from mteb.abstasks.task_metadata import TaskMetadata
from mteb.abstasks.zeroshot_classification import (
    AbsTaskZeroShotClassification,
)


class MNISTZeroShotClassification(AbsTaskZeroShotClassification):
    metadata = TaskMetadata(
        name="MNISTZeroShot",
        description="Classifying handwritten digits.",
        reference="https://en.wikipedia.org/wiki/MNIST_database",
        dataset={
            "path": "ylecun/mnist",
            "revision": "77f3279092a1c1579b2250db8eafed0ad422088c",
        },
        type="ZeroShotClassification",
        category="i2t",
        eval_splits=["test"],
        eval_langs=["eng-Latn"],
        main_score="accuracy",
        date=(
            "2010-01-01",
            "2010-04-01",
        ),  # Estimated range for the collection of reviews
        domains=["Encyclopaedic"],
        task_subtypes=["Object recognition"],
        license="not specified",
        annotations_creators="derived",
        dialect=[],
        modalities=["image", "text"],
        sample_creation="created",
        bibtex_citation=r"""
@article{lecun2010mnist,
  author = {LeCun, Yann and Cortes, Corinna and Burges, CJ},
  journal = {ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist},
  title = {MNIST handwritten digit database},
  volume = {2},
  year = {2010},
}
""",
    )

    def get_candidate_labels(self) -> list[str]:
        return [
            f"a photo of the number: '{name}'."
            for name in self.dataset["test"].features[self.label_column_name].names
        ]
