import polars as pl
from bird_tool_utils import iterable_chunks
import numpy as np

prodigal_runner_path = config['prodigal_runner_path'] #'~/git/prodigal-runner/bin/prodigal-runner'
mag_paths = config['mag_paths']
gtdbtk_db_path = config['GTDBTK_DATA_PATH']
checkm2_db = config['CHECKM2DB']
num_threads = 8

output_directory = config['output'] if 'output_directory' in config else 'supplement_preparation'
genome_fasta_extension = config['genome_fasta_extension'] if 'genome_fasta_extension' in config else '.fna'

# Read in the list of files to process.
genomes_input = np.loadtxt(mag_paths, dtype='str')

# chunk into batches of 10
groups1 = list(iterable_chunks(genomes_input, 10))

# groups1 = list(iterable_chunks(genomes_input, 2))[:2] # debug
# output_directory = output_directory + '-debug'

groups = []
for group in groups1:
    groups.append(list([g for g in group if g is not None]))

print("Found {} genome files and {} groups".format(len(genomes_input), len(groups)))
group_ids = list([str(i) for i in range(len(groups))])
group_ids = list([i for i in range(len(groups))])

rule all: 
    input:
        f'{output_directory}/done/all-prodigal-runner.done',
        f'{output_directory}/done/all-gtdbtk.done',
        f'{output_directory}/done/all-checkm2.done',
        f'{output_directory}/done/galah.done'

rule galah:
    input:
        checkm2_report = f'{output_directory}/checkm2_quality_report.tsv',
    output:
        representatives = f'{output_directory}/galah_representatives.tsv',
        done = touch(f'{output_directory}/done/galah.done')
    conda:
        "envs/galah.yml"
    threads: num_threads
    shell:
        "galah cluster --genome-fasta-list {mag_paths} --checkm2-quality-report {input.checkm2_report} --output-representative-list {output.representatives} -t {threads}"

rule all_checkm2:
    input:
        expand(f'{output_directory}/done/checkm2-'+'{group}.done', group=range(len(groups))),
    output:
        report = f'{output_directory}/checkm2_quality_report.tsv',
        done = touch(f'{output_directory}/done/all-checkm2.done')
    shell:
        """
        head -1 {output_directory}/checkm2/0/quality_report.tsv > {output.report} && \
        find {output_directory}/checkm2/ |grep quality_report.tsv | parallel -j1 -k tail -n+2 {{}} >> {output.report}
        """

rule all_gtdbtk:
    input:
        expand(f'{output_directory}/done/gtdbtk-'+'{group}.done', group=range(len(groups))),
    output:
        done = touch(f'{output_directory}/done/all-gtdbtk.done')
    shell:
        """
        head -1 {output_directory}/gtdbtk/0/gtdbtk.ar53.summary.tsv > {output_directory}/gtdbtk/gtdbtk.ar53.summary.tsv && \
        find {output_directory}/gtdbtk  -type f |grep -v gtdbtk/gtdbtk.ar53.summary.tsv |grep gtdbtk.ar53.summary.tsv | parallel -j1 tail -n+2 {{}} >> {output_directory}/gtdbtk/gtdbtk.ar53.summary.tsv && \
        head -1 {output_directory}/gtdbtk/0/gtdbtk.bac120.summary.tsv > {output_directory}/gtdbtk/gtdbtk.bac120.summary.tsv && \
        find {output_directory}/gtdbtk  -type f |grep -v gtdbtk/gtdbtk.bac120.summary.tsv |grep gtdbtk.bac120.summary.tsv | parallel -j1 tail -n+2 {{}} >> {output_directory}/gtdbtk/gtdbtk.bac120.summary.tsv
        """

rule all_prodigal_runner:
    input:
        expand(f'{output_directory}/done/prodigal-runner-'+'{group}.done', group=range(len(groups))),
    output:
        gene_definitions = f'{output_directory}/gene_definitions.tsv',
        done=touch(f'{output_directory}/done/all-prodigal-runner.done')
    params:
        output_directory = output_directory,
        genomes_input = genomes_input,
    script:
        "bin/combine-prodigal-runner.py"

