#
# This file is part of pyspex
#
# https://github.com/rmvanhees/pyspex.git
#
# Copyright (c) 2019-2025 SRON
#    All Rights Reserved
#
# License:  BSD-3-Clause
"""Collection of routines to access EGSE data generated by ITOS."""

from __future__ import annotations

__all__ = ["add_egse_data", "create_egse_db"]

from datetime import UTC, datetime, timedelta
from pathlib import Path
from typing import TYPE_CHECKING

import h5py
import numpy as np

# pylint: disable=no-name-in-module
from netCDF4 import Dataset

if TYPE_CHECKING:
    import argparse

# - global parameters ------------------------------
DB_EGSE = "egse_db_itos.nc"

# enumerate source status
LDLS_DICT = {
    b"UNPLUGGED": 0,
    b"Controller Fault": 1,
    b"Idle": 2,
    b"Laser ON": 3,
    b"Lamp ON": 4,
    b"MISSING": 255,
}

# enumerate shutter positions
SHUTTER_DICT = {b"CLOSED": 0, b"OPEN": 1, b"PARTIAL": 255}


# - local functions --------------------------------
def init_gse_data(fid: h5py.File) -> h5py.Group:
    """Initialize the netCDF4 group 'gse_data' for EGSE/OGSE information."""
    # investigate filename
    parts = fid.product_name.split("_")
    if len(parts) > 2 and parts[0] == "SPX1":
        parts = parts[2:]
    msmt_fields = parts[0].split("-")
    background = "BKG" in msmt_fields
    act_angle = [float(x.replace("act", "")) for x in parts if x.startswith("act")]
    alt_angle = [float(x.replace("alt", "")) for x in parts if x.startswith("alt")]
    pol_angle = [
        float(x.replace("pol", ""))
        for x in parts
        if x.startswith("pol") and x != "polcal"
    ]
    gp1_angle = [float(x.replace("glass", "")) for x in parts if x.startswith("glass")]
    gp1_offs = 5.634375
    # gp2_offs = 5.09625

    # determine viewport: default 0, when all viewports are illuminated
    if alt_angle:
        vp_angle = np.array([-50.0, -20, 0, 20, 50])
        vp_diff = np.abs(vp_angle - alt_angle[0])
        viewport = 0 if vp_diff.min() > 6 else 2 ** np.argmin(vp_diff)
    else:
        viewport = 0

    gid = fid.createGroup("/gse_data")
    dset = gid.createVariable("viewport", "u1")
    dset.long_name = "viewport status"
    dset.standard_name = "status_flag"
    dset.valid_range = np.array([0, 16], dtype="u1")
    dset.flag_values = np.array([0, 1, 2, 4, 8, 16], dtype="u1")
    dset.flag_meanings = "ALL -50deg -20deg 0deg +20deg +50deg"
    dset[:] = viewport

    # gid.FOV_begin = np.nan
    # gid.FOV_end = np.nan
    gid.ACT_rotationAngle = np.nan if not act_angle else act_angle[0]
    gid.ALT_rotationAngle = np.nan if not alt_angle else alt_angle[0]
    # gid.ACT_illumination = np.nan
    # gid.ALT_illumination = np.nan
    gid.AoLP = 0.0 if not pol_angle else pol_angle[0]
    if not background and msmt_fields[0] in ("POLARIZED", "POLARIMETRIC"):
        gid.DoLP = 1.0
        if gp1_angle:
            gid.GP1_angle = gp1_angle[0] + gp1_offs
    else:
        gid.DoLP = 0.0

    return gid


def byte_to_timestamp(str_date: str) -> float:
    """Convert a byte-string to a timestamp."""
    return datetime.strptime(
        str_date.strip() + "00+00:00", "%Y%m%dT%H%M%S.%f%z"
    ).timestamp()


def egse_dtype() -> np.dtype:
    """Define numpy structured array to hold EGSE data."""
    return np.dtype(
        [
            ("ITOS_time", "f8"),
            ("NOMHK_packets_time", "f8"),
            ("LDLS_STATUS", "u1"),
            ("POLARIZER_MOVING", "u1"),
            ("SHUTTER_STAGE_MOVING", "u1"),
            ("STST_POLARIZER_MOVING", "u1"),
            ("ALT_STAGE_MOVING", "u1"),
            ("ACT_STAGE_MOVING", "u1"),
            ("GP_0_MOVING", "u1"),
            ("GP_1_MOVING", "u1"),
            ("C_FLEX-405nm", "u1"),
            ("C_FLEX-457nm", "u1"),
            ("C_FLEX-515nm", "u1"),
            ("C_FLEX-561nm", "u1"),
            ("C_FLEX-660nm", "u1"),
            ("COBOLT-785nm", "u1"),
            ("CRYSTA_STATUS", "u1"),  # should be C_FLEX-732nm!
            # 15 bytes, will be aligned at 16 or 18 bytes
            ("HK_TS1", "u4"),
            ("HK_TS2", "u4"),
            ("SURV_TST1", "f4"),
            ("SURV_TST2", "f4"),
            ("AI_02", "f4"),
            ("AI_03", "f4"),
            ("AI_04", "f4"),
            ("AI_05", "f4"),
            ("AI_06", "f4"),
            ("AI_07", "f4"),
            ("AI_08", "f4"),
            ("AI_09", "f4"),
            ("V_OUT_ICU", "f4"),
            ("I_OUT_ICU", "f4"),
            ("V_OUT_HTR", "f4"),
            ("I_OUT_HTR", "f4"),
            ("POLARIZER", "f4"),
            ("SHUTTER_STAGE", "f4"),
            ("STST_POLARIZER", "f4"),
            ("ALT_ANGLE", "f4"),
            ("ACT_ANGLE", "f4"),
            ("GP_0_ANGLE", "f4"),
            ("GP_1_ANGLE", "f4"),
        ]
    )


