from __future__ import annotations

from logging import getLogger
from typing import Optional

from fastapi import APIRouter, Depends
from prometheus_client import Counter, Gauge
from sqlmodel import select

import murfey.server.prometheus as prom
from murfey.server.api.auth import validate_instrument_token
from murfey.server.murfey_db import murfey_db
from murfey.util import sanitise
from murfey.util.db import RsyncInstance
from murfey.util.models import RsyncerInfo, RsyncerSkippedFiles

logger = getLogger("murfey.server.api.prometheus")

router = APIRouter(
    prefix="/prometheus",
    dependencies=[Depends(validate_instrument_token)],
    tags=["Prometheus"],
)


@router.post("/visits/{visit_name}/increment_rsync_file_count")
def increment_rsync_file_count(
    visit_name: str, rsyncer_info: RsyncerInfo, db=murfey_db
):
    try:
        rsync_instance = db.exec(
            select(RsyncInstance).where(
                RsyncInstance.source == rsyncer_info.source,
                RsyncInstance.destination == rsyncer_info.destination,
                RsyncInstance.session_id == rsyncer_info.session_id,
            )
        ).one()
    except Exception:
        logger.error(
            f"Failed to find rsync instance for visit {sanitise(visit_name)} "
            "with the following properties: \n"
            f"{rsyncer_info.model_dump()}",
            exc_info=True,
        )
        return None
    rsync_instance.files_counted += rsyncer_info.increment_count
    db.add(rsync_instance)
    db.commit()
    db.close()
    prom.seen_files.labels(rsync_source=rsyncer_info.source, visit=visit_name).inc(
        rsyncer_info.increment_count
    )
    prom.seen_data_files.labels(rsync_source=rsyncer_info.source, visit=visit_name).inc(
        rsyncer_info.increment_data_count
    )


@router.post("/visits/{visit_name}/increment_rsync_transferred_files")
def increment_rsync_transferred_files(
    visit_name: str, rsyncer_info: RsyncerInfo, db=murfey_db
):
    rsync_instance = db.exec(
        select(RsyncInstance).where(
            RsyncInstance.source == rsyncer_info.source,
            RsyncInstance.destination == rsyncer_info.destination,
            RsyncInstance.session_id == rsyncer_info.session_id,
        )
    ).one()
    rsync_instance.files_transferred += rsyncer_info.increment_count
    db.add(rsync_instance)
    db.commit()
    db.close()


@router.post("/visits/{visit_name}/increment_rsync_transferred_files_prometheus")
def increment_rsync_transferred_files_prometheus(
    visit_name: str, rsyncer_info: RsyncerInfo, db=murfey_db
):
    prom.transferred_files.labels(
        rsync_source=rsyncer_info.source, visit=visit_name
    ).inc(rsyncer_info.increment_count)
    prom.transferred_files_bytes.labels(
        rsync_source=rsyncer_info.source, visit=visit_name
    ).inc(rsyncer_info.bytes)
    prom.transferred_data_files.labels(
        rsync_source=rsyncer_info.source, visit=visit_name
    ).inc(rsyncer_info.increment_data_count)
    prom.transferred_data_files_bytes.labels(
        rsync_source=rsyncer_info.source, visit=visit_name
    ).inc(rsyncer_info.data_bytes)


@router.post("/visits/{visit_name}/increment_rsync_skipped_files_prometheus")
def increment_rsync_skipped_files_prometheus(
    visit_name: str, rsyncer_skipped_files: RsyncerSkippedFiles, db=murfey_db
):
    prom.skipped_files.labels(
        rsync_source=rsyncer_skipped_files.source, visit=visit_name
    ).inc(rsyncer_skipped_files.increment_count)


@router.post("/visits/{visit_name}/monitoring/{on}")
def change_monitoring_status(visit_name: str, on: int):
    prom.monitoring_switch.labels(visit=visit_name)
    prom.monitoring_switch.labels(visit=visit_name).set(on)


@router.get("/metrics/{metric_name}")
def inspect_prometheus_metrics(
    metric_name: str,
):
    """
    A debugging endpoint that returns the current contents of any Prometheus
    gauges and counters that have been set up thus far.
    """

    # Extract the Prometheus metric defined in the Prometheus module
    metric: Optional[Counter | Gauge] = getattr(prom, metric_name, None)
    if metric is None or not isinstance(metric, (Counter, Gauge)):
        raise LookupError("No matching metric was found")

    # Package contents into dict and return
    results = {}
    if hasattr(metric, "_metrics"):
        for i, (label_tuple, sub_metric) in enumerate(metric._metrics.items()):
            labels = dict(zip(metric._labelnames, label_tuple))
            labels["value"] = sub_metric._value.get()
            results[i] = labels
        return results
    else:
        value = metric._value.get()
        return {"value": value}