rule checkm2:
    output:
        touch(f'{output_directory}/done/checkm2-'+'{group}.done'),
    conda:
        "envs/checkm2.yml"
    threads: 4
    resources:
        mem_mb = 8000,
        runtime = "1h",
    log:
        f'{output_directory}/'+'logs/checkm2/{group}.log'
    benchmark:
        f'{output_directory}/'+'benchmarks/checkm2/{group}.txt'
    params:
        output_directory = output_directory,
        genome_paths = lambda wildcards: '\n'.join(groups[int(wildcards.group)]),
        checkm2_db = checkm2_db,
        genome_fasta_extension = genome_fasta_extension
    shell:
        'mkdir -p {params.output_directory}/checkm2-inputs/{wildcards.group} && '\
        'echo "{params.genome_paths}" |parallel -j1 --halt-on-error 2 ln -frs {{}} {params.output_directory}/checkm2-inputs/{wildcards.group}/{{/}} && '\
        'mkdir -p {params.output_directory}/checkm2/{wildcards.group} && '\
        'CHECKM2DB={params.checkm2_db} checkm2 predict --force --threads {threads} --input {params.output_directory}/checkm2-inputs/{wildcards.group} --output-directory {params.output_directory}/checkm2/{wildcards.group} -x {params.genome_fasta_extension} &> {log}'

rule gtdbtk:
    output:
        touch(f'{output_directory}/done/gtdbtk-'+'{group}.done'),
    conda:
        "envs/gtdbtk.yml"
    threads: 4
    resources:
        mem_mb = 100000,
        runtime = "1h",
    log:
        f'{output_directory}/'+'logs/gtdbtk/{group}.log'
    benchmark:
        f'{output_directory}/'+'benchmarks/gtdbtk/{group}.txt'
    params:
        output_directory = output_directory,
        genome_paths = lambda wildcards: '\n'.join(groups[int(wildcards.group)]),
        gtdbtk_db_path = gtdbtk_db_path,
        genome_fasta_extension = genome_fasta_extension
    shell:
        'mkdir -p {params.output_directory}/gtdbtk-inputs/{wildcards.group} && '\
        'echo "{params.genome_paths}" |parallel -j1 --halt-on-error 2 ln -frs {{}} {params.output_directory}/gtdbtk-inputs/{wildcards.group}/{{/}} && '\
        'GTDBTK_DATA_PATH={params.gtdbtk_db_path} gtdbtk classify_wf --cpus {threads} --force --mash_db ~/m/db/gtdb/gtdb_release214/mash.db.msh --genome_dir {params.output_directory}/gtdbtk-inputs/{wildcards.group} --out_dir {params.output_directory}/gtdbtk/{wildcards.group} --extension {params.genome_fasta_extension} &> {log}'

rule prodigal_runner:
    output:
        touch(f'{output_directory}/done/prodigal-runner-'+'{group}.done'),
    conda:
        "envs/singlem.yml"
    threads: 1
    resources: # These are both overkill for 10 genomes each, but eh
        mem_mb = 8000,
        runtime = "1h",
    log:
        f'{output_directory}/'+'logs/prodigal-runner/{group}.log'
    benchmark:
        f'{output_directory}/'+'benchmarks/prodigal-runner/{group}.txt'
    params:
        output_directory = output_directory,
        genome_paths = lambda wildcards: '\n'.join(groups[int(wildcards.group)]),
        prodigal_runner_path = prodigal_runner_path
    shell:
        'mkdir -p {params.output_directory}/prodigal-runner/{wildcards.group} && '\
        'echo "{params.genome_paths}" |parallel -j1 --halt-on-error 2 --xargs --ungroup {params.prodigal_runner_path} run -i {{}} -o {params.output_directory}/prodigal-runner/{wildcards.group} &> {log}'