def egse_units() -> tuple[str, ...]:
    """Define numpy structured array to hold EGSE data."""
    return (
        "s",
        "s",
        "1",
        "1",
        "1",
        "1",
        "1",
        "1",
        "1",
        "1",
        "1",
        "1",
        "1",
        "1",
        "1",
        "1",
        "1",
        "Ohm",
        "Ohm",
        "V",
        "V",
        "V",
        "V",
        "V",
        "V",
        "V",
        "V",
        "V",
        "V",
        "V",
        "A",
        "V",
        "A",
        "deg",
        "deg",
        "deg",
        "deg",
        "deg",
        "deg",
        "deg",
    )


def read_egse(egse_file: str, verbose: bool = False) -> tuple[np.ndarray, list[str]]:
    """Read EGSE data (tab separated values) to numpy compound array."""
    with open(egse_file, encoding="ascii") as fid:
        line = None
        names = []
        units = []
        while not line:
            line = fid.readline().strip()
            fields = line.split("\t")
            for field in fields:
                if field == "":
                    continue
                res = field.strip().split(" [")
                names.append(res[0].replace(" nm", "nm").replace(" ", "_"))
                if len(res) == 2:
                    units.append(res[1].replace("[", "").replace("]", ""))
                else:
                    units.append("1")

        if len(names) in (35, 36):
            # define dtype of the data
            formats = (
                ("f8",)
                + 14 * ("f4",)
                + ("u1",)
                + 2 * ("i4",)
                + (
                    "f4",
                    "u1",
                )
                + 2 * ("u1",)
                + 3
                * (
                    "f4",
                    "u1",
                )
                + 7 * ("u1",)
            )
        else:
            # define dtype of the data
            formats = (
                ("f8",)
                + 14 * ("f4",)
                + ("u1",)
                + 2 * ("i4",)
                + (
                    "f4",
                    "u1",
                )
                + 2 * ("u1",)
                + 5
                * (
                    "f4",
                    "u1",
                )
                + 7 * ("u1",)
            )
        usecols = list(range(len(names)))

        if "NOMHK_packets_time" in names:
            formats = ("f8", *formats)
            convertors = {
                0: byte_to_timestamp,
                1: byte_to_timestamp,
                16: lambda s: LDLS_DICT.get(s.strip(), 255),
                21: lambda s: SHUTTER_DICT.get(s.strip(), 255),
            }
        else:
            convertors = {
                0: byte_to_timestamp,
                15: lambda s: LDLS_DICT.get(s.strip(), 255),
                20: lambda s: SHUTTER_DICT.get(s.strip(), 255),
            }
        if verbose:
            print(len(names), names)
            print(len(units), units)
            print(len(formats), formats)

        if not len(names) == len(units) == len(formats):
            raise RuntimeError("Size of names, units or formats are not equal")

        data = np.loadtxt(
            fid,
            delimiter="\t",
            converters=convertors,
            usecols=usecols,
            dtype={"names": names, "formats": formats},
        )

    egse = np.empty(data.size, dtype=egse_dtype())
    egse["NOMHK_packets_time"][:] = np.nan
    egse["GP_0_ANGLE"][:] = np.nan
    egse["GP_1_ANGLE"][:] = np.nan
    egse["GP_0_MOVING"][:] = 255
    egse["GP_1_MOVING"][:] = 255
    for name in data.dtype.names:
        egse[name][:] = data[name][:]

    return egse, units


