#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2010,2011 Marcel Martin <marcel.martin@tu-dortmund.de>
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

"""%prog [options] <FASTA/FASTQ FILE> [<QUALITY FILE>]

Reads a FASTA or FASTQ file, finds and removes adapters,
and writes the changed sequence to standard output.
When finished, statistics are printed to standard error.

If two file names are given, the first must be a .fasta or .csfasta
file and the second must be a .qual file. This is the file format
used by some 454 software and by the SOLiD sequencer.
If you have color space data, you still need to provide the -c option
to correctly deal with color space!

If the name of any input or output file ends with '.gz', it is
assumed to be gzip-compressed.

If you want to search for the reverse complement of an adapter, you must
provide an additional adapter sequence using two '-a' parameters.

If the input sequences are in color space, the adapter must
also be provided in color space (using a string of digits 0123).

EXAMPLE

Assuming your sequencing data is available as a FASTQ file, use this
command line:
$ cutadapt -e ERROR-RATE -a ADAPTER-SEQUENCE input.fastq > output.fastq

See the README file for more help and examples."""

from __future__ import print_function, division

import sys
import re
import gzip
import time
from os.path import dirname, join, isfile, realpath
from string import maketrans
from optparse import OptionParser, OptionGroup
from contextlib import closing
from collections import defaultdict

# If running from within source directory,
# add 'lib' to sys.path.
_libdir = join(dirname(realpath(__file__)), 'lib')
if isfile(join(_libdir, 'cutadapt', '__init__.py')):
	sys.path.insert(0, _libdir)

from cutadapt import align, seqio
from cutadapt.xopen import xopen
from cutadapt.qualtrim import quality_trim_index
from cutadapt import __version__

# constants for the find_best_alignment function
BACK = align.START_WITHIN_SEQ2 | align.STOP_WITHIN_SEQ2 | align.STOP_WITHIN_SEQ1
ANYWHERE = align.SEMIGLOBAL

# for double-encoding colorspace sequences
DOUBLE_ENCODE_TRANS = maketrans('0123.', 'ACGTN')


class FormatError(Exception):
	"""Exception raised for parse errors in the input files"""
	pass


class HelpfulOptionParser(OptionParser):
	"""An OptionParser that prints full help on errors."""
	def error(self, msg):
		self.print_help(sys.stderr)
		self.exit(2, "\n%s: error: %s\n" % (self.get_prog_name(), msg))


def print_histogram(d):
	"""d -- a dictionary mapping values to their respective frequency"""
	print("length", "count", sep="\t")
	for key in sorted(d):
		print(key, d[key], sep="\t")
	print()


class Statistics(object):
	"""Store statistics about reads and adapters"""

	def __init__(self, adapters):
		self.reads_changed = 0
		self.too_short = 0
		self.too_long = 0
		self.n = 0
		self._start_time = time.clock()
		self.time = None
		self.lengths_front = []
		self.lengths_back = []
		self.adapters = adapters
		for a in adapters:
			self.lengths_front.append(defaultdict(int))
			self.lengths_back.append(defaultdict(int))

	def stop_clock(self):
		"""Stop the timer that was automatically started when the class was instantiated."""
		self.time = time.clock() - self._start_time

	def print_statistics(self, error_rate):
		"""Print summary to stdout"""
		if self.time is None:
			self.stop_clock()
		print("cutadapt version", __version__)
		print("Command line parameters:", " ".join(sys.argv[1:]))
		print("Maximum error rate: %.2f%%" % (error_rate * 100.))
		print("   Processed reads:", self.n)
		print("     Trimmed reads:", self.reads_changed, "(%5.1f%%)" % (100. * self.reads_changed / self.n))
		print("   Too short reads:", self.too_short, "(%5.1f%% of processed reads)" % (100. * self.too_short / self.n))
		print("    Too long reads:", self.too_long, "(%5.1f%% of processed reads)" % (100. * self.too_long / self.n))
		print("        Total time: %9.2f s" % self.time)
		print("     Time per read: %9.2f ms" % (1000. * self.time / self.n))
		print()
		for index, (where, adapter) in enumerate(self.adapters):
			total_front = sum(self.lengths_front[index].values())
			total_back = sum(self.lengths_back[index].values())
			total = total_front + total_back
			assert where == ANYWHERE or (where == BACK and total_front == 0)

			print("=" * 3, "Adapter", index+1, "=" * 3)
			print()
			print("Adapter '%s'," % adapter.upper(), "length %d," % len(adapter), "was trimmed", total, "times.")
			if where == ANYWHERE:
				print(total_front, "times, it overlapped the 5' end of a read")
				print(total_back, "times, it overlapped the 3' end or was within the read")
				print()
				print("Histogram of adapter lengths (5')")
				print_histogram(self.lengths_front[index])
				print()
				print("Histogram of adapter lengths (3' or within)")
				print_histogram(self.lengths_back[index])
			else:
				print()
				print("Histogram of adapter lengths")
				print_histogram(self.lengths_back[index])


