# cleanframe/main.py
import os, shutil, random, numpy as np
from sklearn.cluster import KMeans, DBSCAN
from sklearn.metrics.pairwise import cosine_similarity
from .core import console, DeviceManager, CacheManager, load_images, progress_bar
from .models import EmbeddingExtractor
from PIL import Image
import matplotlib.pyplot as plt
from rich.table import Table

class CleanFrame:
    def __init__(self, path, model="clip", cluster="kmeans", n_clusters=5,
                 device="auto", batch_size=32, eps=0.5, min_samples=5, threshold=0.95):

        # accept folder path of frames directly
        self.path = path
        self.model = model.lower()
        self.cluster_method = cluster.lower()
        self.n_clusters = n_clusters
        self.eps = eps
        self.min_samples = min_samples
        self.batch_size = batch_size
        self.threshold = threshold

        # prepare directories
        self.device = DeviceManager(device).get()
        self.dataset_name = os.path.basename(os.path.normpath(path))
        self.root = os.path.join(".cleanframe_cache", self.dataset_name)
        self.cleaned_dir = os.path.join(self.root, "cleaned")
        os.makedirs(self.cleaned_dir, exist_ok=True)

        self.cache = CacheManager(self.root)
        self.extractor = EmbeddingExtractor(self.device, self.batch_size)

        # placeholders
        self.embeddings = None
        self.paths = None
        self.labels = None
        self.cleaned_paths = []
        self.removed_paths = []

    # ================================================================
    # Clustering and Cleaning
    # ================================================================

    def _cluster(self, embeddings):
        console.print(f"[cyan]Clustering using {self.cluster_method.upper()}...[/cyan]")
        if self.cluster_method == "kmeans":
            labels = KMeans(n_clusters=self.n_clusters, random_state=42, n_init="auto").fit_predict(embeddings)
        elif self.cluster_method == "dbscan":
            labels = DBSCAN(eps=self.eps, min_samples=self.min_samples).fit_predict(embeddings)
        else:
            raise ValueError("Cluster method must be kmeans or dbscan.")
        return labels

    def _clean(self, embeddings, paths, labels):
        unique_clusters = [c for c in np.unique(labels) if c != -1]
        total_kept, total_removed = 0, 0
        kept, removed = [], []

        with progress_bar("Cleaning frames", len(unique_clusters)) as p:
            task = p.add_task("Clean", total=len(unique_clusters))
            for cid in unique_clusters:
                idx = np.where(labels == cid)[0]
                cluster_paths = [paths[i] for i in idx]
                cluster_embs = embeddings[idx]

                selected_idx = []
                for i, e in enumerate(cluster_embs):
                    if not selected_idx:
                        selected_idx.append(i)
                        kept.append(cluster_paths[i])
                        continue
                    sims = cosine_similarity(e.reshape(1, -1), cluster_embs[selected_idx])[0]
                    if np.max(sims) < self.threshold:
                        selected_idx.append(i)
                        kept.append(cluster_paths[i])
                    else:
                        removed.append(cluster_paths[i])
                p.update(task, advance=1)

        for f in kept:
            shutil.copy(f, os.path.join(self.cleaned_dir, os.path.basename(f)))

        self.cleaned_paths = kept
        self.removed_paths = removed
        console.print(f"[bold cyan]Kept:[/bold cyan] {len(kept)}   [bold red]Removed:[/bold red] {len(removed)}")
        console.print(f"[yellow]Output → {self.cleaned_dir}[/yellow]")

    # ================================================================
    # Summary
    # ================================================================

    def _summary(self):
        total = len(self.paths)
        kept = len(self.cleaned_paths)
        removed = len(self.removed_paths)

        table = Table(title="CLEANFRAMES SUMMARY", show_header=False, show_lines=False)
        table.add_row("Model", self.model.capitalize())
        table.add_row("Clustering", self.cluster_method.upper())
        table.add_row("Device", str(self.device))
        table.add_row("Threshold", str(self.threshold))
        table.add_row("------------------------------------", "")
        table.add_row("Total frames", str(total))
        table.add_row("Kept", f"{kept} ({(kept/total)*100:.1f}%)")
        table.add_row("Removed", f"{removed} ({(removed/total)*100:.1f}%)")
        table.add_row("------------------------------------", "")
        table.add_row("Output", self.cleaned_dir)
        table.add_row("Report", "cf.report() — detailed visual report")
        console.print(table)

    # ================================================================
    # Main pipeline
    # ================================================================

    def run(self):
        console.rule(f"[bold green]CLEANFRAME PIPELINE — {self.dataset_name}")
        images, paths = load_images(self.path)
        self.paths = paths
        self.embeddings = self.extractor.get_embeddings(images, self.model)
        self.labels = self._cluster(self.embeddings)
        self._clean(self.embeddings, self.paths, self.labels)
        self._summary()
        console.rule("[bold green]Done[/bold green]")

    # ================================================================
    # Report (Notebook Visual)
    # ================================================================

    def report(self):
        console.print("[bold magenta]CleanFrames Detailed Report[/bold magenta]")
        console.print("---------------------------")
        console.print(f"Model: {self.model.capitalize()}")
        console.print(f"Clustering: {self.cluster_method.upper()}")
        console.print(f"Device: {self.device}")
        console.print(f"Threshold: {self.threshold}")
        console.print("------------------------------------")

        # pick 3 random pairs from kept frames
        pairs = []
        if len(self.cleaned_paths) >= 6:
            sample_idx = random.sample(range(len(self.cleaned_paths)), 6)
            for i in range(0, 6, 2):
                pairs.append((self.cleaned_paths[sample_idx[i]], self.cleaned_paths[sample_idx[i+1]]))
        else:
            for i in range(0, len(self.cleaned_paths)-1, 2):
                pairs.append((self.cleaned_paths[i], self.cleaned_paths[i+1]))

        # visualize each set
        for i, (p1, p2) in enumerate(pairs[:3]):
            idx1, idx2 = self.paths.index(p1), self.paths.index(p2)
            sim = cosine_similarity(self.embeddings[idx1].reshape(1, -1),
                                    self.embeddings[idx2].reshape(1, -1))[0][0]
            img1, img2 = Image.open(p1), Image.open(p2)
            fig, axes = plt.subplots(1, 2, figsize=(6, 3))
            axes[0].imshow(img1); axes[1].imshow(img2)
            for ax in axes:
                ax.axis("off")
            fig.suptitle(f"Set {i+1} — Similarity: {sim*100:.1f}%", fontsize=12)
            plt.show()

            console.print(f"Set {i+1}: Similarity {sim*100:.1f}%")
            console.print(f"  Image 1: {os.path.basename(p1)}")
            console.print(f"  Image 2: {os.path.basename(p2)}\n")

        # textual summary footer
        total = len(self.paths)
        kept = len(self.cleaned_paths)
        removed = len(self.removed_paths)
        kept_pct = (kept / total) * 100
        removed_pct = 100 - kept_pct

        console.print("Summary:")
        console.print(f"Cleaned with {self.model} embedding and {self.cluster_method}.")
        console.print(f"{kept_pct:.1f}% kept")
        console.print(f"{removed_pct:.1f}% removed")
        console.rule()