import mlx.core as mx

from mlx_graphs.data.data import GraphData
from mlx_graphs.datasets.dataset import Dataset
from mlx_graphs.utils.transformations import to_undirected


class KarateClubDataset(Dataset):
    """
    Zachary's Karate Club netowork dataset from `An Information Flow Model for\
    Conflict and Fission in Small Groups\
    <https://www.jstor.org/stable/3629752>`_. This is a simple
    dataset for node classification. The graph has 34 nodes and 156 (undirected)
    edges. Each node belongs to one of 2 classes.
    """

    def __init__(self):
        super().__init__(name="karate_club", base_dir=None)

    def download(self):
        pass

    def process(self):
        edge_index = to_undirected(
            mx.array(
                [
                    (0, 1),
                    (0, 2),
                    (0, 3),
                    (0, 4),
                    (0, 5),
                    (0, 6),
                    (0, 7),
                    (0, 8),
                    (0, 10),
                    (0, 11),
                    (0, 12),
                    (0, 13),
                    (0, 17),
                    (0, 19),
                    (0, 21),
                    (0, 31),
                    (1, 2),
                    (1, 3),
                    (1, 7),
                    (1, 13),
                    (1, 17),
                    (1, 19),
                    (1, 21),
                    (1, 30),
                    (2, 3),
                    (2, 7),
                    (2, 8),
                    (2, 9),
                    (2, 13),
                    (2, 27),
                    (2, 28),
                    (2, 32),
                    (3, 7),
                    (3, 12),
                    (3, 13),
                    (4, 6),
                    (4, 10),
                    (5, 6),
                    (5, 10),
                    (5, 16),
                    (6, 16),
                    (8, 30),
                    (8, 32),
                    (8, 33),
                    (9, 33),
                    (13, 33),
                    (14, 32),
                    (14, 33),
                    (15, 32),
                    (15, 33),
                    (18, 32),
                    (18, 33),
                    (19, 33),
                    (20, 32),
                    (20, 33),
                    (22, 32),
                    (22, 33),
                    (23, 25),
                    (23, 27),
                    (23, 29),
                    (23, 32),
                    (23, 33),
                    (24, 25),
                    (24, 27),
                    (24, 31),
                    (25, 31),
                    (26, 29),
                    (26, 33),
                    (27, 33),
                    (28, 31),
                    (28, 33),
                    (29, 32),
                    (29, 33),
                    (30, 32),
                    (30, 33),
                    (31, 32),
                    (31, 33),
                    (32, 33),
                ]
            ).transpose()
        )
        node_features = mx.ones([34, 1])
        node_labels = mx.array(
            [
                [
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    1,
                    0,
                    0,
                    0,
                    0,
                    1,
                    1,
                    0,
                    0,
                    1,
                    0,
                    1,
                    0,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                ]
            ]
        ).transpose()
        self.graphs = [
            GraphData(
                edge_index=edge_index,
                node_features=node_features,
                node_labels=node_labels,
            )
        ]
