#!/usr/bin/env python3
"""
Command-line tool to extract parkeys from JWST CRDS .rmap files,
find .rmap fileswith specific parkeys, or find parkeys by column
name (e.g., as shown on jwst-crds.stsci.edu).

Usage examples:
  # Extract parkeys for all .rmap files in a .pmap
  crds_parkey_tool /path/mappings/jwst/jwst_1234.pmap

  # Extract parkeys for NIRCam .rmap files
  crds_parkey_tool /path/mappings/jwst/jwst_1234.pmap --mission jwst --instrument nircam

  # Find .rmap files with all specified parkeys
  crds_parkey_tool /path/mappings/jwst/jwst_1234.pmap --mission jwst --parkeys META.INSTRUMENT.GRATING,META.EXPOSURE.TYPE --parkey-mode all

  # Find .rmap files with any specified parkey
  crds_parkey_tool /path/mappings/jwst/jwst_1234.pmap --mission jwst --parkeys META.INSTRUMENT.GRATING,META.EXPOSURE.TYPE --parkey-mode any

  # Find .rmap files and parkey names corresponding to all specific column names
  crds_parkey_tool /path/mappings/jwst/jwst_1234.pmap --mission jwst --column GRATING,TYPE --column-mode all

  # Find .rmap files and parkey names corresponding to any specific column names
  crds_parkey_tool /path/mappings/jwst/jwst_1234.pmap --mission jwst --column GRATING,TYPE --column-mode any

Requirements:
  - CRDS package installed (`pip install crds`)
  - Environment variables: CRDS_PATH, CRDS_SERVER_URL
"""

import argparse
import os
import sys
import crds
from crds.core import rmap

__version__ = "1.0"


def get_rmaps_from_pmap(pmap_file, instrument=None, mission='jwst'):
    """
    Retrieve .rmap files referenced by a .pmap file, optionally filtered
    by instrument.

    Parameters:
        pmap_file (str): The context file.
        instrument (str, optional): Instrument name (e.g., 'nircam', 'miri').
        If None, include all.
        mission (str): 'jwst' or 'hst'

    Returns:
        list: List of paths to .rmap files.
    """
    try:
        pmap = rmap.Mapping.from_file(pmap_file)
        crds_path = os.environ.get("CRDS_PATH",
                                   os.path.expanduser("~/crds_cache"))
        mapping_dir = os.path.join(crds_path, "mappings", mission)
        imap_files = pmap.selector.values()

        # Uses the imap for selected instrument(s)
        if instrument:
            imap_files = [imap_name for imap_name in imap_files
                          if imap_name.lower().startswith(f"{mission}_{instrument.lower()}_")]

        # For each instrument, selects all the existing rmaps.
        rmap_files = []
        for imap_name in imap_files:
            imap_path = os.path.join(mapping_dir, imap_name)
            if os.path.exists(imap_path):
                imap = rmap.Mapping.from_file(imap_path)
                rmap_files.extend(
                    os.path.join(mapping_dir, rmap_name)
                    for rmap_name in imap.selector.values()
                    if rmap_name.endswith(".rmap")
                )

        return rmap_files

    except Exception as e:
        print(f"Error reading .pmap file: {e}")
        return []


def get_column_names(parkeys):
    """
    Derive column names from parkeys. Uses the last field
    in the parkey value. Case with last field with value
    NAME, it uses the second_to_last. It ignores the parkeys
    for DATE and TIME that are used for the USEAFTER value.

    Parameters:
        parkeys (list or dict): Parkeys from an .rmap file.

    Returns:
        dict: Mapping of parkeys to (last_component/second_to_last, parkey)
        tuples.
    """

    # Parkeys are usually tuples and lists but add other formats for
    # completness.
    column_mapping = {}
    ignored_parkeys = {"META.OBSERVATION.TIME",
                       "META.OBSERVATION.DATE",
                       "DATE-OBS",
                       "TIME-OBS"}  # Set of parkeys to ignore

    if isinstance(parkeys, dict):
        for values in parkeys.values():
            for value in values:
                if value not in ignored_parkeys:  # Skip ignored parkeys
                    if value.split(".")[-1].upper() == 'NAME':
                        column_mapping[value] = (value.split(".")[-2].upper(),
                                                 value)
                    else:
                        column_mapping[value] = (value.split(".")[-1].upper(),
                                                 value)
    elif isinstance(parkeys, (tuple, list)):
        for item in parkeys:
            if isinstance(item, (tuple, list)) and item:
                for value in item:
                    if value not in ignored_parkeys:  # Skip ignored parkeys
                        if value.split(".")[-1].upper() == 'NAME':
                            column_mapping[value] = (
                                value.split(".")[-2].upper(),
                                value)
                        else:
                            column_mapping[value] = (
                                value.split(".")[-1].upper(),
                                value)
            elif isinstance(item, str):
                if item not in ignored_parkeys:  # Skip ignored parkeys
                    if item.split(".")[-1].upper() == 'NAME':
                        column_mapping[item] = (item.split(".")[-2].upper(),
                                                item)
                    else:
                        column_mapping[item] = (item.split(".")[-1].upper(),
                                                item)
    return column_mapping