def find_best_alignment(adapters, seq, max_error_rate, minimum_overlap):
	"""
	Find the best matching adapter.

	adapters -- List of tuples (where, adapter). 'where' is either BACK or
	            ANYWHERE. The adapter itself is a string.
	seq -- The sequence to which each adapter will be aligned
	where -- Where in the sequence the adapter may be found.
		One of BACK and FRONT_OR_BACK. For both,
		the adapter will also be found if it is in the middle.
	max_error_rate -- Maximum allowed error rate. The error rate is
		the number of errors in the alignment divided by the length
		of the part of the alignment that matches the adapter.

	Return tuple (best_alignment, best_index).

	best_alignment is an alignment as returned by semiglobalalign.
	best_index is the index of the best adapter into the adapters list.
	"""
	best_score = 0
	best_alignment = None
	best_index = None
	for index, (where, adapter) in enumerate(adapters):
		# try to find an exact match first
		pos = seq.find(adapter)
		if pos >= 0:
			alignment = (None, None, 0, len(adapter), pos, pos + len(adapter), 0)
		else:
			alignment = align.globalalign(adapter, seq, where)
		(r1, r2, astart, astop, rstart, rstop, errors) = alignment
		length = astop - astart
		if length < minimum_overlap or errors/length > max_error_rate:
			continue

		# the length of the matching part of the adapter minus errors
		# determines which adapter fits best
		score = length - errors
		if score > best_score:
			best_alignment = alignment
			best_score = score
			best_index = index
	return (best_alignment, best_index)


def write_read(desc, seq, qualities, outfile):
	"""
	Write read in either FASTA or FASTQ format
	(depending on whether qualities is None or not) to outfile
	"""
	if qualities is None:
		# FASTA
		print('>%s\n%s' % (desc, seq), file=outfile)
	else:
		# FASTQ
		print('@%s\n%s\n+\n%s' % (desc, seq, qualities), file=outfile)


def read_sequences(seqfilename, qualityfilename, colorspace):
	"""
	Read sequences and (if available) quality information from either:
	* seqfilename in FASTA format (qualityfilename must be None)
	* seqfilename in FASTQ format (qualityfilename must be None)
	* seqfilename in .csfasta format and qualityfilename in .qual format
	  (SOLiD color space)

	Return a generator over tuples (description, sequence, qualities).
	qualities is None if no qualities are available.
	qualities are ASCII-encoded (chr(quality) + 33).
	"""
	#if ftype == 'FASTQ' and qualityfilename is not None:
		#raise ValueError("If a FASTQ file is given, no quality file can be provided.")

	if qualityfilename is not None:
		# read from .(CS)FASTA/.QUAL
		return seqio.FastaQualReader(seqfilename, qualityfilename, colorspace)
	else:
		# read from FASTA or FASTQ (autodetected)
		return seqio.SequenceReader(seqfilename, colorspace)


class ReadFilter(object):
	"""Filter reads according to length and according to whether any adapter matches."""

	def __init__(self, minimum_length, maximum_length, too_short_outfile, statistics):
		self.minimum_length = minimum_length
		self.maximum_length = maximum_length
		self.too_short_outfile = too_short_outfile
		self.statistics = statistics

	def keep(self, desc, seq, qualities):
		"""
		Return whether to keep the given read.
		"""
		if len(seq) < self.minimum_length:
			self.statistics.too_short += 1
			if self.too_short_outfile is not None:
				write_read(desc, seq, qualities, self.too_short_outfile)
			return False
		if len(seq) > self.maximum_length:
			self.statistics.too_long += 1
			return False
		return True


