
# generate_pheno_plink_fast.py  (explicit DO/NON-DO via panel_type)
from __future__ import annotations
import os
import io
import math
import logging
from typing import Dict, List, Union

import numpy as np
import pandas as pd

from plinkformatter.plink_utils import (
    generate_bed_bim_fam,
    calculate_kinship_from_pedmap,
    rewrite_pheno_ids_from_fam,
)
from plinkformatter.generate_pheno_plink import extract_pheno_measure


def _norm_id(x) -> str:
    """
    Normalize IDs used to join DO PED V1 with pheno['animal_id'].

    - Strip whitespace.
    - If numeric, collapse "123", "123.0", "123.000" → "123".
    - Leave non-numeric IDs (e.g. "DO-123") untouched apart from stripping.
    """
    s = str(x).strip()
    if s == "":
        return s

    # Try to canonicalize numeric IDs first
    try:
        f = float(s)
        if f.is_integer():
            return str(int(f))
        # Non-integer numeric IDs: avoid scientific notation
        return ("%.10g" % f).rstrip()
    except Exception:
        # Not numeric: fall back to simple cleanup
        if s.endswith(".0"):
            s = s[:-2]
        return s


# ----------------------------- NON-DO PATH ----------------------------- #
def generate_pheno_plink_fast_non_do(
    ped_file: str,
    map_file: str,
    pheno: pd.DataFrame,
    outdir: str,
) -> pd.DataFrame:
    """
    NON-DO behavior (matches Hao; this is your validated path):
      - replicate-level rows, keyed by STRAIN
      - FID = IID = STRAIN
      - PID/MID = 0; SEX=2 if 'f' else 1; PHE=zscore or -9
      - PHENO: 'FID IID zscore value'
      - MAP: '.' rsids -> 'chr_bp'
    """
    os.makedirs(outdir, exist_ok=True)
    if pheno is None or pheno.empty:
        return pd.DataFrame()

    need = ("strain", "sex", "measnum", "value")
    missing = [c for c in need if c not in pheno.columns]
    if missing:
        raise ValueError(f"pheno missing required columns: {missing} (need {list(need)})")

    ph = pheno.copy()
    ph["strain"] = ph["strain"].astype(str).str.replace(" ", "", regex=False)
    ph = ph[ph["sex"].isin(["f", "m"])].copy()
    if ph.empty:
        return ph

    # MAP sanitize
    map_df = pd.read_csv(map_file, header=None, sep=r"\s+", engine="python")
    map_df[1] = np.where(
        map_df[1].astype(str) == ".",
        map_df[0].astype(str) + "_" + map_df[3].astype(str),
        map_df[1].astype(str),
    )

    # Ensure zscore column
    if "zscore" not in ph.columns:
        logging.info("[NON-DO] 'zscore' missing; filling NaN (becomes -9).")
        ph["zscore"] = np.nan

    # Build strain -> byte offset index from reference PED
    ped_offsets: Dict[str, int] = {}
    with open(ped_file, "rb") as f:
        while True:
            pos = f.tell()
            line = f.readline()
            if not line:
                break
            first_tab = line.find(b"\t")
            fid_bytes = (line.strip().split()[0] if first_tab <= 0 else line[:first_tab])
            name = fid_bytes.decode(errors="replace").replace("?", "").replace(" ", "")
            if name and name not in ped_offsets:
                ped_offsets[name] = pos

    ped_strains = set(ped_offsets.keys())
    ph = ph[ph["strain"].isin(ped_strains)].reset_index(drop=True)
    if ph.empty:
        return ph

    # Conservative de-duplication (kept from the working version)
    dedup_keys = [c for c in ["strain", "sex", "measnum", "animal_id"] if c in ph.columns]
    if dedup_keys:
        ph = ph.drop_duplicates(subset=dedup_keys, keep="first")
    else:
        sig_cols = [c for c in ["strain", "sex", "measnum", "zscore", "value"] if c in ph.columns]
        ph = ph.drop_duplicates(subset=sig_cols, keep="first")

    for (measnum, sex), df in ph.groupby(["measnum", "sex"], sort=False):
        measnum = int(measnum)
        sex = str(sex)

        # MAP
        map_out = os.path.join(outdir, f"{measnum}.{sex}.map")
        map_df.to_csv(map_out, sep="\t", index=False, header=False)

        ped_out = os.path.join(outdir, f"{measnum}.{sex}.ped")
        phe_out = os.path.join(outdir, f"{measnum}.{sex}.pheno")

        df = df.sort_values(["strain"], kind="stable").reset_index(drop=True)

        with open(ped_out, "w", encoding="utf-8") as f_ped, open(phe_out, "w", encoding="utf-8") as f_ph:
            for strain, sdf in df.groupby("strain", sort=False):
                with open(ped_file, "rb") as fp:
                    fp.seek(ped_offsets[strain])
                    raw = fp.readline().decode(errors="replace").rstrip("\n")

                parts = raw.split("\t")
                if len(parts) <= 6:
                    parts = raw.split()
                if len(parts) < 7:
                    raise ValueError("Malformed PED: need >=7 columns (6 meta + genotypes)")

                parts[0] = parts[0].replace("?", "").replace(" ", "")
                parts[1] = parts[1].replace("?", "").replace(" ", "")

                for _, r in sdf.iterrows():
                    z = r.get("zscore", np.nan)
                    v = r.get("value", np.nan)
                    try:
                        z = float(z)
                    except Exception:
                        z = np.nan
                    try:
                        v = float(v)
                    except Exception:
                        v = np.nan

                    meta = parts[:6]
                    meta[0] = strain
                    meta[1] = strain
                    meta[2] = "0"
                    meta[3] = "0"
                    meta[4] = "2" if sex == "f" else "1"
                    meta[5] = f"{z}" if math.isfinite(z) else "-9"

                    out = io.StringIO()
                    out.write(" ".join(meta))
                    for gp in parts[6:]:
                        a_b = gp.split(" ")
                        if len(a_b) != 2:
                            a_b = gp.split()
                            if len(a_b) != 2:
                                raise ValueError(f"Genotype pair not splitable into two alleles: {gp!r}")
                        out.write(f" {a_b[0]} {a_b[1]}")
                    f_ped.write(out.getvalue() + "\n")

                    f_ph.write(
                        f"{strain} {strain} "
                        f"{(z if math.isfinite(z) else -9)} "
                        f"{(v if math.isfinite(v) else -9)}\n"
                    )

        logging.info(f"[generate_pheno_plink_fast:NON-DO] wrote {ped_out}, {map_out}, {phe_out}")

    return ph


