"""Flask service exposing uma-ase workflows for web clients."""

from __future__ import annotations

import shutil
import subprocess
import sys
import tempfile
import zipfile
from contextlib import suppress
from datetime import datetime
from importlib import resources
from pathlib import Path
from collections import Counter
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Optional
import threading
import uuid
from types import SimpleNamespace
from xml.sax.saxutils import escape
import io

import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

try:
    from docx import Document

    DRIVER_DOCX_AVAILABLE = True
except ModuleNotFoundError:
    DRIVER_DOCX_AVAILABLE = False

from flask import Flask, Response, abort, after_this_request, jsonify, request, send_file
from werkzeug.utils import secure_filename

from ase.io import read, write, Trajectory
from ase import units
from ase.geometry import cellpar_to_cell
from ase.md.langevin import Langevin
from ase.md.verlet import VelocityVerlet
from ase.md.nvtberendsen import NVTBerendsen
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary, ZeroRotation

try:  # Optional engines that may not be present in older ASE releases.
    from ase.md.nose_hoover_chain import NoseHooverChainNVT, IsotropicMTKNPT, MTKNPT
except ModuleNotFoundError:  # pragma: no cover - depends on ASE version
    NoseHooverChainNVT = None
    IsotropicMTKNPT = None
    MTKNPT = None

try:
    from ase.md.bussi import Bussi
except ModuleNotFoundError:  # pragma: no cover
    Bussi = None

try:
    from ase.md.andersen import Andersen
except ModuleNotFoundError:  # pragma: no cover
    Andersen = None

try:
    from ase.md.nptberendsen import NPTBerendsen
except ModuleNotFoundError:  # pragma: no cover
    NPTBerendsen = None

try:
    from ase.md.langevinbaoab import LangevinBAOAB
except ModuleNotFoundError:  # pragma: no cover
    LangevinBAOAB = None

try:
    from ase.md.melchionna import MelchionnaNPT
except ModuleNotFoundError:  # pragma: no cover
    MelchionnaNPT = None

from .styled_rmsd_report import generate_report, _write_basic_docx as _write_basic_report_docx
from .utils import extract_xyz_metadata
from .workflows import (
    build_output_paths,
    select_device,
    TorchUnavailable,
    setup_calculated_atoms,
    configure_logging,
)

MD_ENGINE_LABELS = {
    "langevin": "Langevin (NVT)",
    "velocity_verlet": "VelocityVerlet (NVE)",
    "nvt_berendsen": "NVT Berendsen",
    "nose_hoover_chain": "Nosé-Hoover Chain (NVT)",
    "bussi": "Bussi (NVT)",
    "andersen": "Andersen (NVT)",
    "langevin_baoab_nvt": "Langevin-Hoover BAOAB (NVT)",
    "npt_berendsen": "NPT Berendsen",
    "isotropic_mtk_npt": "Isotropic MTK (NPT)",
    "mtk_npt": "Full MTK (NPT)",
    "langevin_baoab_npt": "Langevin-Hoover BAOAB (NPT)",
    "melchionna_npt": "Melchionna NPT",
}


def _parse_float_field(
    form,
    field: str,
    default: Optional[float],
    *,
    label: str,
    min_value: Optional[float] = None,
    min_inclusive: bool = True,
    max_value: Optional[float] = None,
    max_inclusive: bool = True,
    required: bool = False,
) -> Optional[float]:
    raw = form.get(field)
    if raw is None or raw.strip() == "":
        if required:
            raise ValueError(f"{label} is required.")
        return float(default) if default is not None else None
    try:
        value = float(raw)
    except (TypeError, ValueError):
        raise ValueError(f"{label} must be a number.") from None
    if min_value is not None:
        if value < min_value or (not min_inclusive and value == min_value):
            comparator = "greater than" if not min_inclusive else "at least"
            raise ValueError(f"{label} must be {comparator} {min_value}.")
    if max_value is not None:
        if value > max_value or (not max_inclusive and value == max_value):
            comparator = "less than" if not max_inclusive else "at most"
            raise ValueError(f"{label} must be {comparator} {max_value}.")
    return value


def _parse_int_field(
    form,
    field: str,
    default: Optional[int],
    *,
    label: str,
    min_value: Optional[int] = None,
    min_inclusive: bool = True,
    required: bool = False,
) -> Optional[int]:
    raw = form.get(field)
    if raw is None or raw.strip() == "":
        if required:
            raise ValueError(f"{label} is required.")
        return int(default) if default is not None else None
    try:
        value = int(raw)
    except (TypeError, ValueError):
        raise ValueError(f"{label} must be an integer.") from None
    if min_value is not None:
        if value < min_value or (not min_inclusive and value == min_value):
            comparator = "greater than" if not min_inclusive else "at least"
            raise ValueError(f"{label} must be {comparator} {min_value}.")
    return value


def _parse_bool_field(form, field: str, default: bool = False) -> bool:
    raw = form.get(field)
    if raw is None:
        return default
    return raw.strip().lower() in ("1", "true", "yes", "on")


def _missing_engine_builder(engine_key: str, module_hint: str):
    def _builder(*_):
        raise RuntimeError(
            f"The '{engine_key}' MD engine requires {module_hint}, but it is not available in the installed ASE. "
            "Update ASE or choose a different engine."
        )

    return _builder


def _build_md_builder_params(md_options: Dict[str, Any]) -> SimpleNamespace:
    temperature = md_options.get("temperature", 300.0)
    timestep = md_options.get("timestep_fs", 0.5) * units.fs
    friction_coeff = max(0.0, md_options.get("friction", 0.0)) / units.fs
    relax_fs = max(1e-12, md_options.get("relax_fs", 100.0))
    relax_time = relax_fs * units.fs
    pressure_bar = md_options.get("pressure_bar", 1.0)
    compressibility_bar = md_options.get("compressibility_bar", 0.0)
    compressibility_au = compressibility_bar / units.bar if compressibility_bar and compressibility_bar > 0 else None
    barostat_relax = max(1e-12, md_options.get("barostat_relax_fs", 1000.0)) * units.fs
    nose_tdamp = max(1e-12, md_options.get("nose_tdamp_fs", relax_fs)) * units.fs
    bussi_taut = max(1e-12, md_options.get("bussi_taut_fs", relax_fs)) * units.fs
    baoab_ttau = max(1e-12, md_options.get("baoab_ttau_fs", 50.0)) * units.fs
    baoab_ptau = max(1e-12, md_options.get("baoab_ptau_fs", 1000.0)) * units.fs
    melchionna_ttime = max(0.0, md_options.get("melchionna_ttime_fs", 25.0)) * units.fs
    melchionna_ptime = max(0.0, md_options.get("melchionna_ptime_fs", 75.0)) * units.fs
    return SimpleNamespace(
        temperature=temperature,
        timestep=timestep,
        friction=friction_coeff,
        relax=relax_time,
        pressure_bar=pressure_bar,
        pressure_au=pressure_bar * units.bar,
        compressibility_au=compressibility_au,
        barostat_relax=barostat_relax,
        nose_tdamp=nose_tdamp,
        nose_tchain=int(md_options.get("nose_tchain", 3)),
        nose_tloop=int(md_options.get("nose_tloop", 1)),
        bussi_taut=bussi_taut,
        andersen_prob=float(md_options.get("andersen_prob", 0.0)),
        mtk_tdamp=max(1e-12, md_options.get("mtk_tdamp_fs", 100.0)) * units.fs,
        mtk_pdamp=max(1e-12, md_options.get("mtk_pdamp_fs", 1000.0)) * units.fs,
        mtk_tchain=int(md_options.get("mtk_tchain", 3)),
        mtk_pchain=int(md_options.get("mtk_pchain", 3)),
        mtk_tloop=int(md_options.get("mtk_tloop", 1)),
        mtk_ploop=int(md_options.get("mtk_ploop", 1)),
        baoab_ttau=baoab_ttau,
        baoab_ptau=baoab_ptau,
        baoab_pmass=md_options.get("baoab_pmass"),
        baoab_pmass_factor=md_options.get("baoab_pmass_factor", 1.0),
        baoab_hydrostatic=bool(md_options.get("baoab_hydrostatic")),
        melchionna_ttime=melchionna_ttime if melchionna_ttime > 0 else None,
        melchionna_ptime=melchionna_ptime if melchionna_ptime > 0 else None,
        melchionna_bulk=md_options.get("melchionna_bulk_gpa", 0.0) * units.GPa,
        melchionna_mask=md_options.get("melchionna_mask"),
    )


