import json
import requests
from datetime import datetime
from pymisp import MISPAttribute, MISPEvent, MISPObject
from typing import Iterator


class VulnerabilityMapping:
    __variot_data_mapping = {
        "credits": "credit",
        "description": "description",
        "title": "summary",
    }
    __variot_flat_mapping = {"cve": "id", "id": "id"}

    @classmethod
    def exploit_mapping(cls) -> dict:
        return cls.__exploit_mapping

    @classmethod
    def exploit_multiple_mapping(cls) -> dict:
        return cls.__exploit_multiple_mapping

    @classmethod
    def variot_data_mapping(cls) -> dict:
        return cls.__variot_data_mapping

    @classmethod
    def variot_flat_mapping(cls) -> dict:
        return cls.__variot_flat_mapping


class VulnerabilityParser:
    def __init__(self, attribute: dict):
        misp_attribute = MISPAttribute()
        misp_attribute.from_dict(**attribute)
        misp_event = MISPEvent()
        misp_event.add_attribute(**misp_attribute)
        self.__misp_attribute = misp_attribute
        self.__misp_event = misp_event

    @property
    def misp_attribute(self):
        return self.__misp_attribute

    @property
    def misp_event(self):
        return self.__misp_event

    def get_results(self) -> dict:
        event = json.loads(self.misp_event.to_json())
        return {"results": {key: value for key, value in event.items() if key in ("Attribute", "Object")}}

    def _parse_variot_description(self, query_results):
        vulnerability_object = MISPObject("vulnerability")
        for field, relation in self.mapping.variot_flat_mapping().items():
            if query_results.get(field):
                vulnerability_object.add_attribute(relation, query_results[field])
        for field, relation in self.mapping.variot_data_mapping().items():
            if query_results.get(field, {}).get("data"):
                vulnerability_object.add_attribute(relation, query_results[field]["data"])
        if query_results.get("configurations", {}).get("data"):
            for configuration in query_results["configurations"]["data"]:
                for node in configuration["nodes"]:
                    for cpe_match in node["cpe_match"]:
                        if cpe_match["vulnerable"]:
                            vulnerability_object.add_attribute("vulnerable-configuration", cpe_match["cpe23Uri"])
        if query_results.get("cvss", {}).get("data"):
            cvss = {}
            for cvss_data in query_results["cvss"]["data"]:
                for cvss_v2 in cvss_data["cvssV2"]:
                    cvss[float(cvss_v2["trust"])] = cvss_v2
                for cvss_v3 in cvss_data["cvssV3"]:
                    cvss[float(cvss_v3["trust"])] = cvss_v3
            if cvss:
                cvss = cvss[max(cvss)]
                vulnerability_object.add_attribute("cvss-score", cvss["baseScore"])
                vulnerability_object.add_attribute("cvss-string", cvss["vectorString"])
        if query_results.get("references", {}).get("data"):
            for reference in query_results["references"]["data"]:
                vulnerability_object.add_attribute("references", reference["url"])
        if query_results.get("sources_release_date", {}).get("data"):
            for release_date in query_results["sources_release_date"]["data"]:
                if release_date["db"] != "NVD":
                    continue
                if release_date["id"] == self.misp_attribute.value:
                    vulnerability_object.add_attribute("published", release_date["date"])
                    break
        if query_results.get("sources_update_date", {}).get("data"):
            for update_date in query_results["sources_update_date"]["data"]:
                if update_date["db"] != "NVD":
                    continue
                if update_date["id"] == self.misp_attribute.value:
                    vulnerability_object.add_attribute("modified", update_date["date"])
                    break
        vulnerability_object.add_reference(self.misp_attribute.uuid, "related-to")
        self.misp_event.add_object(vulnerability_object)


