import hashlib
import requests
import base64

import dateutil.parser
from pyasn1.codec.der import encoder, decoder
from pyasn1_modules import rfc2459
from pyasn1.type import univ
from pyasn1.error import PyAsn1Error
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives import hashes

import rfc3161ng

__all__ = (
    'RemoteTimestamper', 'check_timestamp', 'get_hash_oid',
    'TimestampingError', 'get_timestamp'
)

id_attribute_messageDigest = univ.ObjectIdentifier((1, 2, 840, 113549, 1, 9, 4))


def get_hash_oid(hashname):
    return rfc3161ng.__dict__['id_' + hashname]


def get_hash_from_oid(oid):
    h = rfc3161ng.oid_to_hash.get(oid)
    if h is None:
        raise ValueError('unsupported hash algorithm', oid)
    return h


def get_hash_class_from_oid(oid):
    h = get_hash_from_oid(oid)
    return getattr(hashlib, h)


class TimestampingError(RuntimeError):
    pass


def get_timestamp(tst):
    try:
        if not isinstance(tst, rfc3161ng.TimeStampToken):
            tst, substrate = decoder.decode(tst, asn1Spec=rfc3161ng.TimeStampToken())
            if substrate:
                raise ValueError("extra data after tst")

        tstinfo = tst.getComponentByName('content').getComponentByPosition(2).getComponentByPosition(1)
        tstinfo, substrate = decoder.decode(tstinfo, asn1Spec=univ.OctetString())
        if substrate:
            raise ValueError("extra data after tst")
        tstinfo, substrate = decoder.decode(tstinfo, asn1Spec=rfc3161ng.TSTInfo())
        if substrate:
            raise ValueError("extra data after tst")
        genTime = tstinfo.getComponentByName('genTime')
        return dateutil.parser.parse(str(genTime))
    except PyAsn1Error as exc:
        raise ValueError('not a valid TimeStampToken', exc)


def load_certificate(signed_data, certificate=b""):
    backend = default_backend()

    if certificate == b"":
        try:
            certificate = signed_data['certificates'][0][0]
        except (KeyError, IndexError, TypeError):
            raise AttributeError("missing certificate")
        data = encoder.encode(certificate)
        return x509.load_der_x509_certificate(data, backend)

    if b'-----BEGIN CERTIFICATE-----' in certificate:
        return x509.load_pem_x509_certificate(certificate, backend)
    return x509.load_der_x509_certificate(certificate, backend)


def check_timestamp(tst, certificate, data=None, digest=None, hashname=None, nonce=None):
    hashname = hashname or 'sha1'
    hashobj = hashlib.new(hashname)
    if digest is None:
        if not data:
            raise ValueError("check_timestamp requires data or digest argument")
        hashobj.update(data)
        digest = hashobj.digest()

    if not isinstance(tst, rfc3161ng.TimeStampToken):
        tst, substrate = decoder.decode(tst, asn1Spec=rfc3161ng.TimeStampToken())
        if substrate:
            raise ValueError("extra data after tst")
    signed_data = tst.content
    certificate = load_certificate(signed_data, certificate)
    if nonce is not None and int(tst.tst_info['nonce']) != int(nonce):
        raise ValueError('nonce is different or missing')
    # check message imprint with respect to locally computed digest
    message_imprint = tst.tst_info.message_imprint
    if message_imprint.hash_algorithm[0] != get_hash_oid(hashname) or bytes(message_imprint.hashed_message) != digest:
        raise ValueError('Message imprint mismatch')
    if not len(signed_data['signerInfos']):
        raise ValueError('No signature')
    # We validate only one signature
    signer_info = signed_data['signerInfos'][0]
    # check content type
    if tst.content['contentInfo']['contentType'] != rfc3161ng.id_ct_TSTInfo:
        raise ValueError("Signed content type is wrong: %s != %s" % (
            tst.content['contentInfo']['contentType'], rfc3161ng.id_ct_TSTInfo
        ))

    # check signed data digest
    content = bytes(decoder.decode(bytes(tst.content['contentInfo']['content']), asn1Spec=univ.OctetString())[0])
    # if there is authenticated attributes, they must contain the message
    # digest and they are the signed data otherwise the content is the
    # signed data
    if len(signer_info['authenticatedAttributes']):
        authenticated_attributes = signer_info['authenticatedAttributes']
        signer_digest_algorithm = signer_info['digestAlgorithm']['algorithm']
        signer_hash_class = get_hash_class_from_oid(signer_digest_algorithm)
        signer_hash_name = get_hash_from_oid(signer_digest_algorithm)
        content_digest = signer_hash_class(content).digest()
        for authenticated_attribute in authenticated_attributes:
            if authenticated_attribute[0] == id_attribute_messageDigest:
                try:
                    signed_digest = bytes(decoder.decode(bytes(authenticated_attribute[1][0]), asn1Spec=univ.OctetString())[0])
                    if signed_digest != content_digest:
                        raise ValueError('Content digest != signed digest')
                    s = univ.SetOf()
                    for i, x in enumerate(authenticated_attributes):
                        s.setComponentByPosition(i, x)
                    signed_data = encoder.encode(s)
                    break
                except PyAsn1Error:
                    raise
        else:
            raise ValueError('No signed digest')
    else:
        signed_data = content
    # check signature
    signature = signer_info['encryptedDigest']
    public_key = certificate.public_key()
    hash_family = getattr(hashes, signer_hash_name.upper())
    public_key.verify(
        bytes(signature),
        signed_data,
        padding.PKCS1v15(),
        hash_family(),
    )
    return True


