# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import wraps
from typing import Any, Callable, Optional, TypeVar

import lightning.pytorch as pl
from torch import nn

from nemo.utils import logging


class ModelTransform(pl.Callback):
    """
    A PyTorch Lightning callback that applies a model transformation function at the start of fitting or validation.

    This callback is designed to apply a transformation to the model when fitting or validation begins.
    This design allows for loading the original checkpoint first and then applying the transformation,
    which is particularly useful for techniques like Parameter-Efficient Fine-Tuning (PEFT).

    The transformation function is expected to be defined on the LightningModule
    as an attribute called 'model_transform'.

    Key Features:
    - Applies transformation at the start of fit or validation, not during initialization.
    - Allows loading of original checkpoints before transformation.
    - Supports PEFT and similar techniques that modify model structure.

    Example:
        >>> class MyLightningModule(pl.LightningModule):
        ...     def __init__(self):
        ...         super().__init__()
        ...         self.model = SomeModel()
        ...         self.model_transform = lambda m: SomePEFTMethod()(m)
        ...
        >>> model = MyLightningModule()
        >>> # Load original checkpoint here if needed
        >>> model.load_state_dict(torch.load('original_checkpoint.pth'))
        >>> trainer = pl.Trainer(callbacks=[ModelTransform()])
        >>> # The model will be transformed when trainer.fit() or trainer.validate() is called
        >>> trainer.fit(model)

    Note:
        The transformation is applied only once, at the start of fitting or validation,
        whichever comes first. This ensures that the model structure is modified before
        any forward passes or parameter updates occur, but after the original weights
        have been loaded.
    """

    def __init__(self):
        super().__init__()
        self.model_transform: Optional[Callable[[nn.Module], nn.Module]] = None

    def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
        """Setups model transform"""
        logging.info(f"Setting up ModelTransform for stage: {stage}")

        if hasattr(pl_module, 'model_transform'):
            logging.info("Found model_transform attribute on pl_module")
            self.model_transform = _call_counter(pl_module.model_transform)
            pl_module.model_transform = self.model_transform
            logging.info(f"Set model_transform to: {self.model_transform}")
        else:
            logging.info("No model_transform attribute found on pl_module")

    def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        """event hook"""
        self._maybe_apply_transform(trainer)

    def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        """event hook"""
        self._maybe_apply_transform(trainer)

    def _maybe_apply_transform(self, trainer):
        """Applies transform if haven't already"""
        if self._needs_to_call:
            self.apply_transform(trainer)

    def apply_transform(self, trainer):
        """Applies a model transform e.g., PeFT"""
        self.model_transform(trainer.model)
        from lightning.pytorch.utilities import model_summary

        summary = str(model_summary.summarize(trainer.lightning_module, max_depth=1))
        summary = "\n\r".join(summary.split("\n"))
        logging.info(f"After applying model_transform:\n\n\r{summary}")

    @property
    def _needs_to_call(self) -> bool:
        """boolean var to indicate whether need to run transform based on call counter"""
        return self.model_transform and self.model_transform.__num_calls__ == 0


T = TypeVar('T', bound=Callable[..., Any])


def _call_counter(func: T) -> T:
    """
    A decorator that counts the number of times a function is called.

    This decorator wraps a function and adds a '__num_calls__' attribute to it,
    which is incremented each time the function is called.

    Args:
        func (Callable): The function to be wrapped.

    Returns:
        Callable: The wrapped function with a call counter.
    """

    @wraps(func)
    def wrapper(*args, **kwargs):
        wrapper.__num_calls__ += 1
        return func(*args, **kwargs)

    wrapper.__num_calls__ = 0
    return wrapper  # type: ignore
