#############
### Setup ###
#############
"""
Steps:
Create and activate base environment (update_metapackage.yml)
Update config.yaml
Run `snakemake --cores 64 --use-conda`
"""

import pandas as pd
import os

configfile: "config.yaml"

hmms_and_names = pd.read_csv("hmms_and_names", sep="\t").set_index("name", drop=False)
singlem_bin = "../../singlem/main.py"
output_dir = config["output_dir"]
logs_dir = output_dir + "/logs"
checkm2_output_dirs = config['checkm2_output_dirs'] if 'checkm2_output_dirs' in config else None # not needed for R220 and beyond (actually R220 was missing some data, so using it there too)
taxonomy_database_name = config["taxonomy_database_name"]
taxonomy_database_version = config["taxonomy_database_version"]

# TODO: This finding genomes at the start is slow, and is run each time
# snakemake is started up in a queued job. Would be better to do this once, and
# then reuse cached results.
print("Finding genomes...")
if 'compressed_genome_data' in config and config['compressed_genome_data']:
    genome_paths = [os.path.join(d, f) for d, _, fs in os.walk(config["gtdb_protein_faa_reps"], followlinks=True) for f in fs if f.endswith(".faa.gz")]
else:
    genome_paths = [os.path.join(d, f) for d, _, fs in os.walk(config["gtdb_protein_faa_reps"], followlinks=True) for f in fs if f.endswith(".faa")]
genomes = [g.removeprefix(config["gtdb_protein_faa_reps"]).removeprefix("/") for g in genome_paths]

size_genome_splits = 50
n_genome_splits = len(genomes) // size_genome_splits + 1

zenodo_backpack = os.path.join(output_dir, config["new_metapackage"]+'.zb.tar.gz')
zenodo_backpack_version = config['new_version_number']

# genomes = genomes[:50]+genomes[(len(genomes)-50):] # debug, so have archaea and bacteria
print("Found {} genomes.".format(len(genomes)))

if not "max_threads" in config: config["max_threads"] = 8

wildcard_constraints:
    genome = r"(archaea|bacteria)/\w+.\d+_protein.faa"

def get_mem_mb(wildcards, threads, attempt):
    return 8 * 1000 * threads * attempt

rule all:
    input:
        zenodo_backpack,
        output_dir + "/sra/done"

rule zenodo_backpack:
    input:
        metapackage = output_dir + "/metapackage/" + config["new_metapackage"],
    output:
        zenodo_backpack=zenodo_backpack,
        done=touch(output_dir + "/zenodo_backpack.done")
    conda:
        "envs/zenodo_backpack.yml"
    shell:
        "zenodo_backpack create --input-directory {input.metapackage} --output-file {output.zenodo_backpack} --data-version {zenodo_backpack_version}"

####################
### HMM searches ###
####################
rule pfam_search:
    input:
        # config["gtdb_protein_faa_reps"] + "/{genome}"
    output:
        done=output_dir + "/pfam_search.done" #"/hmmsearch/pfam/{genome}.tsv"
    params:
        pfams = config["pfams"],
        genome_ids = genomes,
        output_dir = output_dir,
    threads:
        config["max_threads"]
    resources:
        mem_mb = get_mem_mb,
        runtime = 48*60,
    log:
        logs_dir + "/pfam_search.log" #"/hmmsearch/pfam/{genome}.log"
    conda:
        "envs/hmm_tools.yml"
    script:
        "scripts/pfam_search.py"

rule tigrfam_search:
    input:
        # config["gtdb_protein_faa_reps"] + "/{genome}"
    output:
        done=output_dir + "/tigrfam_search.done" #output_dir + "/hmmsearch/tigrfam/{genome}.tsv"
    params:
        tigrfams = config["tigrfams"],
        log_dir=logs_dir + "/hmmsearch/tigrfam/",
        output_dir=output_dir + "/hmmsearch/tigrfam/",
        genome_ids = genomes,
    threads:
        config["max_threads"]
    resources:
        mem_mb = get_mem_mb,
        runtime = 48*60,
    conda:
        "envs/tigrfam_search.yml"
    script:
        "scripts/tigrfam_search.py"

