"""Report vulnerability Mixin.

Definition of the main method to process an entry from the KB - Knowledge Base,
and emit a vulnerability message.
"""

import dataclasses
import enum
from typing import Optional, List, Any, Union, Dict

from ostorlab.agent.kb import kb
from ostorlab.agent.mixins.protocols import emit
from ostorlab.assets import asset as os_asset


FrameProtoDictType = Dict[str, Optional[str]]
CallTraceProtoDictType = Dict[str, List[FrameProtoDictType]]
VulnerabilityLocationMetadataProtoDictType = Dict[
    str, Union[str, CallTraceProtoDictType]
]


class MetadataType(enum.Enum):
    """Vulnerability location metadata type."""

    FILE_PATH = enum.auto()
    CODE_LOCATION = enum.auto()
    URL = enum.auto()
    PORT = enum.auto()
    LOG = enum.auto()
    PACKAGE_NAME = enum.auto()
    VERSION = enum.auto()
    CLASS_NAME = enum.auto()
    METHOD_NAME = enum.auto()
    CALL_TRACE = enum.auto()
    INSERTION_POINT = enum.auto()
    DNS_RECORD_TYPE = enum.auto()


@dataclasses.dataclass
class Frame:
    """Represents a single frame in a call trace."""

    function_name: str
    function_signature: Optional[str] = None
    class_name: Optional[str] = None
    package_name: Optional[str] = None

    def to_proto_dict(self) -> FrameProtoDictType:
        """Return a dictionary structured same as the corresponding `Frame` protobuf field of the callTrace message."""
        proto_dict_value: FrameProtoDictType = {"function_name": self.function_name}
        if self.class_name is not None:
            proto_dict_value["class_name"] = self.class_name
        if self.package_name is not None:
            proto_dict_value["package_name"] = self.package_name
        return proto_dict_value


@dataclasses.dataclass
class CallTrace:
    """Represents the captured call trace."""

    frames: List[Frame] = dataclasses.field(default_factory=lambda: [])

    def to_proto_dict(self) -> CallTraceProtoDictType:
        """Return a dictionary structured same as the corresponding `callTrace` protobuf message."""
        proto_dict_value: CallTraceProtoDictType = {
            "frames": [frame.to_proto_dict() for frame in self.frames]
        }
        return proto_dict_value


@dataclasses.dataclass
class VulnerabilityLocationMetadata:
    """Vulnerability location metadata holding the type
    and the value of 'where' the vulnerability was detected in the asset"""

    metadata_type: MetadataType
    value: Union[str, CallTrace]

    def to_proto_dict(self) -> VulnerabilityLocationMetadataProtoDictType:
        proto_dict_value: VulnerabilityLocationMetadataProtoDictType = {
            "type": self.metadata_type.name
        }
        if isinstance(self.value, str):
            proto_dict_value["value"] = self.value
        if isinstance(self.value, CallTrace):
            proto_dict_value["calltrace"] = self.value.to_proto_dict()
        return proto_dict_value


@dataclasses.dataclass
class VulnerabilityLocation:
    """Vulnerability location used to attach a vulnerability to a specific asset"""

    metadata: List[VulnerabilityLocationMetadata]
    asset: Union[os_asset.Asset, None] = None

    def to_dict(
        self,
    ) -> Dict[str, Union[Any, List[VulnerabilityLocationMetadataProtoDictType]]]:
        """Convert data class to a dict matching what is expected from protobuf."""
        location: Dict[
            str, Union[Any, List[VulnerabilityLocationMetadataProtoDictType]]
        ] = {"metadata": [meta.to_proto_dict() for meta in self.metadata]}
        if self.asset is not None:
            location[self.asset.proto_field] = self.asset.__dict__
        return location


class RiskRating(enum.Enum):
    """Risk ratings enumeration."""

    CRITICAL = enum.auto()
    HIGH = enum.auto()
    MEDIUM = enum.auto()
    LOW = enum.auto()
    POTENTIALLY = enum.auto()
    HARDENING = enum.auto()
    SECURE = enum.auto()
    IMPORTANT = enum.auto()
    INFO = enum.auto()


class AgentReportVulnMixin(emit.EmitProtocol):
    """Report Vulnerability class implementing logic of fetching entries from the knowledge base,
    and emitting vulnerability messages."""

    def report_vulnerability(
        self,
        entry: kb.Entry,
        technical_detail: str,
        risk_rating: RiskRating,
        dna: Optional[str] = None,
        vulnerability_location: Optional[VulnerabilityLocation] = None,
        exploitation_detail: Optional[str] = None,
        post_exploitation_detail: Optional[str] = None,
        risk: Optional[str] = None,
    ) -> None:
        """Fetch the details of an entry from the knowledge base, and emit a vulnerability message.
        Args:
            entry: knowledge base entry.
            technical_detail: markdown of the scan results.
            risk_rating: the risk rating assigned to the result of the scan.
            dna: unique identifier for duplicate vulnerabilities.
            vulnerability_location: vulnerability location where the vulnerability was detected.
            exploitation_detail: steps taken to exploit the vulnerability.
            post_exploitation_detail: impact or aftermath of the exploitation.
            risk: risk linked to the vulnerability.
        Returns:
            None
        """
        title = entry.title
        short_description = entry.short_description
        description = entry.description
        recommendation = entry.recommendation
        security_issue = entry.security_issue
        privacy_issue = entry.privacy_issue
        has_public_exploit = entry.has_public_exploit
        targeted_by_malware = entry.targeted_by_malware
        targeted_by_ransomware = entry.targeted_by_ransomware
        targeted_by_nation_state = entry.targeted_by_nation_state
        cvss_v3_vector = entry.cvss_v3_vector
        cvss_v4_vector = entry.cvss_v4_vector
        category_groups = entry.category_groups

        references = []
        for key, value in entry.references.items():
            reference = {}
            reference["title"] = key
            reference["url"] = value
            references.append(reference)
        data = {
            "title": title,
            "technical_detail": technical_detail,
            "risk_rating": risk_rating.name,
            "short_description": short_description,
            "description": description,
            "recommendation": recommendation,
            "references": references,
            "security_issue": security_issue,
            "privacy_issue": privacy_issue,
            "has_public_exploit": has_public_exploit,
            "targeted_by_malware": targeted_by_malware,
            "targeted_by_ransomware": targeted_by_ransomware,
            "targeted_by_nation_state": targeted_by_nation_state,
            "cvss_v3_vector": cvss_v3_vector,
            "cvss_v4_vector": cvss_v4_vector,
            "dna": dna,
            "category_groups": category_groups,
            "exploitation_detail": exploitation_detail,
            "post_exploitation_detail": post_exploitation_detail,
            "risk": risk,
        }
        if vulnerability_location is not None:
            data["vulnerability_location"] = vulnerability_location.to_dict()

        # if dna is not specified, it should not be provided to the portal otherwise it will cause a problem.
        if dna is None:
            data.pop("dna")

        selector = "v3.report.vulnerability"
        self.emit(selector, data)
