# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Contains the TrialComponent class."""
from __future__ import absolute_import

import time

from botocore.exceptions import ClientError

from sagemaker.apiutils import _base_types
from sagemaker.experiments import _api_types
from sagemaker.experiments._api_types import TrialComponentSearchResult
from sagemaker.utils import format_tags


class _TrialComponent(_base_types.Record):
    """This class represents a SageMaker trial component object.

    A trial component is a stage in a trial.
    Trial components are created automatically within the SageMaker runtime and
    may not be created directly. To automatically associate trial components with
    a trial and experiment, supply an experiment config when creating a job.
    For example: https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTrainingJob.html

    Attributes:
        trial_component_name (str): The name of the trial component. Generated by SageMaker
            from the name of the source job with a suffix specific to the type of source job.
        trial_component_arn (str): The ARN of the trial component.
        display_name (str): The name of the trial component that will appear in UI,
            such as SageMaker Studio.
        source (TrialComponentSource): A TrialComponentSource object with a source_arn attribute.
        status (str): Status of the source job.
        start_time (datetime): When the source job started.
        end_time (datetime): When the source job ended.
        creation_time (datetime): When the source job was created.
        created_by (obj): Contextual info on which account created the trial component.
        last_modified_time (datetime): When the trial component was last modified.
        last_modified_by (obj): Contextual info on which account last modified the trial component.
        parameters (dict): Dictionary of parameters to the source job.
        input_artifacts (dict): Dictionary of input artifacts.
        output_artifacts (dict): Dictionary of output artifacts.
        metrics (obj): Aggregated metrics for the job.
        parameters_to_remove (list): The hyperparameters to remove from the component.
        input_artifacts_to_remove (list): The input artifacts to remove from the component.
        output_artifacts_to_remove (list): The output artifacts to remove from the component.
        tags (List[Dict[str, str]]): A list of tags to associate with the trial component.
    """

    trial_component_name = None
    trial_component_arn = None
    display_name = None
    source = None
    status = None
    start_time = None
    end_time = None
    creation_time = None
    created_by = None
    last_modified_time = None
    last_modified_by = None
    parameters = None
    input_artifacts = None
    output_artifacts = None
    metrics = None
    parameters_to_remove = None
    input_artifacts_to_remove = None
    output_artifacts_to_remove = None
    tags = None

    _boto_load_method = "describe_trial_component"
    _boto_create_method = "create_trial_component"
    _boto_update_method = "update_trial_component"
    _boto_delete_method = "delete_trial_component"

    _custom_boto_types = {
        "source": (_api_types.TrialComponentSource, False),
        "status": (_api_types.TrialComponentStatus, False),
        "parameters": (_api_types.TrialComponentParameters, False),
        "input_artifacts": (_api_types.TrialComponentArtifact, True),
        "output_artifacts": (_api_types.TrialComponentArtifact, True),
        "metrics": (_api_types.TrialComponentMetricSummary, True),
    }

    _boto_update_members = [
        "trial_component_name",
        "display_name",
        "status",
        "start_time",
        "end_time",
        "parameters",
        "input_artifacts",
        "output_artifacts",
        "parameters_to_remove",
        "input_artifacts_to_remove",
        "output_artifacts_to_remove",
    ]
    _boto_delete_members = ["trial_component_name"]

    def __init__(self, sagemaker_session=None, **kwargs):
        """Init for _TrialComponent"""
        super().__init__(sagemaker_session, **kwargs)
        self.parameters = self.parameters or {}
        self.input_artifacts = self.input_artifacts or {}
        self.output_artifacts = self.output_artifacts or {}

    @classmethod
    def _boto_ignore(cls):
        """Response fields to ignore by default."""
        return super(_TrialComponent, cls)._boto_ignore() + ["CreatedBy"]

    def save(self):
        """Save the state of this TrialComponent to SageMaker."""
        return self._invoke_api(self._boto_update_method, self._boto_update_members)

    def delete(self, force_disassociate=False):
        """Delete this TrialComponent from SageMaker.

        Args:
            force_disassociate (boolean): Indicates whether to force disassociate the
                trial component with the trials before deletion (default: False).
                If set to true, force disassociate the trial component with associated trials
                first, then delete the trial component.
                If it's not set or set to false, it will delete the trial component directory
                without disassociation.

          Returns:
            dict: Delete trial component response.
        """
        if force_disassociate:
            next_token = None

            while True:
                if next_token:
                    list_trials_response = self.sagemaker_session.sagemaker_client.list_trials(
                        TrialComponentName=self.trial_component_name, NextToken=next_token
                    )
                else:
                    list_trials_response = self.sagemaker_session.sagemaker_client.list_trials(
                        TrialComponentName=self.trial_component_name
                    )

                # Disassociate the trials and trial components
                for per_trial in list_trials_response["TrialSummaries"]:
                    # to prevent DisassociateTrialComponent throttling
                    time.sleep(1.2)
                    self.sagemaker_session.sagemaker_client.disassociate_trial_component(
                        TrialName=per_trial["TrialName"],
                        TrialComponentName=self.trial_component_name,
                    )

                if "NextToken" in list_trials_response:
                    next_token = list_trials_response["NextToken"]
                else:
                    break

        return self._invoke_api(self._boto_delete_method, self._boto_delete_members)

    @classmethod
    def load(cls, trial_component_name, sagemaker_session=None):
        """Load an existing trial component and return an `_TrialComponent` object representing it.

        Args:
            trial_component_name (str): Name of the trial component
            sagemaker_session (sagemaker.session.Session): Session object which
                manages interactions with Amazon SageMaker APIs and any other
                AWS services needed. If not specified, one is created using the
                default AWS configuration chain.

        Returns:
            experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object
        """
        trial_component = cls._construct(
            cls._boto_load_method,
            trial_component_name=trial_component_name,
            sagemaker_session=sagemaker_session,
        )
        return trial_component

    @classmethod
    def create(cls, trial_component_name, display_name=None, tags=None, sagemaker_session=None):
        """Create a trial component and return a `_TrialComponent` object representing it.

        Args:
            trial_component_name (str): The name of the trial component.
            display_name (str): Display name of the trial component used by Studio (default: None).
            tags (Optional[Tags]): Tags to add to the trial component (default: None).
            sagemaker_session (sagemaker.session.Session): Session object which
                manages interactions with Amazon SageMaker APIs and any other
                AWS services needed. If not specified, one is created using the
                default AWS configuration chain.

        Returns:
            experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object.
        """
        return super(_TrialComponent, cls)._construct(
            cls._boto_create_method,
            trial_component_name=trial_component_name,
            display_name=display_name,
            tags=format_tags(tags),
            sagemaker_session=sagemaker_session,
        )

    @classmethod
    def list(
        cls,
        source_arn=None,
        created_before=None,
        created_after=None,
        sort_by=None,
        sort_order=None,
        sagemaker_session=None,
        trial_name=None,
        experiment_name=None,
        max_results=None,
        next_token=None,
    ):
        """Return a list of trial component summaries.

        Args:
            source_arn (str): A SageMaker Training or Processing Job ARN (default: None).
            created_before (datetime.datetime): Return trial components created before this instant
                (default: None).
            created_after (datetime.datetime): Return trial components created after this instant
                (default: None).
            sort_by (str): Which property to sort results by. One of 'Name', 'CreationTime'
                (default: None).
            sort_order (str): One of 'Ascending', or 'Descending' (default: None).
            sagemaker_session (sagemaker.session.Session): Session object which
                manages interactions with Amazon SageMaker APIs and any other
                AWS services needed. If not specified, one is created using the
                default AWS configuration chain.
            trial_name (str): If provided only trial components related to the trial are returned
                (default: None).
            experiment_name (str): If provided only trial components related to the experiment are
                returned (default: None).
            max_results (int): maximum number of trial components to retrieve (default: None).
            next_token (str): token for next page of results (default: None).
        Returns:
            collections.Iterator[experiments._api_types.TrialComponentSummary]: An iterator
                over `TrialComponentSummary` objects.
        """
        return super(_TrialComponent, cls)._list(
            "list_trial_components",
            _api_types.TrialComponentSummary.from_boto,
            "TrialComponentSummaries",
            source_arn=source_arn,
            created_before=created_before,
            created_after=created_after,
            sort_by=sort_by,
            sort_order=sort_order,
            sagemaker_session=sagemaker_session,
            trial_name=trial_name,
            experiment_name=experiment_name,
            max_results=max_results,
            next_token=next_token,
        )

    @classmethod
    def search(
        cls,
        search_expression=None,
        sort_by=None,
        sort_order=None,
        max_results=None,
        sagemaker_session=None,
    ):
        """Search Experiment Trail Component.

        Returns SearchResults in the account matching the search criteria.

        Args:
            search_expression: (SearchExpression): A Boolean conditional statement (default: None).
                Resource objects must satisfy this condition to be included in search results.
                You must provide at least one subexpression, filter, or nested filter.
            sort_by (str): The name of the resource property used to sort the SearchResults
                (default: None).
            sort_order (str): How SearchResults are ordered. Valid values are Ascending or
                Descending (default: None).
            max_results (int): The maximum number of results to return in a SearchResponse
                (default: None).
            sagemaker_session (sagemaker.session.Session): Session object which
                manages interactions with Amazon SageMaker APIs and any other
                AWS services needed. If not specified, one is created using the
                default AWS configuration chain.

        Returns:
            collections.Iterator[SearchResult] : An iterator over search results matching the
            search criteria.
        """
        return super(_TrialComponent, cls)._search(
            search_resource="ExperimentTrialComponent",
            search_item_factory=TrialComponentSearchResult.from_boto,
            search_expression=None if search_expression is None else search_expression.to_boto(),
            sort_by=sort_by,
            sort_order=sort_order,
            max_results=max_results,
            sagemaker_session=sagemaker_session,
        )

    @classmethod
    def _load_or_create(
        cls, trial_component_name, display_name=None, tags=None, sagemaker_session=None
    ):
        """Load a trial component by name and create a new one if it does not exist.

        Args:
            trial_component_name (str): The name of the trial component.
            display_name (str): Display name of the trial component used by Studio (default: None).
                This is used only when the given `trial_component_name` does not
                exist and a new trial component has to be created.
            tags (Optional[Tags]): Tags to add to the trial component (default: None).
                This is used only when the given `trial_component_name` does not
                exist and a new trial component has to be created.
            sagemaker_session (sagemaker.session.Session): Session object which
                manages interactions with Amazon SageMaker APIs and any other
                AWS services needed. If not specified, one is created using the
                default AWS configuration chain.

        Returns:
            experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object.
            bool: A boolean variable indicating whether the trail component already exists
        """
        is_existed = False
        try:
            run_tc = _TrialComponent.create(
                trial_component_name=trial_component_name,
                display_name=display_name,
                tags=format_tags(tags),
                sagemaker_session=sagemaker_session,
            )
        except ClientError as ce:
            error_code = ce.response["Error"]["Code"]
            error_message = ce.response["Error"]["Message"]
            if not (error_code == "ValidationException" and "already exists" in error_message):
                raise ce
            # already exists
            run_tc = _TrialComponent.load(trial_component_name, sagemaker_session)
            is_existed = True
        return run_tc, is_existed

    @classmethod
    def _trial_component_is_associated_to_trial(
        cls, trial_component_name, trial_name=None, sagemaker_session=None
    ):
        """Returns a bool based on if trial_component is already associated with the trial.

        Args:
            trial_component_name (str): The name of the trial component.
            trial_name: (str): The name of the trial.
            sagemaker_session (sagemaker.session.Session): Session object which
                manages interactions with Amazon SageMaker APIs and any other
                AWS services needed.

        Returns:
            bool: A boolean variable indicating whether the trial component is already
                  associated with the trial.

        """
        search_results = sagemaker_session.sagemaker_client.search(
            Resource="ExperimentTrialComponent",
            SearchExpression={
                "Filters": [
                    {
                        "Name": "TrialComponentName",
                        "Operator": "Equals",
                        "Value": str(trial_component_name),
                    },
                    {
                        "Name": "Parents.TrialName",
                        "Operator": "Equals",
                        "Value": str(trial_name),
                    },
                ]
            },
        )
        if search_results["Results"]:
            return True
        return False