class AdapterCutter(object):
	"""Cut adapters from reads."""

	def __init__(self, options, adapters):
		"""Initialize this cutter.
		adapters --
			List of tuples (where, adapter). 'where' is either BACK
			or ANYWHERE, and adapter is the adapter as a string.
			Adapters will be converted to uppercase before aligning
			them.
		"""
		self.options = options
		self.adapters = [ (where, adapter.upper()) for (where, adapter) in adapters ]
		self.stats = Statistics(adapters)
		# regular expression used to replace "length=..." strings in sequence descriptions
		if self.options.length_tag is not None:
			self.lengthregex = re.compile(r"\b" + self.options.length_tag + r"[0-9]*\b")
		else:
			self.lengthregex = None

	def cut(self, desc, seq, qualities):
		"""
		Cut adapters from a single read. The read will be converted to uppercase
		before it is compared to the adapter sequences.

		seq -- sequence of the read
		desc -- description of the read
		qualities -- quality values of the read

		Return a tuple (seq, desc, qualities, trimmed) with the modified read.
		trimmed is True when any adapter was found and trimmed.
		"""
		self.stats.n += 1

		if __debug__:
			old_length = len(seq)

		# try (possibly more than once) to remove an adapter
		any_adapter_matches = False
		for t in xrange(self.options.times):
			alignment, index = find_best_alignment(self.adapters, seq.upper(), self.options.error_rate, self.options.overlap)
			if alignment is None:
				# nothing found
				break

			(r1, r2, astart, astop, rstart, rstop, errors) = alignment
			length = astop - astart
			assert length > 0
			assert errors/length <= self.options.error_rate
			assert length - errors > 0

			any_adapter_matches = True
			where = self.adapters[index][0]

			if where == BACK or (astart == 0 and rstart > 0):
				assert where == ANYWHERE or astart == 0
				# The adapter is at the end of the read or within the read
				if rstop < len(seq):
					# The adapter is within the read
					if self.options.rest_file is not None:
						print(seq[rstop:], file=self.options.rest_file)
				self.stats.lengths_back[index][length] += 1

				if self.options.colorspace:
					# trim one more color if long enough
					rstart = max(0, rstart - 1)
				seq = seq[:rstart]
				if qualities is not None:
					qualities = qualities[:rstart]
					assert len(qualities) == len(seq)
			elif where == ANYWHERE:
				# The adapter is in the beginning of the read (case 4)
				assert rstart == 0
				self.stats.lengths_front[index][length] += 1
				# TODO What should we do in color space?
				seq = seq[rstop:]
				if qualities is not None:
					qualities = qualities[rstop:]
			else:
				assert False

		if __debug__:
			# if an adapter was found, then the read should now be shorter
			assert (not any_adapter_matches) or (len(seq) < old_length)
		if any_adapter_matches: # TODO move to filter class
			self.stats.reads_changed += 1

		# other modifications to the sequence or its description
		# change any length=XX that may appear in the read description (for 454)
		if self.lengthregex is not None and desc.find(self.options.length_tag) >= 0:
			desc = self.lengthregex.sub(self.options.length_tag + str(len(seq)), desc)

		if self.options.strip_f3 and desc.endswith('_F3'):
			desc = desc[:-3]
		desc = self.options.prefix + desc + self.options.suffix
		if self.options.double_encode:
			# convert color space sequence to double-encoded colorspace (using
			# characters ACGTN to represent colors)
			seq = seq.translate(DOUBLE_ENCODE_TRANS)

		if self.options.colorspace:
			assert not qualities or len(seq) == len(qualities)
		return (desc, seq, qualities, any_adapter_matches)