MD_ENGINE_BUILDERS: Dict[str, Callable[[Any, SimpleNamespace], Any]] = {
    "langevin": lambda atoms, p: Langevin(
        atoms,
        p.timestep,
        temperature_K=p.temperature,
        friction=p.friction,
    ),
    "velocity_verlet": lambda atoms, p: VelocityVerlet(atoms, p.timestep),
    "nvt_berendsen": lambda atoms, p: NVTBerendsen(
        atoms,
        p.timestep,
        temperature_K=p.temperature,
        taut=p.relax,
    ),
}

MD_ENGINE_BUILDERS["bussi"] = (
    (lambda atoms, p: Bussi(atoms, p.timestep, temperature_K=p.temperature, taut=p.bussi_taut))
    if Bussi is not None
    else _missing_engine_builder("bussi", "ase.md.bussi.Bussi")
)

MD_ENGINE_BUILDERS["andersen"] = (
    (lambda atoms, p: Andersen(atoms, p.timestep, temperature_K=p.temperature, andersen_prob=p.andersen_prob))
    if Andersen is not None
    else _missing_engine_builder("andersen", "ase.md.andersen.Andersen")
)

MD_ENGINE_BUILDERS["npt_berendsen"] = (
    (
        lambda atoms, p: NPTBerendsen(
            atoms,
            p.timestep,
            temperature_K=p.temperature,
            pressure_au=p.pressure_au,
            taut=p.relax,
            taup=p.barostat_relax,
            compressibility_au=p.compressibility_au,
        )
    )
    if NPTBerendsen is not None
    else _missing_engine_builder("npt_berendsen", "ase.md.nptberendsen.NPTBerendsen")
)

MD_ENGINE_BUILDERS["nose_hoover_chain"] = (
    (
        lambda atoms, p: NoseHooverChainNVT(
            atoms,
            p.timestep,
            temperature_K=p.temperature,
            tdamp=p.nose_tdamp,
            tchain=p.nose_tchain,
            tloop=p.nose_tloop,
        )
    )
    if NoseHooverChainNVT is not None
    else _missing_engine_builder("nose_hoover_chain", "ase.md.nose_hoover_chain.NoseHooverChainNVT")
)

MD_ENGINE_BUILDERS["isotropic_mtk_npt"] = (
    (
        lambda atoms, p: IsotropicMTKNPT(
            atoms,
            p.timestep,
            temperature_K=p.temperature,
            pressure_au=p.pressure_au,
            tdamp=p.mtk_tdamp,
            pdamp=p.mtk_pdamp,
            tchain=p.mtk_tchain,
            pchain=p.mtk_pchain,
            tloop=p.mtk_tloop,
            ploop=p.mtk_ploop,
        )
    )
    if IsotropicMTKNPT is not None
    else _missing_engine_builder("isotropic_mtk_npt", "ase.md.nose_hoover_chain.IsotropicMTKNPT")
)

MD_ENGINE_BUILDERS["mtk_npt"] = (
    (
        lambda atoms, p: MTKNPT(
            atoms,
            p.timestep,
            temperature_K=p.temperature,
            pressure_au=p.pressure_au,
            tdamp=p.mtk_tdamp,
            pdamp=p.mtk_pdamp,
            tchain=p.mtk_tchain,
            pchain=p.mtk_pchain,
            tloop=p.mtk_tloop,
            ploop=p.mtk_ploop,
        )
    )
    if MTKNPT is not None
    else _missing_engine_builder("mtk_npt", "ase.md.nose_hoover_chain.MTKNPT")
)

MD_ENGINE_BUILDERS["langevin_baoab_nvt"] = (
    (
        lambda atoms, p: LangevinBAOAB(
            atoms,
            p.timestep,
            temperature_K=p.temperature,
            T_tau=p.baoab_ttau,
        )
    )
    if LangevinBAOAB is not None
    else _missing_engine_builder("langevin_baoab_nvt", "ase.md.langevinbaoab.LangevinBAOAB")
)

MD_ENGINE_BUILDERS["langevin_baoab_npt"] = (
    (
        lambda atoms, p: LangevinBAOAB(
            atoms,
            p.timestep,
            temperature_K=p.temperature,
            externalstress=-p.pressure_au,
            hydrostatic=p.baoab_hydrostatic,
            T_tau=p.baoab_ttau,
            P_tau=p.baoab_ptau,
            P_mass=p.baoab_pmass if p.baoab_pmass else None,
            P_mass_factor=p.baoab_pmass_factor,
        )
    )
    if LangevinBAOAB is not None
    else _missing_engine_builder("langevin_baoab_npt", "ase.md.langevinbaoab.LangevinBAOAB")
)

MD_ENGINE_BUILDERS["melchionna_npt"] = (
    (
        lambda atoms, p: MelchionnaNPT(
            atoms,
            p.timestep,
            temperature_K=p.temperature,
            externalstress=-p.pressure_au,
            ttime=p.melchionna_ttime,
            pfactor=(p.melchionna_ptime**2 * (p.melchionna_bulk or 0.0)) if p.melchionna_ptime else None,
            mask=p.melchionna_mask,
        )
    )
    if MelchionnaNPT is not None
    else _missing_engine_builder("melchionna_npt", "ase.md.melchionna.MelchionnaNPT")
)

STATIC_HTML = "uma-ase.html"
app = Flask(__name__)
app.config.setdefault("UMA_RESULTS_DIR", Path.home() / ".uma_ase" / "results")
ANALYZE_REPORT_ROOT = Path(app.config["UMA_RESULTS_DIR"]) / "analyze_reports"
ANALYZE_REPORT_ROOT.mkdir(parents=True, exist_ok=True)


@dataclass
class JobRecord:
    job_id: str
    job_dir: Path
    charge: int
    spin: int
    grad: float
    iterations: int
    run_types: List[str]
    status: str = "running"
    message: Optional[str] = None
    log_path: Optional[Path] = None
    traj_path: Optional[Path] = None
    opt_path: Optional[Path] = None
    log_url: Optional[str] = None
    traj_url: Optional[str] = None
    opt_url: Optional[str] = None
    relative_path: Optional[Path] = None
    job_kind: str = "workflow"
    md_options: Optional[Dict[str, Any]] = None
    md_multi_xyz: Optional[Path] = None
    md_multi_xyz_url: Optional[str] = None
    cancel_event: Optional[threading.Event] = None
    worker: Optional[threading.Thread] = None
    process: Optional[subprocess.Popen] = None


JOBS: Dict[str, JobRecord] = {}
JOB_LOCK = threading.Lock()


class JobCancelled(Exception):
    """Raised internally when a running job is cancelled by the user."""


def _get_job(job_id: str) -> JobRecord:
    with JOB_LOCK:
        record = JOBS.get(job_id)
    if record is None:
        abort(404)
    return record


def _build_cli_args(
    input_path: Path,
    run_types: Iterable[str],
    charge: str,
    spin: str,
    optimizer: str,
    grad: str,
    iterations: str,
    temperature: str,
    pressure: str,
    mlff_checkpoint: str | None,
    mlff_task: str | None,
) -> List[str]:
    args: List[str] = [
        "-input",
        str(input_path),
        "-chg",
        charge,
        "-spin",
        spin,
        "-optimizer",
        optimizer,
        "-grad",
        grad,
        "-iter",
        iterations,
        "-temp",
        temperature,
        "-press",
        pressure,
    ]
    if run_types:
        args.extend(["-run-type", *run_types])
    if mlff_checkpoint:
        args.extend(["-mlff-chk", mlff_checkpoint])
    if mlff_task:
        args.extend(["-mlff-task", mlff_task])
    return args


def _initialise_md_velocities(atoms, temperature: float, logger):
    if temperature <= 0:
        logger.info("Temperature %.2f K <= 0. Skipping velocity initialisation.", temperature)
        return
    MaxwellBoltzmannDistribution(atoms, temperature_K=temperature)
    Stationary(atoms)
    ZeroRotation(atoms)
    logger.info("Initial velocities sampled at %.2f K.", temperature)


