import os, random, shutil, hashlib
import numpy as np
import torch
import torch.nn as nn
from torchvision import models, transforms
from timm import create_model
import open_clip
from sklearn.cluster import KMeans, DBSCAN
from sklearn.metrics.pairwise import cosine_similarity
from PIL import Image, ImageFile
from rich.console import Console
from rich.table import Table
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeElapsedColumn
import matplotlib.pyplot as plt

ImageFile.LOAD_TRUNCATED_IMAGES = True
console = Console()

# ============================================================
# Core utilities
# ============================================================

def progress_bar(task_name, total):
    return Progress(
        SpinnerColumn(),
        TextColumn(f"[bold blue]{task_name}[/bold blue]"),
        BarColumn(),
        TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
        TimeElapsedColumn(),
        console=console,
        transient=True,
    )


def device_select(device="auto"):
    if device == "auto":
        if torch.cuda.is_available():
            return torch.device("cuda")
        elif torch.backends.mps.is_available():
            return torch.device("mps")
        else:
            return torch.device("cpu")
    else:
        return torch.device(device)


def load_images(folder):
    images, paths = [], []
    for f in os.listdir(folder):
        if f.lower().endswith((".jpg", ".jpeg", ".png")):
            path = os.path.join(folder, f)
            try:
                img = Image.open(path).convert("RGB")
                images.append(img)
                paths.append(path)
            except Exception:
                continue
    return images, paths


# ============================================================
# CleanFrame main class
# ============================================================

