#!/usr/bin/env python

"""
Copyright (c) 2015 Kwangbom Choi, The Jackson Laboratory
This software was developed by Kwangbom "KB" Choi in Gary Churchill's Lab.
This is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This software is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this software. If not, see <http://www.gnu.org/licenses/>.
"""

import os
import sys
import getopt
import string
import subprocess
from collections import defaultdict
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio.Alphabet import generic_dna
from itertools import dropwhile


help_message = '''
Usage:
    prepare-emase -G <genome_files> [ -g <gtf_files> -s <hap_list> -o <out_dir> -m -x ]

Input:
    -G <genome_files> : List of Genome files (comma delimited)
    -g <gtf_files>    : List of gene annotation files (comma delimited, in the order of genomes)
    -s <hap_list>     : Names of haplotypes to be used instead (comma delimited, in the order of genomes)
    -o <out_dir>      : Output folder to store results (default: the current working directory)

Parameters:
    -h, --help            : shows this help message
    -m, --save-g2tmap     : saves gene id to transcript id list in a tab-delimited text file
    -x, --no-bowtie-index : skips building bowtie index

Note:
    Does not work if the input gtf file is older than Ensembl Release 75.
'''


class Usage(Exception):
    def __init__(self, msg):
        self.msg = msg


def is_comment(s):
    return s.startswith('#')


def parse_gtf(gtffile):
    gdb = {}
    tdb = {}
    with open(gtffile) as fh:
        for curline in dropwhile(is_comment, fh):
            item = curline.rstrip().split("\t")
            attribute = defaultdict(list)
            for e in item[8].split("; "):
                e1, e2 = e.replace('"', '').replace(';', '').split()
                attribute[e1].append(e2)
            attribute = dict(attribute)
            feature = item[2]
            s = int(item[3])
            e = int(item[4])
            #chro = item[0].replace('chr', '')
            chro = item[0]  # Leave the chromosome names as is in gtf file
            strand = item[6]
            if feature == 'gene': # Seqnature currently has some issue processing this entry
                gid = attribute.pop('gene_id')[0]
                if gdb.has_key(gid):
                    print "[Error] Duplicate entry: %s" % gid
                else:
                    gdb[gid] = {'chr':chro,
                                'strand':strand,
                                'start':s,
                                'end':e,
                                'isoform':set()}
                    for k, v in attribute.iteritems():
                        if len(v) == 1:
                            v = v.pop()
                        gdb[gid][k] = v
            elif feature == 'transcript':
                gid = attribute['gene_id'][0]
                tid = attribute.pop('transcript_id')[0]
                if tdb.has_key(tid):
                    print "[Error] Duplicate entry: %s" % tid
                else:
                    tdb[tid] = {'chr':chro,
                                'strand':strand,
                                'start':s,
                                'end':e,
                                'eid':[],
                                'exon':[],
                                'exon_number':[],  # Do we need this?
                                'UTR':[],
                                'Selenocysteine':[],
                                'start_codon':[],
                                'stop_codon':[],
                                'start_codon_frame':[],
                                'stop_codon_frame':[]}
                    for k, v in attribute.iteritems():
                        if len(v) == 1:
                            v = v.pop()
                        tdb[tid][k] = v
                    gid = tdb[tid]['gene_id']
                    if gdb.has_key(gid):
                        gdb[gid]['isoform'].add(tid)
                    else:  # This is a non-standard case where the input gtf does not have detailed gene-level annotation
                        gdb[gid] = dict()
                        gdb[gid]['chr'] = chro
                        gdb[gid]['isoform'] = set(tid)
            else:
                gid = attribute['gene_id'][0]
                tid = attribute['transcript_id'][0]
                if feature == 'exon':
                    if tdb.has_key(tid):
                        try:
                            enu = int(attribute.pop('exon_number')[0])
                            tdb[tid]['exon_number'].append(enu)
                        except:
                            pass
                        try:
                            eid = attribute.pop('exon_id')[0]
                            tdb[tid]['eid'].append(eid)
                        except:
                            pass
                        tdb[tid]['exon'].append((s, e))
                    else:  # This is a non-standard case where input gtf does not have detailed transcript-level annotation
                        tdb[tid] = dict()
                        tdb[tid]['chr'] = chro
                        tdb[tid]['strand'] = strand
                        tdb[tid]['exon'] = [(s, e)]
                elif feature == 'UTR' or feature == 'Selenocysteine':
                    tdb[tid][feature].append((s, e))
                elif feature in ('start_codon', 'stop_codon'):
                    tdb[tid][feature].append((s, e))
                    tdb[tid][feature + '_frame'].append(int(item[7]))
                elif feature == 'CDS':
                    pass
                else:
                    print >> sys.stderr, "[Unknown feature: %s]/n%s" % (feature, curline)

    return (gdb, tdb)


#
# Get the regions of genes
#
def get_fragment(start, end, chro, strand, genome):
    fragment = genome[chro].seq[(start-1):end]
    if strand == '-':
        fragment = fragment.reverse_complement()
    return fragment