class VulnerabilityLookupMapping(VulnerabilityMapping):
    __certfr_mapping = {
        "reference": "id",
        "title": "summary",
        "summary": "description"
    }
    __cnvd_mapping = {
        "number": "id",
        "title": "summary",
        "description": "description",
        "referenceLink": "references",
        "submitTime": "published",
        "openTime": "modified"
    }
    __csaf_mapping = {
        "id": "id",
        "initial_release_date": "published",
        "current_release_date": "modified"
    }
    __cve_mapping = {
        "cveId": "id",
        "datePublished": "published",
        "dateUpdated": "modified",
        "state": "state"
    }
    __cwe_mapping = {"cweId": "id", "description": "description", "name": "name"}
    __fkie_mapping = {
        'lastModified': 'modified',
        'published': 'published'
    }
    __gcve_mapping = {
        "vulnId": "id",
        "datePublished": "published",
        "dateUpdated": "modified",
        "state": "state"
    }
    __gsd_mapping = {"id": "id", "details": "description", "modified": "modified"}
    __jvn_mapping = {
        "sec:identifier": "id",
        "description": "description",
        "title": "summary",
        "link": "references",
        "dcterms:issued": "published",
        "dcterms:modified": "modified"
    }
    __nvd_mapping = {"id": "id", "published": "published", "lastModified": "modified"}
    __ossf_mapping = {
        "id": "id",
        "summary": "summary",
        "details": "description",
        "published": "published",
        "modified": "modified"
    }
    __related_vuln_mapping = {
        "cve": "id",
        "title": "summary",
        "discovery_date": "published"
    }
    __source_mapping = {
        "certfr": "_parse_certfr_description",
        "cnvd": "_parse_cnvd_description",
        "cve": "_parse_cve_description",
        'fkie_cve': '_parse_fkie_description',
        "gcve": "_parse_gcve_description",
        "ghsa": "_parse_standard_description",
        "gsd": "_parse_gsd_description",
        "jvndb": "_parse_jvn_description",
        "mal": "_parse_ossf_description",
        "pysec": "_parse_standard_description",
        "ts": "_parse_tailscale_description",
        "var": "_parse_variot_description"
    }
    __source_mapping.update(
        dict.fromkeys(
            (
                'cisco', 'icsa', 'icsma', 'msrc_cve', 'ncsc', 'nn',
                'opensuse', 'oxas', 'rhba', 'rhea', 'rhsa', 'sca',
                'suse', 'ssa', 'va', 'wid'
            ),
            '_parse_csaf_description'
        )
    )
    __standard_mapping = {
        "id": "id",
        "details": "description",
        "published": "published",
        "modified": "modified"
    }
    __tailscale_mapping = {
        "title": "id",
        "link": "references",
        "summary": "summary",
        "published": "published"
    }

    @classmethod
    def certfr_mapping(cls) -> dict:
        return cls.__certfr_mapping

    @classmethod
    def cnvd_mapping(cls) -> dict:
        return cls.__cnvd_mapping

    @classmethod
    def csaf_mapping(cls) -> dict:
        return cls.__csaf_mapping

    @classmethod
    def cve_mapping(cls) -> dict:
        return cls.__cve_mapping

    @classmethod
    def cwe_mapping(cls) -> dict:
        return cls.__cwe_mapping

    @classmethod
    def fkie_mapping(cls) -> dict:
        return cls.__fkie_mapping

    @classmethod
    def gcve_mapping(cls) -> dict:
        return cls.__gcve_mapping

    @classmethod
    def gsd_mapping(cls) -> dict:
        return cls.__gsd_mapping

    @classmethod
    def jvn_mapping(cls) -> dict:
        return cls.__jvn_mapping

    @classmethod
    def nvd_mapping(cls) -> dict:
        return cls.__nvd_mapping

    @classmethod
    def ossf_mapping(cls) -> dict:
        return cls.__ossf_mapping

    @classmethod
    def related_vuln_mapping(cls) -> dict:
        return cls.__related_vuln_mapping

    @classmethod
    def source_mapping(cls, field: str) -> str:
        return cls.__source_mapping.get(field)

    @classmethod
    def standard_mapping(cls) -> dict:
        return cls.__standard_mapping

    @classmethod
    def tailscale_mapping(cls) -> dict:
        return cls.__tailscale_mapping


