# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Common config for QualityCheckStep and ClarifyCheckStep."""
from __future__ import absolute_import

import logging
from typing import Optional

from sagemaker import Session
from sagemaker.model_monitor import (
    ModelMonitor,
    DefaultModelMonitor,
    ModelQualityMonitor,
    ModelBiasMonitor,
    ModelExplainabilityMonitor,
)
from sagemaker.utils import format_tags


class CheckJobConfig:
    """Check job config for QualityCheckStep and ClarifyCheckStep."""

    def __init__(
        self,
        role,
        instance_count=1,
        instance_type="ml.m5.xlarge",
        volume_size_in_gb=30,
        volume_kms_key=None,
        output_kms_key=None,
        max_runtime_in_seconds=None,
        base_job_name=None,
        sagemaker_session=None,
        env=None,
        tags=None,
        network_config=None,
    ):
        """Constructs a CheckJobConfig instance.

        Args:
            role (str): An AWS IAM role. The Amazon SageMaker jobs use this role.
            instance_count (int): The number of instances to run the jobs with (default: 1).
            instance_type (str): Type of EC2 instance to use for the job
                (default: 'ml.m5.xlarge').
            volume_size_in_gb (int): Size in GB of the EBS volume
                to use for storing data during processing (default: 30).
            volume_kms_key (str): A KMS key for the processing volume (default: None).
            output_kms_key (str): The KMS key id for the job's outputs (default: None).
            max_runtime_in_seconds (int): Timeout in seconds. After this amount of
                time, Amazon SageMaker terminates the job regardless of its current status.
                Default: 3600 if not specified
            base_job_name (str): Prefix for the job name. If not specified,
                a default name is generated based on the training image name and
                current timestamp (default: None).
            sagemaker_session (sagemaker.session.Session): Session object which
                manages interactions with Amazon SageMaker APIs and any other
                AWS services needed (default: None). If not specified, one is
                created using the default AWS configuration chain.
            env (dict): Environment variables to be passed to the job (default: None).
            tags (Optional[Tags]): List of tags to be passed to the job (default: None).
            network_config (sagemaker.network.NetworkConfig): A NetworkConfig
                object that configures network isolation, encryption of
                inter-container traffic, security group IDs, and subnets (default: None).

        """
        self.role = role
        self.instance_count = instance_count
        self.instance_type = instance_type
        self.volume_size_in_gb = volume_size_in_gb
        self.volume_kms_key = volume_kms_key
        self.output_kms_key = output_kms_key
        self.max_runtime_in_seconds = max_runtime_in_seconds
        self.base_job_name = base_job_name
        self.sagemaker_session = sagemaker_session or Session()
        self.env = env
        self.tags = format_tags(tags)
        self.network_config = network_config

    def _generate_model_monitor(self, mm_type: str) -> Optional[ModelMonitor]:
        """Generates a ModelMonitor object

        Generates a ModelMonitor object with required config attributes for
            QualityCheckStep and ClarifyCheckStep

        Args:
            mm_type (str): The subclass type of ModelMonitor object.
                A valid mm_type should be one of the following: "DefaultModelMonitor",
                "ModelQualityMonitor", "ModelBiasMonitor", "ModelExplainabilityMonitor"

        Return:
            sagemaker.model_monitor.ModelMonitor or None if the mm_type is not valid

        """
        if mm_type == "DefaultModelMonitor":
            monitor = DefaultModelMonitor(
                role=self.role,
                instance_count=self.instance_count,
                instance_type=self.instance_type,
                volume_size_in_gb=self.volume_size_in_gb,
                volume_kms_key=self.volume_kms_key,
                output_kms_key=self.output_kms_key,
                max_runtime_in_seconds=self.max_runtime_in_seconds,
                base_job_name=self.base_job_name,
                sagemaker_session=self.sagemaker_session,
                env=self.env,
                tags=self.tags,
                network_config=self.network_config,
            )
        elif mm_type == "ModelQualityMonitor":
            monitor = ModelQualityMonitor(
                role=self.role,
                instance_count=self.instance_count,
                instance_type=self.instance_type,
                volume_size_in_gb=self.volume_size_in_gb,
                volume_kms_key=self.volume_kms_key,
                output_kms_key=self.output_kms_key,
                max_runtime_in_seconds=self.max_runtime_in_seconds,
                base_job_name=self.base_job_name,
                sagemaker_session=self.sagemaker_session,
                env=self.env,
                tags=self.tags,
                network_config=self.network_config,
            )
        elif mm_type == "ModelBiasMonitor":
            monitor = ModelBiasMonitor(
                role=self.role,
                instance_count=self.instance_count,
                instance_type=self.instance_type,
                volume_size_in_gb=self.volume_size_in_gb,
                volume_kms_key=self.volume_kms_key,
                output_kms_key=self.output_kms_key,
                max_runtime_in_seconds=self.max_runtime_in_seconds,
                base_job_name=self.base_job_name,
                sagemaker_session=self.sagemaker_session,
                env=self.env,
                tags=self.tags,
                network_config=self.network_config,
            )
        elif mm_type == "ModelExplainabilityMonitor":
            monitor = ModelExplainabilityMonitor(
                role=self.role,
                instance_count=self.instance_count,
                instance_type=self.instance_type,
                volume_size_in_gb=self.volume_size_in_gb,
                volume_kms_key=self.volume_kms_key,
                output_kms_key=self.output_kms_key,
                max_runtime_in_seconds=self.max_runtime_in_seconds,
                base_job_name=self.base_job_name,
                sagemaker_session=self.sagemaker_session,
                env=self.env,
                tags=self.tags,
                network_config=self.network_config,
            )
        else:
            logging.warning(
                'Expected model monitor types: "DefaultModelMonitor", "ModelQualityMonitor", '
                '"ModelBiasMonitor", "ModelExplainabilityMonitor"'
            )
            return None
        return monitor
