import os
import re
import warnings
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import tensorflow as tf
import tensorflow.keras.applications as tensorflow_models
import timm
import torch
import torchvision
from tensorflow.keras.layers import Lambda
from torch.hub import load_state_dict_from_url

from thingsvision.utils.checkpointing import get_torch_home
from thingsvision.utils.models.dino import vit_base, vit_small, vit_tiny
from thingsvision.utils.models.mae import (
    interpolate_pos_embed,
    vit_base_patch16,
    vit_huge_patch14,
    vit_large_patch16,
)

from .tensorflow import TensorFlowExtractor
from .torch import PyTorchExtractor

# neccessary to prevent gpu memory conflicts between torch and tf
gpus = tf.config.list_physical_devices("GPU")
if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.list_logical_devices("GPU")
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

Tensor = torch.Tensor
Array = np.ndarray


class TorchvisionExtractor(PyTorchExtractor):
    def __init__(
        self,
        model_name: str,
        pretrained: bool,
        device: str,
        model_path: str = None,
        model_parameters: Dict[str, Union[str, bool, List[str]]] = None,
        preprocess: Optional[Callable] = None,
    ) -> None:
        model_parameters = (
            model_parameters if model_parameters else {"weights": "DEFAULT"}
        )
        super().__init__(
            model_name=model_name,
            pretrained=pretrained,
            model_path=model_path,
            model_parameters=model_parameters,
            preprocess=preprocess,
            device=device,
        )

    def get_weights(self, model_name: str, suffix: str = "_weights") -> Any:
        weights_name = None
        for m in dir(torchvision.models):
            if m.lower() == model_name + suffix:
                weights_name = m
                break
        if not weights_name:
            raise ValueError(
                f"\nCould not find pretrained weights for {model_name} in <torchvision>. Choose a different model or change the source.\n"
            )
        weights = getattr(
            getattr(torchvision.models, f"{weights_name}"),
            self.model_parameters["weights"],
        )
        return weights

    def load_model_from_source(self) -> None:
        """Load a (pretrained) neural network model from <torchvision>."""
        if hasattr(torchvision.models, self.model_name):
            model = getattr(torchvision.models, self.model_name)
            if self.pretrained:
                self.weights = self.get_weights(self.model_name)
            else:
                self.weights = None
            self.model = model(weights=self.weights)
        else:
            raise ValueError(
                f"\nCould not find {self.model_name} in torchvision library.\nChoose a different model.\n"
            )

    def get_default_transformation(
        self,
        mean,
        std,
        resize_dim: int = 256,
        crop_dim: int = 224,
        apply_center_crop: bool = True,
    ) -> Any:
        if self.weights:
            warnings.warn(
                message="\nInput arguments are ignored because transforms are automatically inferred from model weights.\n",
                category=UserWarning,
                stacklevel=2,
            )
            transforms = self.weights.transforms()
        else:
            transforms = super().get_default_transformation(
                mean, std, resize_dim, crop_dim, apply_center_crop
            )

        return transforms


class TimmExtractor(PyTorchExtractor):
    def __init__(
        self,
        model_name: str,
        pretrained: bool,
        device: str,
        model_path: str = None,
        model_parameters: Dict[str, Union[str, bool, List[str]]] = None,
        preprocess: Optional[Callable] = None,
    ) -> None:
        super().__init__(
            model_name=model_name,
            pretrained=pretrained,
            model_path=model_path,
            model_parameters=model_parameters,
            preprocess=preprocess,
            device=device,
        )

    def load_model_from_source(self) -> None:
        """Load a (pretrained) neural network model from <timm>."""
        if self.model_name.split(".")[0] in timm.list_models():
            self.model = timm.create_model(self.model_name, pretrained=self.pretrained)
        else:
            raise ValueError(
                f"\nCould not find {self.model_name} in timm library.\nChoose a different model.\n"
            )

    def get_default_transformation(
        self,
        mean,
        std,
        resize_dim: int = 256,
        crop_dim: int = 224,
        apply_center_crop: bool = True,
    ) -> Any:
        warnings.warn(
            message="\nInput arguments are ignored because <timm> automatically infers transforms from model config.\n",
            category=UserWarning,
            stacklevel=2,
        )
        data_config = timm.data.resolve_model_data_config(self.model)
        transforms = timm.data.create_transform(**data_config, is_training=False)

        return transforms


