#!/usr/bin/env python

"""
Copyright (c) 2015 Kwangbom Choi, The Jackson Laboratory
This software was developed by Kwangbom "KB" Choi in Gary Churchill's Lab.
This is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This software is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this software. If not, see <http://www.gnu.org/licenses/>.
"""

import os
import sys
import pysam
import numpy as np
from scipy.sparse import lil_matrix
from numpy.random import multinomial, randint, poisson, choice
import getopt


help_message = '''
Usage:
    simulate-reads -R <reference_dir> -H <hybrid_dir> -p <param_file> \\
                   -m <simulation_model> -N <total_count> -X <num_sims> -o <out_base> \\
                 [ -r <read_len> -e <err_rate> -L ]

Input:
    -R, --ref-dir <reference_dir>     : Directory where EMASE reference data resides
    -H, --hyb-dir <hybrid_dir>        : Directory where EMASE hybrid data resides
    -p, --parameter-file <param_file> : Expression abundance we want to simulate (Usually a result from EMASE)
    -m <simulation_model>             : Read sampling model (default: 1)
    -N, --total-count <total_count>   : Total read count we want to simulate (default: 10,000,000)
    -X, --num-sims <num_sims>         : The number of simulations (default: 1)
    -r <read_len>                     : The length of the simulated reads (default: 100)
    -e, --error-rate <error_rate>     : Mutation rate given by Poisson model (default: 0.1)
    -o, --outfile-basename <out_base> : The base name for output files (default: './emase.simulated')

Parameters:
    -h, --help                        : shows this help message
    -L, --long-read-name              : outputs more informative read name (origin & mutation)
'''


class Usage(Exception):
    def __init__(self, msg):
        self.msg = msg


def simulate_total_counts(acount, N):
    rowsum = acount.sum(axis=1)
    theta = rowsum / rowsum.sum()
    total_count_sim = multinomial(N, theta, size=1)[0]
    return total_count_sim


def simulate_allele_counts(acount, tcount_sim):
    num_haps = acount.shape[1]
    rowsum = acount.sum(axis=1)
    phi = np.zeros(acount.shape)
    expr = rowsum > 0
    phi[expr, :] = acount[expr, :] / rowsum[expr, np.newaxis]
    phi[np.logical_not(expr), :] = np.ones(num_haps) / num_haps
    acount_sim = np.zeros(acount.shape, dtype=int)
    for k in xrange(acount.shape[0]):
        acount_sim[k] = multinomial(tcount_sim[k], phi[k], size=1)[0]
    return acount_sim


def simulate_isoform_counts(tcount, gcount_sim, groups):
    tcount_sim = np.zeros(len(tcount), dtype=int)
    for g in xrange(len(groups)):
        tindex = groups[g]
        num_isoforms = len(tindex)
        delta = tcount[tindex]
        delta_sum = delta.sum()
        if delta_sum > 0:
            delta /= delta_sum
        else:
            delta = np.ones(num_isoforms) / num_isoforms
        tcount_sim[tindex] = multinomial(gcount_sim[g], delta)
    return tcount_sim


def simulate_read_counts(total_simreads, model, param_count, groups, t2gmat):
    num_transcripts, num_haps = param_count.shape
    num_transcripts_chk, num_genes = t2gmat.shape
    if num_transcripts != num_transcripts_chk:
        print >> sys.stderr, sys.argv[0].split("/")[-1] + ": " + '[Error] Dimension mismatch.'
        return 2
    if model == 1:  # Gene->Alleles->Isoforms
        gacount = t2gmat.transpose() * param_count
        gcount_sim = simulate_total_counts(gacount, total_simreads)
        gacount_sim = simulate_allele_counts(gacount, gcount_sim)
        tacount_sim = np.zeros(param_count.shape, dtype=int)
        for h in xrange(num_haps):
            tacount_sim[:, h] = simulate_isoform_counts(param_count[:, h], gacount_sim[:, h], groups)
    elif model == 2:  # Gene->Isoforms->Alleles
        gacount = t2gmat.transpose() * param_count
        gcount_sim = simulate_total_counts(gacount, total_simreads)
        tcount_sim = simulate_isoform_counts(param_count.sum(axis=1), gcount_sim, groups)
        tacount_sim = simulate_allele_counts(param_count, tcount_sim)
    elif model == 3:  # Gene->(Alleles*Isoforms)
        gacount = t2gmat.transpose() * param_count
        gcount_sim = simulate_total_counts(gacount, total_simreads)
        tacount_sim = np.zeros(param_count.shape, dtype=int)
        for g in xrange(num_genes):
            tindex = t2gmat[:, g].nonzero()[0]
            num_isoforms = len(tindex)
            num_isoform_alleles = num_isoforms * num_haps
            xi = param_count[tindex]
            xi_sum = xi.sum()
            if xi_sum > 0:
                xi /= xi_sum
            else:
                xi = np.ones((num_isoforms, num_haps)) / num_isoform_alleles
            tacount_sim_vector = multinomial(gcount_sim[g], xi.reshape((1, num_isoform_alleles))[0])
            tacount_sim[tindex] = tacount_sim_vector.reshape((num_isoforms, num_haps))
    else:  # elif model == 4:
        zeta = param_count / param_count.sum()
        tacount_sim_vector = multinomial(total_simreads, zeta.reshape((1, num_transcripts * num_haps))[0])
        tacount_sim = tacount_sim_vector.reshape((num_transcripts, num_haps))
    return tacount_sim