def extract_parkeys(rmap_file):
    """
    Extract parkeys from an .rmap file, handling different parkey formats.

    Parameters:
        rmap_file (str): Path to the .rmap file.

    Returns:
        tuple: (parkeys, column_mapping) or (None, None) if error.
    """
    try:
        mapping = rmap.Mapping.from_file(rmap_file)
        parkeys = mapping.parkey

        if isinstance(parkeys, dict):
            return parkeys, get_column_names(parkeys)
        elif isinstance(parkeys, (tuple, list)):
            flatlist = []
            for item in parkeys:
                if isinstance(item, (tuple, list)) and item:
                    flatlist.extend(item)
                elif isinstance(item, str):
                    flatlist.append(item)
            return flatlist if flatlist else None, get_column_names(flatlist)
        else:
            print(f"Unexpected parkey format in {rmap_file}: {type(parkeys)}")
            return None, None

    except Exception as e:
        print(f"Error reading .rmap file {rmap_file}: {e}")
        return None, None


def find_rmaps_with_parkeys(pmap_file, target_parkeys, mission='jwst', instrument=None, parkey_mode="all"):
    """
    Find .rmap files that include all specified parkeys, optionally
    for a given instrument. Match parkeys based on mode ('all' or 'any').

    Parameters:
        pmap_file (str): Path to the context or .pmap file.
        target_parkeys (list): List of parkeys to match (e.g., ['META.EXPOSURE.TYPE']).
        instrument (str, optional): Instrument name. If None, include all.
        parkey_mode (str): 'all' (match all parkeys) or 'any' (match any parkey).

    Returns:
        list: List of .rmap files that include all or any target parkeys.
    """
    rmap_files = get_rmaps_from_pmap(pmap_file, instrument, mission)
    
    if parkey_mode == "all":
        matching_rmaps = []
        for rmap_file in rmap_files:
            parkeys, _ = extract_parkeys(rmap_file)
            if parkeys:
                if isinstance(parkeys, dict):
                    parkey_set = set()
                    for values in parkeys.values():
                        parkey_set.update(str(v).upper() for v in values)
                else:
                    parkey_set = set(str(p).upper() for p in parkeys)
                
                target_parkeys_set = set(str(p).upper() for p in target_parkeys)
                if target_parkeys_set.issubset(parkey_set):
                    matching_rmaps.append(rmap_file)
        return matching_rmaps
    
    elif parkey_mode == "any":
        matching_rmaps = {parkey: [] for parkey in target_parkeys}
        for rmap_file in rmap_files:
            parkeys, _ = extract_parkeys(rmap_file)
            if parkeys:
                if isinstance(parkeys, dict):
                    parkey_set = set()
                    for values in parkeys.values():
                        parkey_set.update(str(v).upper() for v in values)
                else:
                    parkey_set = set(str(p).upper() for p in parkeys)
                
                for parkey in target_parkeys:
                    if str(parkey).upper() in parkey_set:
                        matching_rmaps[parkey].append(rmap_file)
        return matching_rmaps


def find_parkeys_by_column(pmap_file, column_names, mission='jwst', instrument=None, column_mode="all"):
    """
    Find parkeys whose last/previous-to-last component matches the
    specified column name.

    Parameters:
        pmap_file (str): Path to the context file.
        column_name (str): Column name to match (e.g., 'GRATING').
        instrument (str, optional): Instrument name (e.g., 'nircam'). If None, include all.
        column_mode (str): 'all' (match all column names) or 'any' (match any column name).

    Returns:
        dict (mode='any'): Mapping of column names to dicts of .rmap files to matching parkeys.
        dict (mode='all'): Mapping of .rmap files to lists of matching parkeys for all column names.
    """
    rmap_files = get_rmaps_from_pmap(pmap_file, instrument, mission)

    # Case 1: Matches all the columns
    if column_mode == "all":
        matching_parkeys = {}
        for rmap_file in rmap_files:
            parkeys, column_mapping = extract_parkeys(rmap_file)
            if parkeys:
                matched = []
                column_names_set = set(c.upper() for c in column_names)
                for parkey, (last, _) in column_mapping.items():
                    if last in column_names_set:
                        matched.append(parkey)
                # Check if all column names are matched
                matched_columns = {column_mapping[p][0] for p in matched}
                if column_names_set.issubset(matched_columns):
                    matching_parkeys[rmap_file] = matched
        return matching_parkeys

    # Case 2: Matches any of the columns
    elif column_mode == "any":
        matching_parkeys = {column_name: {} for column_name in column_names}
        for rmap_file in rmap_files:
            parkeys, column_mapping = extract_parkeys(rmap_file)
            if parkeys:
                for column_name in column_names:
                    matched = [parkey
                               for parkey, (last, _) in column_mapping.items()
                               if last == column_name.upper()]
                    if matched:
                        matching_parkeys[column_name][rmap_file] = matched
        return matching_parkeys


