import shutil

metapackage = config['metapackage_path']
spkgs = config['singlem_packages'] if 'singlem_packages' in config else {}
test_file = config['test_file'] if 'test_file' in config else None

taxonomy_database_name = '--taxonomy-database-name ' + config['taxonomy_database_name'] if 'taxonomy_database_name' in config else ''
taxonomy_database_version = '--taxonomy-database-version ' + config['taxonomy_database_version'] if 'taxonomy_database_version' in config else ''
diamond_prefilter_performance_params = '--diamond-prefilter-performance-parameters ' + config['diamond_prefilter_performance_params'] if 'diamond_prefilter_performance_params' in config else ''
diamond_assign_taxonomy_performance_params = '--diamond-taxonomy-assignment-performance-parameters ' + config['diamond_assign_taxonomy_performance_params'] if 'diamond_assign_taxonomy_performance_params' in config else ''
makeidx_sensitivity_params = '--makeidx-sensitivity-params ' + config['makeidx_sensitivity_params'] if 'makeidx_sensitivity_params' in config else ''

output_directory = config['singlem_new_metapackage_working_directory'] if 'singlem_new_metapackage_working_directory' in config else "singlem_new_metapackage_working_directory"

# If spkgs are already defined in the config, copy them into the output
# directory so spkg creation steps can be skipped. Assume they have not already
# been chainsaw'd.
if 'pregenerated_singlem_packages' in config:
    for spkg in config['pregenerated_singlem_packages']:
        spkg_key = os.path.basename(spkg)
        final_path = os.path.join(output_directory, "spkgs", f"{spkg_key}.spkg")
        if not os.path.exists(final_path):
            print(f"Copying {spkg} to {output_directory} ..")
            # copy spkg directory to output_directory
            shutil.copytree(spkg, final_path)
            # create done touch
            # mkdir-p
            os.makedirs(os.path.join(output_directory, "done"), exist_ok=True)
            with open(os.path.join(output_directory, "done", f"{spkg_key}.singlem_create.done"), 'w') as f:
                pass
        if spkg_key in spkgs:
            raise Exception(f"Duplicate spkg name: {spkg_key}")
        # taxonomy_file = lambda wildcards: spkgs[wildcards.name]['taxonomy_file'],
        # gene_description = lambda wildcards: spkgs[wildcards.name]['gene_description'],
        # target_domains = lambda wildcards: spkgs[wildcards.name]['target_domains'],
        spkgs[spkg_key] = {
            'taxonomy_file': None,
            'gene_description': None,
            'target_domains': None,
            'aa_sequences_file': None,
        } # Hack of course

print("Generating metapackage from {} singlem packages".format(len(spkgs)))
if len(spkgs) == 0:
    raise Exception("No singlem packages defined, cannot create metapackage")

rule all:
    input:
        metapackage,
        os.path.join(output_directory, "done", "create_metapackage.done"),
        # Verifying/testing is optional
        os.path.join(output_directory, "done", "verify_tested_metapackage.done") if test_file else [],
        [os.path.join(output_directory, "chainsaw", f"{name}.spkg") for name in spkgs.keys()],

# Definition of HMM path is optional
def graftm_create_hmm_arg(spkg_def):
    if 'hmm_file' in spkg_def:
        return f"--hmm {spkg_def['hmm_file']}"
    else:
        return ""

rule create_gpkg:
    output:
        gpkg = directory(os.path.join(output_directory, "gpkgs", "{name}.gpkg")),
        done = touch(os.path.join(output_directory, "done", "{name}.gpkg.done")),
    log:
        log = os.path.join(output_directory, "logs", "{name}.gpkg.log")
    threads: 4
    benchmark: os.path.join(output_directory, "benchmarks", "{name}.gpkg.benchmark")
    params:
        taxonomy_file = lambda wildcards: spkgs[wildcards.name]['taxonomy_file'],
        aa_sequences_file = lambda wildcards: spkgs[wildcards.name]['aa_sequences_file'],
        hmm_arg = lambda wildcards: graftm_create_hmm_arg(spkgs[wildcards.name]),
    conda:
        "envs/singlem.yml"
    shell:
        "graftM create --min_aligned_percent 0 --force --sequences {params.aa_sequences_file} {params.hmm_arg} --output {output.gpkg} --threads {threads} --no_tree --taxonomy {params.taxonomy_file} 2> {log}"
    
def get_final_hmm_file(spkg_def, gpkg):
    if 'hmm_file' in spkg_def:
        return spkg_def['hmm_file']
    else:
        return os.path.join(gpkg, '*_align.hmm')