class CleanFrame:
    def __init__(self, path, model="swin", cluster="kmeans", n_clusters=5,
                 device="auto", batch_size=32, eps=0.5, min_samples=5, threshold=0.95):

        self.path = path
        self.model_name = model.lower()
        self.cluster_method = cluster.lower()
        self.n_clusters = n_clusters
        self.batch_size = batch_size
        self.eps = eps
        self.min_samples = min_samples
        self.threshold = threshold

        self.device = device_select(device)
        self.embeddings = None
        self.paths = None
        self.labels = None
        self.cleaned_paths = []
        self.removed_paths = []

        name = os.path.basename(os.path.normpath(path))
        self.root = os.path.join(".cleanframe_cache", name)
        self.cleaned_dir = os.path.join(self.root, "cleaned")
        os.makedirs(self.cleaned_dir, exist_ok=True)

        console.print(f"[bold cyan]Device:[/bold cyan] {self.device}")

    # ============================================================
    # Embedding extraction
    # ============================================================

    def embed_frames(self):
        images, paths = load_images(self.path)
        self.paths = paths

        if self.model_name == "resnet":
            model = models.resnet50(pretrained=True)
            model = nn.Sequential(*list(model.children())[:-1])
        elif self.model_name == "swin":
            model = create_model("swin_tiny_patch4_window7_224", pretrained=True, num_classes=0)
        elif self.model_name == "clip":
            model, _, preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai")
        else:
            raise ValueError("model must be one of ['swin','clip','resnet'].")

        model.to(self.device).eval()

        tf = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        embs = []
        with torch.inference_mode(), progress_bar(f"Embedding ({self.model_name})", len(images)) as p:
            task = p.add_task("Embedding", total=len(images))
            for i in range(0, len(images), self.batch_size):
                batch_imgs = images[i:i+self.batch_size]
                if self.model_name == "clip":
                    batch = torch.stack([preprocess(im) for im in batch_imgs]).to(self.device)
                    feat = model.encode_image(batch)
                    feat = feat / feat.norm(dim=-1, keepdim=True)
                else:
                    batch = torch.stack([tf(im) for im in batch_imgs]).to(self.device)
                    feat = model(batch).squeeze()
                    if feat.ndim == 1:
                        feat = feat.unsqueeze(0)
                    feat = feat / feat.norm(dim=-1, keepdim=True)
                embs.append(feat.cpu().numpy())
                p.update(task, advance=len(batch_imgs))

        self.embeddings = np.concatenate(embs, axis=0)
        return self.embeddings, self.paths

    # ============================================================
    # Clustering
    # ============================================================

    def cluster_frames(self):
        console.print(f"[cyan]Clustering using {self.cluster_method.upper()}...[/cyan]")
        emb = self.embeddings
        if self.cluster_method == "kmeans":
            self.labels = KMeans(n_clusters=self.n_clusters, random_state=42, n_init="auto").fit_predict(emb)
        elif self.cluster_method == "dbscan":
            self.labels = DBSCAN(eps=self.eps, min_samples=self.min_samples).fit_predict(emb)
        else:
            raise ValueError("cluster must be 'kmeans' or 'dbscan'")
        return self.labels

    # ============================================================
    # Cleaning
    # ============================================================

    def clean_frames(self):
        console.print("[cyan]Cleaning frames...[/cyan]")
        emb, paths, labels = self.embeddings, self.paths, self.labels
        unique_clusters = [c for c in np.unique(labels) if c != -1]

        kept, removed = [], []
        with progress_bar("Cleaning", len(unique_clusters)) as p:
            task = p.add_task("clean", total=len(unique_clusters))
            for cid in unique_clusters:
                idx = np.where(labels == cid)[0]
                cluster_paths = [paths[i] for i in idx]
                cluster_embs = emb[idx]

                selected_idx = []
                for i, e in enumerate(cluster_embs):
                    if not selected_idx:
                        selected_idx.append(i)
                        kept.append(cluster_paths[i])
                        continue
                    sims = cosine_similarity(e.reshape(1, -1), cluster_embs[selected_idx])[0]
                    if np.max(sims) < self.threshold:
                        selected_idx.append(i)
                        kept.append(cluster_paths[i])
                    else:
                        removed.append(cluster_paths[i])
                        shutil.copy(cluster_paths[i], os.path.join(self.root, f"removed_{os.path.basename(cluster_paths[i])}"))
                p.update(task, advance=1)

        for f in kept:
            shutil.copy(f, os.path.join(self.cleaned_dir, os.path.basename(f)))

        self.cleaned_paths = kept
        self.removed_paths = removed
        return kept, removed

    # ============================================================
    # Summary
    # ============================================================

    def _summary_table(self):
        total = len(self.paths)
        kept = len(self.cleaned_paths)
        removed = len(self.removed_paths)
        table = Table(title="CLEANFRAMES SUMMARY", show_header=False, show_lines=False)
        table.add_row("Model", f"Swin Transformer" if self.model_name == "swin" else self.model_name.capitalize())
        table.add_row("Clustering", self.cluster_method.upper())
        table.add_row("Device", str(self.device))
        table.add_row("Threshold", str(self.threshold))
        table.add_row("------------------------------------", "")
        table.add_row("Total frames", str(total))
        table.add_row("Kept", f"{kept} ({(kept/total)*100:.1f}%)")
        table.add_row("Removed", f"{removed} ({(removed/total)*100:.1f}%)")
        table.add_row("------------------------------------", "")
        table.add_row("Output", self.cleaned_dir)
        table.add_row("Report", f"{self.root}/report (use cf.report())")
        console.print(table)

    def run(self):
        console.rule(f"[bold green]CLEANFRAME PIPELINE")
        self.embed_frames()
        self.cluster_frames()
        self.clean_frames()
        self._summary_table()
        console.rule("[bold green]Done[/bold green]")

    # ============================================================
    # Detailed Report (Notebook visuals)
    # ============================================================

    def report(self):
        console.print("[bold magenta]CleanFrames Detailed Report[/bold magenta]")
        console.print("---------------------------")
        console.print(f"Model: {self.model_name.capitalize()}")
        console.print(f"Clustering: {self.cluster_method.upper()}")
        console.print(f"Device: {self.device}")
        console.print(f"Threshold: {self.threshold}")
        console.print("------------------------------------")

        # pick 3 random pairs from all frames
        pairs = []
        if len(self.cleaned_paths) > 3:
            sample_idx = random.sample(range(len(self.cleaned_paths)), 6)
            for i in range(0, 6, 2):
                pairs.append((self.cleaned_paths[sample_idx[i]], self.cleaned_paths[sample_idx[i+1]]))
        else:
            for i in range(0, len(self.cleaned_paths)-1, 2):
                pairs.append((self.cleaned_paths[i], self.cleaned_paths[i+1]))

        for i, (p1, p2) in enumerate(pairs[:3]):
            idx1, idx2 = self.paths.index(p1), self.paths.index(p2)
            sim = cosine_similarity(self.embeddings[idx1].reshape(1, -1),
                                    self.embeddings[idx2].reshape(1, -1))[0][0]
            img1, img2 = Image.open(p1), Image.open(p2)
            fig, axes = plt.subplots(1, 2, figsize=(6, 3))
            axes[0].imshow(img1); axes[1].imshow(img2)
            for ax in axes:
                ax.axis("off")
            fig.suptitle(f"Set {i+1} — Similarity: {sim*100:.1f}%", fontsize=12)
            plt.show()
            console.print(f"Set {i+1}: Similarity {sim*100:.1f}%")
            console.print(f"  Image 1: {os.path.basename(p1)}")
            console.print(f"  Image 2: {os.path.basename(p2)}\n")

        total = len(self.paths)
        kept = len(self.cleaned_paths)
        removed = len(self.removed_paths)
        kept_pct = (kept / total) * 100
        removed_pct = 100 - kept_pct
        console.print("Summary:")
        console.print(f"Cleaned with {self.model_name} embedding and {self.cluster_method}.")
        console.print(f"{kept_pct:.1f}% kept")
        console.print(f"{removed_pct:.1f}% removed")
        console.rule()