class KerasExtractor(TensorFlowExtractor):
    def __init__(
        self,
        model_name: str,
        pretrained: bool,
        device: str,
        model_path: str = None,
        model_parameters: Dict[str, Union[str, bool, List[str]]] = None,
        preprocess: Optional[Callable] = None,
    ) -> None:
        model_parameters = (
            model_parameters if model_parameters else {"weights": "imagenet"}
        )
        super().__init__(
            model_name=model_name,
            pretrained=pretrained,
            model_path=model_path,
            model_parameters=model_parameters,
            preprocess=preprocess,
            device=device,
        )

    def load_model_from_source(self) -> None:
        """Load a (pretrained) neural network model from <keras>."""
        if hasattr(tensorflow_models, self.model_name):
            model = getattr(tensorflow_models, self.model_name)
            if self.pretrained:
                weights = self.model_parameters["weights"]
            elif self.model_path:
                weights = self.model_path
            else:
                weights = None
            self.model = model(weights=weights)
            preproc_fun_name = self.get_keras_preprocessing(self.model_name)
            if isinstance(preproc_fun_name, str):
                # get preprocessing function for a specific model
                preproc_fun = self.get_preproc_fun(preproc_fun_name)
                # different models take differently sized inputs. this has to be accounted for.
                resize_dim = self.model.layers[0].input_shape[0][
                    -2
                ]  # -2 and -3 are the H and W channel dims.
                self.preprocess = tf.keras.Sequential(
                    [
                        Lambda(preproc_fun),
                        tf.keras.layers.experimental.preprocessing.Resizing(
                            resize_dim, resize_dim
                        ),
                    ]
                )
        else:
            raise ValueError(
                f"\nCould not find {self.model_name} among TensorFlow models.\n"
            )

    @staticmethod
    def get_preproc_fun(preproc_fun_name: str) -> Callable:
        """Get the preprocessing function associated with a specific model."""
        return getattr(getattr(tensorflow_models, preproc_fun_name), "preprocess_input")

    def get_keras_preprocessing(self, model_name: str) -> Union[str, None]:
        """Get the preprocessing function for the corresponding model from `tensorflow.keras.applications.*`"""

        patterns = [
            (r"^ConvNeXt(Base|Large|Small|Tiny|XLarge)$", "convnext"),
            (r"^DenseNet\d+$", "densenet"),
            (r"^EfficientNetB[0-7]$", "efficientnet"),
            (r"^EfficientNetV2(B[0-3]|[LMS])$", "efficientnet_v2"),
            (r"^InceptionResNetV2$", "inception_resnet_v2"),
            (r"^InceptionV3$", "inception_v3"),
            (r"^MobileNet$", "mobilenet"),
            (r"^MobileNetV2$", "mobilenet_v2"),
            (r"^MobileNetV3(Large|Small)$", "mobilenet_v3"),
            (r"^NasNet(Large|Mobile)$", "nasnet"),
            (r"^ResNet\d+$", "resnet"),
            (r"^ResNet\d+V2$", "resnet_v2"),
            (r"^VGG16$", "vgg16"),
            (r"^VGG19$", "vgg19"),
            (r"^Xception$", "xception"),
        ]
        # Try each pattern
        for pattern, preproc_val in patterns:
            if re.search(pattern, model_name):
                return preproc_val

        # If no match is found, print a warning message
        warnings.warn(
            f"No preprocessing function found for model {model_name}, so falling back to default preprocessing.\nOften, models that come from Keras Applications have their own preprocessing functions.\nThus, this may create inaccurate results. If you need to manually specify a preprocessing function, please do so under the `transforms` argument when creating your Dataset"
        )
        return None