rule create_alignment:
    input:
        # Only required when hmm_file is not defined, but unsure how to make this conditional
        gpkg = os.path.join(output_directory, "gpkgs", "{name}.gpkg"),
    output:
        alignment = os.path.join(output_directory, "alns", "{name}.aln"),
        done = touch(os.path.join(output_directory, "done", "{name}.alignment.done"))
    log:
        log = os.path.join(output_directory, "logs", "{name}.alignment.log")
    benchmark: os.path.join(output_directory, "benchmarks", "{name}.alignment.benchmark")
    params:
        aa_sequences_file = lambda wildcards: spkgs[wildcards.name]['aa_sequences_file'],
        hmm_file = lambda wildcards, input: get_final_hmm_file(spkgs[wildcards.name], input.gpkg),
    conda:
        "envs/singlem.yml"
    shell:
        "bash -c 'hmmalign --amino {params.hmm_file} {params.aa_sequences_file} |seqmagick convert --input-format stockholm --output-format fasta - {output.alignment}' 2> {log}"

rule singlem_seqs:
    input:
        alignment = os.path.join(output_directory, "alns", "{name}.aln"),
        # Only required when hmm_file is not defined, but unsure how to make this conditional
        gpkg = os.path.join(output_directory, "gpkgs", "{name}.gpkg"),
    output:
        window_position_file = os.path.join(output_directory, "{name}.window_position.txt"),
        done = touch(os.path.join(output_directory, "done", "{name}.singlem_seqs.done"))
    log:
        log = os.path.join(output_directory, "logs", "{name}.singlem_seqs.log")
    benchmark: os.path.join(output_directory, "benchmarks", "{name}.singlem_seqs.benchmark")
    params:
        aa_sequences_file = lambda wildcards: spkgs[wildcards.name]['aa_sequences_file'],
        hmm_file = lambda wildcards, input: get_final_hmm_file(spkgs[wildcards.name], input.gpkg),
    conda:
        "envs/singlem.yml"
    shell:
        # There's an unreleased fix post 0.16.0
        "~/git/singlem/bin/singlem seqs --alignment {input.alignment} --alignment-type aa --hmm {params.hmm_file} >{output.window_position_file} 2> {log}"

rule singlem_create:
    input:
        gpkg = os.path.join(output_directory, "gpkgs", "{name}.gpkg"),
        window_position_file = os.path.join(output_directory, "{name}.window_position.txt"),
    output:
        spkg = directory(os.path.join(output_directory, "spkgs", "{name}.spkg")),
        done = touch(os.path.join(output_directory, "done", "{name}.singlem_create.done"))
    log:
        log = os.path.join(output_directory, "logs", "{name}.singlem_create.log")
    benchmark: os.path.join(output_directory, "benchmarks", "{name}.singlem_create.benchmark")
    params:
        taxonomy_file = lambda wildcards: spkgs[wildcards.name]['taxonomy_file'],
        gene_description = lambda wildcards: spkgs[wildcards.name]['gene_description'],
        target_domains = lambda wildcards: spkgs[wildcards.name]['target_domains'],
    conda:
        "envs/singlem.yml"
    shell:
        # There's an unreleased fix post 0.16.0
        "~/git/singlem/bin/singlem create --input-graftm-package {input.gpkg} --output {output.spkg} --target-domains {params.target_domains} --hmm-position `cat {input.window_position_file}` --gene-description '{params.gene_description}' --force --input-taxonomy {params.taxonomy_file} 2> {log}"

def get_regenerate_decoy_args():
    if 'decoy_sequences' in config:
        return f"--euk-sequences {config['decoy_sequences']} --euk-taxonomy {config['decoy_taxonomy']}"
    else:
        return "--no-further-euks"

rule add_decoy_sequences_via_regenerate:
    input:
        spkg = os.path.join(output_directory, "spkgs", "{name}.spkg"),
        done = os.path.join(output_directory, "done", "{name}.singlem_create.done")
    output:
        spkg = directory(os.path.join(output_directory, "regenerated", "{name}.spkg")),
        done = touch(os.path.join(output_directory, "done", "{name}.regenerated.done"))
    log:
        log = os.path.join(output_directory, "logs", "{name}.regenerated.log")
    conda:
        "envs/singlem.yml"
    params:
        decoy_sequences_arg = get_regenerate_decoy_args(),
        sequence_prefix = lambda w: w.name+'~',
        spkg_seq = lambda wildcards: spkgs[wildcards.name]['aa_sequences_file'],
        spkg_tax = lambda wildcards: spkgs[wildcards.name]['taxonomy_file'],
    shell: # "--sequence-prefix {params.sequence_prefix} "
        # Fix post 0.16.0
        "~/git/singlem/bin/singlem regenerate "
        "--input-singlem-package {input.spkg} "
        "--input-sequences {params.spkg_seq} "
        "--input-taxonomy {params.spkg_tax} "
        " {params.decoy_sequences_arg} "
        "--output-singlem-package {output.spkg} "
        "&> {log}"

