#!/usr/bin/env python
import os
import sys
import getopt
import string
import subprocess
from collections import defaultdict
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio.Alphabet import generic_dna
from itertools import dropwhile

__author__ = 'Kwangbom "KB" Choi, Ph. D.'


help_message = '''
    Usage:
        prepare-emase -G <genome_files> [ -g <gtf_files> -s <hap_list> -o <out_dir> -m -x ]

    Input:
        <genome_files> : List of Genome files (comma delimited)
        <gtf_files>    : List of gene annotation files (comma delimited, in the order of genomes)
        <hap_list>     : Names of haplotypes to be used instead (comma delimited, in the order of genomes)
        <out_dir>      : Output folder to store results (default: the current working directory)

    Parameters:
        -h, --help: shows this help message
        -m, --save-g2tmap: saves gene id to transcript id list in a tab-delimited text file
        -x, --no-bowtie-index: skips building bowtie index

    Note:
        Does not work if the input gtf file is older than Ensembl Release 75.
'''


class Usage(Exception):
    def __init__(self, msg):
        self.msg = msg


def is_comment(s):
    return s.startswith('#')


def parse_gtf(gtffile):
    gdb = {}
    tdb = {}
    with open(gtffile) as fh:
        for curline in dropwhile(is_comment, fh):
            item = curline.rstrip().split("\t")
            attribute = defaultdict(list)
            for e in item[8].split("; "):
                e1, e2 = e.replace('"', '').replace(';', '').split()
                attribute[e1].append(e2)
            attribute = dict(attribute)
            feature = item[2]
            s = int(item[3])
            e = int(item[4])
            #chro = item[0].replace('chr', '')
            chro = item[0]  # Leave the chromosome names as is in gtf file
            strand = item[6]
            if feature == 'gene': # Seqnature currently has some issue processing this entry
                gid = attribute.pop('gene_id')[0]
                if gdb.has_key(gid):
                    print "[Error] Duplicate entry: %s" % gid
                else:
                    gdb[gid] = {'chr':chro,
                                'strand':strand,
                                'start':s,
                                'end':e,
                                'isoform':set()}
                    for k, v in attribute.iteritems():
                        if len(v) == 1:
                            v = v.pop()
                        gdb[gid][k] = v
            elif feature == 'transcript':
                gid = attribute['gene_id'][0]
                tid = attribute.pop('transcript_id')[0]
                if tdb.has_key(tid):
                    print "[Error] Duplicate entry: %s" % tid
                else:
                    tdb[tid] = {'chr':chro,
                                'strand':strand,
                                'start':s,
                                'end':e,
                                'eid':[],
                                'exon':[],
                                'exon_number':[],  # Do we need this?
                                'UTR':[],
                                'Selenocysteine':[],
                                'start_codon':[],
                                'stop_codon':[],
                                'start_codon_frame':[],
                                'stop_codon_frame':[]}
                    for k, v in attribute.iteritems():
                        if len(v) == 1:
                            v = v.pop()
                        tdb[tid][k] = v
                    gid = tdb[tid]['gene_id']
                    if gdb.has_key(gid):
                        gdb[gid]['isoform'].add(tid)
                    else:  # This is a non-standard case where the input gtf does not have detailed gene-level annotation
                        gdb[gid] = dict()
                        gdb[gid]['chr'] = chro
                        gdb[gid]['isoform'] = set(tid)
            else:
                gid = attribute['gene_id'][0]
                tid = attribute['transcript_id'][0]
                if feature == 'exon':
                    if tdb.has_key(tid):
                        try:
                            enu = int(attribute.pop('exon_number')[0])
                            tdb[tid]['exon_number'].append(enu)
                        except:
                            pass
                        try:
                            eid = attribute.pop('exon_id')[0]
                            tdb[tid]['eid'].append(eid)
                        except:
                            pass
                        tdb[tid]['exon'].append((s, e))
                    else:  # This is a non-standard case where input gtf does not have detailed transcript-level annotation
                        tdb[tid] = dict()
                        tdb[tid]['chr'] = chro
                        tdb[tid]['strand'] = strand
                        tdb[tid]['exon'] = [(s, e)]
                elif feature == 'UTR' or feature == 'Selenocysteine':
                    tdb[tid][feature].append((s, e))
                elif feature in ('start_codon', 'stop_codon'):
                    tdb[tid][feature].append((s, e))
                    tdb[tid][feature + '_frame'].append(int(item[7]))
                elif feature == 'CDS':
                    pass
                else:
                    print >> sys.stderr, "[Unknown feature: %s]/n%s" % (feature, curline)

    return (gdb, tdb)


#
# Get the regions of genes
#
def get_fragment(start, end, chro, strand, genome):
    fragment = genome[chro].seq[(start-1):end]
    if strand == '-':
        fragment = fragment.reverse_complement()
    return fragment


