"""
plot_templates_pdf.py – generate multipage PDF previews of rebinned template spectra
==========================================================================

This utility opens every HDF5 template-storage file (generated by
snid.template_fft_storage.TemplateFFTStorage) found in a directory, and
creates a multi-page PDF for each file.  Each page shows the rebinned
flux for one template, with all of its epochs plotted on the common
log-wavelength grid and vertically offset so the spectra do not
overlap.  A legend lists the epoch ages.

Usage (from project root, PowerShell example):

    python scripts/plot_templates_pdf.py -i templates -o plots

Arguments
---------
-i / --input_dir   Directory that contains the *.hdf5 (or *.h5)
                   storage files.  Defaults to "templates".
-o / --output_dir  Where to write the PDFs.  Defaults to current dir.
--max_templates    Limit number of templates plotted from each file
                   (handy for quick tests).

The script requires `h5py` and `matplotlib` which are already project
requirements.
"""

from __future__ import annotations

import argparse
from pathlib import Path
import h5py
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib import cm

__all__ = ["main"]

def _plot_template(ax, wave: np.ndarray, epoch_data: list[tuple[float, np.ndarray]]):
    """Plot all epochs for a single template with vertical offsets."""

    # Sort epochs by age (nan ages go last)
    epoch_data.sort(key=lambda tup: (np.isnan(tup[0]), tup[0]))
    n_epochs = len(epoch_data)

    # Determine a sensible vertical spacing – use peak-to-peak of first epoch
    if n_epochs:
        ptp = np.ptp(epoch_data[0][1])
    else:
        ptp = 1.0
    spacing = ptp * 1.4 if ptp > 0 else 1.0

    cmap = cm.get_cmap("viridis", n_epochs)
    for idx, (age, flux) in enumerate(epoch_data):
        offset = idx * spacing
        label = f"age = {age:.1f}" if not np.isnan(age) else f"epoch {idx}"
        ax.plot(wave, flux + offset, color=cmap(idx), lw=0.8, label=label)

    ax.set_xlabel("log λ grid (Å)")
    ax.set_ylabel("Flux + offset")
    ax.legend(fontsize="small", ncol=2)


def process_hdf5_file(h5_path: Path, output_dir: Path, max_templates: int | None = None):
    """Create a multi-page PDF for a single type-specific HDF5 storage file."""

    print(f"📖 Processing {h5_path.name} …")
    pdf_path = output_dir / (h5_path.stem + "_preview.pdf")

    with h5py.File(h5_path, "r") as f, PdfPages(pdf_path) as pdf:
        # Standard wavelength array (log grid, *not* linear Å), shape (NW,)
        wave = f["metadata/standard_wavelength"][:]

        templates_group = f["templates"]
        template_names = list(templates_group.keys())
        if max_templates is not None:
            template_names = template_names[: max_templates]

        for tname in template_names:
            tgrp = templates_group[tname]

            # Collect epoch data (age, flux)
            epoch_data: list[tuple[float, np.ndarray]] = []

            if "epochs" in tgrp:
                for epoch_name in tgrp["epochs"].keys():
                    egrp = tgrp["epochs"][epoch_name]
                    flux = egrp["flux"][:]
                    age = egrp.attrs.get("age", np.nan)
                    epoch_data.append((float(age), flux))
            else:
                flux = tgrp["flux"][:]
                age = tgrp.attrs.get("age", np.nan)
                epoch_data.append((float(age), flux))

            # Create page
            fig, ax = plt.subplots(figsize=(11, 6))
            _plot_template(ax, wave, epoch_data)
            ax.set_title(f"{tname}  –  {h5_path.stem}")
            fig.tight_layout()

            pdf.savefig(fig)
            plt.close(fig)

    print(f"✅ Wrote {pdf_path}")


def main():
    parser = argparse.ArgumentParser(description="Generate multipage PDFs of template spectra stored in HDF5 files.")
    parser.add_argument("-i", "--input_dir", type=str, default="templates", help="Directory containing *.hdf5 files")
    parser.add_argument("-o", "--output_dir", type=str, default=".", help="Directory to write PDFs")
    parser.add_argument("--max_templates", type=int, default=None, help="Limit number of templates per file (for testing)")

    args = parser.parse_args()
    in_dir = Path(args.input_dir).expanduser().resolve()
    out_dir = Path(args.output_dir).expanduser().resolve()
    out_dir.mkdir(parents=True, exist_ok=True)

    h5_files = sorted(list(in_dir.glob("*.hdf5")) + list(in_dir.glob("*.h5")))
    if not h5_files:
        print(f"⚠️  No .h5/.hdf5 files found in {in_dir}")
        return

    for h5_path in h5_files:
        try:
            process_hdf5_file(h5_path, out_dir, args.max_templates)
        except Exception as exc:
            print(f"❌ Failed to process {h5_path.name}: {exc}")


if __name__ == "__main__":
    main() 