class RemoteTimestamper(object):
    def __init__(self, url, certificate=None, capath=None, cafile=None, username=None, password=None, hashname=None, include_tsa_certificate=False, timeout=10):
        self.url = url
        self.certificate = certificate
        self.capath = capath
        self.cafile = cafile
        self.username = username
        self.password = password
        self.hashname = hashname or 'sha1'
        self.include_tsa_certificate = include_tsa_certificate
        self.timeout = timeout

    def check_response(self, response, digest, nonce=None):
        '''
           Check validity of a TimeStampResponse
        '''
        tst = response.time_stamp_token
        return self.check(tst, digest=digest, nonce=nonce)

    def check(self, tst, data=None, digest=None, nonce=None):
        return check_timestamp(
            tst,
            digest=digest,
            data=data,
            nonce=nonce,
            certificate=self.certificate,
            hashname=self.hashname,
        )

    def timestamp(self, data=None, digest=None, include_tsa_certificate=None, nonce=None):
        return self(
            data=data,
            digest=digest,
            include_tsa_certificate=include_tsa_certificate,
            nonce=nonce,
        )

    def __call__(self, data=None, digest=None, include_tsa_certificate=None, nonce=None):
        algorithm_identifier = rfc2459.AlgorithmIdentifier()
        algorithm_identifier.setComponentByPosition(0, get_hash_oid(self.hashname))
        message_imprint = rfc3161ng.MessageImprint()
        message_imprint.setComponentByPosition(0, algorithm_identifier)
        hashobj = hashlib.new(self.hashname)
        if data:
            if not isinstance(data, bytes):
                data = data.encode()
            hashobj.update(data)
            digest = hashobj.digest()
        elif digest:
            assert len(digest) == hashobj.digest_size, 'digest length is wrong'
        else:
            raise ValueError('You must pass some data to digest, or the digest')
        message_imprint.setComponentByPosition(1, digest)
        request = rfc3161ng.TimeStampReq()
        request.setComponentByPosition(0, 'v1')
        request.setComponentByPosition(1, message_imprint)
        if nonce is not None:
            request.setComponentByPosition(3, int(nonce))
        request.setComponentByPosition(4, include_tsa_certificate if include_tsa_certificate is not None else self.include_tsa_certificate)
        binary_request = encoder.encode(request)
        headers = {'Content-Type': 'application/timestamp-query'}
        if self.username is not None:
            username = self.username.encode() if not isinstance(self.username, bytes) else self.username
            password = self.password.encode() if not isinstance(self.password, bytes) else self.password
            base64string = base64.standard_b64encode(b'%s:%s' % (username, password))
            if isinstance(base64string, bytes):
                base64string = base64string.decode()
            headers['Authorization'] = "Basic %s" % base64string
        try:
            response = requests.post(
                self.url,
                data=binary_request,
                timeout=self.timeout,
                headers=headers,
            )
            response.raise_for_status()
        except requests.RequestException as exc:
            raise TimestampingError('Unable to send the request to %r' % self.url, exc)
        tst_response, substrate = decoder.decode(response.content, asn1Spec=rfc3161ng.TimeStampResp())
        if substrate:
            raise ValueError('Extra data returned')
        self.check_response(tst_response, digest, nonce=nonce)
        return encoder.encode(tst_response.time_stamp_token)