rule chainsaw:
    input:
        spkg = os.path.join(output_directory, "regenerated", "{name}.spkg"),
        done = os.path.join(output_directory, "done", "{name}.regenerated.done")
    output:
        spkg = directory(os.path.join(output_directory, "chainsaw", "{name}.spkg")),
        done = touch(os.path.join(output_directory, "done", "{name}.chainsaw.done"))
    log:
        log = os.path.join(output_directory, "logs", "{name}.chainsaw.log")
    conda:
        "envs/singlem.yml"
    shell:
        # There's an unreleased fix post 0.16.0
        "~/git/singlem/bin/singlem chainsaw --input-singlem-package {input.spkg} --output-singlem-package {output.spkg} 2> {log}"

rule create_metapackage:
    input:
        spkgs = [os.path.join(output_directory, "chainsaw", f"{name}.spkg") for name in spkgs.keys()],
        done = [os.path.join(output_directory, "done", f"{name}.chainsaw.done") for name in spkgs.keys()],
        taxonomy_database = taxonomy_database_name,
        taxonomy_version = taxonomy_database_version,
        diamond_prefilter_performance_params = diamond_prefilter_performance_params,
        diamond_assign_taxonomy_performance_params = diamond_assign_taxonomy_performance_params,
        makeidx_sensitivity_params = makeidx_sensitivity_params,
        # sdb = os.path.join(output_directory, "a.sdb") #TODO
    output:
        metapackage = directory(metapackage),
        done = touch(os.path.join(output_directory, "done", "create_metapackage.done"))
    log:
        log = os.path.join(output_directory, "logs", "create_metapackage.log")
    benchmark: os.path.join(output_directory, "benchmarks", "create_metapackage.benchmark")
    threads: workflow.cores
    conda:
        "envs/singlem.yml"
    shell:
        # "singlem metapackage --singlem-package {input.spkg} --nucleotide-sdb {input.sdb} --no-taxon-genome-lengths --metapackage {output.metapackage} --threads {threads} 2> {log}"
        # "singlem metapackage --singlem-packages {input.spkgs} --no-nucleotide-sdb --no-taxon-genome-lengths --metapackage {output.metapackage} --threads {threads} 2> {log}"
        "singlem metapackage --singlem-packages {input.spkgs} --no-nucleotide-sdb --no-taxon-genome-lengths --metapackage {output.metapackage} --threads {threads} {input.taxonomy_database} {input.taxonomy_version} {input.diamond_prefilter_performance_params} {input.diamond_assign_taxonomy_performance_params} {input.makeidx_sensitivity_params} 2> {log}"

rule test_metapackage:
    input:
        metapackage = metapackage,
        test_file = test_file if test_file else []
    output:
        otu_table = os.path.join(output_directory, "otu_table.tsv"),
        done = touch(os.path.join(output_directory, "done", "test_metapackage.done"))
    log:
        log = os.path.join(output_directory, "logs", "test_metapackage.log")
    benchmark: os.path.join(output_directory, "benchmarks", "test_metapackage.benchmark")
    conda:
        "envs/singlem.yml"
    shell:
        "singlem pipe --forward {input.test_file} --metapackage {input.metapackage} --otu-table {output.otu_table} --threads {threads} --assignment-method diamond --debug --output-extras 2> {log}"

rule verify_tested_metapackage:
    input:
        otu_table = os.path.join(output_directory, "otu_table.tsv"),
        done = os.path.join(output_directory, "done", "test_metapackage.done"),
        test_file = test_file if test_file else []
    output:
        done = touch(os.path.join(output_directory, "done", "verify_tested_metapackage.done"))
    run:
        # Calculate number of sequences in test file
        import extern
        stdout = extern.run(f"grep -c '>' {input.test_file}")
        num_seqs = int(stdout.strip())
        # Calculate number of sequences in otu table
        import polars as pl
        df = pl.read_csv(input.otu_table, separator='\t')
        all_seqs = []
        for row in df.rows(named=True):
            for n in row['read_names'].split(' '):
                all_seqs.append(n)
        print("Number of sequences in otu table:", len(all_seqs))
        print("Number of sequences in test file:", num_seqs)
        if len(all_seqs) != num_seqs:
            raise Exception(f"Number of sequences in otu table ({len(all_seqs)}) does not match number of sequences in test file ({num_seqs})")