rule get_matches_no_dup:
    input:
        output_dir + "/pfam_search.done",
        output_dir + "/tigrfam_search.done",
    output:
        done=output_dir + "/get_matches_no_dup.done"
    params:
        pfam_search_directory = output_dir + "/hmmsearch/pfam/",
        tigrfam_search_directory = output_dir + "/hmmsearch/tigrfam/",
        hmms_and_names = "hmms_and_names",
        output_dir = output_dir + "/hmmsearch/matches",
        logs_dir = logs_dir + "/hmmsearch/matches",
        genome_ids = genomes,
    threads:
        config["max_threads"]
    resources:
        mem_mb = get_mem_mb,
        runtime = 48*60,
    conda:
        "envs/get_matches_no_dup.yml"
    script:
        "scripts/get_matches_no_dup_all.py"

rule mfqe:
    input:
        done=output_dir + "/get_matches_no_dup.done",
    params:
        # fasta = config["gtdb_protein_faa_reps"] + "/{genome}",
        # output_tsv = os.path.join(output_dir, f"{genome}.fam") + "/hmmsearch/matches",
        fam_directory = output_dir + "/hmmsearch/matches",
        # seqnames = output_dir + "/hmmsearch/matches/{genome}.tmp"
        output_dir = output_dir + "/hmmsearch/matches",
        logs_dir = logs_dir + "/mfqe",
        genome_ids = genomes,
    output:
        done=output_dir + "/mfqe.done",
    threads: 1
    resources:
        mem_mb = get_mem_mb,
        runtime = 48*60,
    # log:
    #     logs_dir + "/hmmsearch/matches/{genome}_mfqe.log"
    conda:
        "../../singlem.yml"
    script:
        "scripts/mfqe_all.py"

########################
### Package creation ###
########################
rule transpose_hmms_with_sequences:
    input:
        output_dir + "/mfqe.done",
    output:
        done = output_dir + "/transpose_hmms_with_sequences.done",
    params:
        output_dir = directory(output_dir + "/hmmseq/"), # output is file for each spkg
        hmms_and_names = "hmms_and_names",
        matches_dir = output_dir + "/hmmsearch/matches",
        logs_dir = logs_dir + "/transpose_hmms_with_sequences",
        genome_ids = genomes,
    threads:
        config["max_threads"]
    resources:
        mem_mb = get_mem_mb,
        runtime = 48*60,
    conda:
        "envs/transpose_hmms_with_sequences.yml"
    script:
        "scripts/transpose_hmms_with_sequences_all.py"

rule concatenate_seqs_and_taxonomies:
    input:
        done = output_dir + "/transpose_hmms_with_sequences.done",
    output:
        done = output_dir + "/concatenate_seqs_and_taxonomies.{spkg}.done",
    params:
        hmmseq_dir = output_dir + "/hmmseq/",
        concat_dir = output_dir + "/hmmseq_concat/",
        spkg_seq = output_dir + "/hmmseq_concat/{spkg}.faa",
        spkg_tax = output_dir + "/hmmseq_concat/{spkg}_taxonomy.tsv",
    threads: 1
    resources:
        mem_mb = get_mem_mb,
        runtime = 48*60,
    shell:
        "mkdir -p {params.concat_dir} && find {params.hmmseq_dir} |grep -F {wildcards.spkg} |grep .faa$ |parallel -j1 --ungroup cat {{}} >{params.spkg_seq} && find {params.hmmseq_dir} |grep -F {wildcards.spkg} |grep _taxonomy.tsv$ |parallel -j1 --ungroup cat {{}} >{params.spkg_tax} && touch {output.done}"