def introduce_mutation(seq, pos, new_base):
    return seq[:pos] + new_base + seq[pos+1:]


def main(argv=None):
    if argv is None:
        argv = sys.argv
    try:
        try:
            opts, args = getopt.getopt(argv[1:], "hR:H:p:m:N:X:r:e:o:L", \
                                       ["help", "ref-dir", "hybrid-dir", "parameter-file", "total-count", \
                                        "num-sims", "error-rate", "outfile-basename", "long-read-name"])
        except getopt.error, msg:
            raise Usage(msg)

        # Default values of vars
        ref_dir = os.getenv('EMASE_TARGET_BASE', '.')
        hyb_dir = '.'
        outbase = './emase.simulated'
        model = 1
        total_simreads = 10000000
        num_simulations = 1
        read_len = 100
        error_rate = 0.1
        use_long_id = False
        mut = {'A':('T', 'G', 'C'), \
               'T':('G', 'C', 'A'), \
               'G':('C', 'A', 'T'), \
               'C':('A', 'T', 'G')}

        # option processing (change this later with optparse)
        for option, value in opts:
            if option in ("-h", "--help"):
                raise Usage(help_message)
            if option in ("-R", "--ref-dir"):
                ref_dir = value
            if option in ("-H", "--hybrid-dir"):
                hyb_dir = value
            if option in ("-p", "--parameter-file"):
                param_file = value
            if option == "-m":
                model = int(value)
            if option in ("-N", "--total-count"):
                total_simreads = int(value)
            if option in ("-X", "--num-sims"):
                num_simulations = int(value)
            if option == "-r":
                read_len = int(value)
            if option in ("-e", "--error-rate"):
                error_rate = float(value)
            if option in ("-o", "--outfile-basename"):
                outbase = value
            if option in ("-L", "--long-read-name"):
                use_long_id = True

        # Check if the required options are all given
        transcript_info_file = os.path.join(ref_dir, 'emase.transcripts.info')
        if not os.path.exists(transcript_info_file):
            print >> sys.stderr, sys.argv[0].split("/")[-1] + ": " + '%s does not exist at %s.' % (transcript_info_file, ref_dir)
            return 2
        grouping_file = os.path.join(ref_dir, 'emase.gene2transcripts.tsv')
        if not os.path.exists(grouping_file):
            print >> sys.stderr, sys.argv[0].split("/")[-1] + ": " + '%s does not exist at %s.' % (grouping_file, ref_dir)
            return 2
        target_file = os.path.join(hyb_dir, 'emase.pooled.transcripts.fa')
        if not os.path.exists(target_file):
            print >> sys.stderr, sys.argv[0].split("/")[-1] + ": " + '%s does not exist at %s.' % (target_file, hyb_dir)
            return 2
        target_info_file = os.path.join(hyb_dir, 'emase.pooled.transcripts.info')
        if not os.path.exists(target_info_file):
            print >> sys.stderr, sys.argv[0].split("/")[-1] + ": " + '%s does not exist at %s.' % (target_info_file, hyb_dir)
            return 2
        if model not in (1, 2, 3, 4):
            print >> sys.stderr, sys.argv[0].split("/")[-1] + ": " + 'Simulation model should be either 1, 2, 3 or 4.'
            return 2


        #
        # Main body
        #

        # Load parameters and other required information
        tname = np.loadtxt(transcript_info_file, dtype='string', usecols=(0,))
        num_transcripts = len(tname)
        tid = dict(zip(tname, np.arange(num_transcripts)))

        gname = list()
        groups = list()
        with open(grouping_file) as fh:
            for curline in fh:
                item = curline.rstrip().split("\t")
                gname.append(item[0])
                tid_list = [ tid[t] for t in item[1:] ]
                groups.append(tid_list)
        gname = np.array(gname)
        num_genes = len(gname)
        #gid = dict(zip(gname, np.arange(num_genes)))
        grp_conv_mat = lil_matrix((num_transcripts, num_genes))
        for i in xrange(num_genes):
            grp_conv_mat[groups[i], i] = 1.0
        grp_conv_mat = grp_conv_mat.tocsc()

        with open(param_file) as fh:
            curline = fh.next()
            item = curline.rstrip().split('\t')
            hname = item[1:-1]
            num_haps = len(hname)
            hid = dict(zip(hname, np.arange(num_haps)))
            tacount = np.zeros((num_transcripts, num_haps))
            for curline in fh:
                item = curline.rstrip().split('\t')
                tacount[tid[item[0]], :] = map(float, item[1:(1+num_haps)])

        f = pysam.FastaFile(target_file)

        trange = np.zeros(tacount.shape, dtype=int)
        with open(target_info_file) as fh:
            for curline in fh:
                item = curline.rstrip().split('\t')
                t, h = item[0].split('_')
                trange[tid[t], hid[h]] = np.int(item[1])
        trange = trange - read_len + 1

        for k in xrange(num_simulations):
            # Simulate read counts
            tacount_sim = simulate_read_counts(total_simreads, model, tacount, groups, grp_conv_mat)
            gacount_sim = grp_conv_mat.transpose() * tacount_sim

            # Write the true read counts and simulated reads
            if num_simulations > 1:
                true_gaout = outbase + '.%03d.genes.read_counts' % (k+1)
                true_taout = outbase + '.%03d.transcripts.read_counts' % (k+1)
                outfile = outbase + '.%03d.fa' % (k+1)
            else:
                true_gaout = outbase + '.genes.read_counts'
                true_taout = outbase + '.transcripts.read_counts'
                outfile = outbase + '.fa'
            np.savetxt(true_gaout, np.column_stack((gname, np.char.mod('%d', gacount_sim))), \
                       header='Gene_ID\t' + "\t".join(hname), fmt='%s', delimiter='\t')
            np.savetxt(true_taout, np.column_stack((tname, np.char.mod('%d', tacount_sim))), \
                       header='Transcript_ID\t' + "\t".join(hname), fmt='%s', delimiter='\t')
            with open(outfile, 'w') as fhout:
                num_digits = len(str(total_simreads))
                runum = 1
                for t in xrange(num_transcripts):
                    for h in xrange(num_haps):
                        theader = '%s_%s' % (tname[t], hname[h])
                        trange_max = trange[t, h]
                        if trange_max > 0:  # We cannot generate reads from a transcript shorter than read_len
                            for cnt in xrange(tacount_sim[t, h]):
                                ruid = 'r{unum:{fill}{width}}'.format(unum=runum, fill='0', width=num_digits)
                                start = randint(0, trange_max)
                                end = start + read_len
                                seq = f.fetch(theader, start, end)
                                nerr = poisson(error_rate)
                                errstr = ''
                                if nerr > 0:
                                    errloc = sorted(choice(read_len, nerr, replace=False))
                                    errstr = ':'
                                    for epos in errloc:
                                        from_base = seq[epos]
                                        to_base = mut[from_base][randint(3)]
                                        seq = introduce_mutation(seq, epos, to_base)
                                        errstr += '[%d]%s->%s;' % (epos+1, from_base, to_base)
                                if use_long_id:
                                    fhout.write(">%s:%s:%d-%d%s\n%s\n" % (ruid, theader, start+1, end, errstr, seq))
                                else:
                                    fhout.write(">%s\n%s\n" % (ruid, seq))
                                runum += 1

        #
        # End of Main body
        #

    except Usage, err:
        print >> sys.stderr, sys.argv[0].split("/")[-1] + ": " + str(err.msg)
        return 2


if __name__ == "__main__":
    sys.exit(main())