def main():
    parser = argparse.ArgumentParser(
        description="Extract parkeys from JWST CRDS .rmap files, find .rmap files with specific parkeys, or find parkeys by column name.",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""Examples:

  Extract parkeys for all .rmap files:
  rmap_parkeys_tool jwst_1234.pmap

  Find .rmap files with dictionary parkeys:
  rmap_parkeys_tool jwst_1234.pmap --parkey-format dict

  Extract parkeys for NIRCam:
  rmap_parkeys_tool jwst_1234.pmap --instrument nircam

  Find .rmap files with all specified parkeys:
  rmap_parkeys_tool jwst_1234.pmap --parkeys META.INSTRUMENT.GRATING,META.EXPOSURE.TYPE --parkey-mode all

  Find .rmap files with any specified parkey:
  rmaps_parkeys_tool jwst_1234.pmap --parkeys META.INSTRUMENT.GRATING,META.EXPOSURE.TYPE --parkey-mode any

  Find .rmap files and parkeys for all specified column names:
  rmap_parkeys_tool /path/mappings/jwst/jwst_1234.pmap  --column GRATING,TYPE --column-mode all

  Find .rmap files and parkeys corresponding to any specified column name:
  rmap_parkeys_tool /path/mappings/jwst/jwst_1234.pmap  --column GRATING,TYPE  --column-mode any

""",
    )

    parser.add_argument(
        "pmap_file",
        type=str,
        help="Path to the context file (e.g., "
             "'/path/mappings/jwst/jwst_1234.pmap')"
    )
    parser.add_argument(
        "--mission",
        type=str,
        choices=["jwst", "hst"],
        help="Filter .rmap files by mission (e.g., 'jwst')"
    )
    parser.add_argument(
        "--instrument",
        type=str,
        choices=["miri", "nircam", "niriss", "fgs", "nirspec", "acs", "stis", "cos", "wfc3"],
        help="Filter .rmap files by instrument for a single mission (e.g., 'nircam', 'miri')"
    )
    parser.add_argument(
        "--parkeys",
        type=str,
        help="Comma-separated list of parkeys to match "
             "(e.g., 'META.EXPOSURE.TYPE,META.INSTRUMENT.FILTER' for jwst or "
             " 'ATODGAIN','MODE' for hst) "
             "or a single parkey"
    )

    parser.add_argument(
        "--parkey-mode",
        type=str,
        choices=["all", "any"],
        default="all",
        help="Search mode for parkeys: 'all' (match all parkeys) "
             "or 'any' (match any parkey)"
    )

    parser.add_argument(
        "--columns",
        type=str,
        help="Comma-separated list of columnss to match for a single mission"
             "(e.g., 'TYPE,FILTER') or a single column"
    )

    parser.add_argument(
        "--column-mode",
        type=str,
        choices=["all", "any"],
        default="all",
        help="Search mode for columns: 'all' (match allcolumns) or "
             "'any' (match any column)"
    )

    parser.add_argument(
        "--version",
        action="version",
        version=f"{__version__} (CRDS {crds.__version__})",
        help="Show script and CRDS version"
    )

    args = parser.parse_args()

    # Check CRDS environment variables
    if not os.environ.get("CRDS_PATH"):
        print("Error: CRDS_PATH environment variable is not set. "
              "Set it to your CRDS cache directory (e.g., ~/crds_cache).")
        sys.exit(1)

    if not os.environ.get("CRDS_SERVER_URL"):
        print("Error: CRDS_SERVER_URL environment variable is not set."
              " Set it to 'https://jwst-crds.stsci.edu' or 'https://hst-crds.stsci.edu'.")
        sys.exit(1)

    # Validate pmap_file
    crds_path = os.path.join(os.environ.get("CRDS_PATH"), "mappings", args.mission)
    if not os.path.exists(os.path.join(crds_path, args.pmap_file)):
        print(
            f"Error: .pmap file not found at "
            f"{os.path.join(crds_path,args.pmap_file)}"
        )
        sys.exit(1)

    # Ensure only one of --parkeys or --columns is provided
    if args.parkeys and args.columns:
        print("Error: Cannot use both --parkeys and --column together.")
        sys.exit(1)

    used_args = sum([bool(args.parkeys), bool(args.columns)])
    if used_args > 1:
        print("Error: Can only use one of --parkeys or --column at a time.")
        sys.exit(1)

    # Case 1: Find .rmap files by column name
    if args.columns:
        if "," in args.columns:
            target_columns = args.columns.split(",")
        else:
            target_columns = [args.columns]
        matching_parkeys = find_parkeys_by_column(args.pmap_file,
                                                  target_columns,
                                                  args.mission,
                                                  args.instrument,
                                                  args.column_mode)
        if args.column_mode == "all":
            if matching_parkeys:
                print(
                    f"\nFound parkeys for column name {target_columns}"
                    f" in {len(matching_parkeys)} .rmap files:"
                )
                for rmap_file, parkeys in matching_parkeys.items():
                    print(f"  {rmap_file}:")
                    for parkey in parkeys:
                        print(f"    {parkey}")
            else:
                print(f"\nNo parkeys found for column name '{target_columns}'")

        else:  # mode == "any"
            any_matches = False
            for column_name, rmap_parkeys in matching_parkeys.items():
                if rmap_parkeys:
                    any_matches = True
                    print(
                        f"\nFound parkeys for column name '{column_name}'"
                        f" in {len(rmap_parkeys)} .rmap files:"
                    )
                    for rmap_file, parkeys in rmap_parkeys.items():
                        print(f"  {rmap_file}:")
                        for parkey in parkeys:
                            print(f"    {parkey}")
                else:
                    print(
                        f"\nNo parkeys found for column name"
                        f" '{column_name}'"
                    )
            if not any_matches:
                print(
                    f"\nNo parkeys found for any column names"
                    f" {column_name}"
                )
        return

    # Case 2: Find .rmap files with specific parkeys
    if args.parkeys:
        if "," in args.parkeys:
            target_parkeys = args.parkeys.split(",")
        else:
            target_parkeys = [args.parkeys]
        matching_rmaps = find_rmaps_with_parkeys(args.pmap_file,
                                                 target_parkeys,
                                                 args.mission,
                                                 args.instrument,
                                                 args.parkey_mode)
        if args.parkey_mode == "all":
            if matching_rmaps:
                print(
                    f"\nFound {len(matching_rmaps)} .rmap"
                    f" files with all parkeys {target_parkeys}: ")
                for rmap_file in matching_rmaps:
                    print(f"  {rmap_file}")
            else:
                print(
                    f"\nNo .rmap files found with all parkeys"
                    f" {target_parkeys}"
                )
        else:  # mode == "any"
            any_matches = False
            print(matching_rmaps)
            for parkey, rmap_files in matching_rmaps.items():
                if rmap_files:
                    any_matches = True
                    print(
                        f"\nFound {len(rmap_files)} .rmap files"
                        f" with parkey '{parkey}': "
                    )
                    for rmap_file in rmap_files:
                        print(f"  {rmap_file}")
                else:
                    print(f"\nNo .rmap files found with parkey '{parkey}'")
            if not any_matches:
                print(
                    f"\nNo .rmap files found with any parkeys"
                    f" {target_parkeys}"
                )
        return

    # Case 3: Extract parkeys for .rmap files (optionally filtered
    # by instrument)
    rmap_files = get_rmaps_from_pmap(args.pmap_file, args.instrument, args.mission)

    if not rmap_files:
        print(
            f"No .rmap files found for {args.pmap_file}"
            f" {'for instrument ' + args.instrument if args.instrument else ''}"
        )
        return

    print(
        f"Found {len(rmap_files)} .rmap files in {args.pmap_file}"
        f" for instrument {args.instrument if args.instrument else ''}\n")

    for rmap_file in rmap_files:
        print(f"Processing {rmap_file}")
        parkeys, column_mapping = extract_parkeys(rmap_file)
        if parkeys:
            print(f"Parameter keys for {rmap_file}:")
            if isinstance(parkeys, dict):
                for key, value in parkeys.items():
                    print(f"  {key}: {value}")
            elif isinstance(parkeys, list):
                print(f"  {parkeys}")
            else:
                print(f"  {parkeys}")
            print("Derived column names (as on jwst-crds.stsci.edu):")
            for parkey, (last, _) in column_mapping.items():
                print(f"  Parkey: {parkey} -  Column Name: {last}")
            print()
        else:
            print(f"Failed to extract parkeys for {rmap_file}\n")


if __name__ == "__main__":
    main()