class VulnerabilityLookupParser(VulnerabilityParser):
    def __init__(self, attribute: dict, api_url: str):
        super().__init__(attribute)
        self.__api_url = api_url
        self.__mapping = VulnerabilityLookupMapping
        self.__errors = []

    @property
    def api_url(self) -> str:
        return self.__api_url

    @property
    def errors(self) -> list:
        return self.__errors

    @property
    def mapping(self) -> VulnerabilityLookupMapping:
        return self.__mapping

    def parse_lookup_result(self, lookup_result: dict):
        feature = self.mapping.source_mapping(self.misp_attribute.value.split("-")[0].lower())
        getattr(self, feature)(lookup_result)

    def _create_vulnerability_object(self, vuln_id: str) -> MISPObject:
        misp_object = MISPObject("vulnerability")
        misp_object.add_attribute(
            "references", f"{self.api_url}/vuln/{vuln_id}"
        )
        return misp_object

    def _parse_alias(self, alias: str) -> str:
        query = requests.get(f"{self.api_url}/api/vulnerability/{alias}")
        if query.status_code != 200:
            self.errors.append(
                f"Unable to query related vulnerability id {alias}"
            )
            return
        vulnerability = query.json()
        if not vulnerability:
            self.errors.append(
                f"No results for related vulnerability id{alias}"
            )
            return
        feature = self.mapping.source_mapping(alias.split("-")[0].lower())
        return getattr(self, feature)(vulnerability)

    def _parse_aliases(self, *aliases: tuple) -> Iterator[str]:
        for alias in aliases:
            yield self._parse_alias(alias)

    def _parse_certfr_description(self, lookup_result: dict) -> str:
        misp_object = self._create_vulnerability_object(
            lookup_result['reference']
        )
        for field, relation in self.mapping.certfr_mapping().items():
            misp_object.add_attribute(relation, lookup_result[field])
        timestamps = {
            datetime.strptime(revision['revision_date'], '%Y-%m-%dT%H:%M:%S.%f')
            for revision in lookup_result.get('revisions', [])
        }
        if timestamps:
            misp_object.add_attribute('published', min(timestamps))
            if len(timestamps) > 1:
                misp_object.add_attribute('modified', max(timestamps))
        for link in lookup_result.get('links', []):
            misp_object.add_attribute(
                'references', link['url'], comment=link['title']
            )
        vulnerability_object = self.misp_event.add_object(misp_object)
        for cve in lookup_result.get('cves', []):
            vulnerability_object.add_reference(
                self._parse_alias(cve['name']), 'related-to'
            )
        return vulnerability_object.uuid

    def _parse_cnvd_description(self, lookup_result: dict) -> str:
        misp_object = self._create_vulnerability_object(lookup_result['number'])
        for field, relation in self.mapping.cnvd_mapping().items():
            misp_object.add_attribute(relation, lookup_result[field])
        vulnerability_object = self.misp_event.add_object(misp_object)
        cve = lookup_result.get('cves', {}).get('cve', {}).get('cveNumber')
        if cve is not None:
            vulnerability_object.add_reference(
                self._parse_alias(cve), 'related-to'
            )
        return vulnerability_object.uuid

    def _parse_csaf_branch(self, branch: list) -> Iterator[str]:
        for sub_branch in branch:
            if sub_branch.get("branches"):
                yield from self._parse_csaf_branch(sub_branch["branches"])
            else:
                cpe = sub_branch.get("product", {}).get("product_identification_helper", {}).get("cpe")
                if cpe is not None:
                    yield cpe

    def _parse_csaf_description(self, lookup_result: dict) -> str:
        description = lookup_result["document"]

        tracking = description["tracking"]
        misp_object = self._create_vulnerability_object(tracking['id'])
        for field, relation in self.mapping.csaf_mapping().items():
            misp_object.add_attribute(relation, tracking[field])
        misp_object.add_attribute("summary", description["title"])
        for reference in description.get("references", []):
            misp_object.add_attribute("references", reference["url"])
        misp_object.add_attribute("credit", description["publisher"]["name"])
        branches = lookup_result.get("product_tree", {}).get("branches", [])
        if branches:
            for cpe in set(self._parse_csaf_branch(branches)):
                misp_object.add_attribute("vulnerable-configuration", cpe)
        misp_object.add_reference(self.misp_attribute.uuid, "describes")
        vulnerability_object = self.misp_event.add_object(misp_object)

        for vulnerability in lookup_result["vulnerabilities"]:
            related = self._create_vulnerability_object(vulnerability['cve'])
            for field, relation in self.mapping.related_vuln_mapping().items():
                if vulnerability.get(field):
                    related.add_attribute(relation, vulnerability[field])
            for score in vulnerability.get("scores", []):
                cvss_v3 = score["cvss_v3"]
                related.add_attribute("cvss-score", cvss_v3["baseScore"])
                related.add_attribute("cvss-string", cvss_v3["vectorString"])
            for reference in vulnerability.get("references", []):
                related.add_attribute("references", reference["url"])
            related.add_reference(vulnerability_object.uuid, "related-to")
            related_vulnerability = self.misp_event.add_object(related)
            if vulnerability.get("cwe"):
                cwe = vulnerability["cwe"]
                weakness = MISPObject("weakness")
                for field, value in cwe.items():
                    weakness.add_attribute(field, value)
                self.misp_event.add_object(weakness)
                related_vulnerability.add_reference(weakness.uuid, "weakened-by")

        return vulnerability_object.uuid

    def _parse_cve_description(self, lookup_result: dict) -> str:
        cveMetaData = lookup_result["cveMetadata"]
        misp_object = self._create_vulnerability_object(cveMetaData["cveId"])
        for field, relation in self.mapping.cve_mapping().items():
            misp_object.add_attribute(relation, cveMetaData[field])
        containers = lookup_result["containers"]
        for reference in containers.get("cna", {}).get("references", []):
            misp_object.add_attribute("references", reference["url"])
        for adp in containers.get("adp", []):
            for affected in adp.get("affected", []):
                for cpe in affected.get("cpes", []):
                    misp_object.add_attribute("vulnerable-configuration", cpe)
        misp_object.add_reference(self.misp_attribute.uuid, "related-to")
        vulnerability_object = self.misp_event.add_object(misp_object)
        return vulnerability_object.uuid

    def _parse_cve_related_description(self, cve_description: dict) -> str:
        cve_id = cve_description["CVE_data_meta"]["ID"]
        misp_object = self._create_vulnerability_object(cve_id)
        misp_object.add_attribute("id", cve_id)
        misp_object.add_attribute(
            'description',
            self._parse_description_value(
                *cve_description['description']['description_data']
            )
        )
        for cvss in cve_description.get("impact", {}).get("cvss", []):
            misp_object.add_attribute("cvss-score", cvss["baseScore"])
            misp_object.add_attribute("cvss-string", cvss["vectorString"])
        for reference in misp_object.get("references", {}).get("reference_data", []):
            misp_object.add_attribute("references", reference["url"])
        return self.misp_event.add_object(misp_object).uuid

    @staticmethod
    def _parse_description_value(*descriptions: tuple[dict]) -> str:
        for description in descriptions:
            if description.get('lang') in ('en', 'eng'):
                return description['value']
        return descriptions[0]['value']

    def _parse_fkie_description(self, lookup_result: dict) -> str:
        vuln_id = lookup_result['id']
        misp_object = self._create_vulnerability_object(vuln_id)
        misp_object.add_attribute('id', self.misp_attribute.value)
        misp_object.add_attribute('id', vuln_id)
        misp_object.add_attribute(
            'description',
            self._parse_description_value(*lookup_result['descriptions'])
        )
        for field, relation in self.mapping.fkie_mapping().items():
            if lookup_result.get(field) is not None:
                misp_object.add_attribute(relation, lookup_result[field])
        for cvss in lookup_result.get('metrics', {}).get('cvssMetricV31', []):
            if cvss.get('cvssData') is None:
                continue
            cvss_data = cvss['cvssData']
            misp_object.add_attribute('cvss-score', cvss_data['baseScore'])
            misp_object.add_attribute('cvss-string', cvss_data['vectorString'])
        for configuration in lookup_result.get('configurations', []):
            for node in configuration.get('nodes', []):
                for cpe_match in node.get('cpeMatch', []):
                    if cpe_match.get('criteria') is not None:
                        misp_object.add_attribute(
                            'vulnerable-configuration', cpe_match['criteria']
                        )
        references = set(
            reference['url'] for reference in
            lookup_result.get('references', [])
            if reference.get('url') is not None
        )
        if references:
            for reference in references:
                misp_object.add_attribute('references', reference)
        vulnerability_object = self.misp_event.add_object(misp_object)
        for weakness in lookup_result.get('weaknesses', []):
            for value in weakness.get('description', []):
                attribute = self.misp_event.add_attribute(
                    'weakness', value['value']
                )
                vulnerability_object.add_reference(
                    attribute.uuid, 'weakened-by'
                )

    def _parse_gcve_description(self, lookup_result: dict) -> str:
        metadata = lookup_result['cveMetadata']
        misp_object = self._create_vulnerability_object(metadata['vulnId'])
        for field, relation in self.mapping.gcve_mapping().items():
            misp_object.add_attribute(relation, metadata[field])
        vulnerability_object = self.misp_event.add_object(misp_object)
        container = lookup_result['containers'].get('cna')
        if container is not None:
            if container.get('title'):
                vulnerability_object.add_attribute(
                    'summary', container['title']
                )
            for description in container.get('descriptions', []):
                vulnerability_object.add_attribute(
                    'description', description['value']
                )
            for reference in container.get('references', []):
                vulnerability_object.add_attribute(
                    'references', reference['url']
                )
            for metric in container.get('metrics', []):
                for key, fields in metric.items():
                    if key.startswith('cvssV'):
                        vulnerability_object.add_attribute(
                            'cvss-score', fields['baseScore']
                        )
                        vulnerability_object.add_attribute(
                            'cvss-string', fields['vectorString']
                        )
            for credit in container.get('credits', []):
                vulnerability_object.add_attribute('credit', credit['value'])
            for weakness in container.get('problemTypes', []):
                for description in weakness.get('descriptions', []):
                    weakness_object = MISPObject('weakness')
                    weakness_object.add_attribute('id', description['cweId'])
                    weakness_object.add_attribute(
                        'description', description['description']
                    )
                    vulnerability_object.add_reference(
                        self.misp_event.add_object(weakness_object).uuid,
                        'weakened-by'
                    )
        if metadata.get('cveId') is not None:
            vulnerability_object.add_reference(
                self._parse_alias(metadata['cveId']), 'related-to'
            )
        return vulnerability_object.uuid

    def _parse_gsd_description(self, lookup_result: dict) -> str:
        gsd_details = lookup_result["gsd"]["osvSchema"]
        misp_object = self._create_vulnerability_object(gsd_details['id'])
        for field, relation in self.mapping.gsd_mapping().items():
            if gsd_details.get(field):
                misp_object.add_attribute(relation, gsd_details[field])
        misp_object.add_reference(self.misp_attribute.uuid, "related-to")
        vulnerability_object = self.misp_event.add_object(misp_object)

        for field, values in lookup_result["namespaces"].items():
            if field == "cve.org":
                vulnerability_object.add_reference(self._parse_cve_related_description(values), "related-to")
                continue
            if field == "nvd.nist.gov" and values.get("cve"):
                vulnerability_object.add_reference(self._parse_nvd_related_description(values["cve"]), "related-to")

        return vulnerability_object.uuid

    def _parse_jvn_description(self, lookup_result: dict) -> str:
        vulnerability = self._create_vulnerability_object(lookup_result['id'])
        for field, relation in self.mapping.jvn_mapping().items():
            vulnerability.add_attribute(relation, lookup_result[field])
        for cpe in lookup_result.get("sec:cpe", []):
            cpe_value = cpe.get("#text")
            if cpe_value is not None:
                vulnerability.add_attribute("vulnerable-configuration", cpe_value)
        misp_object = self.misp_event.add_object(vulnerability)
        for reference in lookup_result.get("sec:references", []):
            source = reference.get("@source")
            if source is None and reference.get("@id", "").startswith("CWE-"):
                title = reference.get("@title")
                if title is not None:
                    weakness = MISPObject("weakness")
                    weakness.add_attribute("id", reference["@id"])
                    weakness.add_attribute("description", title)
                    misp_object.add_reference(self.misp_event.add_object(weakness).uuid, "weakened-by")
                else:
                    misp_object.add_reference(
                        self.misp_event.add_attribute(type="weakness", value=reference["@id"]).uuid,
                        "weakened-by"
                    )
                continue
            if source == "JVN":
                misp_object.add_attribute("references", reference["#text"])
            elif source == "CVE":
                for referenced_uuid in self._parse_aliases(reference["@id"]):
                    misp_object.add_reference(referenced_uuid, "related-to")
        return misp_object.uuid

    def _parse_nvd_related_description(self, nvd_description: dict) -> str:
        misp_object = self._create_vulnerability_object(nvd_description['id'])
        for field, relation in self.mapping.nvd_mapping().items():
            misp_object.add_attribute(relation, nvd_description[field])
        misp_object.add_attribute(
            'description',
            self._parse_description_value(*nvd_description['descriptions'])
        )
        for cvss in nvd_description.get("metrics", {}).get("cvssMetricV31", []):
            misp_object.add_attribute("cvss-score", cvss["cvssData"]["baseScore"])
            misp_object.add_attribute("cvss-string", cvss["cvssData"]["vectorString"])
        for reference in nvd_description.get("references", []):
            misp_object.add_attribute("references", reference["url"])
        return self.misp_event.add_object(misp_object).uuid

    def _parse_ossf_description(self, lookup_result: dict) -> str:
        misp_object = self._create_vulnerability_object(lookup_result['id'])
        for field, relation in self.mapping.ossf_mapping().items():
            misp_object.add_attribute(relation, lookup_result[field])
        for reference in lookup_result["references"]:
            misp_object.add_attribute("references", reference["url"])
        misp_object.add_reference(self.misp_attribute.uuid, "related-to")
        vulnerability_object = self.misp_event.add_object(misp_object)
        for affected in lookup_result.get("affected", []):
            for cwe in affected.get("database_specific", {}).get("cwes", []):
                cwe_id = cwe.get("cweId")
                if cwe_id is not None:
                    weakness = MISPObject("weakness")
                    for field, relation in self.mapping.cwe_mapping().items():
                        if cwe.get(field):
                            weakness.add_attribute(relation, cwe[field])
                    self.misp_event.add_object(weakness)
                    vulnerability_object.add_reference(weakness.uuid, "weakened-by")

        if lookup_result.get("aliases"):
            for vuln_uuid in self._parse_aliases(*lookup_result["aliases"]):
                vulnerability_object.add_reference(vuln_uuid, "related-to")

        return vulnerability_object.uuid

    def _parse_standard_description(self, lookup_result: dict) -> str:
        misp_object = self._create_vulnerability_object(lookup_result['id'])
        for field, relation in self.mapping.standard_mapping().items():
            misp_object.add_attribute(relation, lookup_result[field])
        for cvss in lookup_result.get("severity", []):
            misp_object.add_attribute("cvss-string", cvss["score"])
        for reference in lookup_result["references"]:
            misp_object.add_attribute("references", reference["url"])
        for cwe_id in lookup_result.get("database_specific", {}).get("cwe_ids", []):
            attribute = self.misp_event.add_attribute(type="weakness", value=cwe_id)
            misp_object.add_reference(attribute.uuid, "weakened-by")
        misp_object.add_reference(self.misp_attribute.uuid, "related-to")
        vulnerability_object = self.misp_event.add_object(misp_object)

        if lookup_result.get("aliases"):
            for vuln_uuid in self._parse_aliases(*lookup_result["aliases"]):
                vulnerability_object.add_reference(vuln_uuid, "related-to")

        return vulnerability_object.uuid

    def _parse_tailscale_description(self, lookup_result: dict) -> str:
        misp_object = self._create_vulnerability_object(lookup_result['title'])
        for field, relation in self.mapping.tailscale_mapping().items():
            misp_object.add_attribute(relation, lookup_result[field])
        misp_object.add_reference(self.misp_attribute.uuid, "related-to")
        self.misp_event.add_object(misp_object)