rule create_SingleM_packages:
    input:
        output_dir + "/concatenate_seqs_and_taxonomies.{spkg}.done",
    output:
        directory(output_dir + "/packages/{spkg}.spkg")
    params:
        singlem = singlem_bin,
        hmms_and_names = "hmms_and_names",
        uniprot_seq = config["uniprot_seq"],
        uniprot_tax = config["uniprot_tax"],
        spkg = config["metapackage"] + "/{spkg}.spkg",
        spkg_seq = output_dir + "/hmmseq_concat/{spkg}.faa",
        spkg_tax = output_dir + "/hmmseq_concat/{spkg}_taxonomy.tsv",
        spkg_name = lambda wildcards: hmms_and_names.loc[wildcards.spkg, "name_without_number"]
    threads: 1
    resources:
        mem_mb = get_mem_mb,
        runtime = 48*60,
    log:
        logs_dir + "/packages/{spkg}.log"
    conda:
        "../../singlem.yml"
    shell:
        "{params.singlem} regenerate "
        "--input-singlem-package {params.spkg} "
        "--input-sequences {params.spkg_seq} "
        "--input-taxonomy {params.spkg_tax} "
        "--euk-sequences {params.uniprot_seq} "
        "--euk-taxonomy {params.uniprot_tax} "
        "--output-singlem-package {output} "
        "--sequence-prefix {params.spkg_name}~ "
        "&> {log}"

############################
### Metapackage creation ###
############################
rule create_draft_SingleM_metapackage:
    input:
        packages = expand(output_dir + "/packages/{spkg}.spkg", spkg = hmms_and_names.index)
    output:
        directory(output_dir + "/draft_metapackage.smpkg")
    params:
        singlem = singlem_bin
    threads:
        config["max_threads"]
    resources:
        mem_mb = get_mem_mb,
        runtime = 48*60,
    log:
        logs_dir + "/draft_metapackage.log"
    conda:
        "../../singlem.yml"
    shell:
        "{params.singlem} metapackage "
        "--singlem-packages {input.packages} "
        "--no-nucleotide-sdb "
        "--no-taxon-genome-lengths "
        "--metapackage {output} "
        "--threads {threads} "
        "&> {log}"

rule SingleM_transcripts_genome_lists:
    input:
        dir = config["gtdb_protein_fna_reps"],
    output:
        dir = directory(output_dir + "/transcript_lists"),
    threads:
        1
    resources:
        mem_mb = get_mem_mb,
        runtime = 5*60,
    params:
        size = size_genome_splits,
    shell:
        "mkdir -p {output.dir} && "
        "find -L {input.dir} \\( -name '*.fa.gz' -o -name '*.fna.gz' \\) "
        "| parallel --pipe -N{params.size} 'cat > {output.dir}/{{#}}.txt' "

rule SingleM_transcripts:
    input:
        dir = config["gtdb_protein_fna_reps"],
        metapackage = output_dir + "/draft_metapackage.smpkg",
        lists = output_dir + "/transcript_lists",
    output:
        otu_table = output_dir + "/transcripts/{n}.otu_table.tsv",
    threads: 1
    resources:
        mem_mb = get_mem_mb,
        runtime = 10*60,
    params:
        singlem = singlem_bin,
        input_files = output_dir + "/transcript_lists/{n}.txt"
    conda:
        "../../singlem.yml"
    log:
        logs_dir + "/transcripts/{n}.log"
    shell:
        "{params.singlem} pipe "
        "--forward $(cat {params.input_files} | tr '\n' ' ') "
        "--metapackage {input.metapackage} "
        "--otu-table {output.otu_table} "
        "--no-assign-taxonomy "
        "&> {log} "

rule collect_SingleM_transcripts:
    input:
        expand(output_dir + "/transcripts/{n}.otu_table.tsv", n = [i + 1 for i in range(n_genome_splits)])
    output:
        touch(output_dir + "/transcripts/done")

rule assign_taxonomy:
    input:
        touch = output_dir + "/transcripts/done"
    output:
        output_dir + "/taxonomy/transcripts.otu_table.tsv"
    threads: 1
    resources:
        mem_mb = get_mem_mb,
        runtime = 48*60,
    params:
        bac_metadata = config["gtdb_bac_metadata"],
        arc_metadata = config["gtdb_arc_metadata"],
        dir = output_dir + "/transcripts",
    conda:
        "../../singlem.yml"
    log:
        logs_dir + "/taxonomy.log"
    shell:
        "../assign_gtdb_taxonomy_to_gtdb_genomes.py "
        "--otu-table-list <( find {params.dir} -name '*.otu_table.tsv' ) "
        "--gtdb-bac {params.bac_metadata} "
        "--gtdb-arc {params.arc_metadata} "
        "> {output} "
        "2> {log}"

