import matplotlib.pyplot as plt
import numpy as np
from sbcluster import SpectralBridges
from sklearn.datasets import fetch_openml
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics import pairwise_distances

np.random.seed(0)

# Load MNIST dataset
mnist = fetch_openml("mnist_784", version=1)
X = PCA(n_components=32, random_state=42).fit_transform(mnist.data)
y = mnist.target

# Run SpectralBridges as a dimensionality reduction method
model = SpectralBridges(n_clusters=32, n_nodes=500, random_state=0, no_clustering=True)
model.fit(X)

# Project the data using t-SNE
embedding = TSNE().fit_transform(model.embedding_)

# Project the data using t-SNE
proj = embedding[pairwise_distances(X, model.cluster_centers_).argmin(axis=1)]
proj += np.random.normal(0, 0.1, proj.shape)

# Run raw t-SNE
raw_tsne = TSNE().fit_transform(X)

# Plot the data, and compare with raw TSNE
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.scatter(proj[:, 0], proj[:, 1], c=y.astype(int), s=1)
plt.xlabel("SB 1")
plt.ylabel("SB 2")

plt.subplot(1, 2, 2)
plt.scatter(raw_tsne[:, 0], raw_tsne[:, 1], c=y.astype(int), s=1)
plt.xlabel("TSNE 1")
plt.ylabel("TSNE 2")

plt.suptitle("MNIST")
plt.show()