def _create_md_dynamics(engine: str, atoms, *, md_options: Dict[str, Any]):
    params = _build_md_builder_params(md_options)
    builder = MD_ENGINE_BUILDERS.get(engine)
    if builder is None:
        raise ValueError(
            "Unsupported MD engine "
            f"'{engine}'. Available: {', '.join(MD_ENGINE_LABELS)}"
        )
    return builder(atoms, params)


def _collect_log(temp_dir: Path) -> str:
    logs = sorted(temp_dir.glob("*.log"), key=lambda p: p.stat().st_mtime, reverse=True)
    if not logs:
        return "No log file generated."
    return logs[0].read_text(encoding="utf-8", errors="replace")


def _safe_save_upload(storage, base_dir: Path) -> Path:
    filename = storage.filename or getattr(storage, "name", None)
    if not filename:
        raise ValueError("Uploaded file missing name.")
    relative_parts = [secure_filename(part) for part in Path(filename).parts if part not in ("", ".", "..")]
    if not relative_parts:
        relative_parts = [secure_filename(filename)]
    destination = base_dir.joinpath(*relative_parts)
    destination.parent.mkdir(parents=True, exist_ok=True)
    storage.save(destination)
    return destination


def _build_analyze_url(token: str, path: Path | None) -> str | None:
    if not path:
        return None
    return f"/api/uma-ase/analyze/{token}/{path.name}"


def _sanitize_relative_path(relpath: str | None) -> Path | None:
    if not relpath:
        return None
    parts = [
        secure_filename(part)
        for part in Path(relpath).parts
        if part not in ("", ".", "..")
    ]
    cleaned = [part for part in parts if part]
    if not cleaned:
        return None
    return Path(*cleaned)


def _write_driver_pdf(text: str, pdf_path: Path) -> Path | None:
    try:
        lines = text.splitlines() or [""]
        lines_per_page = 55
        with PdfPages(pdf_path) as pdf:
            for start in range(0, len(lines), lines_per_page):
                chunk = lines[start : start + lines_per_page]
                fig = plt.figure(figsize=(8.27, 11.69))
                fig.patch.set_facecolor("white")
                plt.axis("off")
                fig.text(
                    0.03,
                    0.97,
                    "\n".join(chunk),
                    family="monospace",
                    fontsize=8,
                    va="top",
                    ha="left",
                )
                pdf.savefig(fig, bbox_inches="tight")
                plt.close(fig)
        return pdf_path
    except Exception:
        if pdf_path.exists():
            pdf_path.unlink(missing_ok=True)  # type: ignore[arg-type]
        return None


def _write_driver_latex(text: str, tex_path: Path) -> Path | None:
    try:
        latex = "\n".join(
            [
                r"\documentclass{article}",
                r"\usepackage[margin=1in]{geometry}",
                r"\usepackage{fancyvrb}",
                r"\begin{document}",
                r"\section*{Compiled RMSD Results}",
                r"\begin{Verbatim}[fontsize=\small]",
                text,
                r"\end{Verbatim}",
                r"\end{document}",
                "",
            ]
        )
        tex_path.write_text(latex, encoding="utf-8")
        return tex_path
    except OSError:
        return None


def _write_driver_docx(text: str, docx_path: Path) -> Path | None:
    if DRIVER_DOCX_AVAILABLE:
        try:
            document = Document()
            for line in text.splitlines():
                document.add_paragraph(line)
            document.save(docx_path)
            return docx_path
        except Exception:
            pass

    lines = text.splitlines() or [""]
    fallback = _write_basic_report_docx(lines, docx_path)
    if fallback:
        return fallback
    if docx_path.exists():
        docx_path.unlink(missing_ok=True)  # type: ignore[arg-type]
    return None


def _run_driver_analysis(xyz_root: Path, output_dir: Path) -> Dict[str, Path | str | int]:
    xyz_files = [
        path
        for path in xyz_root.rglob("*")
        if path.is_file() and path.suffix.lower() == ".xyz"
    ]
    if not xyz_files:
        raise ValueError("Upload at least one XYZ file in the selected folder.")

    def _is_opt_variant(path: Path) -> bool:
        stem = path.stem.lower()
        return "opt" in stem or "sp-opt" in stem

    def _matches_base(base: Path, candidate: Path) -> bool:
        base_key = base.stem.lower()
        cand_key = candidate.stem.lower()
        if cand_key == base_key:
            return False
        prefix = f"{base_key}-"
        if not cand_key.startswith(prefix):
            return False
        suffix = cand_key[len(prefix) :]
        return "opt" in suffix or "sp-opt" in suffix

    by_parent: Dict[Path, Dict[str, Path]] = {}
    for path in xyz_files:
        by_parent.setdefault(path.parent, {})[path.name.lower()] = path

    file_pairs: List[tuple[Path, Path]] = []
    for folder_files in by_parent.values():
        bases = {name: path for name, path in folder_files.items() if not _is_opt_variant(path)}
        variants = {name: path for name, path in folder_files.items() if _is_opt_variant(path)}
        for base_name, base_path in bases.items():
            prefix = f"{base_name}"
            matches = [
                variants[name]
                for name in variants
                if name.startswith(f"{base_name[:-4]}-") and _matches_base(base_path, variants[name])
            ]
            for match in matches:
                file_pairs.append((base_path, match))

    if not file_pairs:
        raise ValueError("No XYZ/-geoopt-OPT pairs found. Ensure optimized counterparts are present.")

    scripts_root = resources.files("uma_ase").joinpath("scripts-to-share_v2")
    with resources.as_file(scripts_root) as resolved_root:
        root_path = Path(resolved_root)
        rmsd_script = root_path / "rmsd.py"
        hetero_script = root_path / "rmsd_dist-angles_ranking_hetero-cutoff.py"
        if not rmsd_script.exists() or not hetero_script.exists():
            raise RuntimeError("Driver scripts are unavailable in this installation.")

        def _run_tool(tool_path: Path, file_a: Path, file_b: Path) -> str:
            result = subprocess.run(
                [sys.executable, str(tool_path), str(file_a), str(file_b)],
                capture_output=True,
                text=True,
                cwd=str(root_path),
            )
            if result.returncode != 0:
                stderr = (result.stderr or "").strip()
                stdout = (result.stdout or "").strip()
                details = stderr or stdout or f"Exited with status {result.returncode}"
                return f"Error running {tool_path.name}: {details}\n"
            return result.stdout

        output_dir.mkdir(parents=True, exist_ok=True)
        output_path = output_dir / "compiled_rmsd_results.txt"
        with output_path.open("w", encoding="utf-8", buffering=1024 * 1024) as handle:
            handle.write("=== RMSD Analysis Results ===\n\n")
            for index, (file_a, file_b) in enumerate(file_pairs, start=1):
                handle.write(f"[{index}] File pair: {file_a.name}  vs  {file_b.name}\n")
                handle.write("-" * 60 + "\n")
                handle.write("--- rmsd.py output ---\n")
                handle.write(_run_tool(rmsd_script, file_a, file_b))
                handle.write("\n--- rmsd_dist-angles_ranking_hetero-cutoff.py output ---\n")
                handle.write(_run_tool(hetero_script, file_a, file_b))
                handle.write("\n" + "=" * 80 + "\n\n")
            handle.flush()

    preview_text = output_path.read_text(encoding="utf-8", errors="replace")
    preview_limit = 200_000
    trimmed_preview = (
        preview_text if len(preview_text) <= preview_limit else f"{preview_text[:preview_limit]}\n...\n"
    )
    pdf_path = _write_driver_pdf(preview_text, output_dir / "compiled_rmsd_results.pdf")
    tex_path = _write_driver_latex(preview_text, output_dir / "compiled_rmsd_results.tex")
    docx_path = _write_driver_docx(preview_text, output_dir / "compiled_rmsd_results.docx")

    return {
        "text_path": output_path,
        "pdf_path": pdf_path,
        "latex_path": tex_path,
        "docx_path": docx_path,
        "pairs": len(file_pairs),
        "preview": trimmed_preview,
    }