# ------------------------------- DO PATH ------------------------------- #
def generate_pheno_plink_fast_do(
    ped_file: str,
    map_file: str,
    pheno: pd.DataFrame,
    outdir: str,
) -> pd.DataFrame:
    """
    DO behavior (mirror Hao's 'J:DO' branch, but streaming the PED):

      - Find animals where pheno.animal_id overlaps PED V1 (after ID normalization).
      - For each overlapping PED row, attach sex/measnum/zscore/value from pheno.
      - For each (measnum, sex), write:
          * PED: V1..last_col (meta + genotypes), with V5 = sex, V6 = zscore
          * MAP: full map with '.' -> 'chr_bp'
          * PHENO: FID = V1, IID = V2, zscore, value

      Implementation avoids loading the PED into memory; it scans ped_file once.
    """
    os.makedirs(outdir, exist_ok=True)
    if pheno is None or pheno.empty:
        return pd.DataFrame()

    # Required columns
    need = ("strain", "sex", "measnum", "value", "animal_id")
    missing = [c for c in need if c not in pheno.columns]
    if missing:
        raise ValueError(f"[DO] pheno missing required columns: {missing} (need {list(need)})")

    ph = pheno.copy()
    # Match Hao: strip spaces from strain
    ph["strain"] = ph["strain"].astype(str).str.replace(" ", "", regex=False)
    # Only keep f/m
    ph = ph[ph["sex"].isin(["f", "m"])].copy()
    if ph.empty:
        return ph

    # Hao's DO branch is triggered when strain == "J:DO"
    if "J:DO" not in set(ph["strain"]):
        logging.warning("[DO] No 'J:DO' strain in pheno; falling back to NON-DO behavior.")
        # IMPORTANT: NON-DO path is already validated and should not be touched
        return generate_pheno_plink_fast_non_do(ped_file, map_file, pheno, outdir)

    # MAP sanitize (same as NON-DO / Hao)
    map_df = pd.read_csv(map_file, header=None, sep=r"\s+", engine="python")
    map_df[1] = np.where(
        map_df[1].astype(str) == ".",
        map_df[0].astype(str) + "_" + map_df[3].astype(str),
        map_df[1].astype(str),
    )

    # Ensure zscore column exists
    if "zscore" not in ph.columns:
        logging.info("[DO] 'zscore' missing; filling NaN (becomes -9).")
        ph["zscore"] = np.nan

    # ------------------------------------------------------------------
    # Build per-animal lookup from pheno, mimicking:
    #   pheno.overlap.id =
    #     pheno %>%
    #       filter(animal_id %in% overlap.id) %>%
    #       mutate(animal_id = as.numeric(animal_id)) %>%
    #       slice(match(ped.overlap.id$V1, animal_id))
    #
    # We normalize IDs with _norm_id and keep the *first* row per animal_id,
    # just like match() would.
    # ------------------------------------------------------------------
    ph["animal_id_norm"] = ph["animal_id"].map(_norm_id)

    per_id: Dict[str, tuple] = {}
    for _, r in ph.iterrows():
        aid = r["animal_id_norm"]
        if not aid:
            continue
        # First occurrence wins (match-like behavior)
        if aid in per_id:
            continue
        # Coerce measnum, zscore, value to numeric-ish
        try:
            meas = int(r["measnum"])
        except Exception:
            # if something insane comes through, just skip it
            continue
        sex = str(r["sex"])
        z = r.get("zscore", np.nan)
        v = r.get("value", np.nan)
        try:
            z = float(z)
        except Exception:
            z = np.nan
        try:
            v = float(v)
        except Exception:
            v = np.nan
        per_id[aid] = (sex, meas, z, v)

    if not per_id:
        logging.info("[DO] No overlapping animal IDs between pheno and PED; nothing to do.")
        return ph.iloc[0:0].copy()

    # ------------------------------------------------------------------
    # Stream PED once, in file order, and for each overlapping V1:
    #   - update V5/V6 from pheno (sex/zscore)
    #   - route the row into the correct (measnum, sex) PED + PHENO files
    #   - write MAP once per (measnum, sex)
    # This reproduces Hao's ped.overlap.id + pheno.ped behavior but without
    # loading the PED into memory.
    # ------------------------------------------------------------------
    handles: Dict[tuple, tuple] = {}  # (meas, sex) -> (ped_fh, phe_fh)
    wrote_map_for: set[tuple] = set()  # (meas, sex) that already have a .map

    with open(ped_file, "r", encoding="utf-8", errors="replace") as fped:
        for raw in fped:
            if not raw.strip():
                continue
            # Split on any whitespace, like PLINK and fread() do
            parts = raw.rstrip("\n").split()
            if len(parts) < 7:
                # malformed PED line, ignore (Hao's fread would drop these)
                continue

            V1, V2 = parts[0], parts[1]
            aid_norm = _norm_id(V1)

            if aid_norm not in per_id:
                # This PED animal is not in the pheno set
                continue

            sex_label, measnum, z, v = per_id[aid_norm]
            key = (measnum, sex_label)

            # Lazily open PED/PHENO/MAP for this (measnum, sex)
            if key not in handles:
                ped_out = os.path.join(outdir, f"{measnum}.{sex_label}.ped")
                phe_out = os.path.join(outdir, f"{measnum}.{sex_label}.pheno")

                ped_fh = open(ped_out, "w", encoding="utf-8")
                phe_fh = open(phe_out, "w", encoding="utf-8")
                handles[key] = (ped_fh, phe_fh)

                # Write MAP once for this (measnum, sex)
                if key not in wrote_map_for:
                    map_out = os.path.join(outdir, f"{measnum}.{sex_label}.map")
                    map_df.to_csv(map_out, sep="\t", header=False, index=False)
                    wrote_map_for.add(key)
                    logging.info(
                        "[generate_pheno_plink_fast:DO] wrote %s", map_out
                    )

            ped_fh, phe_fh = handles[key]

            # Update V5 (SEX) and V6 (PHE) like Hao:
            #   V5 = 2 if f, 1 if m
            #   V6 = zscore, with missing -> -9
            parts_meta = parts[:6]
            if sex_label == "f":
                parts_meta[4] = "2"
            elif sex_label == "m":
                parts_meta[4] = "1"
            # else: just leave whatever is there (shouldn't happen)

            if math.isfinite(z):
                parts_meta[5] = f"{z}"
                z_out = z
            else:
                parts_meta[5] = "-9"
                z_out = -9.0

            # The rest of the line are genotype tokens already split (V7+),
            # so we just preserve them.
            parts_out = parts_meta + parts[6:]

            ped_fh.write(" ".join(parts_out) + "\n")

            # PHENO: FID = V1, IID = V2, zscore, value
            if not math.isfinite(v):
                v_out = -9.0
            else:
                v_out = v
            phe_fh.write(f"{V1} {V2} {z_out} {v_out}\n")

    # Close all open handles
    for ped_fh, phe_fh in handles.values():
        ped_fh.close()
        phe_fh.close()

    # Log what we actually produced
    if not handles:
        logging.info("[generate_pheno_plink_fast:DO] No DO PED/PHENO files were written.")
    else:
        for (measnum, sex_label) in sorted(handles.keys()):
            logging.info(
                "[generate_pheno_plink_fast:DO] finished measnum=%s sex=%s", measnum, sex_label
            )

    return ph