rule make_sdb:
    input:
        output_dir + "/taxonomy/transcripts.otu_table.tsv"
    output:
        directory(output_dir + "/taxonomy/transcripts.sdb")
    threads:
        config["max_threads"]
    resources:
        mem_mb = get_mem_mb,
        runtime = 48*60,
    params:
        singlem = singlem_bin
    log:
        logs_dir + "/taxonomy/makedb.log"
    conda:
        "../../singlem.yml"
    shell:
        "{params.singlem} makedb "
        "--otu-table {input} "
        "--db {output} "
        "--threads {threads} "
        "&> {log}"

rule create_SingleM_metapackage:
    input:
        packages = expand(output_dir + "/packages/{spkg}.spkg", spkg = hmms_and_names.index),
        sdb = output_dir + "/taxonomy/transcripts.sdb",
        genome_sizes = output_dir + "/gtdb_mean_genome_sizes.tsv"
    output:
        directory(output_dir + "/metapackage/" + config["new_metapackage"])
    params:
        singlem = singlem_bin
    threads:
        config["max_threads"]
    resources:
        mem_mb = get_mem_mb,
        runtime = 48*60,
    log:
        logs_dir + "/metapackage.log"
    conda:
        "../../singlem.yml"
    shell:
        "{params.singlem} metapackage "
        "--singlem-packages {input.packages} "
        "--nucleotide-sdb {input.sdb} "
        "--taxon-genome-lengths {input.genome_sizes} "
        "--metapackage {output} "
        "--threads {threads} "
        "--taxonomy-database-name '{taxonomy_database_name}' "
        "--taxonomy-database-version '{taxonomy_database_version}' "
        "&> {log}"

#######################################
### Mean genome size table creation ###
#######################################
rule checkm2_gather: # In future presumably this data will be available form the GTDB metadata
    input:
        checkm2_output_dirs = checkm2_output_dirs if checkm2_output_dirs else [],
    output:
        gtdb_checkm2_grepped = output_dir + "/gtdb_checkm2_grepped.tsv",
        done = touch(output_dir + "/checkm2_gather.done"),
    log:
        logs_dir + "/checkm2_gather.log"
    shell:
        # grep -hv Completeness {input.checkm2_output_dirs}/*.checkm2/quality_report.tsv
        "find {input.checkm2_output_dirs} -name '*quality_report.tsv' |parallel -N50 grep -hv Completeness {{}} > {output.gtdb_checkm2_grepped} 2> {log}"

rule calculate_mean_genome_size:
    input:
        # Only need these inputs if checkm2 isn't in the metadata
        gtdb_checkm2_grepped = output_dir + "/gtdb_checkm2_grepped.tsv" if checkm2_output_dirs else [],
        done = output_dir + "/checkm2_gather.done" if checkm2_output_dirs else [],
    output:
        output_dir + "/gtdb_mean_genome_sizes.tsv"
    params:
        gtdb_bac_metadata = config["gtdb_bac_metadata"],
        gtdb_arc_metadata = config["gtdb_arc_metadata"],
        checkm2_grep_arg = '--checkm2-grep ' + output_dir + "/gtdb_checkm2_grepped.tsv" if checkm2_output_dirs else ''
    threads: 8
    resources:
        mem_mb = get_mem_mb,
        runtime = '1h',
    log:
        logs_dir + "/calculate_mean_genome_size.log"
    conda:
        "envs/calculate_mean_genome_size.yml"
    shell:
        """
        scripts/genome_size_from_gtdb.py {params.checkm2_grep_arg} --gtdb-bac-metadata {params.gtdb_bac_metadata} --gtdb-ar-metadata {params.gtdb_arc_metadata} > {output} 2> {log}
        """