@app.route("/")
def index() -> Response:
    """Serve the single-page frontend bundled with the package."""
    html_path = resources.files("uma_ase").joinpath("static", STATIC_HTML)
    return Response(html_path.read_bytes(), mimetype="text/html")


@app.route("/assets/<path:asset>")
def serve_static_asset(asset: str):
    """Serve packaged static assets (e.g. logo.svg) referenced from the frontend."""
    candidate = resources.files("uma_ase").joinpath("static", asset)
    if not candidate.is_file():
        abort(404)
    with resources.as_file(candidate) as fs_path:
        return send_file(fs_path)


@app.route("/assets/")
def serve_static_root():
    """Provide a no-op response for tools that probe the asset root (e.g. JSmol)."""
    return Response(status=204)


@app.route("/api/uma-ase/analyze", methods=["POST"])
def analyze_logs():
    uploads = request.files.getlist("files")
    if not uploads:
        return jsonify({"status": "error", "message": "Upload at least one log file or folder."}), 400

    with tempfile.TemporaryDirectory() as tmpdir:
        logs_root = Path(tmpdir) / "logs"
        logs_root.mkdir(parents=True, exist_ok=True)
        saved = 0
        for storage in uploads:
            if not storage or not storage.filename:
                continue
            try:
                _safe_save_upload(storage, logs_root)
                saved += 1
            except ValueError:
                continue

        if not saved:
            return jsonify({"status": "error", "message": "No valid files uploaded."}), 400

        token = secure_filename(uuid.uuid4().hex)
        output_dir = ANALYZE_REPORT_ROOT / token
        output_dir.mkdir(parents=True, exist_ok=True)
        try:
            outputs = generate_report(logs_root, output_dir)
        except ValueError as exc:
            shutil.rmtree(output_dir, ignore_errors=True)
            return jsonify({"status": "error", "message": str(exc)}), 400
        pdf_path = outputs.get("pdf")
        if not pdf_path or not pdf_path.exists():
            shutil.rmtree(output_dir, ignore_errors=True)
            return jsonify({"status": "error", "message": "Report generation failed."}), 500
        payload = {
            "status": "ok",
            "token": token,
            "pdf_url": _build_analyze_url(token, pdf_path),
            "latex_url": _build_analyze_url(token, outputs.get("latex")),
            "docx_url": _build_analyze_url(token, outputs.get("docx")),
        }
        return jsonify(payload)


@app.route("/api/uma-ase/analyze/driver", methods=["POST"])
def analyze_xyz_pairs():
    uploads = request.files.getlist("files")
    if not uploads:
        return jsonify({"status": "error", "message": "Upload at least one XYZ file or folder."}), 400

    with tempfile.TemporaryDirectory() as tmpdir:
        xyz_root = Path(tmpdir) / "xyz"
        xyz_root.mkdir(parents=True, exist_ok=True)
        saved = 0
        for storage in uploads:
            if not storage or not storage.filename:
                continue
            try:
                _safe_save_upload(storage, xyz_root)
                saved += 1
            except ValueError:
                continue

        if not saved:
            return jsonify({"status": "error", "message": "No valid files uploaded."}), 400

        token = secure_filename(f"drv-{uuid.uuid4().hex}")
        output_dir = ANALYZE_REPORT_ROOT / token
        output_dir.mkdir(parents=True, exist_ok=True)
        try:
            result = _run_driver_analysis(xyz_root, output_dir)
        except ValueError as exc:
            shutil.rmtree(output_dir, ignore_errors=True)
            return jsonify({"status": "error", "message": str(exc)}), 400
        except RuntimeError as exc:
            shutil.rmtree(output_dir, ignore_errors=True)
            return jsonify({"status": "error", "message": str(exc)}), 500

    text_path = result.get("text_path")
    payload = {
        "status": "ok",
        "token": token,
        "pairs": result.get("pairs", 0),
        "preview": result.get("preview"),
        "results_url": _build_analyze_url(token, text_path),
        "pdf_url": _build_analyze_url(token, result.get("pdf_path")),
        "latex_url": _build_analyze_url(token, result.get("latex_path")),
        "docx_url": _build_analyze_url(token, result.get("docx_path")),
        "message": f"Processed {result.get('pairs', 0)} file pairs." if result.get("pairs") else "Analysis complete.",
    }
    return jsonify(payload)


@app.route("/api/uma-ase/analyze/<token>/<path:filename>")
def download_analyze_file(token: str, filename: str):
    safe_token = secure_filename(token)
    base_dir = (ANALYZE_REPORT_ROOT / safe_token).resolve()
    if not base_dir.exists():
        abort(404)
    target = (base_dir / filename).resolve()
    try:
        target.relative_to(base_dir)
    except ValueError:
        abort(404)
    if not target.is_file():
        abort(404)
    return send_file(target)


@app.route("/api/uma-ase/run", methods=["POST"])
def run_job():
    geometry = request.files.get("geometry")
    if geometry is None or geometry.filename == "":
        return jsonify({"status": "error", "message": "Geometry file is required."}), 400

    try:
        charge_val = int(request.form.get("charge", "0"))
    except (TypeError, ValueError):
        return jsonify({"status": "error", "message": "Charge must be an integer."}), 400

    try:
        spin_val = int(request.form.get("spin", "1"))
    except (TypeError, ValueError):
        return jsonify({"status": "error", "message": "Spin multiplicity must be an integer."}), 400

    try:
        grad_val = float(request.form.get("grad", "0.01"))
    except (TypeError, ValueError):
        return jsonify({"status": "error", "message": "Grad must be a number."}), 400
    if grad_val <= 0:
        return jsonify({"status": "error", "message": "Grad must be positive."}), 400

    try:
        iter_val = int(request.form.get("iter", "250"))
    except (TypeError, ValueError):
        return jsonify({"status": "error", "message": "Max iterations must be an integer."}), 400
    if iter_val <= 0:
        return jsonify({"status": "error", "message": "Max iterations must be positive."}), 400

    optimizer = request.form.get("optimizer", "LBFGS")
    temperature = request.form.get("temperature", "298.15")
    pressure = request.form.get("pressure", "101325.0")
    run_types_raw = request.form.get("run_type", "sp").split()
    run_types = [item.lower() for item in run_types_raw] or ["sp"]
    mlff_checkpoint_raw = request.form.get("mlff_checkpoint", "uma-s-1p1")
    mlff_checkpoint = mlff_checkpoint_raw.strip() or "uma-s-1p1"
    mlff_task_raw = request.form.get("mlff_task", "omol")
    mlff_task = mlff_task_raw.strip() or "omol"

    results_root = Path(app.config["UMA_RESULTS_DIR"])
    results_root.mkdir(parents=True, exist_ok=True)

    job_id = f"{datetime.now().strftime('%Y%m%d%H%M%S')}-{uuid.uuid4().hex[:6]}"
    filename = secure_filename(geometry.filename) or "input.xyz"
    relative_field = request.form.get("relative_path") or request.form.get("source_path")
    sanitized_relative = _sanitize_relative_path(relative_field)
    folder_root_raw = request.form.get("multi_root")
    folder_root = secure_filename(folder_root_raw) if folder_root_raw else None
    multi_root_dir = results_root / "multi_runs"
    job_dir: Path
    if sanitized_relative:
        multi_root_dir.mkdir(parents=True, exist_ok=True)
        rel_parent = sanitized_relative.parent if sanitized_relative.parent != Path(".") else Path()
        base_name = sanitized_relative.stem or Path(filename).stem or "geometry"
        if folder_root:
            rel_parent = Path(folder_root) / rel_parent
        base_dir = multi_root_dir.joinpath(rel_parent, base_name)
        job_dir = base_dir
        attempt = 1
        while job_dir.exists():
            job_dir = base_dir.parent / f"{base_name}_{attempt}"
            attempt += 1
        job_dir.mkdir(parents=True, exist_ok=True)
    else:
        if folder_root:
            multi_root_dir.mkdir(parents=True, exist_ok=True)
            base_name = Path(filename).stem or "geometry"
            base_dir = multi_root_dir / folder_root / base_name
            job_dir = base_dir
            attempt = 1
            while job_dir.exists():
                job_dir = base_dir.parent / f"{base_name}_{attempt}"
                attempt += 1
            job_dir.mkdir(parents=True, exist_ok=True)
        else:
            job_dir = results_root / job_id
            job_dir.mkdir(parents=True, exist_ok=True)

    input_path = job_dir / filename
    geometry.save(input_path)

    record = JobRecord(
        job_id=job_id,
        job_dir=job_dir,
        charge=charge_val,
        spin=spin_val,
        grad=grad_val,
        iterations=iter_val,
        run_types=run_types,
        relative_path=sanitized_relative,
    )
    record.cancel_event = threading.Event()
    record.job_kind = "workflow"

    with JOB_LOCK:
        JOBS[job_id] = record

    worker = threading.Thread(
        target=_execute_job,
        args=(
            record,
            filename,
            optimizer,
            temperature,
            pressure,
            mlff_checkpoint,
            mlff_task,
            sanitized_relative,
        ),
        daemon=True,
    )
    record.worker = worker
    worker.start()

    return jsonify({"job_id": job_id})



