# Copyright (C) 2025 Embedl AB

"""
Functions for quantizing TensorFlow Lite models.
"""

from pathlib import Path

import numpy as np
import tensorflow as tf
from ai_edge_quantizer.quantizer import Quantizer
from ai_edge_quantizer.recipe_manager import ModelQuantizationRecipe
from tensorflow.lite.python import schema_py_generated as schema_fb

from embedl_hub.core.utils.tflite_utils import instantiate_tflite_interpreter
from embedl_hub.thirdparty.ai_edge_quantizer import recipe
from embedl_hub.thirdparty.aimet.utils import compute_psnr


def _quantize_input(float_input, detail):
    """Quantize a float input tensor based on tensor details."""
    qp = detail.get('quantization_parameters', {})
    scales = np.asarray(qp.get('scales', []), dtype=np.float32)
    zps = np.asarray(qp.get('zero_points', []), dtype=np.int32)

    if scales.size == 0:
        # No quantization info, return original input
        return float_input

    # Perform quantization: (float / scale) + zero_point
    quantized = float_input / scales + zps

    # Clip to the dtype's min/max values
    dtype_info = np.iinfo(detail['dtype'])
    quantized = np.clip(quantized, dtype_info.min, dtype_info.max)

    return quantized.astype(detail['dtype'])


def _forward_tflite(
    interpreter: tf.lite.Interpreter,
    input_data: dict[str, np.ndarray],
):
    """Run inference with TFLite model."""

    input_details = interpreter.get_input_details()

    for name, input_sample in input_data.items():
        inp = next((d for d in input_details if name in d['name']), None)
        if not inp:
            raise ValueError(f"Input tensor {name} not found in model.")
        input_sample = _quantize_input(input_sample, inp)
        interpreter.set_tensor(inp['index'], input_sample)

    interpreter.invoke()


def _dequantize(arr, detail):
    """Dequantize an int8/uint8 tensor based on tensor details."""
    qp = detail.get('quantization_parameters', {})
    scales = np.asarray(qp.get('scales', []), dtype=np.float32)
    zps = np.asarray(qp.get('zero_points', []), dtype=np.float32)

    # No quantization info -> just cast
    if scales.size == 0:
        return arr.astype(np.float32)

    x = arr.astype(np.float32)

    # Per-tensor case
    if scales.size == 1:
        scale = scales.item()
        zp = zps.item() if zps.size else 0.0
        return scale * (x - zp)

    # Per-channel case
    axis = qp.get('quantized_dimension', x.ndim - 1)
    if axis < 0:
        axis += x.ndim

    # Reshape scales to broadcast along 'axis'
    shape = [1] * x.ndim
    shape[axis] = scales.size
    scales = scales.reshape(shape)

    # zero_points:
    #  - if 0 or 1 value -> broadcast scalar
    #  - if same length as scales -> reshape like scales
    #  - otherwise, fall back to scalar (common in TFLite: per-channel scales with zp==0)
    if zps.size in (0, 1):
        zp = zps.item() if zps.size else 0.0
        return scales * (x - zp)
    if zps.size == scales.size:
        zps = zps.reshape(shape)
        return scales * (x - zps)
    # Unusual mismatch: warn and treat as scalar
    zp = zps.flat[0]
    return scales * (x - zp)


def _get_builtin_op_name(builtin_code: int) -> str:
    """Get the operation name from builtin code using TFLite schema."""
    # Use the BuiltinOperator enum from the TFLite schema
    try:
        # Get all attributes that start with uppercase (enum values)
        builtin_ops = [
            attr
            for attr in dir(schema_fb.BuiltinOperator)
            if attr.isupper() and not attr.startswith('_')
        ]

        # Create a mapping from enum values to names
        builtin_op_map = {}
        for name in builtin_ops:
            value = getattr(schema_fb.BuiltinOperator, name)
            builtin_op_map[value] = name

        return builtin_op_map.get(builtin_code, f"BUILTIN_{builtin_code}")
    except Exception:
        # Fallback if schema access fails
        return f"BUILTIN_{builtin_code}"


def parse_tflite_model(model_path: str) -> dict:
    """Parse TFLite model to extract operation types for each tensor."""
    # Read the model file
    with open(model_path, 'rb') as f:
        buf = f.read()

    # Parse the model
    model = schema_fb.Model.GetRootAs(buf, 0)

    # Get the subgraph (assuming single subgraph)
    subgraph = model.Subgraphs(0)

    # Create mapping from tensor index to operation type
    tensor_to_op_type = {}

    # Initialize all tensors as unknown
    for i in range(subgraph.TensorsLength()):
        tensor_to_op_type[i] = "Unknown"

    # Process each operator
    for op_idx in range(subgraph.OperatorsLength()):
        operator = subgraph.Operators(op_idx)

        # Get operator code
        opcode_idx = operator.OpcodeIndex()
        opcode = model.OperatorCodes(opcode_idx)

        # Get builtin code
        builtin_code = opcode.BuiltinCode()

        # Get operation name using TFLite schema
        op_name = _get_builtin_op_name(builtin_code)

        # Map output tensors to this operation type
        for j in range(operator.OutputsLength()):
            output_tensor_idx = operator.Outputs(j)
            tensor_to_op_type[output_tensor_idx] = op_name

        # For input tensors, if they don't have an op type yet, mark as input
        for j in range(operator.InputsLength()):
            input_tensor_idx = operator.Inputs(j)
            if tensor_to_op_type.get(input_tensor_idx, "Unknown") == "Unknown":
                # Check if this is actually a model input
                is_model_input = False
                for k in range(subgraph.InputsLength()):
                    if subgraph.Inputs(k) == input_tensor_idx:
                        is_model_input = True
                        break

                if is_model_input:
                    tensor_to_op_type[input_tensor_idx] = "INPUT"
                else:
                    # This is an intermediate tensor without a producing op (constant)
                    tensor_to_op_type[input_tensor_idx] = "CONSTANT"

    return tensor_to_op_type


