"Driver for Hyper-V virtual machine"

import os
from typing import Any

import psrp

from .base import HyperVirtualDriver
from .disk import HyperVirtualDiskDriver
from .snapshot import HyperVirtualSnapshotDriver


class HyperVirtualMachineDriver(HyperVirtualDriver):
    """Driver for managing the Hyper-V virtual machine"""

    def __init__(self, conn: psrp.ConnectionInfo, uuid: str) -> None:
        super().__init__(conn=conn)
        self.id = uuid

    @classmethod
    async def connect(cls, host: str, auth: tuple[str, str], *args, **kwargs):
        conn = cls._connect(host=host, auth=auth)
        obj = cls(conn=conn, uuid=kwargs["uuid"])
        await obj.get_name()
        return obj

    async def __get_vm_property(self, prop: str | list[str]) -> Any:
        async with self.get_ps() as ps:
            ps.add_command("Get-VM")
            ps.add_parameter("Id", self.id)
            ps.add_command("Select")
            ps.add_parameter("Property", prop)
            result = await self.exec_ps(ps=ps)
        return result[0]

    async def get_name(self) -> str:
        """Get the virtual machine name"""
        result = await self.__get_vm_property("Name")
        return result.Name

    async def set_name(self, name: str) -> None:
        """Change the virtual machine name"""
        async with self.get_ps() as ps:
            ps.add_command("Get-VM")
            ps.add_parameter("Id", self.id)
            ps.add_command("Rename-VM")
            ps.add_parameter("NewName", name)
            await self.exec_ps(ps=ps)
        object.__setattr__(self, "name", name)

    async def get_state(self) -> str:
        """Get the virtual machine state"""
        result = await self.__get_vm_property("State")
        return str(result.State)

    async def get_description(self) -> str | None:
        """Get the virtual machine description"""
        result = await self.__get_vm_property("Notes")
        return result.Notes

    async def get_guest_os(self) -> str | None:
        """Get the name of the virtual machine guest operating system"""

        async with self.get_ps() as ps:
            ps.add_command("Get-WmiObject")
            ps.add_parameter("Namespace", r"root\virtualization\v2")
            ps.add_parameter("Query", f"Select * From Msvm_SummaryInformation Where Name='{self.id}'")
            ps.add_command("Select")
            ps.add_parameter("Property", "GuestOperatingSystem")
            result = await self.exec_ps(ps=ps)
        return result[0].GuestOperatingSystem

    async def get_memory_stat(self) -> dict[str, int]:
        """Get the memory statistic of the virtual machine"""
        result = await self.__get_vm_property(["MemoryStartup", "MemoryMaximum", "MemoryDemand", "MemoryAssigned"])
        return {
            "startup": result.MemoryStartup,
            "maximum": result.MemoryMaximum,
            "demand": result.MemoryDemand,
            "assigned": result.MemoryAssigned,
        }

    async def get_cpus(self) -> int:
        """Get number of cpu"""
        result = await self.__get_vm_property("ProcessorCount")
        return result.ProcessorCount

    async def set_cpus(self, cpus: int) -> None:
        """Sect number of cpu"""
        async with self.get_ps() as ps:
            ps.add_command("Get-VM")
            ps.add_parameter("Id", self.id)
            ps.add_command("Set-VMProcessor")
            ps.add_parameter("Count", cpus)
            await self.exec_ps(ps=ps)

    async def get_snapshots(self) -> list[dict]:
        """Get the list of the virtual machine snapshots"""
        vm_parent_snapshot = await self.__get_vm_property("ParentSnapshotName")
        async with self.get_ps() as ps:
            ps.add_command("Get-VM")
            ps.add_parameter("Id", self.id)
            ps.add_command("Get-VMSnapshot")
            ps.add_command("Select")
            ps.add_parameter(
                "Property",
                ["Id", "Name", "Notes", "ParentSnapshotName", "CreationTime", "ProcessorCount", "MemoryStartup"],
            )
            result = await self.exec_ps(ps)
        snapshots = []
        for r in result:
            snapshots.append(
                {
                    "name": r.Name,
                    "description": r.Notes,
                    "parent_name": r.ParentSnapshotName,
                    "created_at_ts": r.CreationTime.timestamp(),
                    "is_applied": r.Name == vm_parent_snapshot.ParentSnapshotName,
                    "cpus": r.ProcessorCount,
                    "ram": r.MemoryStartup,
                    "driver": HyperVirtualSnapshotDriver(conn=self.conn, snapshot_id=str(r.Id)),
                }
            )
        return snapshots

    async def get_disks(self) -> list[dict]:
        """Get the list of the virtual machine connected disks"""
        async with self.get_ps() as ps:
            ps.add_command("Get-VM")
            ps.add_parameter("Id", self.id)
            ps.add_command("Get-VMHardDiskDrive").add_command("Get-VHD")

            result = await self.exec_ps(ps)
        disks = []
        for d in result:
            disks.append(
                {
                    "name": str(d.Path).rsplit("\\", maxsplit=1)[-1],
                    "path": d.Path,
                    "storage": d.Path[0],
                    "size": d.Size,
                    "used": d.FileSize,
                    "driver": HyperVirtualDiskDriver(conn=self.conn, path=d.Path),
                }
            )
        return disks

    async def get_networks(self) -> list[dict]:
        """Get the list of the virtual machine network adapters"""
        async with self.get_ps() as ps:
            ps.add_command("Get-VM")
            ps.add_parameter("Id", self.id)
            ps.add_command("VMNetworkAdapter")
            ps.add_command("Select")
            ps.add_parameter("Property", ["MacAddress", "SwitchName", "IPAddresses"])
            result = await self.exec_ps(ps)
        networks = []
        for n in result:
            networks.append(
                {"mac": bytes.fromhex(n.MacAddress).hex(":"), "switch": n.SwitchName, "addresses": n.IPAddresses}
            )
        return networks

    async def run(self) -> None:
        """Power on the virtual machine"""
        async with self.get_ps() as ps:
            ps.add_command("Get-VM")
            ps.add_parameter("Id", self.id)
            ps.add_command("Start-VM")
            await self.exec_ps(ps)

    async def shutdown(self) -> None:
        """Shutdown the virtual machine"""
        async with self.get_ps() as ps:
            ps.add_command("Get-VM")
            ps.add_parameter("Id", self.id)
            ps.add_command("Stop-VM")
            ps.add_parameter("Force", True)
            await self.exec_ps(ps)

    async def poweroff(self) -> None:
        """Force off the virtual machine"""
        async with self.get_ps() as ps:
            ps.add_command("Get-VM")
            ps.add_parameter("Id", self.id)
            ps.add_command("Stop-VM")
            ps.add_parameter("TurnOff", True)
            await self.exec_ps(ps)

    async def save(self) -> None:
        """Pause the virtual machine and temporarily saving its memory state to a file"""
        async with self.get_ps() as ps:
            ps.add_command("Get-VM")
            ps.add_parameter("Id", self.id)
            ps.add_command("Save-VM")
            await self.exec_ps(ps)

    async def suspend(self) -> None:
        """Pause the virtual machine and temporarily saving its memory state"""
        async with self.get_ps() as ps:
            ps.add_command("Get-VM")
            ps.add_parameter("Id", self.id)
            ps.add_command("Suspend-VM")
            await self.exec_ps(ps)

    async def resume(self) -> None:
        """Unpause the suspended virtual machine"""
        async with self.get_ps() as ps:
            ps.add_command("Get-VM")
            ps.add_parameter("Id", self.id)
            ps.add_command("Resume-VM")
            await self.exec_ps(ps)

    async def snapshot_create(self, name: str, description: str) -> None:
        """Create a new snapshot of virtual machine"""
        async with self.get_ps() as ps:
            ps.add_command("Get-VM")
            ps.add_parameter("Id", self.id)
            ps.add_command("Checkpoint-VM")
            ps.add_parameter("SnapshotName", name)
            ps.add_parameter("-Passthru", True)
            ps.add_command("Select")
            ps.add_parameter("Property", "Id")
            result = await self.exec_ps(ps)

        snapshot_id = result[0].Id

        async with self.get_ps() as ps:
            ps.add_script(rf"""
                $snapshot = Get-CimInstance -Namespace root\virtualization\v2 -Query "SELECT * FROM Msvm_VirtualSystemSettingData WHERE InstanceID LIKE '%{snapshot_id}%'"

                $snapshot.Notes = "{description}"
                $serializer = [Microsoft.Management.Infrastructure.Serialization.CimSerializer]::Create()
                $snapshotBytes = $serializer.Serialize( $snapshot, [Microsoft.Management.Infrastructure.Serialization.InstanceSerializationOptions]::None )
                $snapshotStr   = [Text.Encoding]::Unicode.GetString( $snapshotBytes )
                $service = Get-CimInstance -Namespace 'root\virtualization\v2' -class 'Msvm_VirtualSystemManagementService'
                $result = $service | Invoke-CimMethod -MethodName ModifySystemSettings -Arguments @{{ SystemSettings = $snapshotStr }}

                if( $result.ReturnValue -notin 0, 4096 ) {{
                    $PSCmdlet.WriteError( [Management.Automation.ErrorRecord]::new(
                        [Exception]::new("Failed to set VM notes for snapshot '$SnapshotName' of VM '$VmName'"), 'ModifySystemSettingsFailed', [Management.Automation.ErrorCategory]::InvalidResult, $result.ReturnValue ))
                    return
                }}
            """)
            result = await self.exec_ps(ps)

    async def export(self, storage: str) -> str:
        """Export the virtual machine to a storage destination"""
        destination_path = os.path.join(storage + ":", "hyperv", "export")
        async with self.get_ps() as ps:
            ps.add_command("Get-VM")
            ps.add_parameter("Id", self.id)
            ps.add_command("Export-VM")
            ps.add_parameter("Path", destination_path)
            await self.exec_ps(ps)

        return os.path.join(destination_path, await self.get_name())