@app.route("/api/uma-ase/md/run", methods=["POST"])
def run_md_job():
    geometry = request.files.get("geometry")
    if geometry is None or geometry.filename == "":
        return jsonify({"status": "error", "message": "Geometry file is required."}), 400

    form = request.form
    engine = (form.get("md_engine") or "langevin").lower()

    try:
        charge_val = _parse_int_field(form, "charge", 0, label="Charge")
        spin_val = _parse_int_field(form, "spin", 1, label="Spin multiplicity", min_value=0, min_inclusive=False)
        steps_val = _parse_int_field(form, "md_steps", 500, label="MD steps", min_value=0, min_inclusive=False)
        timestep_fs = _parse_float_field(form, "md_timestep_fs", 0.5, label="Timestep", min_value=0.0, min_inclusive=False)
        temperature_val = _parse_float_field(form, "md_temperature", 300.0, label="Temperature", min_value=0.0, min_inclusive=False)
        friction_val = _parse_float_field(form, "md_friction", 0.002, label="Friction", min_value=0.0)
        traj_interval = _parse_int_field(form, "md_traj_interval", 10, label="Trajectory interval", min_value=0, min_inclusive=False)
        log_interval = _parse_int_field(form, "md_log_interval", 10, label="Log interval", min_value=0, min_inclusive=False)
        relax_fs = _parse_float_field(form, "md_relax_fs", 100.0, label="Thermostat relaxation time", min_value=0.0, min_inclusive=False)
        pressure_bar = _parse_float_field(form, "md_pressure_bar", 1.0, label="Pressure", min_value=0.0)
        compressibility_bar = _parse_float_field(form, "md_compressibility_bar", 4.57e-5, label="Compressibility", min_value=0.0)
        barostat_relax_fs = _parse_float_field(form, "md_barostat_relax_fs", 1000.0, label="Barostat relaxation time", min_value=0.0, min_inclusive=False)
        nose_tdamp_fs = _parse_float_field(form, "md_nose_tdamp_fs", relax_fs, label="Nosé-Hoover damping", min_value=0.0, min_inclusive=False)
        nose_tchain = _parse_int_field(form, "md_nose_tchain", 3, label="Nosé-Hoover chain length", min_value=0, min_inclusive=False)
        nose_tloop = _parse_int_field(form, "md_nose_tloop", 1, label="Nosé-Hoover sub-steps", min_value=0, min_inclusive=False)
        bussi_taut_fs = _parse_float_field(form, "md_bussi_taut_fs", relax_fs, label="Bussi time constant", min_value=0.0, min_inclusive=False)
        andersen_prob = _parse_float_field(form, "md_andersen_prob", 0.001, label="Andersen collision probability", min_value=0.0)
        if andersen_prob is None or not 0 <= andersen_prob <= 1:
            raise ValueError("Andersen collision probability must be between 0 and 1.")
        mtk_tdamp_fs = _parse_float_field(form, "md_mtk_tdamp_fs", 100.0, label="MTK thermostat damping", min_value=0.0, min_inclusive=False)
        mtk_pdamp_fs = _parse_float_field(form, "md_mtk_pdamp_fs", 1000.0, label="MTK barostat damping", min_value=0.0, min_inclusive=False)
        mtk_tchain = _parse_int_field(form, "md_mtk_tchain", 3, label="MTK thermostat chain length", min_value=0, min_inclusive=False)
        mtk_pchain = _parse_int_field(form, "md_mtk_pchain", 3, label="MTK barostat chain length", min_value=0, min_inclusive=False)
        mtk_tloop = _parse_int_field(form, "md_mtk_tloop", 1, label="MTK thermostat sub-steps", min_value=0, min_inclusive=False)
        mtk_ploop = _parse_int_field(form, "md_mtk_ploop", 1, label="MTK barostat sub-steps", min_value=0, min_inclusive=False)
        baoab_ttau_fs = _parse_float_field(form, "md_baoab_ttau_fs", 50.0, label="BAOAB thermostat time", min_value=0.0, min_inclusive=False)
        baoab_ptau_fs = _parse_float_field(form, "md_baoab_ptau_fs", 1000.0, label="BAOAB barostat time", min_value=0.0, min_inclusive=False)
        baoab_pmass = _parse_float_field(form, "md_baoab_pmass", None, label="BAOAB cell mass", min_value=0.0, min_inclusive=False)
        baoab_pmass_factor = _parse_float_field(form, "md_baoab_pmass_factor", 1.0, label="BAOAB mass factor", min_value=0.0, min_inclusive=False)
        melchionna_ttime_fs = _parse_float_field(form, "md_melchionna_ttime_fs", 25.0, label="Melchionna thermostat time", min_value=0.0)
        melchionna_ptime_fs = _parse_float_field(form, "md_melchionna_ptime_fs", 75.0, label="Melchionna barostat time", min_value=0.0)
        melchionna_bulk_gpa = _parse_float_field(form, "md_melchionna_bulk_gpa", 100.0, label="Melchionna bulk modulus", min_value=0.0)
    except ValueError as exc:
        return jsonify({"status": "error", "message": str(exc)}), 400

    melchionna_mask_raw = (form.get("md_melchionna_mask") or "").strip()
    melchionna_mask: Optional[List[int]] = None
    if melchionna_mask_raw:
        try:
            mask_values = []
            for entry in melchionna_mask_raw.split(","):
                entry_clean = entry.strip()
                if entry_clean == "":
                    continue
                value = 1 if int(round(float(entry_clean))) > 0 else 0
                mask_values.append(value)
            if mask_values:
                mask_values = (mask_values + [1, 1, 1])[:3]
                melchionna_mask = mask_values
        except (TypeError, ValueError):
            return jsonify({"status": "error", "message": "Melchionna mask must be comma separated integers."}), 400

    baoab_hydrostatic = _parse_bool_field(form, "md_baoab_hydrostatic", False)
    use_pbc = _parse_bool_field(form, "md_use_pbc", False)
    cell_parameters: Optional[List[float]] = None
    if use_pbc:
        try:
            cell_a = _parse_float_field(form, "md_cell_a", None, label="Cell a", min_value=0.0, min_inclusive=False, required=True)
            cell_b = _parse_float_field(form, "md_cell_b", None, label="Cell b", min_value=0.0, min_inclusive=False, required=True)
            cell_c = _parse_float_field(form, "md_cell_c", None, label="Cell c", min_value=0.0, min_inclusive=False, required=True)
            cell_alpha = _parse_float_field(
                form,
                "md_cell_alpha",
                None,
                label="Cell alpha",
                min_value=0.0,
                min_inclusive=False,
                max_value=180.0,
                max_inclusive=False,
                required=True,
            )
            cell_beta = _parse_float_field(
                form,
                "md_cell_beta",
                None,
                label="Cell beta",
                min_value=0.0,
                min_inclusive=False,
                max_value=180.0,
                max_inclusive=False,
                required=True,
            )
            cell_gamma = _parse_float_field(
                form,
                "md_cell_gamma",
                None,
                label="Cell gamma",
                min_value=0.0,
                min_inclusive=False,
                max_value=180.0,
                max_inclusive=False,
                required=True,
            )
            cell_parameters = [cell_a, cell_b, cell_c, cell_alpha, cell_beta, cell_gamma]  # type: ignore[list-item]
        except ValueError as exc:
            return jsonify({"status": "error", "message": str(exc)}), 400

    mlff_checkpoint_raw = request.form.get("mlff_checkpoint", "uma-s-1p1")
    mlff_checkpoint = mlff_checkpoint_raw.strip() or "uma-s-1p1"
    mlff_task_raw = request.form.get("mlff_task", "omol")
    mlff_task = mlff_task_raw.strip() or "omol"

    results_root = Path(app.config["UMA_RESULTS_DIR"])
    results_root.mkdir(parents=True, exist_ok=True)

    job_id = f"{datetime.now().strftime('%Y%m%d%H%M%S')}-{uuid.uuid4().hex[:6]}"
    filename = secure_filename(geometry.filename) or "input.xyz"
    job_dir = results_root / job_id
    job_dir.mkdir(parents=True, exist_ok=True)

    input_path = job_dir / filename
    geometry.save(input_path)

    log_path = job_dir / f"{Path(filename).stem}-MD.log"
    traj_path = job_dir / f"{Path(filename).stem}-MD.traj"
    snapshot_path = job_dir / f"{Path(filename).stem}-MD-final.xyz"

    md_options = {
        "engine": engine,
        "steps": steps_val,
        "timestep_fs": timestep_fs,
        "temperature": temperature_val,
        "friction": friction_val,
        "traj_interval": traj_interval,
        "log_interval": log_interval,
        "relax_fs": relax_fs,
        "pressure_bar": pressure_bar,
        "compressibility_bar": compressibility_bar,
        "barostat_relax_fs": barostat_relax_fs,
        "nose_tdamp_fs": nose_tdamp_fs,
        "nose_tchain": nose_tchain,
        "nose_tloop": nose_tloop,
        "bussi_taut_fs": bussi_taut_fs,
        "andersen_prob": andersen_prob,
        "mtk_tdamp_fs": mtk_tdamp_fs,
        "mtk_pdamp_fs": mtk_pdamp_fs,
        "mtk_tchain": mtk_tchain,
        "mtk_pchain": mtk_pchain,
        "mtk_tloop": mtk_tloop,
        "mtk_ploop": mtk_ploop,
        "baoab_ttau_fs": baoab_ttau_fs,
        "baoab_ptau_fs": baoab_ptau_fs,
        "baoab_pmass": baoab_pmass,
        "baoab_pmass_factor": baoab_pmass_factor,
        "baoab_hydrostatic": baoab_hydrostatic,
        "melchionna_ttime_fs": melchionna_ttime_fs,
        "melchionna_ptime_fs": melchionna_ptime_fs,
        "melchionna_bulk_gpa": melchionna_bulk_gpa,
        "melchionna_mask": melchionna_mask,
        "use_pbc": use_pbc,
        "cell_parameters": cell_parameters,
        "mlff_checkpoint": mlff_checkpoint,
        "mlff_task": mlff_task,
    }

    record = JobRecord(
        job_id=job_id,
        job_dir=job_dir,
        charge=charge_val,
        spin=spin_val,
        grad=0.0,
        iterations=steps_val,
        run_types=["md"],
        status="running",
        log_path=log_path,
        traj_path=traj_path,
        opt_path=snapshot_path,
        md_options=md_options,
    )
    record.cancel_event = threading.Event()
    record.job_kind = "md"

    with JOB_LOCK:
        JOBS[job_id] = record

    worker = threading.Thread(
        target=_execute_md_job,
        args=(record, filename, md_options),
        daemon=True,
    )
    record.worker = worker
    worker.start()

    return jsonify({"job_id": job_id})


