import dataclasses as dc
import shutil
import socket
import time
import typing as tp
from datetime import datetime, timedelta

import psutil
import speedtest

HealthStatus = tp.Literal["healthy", "under_load", "unhealthy"]


@dc.dataclass(frozen=True, slots=True)
class DiskStats:
    total: int
    used: int
    free: int
    percent_used: float


@dc.dataclass(frozen=True, slots=True)
class NetworkStats:
    download_speed: float
    upload_speed: float
    ping: float


@dc.dataclass(frozen=True, slots=True)
class SystemStats:
    cpu_percent: float
    memory_percent: float
    uptime: float
    load_avg: list[float]


@dc.dataclass(frozen=True, slots=True)
class CachedValue:
    value: tp.Any
    timestamp: datetime
    ttl: timedelta

    def is_expired(self) -> bool:
        return datetime.now() - self.timestamp > self.ttl


@dc.dataclass(frozen=True, slots=True)
class HealthCheckResult:
    status: HealthStatus
    score: float
    disk: DiskStats
    network: NetworkStats
    system: SystemStats


def measure_disk_space(path: str = "/") -> DiskStats:
    disk = shutil.disk_usage(path)
    percent_used = disk.used / disk.total * 100 if disk.total > 0 else 0.0

    return DiskStats(total=disk.total, used=disk.used, free=disk.free, percent_used=percent_used)


def check_disk_space(path: str = "/", force_refresh: bool = False) -> DiskStats:
    _clear_cache_if_refresh(force_refresh, "disk_space")
    return _get_cached(key="disk_space", default_factory=lambda: measure_disk_space(path), ttl=timedelta(minutes=30))


def measure_network_speed() -> NetworkStats:
    try:
        st = speedtest.Speedtest()
        st.get_best_server()
        st.download()
        st.upload()

        ping = st.results.ping
        download = st.results.download / 1_000_000
        upload = st.results.upload / 1_000_000

        return NetworkStats(download_speed=download, upload_speed=upload, ping=ping)
    except Exception:
        host = "8.8.8.8"
        start_time = time.time()
        try:
            socket.create_connection((host, 53), timeout=2)
            ping_time = (time.time() - start_time) * 1000
        except (TimeoutError, OSError):
            ping_time = 9999.0

        download_speed = 100.0 if ping_time < 100 else 50.0
        upload_speed = 50.0 if ping_time < 100 else 25.0

        return NetworkStats(download_speed=download_speed, upload_speed=upload_speed, ping=ping_time)


_cache: dict[str, CachedValue] = {}


def _clear_cache_if_refresh(force_refresh: bool, *keys: str) -> None:
    if force_refresh:
        for key in keys:
            _cache.pop(key, None)


def _get_cached(key: str, default_factory: tp.Callable[[], tp.Any], ttl: timedelta) -> tp.Any:
    now = datetime.now()
    cache_entry = _cache.get(key)

    if cache_entry is None or cache_entry.is_expired():
        value = default_factory()
        _cache[key] = CachedValue(value=value, timestamp=now, ttl=ttl)
        return value

    return cache_entry.value


def check_network_speed(force_refresh: bool = False) -> NetworkStats:
    _clear_cache_if_refresh(force_refresh, "network_speed")
    return _get_cached(key="network_speed", default_factory=measure_network_speed, ttl=timedelta(minutes=60))


def measure_system_stats() -> SystemStats:
    return SystemStats(
        cpu_percent=psutil.cpu_percent(interval=0.5),
        memory_percent=psutil.virtual_memory().percent,
        uptime=time.time() - psutil.boot_time(),
        load_avg=psutil.getloadavg(),  # type: ignore
    )


def check_system_stats(force_refresh: bool = False) -> SystemStats:
    _clear_cache_if_refresh(force_refresh, "system_stats")
    return _get_cached(key="system_stats", default_factory=measure_system_stats, ttl=timedelta(minutes=1))


def calculate_health_score(
    disk_stats: DiskStats,
    network_stats: NetworkStats,
    system_stats: SystemStats,
) -> float:
    disk_penalty = 0.2 if disk_stats.percent_used > 90 else (0.5 if disk_stats.percent_used > 80 else 1.0)
    disk_score = 40 * (1 - disk_stats.percent_used / 100) * disk_penalty
    if disk_stats.percent_used > 95:
        return min(30, disk_score)
    network_score = 0
    if network_stats.ping < 9999:
        network_score = 15 * max(0, min(1, (200 - network_stats.ping) / 150)) + 15 * min(
            1, network_stats.download_speed / 100
        )
    system_score = 15 * (2 - (system_stats.cpu_percent + system_stats.memory_percent) / 100)
    return round(disk_score + network_score + system_score, 1)


def get_health_status(score: float) -> HealthStatus:
    if score >= 75:
        return "healthy"
    elif score >= 40:
        return "under_load"
    else:
        return "unhealthy"


def _calculate_health_result() -> HealthCheckResult:
    disk_stats = check_disk_space()
    network_stats = check_network_speed()
    system_stats = check_system_stats()

    score = calculate_health_score(disk_stats, network_stats, system_stats)
    status = get_health_status(score)

    return HealthCheckResult(status=status, score=score, disk=disk_stats, network=network_stats, system=system_stats)


def check_health(force_refresh: bool = False) -> HealthCheckResult:
    _clear_cache_if_refresh(force_refresh, "health_result", "disk_space", "network_speed", "system_stats")
    return _get_cached(key="health_result", default_factory=_calculate_health_result, ttl=timedelta(minutes=5))