def main():
	"""Main function that evaluates command-line parameters and contains the main loop over all reads."""
	parser = HelpfulOptionParser(usage=__doc__, version=__version__)

	group = OptionGroup(parser, "Options that influence how the adapters are found")
	group.add_option("-a", "--adapter", action="append", dest="adapters",
		help="Sequence of an adapter that was ligated to the 3' end. The adapter itself and anything that follows is trimmed. If multiple -a or -b options are given, only the best matching adapter is trimmed.")
	group.add_option("-b", "--anywhere", action="append", metavar="ADAPTER",
		help="Sequence of an adapter that was ligated to the 5' or 3' end. If the adapter is found within the read or overlapping the 3' end of the read, the behavior is the same as for the -a option. If the adapter overlaps the 5' end (beginning of the read), the initial portion of the read matching the adapter is trimmed, but anything that follows is kept. If multiple -a or -b options are given, only the best matching adapter is trimmed.")
	group.add_option("-e", "--error-rate", type=float, default=0.1,
		help="Maximum allowed error rate (no. of errors divided by the length of the matching region) (default: %default)")
	group.add_option("-n", "--times", type=int, metavar="COUNT", default=1,
		help="Try to remove adapters at most COUNT times. Useful when an adapter gets appended multiple times.")
	group.add_option("-O", "--overlap", type=int, metavar="LENGTH", default=3,
		help="Minimum overlap length. If the overlap between the adapter and the sequence is shorter than LENGTH, the read is not modified.")
	parser.add_option_group(group)

	group = OptionGroup(parser, "Options for filtering of processed reads")
	group.add_option("--discard", "--discard-trimmed", action='store_true', default=False,
		help="Discard reads that contain the adapter instead of trimming them. Also use -O in order to avoid throwing away too many randomly matching reads!")
	group.add_option("-m", "--minimum-length", type=int, default=0, metavar="LENGTH",
		help="Discard trimmed reads that are shorter than LENGTH. Reads that are too short even before adapter removal are also discarded. In colorspace, an initial primer is not counted.")
	group.add_option("-M", "--maximum-length", type=int, default=sys.maxint, metavar="LENGTH",
		help="Discard trimmed reads that are longer than LENGTH. "
			"Reads that are too long even before adapter removal "
			"are also discarded. In colorspace, an initial primer "
			"is not counted.")
	parser.add_option_group(group)

	group = OptionGroup(parser, "Options that influence what gets output to where")
	group.add_option("-o", "--output", default=None, metavar="FILE",
		help="The modified sequences get written to this file. The format is FASTQ if qualities are available, FASTA otherwise. (default: standard output)")
	group.add_option("-r", "--rest-file", default=None, metavar="FILE",
		help="When the adapter matches in the middle of a read, write the rest (after the adapter) into a file. Use - for standard output.")
	group.add_option("--too-short-output", default=None, metavar="FILE",
		help="Write reads that are too short (according to length specified by -m) to FILE. (default: discard reads)")
	group.add_option("--untrimmed-output", default=None, metavar="FILE",
		help="Write reads that do not contain the adapter to FILE, instead "
			"of writing them to the regular output file. (default: output "
			"to same file as trimmed)")
	parser.add_option_group(group)

	group = OptionGroup(parser, "Additional modifications to the reads")
	group.add_option("-q", "--quality-cutoff", type=int, default=None, metavar="CUTOFF",
		help="Trim low-quality ends from reads before adapter removal. "
			"The algorithm is the same as the one used by BWA "
			"(Subtract CUTOFF from all qualities; "
			"compute partial sums from all indices to the end of the "
			"sequence; cut sequence at the index at which the sum "
			"is minimal) (default: %default)")
	group.add_option("-x", "--prefix", default='',
		help="Add this prefix to read names")
	group.add_option("-y", "--suffix", default='',
		help="Add this suffix to read names")
	group.add_option("-c", "--colorspace", action='store_true', default=False,
		help="Colorspace mode: Also trim the color that is adjacent to the found adapter.")
	group.add_option("-d", "--double-encode", action='store_true', default=False,
		help="When in color space, double-encode colors (map 0,1,2,3,4 to A,C,G,T,N).")
	group.add_option("-t", "--trim-primer", action='store_true', default=False,
		help="When in color space, trim primer base and the first color "
			"(which is the transition to the first nucleotide)")
	group.add_option("--strip-f3", action='store_true', default=False,
		help="For color space: Strip the _F3 suffix of read names")
	group.add_option("--maq", "--bwa", action='store_true', default=False,
		help="MAQ- and BWA-compatible color space output. This enables -c, -d, -t, --strip-f3 and -y '/1'.")
	group.add_option("--length-tag", default=None, metavar="TAG",
		help="Search for TAG followed by a decimal number in the name of the read "
			"(description/comment field of the FASTA or FASTQ file). Replace the "
			"decimal number with the correct length of the trimmed read. "
			"For example, use --length-tag 'length=' to search for fields "
			"like 'length=123'.")
	parser.add_option_group(group)

	options, args = parser.parse_args()

	if len(args) == 0:
		parser.error("At least one parameter needed: name of a FASTA or FASTQ file.")
	elif len(args) > 2:
		parser.error("Too many parameters.")

	input_filename = args[0]
	quality_filename = None
	if len(args) == 2:
		quality_filename = args[1]

	# default output files (overwritten below)
	trimmed_outfile = sys.stdout # reads with adapters go here
	too_short_outfile = None # too short reads go here
	#too_long_outfile = None # too long reads go here

	if options.output is not None:
		trimmed_outfile = xopen(options.output, 'w')
	untrimmed_outfile = trimmed_outfile # reads without adapters go here
	if options.untrimmed_output is not None:
		untrimmed_outfile = xopen(options.untrimmed_output, 'w')
	if options.too_short_output is not None:
		too_short_outfile = xopen(options.too_short_output, 'w')
	#if options.too_long_output is not None:
		#too_long_outfile = xopen(options.too_long_output, 'w')

	if options.maq:
		options.colorspace = True
		options.double_encode = True
		options.trim_primer = True
		options.strip_f3 = True
		options.suffix = "/1"
	if options.trim_primer and not options.colorspace:
		parser.error("Trimming the primer makes only sense in color space.")
	if options.double_encode and not options.colorspace:
		parser.error("Double-encoding makes only sense in color space.")
	if options.anywhere and options.colorspace:
		parser.error("Using --anywhere with color space reads is currently not supported  (if you think this may be useful, contact the author).")
	if not (0 <= options.error_rate <= 1.):
		parser.error("The maximum error rate must be between 0 and 1.")

	if options.rest_file is not None:
		options.rest_file = xopen(options.rest_file, 'w')

	adapters = []
	if options.adapters:
		adapters = [ (BACK, adapter) for adapter in options.adapters ]
	if options.anywhere:
		adapters += [ (ANYWHERE, adapter) for adapter in options.anywhere ]
	del options.adapters
	del options.anywhere

	if not adapters:
		print("You need to provide at least one adapter sequence.", file=sys.stderr)
		return 1

	#total_bases = 0
	#total_quality_trimmed = 0
	cutter = AdapterCutter(options, adapters)
	readfilter = ReadFilter(options.minimum_length, options.maximum_length, too_short_outfile, cutter.stats) # TODO stats?
	try:
		reader = read_sequences(input_filename, quality_filename, colorspace=options.colorspace)
		for desc, seq, qualities in reader:
			# In colorspace, the first character is the last nucleotide of the primer base
			# and the second character encodes the transition from the primer base to the
			# first real base of the read.
			if options.trim_primer:
				seq = seq[2:]
				if qualities is not None:
					qualities = qualities[1:]
				initial = ''
			elif options.colorspace:
				initial = seq[0]
				seq = seq[1:]
			else:
				initial = ''

			#total_bases += len(qualities)
			if options.quality_cutoff is not None:
				index = quality_trim_index(qualities, options.quality_cutoff)
				#total_quality_trimmed += len(qualities) - index
				qualities = qualities[:index]
				seq = seq[:index]

			desc, seq, qualities, trimmed = cutter.cut(desc, seq, qualities)
			if readfilter.keep(desc, seq, qualities):
				seq = initial + seq
				write_read(desc, seq, qualities, trimmed_outfile if trimmed else untrimmed_outfile)
	except FormatError as e:
		print(e, file=sys.stderr)
		return 1
	sys.stdout = sys.stderr
	cutter.stats.print_statistics(options.error_rate)
	sys.stdout = sys.__stdout__
	return 0


if __name__ == '__main__':
	if len(sys.argv) > 1 and sys.argv[1] == '--profile':
		del sys.argv[1]
		import cProfile as profile
		profile.run('main()', 'cutadapt.prof')
	else:
		sys.exit(main())
