import os, shutil, hashlib, random, json
import numpy as np, pandas as pd, torch, torch.nn as nn
from tqdm.notebook import tqdm
from PIL import Image, ImageFile
from torchvision import models, transforms
from transformers import AutoImageProcessor, AutoModel
from timm import create_model
import open_clip
from sklearn.cluster import KMeans, DBSCAN
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import imagehash
import matplotlib.pyplot as plt
from IPython import get_ipython
from tabulate import tabulate
from datetime import datetime

try:
    import hdbscan
except ImportError:
    hdbscan = None

ImageFile.LOAD_TRUNCATED_IMAGES = True


class CleanFrame:
    """
    =============================================================
    CleanFrame: Professional Image Deduplication & Clustering Tool
    =============================================================
    Performs:
      1. MD5 exact-duplicate removal
      2. Perceptual-hash near-duplicate removal
      3. Deep-embedding semantic cleaning
      4. Embedding clustering (KMeans, DBSCAN, HDBSCAN)
      5. TSNE/PCA visualization
      6. Smart caching + Summary reports
    """

    def __init__(self, device: str = "cpu"):
        self.device = torch.device(device if torch.cuda.is_available() else device)
        self._ipython_env = get_ipython()
        os.makedirs(".cleanframe_cache", exist_ok=True)

    # ==========================================================
    # Utility and Cache Helpers
    # ==========================================================

    def _safe_open(self, path):
        try:
            img = Image.open(path).convert("RGB")
            return img
        except Exception as e:
            print(f"[Warning] Skipping {os.path.basename(path)} ({e})")
            return None

    def _cache_path(self, folder, key):
        safe_folder = folder.strip("/").replace("/", "_")
        return os.path.join(".cleanframe_cache", f"{safe_folder}_{key}.npz")

    def _save_cache(self, folder, key, **kwargs):
        try:
            np.savez_compressed(self._cache_path(folder, key), **kwargs)
        except Exception as e:
            print(f"[Cache Warning] Could not save cache for {key}: {e}")

    def _load_cache(self, folder, key):
        path = self._cache_path(folder, key)
        if os.path.exists(path):
            try:
                data = np.load(path, allow_pickle=True)
                print(f"[Cache] Loaded {key} from cache.")
                return dict(data)
            except Exception as e:
                print(f"[Cache Warning] Failed to load cache for {key}: {e}")
        return None

    # ==========================================================
    # Embedding Extractors
    # ==========================================================

    def ResNetEmbedding(self, folder):
        cache = self._load_cache(folder, "resnet")
        if cache:
            return cache["embeddings"], list(cache["paths"])

        model = models.resnet50(pretrained=True)
        model = nn.Sequential(*list(model.children())[:-1]).to(self.device).eval()
        tf = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        embs, paths = [], []
        for f in tqdm(os.listdir(folder), desc="[ResNet] Embeddings"):
            if not f.lower().endswith(('.jpg', '.jpeg', '.png')):
                continue
            p = os.path.join(folder, f)
            img = self._safe_open(p)
            if img is None:
                continue
            try:
                x = tf(img).unsqueeze(0).to(self.device)
                feat = model(x).detach().squeeze().cpu().numpy()
                embs.append(feat / np.linalg.norm(feat))
                paths.append(p)
            except Exception as e:
                print(f"[ResNet Error] {f}: {e}")

        embs = np.array(embs)
        self._save_cache(folder, "resnet", embeddings=embs, paths=np.array(paths))
        return embs, paths

    def SwinEmbedding(self, folder):
        cache = self._load_cache(folder, "swin")
        if cache:
            return cache["embeddings"], list(cache["paths"])

        model = create_model('swin_tiny_patch4_window7_224', pretrained=True, num_classes=0)
        model.to(self.device).eval()
        tf = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        embs, paths = [], []
        with torch.no_grad():
            for f in tqdm(os.listdir(folder), desc="[Swin] Embeddings"):
                if not f.lower().endswith(('.jpg', '.jpeg', '.png')):
                    continue
                p = os.path.join(folder, f)
                img = self._safe_open(p)
                if img is None:
                    continue
                try:
                    x = tf(img).unsqueeze(0).to(self.device)
                    feat = model(x).detach().squeeze().cpu().numpy()
                    embs.append(feat / np.linalg.norm(feat))
                    paths.append(p)
                except Exception as e:
                    print(f"[Swin Error] {f}: {e}")

        embs = np.array(embs)
        self._save_cache(folder, "swin", embeddings=embs, paths=np.array(paths))
        return embs, paths

    def CLIPEmbedding(self, folder):
        cache = self._load_cache(folder, "clip")
        if cache:
            return cache["embeddings"], list(cache["paths"])

        model, _, preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai")
        model.to(self.device).eval()
        embs, paths = [], []
        with torch.no_grad():
            for f in tqdm(os.listdir(folder), desc="[CLIP] Embeddings"):
                if not f.lower().endswith(('.jpg', '.jpeg', '.png')):
                    continue
                p = os.path.join(folder, f)
                img = self._safe_open(p)
                if img is None:
                    continue
                try:
                    x = preprocess(img).unsqueeze(0).to(self.device)
                    feat = model.encode_image(x)
                    feat /= feat.norm(dim=-1, keepdim=True)
                    embs.append(feat.cpu().numpy().flatten())
                    paths.append(p)
                except Exception as e:
                    print(f"[CLIP Error] {f}: {e}")

        embs = np.array(embs)
        self._save_cache(folder, "clip", embeddings=embs, paths=np.array(paths))
        return embs, paths

    def DINOEmbedding(self, folder):
        cache = self._load_cache(folder, "dino")
        if cache:
            return cache["embeddings"], list(cache["paths"])

        model_id = "facebook/dinov2-base"
        processor = AutoImageProcessor.from_pretrained(model_id)
        model = AutoModel.from_pretrained(model_id).to(self.device).eval()
        embs, paths = [], []
        with torch.no_grad():
            for f in tqdm(os.listdir(folder), desc="[DINOv2] Embeddings"):
                if not f.lower().endswith(('.jpg', '.jpeg', '.png')):
                    continue
                p = os.path.join(folder, f)
                img = self._safe_open(p)
                if img is None:
                    continue
                try:
                    x = processor(images=img, return_tensors="pt").to(self.device)
                    feat = model(**x).last_hidden_state.mean(dim=1)
                    feat = feat / feat.norm(dim=-1, keepdim=True)
                    embs.append(feat.cpu().numpy().flatten())
                    paths.append(p)
                except Exception as e:
                    print(f"[DINO Error] {f}: {e}")

        embs = np.array(embs)
        self._save_cache(folder, "dino", embeddings=embs, paths=np.array(paths))
        return embs, paths

    # ==========================================================
    # Clustering (cached)
    # ==========================================================

    def cluster_embeddings(self, path, embeddings="clip", method="kmeans",
                           n_clusters=5, eps=0.5, min_samples=5):
        cache_key = f"{embeddings}_{method}"
        cache = self._load_cache(path, cache_key)
        if cache:
            return cache["labels"], cache["embeddings"], list(cache["paths"])

        print("================================")
        print(f"       {embeddings.upper()} + {method.upper()}")
        print("================================")

        if embeddings == "clip": emb, paths = self.CLIPEmbedding(path)
        elif embeddings == "swin": emb, paths = self.SwinEmbedding(path)
        elif embeddings == "resnet": emb, paths = self.ResNetEmbedding(path)
        elif embeddings == "dino": emb, paths = self.DINOEmbedding(path)
        else:
            raise ValueError("Embedding must be one of ['clip','swin','resnet','dino'].")

        if len(emb) == 0:
            raise RuntimeError("No embeddings generated (empty folder or all images failed).")

        if method == "kmeans":
            labels = KMeans(n_clusters=n_clusters, random_state=42, n_init='auto').fit_predict(emb)
        elif method == "dbscan":
            labels = DBSCAN(eps=eps, min_samples=min_samples).fit_predict(emb)
        elif method == "hdbscan":
            if hdbscan is None:
                raise ImportError("Install hdbscan first: pip install hdbscan")
            labels = hdbscan.HDBSCAN(min_cluster_size=min_samples).fit_predict(emb)
        else:
            raise ValueError("Method must be one of ['kmeans','dbscan','hdbscan'].")

        self._save_cache(path, cache_key, labels=labels, embeddings=emb, paths=np.array(paths))
        return labels, emb, paths

    # ==========================================================
    # Visualization (cached)
    # ==========================================================

    def visualize_clusters(self, embeddings, labels, paths,
                           title="Embedding Clusters", method="tsne", save=False, cache_key=None):
        if cache_key:
            cache = self._load_cache(os.path.dirname(paths[0]), f"vis_{cache_key}")
            if cache:
                emb2d = cache["emb2d"]
                print(f"[Cache] Loaded visualization for {cache_key}")
            else:
                print(f"[Compute] Generating visualization for {cache_key}")
                reducer = TSNE(n_components=2, perplexity=30, random_state=42) \
                    if method == "tsne" else PCA(n_components=2)
                emb2d = reducer.fit_transform(embeddings)
                self._save_cache(os.path.dirname(paths[0]), f"vis_{cache_key}", emb2d=emb2d)
        else:
            reducer = TSNE(n_components=2, perplexity=30, random_state=42) \
                if method == "tsne" else PCA(n_components=2)
            emb2d = reducer.fit_transform(embeddings)

        plt.figure(figsize=(7, 6))
        plt.scatter(emb2d[:, 0], emb2d[:, 1], c=labels, cmap='tab10', s=10)
        plt.title(title)
        plt.axis('off')
        if save:
            os.makedirs("cluster_results", exist_ok=True)
            out = os.path.join("cluster_results", f"{title.replace(' ', '_')}.png")
            plt.savefig(out, bbox_inches='tight')
            print(f"Saved visualization → {out}")
        if self._ipython_env:
            plt.show()
        else:
            plt.close()

    # ==========================================================
    # Cleaning + Summary Report
    # ==========================================================

    def cleanframe(self, path, embeddings='clip', cluster=None,
                   threshold=0.95, output_root="frames_cleaned"):
        start = datetime.now()
        paths = sorted([os.path.join(path, f) for f in os.listdir(path)
                        if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
        print(f"Found {len(paths)} images in {path}")

        if embeddings == "clip": embs, _ = self.CLIPEmbedding(path)
        elif embeddings == "swin": embs, _ = self.SwinEmbedding(path)
        elif embeddings == "resnet": embs, _ = self.ResNetEmbedding(path)
        elif embeddings == "dino": embs, _ = self.DINOEmbedding(path)
        else:
            raise ValueError("Embedding type invalid.")

        labels = np.zeros(len(paths)) if cluster is None else np.array(cluster)

        os.makedirs(output_root, exist_ok=True)
        total_kept, total_removed, kept_global = 0, 0, set()
        report_rows = []

        for cid in [c for c in np.unique(labels) if c != -1]:
            idx = np.where(labels == cid)[0]
            cluster_paths = [paths[i] for i in idx]
            if len(cluster_paths) == 0:
                continue

            md5_set, md5_unique = set(), []
            for p in cluster_paths:
                try:
                    h = hashlib.md5(open(p, 'rb').read()).hexdigest()
                    if h not in md5_set:
                        md5_set.add(h)
                        md5_unique.append(p)
                except:
                    continue

            phash_dict, phash_unique = {}, []
            for p in md5_unique:
                try:
                    h = imagehash.phash(Image.open(p))
                    if not any(h - hh < 5 for hh in phash_dict.values()):
                        phash_dict[p] = h
                        phash_unique.append(p)
                except:
                    continue

            kept = [phash_unique[0]] if phash_unique else []
            kept_idx = [paths.index(phash_unique[0])] if phash_unique else []

            for p in phash_unique[1:]:
                idxg = paths.index(p)
                sims = cosine_similarity(embs[idxg].reshape(1, -1), embs[kept_idx])[0]
                if np.max(sims) < threshold:
                    kept.append(p)
                    kept_idx.append(idxg)

            for p in kept:
                if p not in kept_global:
                    shutil.copy(p, os.path.join(output_root, os.path.basename(p)))
                    kept_global.add(p)

            kept_count, total_count = len(kept), len(cluster_paths)
            removed_count = total_count - kept_count
            kept_pct = (kept_count / total_count) * 100
            removed_pct = (removed_count / total_count) * 100
            report_rows.append([cid, total_count, kept_count, removed_count, f"{kept_pct:.1f}%", f"{removed_pct:.1f}%"])
            total_kept += kept_count
            total_removed += removed_count

        # --- Summary Table ---
        print("\n========== CLEANFRAME REPORT ==========")
        headers = ["Cluster ID", "Images", "Kept", "Removed", "Kept %", "Removed %"]
        print(tabulate(report_rows, headers=headers, tablefmt="fancy_grid"))
        total_imgs = total_kept + total_removed
        print("\n--------------------------------------")
        print(f" Total Images      : {total_imgs}")
        print(f" Kept              : {total_kept} ({(total_kept/total_imgs)*100:.2f}%)")
        print(f" Removed           : {total_removed} ({(total_removed/total_imgs)*100:.2f}%)")
        print(f" Output Directory  : {output_root}")
        print(f" Time Elapsed      : {datetime.now() - start}")
        print("--------------------------------------\n")

        return {
            "kept": list(kept_global),
            "removed": total_removed,
            "clusters": len(np.unique(labels)),
            "report": report_rows
        }