import logging
import hashlib
import json
import datetime
import base64
import os
import copy

from botocore.vendored import six
from botocore.exceptions import (ConnectionError,
                                 EndpointConnectionError)

import kmsauth.services
# Try to import the more efficient lru-dict, and fallback to slower pure-python
# lru dict implementation if it's not available.
try:
    from lru import LRU
except ImportError:
    from kmsauth.utils.lru import LRUCache as LRU

TOKEN_SKEW = 3
TIME_FORMAT = "%Y%m%dT%H%M%SZ"


def ensure_text(str_or_bytes, encoding='utf-8'):
    """Ensures an input is a string, decoding if it is bytes.
    """
    if not isinstance(str_or_bytes, six.text_type):
        return str_or_bytes.decode(encoding)
    return str_or_bytes


def ensure_bytes(str_or_bytes, encoding='utf-8', errors='strict'):
    """Ensures an input is bytes, encoding if it is a string.
    """
    if isinstance(str_or_bytes, six.text_type):
        return str_or_bytes.encode(encoding, errors)
    return str_or_bytes


class KMSTokenValidator(object):

    """A class that represents a token validator for KMS auth."""

    def __init__(
            self,
            auth_key,
            user_auth_key,
            to_auth_context,
            region,
            scoped_auth_keys=None,
            minimum_token_version=1,
            maximum_token_version=2,
            auth_token_max_lifetime=60,
            aws_creds=None,
            extra_context=None,
            endpoint_url=None,
            token_cache_size=4096,
            stats=None,
            max_pool_connections=None,
            connect_timeout=None,
            read_timeout=None,
            ):
        """Create a KMSTokenValidator object.

        Args:
            auth_key: A list of KMS key ARNs or aliases to use for service
                authentication. Required.
            user_auth_key: A list of KMS key ARNs or aliases to use for user
                authentication. Required.
            to_auth_context: The KMS encryption context to use for the to
                context for authentication. Required.
            region: AWS region to connect to. Required.
            scoped_auth_keys: A dict of KMS key to account mappings. These keys
            are for the 'service' role to support multiple AWS accounts. If
            services are scoped to accounts, kmsauth will ensure the service
            authentication KMS auth used the mapped key.
            Example: {"sandbox-auth-key":"sandbox","primary-auth-key":"primary"}
            minimum_token_version: The minimum version of the authentication
            token accepted.
            maximum_token_version: The maximum version of the authentication
            token accepted.
            auth_token_max_lifetime: The maximum lifetime of an authentication
            token in minutes.
            token_cache_size: Size of the in-memory LRU cache for auth tokens.
            aws_creds: A dict of AccessKeyId, SecretAccessKey, SessionToken.
                Useful if you wish to pass in assumed role credentials or MFA
                credentials. Default: None
            endpoint_url: A URL to override the default endpoint used to access
                the KMS service. Default: None
            stats: A statsd client instance, to be used to track stats.
                Default: None
        """
        self.auth_key = auth_key
        self.user_auth_key = user_auth_key
        self.to_auth_context = to_auth_context
        self.region = region
        if scoped_auth_keys is None:
            self.scoped_auth_keys = {}
        else:
            self.scoped_auth_keys = scoped_auth_keys
        self.minimum_token_version = minimum_token_version
        self.maximum_token_version = maximum_token_version
        self.auth_token_max_lifetime = auth_token_max_lifetime
        self.aws_creds = aws_creds
        if aws_creds:
            self.kms_client = kmsauth.services.get_boto_client(
                'kms',
                region=self.region,
                aws_access_key_id=self.aws_creds['AccessKeyId'],
                aws_secret_access_key=self.aws_creds['SecretAccessKey'],
                aws_session_token=self.aws_creds['SessionToken'],
                endpoint_url=endpoint_url,
                max_pool_connections=max_pool_connections,
                connect_timeout=connect_timeout,
                read_timeout=read_timeout,
            )
        else:
            self.kms_client = kmsauth.services.get_boto_client(
                'kms',
                region=self.region,
                endpoint_url=endpoint_url,
                max_pool_connections=max_pool_connections,
                connect_timeout=connect_timeout,
                read_timeout=read_timeout,
            )
        if extra_context is None:
            self.extra_context = {}
        else:
            self.extra_context = extra_context
        self.TOKENS = LRU(token_cache_size)
        self.token_cache_size = token_cache_size
        self.KEY_METADATA = {}
        self.stats = stats
        self._validate()

    def _validate(self):
        for key in ['from', 'to', 'user_type']:
            if key in self.extra_context:
                logging.warning(
                    '{0} in extra_context will be ignored.'.format(key)
                )
        if self.minimum_token_version < 1 or self.minimum_token_version > 2:
            raise ConfigurationError(
                'Invalid minimum_token_version provided.'
            )
        if self.maximum_token_version < 1 or self.maximum_token_version > 2:
            raise ConfigurationError(
                'Invalid maximum_token_version provided.'
            )
        if self.minimum_token_version > self.maximum_token_version:
            raise ConfigurationError(
                'minimum_token_version can not be greater than'
                ' self.minimum_token_version'
            )
        self.auth_key = self._format_auth_key(self.auth_key)
        self.user_auth_key = self._format_auth_key(self.user_auth_key)

    def _format_auth_key(self, keys):
        if isinstance(keys, six.string_types):
            logging.debug(
                'Passing auth key as string is deprecated, and will be removed'
                ' in 1.0.0'
            )
            return [keys]
        elif (keys is None or isinstance(keys, list)):
            return keys
        raise ConfigurationError(
            'auth_key and user_auth_key must be a string, list, or None'
        )

    def _get_key_arn(self, key):
        if key.startswith('arn:aws:kms:'):
            self.KEY_METADATA[key] = {
                'KeyMetadata': {'Arn': key}
            }
        if key not in self.KEY_METADATA:
            if self.stats:
                with self.stats.timer('kms_describe_key'):
                    self.KEY_METADATA[key] = self.kms_client.describe_key(
                        KeyId='{0}'.format(key)
                    )
            else:
                self.KEY_METADATA[key] = self.kms_client.describe_key(
                    KeyId='{0}'.format(key)
                )
        return self.KEY_METADATA[key]['KeyMetadata']['Arn']

    def _get_key_alias_from_cache(self, key_arn):
        '''
        Find a key's alias by looking up its key_arn in the KEY_METADATA
        cache. This function will only work after a key has been lookedup by
        its alias and is meant as a convenience function for turning an ARN
        that's already been looked up back into its alias.
        '''
        for alias in self.KEY_METADATA:
            if self.KEY_METADATA[alias]['KeyMetadata']['Arn'] == key_arn:
                return alias
        return None

    def _valid_service_auth_key(self, key_arn):
        if self.auth_key is None:
            return False
        for key in self.auth_key:
            if key_arn == self._get_key_arn(key):
                return True
        for key in self.scoped_auth_keys:
            if key_arn == self._get_key_arn(key):
                return True
        return False

    def _valid_user_auth_key(self, key_arn):
        if self.user_auth_key is None:
            return False
        for key in self.user_auth_key:
            if key_arn == self._get_key_arn(key):
                return True
        return False

    def _parse_username(self, username):
        username_arr = username.split('/')
        if len(username_arr) == 3:
            # V2 token format: version/service/myservice or version/user/myuser
            version = int(username_arr[0])
            user_type = username_arr[1]
            _from = username_arr[2]
        elif len(username_arr) == 1:
            # Old format, specific to services: myservice
            version = 1
            _from = username_arr[0]
            user_type = 'service'
        else:
            raise TokenValidationError('Unsupported username format.')
        return version, user_type, _from

    def extract_username_field(self, username, field):
        version, user_type, _from = self._parse_username(username)
        if field == 'from':
            return _from
        elif field == 'user_type':
            return user_type
        elif field == 'version':
            return version
        return None

    def decrypt_token(self, username, token):
        '''
        Decrypt a token.
        '''
        time_start = datetime.datetime.utcnow()
        version, user_type, _from = self._parse_username(username)
        if (version > self.maximum_token_version or
                version < self.minimum_token_version):
            raise TokenValidationError('Unacceptable token version.')
        if self.stats:
            self.stats.incr('token_version_{0}'.format(version))
            # Checkpoint 1: After username parsing
            checkpoint_1 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000  # noqa: E501
            self.stats.timing('checkpoint_1_after_parse', checkpoint_1)  # noqa: E501

            self.stats.incr('token_version_{version}')
            self.stats.incr(f'cache_key_from_{_from}')
            self.stats.incr(f'cache_key_to_{self.to_auth_context}')
            self.stats.incr(f'cache_key_user_type_{user_type}')

        try:
            token_key = '{0}{1}{2}{3}'.format(
                hashlib.sha256(ensure_bytes(token)).hexdigest(),
                _from,
                self.to_auth_context,
                user_type
            )
        except Exception:
            raise TokenValidationError('Authentication error.')

        if self.stats:
            # Checkpoint 2: After cache key generation
            checkpoint_2 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000  # noqa: E501
            self.stats.timing('checkpoint_2_after_cache_key', checkpoint_2)  # noqa: E501

        cache_miss = token_key not in self.TOKENS

        if self.stats:
            # Checkpoint 3: After cache lookup
            checkpoint_3 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000  # noqa: E501
            self.stats.timing('checkpoint_3_after_cache_lookup', checkpoint_3)  # noqa: E501

        if cache_miss:
            if self.stats:
                self.stats.incr('token_cache_miss')
                self.stats.gauge('token_cache_size_at_miss', len(self.TOKENS))
                if len(self.TOKENS) >= self.token_cache_size:
                    self.stats.incr('token_cache_eviction')

                # Checkpoint 3.5: After stats calls in cache miss
                checkpoint_3_5 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000  # noqa: E501
                self.stats.timing('checkpoint_3_5_after_cache_miss_stats', checkpoint_3_5)  # noqa: E501

            try:
                token = base64.b64decode(token)
                if self.stats:
                    # Checkpoint 3.7: After base64 decode
                    checkpoint_3_7 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000  # noqa: E501
                    self.stats.timing('checkpoint_3_7_after_base64_decode', checkpoint_3_7)  # noqa: E501

                # Ensure normal context fields override whatever is in
                # extra_context.
                context = copy.deepcopy(self.extra_context)
                context['to'] = self.to_auth_context
                context['from'] = _from
                if version > 1:
                    context['user_type'] = user_type

                if self.stats:
                    # Checkpoint 3.9: After context setup
                    checkpoint_3_9 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000  # noqa: E501
                    self.stats.timing('checkpoint_3_9_after_context_setup', checkpoint_3_9)  # noqa: E501
                if self.stats:
                    with self.stats.timer('kms_decrypt_token'):
                        data = self.kms_client.decrypt(
                            CiphertextBlob=token,
                            EncryptionContext=context
                        )
                    # Checkpoint 4: After KMS decrypt
                    checkpoint_4 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000  # noqa: E501
                    self.stats.timing('checkpoint_4_after_kms_decrypt', checkpoint_4)  # noqa: E501
                else:
                    data = self.kms_client.decrypt(
                        CiphertextBlob=token,
                        EncryptionContext=context
                    )
                # Decrypt doesn't take KeyId as an argument. We need to verify
                # the correct key was used to do the decryption.
                # Annoyingly, the KeyId from the data is actually an arn.
                key_arn = data['KeyId']
                if user_type == 'service':
                    if not self._valid_service_auth_key(key_arn):
                        raise TokenValidationError(
                            'Authentication error (wrong KMS key).'
                        )
                elif user_type == 'user':
                    if not self._valid_user_auth_key(key_arn):
                        raise TokenValidationError(
                            'Authentication error (wrong KMS key).'
                        )
                else:
                    raise TokenValidationError(
                        'Authentication error. Unsupported user_type.'
                    )
                if self.stats:
                    # Checkpoint 5: After key validation
                    checkpoint_5 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000  # noqa: E501
                    self.stats.timing('checkpoint_5_after_key_validation', checkpoint_5)  # noqa: E501
                plaintext = data['Plaintext']
                payload = json.loads(plaintext)
                key_alias = self._get_key_alias_from_cache(key_arn)
                if self.stats:
                    # Checkpoint 6: After JSON processing
                    checkpoint_6 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000  # noqa: E501
                    self.stats.timing('checkpoint_6_after_json_processing', checkpoint_6)  # noqa: E501
                ret = {'payload': payload, 'key_alias': key_alias}
            except TokenValidationError:
                raise
            except (ConnectionError, EndpointConnectionError):
                logging.exception('Failure connecting to AWS endpoint.')
                raise TokenValidationError(
                    'Authentication error. Failure connecting to AWS endpoint.'
                )
            # We don't care what exception is thrown. For paranoia's sake, fail
            # here.
            except Exception:
                logging.exception('Failed to validate token.')
                raise TokenValidationError(
                    'Authentication error. General error.'
                )
        else:
            if self.stats:
                self.stats.incr('token_cache_hit')
            ret = self.TOKENS[token_key]
            if self.stats:
                # Checkpoint 7: After cache hit
                checkpoint_7 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000  # noqa: E501
                self.stats.timing('checkpoint_7_after_cache_hit', checkpoint_7)  # noqa: E501

        now = datetime.datetime.utcnow()
        if self.stats:
            # Total time from start to this point (before time validation)
            pre_time_validation_duration = (now - time_start).total_seconds() * 1000  # noqa: E501
            self.stats.timing('pre_time_validation_duration', pre_time_validation_duration)  # noqa: E501
            # Original total validation duration metric
            self.stats.timing('decrypt_token_validation_duration', (now - time_start).total_seconds() * 1000)  # noqa: E501
        try:
            not_before = datetime.datetime.strptime(
                ret['payload']['not_before'],
                TIME_FORMAT
            )
            not_after = datetime.datetime.strptime(
                ret['payload']['not_after'],
                TIME_FORMAT
            )
        except Exception:
            logging.exception(
                'Failed to get not_before and not_after from token payload.'  # noqa: E501
            )
            raise TokenValidationError(
                'Authentication error. Missing validity.'
            )
        delta = (not_after - not_before).seconds / 60
        if delta > self.auth_token_max_lifetime:
            logging.warning('Token used which exceeds max token lifetime.')  # noqa: E501
            raise TokenValidationError(
                'Authentication error. Token lifetime exceeded.'
            )
        if (now < not_before) or (now > not_after):
            logging.warning('Invalid time validity for token.')
            raise TokenValidationError(
                'Authentication error. Invalid time validity for token.'
            )
        if self.stats:
            # Checkpoint 8: After time validation
            checkpoint_8 = (datetime.datetime.utcnow() - time_start).total_seconds() * 1000  # noqa: E501
            self.stats.timing('checkpoint_8_after_time_validation', checkpoint_8)  # noqa: E501

        cache_set_start = datetime.datetime.utcnow()
        self.TOKENS[token_key] = ret
        if self.stats:
            cache_set_duration = (datetime.datetime.utcnow() - cache_set_start).total_seconds() * 1000  # noqa: E501
            self.stats.timing('cache_set_duration', cache_set_duration)  # noqa: E501

        duration = (datetime.datetime.utcnow() - now).total_seconds() * 1000  # noqa: E501
        if self.stats:
            self.stats.timing('decrypt_token_duration_post_validation', duration)  # noqa: E501
            self.stats.incr('token_cache_set')
            self.stats.gauge('token_cache_size_at_set', len(self.TOKENS))  # noqa: E501
        return self.TOKENS[token_key]


