# cleanframe/models.py
import torch, numpy as np
from torchvision import models, transforms
from timm import create_model
import open_clip
from .core import console, progress_bar

class EmbeddingExtractor:
    """
    EmbeddingExtractor:
    --------------------
    Extracts feature embeddings from images using:
      • ResNet-50
      • Swin Transformer Tiny
      • CLIP ViT-B/32
    """

    def __init__(self, device, batch_size=32):
        self.device = device
        self.batch_size = batch_size

    # ============================================================
    # Public API
    # ============================================================

    def get_embeddings(self, images, model_name):
        model_name = model_name.lower()
        if model_name == "resnet":
            return self._resnet(images)
        elif model_name == "swin":
            return self._swin(images)
        elif model_name == "clip":
            return self._clip(images)
        else:
            raise ValueError("model must be one of ['clip','resnet','swin'].")

    # ============================================================
    # ResNet-50
    # ============================================================

    def _resnet(self, images):
        console.print("[cyan]Loading ResNet-50...[/cyan]")
        model = models.resnet50(pretrained=True)
        model = torch.nn.Sequential(*list(model.children())[:-1]).to(self.device).eval()

        tf = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

        feats = []
        with torch.inference_mode(), progress_bar("Embedding (ResNet-50)", len(images)) as p:
            task = p.add_task("Embedding", total=len(images))
            for i in range(0, len(images), self.batch_size):
                batch = torch.stack([tf(im) for im in images[i:i+self.batch_size]]).to(self.device)
                out = model(batch).squeeze()
                if out.ndim == 1:
                    out = out.unsqueeze(0)
                out = out.view(out.size(0), -1)  # flatten
                out = out / out.norm(dim=-1, keepdim=True)
                feats.append(out.cpu().numpy())
                p.update(task, advance=len(batch))
                del batch, out
        feats = np.concatenate(feats, axis=0)
        return feats

    # ============================================================
    # Swin Transformer
    # ============================================================

    def _swin(self, images):
        console.print("[cyan]Loading Swin Transformer Tiny...[/cyan]")
        model = create_model("swin_tiny_patch4_window7_224", pretrained=True, num_classes=0).to(self.device).eval()

        tf = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

        feats = []
        with torch.inference_mode(), progress_bar("Embedding (Swin-Tiny)", len(images)) as p:
            task = p.add_task("Embedding", total=len(images))
            for i in range(0, len(images), self.batch_size):
                batch = torch.stack([tf(im) for im in images[i:i+self.batch_size]]).to(self.device)
                out = model(batch)
                if out.ndim == 1:
                    out = out.unsqueeze(0)
                out = out / out.norm(dim=-1, keepdim=True)
                feats.append(out.cpu().numpy())
                p.update(task, advance=len(batch))
                del batch, out
        feats = np.concatenate(feats, axis=0)
        return feats

    # ============================================================
    # CLIP ViT-B/32
    # ============================================================

    def _clip(self, images):
        console.print("[cyan]Loading CLIP ViT-B/32...[/cyan]")
        model, _, preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai")
        model.to(self.device).eval()

        feats = []
        with torch.inference_mode(), progress_bar("Embedding (CLIP ViT-B/32)", len(images)) as p:
            task = p.add_task("Embedding", total=len(images))
            for i in range(0, len(images), self.batch_size):
                batch = torch.stack([preprocess(im) for im in images[i:i+self.batch_size]]).to(self.device)
                out = model.encode_image(batch)
                out = out / out.norm(dim=-1, keepdim=True)
                feats.append(out.cpu().numpy())
                p.update(task, advance=len(batch))
                del batch, out
        feats = np.concatenate(feats, axis=0)
        return feats