def _execute_job(
    record: JobRecord,
    filename: str,
    optimizer: str,
    temperature: str,
    pressure: str,
    mlff_checkpoint: Optional[str],
    mlff_task: Optional[str],
    relative_path: Optional[Path],
):
    job_dir = record.job_dir
    cancel_event = record.cancel_event
    if cancel_event is None:
        cancel_event = threading.Event()
        record.cancel_event = cancel_event
    if relative_path:
        record.relative_path = relative_path
    if cancel_event.is_set():
        with JOB_LOCK:
            record.status = "cancelled"
            record.message = "Job cancelled by user."
        return

    input_path = job_dir / filename
    run_sequence = record.run_types or ["sp"]

    try:
        paths = build_output_paths(input_path, run_sequence)
        record.log_path = paths.log
        record.traj_path = paths.trajectory
        record.opt_path = paths.final_geometry

        argv = _build_cli_args(
            input_path,
            record.run_types,
            str(record.charge),
            str(record.spin),
            optimizer,
            str(record.grad),
            str(record.iterations),
            temperature,
            pressure,
            mlff_checkpoint,
            mlff_task,
        )

        cmd = [sys.executable, "-m", "uma_ase.cli", *argv]
        process = subprocess.Popen(
            cmd,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
            cwd=str(job_dir),
        )
        with JOB_LOCK:
            record.process = process

        stdout_data, _ = process.communicate()
        exit_code = process.returncode
        with JOB_LOCK:
            record.process = None

        cancelled = cancel_event.is_set()
        error_message: Optional[str] = None
        if cancelled:
            error_message = "Job cancelled by user."
        elif exit_code != 0:
            tail = (stdout_data or "").strip().splitlines()
            detail = tail[-1] if tail else ""
            error_message = f"uma-ase exited with status {exit_code}."
            if detail:
                error_message = f"{error_message} {detail}"
        else:
            if record.opt_path and record.opt_path.exists():
                try:
                    atoms_opt = read(str(record.opt_path))
                    formula_opt = atoms_opt.get_chemical_formula()
                    comment = " ".join(
                        part
                        for part in [
                            formula_opt,
                            f"charge={record.charge}",
                            f"spin={record.spin}",
                        ]
                        if part
                    )
                    write(str(record.opt_path), atoms_opt, format="xyz", comment=comment)
                except Exception as exc:
                    error_message = f"Optimized geometry rewrite failed: {exc}"

        with JOB_LOCK:
            if cancelled:
                record.status = "cancelled"
                record.message = "Job cancelled by user."
            elif error_message:
                record.status = "error"
                record.message = error_message
            else:
                record.status = "completed"

            if record.log_path and record.log_path.exists():
                record.log_url = f"/api/uma-ase/job/{record.job_id}/log"
            if record.traj_path and record.traj_path.exists():
                record.traj_url = f"/api/uma-ase/job/{record.job_id}/trajectory"
            if record.opt_path and record.opt_path.exists():
                record.opt_url = f"/api/uma-ase/job/{record.job_id}/optimized"
            record.worker = None

    except Exception as exc:
        with JOB_LOCK:
            record.status = "error" if not cancel_event.is_set() else "cancelled"
            record.message = "Job cancelled by user." if cancel_event.is_set() else str(exc)
            record.worker = None


