"Driver for libvirt virtual machine"

import os
from xml.etree import ElementTree as ET

import libvirt

from ...exceptions import VirtWrapperError
from .base import KernelVirtualDriver, to_async
from .disk import KernelVirtualDiskDriver
from .exceptions import wrap_libvirt
from .snapshot import KernelVirtualSnapshotDriver

STATES = {
    libvirt.VIR_DOMAIN_RUNNING: "Running",
    libvirt.VIR_DOMAIN_BLOCKED: "Blocked",
    libvirt.VIR_DOMAIN_PAUSED: "Paused",
    libvirt.VIR_DOMAIN_SHUTDOWN: "Shutdown",
    libvirt.VIR_DOMAIN_SHUTOFF: "Shutoff",
    libvirt.VIR_DOMAIN_CRASHED: "Crashed",
    libvirt.VIR_DOMAIN_NOSTATE: "No state",
}


@wrap_libvirt()
class KernelVirtualMachineDriver(KernelVirtualDriver):
    """Driver for managing the KVM virtual machine"""

    def __init__(self, conn: libvirt.virConnect, domain: libvirt.virDomain) -> None:
        super().__init__(conn=conn)
        self.domain = domain
        self.id = self.domain.UUIDString()

    @classmethod
    async def connect(cls, host: str, auth: tuple[str, str], *args, **kwargs):
        conn = await cls._connect(host=host, auth=auth)
        return cls(conn=conn, domain=conn.lookupByUUIDString(kwargs["uuid"]))

    async def get_name(self) -> str:
        """Get the virtual machine name"""
        try:
            return await to_async(self.domain.metadata, libvirt.VIR_DOMAIN_METADATA_TITLE, None)
        except libvirt.libvirtError:
            return await to_async(self.domain.name)

    async def get_state(self) -> str:
        """Get the virtual machine state"""
        state, _ = await to_async(self.domain.state)
        return STATES[state]

    async def set_name(self, name: str) -> None:
        """Change the virtual machine name"""
        await to_async(self.domain.setMetadata, libvirt.VIR_DOMAIN_METADATA_TITLE, name, None, None)

    async def get_description(self) -> str | None:
        """Get the virtual machine description"""
        try:
            return await to_async(self.domain.metadata, libvirt.VIR_DOMAIN_METADATA_DESCRIPTION, None)
        except libvirt.libvirtError:
            return None

    async def get_guest_os(self) -> str | None:
        """Get the name of the virtual machine guest operating system"""
        try:
            guest_info = await to_async(self.domain.guestInfo)
            return guest_info.get("os.pretty-name")
        except libvirt.libvirtError:
            return None

    async def get_memory_stat(self) -> dict[str, int]:
        """Get the memory statistic of the virtual machine"""
        if (await to_async(self.domain.state))[0] == libvirt.VIR_DOMAIN_SHUTOFF:
            actual = 0
            demand = 0
        else:
            memory_stats = await to_async(self.domain.memoryStats)
            actual = memory_stats.get("actual")
            demand = actual - memory_stats.get("unused", actual)

        return {
            "startup": (await to_async(self.domain.info))[2] * 1024,
            "maximum": await to_async(self.domain.maxMemory) * 1024,
            "demand": demand * 1024 if demand >= 0 else 0,
            "assigned": actual * 1024,
        }

    async def get_cpus(self) -> int:
        """Get number of cores"""
        return (await to_async(self.domain.info))[3]

    async def set_cpus(self, cpus: int) -> None:
        """Change number of cores"""
        await to_async(
            self.domain.setVcpusFlags, cpus, libvirt.VIR_DOMAIN_VCPU_MAXIMUM | libvirt.VIR_DOMAIN_AFFECT_CONFIG
        )
        await to_async(self.domain.setVcpusFlags, cpus, libvirt.VIR_DOMAIN_AFFECT_CONFIG)

    async def get_snapshots(self) -> list:
        """Get the list of the virtual machine snapshots"""
        ret = []

        snapshots_list: list[libvirt.virDomainSnapshot] = await to_async(self.domain.listAllSnapshots)

        for snap in snapshots_list:
            current_snapshot: libvirt.virDomainSnapshot = await to_async(self.domain.snapshotCurrent)
            snap_xml = await to_async(snap.getXMLDesc)
            tree_snap = ET.fromstring(snap_xml)

            cpus = tree_snap.find("domain/vcpu")
            ram = tree_snap.find("domain/currentMemory")
            description = tree_snap.find("description")
            created_at_ts = tree_snap.find("creationTime")

            try:
                parent = await to_async(lambda s=snap: s.getParent().getName())
            except libvirt.libvirtError:
                parent = None
            ret.append(
                {
                    "name": snap.getName(),
                    "description": description.text if description is not None else None,
                    "parent_name": parent,
                    "created_at_ts": int(created_at_ts.text or 0) if created_at_ts is not None else 0,
                    "is_applied": snap.getName() == current_snapshot.getName(),
                    "cpus": int(cpus.text or 0) if cpus is not None else 0,
                    "ram": int(ram.text or 0) * 1024 if ram is not None else 0,
                    "driver": KernelVirtualSnapshotDriver(domain=self.domain, snapshot=snap),
                }
            )

        return ret

    async def get_disks(self) -> list[dict]:
        """Get the list of the virtual machine connected disks"""
        ret = []
        domain_xml = await to_async(self.domain.XMLDesc)
        for src in ET.fromstring(domain_xml).findall("devices/disk/source"):
            try:
                if src.get("pool"):
                    storage_pool = await to_async(self.conn.storagePoolLookupByName, src.get("pool"))
                    volume = await to_async(storage_pool.storageVolLookupByName, src.get("volume"))
                else:
                    volume = await to_async(self.conn.storageVolLookupByPath, src.get("file"))
                    storage_pool = await to_async(volume.storagePoolLookupByVolume)
                _, size, used = await to_async(volume.info)

                parent_volume = volume
                backing_path = ET.fromstringlist(volume.XMLDesc()).find("backingStore/path")
                while backing_path is not None:
                    if parent_volume is None:
                        raise VirtWrapperError("Unable to find parent volume path")
                    parent_volume = await to_async(self.conn.storageVolLookupByPath, backing_path.text)
                    backing_path = ET.fromstringlist(parent_volume.XMLDesc()).find("backingStore/path")

                ret.append(
                    {
                        "driver": KernelVirtualDiskDriver(volume=volume, domain=self.domain),
                        "name": parent_volume.name(),
                        "path": volume.path(),
                        "storage": storage_pool.name(),
                        "size": size,
                        "used": used,
                    }
                )
            except libvirt.libvirtError:
                continue
        return ret

    async def get_networks(self) -> list:
        """Get the list of the virtual machine network adapters"""
        ret = []
        domain_xml = await to_async(self.domain.XMLDesc)
        for interface in ET.fromstring(domain_xml).findall("devices/interface"):
            mac = interface.find("mac")
            mac_address = "" if mac is None else mac.get("address", "")

            source = interface.find("source")
            switch_name = "" if source is None else source.get("bridge", "")

            domain_state = await to_async(self.domain.state)
            if domain_state[0] == libvirt.VIR_DOMAIN_RUNNING:
                try:
                    nets = await to_async(
                        self.domain.interfaceAddresses, libvirt.VIR_DOMAIN_INTERFACE_ADDRESSES_SRC_AGENT
                    )
                except libvirt.libvirtError:
                    nets = await to_async(
                        self.domain.interfaceAddresses, libvirt.VIR_DOMAIN_INTERFACE_ADDRESSES_SRC_ARP
                    )
                addresses = []
                for net in nets:
                    if nets[net].get("hwaddr") == mac_address:
                        addrs = nets[net].get("addrs")
                        address = [addr.get("addr") for addr in addrs]
                        addresses.extend(address)
                        break
            else:
                addresses = []

            ret.append({"mac": mac_address.upper(), "switch": switch_name, "addresses": addresses})
        return ret

    async def get_displays(self) -> list[dict]:
        """Get virtual displays"""
        ret = []
        domain_xml = await to_async(self.domain.XMLDesc, libvirt.VIR_DOMAIN_XML_SECURE)
        for display in ET.fromstring(domain_xml).findall("devices/graphics"):
            ret.append(
                {
                    "Type": display.get("type"),
                    "Port": display.get("port"),
                    "Password": display.get("passwd"),
                }
            )
        return ret

    async def run(self) -> None:
        """Power on the virtual machine"""
        if (await to_async(self.domain.state))[0] != libvirt.VIR_DOMAIN_RUNNING:
            await to_async(self.domain.create)

    async def shutdown(self) -> None:
        """Shutdown the virtual machine"""
        await to_async(self.domain.shutdown)

    async def poweroff(self) -> None:
        """Force off the virtual machine"""
        await to_async(self.domain.destroy)

    async def save(self) -> None:
        """Pause the virtual machine and temporarily saving its memory state to a file"""
        await to_async(self.domain.managedSave)

    async def suspend(self) -> None:
        """Pause the virtual machine and temporarily saving its memory state"""
        await to_async(self.domain.suspend)

    async def resume(self) -> None:
        """Unpause the suspended virtual machine"""
        await to_async(self.domain.resume)

    async def snapshot_create(self, name: str, description: str) -> None:
        """Create a new snapshot of virtual machine"""
        snapshot_xml_template = f"""<domainsnapshot>
            <name>{name}</name>
            <description>{description}</description>
        </domainsnapshot>"""
        await to_async(self.domain.snapshotCreateXML, snapshot_xml_template, libvirt.VIR_DOMAIN_SNAPSHOT_CREATE_ATOMIC)

    async def export(self, storage: str) -> str:
        """Export the virtual machine to a storage destination"""
        pool = await to_async(self.conn.storagePoolLookupByName, storage)
        pool_xml = await to_async(pool.XMLDesc)
        pool = ET.fromstring(pool_xml).find("target/path")
        if pool is None or pool.text is None:
            raise ValueError("Unable to find pool")

        target_pool_path = os.path.join(pool.text, await self.get_name())
        target_pool_name = f"export_{self.get_name()}"

        xml_pool = f"""<pool type='dir'>
  <name>{target_pool_name}</name>
  <target>
    <path>{target_pool_path}</path>
    <permissions>
      <mode>0777</mode>
    </permissions>
  </target>
</pool>
"""
        target_pool = await to_async(
            lambda: self.domain.connect().storagePoolCreateXML(xml_pool, libvirt.VIR_STORAGE_POOL_CREATE_WITH_BUILD)
        )
        try:
            for disk in await self.get_disks():
                volume = await to_async(self.conn.storageVolLookupByPath, disk["path"])
                xml_vol = f"""<volume>
  <name>{volume.name()}</name>
  <target>
    <permissions>
      <mode>0644</mode>
      <label>virt_image_t</label>
    </permissions>
  </target>
</volume>"""
                target_pool.createXMLFrom(xml_vol, volume, 0)
        finally:
            await to_async(target_pool.destroy)

        with open(os.path.join(target_pool_path, "config.xml"), "w", encoding="utf-8") as config:
            config.write(
                await to_async(
                    self.domain.XMLDesc,
                    libvirt.VIR_DOMAIN_XML_INACTIVE
                    | libvirt.VIR_DOMAIN_XML_UPDATE_CPU
                    | libvirt.VIR_DOMAIN_XML_MIGRATABLE,
                )
            )
        return target_pool_path