###########################
### Metapackage testing ###
###########################
rule test_metapackage:
    input:
        metapackge = output_dir + "/metapackage/" + config["new_metapackage"],
        sra_1 = lambda wildcards: config["sra_seqs_1"][wildcards.sra],
        sra_2 = lambda wildcards: config["sra_seqs_2"][wildcards.sra],
    output:
        table = output_dir + "/sra/{sra}_new.otu_table.tsv",
        archive_table = output_dir + "/sra/{sra}_new.otu_table.json",
        profile = output_dir + "/sra/{sra}_new.otu_table.condensed.tsv",
        profile_krona = output_dir + "/sra/{sra}_new.otu_table.condensed.krona.html"
    params:
        singlem = singlem_bin
    threads: 8
    resources:
        mem_mb = get_mem_mb,
        runtime = 48*60,
    log:
        logs_dir + "/sra/{sra}_new_test.log"
    conda:
        "../../singlem.yml"
    shell:
        "{params.singlem} pipe "
        "--forward {input.sra_1} "
        "--reverse {input.sra_2} "
        "--metapackage {input.metapackge} "
        "--otu-table {output.table} "
        "--archive-otu-table {output.archive_table} "
        "-p {output.profile} "
        "--taxonomic-profile-krona {output.profile_krona} "
        "--threads {threads} "
        "&> {log}"

rule test_old_metapackage:
    input:
        sra_1 = lambda wildcards: config["sra_seqs_1"][wildcards.sra],
        sra_2 = lambda wildcards: config["sra_seqs_2"][wildcards.sra],
    output:
        table = output_dir + "/sra/{sra}_old.otu_table.tsv",
        archive_table = output_dir + "/sra/{sra}_old.otu_table.json",
        profile = output_dir + "/sra/{sra}_old.otu_table.condensed.tsv",
        profile_krona = output_dir + "/sra/{sra}_old.otu_table.condensed.krona.html"
    params:
        singlem = singlem_bin,
        metapackge = config["metapackage"],
    threads: 8
    resources:
        mem_mb = get_mem_mb,
        runtime = 48*60,
    log:
        logs_dir + "/sra/{sra}_old_test.log"
    conda:
        "../../singlem.yml"
    shell:
        "{params.singlem} pipe "
        "--forward {input.sra_1} "
        "--reverse {input.sra_2} "
        "--metapackage {params.metapackge} "
        "--otu-table {output.table} "
        "--archive-otu-table {output.archive_table} "
        "-p {output.profile} "
        "--taxonomic-profile-krona {output.profile_krona} "
        "--threads {threads} "
        "&> {log}"

rule sra_fractions_assigned_to_each_level:
    input:
        profile = output_dir + "/sra/{sra}_{version}.otu_table.condensed.tsv",
    output:
        breakdown = output_dir + "/sra/{sra}_{version}.taxonomic_coverage.tsv"
    threads: 1
    resources:
        mem_mb = get_mem_mb,
        runtime = 48*60,
    conda:
        "../../singlem.yml"
    shell:
        "{singlem_bin} summarise "
        "--input-taxonomic-profile {input.profile} "
        "--output-taxonomic-level-coverage {output.breakdown} "

rule sra_prokaryotic_fraction:
    input:
        profile = output_dir + "/sra/{sra}_{version}.otu_table.condensed.tsv",
        sra_1 = lambda wildcards: config["sra_seqs_1"][wildcards.sra],
        sra_2 = lambda wildcards: config["sra_seqs_2"][wildcards.sra],
    output:
        smf = output_dir + "/sra/{sra}_{version}.smf.tsv"
    threads: 1
    resources:
        mem_mb = get_mem_mb,
        runtime = 48*60,
    params:
        metapackge = lambda wildcards: config["metapackage"] if wildcards.version == "old" else output_dir + "/metapackage/" + config["new_metapackage"],
    conda:
        "../../singlem.yml"
    log:
        logs_dir + "/sra/{sra}_{version}_smf.log"
    shell:
        "{singlem_bin} prokaryotic_fraction "
        "--metapackage {params.metapackge} "
        "--forward {input.sra_1} "
        "--reverse {input.sra_2} "
        "-p {input.profile} "
        "> {output.smf} "
        "2> {log} "

rule check_tests:
    input:
        expand(output_dir + "/sra/{sra}_new.otu_table.tsv", sra=config["sra_seqs_1"]),
        expand(output_dir + "/sra/{sra}_old.otu_table.tsv", sra=config["sra_seqs_1"]),
        expand(output_dir + "/sra/{sra}_{version}.taxonomic_coverage.tsv", sra=config["sra_seqs_1"], version=["new", "old"]),
        expand(output_dir + "/sra/{sra}_{version}.smf.tsv", sra=config["sra_seqs_1"], version=["new", "old"]),
    output:
        touch(output_dir + "/sra/done")