def _execute_md_job(
    record: JobRecord,
    filename: str,
    md_options: Dict[str, Any],
):
    job_dir = record.job_dir
    input_path = job_dir / filename
    stem = Path(filename).stem or "input"
    log_path = record.log_path or job_dir / f"{stem}-MD.log"
    traj_path = record.traj_path or job_dir / f"{stem}-MD.traj"
    final_path = record.opt_path or job_dir / f"{stem}-MD-final.xyz"
    multi_xyz_path = job_dir / f"{stem}-MD.xyz"
    md_steps = md_options.get("steps", 0)
    traj_interval = max(1, int(md_options.get("traj_interval", 10)))
    log_interval = max(1, int(md_options.get("log_interval", 10)))
    engine = md_options.get("engine", "langevin")
    label = MD_ENGINE_LABELS.get(engine, engine)
    temperature = md_options.get("temperature", 300.0)
    timestep_fs = md_options.get("timestep_fs", 0.5)
    friction = md_options.get("friction", 0.0)
    relax_fs = md_options.get("relax_fs", 100.0)
    checkpoint = md_options.get("mlff_checkpoint")
    task = md_options.get("mlff_task")
    use_pbc = bool(md_options.get("use_pbc"))
    cell_parameters = md_options.get("cell_parameters")
    cancel_event = record.cancel_event
    if cancel_event is None:
        cancel_event = threading.Event()
        record.cancel_event = cancel_event

    success = False
    error_message: Optional[str] = None
    trajectory = None
    try:
        with configure_logging(log_path) as logger:
            args = SimpleNamespace(
                input=str(input_path),
                chg=record.charge,
                spin=record.spin,
                optimizer="",
                grad=0.0,
                iter=md_steps,
                run_type=["md"],
                mlff_chk=checkpoint,
                mlff_task=task,
                temp=temperature,
                press=0.0,
                cpu=False,
                visualize=False,
                _chg_explicit=True,
                _spin_explicit=True,
            )
            status, context = setup_calculated_atoms(args, logger)
            if status != 0 or context is None:
                raise RuntimeError("UMA calculator setup failed; see log for details.")

            if cancel_event.is_set():
                logger.info("Cancellation requested before MD start.")
                raise JobCancelled("MD job cancelled by user.")

            logger.info("*****************************************************************************")
            logger.info("*                        Molecular Dynamics Run                             *")
            logger.info("*****************************************************************************")
            logger.info("* Engine          : %s", label)
            logger.info("* Steps           : %s", md_steps)
            logger.info("* Timestep (fs)   : %.4f", timestep_fs)
            logger.info("* Temperature (K) : %.2f", temperature)
            logger.info("* Friction (1/fs) : %.4f", friction)
            logger.info("* Traj interval   : %d", traj_interval)
            logger.info("* Log interval    : %d", log_interval)
            logger.info("* Checkpoint      : %s", checkpoint)
            logger.info("* Task            : %s", task)
            logger.info("* Periodic BC     : %s", "enabled" if use_pbc else "disabled")
            if use_pbc and cell_parameters:
                logger.info(
                    "* Cell (Å | °)    : a=%.3f b=%.3f c=%.3f | α=%.2f β=%.2f γ=%.2f",
                    cell_parameters[0],
                    cell_parameters[1],
                    cell_parameters[2],
                    cell_parameters[3],
                    cell_parameters[4],
                    cell_parameters[5],
                )
            if engine in {"npt_berendsen", "isotropic_mtk_npt", "mtk_npt", "langevin_baoab_npt", "melchionna_npt"}:
                logger.info("* Pressure (bar)  : %.6f", md_options.get("pressure_bar", 0.0))
            if engine in {"nvt_berendsen", "npt_berendsen"}:
                logger.info("* Thermostat tau  : %.4f fs", relax_fs)
            if engine == "npt_berendsen":
                logger.info("* Barostat tau    : %.4f fs", md_options.get("barostat_relax_fs", 1000.0))
                logger.info("* Compressibility : %.6e 1/bar", md_options.get("compressibility_bar", 0.0))
            if engine == "nose_hoover_chain":
                logger.info(
                    "* Nosé-Hoover     : tdamp=%.4f fs | chain=%d | tloop=%d",
                    md_options.get("nose_tdamp_fs", relax_fs),
                    md_options.get("nose_tchain", 3),
                    md_options.get("nose_tloop", 1),
                )
            if engine == "bussi":
                logger.info("* Bussi taut      : %.4f fs", md_options.get("bussi_taut_fs", relax_fs))
            if engine == "andersen":
                logger.info("* Andersen prob   : %.6f", md_options.get("andersen_prob", 0.0))
            if engine in {"isotropic_mtk_npt", "mtk_npt"}:
                logger.info(
                    "* MTK params      : tdamp=%.4f fs | pdamp=%.4f fs | tchain=%d | pchain=%d",
                    md_options.get("mtk_tdamp_fs", 100.0),
                    md_options.get("mtk_pdamp_fs", 1000.0),
                    md_options.get("mtk_tchain", 3),
                    md_options.get("mtk_pchain", 3),
                )
            if engine.startswith("langevin_baoab"):
                logger.info(
                    "* BAOAB params    : T_tau=%.4f fs | P_tau=%.4f fs | hydrostatic=%s",
                    md_options.get("baoab_ttau_fs", 50.0),
                    md_options.get("baoab_ptau_fs", 1000.0),
                    md_options.get("baoab_hydrostatic", False),
                )
            if engine == "melchionna_npt":
                logger.info(
                    "* Melchionna      : ttime=%.4f fs | ptime=%.4f fs | bulk=%.3f GPa",
                    md_options.get("melchionna_ttime_fs", 25.0),
                    md_options.get("melchionna_ptime_fs", 75.0),
                    md_options.get("melchionna_bulk_gpa", 100.0),
                )
            logger.info("*****************************************************************************")

            atoms = context.atoms
            if use_pbc:
                if cell_parameters:
                    try:
                        cell_matrix = cellpar_to_cell(cell_parameters)
                    except Exception as exc:
                        raise RuntimeError(f"Invalid cell parameters: {exc}") from exc
                    atoms.set_cell(cell_matrix, scale_atoms=True)
                atoms.set_pbc(True)
                cell_vectors = atoms.get_cell()
                logger.info("Cell matrix (Å):")
                axis_labels = ("a", "b", "c")
                for axis_label, vector in zip(axis_labels, cell_vectors):
                    logger.info("  %s = [%.6f %.6f %.6f]", axis_label, vector[0], vector[1], vector[2])
            else:
                atoms.set_pbc(False)
            _initialise_md_velocities(atoms, temperature, logger)
            dynamics = _create_md_dynamics(
                engine,
                atoms,
                md_options=md_options,
            )
            trajectory = Trajectory(str(traj_path), "w", atoms)

            def log_step():
                step = dynamics.nsteps
                potential = atoms.get_potential_energy()
                kinetic = atoms.get_kinetic_energy()
                total = potential + kinetic
                inst_temp = atoms.get_temperature()
                logger.info(
                    "Step %d/%d | E_pot=%.6f eV | E_kin=%.6f eV | E_tot=%.6f eV | T=%.2f K",
                    step,
                    md_steps,
                    potential,
                    kinetic,
                    total,
                    inst_temp,
                )

            dynamics.attach(log_step, interval=log_interval)
            dynamics.attach(trajectory.write, interval=traj_interval)
            dynamics.attach(
                lambda: write(multi_xyz_path, atoms, format="xyz", append=True),
                interval=traj_interval,
            )

            logger.info(
                "Starting %s dynamics for %d steps (dt=%.4f fs).",
                label,
                md_steps,
                timestep_fs,
            )
            chunk = max(1, min(100, traj_interval, log_interval))
            completed = 0
            while completed < md_steps:
                if cancel_event.is_set():
                    logger.info("Cancellation requested. Halting MD integration.")
                    raise JobCancelled("MD job cancelled by user.")
                steps_this_round = min(chunk, md_steps - completed)
                dynamics.run(steps_this_round)
                completed += steps_this_round
            logger.info("Molecular dynamics finished successfully.")

            if not cancel_event.is_set():
                try:
                    write(
                        str(final_path),
                        atoms,
                        format="xyz",
                        comment=f"MD final snapshot | charge={record.charge} spin={record.spin}",
                    )
                    logger.info("Final files saved to %s", final_path)
                except Exception as exc:
                    logger.warning("Unable to write final files: %s", exc)

        success = True
    except JobCancelled as exc:
        error_message = str(exc)
        logger = app.logger
        logger.info("MD job %s cancelled by user.", record.job_id)
    except Exception as exc:
        error_message = str(exc)
        app.logger.exception("MD job %s failed", record.job_id)
    finally:
        if trajectory is not None:
            with suppress(Exception):
                trajectory.close()

        with JOB_LOCK:
            if cancel_event.is_set():
                record.status = "cancelled"
                record.message = error_message or "MD job cancelled by user."
            elif success:
                record.status = "completed"
                record.message = None
            else:
                record.status = "error"
                record.message = error_message or "MD job failed."

            if log_path.exists():
                record.log_url = f"/api/uma-ase/job/{record.job_id}/log"
            if traj_path.exists():
                record.traj_url = f"/api/uma-ase/job/{record.job_id}/trajectory"
            if final_path.exists():
                record.opt_url = f"/api/uma-ase/job/{record.job_id}/optimized"
            if multi_xyz_path.exists():
                record.md_multi_xyz = multi_xyz_path
                record.md_multi_xyz_url = f"/api/uma-ase/job/{record.job_id}/md_xyz"
            if multi_xyz_path.exists():
                record.md_options["multi_xyz"] = multi_xyz_path
            record.worker = None