class KMSTokenGenerator(object):

    """A class that represents a token generator for KMS auth."""

    def __init__(
            self,
            auth_key,
            auth_context,
            region,
            token_version=2,
            token_cache_file=None,
            token_lifetime=10,
            aws_creds=None,
            endpoint_url=None
            ):
        """Create a KMSTokenGenerator object.

        Args:
            auth_key: The KMS key ARN or alias to use for authentication.
                Required.
            auth_context: The KMS encryption context to use for authentication.
                Required.
            region: AWS region to connect to. Required.
            token_version: The version of the authentication token. Default: 2
            token_cache_file: he location to use for caching the auth token.
                If set to empty string, no cache will be used. Default: None
            token_lifetime: Lifetime of the authentication token generated.
                Default: 10
            aws_creds: A dict of AccessKeyId, SecretAccessKey, SessionToken.
                Useful if you wish to pass in assumed role credentials or MFA
                credentials. Default: None
            endpoint_url: A URL to override the default endpoint used to access
                the KMS service. Default: None
        """
        self.auth_key = auth_key
        if auth_context is None:
            self.auth_context = {}
        else:
            self.auth_context = auth_context
        self.token_cache_file = token_cache_file
        self.token_lifetime = token_lifetime
        self.region = region
        self.token_version = token_version
        self.aws_creds = aws_creds
        if aws_creds:
            self.kms_client = kmsauth.services.get_boto_client(
                'kms',
                region=self.region,
                aws_access_key_id=self.aws_creds['AccessKeyId'],
                aws_secret_access_key=self.aws_creds['SecretAccessKey'],
                aws_session_token=self.aws_creds['SessionToken'],
                endpoint_url=endpoint_url
            )
        else:
            self.kms_client = kmsauth.services.get_boto_client(
                'kms',
                region=self.region,
                endpoint_url=endpoint_url
            )
        self._validate()

    def _validate(self):
        for key in ['from', 'to']:
            if key not in self.auth_context:
                raise ConfigurationError(
                    '{0} missing from auth_context.'.format(key)
                )
        if self.token_version > 1:
            if 'user_type' not in self.auth_context:
                raise ConfigurationError(
                    'user_type missing from auth_context.'
                )
        if self.token_version > 2:
            raise ConfigurationError(
                'Invalid token_version provided.'
            )

    def _get_cached_token(self):
        token = None
        if not self.token_cache_file:
            return token
        try:
            with open(self.token_cache_file, 'r') as f:
                token_data = json.load(f)
            _not_after = token_data['not_after']
            _auth_context = token_data['auth_context']
            _token = token_data['token']
            _not_after_cache = datetime.datetime.strptime(
                _not_after,
                TIME_FORMAT
            )
        except IOError as e:
            logging.debug(
                'Failed to read confidant auth token cache: {0}'.format(e)
            )
            return token
        except Exception:
            logging.exception('Failed to read confidant auth token cache.')
            return token
        skew_delta = datetime.timedelta(minutes=TOKEN_SKEW)
        _not_after_cache = _not_after_cache - skew_delta
        now = datetime.datetime.utcnow()
        if (now <= _not_after_cache and
                _auth_context == self.auth_context):
            logging.debug('Using confidant auth token cache.')
            token = _token
        return token

    def _cache_token(self, token, not_after):
        if not self.token_cache_file:
            return
        try:
            cachedir = os.path.dirname(self.token_cache_file)
            if not os.path.exists(cachedir):
                os.makedirs(cachedir)
            with open(self.token_cache_file, 'w') as f:
                json.dump({
                    'token': ensure_text(token),
                    'not_after': not_after,
                    'auth_context': self.auth_context
                }, f)
        except Exception:
            logging.exception('Failed to write confidant auth token cache.')

    def get_username(self):
        """Get a username formatted for a specific token version."""
        _from = self.auth_context['from']
        if self.token_version == 1:
            return '{0}'.format(_from)
        elif self.token_version == 2:
            _user_type = self.auth_context['user_type']
            return '{0}/{1}/{2}'.format(
                self.token_version,
                _user_type,
                _from
            )

    def get_token(self):
        """Get an authentication token."""
        # Generate string formatted timestamps for not_before and not_after,
        # for the lifetime specified in minutes.
        now = datetime.datetime.utcnow()
        # Start the not_before time x minutes in the past, to avoid clock skew
        # issues.
        _not_before = now - datetime.timedelta(minutes=TOKEN_SKEW)
        not_before = _not_before.strftime(TIME_FORMAT)
        # Set the not_after time in the future, by the lifetime, but ensure the
        # skew we applied to not_before is taken into account.
        _not_after = now + datetime.timedelta(
            minutes=self.token_lifetime - TOKEN_SKEW
        )
        not_after = _not_after.strftime(TIME_FORMAT)
        # Generate a json string for the encryption payload contents.
        payload = json.dumps({
            'not_before': not_before,
            'not_after': not_after
        })
        token = self._get_cached_token()
        if token:
            return token
        # Generate a base64 encoded KMS encrypted token to use for
        # authentication. We encrypt the token lifetime information as the
        # payload for verification in Confidant.
        try:
            token = self.kms_client.encrypt(
                KeyId=self.auth_key,
                Plaintext=payload,
                EncryptionContext=self.auth_context
            )['CiphertextBlob']
            token = base64.b64encode(ensure_bytes(token))
        except (ConnectionError, EndpointConnectionError) as e:
            logging.exception('Failure connecting to AWS: {}'.format(str(e)))
            raise ServiceConnectionError()
        except Exception:
            logging.exception('Failed to create auth token.')
            raise TokenGenerationError()
        self._cache_token(token, not_after)
        return token


class ServiceConnectionError(Exception):
    """An exception raised when there was an AWS connection error."""
    pass


class ConfigurationError(Exception):

    """An exception raised when a token was unsuccessfully created."""

    pass


class TokenValidationError(Exception):
    """An exception raised when a token was unsuccessfully validated."""
    pass


class TokenGenerationError(Exception):

    """An exception raised when a token was unsuccessfully generated."""

    pass