# ------------------------------- WRAPPER ------------------------------- #
def generate_pheno_plink_fast(
    ped_file: str,
    map_file: str,
    pheno: pd.DataFrame,
    outdir: str,
    ncore: int = 1,
    *,
    panel_type: str = "NON_DO",   # <-- explicit control
) -> pd.DataFrame:
    """
    Wrapper that dispatches to DO or NON-DO implementation based on explicit panel_type.
      panel_type ∈ {"DO","NON_DO"}  (default NON_DO to preserve current behavior)
    """
    if pheno is None or pheno.empty:
        os.makedirs(outdir, exist_ok=True)
        return pd.DataFrame()

    pt = (panel_type or "NON_DO").upper()
    if pt == "DO":
        logging.info("[generate_pheno_plink_fast] using DO panel_type")
        return generate_pheno_plink_fast_do(ped_file, map_file, pheno, outdir)
    elif pt == "NON_DO":
        logging.info("[generate_pheno_plink_fast] using NON-DO panel_type")
        return generate_pheno_plink_fast_non_do(ped_file, map_file, pheno, outdir)
    else:
        raise ValueError(f"panel_type must be 'DO' or 'NON_DO', got {panel_type!r}")


# ----------------------------- Orchestrator ---------------------------- #
def fast_prepare_pylmm_inputs(
    ped_file: str,
    map_file: str,
    measure_id_directory: str,
    measure_ids: List,
    outdir: str,
    ncore: int,
    plink2_path: str,
    *,
    panel_type: str = "NON_DO",
    ped_pheno_field: str = "zscore",
    maf_threshold: Union[float, None] = None,
) -> None:
    """
    Orchestrate extraction and PLINK file generation (public API + panel_type):

      1) Extract phenotype rows for requested measure_ids.
      2) generate_pheno_plink_fast(..., panel_type=...) -> writes <meas>.<sex>.ped/.map/.pheno
      3) PLINK2 from --pedmap -> BED/BIM/FAM   (geno 0.1, mind 0.1)
      4) Rewrite PHENO IIDs from .fam (exact FID/IID order + suffixes)
      5) Kinship from --pedmap (square .rel)
    """
    os.makedirs(outdir, exist_ok=True)

    pheno = extract_pheno_measure(measure_id_directory, measure_ids)
    if pheno is None or pheno.empty:
        logging.info("[fast_prepare_pylmm_inputs] no phenotype rows extracted; nothing to do.")
        return

    used = generate_pheno_plink_fast(
        ped_file=ped_file,
        map_file=map_file,
        pheno=pheno,
        outdir=outdir,
        ncore=ncore,
        panel_type=panel_type,
    )
    if used is None or used.empty:
        logging.info("[fast_prepare_pylmm_inputs] no usable phenotypes after PED/MAP intersection; nothing to do.")
        return

    for measure_id in measure_ids:
        base_id = str(measure_id).split("_", 1)[0]
        for sex in ("f", "m"):
            ped_path   = os.path.join(outdir, f"{base_id}.{sex}.ped")
            map_path   = os.path.join(outdir, f"{base_id}.{sex}.map")
            out_prefix = os.path.join(outdir, f"{base_id}.{sex}")

            if not (os.path.exists(ped_path) and os.path.exists(map_path)):
                continue

            logging.info(f"[fast_prepare_pylmm_inputs] make BED/BIM/FAM for {base_id}.{sex}")
            generate_bed_bim_fam(
                plink2_path=plink2_path,
                ped_file=ped_path,
                map_file=map_path,
                output_prefix=out_prefix,
                relax_mind_threshold=False,
                maf_threshold=maf_threshold,
                sample_keep_path=None,
                autosomes_only=False,
            )

            # Align PHENO IIDs to FAM IIDs (strict 1:1; no dedup)
            fam_path   = f"{out_prefix}.fam"
            pheno_path = os.path.join(outdir, f"{base_id}.{sex}.pheno")
            rewrite_pheno_ids_from_fam(pheno_path, fam_path, pheno_path)

            logging.info(f"[fast_prepare_pylmm_inputs] compute kinship for {base_id}.{sex} (from --pedmap)")
            calculate_kinship_from_pedmap(
                plink2_path=plink2_path,
                pedmap_prefix=out_prefix,
                kin_prefix=os.path.join(outdir, f"{base_id}.{sex}.kin"),
            )


