import os, shutil, random, numpy as np
from sklearn.cluster import KMeans, DBSCAN
from sklearn.metrics.pairwise import cosine_similarity
from .core import console, device_select, load_images, progress_bar
from .models import EmbeddingExtractor
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt


class CleanFrame:
    def __init__(self, path, model="clip", cluster="kmeans", n_clusters=5,
                 device="auto", batch_size=32, eps=0.5, min_samples=5,
                 threshold=0.95, cache=True):

        self.path = path
        self.model = model.lower()
        self.cluster_method = cluster.lower()
        self.n_clusters = n_clusters
        self.eps = eps
        self.min_samples = min_samples
        self.batch_size = batch_size
        self.threshold = threshold
        self.cache = cache

        # Device
        self.device = device_select(device)
        console.print(f"Device: {self.device}")

        # Paths
        self.dataset_name = os.path.basename(os.path.normpath(path))
        base_dir = os.path.dirname(os.path.normpath(path))
        self.cleaned_dir = os.path.join(base_dir, f"{self.dataset_name}_cleaned")
        self.removed_dir = os.path.join(base_dir, f"{self.dataset_name}_removed")
        self.emb_cache_path = os.path.join(base_dir, f"{self.dataset_name}_embeddings.npy")
        self.report_path = os.path.join(base_dir, f"{self.dataset_name}_report.txt")

        os.makedirs(self.cleaned_dir, exist_ok=True)
        os.makedirs(self.removed_dir, exist_ok=True)

        # Embedding extractor
        self.extractor = EmbeddingExtractor(self.device, self.batch_size)

        # Placeholders
        self.embeddings = None
        self.paths = None
        self.labels = None
        self.cleaned_paths = []
        self.removed_paths = []

    # ================================================================
    # Embedding Management (with cache)
    # ================================================================

    def _get_or_load_embeddings(self):
        if self.cache and os.path.exists(self.emb_cache_path):
            console.print(f"Loading cached embeddings from {self.emb_cache_path}")
            self.embeddings = np.load(self.emb_cache_path)
            return self.embeddings, self.paths

        console.print("Computing embeddings (no cache found)...")
        images, paths = load_images(self.path)
        if not paths:
            raise FileNotFoundError(f"No images found in {self.path}")

        self.paths = paths
        self.embeddings = self.extractor.get_embeddings(images, self.model)

        if self.cache:
            np.save(self.emb_cache_path, self.embeddings)
            console.print(f"Embeddings cached at: {self.emb_cache_path}")

        return self.embeddings, self.paths

    # ================================================================
    # Clustering and Cleaning
    # ================================================================

    def _cluster(self, embeddings):
        console.print(f"Clustering using {self.cluster_method.upper()}...")
        if self.cluster_method == "kmeans":
            labels = KMeans(n_clusters=self.n_clusters, random_state=42, n_init="auto").fit_predict(embeddings)
        elif self.cluster_method == "dbscan":
            labels = DBSCAN(eps=self.eps, min_samples=self.min_samples).fit_predict(embeddings)
        else:
            raise ValueError("Cluster method must be kmeans or dbscan.")
        return labels

    def _clean(self, embeddings, paths, labels):
        unique_clusters = [c for c in np.unique(labels) if c != -1]
        kept, removed = [], []

        with progress_bar("Cleaning frames", len(unique_clusters)) as p:
            task = p.add_task("Clean", total=len(unique_clusters))
            for cid in unique_clusters:
                idx = np.where(labels == cid)[0]
                cluster_paths = [paths[i] for i in idx]
                cluster_embs = embeddings[idx]

                selected_idx = []
                for i, e in enumerate(cluster_embs):
                    if not selected_idx:
                        selected_idx.append(i)
                        kept.append(cluster_paths[i])
                        continue
                    sims = cosine_similarity(e.reshape(1, -1), cluster_embs[selected_idx])[0]
                    if np.max(sims) < self.threshold:
                        selected_idx.append(i)
                        kept.append(cluster_paths[i])
                    else:
                        removed.append(cluster_paths[i])
                        shutil.copy(cluster_paths[i], os.path.join(self.removed_dir, os.path.basename(cluster_paths[i])))
                p.update(task, advance=1)

        for f in kept:
            shutil.copy(f, os.path.join(self.cleaned_dir, os.path.basename(f)))

        self.cleaned_paths = kept
        self.removed_paths = removed
        console.print(f"Kept: {len(kept)}   Removed: {len(removed)}")
        console.print(f"Cleaned output → {self.cleaned_dir}")
        console.print(f"Removed output → {self.removed_dir}")

    # ================================================================
    # Summary and Report
    # ================================================================

    def _summary(self):
        total = len(self.paths)
        kept = len(self.cleaned_paths)
        removed = len(self.removed_paths)

        lines = [
            "CleanFrames Summary",
            "-" * 40,
            f"Model: {self.model}",
            f"Clustering: {self.cluster_method}",
            f"Device: {self.device}",
            f"Threshold: {self.threshold}",
            "-" * 40,
            f"Total frames: {total}",
            f"Kept: {kept} ({(kept/total)*100:.1f}%)",
            f"Removed: {removed} ({(removed/total)*100:.1f}%)",
            "-" * 40,
            f"Cleaned Output: {self.cleaned_dir}",
            f"Removed Output: {self.removed_dir}",
            f"Report: {self.report_path}",
            "-" * 40
        ]

        console.print("\n".join(lines))
        with open(self.report_path, "w") as f:
            f.write("\n".join(lines))

    # ================================================================
    # Main pipeline
    # ================================================================

    def run(self):
        console.print(f"\nRunning CleanFrame on '{self.dataset_name}'")
        self.embeddings, self.paths = self._get_or_load_embeddings()
        self.labels = self._cluster(self.embeddings)
        self._clean(self.embeddings, self.paths, self.labels)
        self._summary()

    # ================================================================
    # Visualization
    # ================================================================
    
    def visualize_clusters(self, method="pca", sample_size=2000):
        """
        Visualize embeddings and cluster assignments using PCA or t-SNE.
        """
        if self.embeddings is None or self.labels is None:
            raise ValueError("Run the pipeline first (cf.run()) before visualizing clusters.")
        
        # Sample subset if large
        n = min(sample_size, len(self.embeddings))
        idx = np.random.choice(len(self.embeddings), n, replace=False)
        emb = self.embeddings[idx]
        labels = self.labels[idx]

        # Reduce dimensionality
        if method == "pca":
            reducer = PCA(n_components=2)
            reduced = reducer.fit_transform(emb)
        elif method == "tsne":
            reducer = TSNE(n_components=2, random_state=42, perplexity=30)
            reduced = reducer.fit_transform(emb)
        else:
            raise ValueError("method must be 'pca' or 'tsne'.")

        # Plot
        plt.figure(figsize=(8, 6))
        plt.scatter(reduced[:, 0], reduced[:, 1], c=labels, cmap='tab10', s=15)
        plt.title(f"Embedding Visualization ({method.upper()})")
        plt.xlabel("Component 1")
        plt.ylabel("Component 2")
        plt.show()


    def report(self):
        console.print("\nVisual Report (CleanFrames)")
        console.print("-" * 40)

        pairs = []
        if len(self.cleaned_paths) >= 6:
            sample_idx = random.sample(range(len(self.cleaned_paths)), 6)
            for i in range(0, 6, 2):
                pairs.append((self.cleaned_paths[sample_idx[i]], self.cleaned_paths[sample_idx[i+1]]))
        else:
            for i in range(0, len(self.cleaned_paths)-1, 2):
                pairs.append((self.cleaned_paths[i], self.cleaned_paths[i+1]))

        for i, (p1, p2) in enumerate(pairs[:3]):
            idx1, idx2 = self.paths.index(p1), self.paths.index(p2)
            sim = cosine_similarity(self.embeddings[idx1].reshape(1, -1),
                                    self.embeddings[idx2].reshape(1, -1))[0][0]
            img1, img2 = Image.open(p1), Image.open(p2)
            fig, axes = plt.subplots(1, 2, figsize=(6, 3))
            axes[0].imshow(img1); axes[1].imshow(img2)
            for ax in axes:
                ax.axis("off")
            fig.suptitle(f"Set {i+1} — Similarity: {sim*100:.1f}%", fontsize=12)
            plt.show()