class SSLExtractor(PyTorchExtractor):
    MODELS = {
        "simclr-rn50": {
            "url": "https://dl.fbaipublicfiles.com/vissl/model_zoo/simclr_rn50_800ep_simclr_8node_resnet_16_07_20.7e8feed1/model_final_checkpoint_phase799.torch",
            "arch": "resnet50",
            "type": "vissl",
        },
        "mocov2-rn50": {
            "url": "https://dl.fbaipublicfiles.com/vissl/model_zoo/moco_v2_1node_lr.03_step_b32_zero_init/model_final_checkpoint_phase199.torch",
            "arch": "resnet50",
            "type": "vissl",
        },
        "jigsaw-rn50": {
            "url": "https://dl.fbaipublicfiles.com/vissl/model_zoo/jigsaw_rn50_in1k_ep105_perm2k_jigsaw_8gpu_resnet_17_07_20.db174a43/model_final_checkpoint_phase104.torch",
            "arch": "resnet50",
            "type": "vissl",
        },
        "rotnet-rn50": {
            "url": "https://dl.fbaipublicfiles.com/vissl/model_zoo/rotnet_rn50_in1k_ep105_rotnet_8gpu_resnet_17_07_20.46bada9f/model_final_checkpoint_phase125.torch",
            "arch": "resnet50",
            "type": "vissl",
        },
        "swav-rn50": {
            "url": "https://dl.fbaipublicfiles.com/vissl/model_zoo/swav_in1k_rn50_800ep_swav_8node_resnet_27_07_20.a0a6b676/model_final_checkpoint_phase799.torch",
            "arch": "resnet50",
            "type": "vissl",
        },
        "pirl-rn50": {
            "url": "https://dl.fbaipublicfiles.com/vissl/model_zoo/pirl_jigsaw_4node_pirl_jigsaw_4node_resnet_22_07_20.34377f59/model_final_checkpoint_phase799.torch",
            "arch": "resnet50",
            "type": "vissl",
        },
        "barlowtwins-rn50": {
            "arch": "resnet50",
            "type": "checkpoint_url",
            "checkpoint_url": "https://dl.fbaipublicfiles.com/barlowtwins/ljng/resnet50.pth",
        },
        "vicreg-rn50": {
            "arch": "resnet50",
            "type": "checkpoint_url",
            "checkpoint_url": "https://dl.fbaipublicfiles.com/vicreg/resnet50.pth",
        },
        "dino-vit-small-p16": {
            "repository": "facebookresearch/dino:main",
            "arch": "dino_vits16",
            "type": "hub",
            "checkpoint_url": "https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth",
        },
        "dino-vit-small-p8": {
            "repository": "facebookresearch/dino:main",
            "arch": "dino_vits8",
            "type": "hub",
            "checkpoint_url": "https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth",
        },
        "dino-vit-base-p16": {
            "repository": "facebookresearch/dino:main",
            "arch": "dino_vitb16",
            "type": "hub",
            "checkpoint_url": "https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth",
        },
        "dino-vit-base-p8": {
            "repository": "facebookresearch/dino:main",
            "arch": "dino_vitb8",
            "type": "hub",
            "checkpoint_url": "https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth",
        },
        "dino-xcit-small-12-p16": {
            "repository": "facebookresearch/dino:main",
            "arch": "dino_xcit_small_12_p16",
            "type": "hub",
        },
        "dino-xcit-small-12-p8": {
            "repository": "facebookresearch/dino:main",
            "arch": "dino_xcit_small_12_p8",
            "type": "hub",
        },
        "dino-xcit-medium-24-p16": {
            "repository": "facebookresearch/dino:main",
            "arch": "dino_xcit_medium_24_p16",
            "type": "hub",
        },
        "dino-xcit-medium-24-p8": {
            "repository": "facebookresearch/dino:main",
            "arch": "dino_xcit_medium_24_p8",
            "type": "hub",
        },
        "dino-rn50": {
            "repository": "facebookresearch/dino:main",
            "arch": "dino_resnet50",
            "type": "hub",
        },
        "dinov2-vit-small-p14": {
            "repository": "facebookresearch/dinov2",
            "arch": "dinov2_vits14",
            "type": "hub",
        },
        "dinov2-vit-base-p14": {
            "repository": "facebookresearch/dinov2",
            "arch": "dinov2_vitb14",
            "type": "hub",
        },
        "dinov2-vit-large-p14": {
            "repository": "facebookresearch/dinov2",
            "arch": "dinov2_vitl14",
            "type": "hub",
        },
        "dinov2-vit-giant-p14": {
            "repository": "facebookresearch/dinov2",
            "arch": "dinov2_vitg14",
            "type": "hub",
        },
        "mae-vit-base-p16": {
            "repository": "facebookresearch/mae",
            "arch": "mae_vit_base_patch16",
            "type": "hub",
            "checkpoint_url": "https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth",
        },
        "mae-vit-large-p16": {
            "repository": "facebookresearch/mae",
            "arch": "mae_vit_large_patch16",
            "type": "hub",
            "checkpoint_url": "https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_large.pth",
        },
        "mae-vit-huge-p14": {
            "repository": "facebookresearch/mae",
            "arch": "mae_vit_huge_patch14",
            "type": "hub",
            "checkpoint_url": "https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_huge.pth",
        },
    }

    def __init__(
        self,
        model_name: str,
        pretrained: bool,
        device: str,
        model_path: str = None,
        model_parameters: Dict[str, Union[str, bool, List[str]]] = None,
        preprocess: Optional[Callable] = None,
    ) -> None:
        super().__init__(
            model_name=model_name,
            pretrained=pretrained,
            model_path=model_path,
            model_parameters=model_parameters,
            preprocess=preprocess,
            device=device,
        )

    def _load_vissl_state_dict(self, model_url: str, unique_model_filename: str):
        """
        Downloads the model in vissl format, converts it to torchvision format and
        caches it under the unique_model_filename. Therefore, this file_name should be unique
        per url. Otherwise, the wrong cached variant is loaded.
        """
        model = load_state_dict_from_url(
            model_url, map_location=torch.device("cpu"), file_name=unique_model_filename
        )

        # get the model trunk to rename
        if "classy_state_dict" in model.keys():
            model_trunk = model["classy_state_dict"]["base_model"]["model"]["trunk"]
        elif "model_state_dict" in model.keys():
            model_trunk = model["model_state_dict"]
        else:
            model_trunk = model

        converted_model = self._replace_module_prefix(model_trunk, "_feature_blocks.")
        return converted_model

    def _replace_module_prefix(
        self, state_dict: Dict[str, Any], prefix: str, replace_with: str = ""
    ):
        """
        Remove prefixes in a state_dict needed when loading models that are not VISSL
        trained models.
        Specify the prefix in the keys that should be removed.
        """
        state_dict = {
            (
                key.replace(prefix, replace_with, 1) if key.startswith(prefix) else key
            ): val
            for (key, val) in state_dict.items()
        }
        return state_dict

    def load_model_from_source(self) -> None:
        """
        Load a (pretrained) neural network model from vissl. Downloads the model when it is not available.
        Otherwise, loads it from the cache directory.
        """
        if self.model_name in SSLExtractor.MODELS:

            # unique model id name for all models
            unique_model_filename = f"thingsvision_ssl_v0_{self.model_name}.pth"

            # defines how the model should be loaded
            model_config = SSLExtractor.MODELS[self.model_name]

            # VISSL MODELS
            if model_config["type"] == "vissl":
                model_state_dict = self._load_vissl_state_dict(
                    model_url=model_config["url"],
                    unique_model_filename=unique_model_filename,
                )
                self.model = getattr(torchvision.models, model_config["arch"])()
                if model_config["arch"] == "resnet50":
                    self.model.fc = torch.nn.Identity()
                self.model.load_state_dict(model_state_dict, strict=True)

            # HUB MODELS
            elif model_config["type"] == "hub":
                if self.model_name.startswith("dino-vit"):
                    if self.model_name == "dino-vit-tiny-p8":
                        model = vit_tiny(patch_size=8)
                    elif self.model_name == "dino-vit-tiny-p16":
                        model = vit_tiny(patch_size=16)
                    elif self.model_name == "dino-vit-small-p8":
                        model = vit_small(patch_size=8)
                    elif self.model_name == "dino-vit-small-p16":
                        model = vit_small(patch_size=16)
                    elif self.model_name == "dino-vit-base-p8":
                        model = vit_base(patch_size=8)
                    elif self.model_name == "dino-vit-base-p16":
                        model = vit_base(patch_size=16)
                    else:
                        raise ValueError(f"\n{self.model_name} is not available.\n")
                    state_dict = torch.hub.load_state_dict_from_url(
                        model_config["checkpoint_url"],
                        map_location=torch.device("cpu"),
                        # This is used to cache the file
                        file_name=unique_model_filename,
                    )
                    model.load_state_dict(state_dict, strict=True)
                    self.model = model
                elif self.model_name.startswith("mae"):
                    if self.model_name == "mae-vit-base-p16":
                        model = vit_base_patch16(num_classes=0, drop_rate=0.0)
                    elif self.model_name == "mae-vit-large-p16":
                        model = vit_large_patch16(num_classes=0, drop_rate=0.0)
                    elif self.model_name == "mae-vit-huge-p14":
                        model = vit_huge_patch14(num_classes=0, drop_rate=0.0)
                    else:
                        raise ValueError(f"\n{self.model_name} is not available.\n")
                    state_dict = torch.hub.load_state_dict_from_url(
                        model_config["checkpoint_url"],
                        map_location=torch.device("cpu"),
                        file_name=unique_model_filename,
                    )
                    checkpoint_model = state_dict["model"]
                    # interpolate position embedding
                    interpolate_pos_embed(model, checkpoint_model)
                    model.load_state_dict(checkpoint_model, strict=False)
                    self.model = model
                else:
                    self.model = torch.hub.load(
                        model_config["repository"], model_config["arch"]
                    )
                    if model_config["arch"] == "resnet50":
                        self.model.fc = torch.nn.Identity()

            # MODELS FROM CHECKPOINT URL
            elif model_config["type"] == "checkpoint_url":

                # load architecture
                self.model = getattr(torchvision.models, model_config["arch"])()
                if model_config["arch"] == "resnet50":
                    self.model.fc = torch.nn.Identity()

                # load and cache state_dict
                state_dict = torch.hub.load_state_dict_from_url(
                    model_config["checkpoint_url"],
                    map_location=torch.device("cpu"),
                    # IMPORTANT that this is unique as it will be used for caching
                    file_name=unique_model_filename,
                )

                # load state dict to model
                self.model.load_state_dict(state_dict, strict=True)

            else:
                type = model_config["type"]
                raise ValueError(f"\nUnknown model type: {type}.\n")
        else:
            raise ValueError(
                f"\nCould not find {self.model_name} in the SSLExtractor.\nUse a different model.\n"
            )
