"Base for Hyper-V driver"

import os
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Any

import psrp

from ...exceptions import (
    VirtualMachineError,
    VirtualMachineNotFoundError,
    VirtualMachineStateError,
    VirtWrapperError,
)

SCRIPT_PATH = os.path.dirname(os.path.realpath(__file__))
TEMPLATES_DIR = os.path.join(SCRIPT_PATH, "utils")

EXCEPTIONS_MAP = {
    "Microsoft.HyperV.PowerShell.Commands.GetVM": {
        "ObjectNotFound": VirtualMachineNotFoundError,
        "Default": VirtualMachineError,
    },
    "Microsoft.HyperV.PowerShell.Commands.SaveVM": {
        "InvalidState": VirtualMachineStateError,
        "Default": VirtualMachineError,
    },
    "Microsoft.HyperV.PowerShell.Commands.SuspendVM": {
        "InvalidState": VirtualMachineStateError,
        "Default": VirtualMachineError,
    },
    "Microsoft.HyperV.PowerShell.Commands.ResumeVM": {
        "InvalidState": VirtualMachineStateError,
        "Default": VirtualMachineError,
    },
    "Microsoft.HyperV.PowerShell.Commands.ImportVM": {
        "ObjectNotFound": VirtualMachineNotFoundError,
    },
}


class HyperVirtualDriver:
    """Common class for connecting to Hyper-V server"""

    conn: psrp.ConnectionInfo

    def __init__(self, conn: psrp.ConnectionInfo) -> None:
        self.conn = conn

    @classmethod
    async def connect(cls, host: str, auth: tuple[str, str], *args, **kwargs):
        """Get driver object"""
        conn = cls._connect(host=host, auth=auth)
        return cls(conn=conn, *args, **kwargs)

    @classmethod
    def _connect(cls, host: str, auth: tuple[str, str]) -> psrp.WSManInfo:
        """Connect to hypervisor"""
        return psrp.WSManInfo(
            server=f"https://{host}:5986/wsman",
            auth="basic",
            username=auth[0],
            password=auth[1],
        )

    @asynccontextmanager
    async def get_ps(self) -> AsyncGenerator[psrp.AsyncPowerShell]:
        """Get powershell"""
        async with psrp.AsyncRunspacePool(self.conn) as rp:
            yield psrp.AsyncPowerShell(rp)

    async def exec_ps(self, ps: psrp.AsyncPowerShell) -> list[Any]:
        """Execute powershell command"""
        result = await ps.invoke()
        if ps.had_errors:
            error = ps.streams.error[0]
            try:
                error_id, source = error.FullyQualifiedErrorId.split(",")
            except ValueError as exc:
                raise VirtWrapperError(str(error)) from exc

            if source in EXCEPTIONS_MAP:
                exception = EXCEPTIONS_MAP[source].get(error_id, "Default")
                raise exception(str(error))
            raise VirtWrapperError(str(error))
        return result
