import lightning as pl
import torch
import torchmetrics
from lightning.pytorch.loggers import WandbLogger
from torch import nn

import stable_pretraining as spt
from stable_pretraining.forward import dinov2_forward
from stable_pretraining.data import transforms
import sys
from pathlib import Path

sys.path.append(str(Path(__file__).parent.parent))
from utils import get_data_dir

dinov2_transform = transforms.MultiViewTransform(
    {
        "global_1": transforms.Compose(
            transforms.RGB(),
            transforms.RandomResizedCrop((224, 224), scale=(0.4, 1.0)),
            transforms.ColorJitter(
                brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            transforms.PILGaussianBlur(p=1.0),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToImage(**spt.data.static.ImageNet),
        ),
        "global_2": transforms.Compose(
            transforms.RGB(),
            transforms.RandomResizedCrop((224, 224), scale=(0.4, 1.0)),
            transforms.ColorJitter(
                brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            transforms.PILGaussianBlur(p=0.1),
            transforms.RandomSolarize(threshold=0.5, p=0.2),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToImage(**spt.data.static.ImageNet),
        ),
        "local_1": transforms.Compose(
            transforms.RGB(),
            transforms.RandomResizedCrop((96, 96), scale=(0.05, 0.4)),
            transforms.ColorJitter(
                brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            transforms.PILGaussianBlur(p=0.5),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToImage(**spt.data.static.ImageNet),
        ),
        "local_2": transforms.Compose(
            transforms.RGB(),
            transforms.RandomResizedCrop((96, 96), scale=(0.05, 0.4)),
            transforms.ColorJitter(
                brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            transforms.PILGaussianBlur(p=0.5),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToImage(**spt.data.static.ImageNet),
        ),
        "local_3": transforms.Compose(
            transforms.RGB(),
            transforms.RandomResizedCrop((96, 96), scale=(0.05, 0.4)),
            transforms.ColorJitter(
                brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            transforms.PILGaussianBlur(p=0.5),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToImage(**spt.data.static.ImageNet),
        ),
        "local_4": transforms.Compose(
            transforms.RGB(),
            transforms.RandomResizedCrop((96, 96), scale=(0.05, 0.4)),
            transforms.ColorJitter(
                brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            transforms.PILGaussianBlur(p=0.5),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToImage(**spt.data.static.ImageNet),
        ),
        "local_5": transforms.Compose(
            transforms.RGB(),
            transforms.RandomResizedCrop((96, 96), scale=(0.05, 0.4)),
            transforms.ColorJitter(
                brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            transforms.PILGaussianBlur(p=0.5),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToImage(**spt.data.static.ImageNet),
        ),
        "local_6": transforms.Compose(
            transforms.RGB(),
            transforms.RandomResizedCrop((96, 96), scale=(0.05, 0.4)),
            transforms.ColorJitter(
                brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            transforms.PILGaussianBlur(p=0.5),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToImage(**spt.data.static.ImageNet),
        ),
    }
)

val_transform = transforms.Compose(
    transforms.RGB(),
    transforms.Resize((256, 256)),
    transforms.CenterCrop((224, 224)),
    transforms.ToImage(**spt.data.static.ImageNet),
)

data_dir = get_data_dir("imagenet100")

train_dataset = spt.data.HFDataset(
    "clane9/imagenet-100",
    split="train",
    cache_dir=str(data_dir),
    transform=dinov2_transform,
)
val_dataset = spt.data.HFDataset(
    "clane9/imagenet-100",
    split="validation",
    cache_dir=str(data_dir),
    transform=val_transform,
)

batch_size = 32  # Per GPU batch size (256 / 8 GPUs)
train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    num_workers=32,
    drop_last=True,
    persistent_workers=True,
    shuffle=True,
)
val_dataloader = torch.utils.data.DataLoader(
    dataset=val_dataset,
    batch_size=batch_size,
    num_workers=32,
    persistent_workers=True,
)

data = spt.data.DataModule(train=train_dataloader, val=val_dataloader)

backbone = spt.backbone.vit_hf(
    size="tiny",
    patch_size=16,
    image_size=224,
    pretrained=False,
    use_mask_token=True,  # Required for iBOT
)

wrapped_backbone = spt.TeacherStudentWrapper(
    backbone,
    warm_init=True,
    base_ema_coefficient=0.9995,
    final_ema_coefficient=1.0,
)

# CLS token projector (DINO component)
projector = nn.Sequential(
    nn.Linear(192, 2048),
    nn.BatchNorm1d(2048),
    nn.GELU(),
    nn.Linear(2048, 2048),
    nn.BatchNorm1d(2048),
    nn.GELU(),
    nn.Linear(2048, 256),
    spt.utils.nn_modules.L2Norm(),
    nn.Linear(256, 65536, bias=False),  # Prototypes layer
)

wrapped_projector = spt.TeacherStudentWrapper(
    projector,
    warm_init=True,
    base_ema_coefficient=0.9995,
    final_ema_coefficient=1.0,
)

# Patch projector (iBOT component)
patch_projector = nn.Sequential(
    nn.Linear(192, 2048),
    nn.BatchNorm1d(2048),
    nn.GELU(),
    nn.Linear(2048, 2048),
    nn.BatchNorm1d(2048),
    nn.GELU(),
    nn.Linear(2048, 256),
    spt.utils.nn_modules.L2Norm(),
    nn.Linear(256, 65536, bias=False),  # Prototypes layer
)

wrapped_patch_projector = spt.TeacherStudentWrapper(
    patch_projector,
    warm_init=True,
    base_ema_coefficient=0.9995,
    final_ema_coefficient=1.0,
)

module = spt.Module(
    backbone=wrapped_backbone,
    projector=wrapped_projector,
    patch_projector=wrapped_patch_projector,
    forward=dinov2_forward,
    dinov2_loss=spt.losses.DINOv2Loss(
        temperature_student=0.1,
        center_momentum=0.9,
        ibot_loss_weight=1.0,
    ),
    mask_ratio=0.3,
    optim={
        "optimizer": {
            "type": "AdamW",
            "lr": 0.005,
            "weight_decay": 1e-4,
        },
        "scheduler": {
            "type": "LinearWarmupCosineAnnealing",
        },
        "interval": "epoch",
    },
)

teacher_student_callback = spt.callbacks.TeacherStudentCallback(
    update_frequency=1,
    update_after_backward=False,
)

linear_probe = spt.callbacks.OnlineProbe(
    name="linear_probe",
    input="embedding",
    target="label",
    probe=nn.Linear(192, 100),
    loss_fn=nn.CrossEntropyLoss(),
    metrics={
        "top1": torchmetrics.classification.MulticlassAccuracy(100),
        "top5": torchmetrics.classification.MulticlassAccuracy(100, top_k=5),
    },
    optimizer={
        "type": "AdamW",
        "lr": 3e-3,
        "weight_decay": 1e-4,
    },
)

knn_probe = spt.callbacks.OnlineKNN(
    name="knn_probe",
    input="embedding",
    target="label",
    queue_length=20000,
    metrics={"accuracy": torchmetrics.classification.MulticlassAccuracy(100)},
    input_dim=192,
    k=20,
)

wandb_logger = WandbLogger(
    entity="stable-ssl",
    project="imagenet100-dinov2",
    name="dinov2-vit-tiny-8gpus",
    log_model=False,
)

trainer = pl.Trainer(
    max_epochs=400,
    num_sanity_val_steps=0,
    callbacks=[teacher_student_callback, linear_probe, knn_probe],
    precision="16-mixed",
    logger=wandb_logger,
    devices=8,  # Use 8 GPUs
    strategy="ddp_find_unused_parameters_true",  # DDP with unused parameters detection
    sync_batchnorm=True,
    accelerator="gpu",
)

manager = spt.Manager(trainer=trainer, module=module, data=data)
manager()