def main(argv=None):
    if argv is None:
        argv = sys.argv
    try:
        try:
            opts, args = getopt.getopt(argv[1:], "hG:g:s:mxo:", ["help", "save-g2tmap", "no-bowtie-index"])
        except getopt.error, msg:
            raise Usage(msg)

        # Default values of vars
        genomelist = list()
        gtflist = None
        haplist = None
        num_haps = 0
        outdir = '.'
        save_g2tmap = False
        build_bowtie_index = True

        # option processing (change this later with optparse)
        for option, value in opts:
            if option in ("-h", "--help"):
                raise Usage(help_message)
            if option == "-G":
                genomelist = value.split(',')
                num_haps = len(genomelist)
            if option == "-g":
                gtflist = value.split(',')
            if option == "-s":
                haplist = value.split(',')
            if option in ("-m", "--save-g2tmap"):
                save_g2tmap = True
            if option in ("-x", "--no-bowtie-index"):
                build_bowtie_index = False
            if option == "-o":
                outdir = value

        # Check if the required options are given
        if num_haps < 1:
            print >> sys.stderr, "No genome file is given."
            return 2
        if haplist is None:
            haplist = list(string.uppercase[:num_haps])
            if num_haps == 1:
                print >> sys.stderr, "Assuming single genome analysis. No suffix will be added to ID's"
            else:
                print >> sys.stderr, "Default haplotype names will be used: %s" % ', '.join(haplist)
        if gtflist is None:
            gtflist = [ os.path.splitext(genomefile)[0] + '.gtf' for genomefile in genomelist ]
            print >> sys.stderr, "Assuming there exist the following GTF files:\n%s" % '\n'.join(gtflist)
        if len(haplist) != num_haps or len(gtflist) != num_haps:
            print >> sys.stderr, "The number of gtf files or specified haplotypes is not matching to the number of genomes."

        if not os.path.exists(outdir):
            os.mkdir(outdir)

        #
        # Main body
        #

        #
        # Get pooled transcriptome
        if num_haps > 1:
            transcriptomefile = os.path.join(outdir, 'emase.pooled.transcripts.fa')
            lenfile = os.path.join(outdir, 'emase.pooled.transcripts.info')
        else:
            transcriptomefile = os.path.join(outdir, 'emase.transcripts.fa')
            lenfile = os.path.join(outdir, 'emase.transcripts.info')
        seqout = open(transcriptomefile, 'w')
        lenout = open(lenfile, 'w')
        for hid in xrange(num_haps):
            genomefile = genomelist[hid]
            genomename = os.path.splitext(os.path.basename(genomefile))[0]
            hapname = haplist[hid]
            gtffile = gtflist[hid]
            print >> sys.stderr, "Loading %s genome..." % genomename
            genome = SeqIO.to_dict(SeqIO.parse(open(genomefile), "fasta"))
            print >> sys.stderr, "Parsing %s..." % os.path.basename(gtffile)
            gdb, tdb = parse_gtf(gtffile)
            if num_haps == 1:
                print >> sys.stderr, "Building %s transcriptome (Note: No suffix added to ID's)..." % genomename
            elif num_haps > 1:
                print >> sys.stderr, "Building %s transcriptome using suffix \'_%s\'..." % (genomename, hapname)
            tlist = tdb.keys()
            tlist.sort()
            for tid in tlist:
                tinfo = tdb[tid]
                if genome.has_key(tinfo['chr']):  # Filter out transcripts from chromosome that the input genome does not contain
                    if num_haps > 1: # No need to add suffix if we deal with a single genome
                        tid = tid + "_%s" % hapname
                    fragment = Seq("", generic_dna)
                    for exon in tinfo['exon']:
                        fragment += get_fragment(exon[0], exon[1], tinfo['chr'], tinfo['strand'], genome)
                    if len(fragment) > 0:
                        SeqIO.write(SeqRecord(fragment, tid, '', ''), seqout, 'fasta')
                    lenout.write("%s\t%d\n" % (tid, len(fragment)))
                else:
                    print >> sys.stderr, "Skipping Transcript %s of Chromosome %s..." % (tid, tinfo['chr'])
        seqout.close()
        lenout.close()

        if save_g2tmap:
            with open(os.path.join(outdir, 'emase.gene2transcripts.tsv'), 'w') as fhout:
                print >> sys.stderr, "Recording mapping of gene id to transcript id's..."
                glist = gdb.keys()
                glist.sort()
                for gid in glist:
                    if genome.has_key(gdb[gid]['chr']):
                        item = [gid]
                        item = item + list(gdb[gid]['isoform'])
                        fhout.write("\t".join(item) + "\n")

        #
        # Build bowtie index for the pooled transcriptome
        if build_bowtie_index:
            out_index = os.path.join(os.path.dirname(transcriptomefile), 'bowtie.transcripts')
            print >> sys.stderr, "Building bowtie index..."
            status = subprocess.call("bowtie-build %s %s" % (transcriptomefile, out_index), shell=True)

        #
        # End of main body
        #

    except Usage, err:
        print >> sys.stderr, sys.argv[0].split("/")[-1] + ": " + str(err.msg)
        return 2


if __name__ == "__main__":
    sys.exit(main())
