import argparse
import sys
from .light_ember import light_ember
from .generate_pvals import generate_pvals
from .plots import plot_partition_specificity, plot_block_specificity, plot_sample_counts, plot_psi_blocks
from .top_genes import highly_specific_to_partition, highly_specific_to_block, non_specific_to_partition

def create_parser():
    """Creates and returns the ArgumentParser object for the ember toolkit."""
    parser = argparse.ArgumentParser(
        prog="ember",
        description="A command-line toolkit for ember: Entropy Metrics for Biological ExploRation."
    )
    
    subparsers = parser.add_subparsers(dest="command", required=True, help="Available sub-commands")

    # =================================================================
    # ==                 COMMAND 1: light_ember                      ==
    # =================================================================

    light_ember_parser = subparsers.add_parser(
        "light_ember",
        help="Runs the ember entropy metrics and p-value generation workflow on an AnnData object.",
        formatter_class=argparse.RawTextHelpFormatter,
        description="""\
    Runs the ember entropy metrics and p-value generation workflow on an AnnData object.

    This function loads an AnnData `.h5ad` file, optionally performs balanced sampling
    across replicates, computes entropy metrics for the specified partition,
    and generates p-values for Psi and Zeta and optionally Psi_block for a block of choice.

    Entropy metrics generated:
        - Psi : Fraction of information explained by partition of choice
        - Psi_block : Specificity of information to a block
        - Zeta : Specificity to a partition / distance of Psi_blocks distribution from uniform

    Notes:
    
    - Results are saved to `save_dir` as CSV files.
    - One CSV file with all entropy metrics.
    - One CSV file in a new Psi_block_df folder with Psi_block values for all blocks in a partition.
    - Separate file for p-values.
    - Separate files for each partition.
    - Alternate file names depending on sampling on or off.
    """, 
        epilog="""\
    Example:
      ember light_ember ~/ember_test/test_adata_cwc22.h5ad Genotype ~/ember_test/ --sample_id_col Mouse_ID --category_col Genotype --condition_col Sex --num_draws 50 --no_partition_pvals --n_cpus 4
    """
    )

    # --- Required Positional Arguments ---
    light_ember_parser.add_argument(
        "h5ad_dir",
        help="Path to the `.h5ad` file to process. Data should be log1p and depth normalized "
             "before running ember. Remove genes with <100 reads before running ember."
    )
    light_ember_parser.add_argument(
        "partition_label",
        help="Column in `.obs` used to partition cells for entropy calculations "
             "(e.g., 'celltype', 'Genotype', 'Age'). For interaction terms, create a new "
             "column concatenating multiple `.obs` columns with a semicolon (:)."
    )
    light_ember_parser.add_argument(
        "save_dir",
        help="Path to directory where results will be saved."
    )

    # --- Sampling Arguments ---
    sampling_group = light_ember_parser.add_argument_group('Sampling Parameters')
    sampling_group.add_argument(
        "--no_sampling",
        action="store_false",
        dest="sampling",
        help="Disable balanced sampling. Default: True. "
             "Note: If partition_pvals or block_pvals are enabled, sampling will be re-enabled."
    )
    sampling_group.add_argument(
        "--sample_id_col",
        type=str,
        default=None,
        help="Column in `.obs` with unique identifiers for each sample or replicate "
             "(e.g., 'sample_id', 'mouse_id')."
    )
    sampling_group.add_argument(
        "--category_col",
        type=str,
        default=None,
        help="Column in `.obs` defining the primary group to balance across "
             "(e.g., 'disease_status', 'mouse_strain'). Interchangeable with condition_col. "
             "For >2 variables, create interaction terms by concatenating columns with `:`."
    )
    sampling_group.add_argument(
        "--condition_col",
        type=str,
        default=None,
        help="Secondary column in `.obs` to balance sampling across (e.g., 'sex', 'treatment'). "
             "Interchangeable with category_col. Supports interaction terms."
    )
    sampling_group.add_argument(
        "--num_draws",
        type=int,
        default=100,
        help="Number of balanced subsets to generate (default: 100)."
    )
    sampling_group.add_argument(
        "--save_draws",
        action="store_true",
        help="Save intermediate sampled draws to save_dir (default: False)."
    )
    sampling_group.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed for reproducible draws (default: 42)."
    )

    # --- P-value Arguments ---
    pval_group = light_ember_parser.add_argument_group('P-value Parameters')
    pval_group.add_argument(
        "--no_partition_pvals",
        action="store_false",
        dest="partition_pvals",
        help="Disable permutation p-value calculation for the main partition. Default: True."
    )
    pval_group.add_argument(
        "--block_pvals",
        action="store_true",
        help="Enable permutation p-value calculation for a specific block. Default: False."
    )
    pval_group.add_argument(
        "--block_label",
        type=str,
        default=None,
        help="Specific value in 'partition_label' for block p-values. Required if --block_pvals is set."
    )
    pval_group.add_argument(
        "--n_pval_iterations",
        type=int,
        default=1000,
        help="Number of permutations for p-value calculation (default: 1000)."
    )

    # --- Performance Arguments ---
    perf_group = light_ember_parser.add_argument_group('Performance Parameters')
    perf_group.add_argument(
        "--n_cpus",
        type=int,
        default=1,
        help="Number of CPU cores to use for parallel processing (default: 1). "
             "Performance is I/O-bound and may not improve beyond 4–8 cores."
    )

    # =================================================================
    # ==                 COMMAND 2: generate_pvals                   ==
    # =================================================================
    generate_pvals_parser = subparsers.add_parser(
        "generate_pvals",
        help="Calculate empirical p-values for entropy metrics from permutation test results.",
        formatter_class=argparse.RawTextHelpFormatter,
        description="""\
    Calculate empirical p-values for entropy metrics from permutation test results.

    Entropy metrics generated:
        - Psi : Fraction of information explained by partition of choice
        - Psi_block : Specificity of information to a block
        - Zeta : Specificity to a partition / distance of Psi_blocks distribution from uniform
    """, 
        epilog="""\
    Example:
      ember generate_pvals test_adata_cwc22.h5ad Genotype ~/ember_test/ ~/ember_test/output Mouse_ID Genotype Sex --block_label WSBJ --n_cpus 4
    """
    )

    # --- Required Positional Arguments ---
    generate_pvals_parser.add_argument(
        "h5ad_dir",
        help="Path to the `.h5ad` file to process. Data should be log1p and depth normalized "
             "before running ember. Remove genes with <100 reads before running ember."
    )
    generate_pvals_parser.add_argument(
        "partition_label",
        help="Column in `.obs` used to partition cells for entropy calculations "
             "(e.g., 'celltype', 'Genotype', 'Age'). For interaction terms, create a new "
             "column concatenating multiple `.obs` columns with a semicolon (:)."
    )
    generate_pvals_parser.add_argument(
        "entropy_metrics_dir",
        help="Path to CSV with entropy metrics to use for generating p-values."
    )
    generate_pvals_parser.add_argument(
        "save_dir",
        help="Path to directory where results will be saved."
    )

    generate_pvals_parser.add_argument(
        "sample_id_col",
        help="Column in `.obs` with unique identifiers for each sample or replicate "
             "(e.g., 'sample_id', 'mouse_id')."
    )
    generate_pvals_parser.add_argument(
        "category_col",
        help="Column in `.obs` defining the primary group to balance across "
             "(e.g., 'disease_status', 'mouse_strain'). Interchangeable with condition_col. "
             "For >2 variables, create interaction terms by concatenating columns with `:`."
    )
    generate_pvals_parser.add_argument(
        "condition_col",
        help="Column in `.obs` containing the conditions to balance within each category "
             "(e.g., 'sex', 'treatment'). Interchangeable with category_col. Supports interaction terms."
    )

    # --- Block Argument ---
    generate_pvals_parser.add_argument(
        "--block_label",
        type=str,
        default=None,
        help="Block in partition to calculate p-values for. Default: None (Psi and Zeta only)."
    )

    # --- Performance & Iterations ---
    perf_group = generate_pvals_parser.add_argument_group('Performance Parameters')
    perf_group.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed for reproducible draws (default: 42)."
    )
    perf_group.add_argument(
        "--n_iterations",
        type=int,
        default=1000,
        help="Number of iterations to calculate p-values (default: 1000). "
             "Use fewer for quick runs, more for reliable results."
    )
    perf_group.add_argument(
        "--n_cpus",
        type=int,
        default=1,
        help="Number of CPUs to use for p-value calculation (default: 1). "
             "Set to -1 to use all available cores but one."
    )
    
    # --- Internal-use Arguments ---
    internal_group = generate_pvals_parser.add_argument_group('Internal Arguments (used by light_ember)')
    internal_group.add_argument(
        "--Psi_real",
        type=str,
        default=None,
        help="Observed Psi values for each gene (pd.Series). Not required for user runs."
    )
    internal_group.add_argument(
        "--Psi_block_df_real",
        type=str,
        default=None,
        help="Observed Psi_block values for all blocks in chosen partition (pd.DataFrame). "
             "Not required for user runs."
    )
    internal_group.add_argument(
        "--Zeta_real",
        type=str,
        default=None,
        help="Observed Zeta values for each gene (pd.Series). Not required for user runs."
    )
    
    # =================================================================
    # ==             COMMAND 3: plot_partition_specificity           ==
    # =================================================================
    plot_partition_specificity_parser = subparsers.add_parser(
        "plot_partition_specificity",
        help="Generate a Zeta vs. Psi scatter plot to visualize partition-specific genes.",
        formatter_class=argparse.RawTextHelpFormatter,
        description="""\
    Generate a Zeta vs. Psi scatter plot to visualize partition-specific genes.

    This function reads p-value data, colors genes based on their statistical
    significance for Psi and Zeta scores, and highlights top "marker" and
    "housekeeping" genes. Allows for custom highlighting of a user-provided
    gene list. Font size and color palette can be customized.
    """, 
        epilog="""\
    Example:
      ember plot_partition_specificity Genotype pvals_entropy_metrics_Genotype_WSBJ.csv output/ --highlight_genes Cwc22 --fontsize 25 
    """
    )

    # --- Required Positional Arguments ---
    plot_partition_specificity_parser.add_argument(
        "partition_label",
        help="Label for the partition being plotted, used in the plot title."
    )
    plot_partition_specificity_parser.add_argument(
        "pvals_dir",
        help="Path to input CSV containing p-values and scores (Psi, Zeta, FDRs). "
             "CSV must have gene names as its index."
    )
    plot_partition_specificity_parser.add_argument(
        "save_dir",
        help="Path where the output plot image will be saved."
    )

    # --- Optional Arguments ---
    plot_partition_specificity_parser.add_argument(
        "--highlight_genes",
        nargs="+",
        default=None,
        help="List of gene names to highlight and annotate on the plot (default: None)."
    )
    
    plot_partition_specificity_parser.add_argument(
        "--q_thresh",
        type=float,
        default=0.05,
        help="Threshold for q-values ('Psi q-value' and 'Zeta q-value'). Must be <= q_thresh (default: 0.05)."
    )
    
    plot_partition_specificity_parser.add_argument(
        "--fontsize",
        type=int,
        default=18,
        help="Base font size for plot labels and text (default: 18)."
    )
    plot_partition_specificity_parser.add_argument(
        "--custom_palette",
        nargs="+",
        default=None,
        help="List of 7 hex color codes to customize the color scheme. Order:\n"
             "['significant by psi', 'significant by zeta', 'highlight genes', "
             "'significant by both', 'circle markers', 'circle housekeeping genes', "
             "'significant by neither']. Default: None (uses built-in palette)."
    )

    # =================================================================
    # ==              COMMAND 4: plot_block_specificity              ==
    # =================================================================
    plot_block_specificity_parser = subparsers.add_parser(
        "plot_block_specificity",
        help="Generate a psi_block vs. Psi scatter plot to visualize block-specific genes.",
        formatter_class=argparse.RawTextHelpFormatter,
        description="""\
    Generate a psi_block vs. Psi scatter plot to visualize block-specific genes.

    This function reads p-value data, colors genes based on their statistical
    significance for Psi and psi_block scores, and highlights the top genes
    significant in both metrics. Allows for custom highlighting of a user-provided
    gene list. Font size and color palette can be customized.
    """, 
        epilog="""\
    Example:
      ember plot_block_specificity Genotype WSBJ pvals_entropy_metrics_Genotype_WSBJ.csv output/ --highlight_genes Cwc22 --fontsize 25 
    """
    )
    
    

    # --- Required Positional Arguments ---
    plot_block_specificity_parser.add_argument(
        "partition_label",
        help="Label for the partition, used in the plot title."
    )
    plot_block_specificity_parser.add_argument(
        "block_label",
        help="Label for the block variable (e.g., a cell type or condition)."
    )
    plot_block_specificity_parser.add_argument(
        "pvals_dir",
        help="Path to input CSV containing p-values and scores. "
             "CSV must have gene names as its index."
    )
    plot_block_specificity_parser.add_argument(
        "save_dir",
        help="Path where the output plot image will be saved."
    )

    # --- Optional Arguments ---
    plot_block_specificity_parser.add_argument(
        "--highlight_genes",
        nargs="+",
        default=None,
        help="List of gene names to highlight and annotate on the plot (default: None)."
    )
        
    plot_block_specificity_parser.add_argument(
        "--q_thresh",
        type=float,
        default=0.05,
        help="Threshold for q-values ('Psi q-value' and 'psi_block q-value'). Must be <= q_thresh (default: 0.05)."
    )
        
    plot_block_specificity_parser.add_argument(
        "--fontsize",
        type=int,
        default=18,
        help="Base font size for plot labels and text (default: 18)."
    )
    plot_block_specificity_parser.add_argument(
        "--custom_palette",
        nargs="+",
        default=None,
        help="List of 7 hex color codes to customize the color scheme. Order:\n"
             "['significant by psi', 'significant by psi_block', 'highlight genes', "
             "'significant by both', 'circle markers', 'circle housekeeping genes', "
             "'significant by neither']. Default: None (uses built-in palette)."
    )

    # =================================================================
    # ==               COMMAND 5: plot_sample_counts                 ==
    # =================================================================
    plot_sample_counts_parser = subparsers.add_parser(
        "plot_sample_counts",
        help="Generate a bar plot showing the number of unique individuals per category and condition.",
        formatter_class=argparse.RawTextHelpFormatter,
        description="""\
    Generate a bar plot showing the number of unique individuals per category and condition.

    This function reads an AnnData object from an .h5ad file in backed mode, 
    calculates the number of unique individuals for each combination of a given 
    category and condition, and visualizes these counts as a grouped bar plot.
    Font size can be customized.
    """, 
        epilog="""\
    Example:
      ember plot_sample_counts test_adata_cwc22.h5ad ~/ember_test/output Mouse_ID Genotype Sex --fontsize 20
    """
    )

    # --- Required Positional Arguments ---
    plot_sample_counts_parser.add_argument(
        "h5ad_dir",
        help="Path to the input AnnData (.h5ad) file."
    )
    plot_sample_counts_parser.add_argument(
        "save_dir",
        help="Path to directory to save the output plot image."
    )
    plot_sample_counts_parser.add_argument(
        "sample_id_col",
        help="Column name in `.obs` that contains unique sample IDs."
    )
    plot_sample_counts_parser.add_argument(
        "category_col",
        help="Column name to use for the primary categories on the x-axis."
    )
    plot_sample_counts_parser.add_argument(
        "condition_col",
        help="Column name to use for grouping the bars (hue)."
    )

    # --- Optional Arguments ---
    plot_sample_counts_parser.add_argument(
        "--fontsize",
        type=int,
        default=18,
        help="Base font size for plot labels and text (default: 18)."
    )

    # =================================================================
    # ==                COMMAND 6: plot_psi_blocks                   ==
    # =================================================================
    plot_psi_blocks_parser = subparsers.add_parser(
        "plot_psi_blocks",
        help="Generate a bar plot of mean psi block values with error bars for a given gene.",
        formatter_class=argparse.RawTextHelpFormatter,
        description="""\
    Generates and saves a bar plot of mean psi block values with error bars.

    This function reads two CSV files from a specified directory: one for mean
    psi block values and one for standard deviations. It plots the mean values
    for a specific gene as a bar plot with corresponding standard deviation
    error bars. Font size can be customized.
    """,
        
        epilog="""\
    Example:
      ember plot_psi_blocks Cwc22 Genotype ~/ember_test/output/Psi_block_df/ ~/ember_test/output/figs --fontsize 30
    """
    )

    # --- Required Positional Arguments ---
    plot_psi_blocks_parser.add_argument(
        "gene_name",
        help="Name of the gene (row) to select and plot from the CSV files."
    )
    plot_psi_blocks_parser.add_argument(
        "partition_label",
        help="Partition label used to find the correct files (e.g., 'Genotype')."
    )
    plot_psi_blocks_parser.add_argument(
        "psi_block_df_dir",
        help="Directory containing the mean and std CSV files. Files must be named "
             "'mean_Psi_block_df_{partition_label}.csv' and "
             "'std_Psi_block_df_{partition_label}.csv'."
    )
    plot_psi_blocks_parser.add_argument(
        "save_dir",
        help="Path to directory to save the output plot image."
    )

    # --- Optional Arguments ---
    plot_psi_blocks_parser.add_argument(
        "--fontsize",
        type=int,
        default=18,
        help="Base font size for plot labels and text (default: 18)."
    )
    
    # =================================================================
    # ==            COMMAND 7: highly_specific_to_partition          ==
    # =================================================================
    highly_specific_parser = subparsers.add_parser(
        "highly_specific_to_partition",
        help="Identify genes that are highly significant and specific to the partition (high Psi and Zeta).",
        formatter_class=argparse.RawTextHelpFormatter,
        description="""\
    Identifies significant and specific genes from an ember generated 
    p-values/q-values CSV file based on thresholds for Psi, Zeta, and q-values.
    The resulting DataFrame is saved as "highly_specific_genes_to_{partition_label}.csv".
    """,
        epilog="""\
    Example:
      ember highly_specific_to_partition Genotype pvals_entropy_metrics_Genotype.csv output/ --psi_thresh 0.6 --zeta_thresh 0.7
    """
    )

    # --- Required Positional Arguments ---
    highly_specific_parser.add_argument(
        "partition_label",
        help="Name of partition used to generate entropy metrics, used to label saved csv."
    )
    highly_specific_parser.add_argument(
        "pvals_dir",
        help="Path to the input CSV file (must contain 'Psi q-value', 'Zeta q-value', 'Psi', and 'Zeta')."
    )
    highly_specific_parser.add_argument(
        "save_dir",
        help="Directory where the filtered results CSV will be saved."
    )

    # --- Optional Threshold Arguments ---
    thresh_group = highly_specific_parser.add_argument_group('Threshold Parameters')
    thresh_group.add_argument(
        "--psi_thresh",
        type=float,
        default=0.5,
        help="Threshold for Psi values. Genes must have Psi > psi_thresh (default: 0.5)."
    )
    thresh_group.add_argument(
        "--zeta_thresh",
        type=float,
        default=0.5,
        help="Threshold for Zeta values. Genes must have Zeta > zeta_thresh (default: 0.5)."
    )
    thresh_group.add_argument(
        "--q_thresh",
        type=float,
        default=0.05,
        help="Threshold for q-values ('Psi q-value' and 'Zeta q-value'). Must be <= q_thresh (default: 0.05)."
    )

    # =================================================================
    # ==             COMMAND 8: highly_specific_to_block             ==
    # =================================================================
    highly_specific_block_parser = subparsers.add_parser(
        "highly_specific_to_block",
        help="Identify genes that are highly significant and specific to a partition block (high Psi and psi_block).",
        formatter_class=argparse.RawTextHelpFormatter,
        description="""\
    Identifies significant and specific genes from an ember generated 
    p-values/q-values CSV file based on thresholds for Psi, psi_block, and q-values. 
    REsultant genes are potential marker genes.  
    The resulting DataFrame is saved as "highly_specific_genes_by_{partition_label}_{block_label}.csv".
    """,
        epilog="""\
    Example:
      ember highly_specific_to_block Genotype WSBJ pvals_entropy_metrics_Genotype_WSBJ.csv output/ --psi_thresh 0.6 --psi_block_thresh 0.7
    """
    )

    # --- Required Positional Arguments ---
    highly_specific_block_parser.add_argument(
        "partition_label",
        help="Name of partition used to generate entropy metrics."
    )
    highly_specific_block_parser.add_argument(
        "block_label",
        help="Name of block in partition used to generate entropy metrics."
    )
    highly_specific_block_parser.add_argument(
        "pvals_dir",
        help="Path to the input CSV file (must contain 'Psi q-value', 'psi_block q-value', 'Psi', and 'psi_block')."
    )
    highly_specific_block_parser.add_argument(
        "save_dir",
        help="Directory where the filtered results CSV will be saved."
    )

    # --- Optional Threshold Arguments ---
    thresh_group_block = highly_specific_block_parser.add_argument_group('Threshold Parameters')
    thresh_group_block.add_argument(
        "--psi_thresh",
        type=float,
        default=0.5,
        help="Threshold for Psi values. Genes must have Psi > psi_thresh (default: 0.5)."
    )
    thresh_group_block.add_argument(
        "--psi_block_thresh",
        type=float,
        default=0.5,
        help="Threshold for psi_block values. Genes must have psi_block > psi_block_thresh (default: 0.5)."
    )
    thresh_group_block.add_argument(
        "--q_thresh",
        type=float,
        default=0.05,
        help="Threshold for q-values ('Psi q-value' and 'psi_block q-value'). Must be <= q_thresh (default: 0.05)."
    )

    # =================================================================
    # ==           COMMAND 9: non_specific_to_partition              ==
    # =================================================================
    non_specific_parser = subparsers.add_parser(
        "non_specific_to_partition",
        help="Identify genes that are highly significant but non-specific to the partition (high Psi, low Zeta).",
        formatter_class=argparse.RawTextHelpFormatter,
        description="""\
    Identifies significant but non-specific genes (potential housekeeping genes) from an ember generated 
    p-values/q-values CSV file based on thresholds for Psi, Zeta, and q-values.
    Note: The Zeta filter is reversed, keeping Zeta < zeta_thresh.
    The resulting DataFrame is saved as "non_specific_genes_to_{partition_label}.csv".
    """,
        epilog="""\
    Example:
      ember non_specific_to_partition Genotype pvals_entropy_metrics_Genotype.csv output/ --psi_thresh 0.6 --zeta_thresh 0.2
    """
    )

    # --- Required Positional Arguments ---
    non_specific_parser.add_argument(
        "partition_label",
        help="Name of partition used to generate entropy metrics, used to label saved csv."
    )
    non_specific_parser.add_argument(
        "pvals_dir",
        help="Path to the input CSV file (must contain 'Psi q-value', 'Zeta q-value', 'Psi', and 'Zeta')."
    )
    non_specific_parser.add_argument(
        "save_dir",
        help="Directory where the filtered results CSV will be saved."
    )

    # --- Optional Threshold Arguments ---
    thresh_group_non = non_specific_parser.add_argument_group('Threshold Parameters')
    thresh_group_non.add_argument(
        "--psi_thresh",
        type=float,
        default=0.5,
        help="Threshold for Psi values. Genes must have Psi > psi_thresh (default: 0.5)."
    )
    thresh_group_non.add_argument(
        "--zeta_thresh",
        type=float,
        default=0.5,
        help="Threshold for Zeta values. Genes must have Zeta < zeta_thresh (default: 0.5) to be considered non-specific."
    )
    thresh_group_non.add_argument(
        "--q_thresh",
        type=float,
        default=0.05,
        help="Threshold for q-values ('Psi q-value' and 'Zeta q-value'). Must be <= q_thresh (default: 0.05)."
    )
    
    return parser

def run_command(parser, args):
    """Takes a parser instance and parsed args, then executes the correct command."""
    kwargs = vars(args)
    command_to_run = kwargs.pop('command', None)

    command_map = {
        "light_ember": light_ember,
        "generate_pvals": generate_pvals,
        "plot_partition_specificity": plot_partition_specificity,
        "plot_block_specificity": plot_block_specificity,
        "plot_sample_counts": plot_sample_counts,
        "plot_psi_blocks": plot_psi_blocks,
        "highly_specific_to_partition": highly_specific_to_partition, 
        "highly_specific_to_block": highly_specific_to_block,         
        "non_specific_to_partition": non_specific_to_partition,       
    }

    func_to_run = command_map.get(command_to_run)

    if func_to_run:
        func_to_run(**kwargs)
    else:
        # If no command was provided or it's unknown, print help and exit
        print(f"Error: Unknown command '{command_to_run}'", file=sys.stderr)
        parser.print_help()
        sys.exit(1)

if __name__ == "__main__":
    parser = create_parser()
    args = parser.parse_args()
    run_command(parser, args)