def _truncate_name(name, max_len=50):
    """Truncate long tensor names for display."""
    if len(name) > max_len:
        return "..." + name[-(max_len - 3) :]
    return name


def _print_psnr_table(
    float_interpreter: tf.lite.Interpreter,
    int8_interpreter: tf.lite.Interpreter,
    float_tensor_details: list[dict],
    int8_tensor_details: list[dict],
    float_tensor_to_op_type: dict[int, str],
):
    """Print a PSNR comparison table between float and int8 TFLite models."""
    # Print table header
    header = (
        f"{'Idx':<4} {'Name':<50} {'Shape':<20} {'Op Type':<15} {'PSNR (dB)':>10} "
        f"{'Float Mean':>12} {'Dequant Mean':>12}"
    )
    print(header)
    print("-" * len(header))

    for float_detail in float_tensor_details:
        idx = float_detail['index']
        # Get operation type from parsed TFLite model
        op_type = float_tensor_to_op_type.get(idx, "Unknown")
        if op_type == "CONSTANT":
            continue
        float_output = float_interpreter.get_tensor(float_detail['index'])

        # Find the corresponding int8 tensor detail by name
        matching_int8_detail = next(
            (
                d
                for d in int8_tensor_details
                if d['name'] == float_detail['name']
            ),
            None,
        )

        if not matching_int8_detail:
            print(
                f"Warning: No matching int8 tensor found for {float_detail['name']}"
            )
            continue

        int8_output = int8_interpreter.get_tensor(
            matching_int8_detail['index']
        )
        dequantized_output = _dequantize(int8_output, matching_int8_detail)

        # Compare float_output and dequantized_output
        psnr = compute_psnr(float_output, dequantized_output)
        if psnr == 100.0:
            continue

        # Prepare data for the table row
        name = _truncate_name(float_detail['name'])
        shape_str = str(list(float_output.shape))

        float_mean = float_output.mean()
        dequant_mean = dequantized_output.mean()

        # Print table row
        row = (
            f"{idx:<4} {name:<50} {shape_str:<20} {op_type:<15} {psnr:>10.2f} "
            f"{float_mean:>12.6g} {dequant_mean:>12.6g}"
        )
        print(row)


def _make_random_calibration_data(
    input_names: list[str],
    input_details: list[dict],
) -> list[dict[str, np.ndarray]]:
    """Generate random calibration data for a TFLite model.

    Args:
        input_names: List of input tensor names.
        input_details: List of input tensor details from the interpreter.
    """

    calibration_data = {}
    for name in input_names:
        detail = next((d for d in input_details if name in d['name']), None)
        if detail is None:
            raise ValueError(f"Input tensor {name} not found in model.")
        shape = detail['shape']
        dtype = detail['dtype']
        # Use random data in the valid range for the dtype
        if np.issubdtype(dtype, np.integer):
            info = np.iinfo(dtype)
            data = np.random.randint(
                info.min, info.max + 1, size=shape, dtype=dtype
            )
        else:
            data = np.random.randn(*shape).astype(dtype)
        calibration_data[name] = data

    # Generate a few calibration samples
    return [calibration_data]


def quantize_tflite_model(
    float_model_path: Path,
    int8_model_path: Path,
    calibration_data: list[dict[str, np.ndarray]] | None = None,
    quantization_recipe: ModelQuantizationRecipe | None = None,
    report_psnr: bool = True,
):
    """Quantize a TFLite model to int8 using AI Edge Quantizer.

    Args:
        float_model_path: Path to the input float TFLite model.
        int8_model_path: Path to save the quantized int8 TFLite model.
        calibration_data: Optional dictionary of input data for calibration.
        quantization_recipe: The quantization recipe to use.
        compute_psnr: Whether to compute and print PSNR between float and int8 models.
    """
    if quantization_recipe is None:
        quantization_recipe = recipe.static_wi8_ai8()

    float_interpreter = instantiate_tflite_interpreter(
        str(float_model_path), experimental_preserve_all_tensors=True
    )

    signatures: dict[str, dict[str, list[str]]] = (
        float_interpreter.get_signature_list()
    )
    signature_key = list(signatures.keys())[0]

    tflite_quantizer = Quantizer(
        str(float_model_path), quantization_recipe=quantization_recipe
    )

    calibration_data = calibration_data or _make_random_calibration_data(
        input_names=signatures[signature_key]['inputs'],
        input_details=float_interpreter.get_input_details(),
    )
    calibration_result = tflite_quantizer.calibrate(
        {signature_key: calibration_data}
    )
    quantization_result = tflite_quantizer.quantize(calibration_result)
    quantization_result.export_model(int8_model_path, overwrite=True)

    if not report_psnr:
        # No need to compute PSNR
        return

    quant_interpreter = instantiate_tflite_interpreter(
        str(int8_model_path), experimental_preserve_all_tensors=True
    )

    _forward_tflite(
        float_interpreter,
        input_data=calibration_data[0],
    )
    _forward_tflite(
        quant_interpreter,
        input_data=calibration_data[0],
    )

    _print_psnr_table(
        float_interpreter=float_interpreter,
        int8_interpreter=quant_interpreter,
        float_tensor_details=float_interpreter.get_tensor_details(),
        int8_tensor_details=quant_interpreter.get_tensor_details(),
        float_tensor_to_op_type=parse_tflite_model(str(float_model_path)),
    )
