from typing import Union

import torch

from esp_ppq.core import (
    PASSIVE_OPERATIONS,
    OperationQuantizationConfig,
    QuantizationPolicy,
    QuantizationProperty,
    QuantizationStates,
    RoundingPolicy,
    TargetPlatform,
)
from esp_ppq.IR import BaseGraph, GraphCommandProcessor, Operation

from .base import BaseQuantizer


class MNNQuantizer(BaseQuantizer):
    def __init__(self, graph: Union[BaseGraph, GraphCommandProcessor]) -> Union[torch.Tensor, list, dict]:
        super().__init__(graph=graph)
        self._num_of_bits = 8
        self._quant_min = -127
        self._quant_max = +127

    def init_quantize_config(self, operation: Operation) -> OperationQuantizationConfig:
        base_quant_config = self.create_default_quant_config(
            policy=self.quantize_policy,
            rounding=self.rounding_policy,
            op=operation,
            num_of_bits=self._num_of_bits,
            exponent_bits=0,
            quant_max=self._quant_max,
            quant_min=self._quant_min,
            observer_algorithm='percentile',
        )

        if operation.type == 'Conv':
            assert operation.num_of_input > 0, 'Seems you got a Conv layer with no parameters.'

            if operation.inputs[1].is_parameter:
                conv_weight_config = base_quant_config.input_quantization_config[1]
                conv_weight_config.policy = QuantizationPolicy(
                    QuantizationProperty.SYMMETRICAL + QuantizationProperty.LINEAR + QuantizationProperty.PER_CHANNEL
                )
                conv_weight_config.channel_axis = 0
                conv_weight_config.observer_algorithm = 'minmax'

            if operation.num_of_input > 2:
                bias_config = base_quant_config.input_quantization_config[-1]
                bias_config.state = QuantizationStates.FP32

        if operation.type in PASSIVE_OPERATIONS:
            # Those op are not active op.
            base_quant_config.is_active_quant_op = False
        return base_quant_config

    @property
    def target_platform(self) -> TargetPlatform:
        return TargetPlatform.MNN_INT8

    @property
    def default_platform(self) -> TargetPlatform:
        return TargetPlatform.FP32

    @property
    def quant_operation_types(self) -> set:
        return {'Conv', 'Add', 'Gemm'}

    @property
    def quantize_policy(self) -> QuantizationPolicy:
        return QuantizationPolicy(
            QuantizationProperty.SYMMETRICAL + QuantizationProperty.LINEAR + QuantizationProperty.PER_TENSOR
        )

    @property
    def rounding_policy(self) -> RoundingPolicy:
        return RoundingPolicy.ROUND_HALF_FAR_FORM_ZERO

    @property
    def activation_fusion_types(self) -> set:
        return {'Relu', 'Clip', 'Swish', 'SoftPlus', 'Sigmoid'}