# ---------- CREATE EGSE DATABASE ----------
def create_egse_db(args: argparse.Namespace) -> None:
    """Write EGSE data to HDF5 database."""
    egse = None
    for egse_file in args.file_list:
        try:
            res = read_egse(egse_file, verbose=args.verbose)
        except RuntimeError:
            return

        egse = res[0] if egse is None else np.concatenate((egse, res[0]))

    with Dataset(args.egse_dir / DB_EGSE, "w", format="NETCDF4") as fid:
        fid.input_files = [Path(x).name for x in args.file_list]
        fid.creation_date = datetime.now(UTC).isoformat(timespec="seconds")

        _ = fid.createEnumType(
            "u1",
            "ldls_t",
            {k.replace(b" ", b"_").upper(): v for k, v in LDLS_DICT.items()},
        )
        _ = fid.createEnumType(
            "u1", "shutter_t", {k.upper(): v for k, v in SHUTTER_DICT.items()}
        )
        _ = fid.createDimension("time", egse.size)
        dset = fid.createVariable("time", "f8", ("time",), chunksizes=(256,))
        time_key = "ITOS_time" if "ITOS_time" in egse.dtype.names else "time"
        indx = np.argsort(egse[time_key])
        dset[:] = egse[time_key][indx]

        egse_t = fid.createCompoundType(egse.dtype, "egse_dtype")
        dset = fid.createVariable("egse", egse_t, ("time",), chunksizes=(64,))
        dset.long_name = "EGSE settings"
        dset.fields = np.bytes_(egse.dtype.names)
        dset.units = np.bytes_(egse_units())
        dset.comment = (
            "DIG_IN_00 is of enumType ldls_t; SHUTTER_STATUS is of enumType shutter_t"
        )
        dset[:] = egse[indx]


# ----- SELECT OGSE DATA FROM DATABASE AND ADD TO L1A PRODUCT -----
def add_egse_data(args: argparse.Namespace) -> None:
    """Write EGSE records of a measurement to a level-1A product."""
    # determine duration of the measurement (ITOS clock)
    with h5py.File(args.l1a_file, "r") as fid:
        # pylint: disable=unsubscriptable-object
        res = fid.attrs["input_files"]
        if isinstance(res, bytes):
            input_file = Path(res.decode("ascii")).stem.rstrip("_hk")
        else:
            input_file = Path(res[0]).stem.rstrip("_hk")
        # pylint: disable=no-member
        msmt_start = datetime.fromisoformat(
            fid.attrs["time_coverage_start"].decode("ascii")
        )
        msmt_stop = datetime.fromisoformat(
            fid.attrs["time_coverage_end"].decode("ascii")
        )
        if args.verbose:
            print(
                "L1A time-coverage:",
                fid.attrs["time_coverage_start"].decode("ascii"),
                fid.attrs["time_coverage_end"].decode("ascii"),
            )
        duration = np.ceil((msmt_stop - msmt_start).total_seconds())

    # use the timestamp in the filename to correct ICU time
    date_str = input_file.split("_")[-1] + "+00:00"
    msmt_start = datetime.strptime(date_str, "%Y%m%dT%H%M%S.%f%z")
    msmt_start = msmt_start.replace(microsecond=0)
    msmt_stop = msmt_start + timedelta(seconds=int(duration))
    if args.verbose:
        print("Corrected time-coverage:", msmt_start.timestamp(), msmt_stop.timestamp())

    # open EGSE database
    with Dataset(args.egse_dir / DB_EGSE, "r") as fid:
        egse_time = fid["time"][:].data
        if args.verbose:
            print("EGSE time-coverage:", egse_time.min(), egse_time.max())
        mask = (egse_time >= msmt_start.timestamp()) & (
            egse_time <= msmt_stop.timestamp()
        )
        if mask.sum() == 0:
            raise RuntimeError("no EGSE data found")

        egse_time = egse_time[mask]
        egse_data = fid["egse"][mask]

    # update Level-1A product with EGSE information
    with Dataset(args.l1a_file, "r+") as fid:
        gid = fid["/gse_data"] if fid.groups.get("/gse_data") else init_gse_data(fid)
        _ = gid.createEnumType(
            "u1",
            "ldls_t",
            {k.replace(b" ", b"_").upper(): v for k, v in LDLS_DICT.items()},
        )
        _ = gid.createEnumType(
            "u1", "shutter_t", {k.upper(): v for k, v in SHUTTER_DICT.items()}
        )
        _ = gid.createDimension("time", egse_data.size)
        dset = gid.createVariable("time", "f8", ("time",))
        dset[:] = egse_time

        egse_t = gid.createCompoundType(egse_data.dtype, "egse_dtype")
        dset = gid.createVariable("egse", egse_t, ("time",))
        dset.long_name = "EGSE settings"
        dset.fields = np.bytes_(egse_data.dtype.names)
        dset.units = np.bytes_(egse_units())
        dset.comment = (
            "DIG_IN_00 is of enumType ldls_t; SHUTTER_STATUS is of enumType shutter_t"
        )
        dset[:] = egse_data

        def smart_average(data: np.ndarray, thres_range: float = 0.1) -> float:
            val_range = np.abs(data.max() - data.min())
            if val_range == 0:
                return float(data[0])

            if val_range < thres_range:
                return np.mean(data)

            return np.median(data)

        fid["/gse_data"].ACT_rotationAngle = smart_average(egse_data["ACT_ANGLE"])
        fid["/gse_data"].ALT_rotationAngle = smart_average(egse_data["ALT_ANGLE"])
        fid["/gse_data"].AoLP = smart_average(egse_data["POLARIZER"])