def _send_job_file(path: Optional[Path], mimetype: str = "text/plain"):
    if path is None or not path.exists():
        abort(404)
    return send_file(
        path,
        mimetype=mimetype,
        as_attachment=True,
        download_name=path.name,
    )


@app.route("/api/uma-ase/job/<job_id>", methods=["GET"])
def job_status(job_id: str):
    record = _get_job(job_id)
    log_text = ""
    if record.log_path and record.log_path.exists():
        try:
            log_text = record.log_path.read_text(encoding="utf-8")
        except OSError:
            log_text = ""
    return jsonify(
        {
            "status": record.status,
            "message": record.message,
            "log": log_text,
            "log_download": record.log_url,
            "traj_download": record.traj_url,
            "opt_download": record.opt_url,
            "md_xyz_download": f"/api/uma-ase/job/{record.job_id}/md_xyz" if record.md_multi_xyz and record.md_multi_xyz.exists() else None,
        }
    )


@app.route("/api/uma-ase/job/<job_id>/cancel", methods=["POST"])
def cancel_job(job_id: str):
    record = _get_job(job_id)
    with JOB_LOCK:
        if record.status not in ("running",):
            return jsonify({"status": "error", "message": "Job is not running."}), 400
        cancel_event = record.cancel_event
        if cancel_event is None:
            cancel_event = threading.Event()
            record.cancel_event = cancel_event
        already_set = cancel_event.is_set()
        cancel_event.set()
        process = record.process
        job_kind = record.job_kind
        record.message = "Cancellation requested..."
    if already_set:
        return jsonify({"status": "ok"})
    if job_kind == "workflow" and process and process.poll() is None:
        try:
            process.terminate()
        except Exception:
            pass
    return jsonify({"status": "ok"})


@app.route("/api/uma-ase/job/<job_id>/log", methods=["GET"])
def download_job_log(job_id: str):
    record = _get_job(job_id)
    return _send_job_file(record.log_path, "text/plain")


@app.route("/api/uma-ase/job/<job_id>/trajectory", methods=["GET"])
def download_job_trajectory(job_id: str):
    record = _get_job(job_id)
    return _send_job_file(record.traj_path, "application/octet-stream")


@app.route("/api/uma-ase/job/<job_id>/optimized", methods=["GET"])
def download_job_optimized(job_id: str):
    record = _get_job(job_id)
    return _send_job_file(record.opt_path, "text/plain")


@app.route("/api/uma-ase/job/<job_id>/md_xyz", methods=["GET"])
def download_job_md_xyz(job_id: str):
    record = _get_job(job_id)
    return _send_job_file(record.md_multi_xyz, "text/plain")


@app.route("/api/uma-ase/clean", methods=["POST"])
def clean_results_root():
    base_dir = Path.home() / ".uma_ase"
    try:
        if base_dir.exists():
            shutil.rmtree(base_dir)
        results_root = Path(app.config["UMA_RESULTS_DIR"])
        results_root.mkdir(parents=True, exist_ok=True)
        ANALYZE_REPORT_ROOT.mkdir(parents=True, exist_ok=True)
        return jsonify({"status": "ok"})
    except OSError as exc:
        return jsonify({"status": "error", "message": str(exc)}), 500


@app.route("/api/uma-ase/multi/logs/<path:folder>", methods=["GET"])
def download_multi_logs(folder: str):
    safe_folder = secure_filename(folder)
    if not safe_folder:
        abort(404)
    multi_root = (Path(app.config["UMA_RESULTS_DIR"]) / "multi_runs").resolve()
    target_dir = (multi_root / safe_folder).resolve()
    try:
        target_dir.relative_to(multi_root)
    except ValueError:
        abort(404)
    if not target_dir.exists():
        abort(404)
    produced_files = [path for path in target_dir.rglob("*") if path.is_file()]
    if not produced_files:
        abort(404)

    temp_dir = Path(tempfile.mkdtemp(prefix="uma_logs_"))
    archive_path = temp_dir / f"{safe_folder}_files.zip"
    try:
        with zipfile.ZipFile(archive_path, "w", zipfile.ZIP_DEFLATED) as archive:
            for file_path in produced_files:
                archive.write(file_path, file_path.relative_to(target_dir))
    except Exception:
        shutil.rmtree(temp_dir, ignore_errors=True)
        raise

    @after_this_request
    def cleanup(response):  # pragma: no cover
        shutil.rmtree(temp_dir, ignore_errors=True)
        return response

    return send_file(
        archive_path,
        mimetype="application/zip",
        as_attachment=True,
        download_name=f"{safe_folder}_files.zip",
    )




@app.route("/api/uma-ase/preview", methods=["POST"])
def preview_structure():
    geometry = request.files.get("geometry")
    if geometry is None or geometry.filename == "":
        return jsonify({"status": "error", "message": "Geometry file is required."}), 400

    charge_raw = request.form.get("charge")
    spin_raw = request.form.get("spin")
    spin_val = 1

    with tempfile.TemporaryDirectory(prefix="uma_preview_") as temp_dir_str:
        temp_dir = Path(temp_dir_str)
        filename = secure_filename(geometry.filename) or "input.xyz"
        input_path = temp_dir / filename
        geometry.save(input_path)

        metadata = extract_xyz_metadata(input_path)

        if charge_raw is None or charge_raw.strip() == "":
            charge_val = metadata.charge if metadata.charge is not None else 0
        else:
            try:
                charge_val = int(charge_raw)
            except (TypeError, ValueError):
                return jsonify({"status": "error", "message": "Charge must be an integer."}), 400

        if spin_raw is None or spin_raw.strip() == "":
            if metadata.spin is not None and metadata.spin > 0:
                spin_val = metadata.spin
            else:
                spin_val = 1
        else:
            try:
                spin_val = int(spin_raw)
            except (TypeError, ValueError):
                return jsonify({"status": "error", "message": "Spin multiplicity must be an integer."}), 400
            if spin_val <= 0:
                return jsonify({"status": "error", "message": "Spin multiplicity must be positive."}), 400

        try:
            atoms = read(str(input_path))
        except Exception as exc:  # pragma: no cover - depends on external IO
            return jsonify({"status": "error", "message": f"Unable to read geometry: {exc}"}), 400

        atoms.info["charge"] = charge_val
        atoms.info["spin"] = spin_val
        xyz_comment = metadata.comment
        if xyz_comment:
            atoms.info.setdefault("uma_comment", xyz_comment)
        if metadata.url:
            atoms.info.setdefault("uma_comment_url", metadata.url)

        counts = Counter(atoms.get_chemical_symbols())
        num_atoms = len(atoms)
        formula = atoms.get_chemical_formula()
        element_counts = dict(counts)

        # Decide device availability using fairchem rules
        try:
            device = select_device()
        except TorchUnavailable:
            device = "cpu"

    summary_lines = [
        f"Number of atoms: {num_atoms}",
        f"Formula: {formula}",
        f"Element counts: {element_counts}",
        f"Device: {device}",
    ]
    summary_lines.insert(0, f"Spin multiplicity: {spin_val}")
    summary_lines.insert(0, f"Charge: {charge_val}")
    if xyz_comment:
        summary_lines.insert(0, f"Comment: {xyz_comment}")
    if metadata.url:
        summary_lines.insert(0, f"Source URL: {metadata.url}")

    return jsonify(
        {
            "status": "ok",
            "initial_geometry": filename,
            "num_atoms": num_atoms,
            "formula": formula,
            "element_counts": element_counts,
            "charge": charge_val,
            "spin": spin_val,
            "device": device,
            "comment": xyz_comment,
            "lines": summary_lines,
        }
    )
def create_app() -> Flask:
    """Factory for embedding in external WSGI servers."""
    return app


def main() -> None:
    """Run the development server."""
    app.run(debug=True, port=8000)


if __name__ == "__main__":  # pragma: no cover
    main()