def main(argv=None):
    if argv is None:
        argv = sys.argv
    try:
        try:
            opts, args = getopt.getopt(argv[1:], "hG:g:s:mxo:", ["help", "save-g2tmap", "no-bowtie-index"])
        except getopt.error, msg:
            raise Usage(msg)

        # Default values of vars
        genomelist = list()
        gtflist = None
        haplist = None
        num_haps = 0
        outdir = '.'
        save_g2tmap = False
        build_bowtie_index = True

        # option processing (change this later with optparse)
        for option, value in opts:
            if option in ("-h", "--help"):
                raise Usage(help_message)
            if option == "-G":
                genomelist = value.split(',')
                num_haps = len(genomelist)
            if option == "-g":
                gtflist = value.split(',')
            if option == "-s":
                haplist = value.split(',')
            if option in ("-m", "--save-g2tmap"):
                save_g2tmap = True
            if option in ("-x", "--no-bowtie-index"):
                build_bowtie_index = False
            if option == "-o":
                outdir = value

        # Check if the required options are given
        if num_haps < 1:
            print >> sys.stderr, "No genome file is given."
            return 2
        if haplist is None:
            haplist = list(string.uppercase[:num_haps])
            if num_haps == 1:
                print >> sys.stderr, "Assuming single genome analysis. No suffix will be added to ID's"
            else:
                print >> sys.stderr, "Default haplotype names will be used: %s" % ', '.join(haplist)
        if gtflist is None:
            gtflist = [ os.path.splitext(genomefile)[0] + '.gtf' for genomefile in genomelist ]
            print >> sys.stderr, "Assuming there exist the following GTF files:\n%s" % '\n'.join(gtflist)
        if len(haplist) != num_haps or len(gtflist) != num_haps:
            print >> sys.stderr, "The number of gtf files or specified haplotypes is not matching to the number of genomes."

        if not os.path.exists(outdir):
            os.mkdir(outdir)

        #
        # Main body
        #

        #
        # Get pooled transcriptome
        transcriptome_file = os.path.join(outdir, 'emase.pooled.transcriptome.fa')
        seqout = open(transcriptome_file, 'w')
        lenout = open(os.path.join(outdir, 'emase.pooled.transcriptome.info'), 'w')
        for hid in xrange(num_haps):
            genomefile = genomelist[hid]
            genomename = os.path.splitext(os.path.basename(genomefile))[0]
            hapname = haplist[hid]
            gtffile = gtflist[hid]
            print >> sys.stderr, "Loading %s genome..." % genomename
            genome = SeqIO.to_dict(SeqIO.parse(open(genomefile), "fasta"))
            print >> sys.stderr, "Parsing %s..." % os.path.basename(gtffile)
            gdb, tdb = parse_gtf(gtffile)
            if num_haps == 1:
                print >> sys.stderr, "Building %s transcriptome (Note: No suffix added to ID's)..." % genomename
            elif num_haps > 1:
                print >> sys.stderr, "Building %s transcriptome using suffix \'_%s\'..." % (genomename, hapname)
            tlist = tdb.keys()
            tlist.sort()
            for tid in tlist:
                tinfo = tdb[tid]
                if genome.has_key(tinfo['chr']):  # Filter out transcripts from chromosome that the input genome does not contain
                    if num_haps > 1: # No need to add suffix if we deal with a single genome
                        tid = tid + "_%s" % hapname
                    fragment = Seq("", generic_dna)
                    for exon in tinfo['exon']:
                        fragment += get_fragment(exon[0], exon[1], tinfo['chr'], tinfo['strand'], genome)
                    if len(fragment) > 0:
                        SeqIO.write(SeqRecord(fragment, tid, '', ''), seqout, 'fasta')
                    lenout.write("%s\t%d\n" % (tid, len(fragment)))
                else:
                    print >> sys.stderr, "Skipping Transcript %s of Chromosome %s..." % (tid, tinfo['chr'])
        seqout.close()
        lenout.close()

        if save_g2tmap:
            with open(os.path.join(outdir, 'emase.gene2transcript.tsv'), 'w') as fhout:
                print >> sys.stderr, "Recording mapping of gene id to transcript id's..."
                glist = gdb.keys()
                glist.sort()
                for gid in glist:
                    if genome.has_key(gdb[gid]['chr']):
                        item = [gid]
                        item = item + list(gdb[gid]['isoform'])
                        fhout.write("\t".join(item) + "\n")

        #
        # Build bowtie index for the pooled transcriptome
        if build_bowtie_index:
            out_index = os.path.join(os.path.dirname(transcriptome_file), 'bowtie.transcriptome')
            print >> sys.stderr, "Building bowtie index..."
            status = subprocess.call("bowtie-build %s %s" % (transcriptome_file, out_index), shell=True)

        #
        # End of main body
        #

    except Usage, err:
        print >> sys.stderr, sys.argv[0].split("/")[-1] + ": " + str(err.msg)
        return 2


if __name__ == "__main__":
    sys.exit(main())
