import os, torch, numpy as np
from tqdm import tqdm
from PIL import Image
from torchvision import models, transforms
from transformers import AutoImageProcessor, AutoModel
import open_clip
from timm import create_model

try:
    import onnxruntime as ort
except ImportError:
    ort = None


class CleanFrame_optimized:
    """
    =============================================================
    CleanFrame_optimized: High-Performance Cleaner for Large Datasets
    =============================================================
    Enhances CleanFrame with:
      • Batched embedding extraction
      • Mixed precision acceleration (FP16)
      • Optional ONNXRuntime inference
      • Automatic fallback to standard CleanFrame methods
    """

    def __init__(self, device="cuda", fast_mode=True, batch_size=64, use_onnx=False):
        from cleanframes.core import CleanFrame  # import your main class

        self.base = CleanFrame(device=device)
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")
        self.fast_mode = fast_mode
        self.batch_size = batch_size
        self.use_onnx = use_onnx and ort is not None
        self._fp16 = self.fast_mode and (self.device.type in ["cuda", "mps"])

        if self.use_onnx:
            print("[INFO] ONNX acceleration enabled.")
        elif use_onnx and ort is None:
            print("[WARNING] onnxruntime not installed, falling back to PyTorch.")

        if self._fp16:
            print("[INFO] Using mixed precision (FP16) for embeddings.")
        else:
            print("[INFO] Using full precision (FP32).")

    # ==========================================================
    # Optimized Embedding Helpers
    # ==========================================================

    def _batch_images(self, folder, tf):
        files = [os.path.join(folder, f) for f in os.listdir(folder)
                 if f.lower().endswith((".jpg", ".jpeg", ".png"))]
        images, paths = [], []
        for p in files:
            try:
                img = Image.open(p).convert("RGB")
                images.append(tf(img))
                paths.append(p)
            except Exception as e:
                print(f"[Warning] Skipping {p}: {e}")
        return images, paths

    def _batched_forward(self, model, images):
        embs = []
        model.eval()
        with torch.no_grad():
            for i in range(0, len(images), self.batch_size):
                batch = torch.stack(images[i:i+self.batch_size]).to(self.device)
                if self._fp16:
                    batch = batch.half()
                feat = model(batch)
                feat = feat.squeeze()
                if len(feat.shape) == 1:
                    feat = feat.unsqueeze(0)
                feat = feat.cpu().numpy()
                embs.append(feat)
        embs = np.concatenate(embs, axis=0)
        embs /= np.linalg.norm(embs, axis=1, keepdims=True)
        return embs

    # ==========================================================
    # ResNet (batched or ONNX)
    # ==========================================================

    def ResNetEmbedding(self, folder):
        cache = self.base._load_cache(folder, "resnet_opt")
        if cache:
            return cache["embeddings"], list(cache["paths"])

        if self.use_onnx:
            model_path = os.path.join(".cleanframe_cache", "resnet50.onnx")
            if not os.path.exists(model_path):
                torch_model = models.resnet50(pretrained=True)
                torch_model = torch.nn.Sequential(*list(torch_model.children())[:-1])
                torch_model.eval()
                dummy = torch.randn(1, 3, 224, 224)
                torch.onnx.export(torch_model, dummy, model_path, opset_version=12)
                print("[ONNX] Exported ResNet50 to ONNX.")
            session = ort.InferenceSession(model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])

            tf = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])

            images, paths = self._batch_images(folder, tf)
            embs = []
            for i in tqdm(range(0, len(images), self.batch_size), desc="[ONNX] ResNet Embeddings"):
                batch = torch.stack(images[i:i+self.batch_size]).numpy()
                feat = session.run(None, {"input": batch})[0]
                feat = feat.reshape(feat.shape[0], -1)
                feat /= np.linalg.norm(feat, axis=1, keepdims=True)
                embs.append(feat)
            embs = np.concatenate(embs, axis=0)
        else:
            model = models.resnet50(pretrained=True)
            model = torch.nn.Sequential(*list(model.children())[:-1]).to(self.device)
            tf = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])
            images, paths = self._batch_images(folder, tf)
            embs = self._batched_forward(model, images)

        self.base._save_cache(folder, "resnet_opt", embeddings=embs, paths=np.array(paths))
        return embs, paths

    # ==========================================================
    # CLIP (batched, mixed precision)
    # ==========================================================

    def CLIPEmbedding(self, folder):
        cache = self.base._load_cache(folder, "clip_opt")
        if cache:
            return cache["embeddings"], list(cache["paths"])

        model, _, preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai")
        model.to(self.device).eval()
        if self._fp16:
            model = model.half()

        files = [os.path.join(folder, f) for f in os.listdir(folder)
                 if f.lower().endswith((".jpg", ".jpeg", ".png"))]
        images, paths = [], []
        for p in files:
            try:
                img = preprocess(Image.open(p).convert("RGB"))
                images.append(img)
                paths.append(p)
            except Exception as e:
                print(f"[CLIP Warning] Skipping {p}: {e}")

        embs = []
        with torch.no_grad():
            for i in tqdm(range(0, len(images), self.batch_size), desc="[CLIP] Embeddings"):
                batch = torch.stack(images[i:i+self.batch_size]).to(self.device)
                if self._fp16:
                    batch = batch.half()
                feat = model.encode_image(batch)
                feat = feat / feat.norm(dim=-1, keepdim=True)
                embs.append(feat.cpu().numpy())
        embs = np.concatenate(embs, axis=0)
        self.base._save_cache(folder, "clip_opt", embeddings=embs, paths=np.array(paths))
        return embs, paths

    # ==========================================================
    # Swin Transformer (batched)
    # ==========================================================

    def SwinEmbedding(self, folder):
        cache = self.base._load_cache(folder, "swin_opt")
        if cache:
            return cache["embeddings"], list(cache["paths"])

        model = create_model('swin_tiny_patch4_window7_224', pretrained=True, num_classes=0)
        model.to(self.device).eval()
        if self._fp16:
            model = model.half()

        tf = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        images, paths = self._batch_images(folder, tf)
        embs = self._batched_forward(model, images)
        self.base._save_cache(folder, "swin_opt", embeddings=embs, paths=np.array(paths))
        return embs, paths

    # ==========================================================
    # DINOv2 (batched)
    # ==========================================================

    def DINOEmbedding(self, folder):
        cache = self.base._load_cache(folder, "dino_opt")
        if cache:
            return cache["embeddings"], list(cache["paths"])

        model_id = "facebook/dinov2-base"
        processor = AutoImageProcessor.from_pretrained(model_id)
        model = AutoModel.from_pretrained(model_id).to(self.device).eval()
        if self._fp16:
            model = model.half()

        files = [os.path.join(folder, f) for f in os.listdir(folder)
                 if f.lower().endswith((".jpg", ".jpeg", ".png"))]
        embs, paths = [], []
        with torch.no_grad():
            for i in tqdm(range(0, len(files), self.batch_size), desc="[DINOv2] Embeddings"):
                batch_imgs = [Image.open(p).convert("RGB") for p in files[i:i+self.batch_size]]
                x = processor(images=batch_imgs, return_tensors="pt").to(self.device)
                if self._fp16:
                    x = {k: v.half() for k, v in x.items()}
                feat = model(**x).last_hidden_state.mean(dim=1)
                feat = feat / feat.norm(dim=-1, keepdim=True)
                embs.append(feat.cpu().numpy())
                paths.extend(files[i:i+self.batch_size])
        embs = np.concatenate(embs, axis=0)
        self.base._save_cache(folder, "dino_opt", embeddings=embs, paths=np.array(paths))
        return embs, paths

    # ==========================================================
    # Proxy Methods
    # ==========================================================

    def cluster_embeddings(self, *args, **kwargs):
        return self.base.cluster_embeddings(*args, **kwargs)

    def visualize_clusters(self, *args, **kwargs):
        return self.base.visualize_clusters(*args, **kwargs)

    def cleanframe(self, *args, **kwargs):
        return self.base.cleanframe(*args, **kwargs)