import os
import boto3
import botocore
import logging
import base64

import aws_encryption_sdk
from aws_encryption_sdk.key_providers.kms import StrictAwsKmsMasterKeyProvider
from aws_encryption_sdk.exceptions import AWSEncryptionSDKClientError

from mind_castle.secret_store_base import SecretStoreBase, MindCastleSecret
from mind_castle.exceptions import (
    RetrieveSecretException,
    CreateSecretException,
)

logger = logging.getLogger(__name__)


class AWSKMSSecretStore(SecretStoreBase):
    """
    Uses AWS KMS to encrypt secrets.
    """

    store_type = "awskms"
    required_config = [
        [
            "MIND_CASTLE_AWS_REGION",
            "MIND_CASTLE_AWS_KMS_KEY_ID",
            "MIND_CASTLE_AWS_SECRET_ACCESS_KEY",
            "MIND_CASTLE_AWS_KMS_KEY_ARN",
        ],
        [
            "MIND_CASTLE_AWS_REGION",
            "MIND_CASTLE_AWS_USE_ENV_AUTH",
            "MIND_CASTLE_AWS_KMS_KEY_ARN",
        ],
    ]
    optional_config = []

    KMS_ENCRYPTION_ALGORITHM = "RSAES_OAEP_SHA_256"

    def __init__(self):
        if all([os.environ.get(req) for req in self.required_config[0]]):
            # Configure with secret key
            session = botocore.session.get_session()
            session.set_credentials(
                os.environ["MIND_CASTLE_AWS_ACCESS_KEY_ID"],
                os.environ["MIND_CASTLE_AWS_SECRET_ACCESS_KEY"],
            )
            self.client = boto3.client(
                "kms",
                region_name=os.environ["MIND_CASTLE_AWS_REGION"],
                aws_access_key_id=os.environ["MIND_CASTLE_AWS_ACCESS_KEY_ID"],
                aws_secret_access_key=os.environ["MIND_CASTLE_AWS_SECRET_ACCESS_KEY"],
            )
        else:
            session = botocore.session.get_session()
            # Assume the environment is configured with the correct credentials
            self.client = boto3.client(
                "kms", region_name=os.environ["MIND_CASTLE_AWS_REGION"]
            )

        self.kms_key_arn = os.environ["MIND_CASTLE_AWS_KMS_KEY_ARN"]
        self.kms_provider = StrictAwsKmsMasterKeyProvider(
            key_ids=[self.kms_key_arn],
            botocore_session=session,
        )
        # High-level Encryption SDK client
        self.client = aws_encryption_sdk.EncryptionSDKClient()

    def retrieve_secret(self, secret: MindCastleSecret) -> str:
        encrypted_value = secret.encrypted_value

        # For a brief period we encrypted directly. This is to support legacy secrets.
        if not secret.metadata.get("envelope_encryption"):
            try:
                response = self.client.decrypt(
                    KeyId=self.kms_key_arn,
                    EncryptionAlgorithm=self.KMS_ENCRYPTION_ALGORITHM,
                    CiphertextBlob=base64.b64decode(encrypted_value),
                )
            except botocore.exceptions.ClientError as e:
                logger.exception(f"Error retrieving secret {secret.key}:")
                raise RetrieveSecretException(e, encrypted_value)
            plaintext = response.get("Plaintext")

        else:
            try:
                plaintext, _ = self.client.decrypt(
                    source=base64.b64decode(encrypted_value),
                    key_provider=self.kms_provider,
                )
            except (
                AWSEncryptionSDKClientError,
                botocore.exceptions.ClientError,
                Exception,
            ) as e:
                raise RetrieveSecretException(e, encrypted_value)

        return plaintext.decode("utf-8")

    def create_secret(self, value: str) -> MindCastleSecret:
        try:
            ciphertext, _ = self.client.encrypt(
                source=value.encode("utf-8"),
                key_provider=self.kms_provider,
            )
        except (
            AWSEncryptionSDKClientError,
            botocore.exceptions.ClientError,
            Exception,
        ) as e:
            raise CreateSecretException(e)

        return MindCastleSecret(
            mind_castle_secret_type=self.store_type,
            encrypted_value=base64.b64encode(ciphertext).decode("utf-8"),
            metadata={"envelope_encryption": True},
        )

    def update_secret(self, secret: MindCastleSecret, value: str) -> MindCastleSecret:
        new_secret = self.create_secret(value)
        new_secret.key = secret.key
        return new_secret

    def delete_secret(self, secret: MindCastleSecret) -> None:
        pass
