#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2011, A. Murat Eren
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free
# Software Foundation; either version 2 of the License, or (at your option)
# any later version.
#
# Please read the COPYING file.


# This script takes an .ini file and analyzes overlapping V6 overlapped pairs to extract
# perfect V6 reads (=after cleaning the primers, reverse-complement of second pair perfectly
# matches to the first pair)
#
# This is how the expected .ini file looks like:
#
#-------8<-------8<-------8<-------8<-------8<-------8<-------8<-------8<-------8<------------
#    [general]
#    project_name = (project name)
#    researcher_email = (your e-mail address)
#    input_directory = (directory from which the input files will be read from)
#    output_directory = (directory to store output)
#    
#    
#    [files]
#    pair_1 = (pair 1 files, comma separated)
#    pair_2 = (pair 2 files, comma separated. must be ordered based on pair 1 files)
#-------8<-------8<-------8<-------8<-------8<-------8<-------8<-------8<-------8<------------
#
# Any questions? Send an e-mail to meren@mbl.edu
#

import os
import re
import sys
import copy
import numpy
import ConfigParser


E = os.path.exists
J = os.path.join


import IlluminaUtils.lib.fastqlib as u
from IlluminaUtils.utils.runconfiguration import RunConfiguration
from IlluminaUtils.utils.helperfunctions import ReadIDTracker 
from IlluminaUtils.utils.helperfunctions import QualityScoresHandler
from IlluminaUtils.utils.helperfunctions import visualize_qual_stats_dict
from IlluminaUtils.utils.helperfunctions import store_cPickle_obj
from IlluminaUtils.utils.helperfunctions import load_cPickle_obj


class ConfigError(Exception):
    pass


def RepresentsInt(s):
    try: 
        int(s)
        return True
    except ValueError:
        return False


def get_rp_match(seq, reverse_primer):
    offset = 70
    rp_match = None

    while 1:
        m = reverse_primer.search(seq, offset)
                
        if not m:
            break
        else:
            rp_match = m
            offset = m.start() + 1

    return rp_match


