# Copyright 2025 Rebellions Inc. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from transformers import PretrainedConfig

from ....utils import logging
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyModelForCausalLMConfig
from .qwen2_architecture import QWEN2Wrapper


logger = logging.get_logger(__name__)


class RBLNQwen2ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
    """
    The Qwen2 Model transformer with a language modeling head (linear layer) on top.
    This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.

    A class to convert and run pre-trained transformers based Qwen2ForCausalLM model on RBLN devices.
    It implements the methods to convert a pre-trained transformers Qwen2ForCausalLM model into a RBLN transformer model by:
    - transferring the checkpoint weights of the original into an optimized RBLN graph,
    - compiling the resulting graph using the RBLN compiler.

    **Configuration:**
    This model uses [`RBLNQwen2ForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
    the `rbln_config` parameter should be an instance of [`RBLNQwen2ForCausalLMConfig`] or a dictionary conforming to its structure.

    See the [`RBLNQwen2ForCausalLMConfig`] class for all available configuration options.

    Examples:
        ```python
        from optimum.rbln import RBLNQwen2ForCausalLM

        # Simple usage using rbln_* arguments
        # `max_seq_len` is automatically inferred from the model config
        model = RBLNQwen2ForCausalLM.from_pretrained(
            "Qwen/Qwen2-7B-Instruct",
            export=True,
            rbln_batch_size=1,
            rbln_tensor_parallel_size=4,
        )


        # Using a config dictionary
        rbln_config = {
            "batch_size": 1,
            "max_seq_len": 4096,
            "tensor_parallel_size": 4,
        }
        model = RBLNQwen2ForCausalLM.from_pretrained(
            "Qwen/Qwen2-7B-Instruct",
            export=True,
            rbln_config=rbln_config
        )


        # Using a RBLNQwen2ForCausalLMConfig instance (recommended for type checking)
        from optimum.rbln import RBLNQwen2ForCausalLMConfig

        config = RBLNQwen2ForCausalLMConfig(
            batch_size=1,
            max_seq_len=4096,
            tensor_parallel_size=4
        )
        model = RBLNQwen2ForCausalLM.from_pretrained(
            "Qwen/Qwen2-7B-Instruct",
            export=True,
            rbln_config=config
        )
        ```
    """

    _decoder_wrapper_cls = QWEN2Wrapper

    @classmethod
    def _update_sliding_window_config(
        cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
    ):
        # https://github.com/huggingface/transformers/issues/35896
        # There seems to be a bug in transformers(v4.52.4). Therefore, similar to when attn_implementation is eager,
        # we set all layers to use sliding window in this version. This should be updated once the bug is fixed.

        rbln_config.cache_impl = "sliding_window"
        rbln_config.sliding_window = model_config.sliding_window
        rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
        return rbln_config