def main(config, archaea = False, visualize_quality_curves = False, recover_ambiguous_bases = False):
    
    if archaea:
        # set archaeal V6 primers.
        
        #forward_primer = re.compile('AATTGGA.TCAACGCCGG')
        # last 12 nts of the forward primer:
        forward_primer = re.compile('A.TCAACGCCGG')
        
        #reverse_primer = re.compile('CG[A,G]C[A,G]GCCATG[C,T]ACC[A,T]C')
        # 6 nts of reverse primer:
        reverse_primer = re.compile('G[A,T]GGT[G,A]')
    else:
        # set bacterial V6 primers.
        # reverse primer (first 10 character of 1046R-rc):
        #reverse_primer = re.compile('AGGTG.TGCA') 
        # 6 nts of reverse primer:
        reverse_primer = re.compile('AGGTG.') 

        # forward primers:
        # 
        # 967F-AQ    C    T A A    C    CG A    .    G    AACCT [CT] ACC
        # 967F-UC3   A    T A C    G    CG A    [AG] G    AACCT T    ACC
        # 967F-PP    C    . A C    G    CG A    A    G    AACCT T    A.C
        # 967F-UC12* C    A A C    G    CG [AC] A    [AG] AACCT T    ACC
        #  COMBINED: [CA] . A [AC] [CG] CG [AC] .    [AG] AACCT [CT] A.C

        #forward_primer = re.compile('A[AC][CG]CG[AC].[AG]AACCT[CT]A.C')
        # last 10 nts of the forward primer
        forward_primer = re.compile('[AG]AACCT[CT]A.C')
                                   

    def get_percent(x, y):
        if y == 0:
            return 0.0
        else:
            return x * 100.0 / y

    def reverse_complement(seq):
        conv_dict = {'A': 'T',
                     'T': 'A',
                     'C': 'G',
                     'G': 'C',
                     'N': 'N'}

        return ''.join(reversed([conv_dict[n] for n in seq]))


    #####################################################################################
    # dealing with output file pointers..
    #####################################################################################
    
    GetFilePath = lambda p: os.path.join(config.output_directory, config.project_name + '-' + p)
    
    errors_fp                 = open(GetFilePath('STATS.txt'), 'w')
    perfect_reads_fasta_fp    = open(GetFilePath('PERFECT_reads.fa'), 'w')
    id_tracker_output_path    = GetFilePath('READ_IDs.cPickle.z')
    
    if visualize_quality_curves:
        qual_dict_output_path = GetFilePath('Q_DICT.cPickle.z')

    #####################################################################################
    # some useful variables before we begin..
    #####################################################################################
   
    number_of_pairs = 0
    number_of_perfect_pairs = 0
    number_of_perfect_pairs_with_Ns = 0
    rp_failed_in_both = 0
    rp_failed_in_pair_1 = 0
    rp_failed_in_pair_2 = 0
    fp_failed_in_both = 0
    fp_failed_in_pair_1 = 0
    fp_failed_in_pair_2 = 0
    number_of_recovered_ambiguous_bases_in_pair_1 = 0
    number_of_recovered_ambiguous_bases_in_pair_2 = 0

    reverse_primer_found_locations_in_p1 = []
    reverse_primer_found_locations_in_p2 = []
    
    forward_primer_found_locations_in_p1 = []
    forward_primer_found_locations_in_p2 = []

    if visualize_quality_curves:
        qual_dict = QualityScoresHandler()
        
    read_id_tracker = ReadIDTracker()

    #####################################################################################
    # main loop per file listed in config:
    #####################################################################################
    COMPRESSED = lambda x: os.path.basename(config.pair_1[index]).endswith('.gz')
    
    for index in range(0, len(config.pair_1)):
        try:
            pair_1 = u.FastQSource(config.pair_1[index], compressed = True if COMPRESSED(config.pair_1[index]) else False)
            pair_2 = u.FastQSource(config.pair_2[index], compressed = True if COMPRESSED(config.pair_2[index]) else False)

        except u.FastQLibError, e:
            print "FastQLib is not happy.\n\n\t", e, "\n"
            sys.exit()


        #####################################################################################
        # main loop per read:
        #####################################################################################
       
        while pair_1.next() and pair_2.next():
            number_of_pairs += 1
            h1, s1, q1 = pair_1.entry.header_line, pair_1.entry.sequence, pair_1.entry.qual_scores
            h2, s2, q2 = pair_2.entry.header_line, reverse_complement(pair_2.entry.sequence), ''.join(reversed(pair_2.entry.qual_scores))

            if pair_1.p_available:
                pair_1.print_percentage()


            #####################################################################################
            # looking for reverse primer:
            #####################################################################################


            # find all instances of matching
            rp1 = get_rp_match(s1, reverse_primer)
            rp2 = get_rp_match(s2, reverse_primer)

            if (not rp1) or (not rp2):
                # reverse primer wasn't there for at least one read
                if (not rp1) and (not rp2):
                    rp_failed_in_both += 1
                elif not rp1:
                    rp_failed_in_pair_1 += 1
                else:
                    rp_failed_in_pair_2 += 1

                read_id_tracker.update(pair_1, pair_2, 'FAILED_RP')
                if visualize_quality_curves:
                    qual_dict.update(pair_1, pair_2, 'FAILED_RP')

                continue

            reverse_primer_found_locations_in_p1.append(rp1.start())
            reverse_primer_found_locations_in_p2.append(rp2.start())

            # strip reverse primers
            s1, q1 = s1[:rp1.start()], q1[:rp1.start()]
            s2, q2 = s2[:rp2.start()], q2[:rp2.start()]


            #####################################################################################
            # looking for forward primer
            #####################################################################################

            fp1 = forward_primer.search(s1, 0, 40)
            fp2 = forward_primer.search(s2, 0, 40)
            
            if (not fp1) or (not fp2):
                # forward primer wasn't there for at least one read
                if (not fp1) and (not fp2):
                    fp_failed_in_both += 1
                elif not fp1:
                    fp_failed_in_pair_1 += 1
                else:
                    fp_failed_in_pair_2 += 1

                read_id_tracker.update(pair_1, pair_2, 'FAILED_FP')
                if visualize_quality_curves:
                    qual_dict.update(pair_1, pair_2, 'FAILED_FP')
                
                continue

            # strip forward primers
            s1, q1 = s1[fp1.end():], q1[fp1.end():]
            s2, q2 = s2[fp2.end():], q2[fp2.end():]

            #####################################################################################
            # compare reads
            #####################################################################################
          
            if recover_ambiguous_bases and len(s1) == len(s2):
                if 'N' in s1 and s1.count('N') == 1:
                    pos = s1.find('N')
                    s1 = s1[0:pos] + s2[pos] + s1[pos + 1:]
                    number_of_recovered_ambiguous_bases_in_pair_1 += 1
                if 'N' in s2 and s2.count('N') == 1:
                    pos = s2.find('N')
                    s2 = s2[0:pos] + s1[pos] + s2[pos + 1:]
                    number_of_recovered_ambiguous_bases_in_pair_2 += 1

            if s1 == s2:
                perfect_reads_fasta_fp.write('>%s\n' % h1)
                perfect_reads_fasta_fp.write('%s\n' % s1)
                number_of_perfect_pairs += 1

                read_id_tracker.update(pair_1, pair_2, 'PASSED')
                if visualize_quality_curves:
                    qual_dict.update(pair_1, pair_2, 'PASSED')

                if 'N' in s1:
                    number_of_perfect_pairs_with_Ns += 1
            else:
                read_id_tracker.update(pair_1, pair_2, 'FAILED_MISMATCH')
                if visualize_quality_curves:
                    qual_dict.update(pair_1, pair_2, 'FAILED_MISMATCH')


    total_pairs_failed = number_of_pairs - number_of_perfect_pairs
    errors_fp.write('number of pairs                 : %d\n' % (number_of_pairs))
    errors_fp.write('total pairs passed              : %d (%%%.2f of all pairs)\n' % (number_of_perfect_pairs, get_percent(number_of_perfect_pairs, number_of_pairs)))
    errors_fp.write('  perfect pairs with Ns         : %d (%%%.2f of perfect pairs)\n' % (number_of_perfect_pairs_with_Ns, get_percent(number_of_perfect_pairs_with_Ns, number_of_perfect_pairs)))
    errors_fp.write('  recovered ambiguous bases (p1): %d (%%%.2f of perfect pairs)\n' % (number_of_recovered_ambiguous_bases_in_pair_1, get_percent(number_of_recovered_ambiguous_bases_in_pair_1, number_of_perfect_pairs)))
    errors_fp.write('  recovered ambiguous bases (p2): %d (%%%.2f of perfect pairs)\n' % (number_of_recovered_ambiguous_bases_in_pair_2, get_percent(number_of_recovered_ambiguous_bases_in_pair_2, number_of_perfect_pairs)))
    errors_fp.write('total pairs failed              : %d (%%%.2f of all pairs)\n' % (total_pairs_failed, get_percent(total_pairs_failed, number_of_pairs)))
    errors_fp.write('  FP failed in both pairs       : %d (%%%.2f of all failed pairs)\n' % (fp_failed_in_both, get_percent(fp_failed_in_both, total_pairs_failed)))
    errors_fp.write('  FP failed only in pair 1      : %d (%%%.2f of all failed pairs)\n' % (fp_failed_in_pair_1, get_percent(fp_failed_in_pair_1, total_pairs_failed)))
    errors_fp.write('  FP failed only in pair 2      : %d (%%%.2f of all failed pairs)\n' % (fp_failed_in_pair_2, get_percent(fp_failed_in_pair_2, total_pairs_failed)))
    errors_fp.write('  RP failed in both pairs       : %d (%%%.2f of all failed pairs)\n' % (rp_failed_in_both, get_percent(rp_failed_in_both, total_pairs_failed)))
    errors_fp.write('  RP failed only in pair 1      : %d (%%%.2f of all failed pairs)\n' % (rp_failed_in_pair_1, get_percent(rp_failed_in_pair_1, total_pairs_failed)))
    errors_fp.write('  RP failed only in pair 2      : %d (%%%.2f of all failed pairs)\n' % (rp_failed_in_pair_2, get_percent(rp_failed_in_pair_2, total_pairs_failed)))
    for fate in [f for f in read_id_tracker.fates if f.startswith('FAILED')]:
        errors_fp.write('  %s%s: %d (%%%.2f of all failed pairs)\n' % (fate,
                                                                       (' ' * (30 - len(fate))),
                                                                       len(read_id_tracker.ids[fate]),
                                                                       get_percent(len(read_id_tracker.ids[fate]), total_pairs_failed)))

    print

    
    sys.stderr.write('\rRead ID tracker dict is being stored ...')
    sys.stderr.flush()
    read_id_tracker.store(id_tracker_output_path)
    sys.stderr.write('\n')

    if visualize_quality_curves:
        sys.stderr.write('\rQuality scores dict is being stored ...')
        sys.stderr.flush()
        qual_dict.store_dict(qual_dict_output_path)

        for entry_type in qual_dict.entry_types:
            sys.stderr.write('\rQuality scores visualization in progress: %s%s' % (entry_type, ' ' * 20))
            sys.stderr.flush()
            visualize_qual_stats_dict(qual_dict.data[entry_type], GetFilePath(entry_type),\
                            title = 'Mean PHRED scores for pairs tagged as "%s"' % entry_type)
        sys.stderr.write('\n')

    errors_fp.close()
    perfect_reads_fasta_fp.close()

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='Dealing with V6 perfect overlaps')
    parser.add_argument('user_config', metavar = 'CONFIG_FILE',
                                        help = 'User configuration to run. See the source code to\
                                                see an example.')
    parser.add_argument('--archaea', action = 'store_true', default = False,
                                        help = 'When set, primers for arhacea is used instead of \
                                                bacteria.')
    parser.add_argument('--visualize-quality-curves', action = 'store_true', default = False,
                                        help = 'When set, mean quality score for individual bases will be\
                                                stored and visualized for each group of reads.')
    parser.add_argument('--recover-ambiguous-bases', action = 'store_true', default = False,
                                        help = 'When set, the script would attempt to recover an ambiguous base\
                                                from the other pair (only if there is only one N in the read). The idea\
                                                here is to save more reads if there is a failed cycle during\
                                                the sequencing, hoping that reads with low quality would present\
                                                more than one ambiguous base, or would have mismatches at other positions even\
                                                after the ambiguous base is recovered. Nevertheless, when this argument is set,\
                                                the accuracy of the quality filtered reads would expected to decrease. This \
                                                parameter does not recover ambiguous bases if they happen to occur in primer\
                                                regions.')


    args = parser.parse_args()
    user_config = ConfigParser.ConfigParser()
    user_config.read(args.user_config)

    try: 
        config = RunConfiguration(user_config)
    except ConfigError, e:
        print "There is something wrong with the config file. This is what we know: \n\n", e
        print
        sys.exit()

    sys.exit(main(config,
                  archaea = args.archaea,
                  visualize_quality_curves = args.visualize_quality_curves,
                  recover_ambiguous_bases = args.recover_ambiguous_bases))
