#!/usr/bin/env python

# Copyright (C) 2018 Ian Harry
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Plot the "Q-scan" time-frequency representation of single-detector strain data.

See https://iopscience.iop.org/article/10.1088/0264-9381/21/20/024 for
information on the Q-transform and its parameters.
"""

import sys
import argparse
import logging
import numpy

import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
from matplotlib.colors import LogNorm

import pycbc.strain
import pycbc.version
import pycbc.results

# https://stackoverflow.com/questions/9978880/python-argument-parser-list-of-list-or-tuple-of-tuples
def t_window(s):
    try:
        start, end = map(float, s.split(','))
        return [start, end]
    except:
        raise argparse.ArgumentTypeError("Input must be start,end start,end")

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--version", action="version",
                    version=pycbc.version.git_verbose_msg)
parser.add_argument('--output-file', required=True, help='Output plot')
parser.add_argument('--center-time', type=float,
                    help='Center plot on the given GPS time. If omitted, use '
                         'center of interval set via --gps-start-time and '
                         '--gps-end-time')
parser.add_argument('--time-windows', required=True, type=t_window,
                    nargs='+', metavar='START,END',
                    help='Use these set of times for time windows. Each '
                         'window produces a different panel. Should '
                         'be provided as start1,end1 start2,end2 ... '
                         'where start/end times are relative to --center-time '
                         'and both positive')
parser.add_argument("--low-frequency-cutoff", type=float,
                    help="The low frequency cutoff to use for fake strain generation")

parser.add_argument('--qtransform-delta-t', default=0.001, type=float,
                    help='The time resolution to interpolate to (optional)')
parser.add_argument('--qtransform-delta-f', type=float,
                    help='Frequency resolution to interpolate to (optional)')
parser.add_argument('--qtransform-logfsteps', type=int, default=200,
                    help='Do a log interpolation (incompatible with '
                         '--qtransform-delta-f option) and set the number '
                         'of steps to take')
parser.add_argument('--qtransform-frange-lower', type=float,
                    help='Lower frequency (in Hz) at which to compute '
                         'qtransform. Optional, default=20 Hz')
parser.add_argument('--qtransform-frange-upper', type=float,
                    help='Upper frequency (in Hz) at which to compute '
                         'qtransform. Optional, default=Half of Nyquist')
parser.add_argument('--qtransform-qrange-lower', default=4, type=float,
                    help='Lower limit of the range of q to consider, '
                         'default=%(default)f')
parser.add_argument('--qtransform-qrange-upper', default=64, type=float,
                    help='Upper limit of the range of q to consider, '
                         'default=%(default)f')
parser.add_argument('--qtransform-mismatch', default=0.2, type=float,
                    help='Mismatch between frequency tiles, default=%(default)f')

parser.add_argument('--linear-y-axis', dest='log_y', default=True,
                    action='store_false',
                    help='Use a linear y-axis. By default a log axis is used.')
parser.add_argument('--linear-colorbar', dest='log_colorbar', default=True,
                    action='store_false',
                    help='Use a linear colorbar scale.')
parser.add_argument('--plot-title',
                    help="If given, use this as the plot title")
parser.add_argument('--plot-caption',
                    help="If given, use this as the plot caption")
parser.add_argument('--colormap', choices=plt.colormaps(), default='viridis',
                    help='Colormap to use (default %(default)s)')
parser.add_argument('--mass1', type=float,
                    help='Provide masses to plot the waveform track on top of '
                         'the qscan.')
parser.add_argument('--mass2', type=float,
                    help='Provide masses to plot the waveform track on top of '
                         'the qscan.')
parser.add_argument('--spin1z', type=float, default=0,
                    help='If masses provided, provide spins to plot the '
                         'waveform track on top of the qscan (default=0).')
parser.add_argument('--spin2z', type=float, default=0,
                    help='If masses provided, provide spins to plot the '
                         'waveform track on top of the qscan (default=0).')

pycbc.strain.insert_strain_option_group(parser)
opts = parser.parse_args()

logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)

if opts.center_time is None:
    center_time = (opts.gps_start_time + opts.gps_end_time) / 2.
else:
    center_time = opts.center_time

if center_time == -1.0:
    fig, axes = plt.subplots(1,1)
    fig.add_subplot(111, frameon=False)
    plt.tick_params(labelcolor='none', top=False, bottom=False, left=False,
                    right=False)
    plt.grid(False)
    plt.xlabel('Time from {:.3f} (s)'.format(opts.center_time))
    plt.ylabel('Frequency (Hz)')
    title = 'No data for this detector'
    pycbc.results.save_fig_with_metadata(
            fig, opts.output_file, cmd=' '.join(sys.argv),
            fig_kwds={'dpi': 150}, title=title, caption=opts.plot_caption)
    sys.exit(0)

plot_chirp = False
if opts.mass1 is not None or opts.mass2 is not None:
    if None in [opts.mass1, opts.mass2]:
        parser.error('Must provide either both --mass1 and --mass2 or neither')
    plot_chirp = True

strain = pycbc.strain.from_cli(opts, pycbc.DYN_RANGE_FAC)

if opts.qtransform_frange_upper is None and \
        opts.qtransform_frange_lower is None:
    curr_frange = (20, opts.sample_rate / 4.)
elif opts.qtransform_frange_upper is None or \
          opts.qtransform_frange_lower is None:
    parser.error('Must provide either both --qtransform-frange-upper and '
                 '--qtransform-frange-lower or neither option.')
else:
    curr_frange = (opts.qtransform_frange_lower, opts.qtransform_frange_upper)

rem_corrupted = True
if (center_time - strain.start_time) < 2 or (strain.end_time - center_time) < 2:
    rem_corrupted = False

strain = strain.whiten(4, 4, remove_corrupted=rem_corrupted)

wins = opts.time_windows
fig, axes = plt.subplots(len(wins), 1)
times_all = []
freqs_all = []
qvals_all = []

qrange = (opts.qtransform_qrange_lower, opts.qtransform_qrange_upper)

for curr_idx in range(len(wins)):
    curr_win = wins[curr_idx]
    # Catch the case that not enough data is available.
    if opts.center_time - curr_win[0] < strain.start_time:
        curr_win[0] = float(opts.center_time - strain.start_time - 0.01)
    if opts.center_time + curr_win[1] > strain.end_time:
        curr_win[1] = float(strain.end_time - opts.center_time - 0.01)
    strain_zoom = strain.time_slice(opts.center_time - curr_win[0],
                                    opts.center_time + curr_win[1])

    times, freqs, qvals = strain_zoom.qtransform(
            delta_t=opts.qtransform_delta_t, delta_f=opts.qtransform_delta_f,
            logfsteps=opts.qtransform_logfsteps, frange=curr_frange,
            qrange=qrange, mismatch=opts.qtransform_mismatch)
    times_all.append(times)
    freqs_all.append(freqs)
    qvals_all.append(qvals)

max_qval = max([qvals.max() for qvals in qvals_all])

for curr_idx in range(len(wins)):
    ax = axes[curr_idx] if len(wins) > 1 else axes
    times = times_all[curr_idx]
    freqs = freqs_all[curr_idx]
    qvals = qvals_all[curr_idx]
    curr_win = wins[curr_idx]

    norm = None
    if opts.log_colorbar:
        norm = LogNorm(vmin=1, vmax=max_qval)

    im = ax.pcolormesh(times - opts.center_time, freqs, qvals, norm=norm,
                       cmap=opts.colormap)
    ax.set_xlim(-curr_win[0], curr_win[1])
    ax.set_ylim(curr_frange[0], curr_frange[1])
    if opts.log_y:
        ax.set_yscale('log')

# https://stackoverflow.com/questions/6963035/pyplot-axes-labels-for-subplots
fig.add_subplot(111, frameon=False)
plt.tick_params(labelcolor='none', top=False, bottom=False, left=False,
                right=False)
plt.grid(False)
plt.xlabel('Time from {:.3f} (s)'.format(opts.center_time))
plt.ylabel('Frequency (Hz)')

# https://stackoverflow.com/questions/13784201/matplotlib-2-subplots-1-colorbar
cb = fig.colorbar(im, ax=(axes.ravel().tolist() if len(wins) > 1 else axes))
cb.set_label('Normalized power')

if plot_chirp:
    from pycbc.pnutils import get_inspiral_tf
    f_low = 20.
    approximant = 'SPAtmplt' if opts.mass1+opts.mass2<4 else 'SEOBNRv4_ROM'
    track_t, track_f = get_inspiral_tf(opts.center_time, opts.mass1,
                                       opts.mass2, opts.spin1z, opts.spin2z,
                                       f_low, approximant=approximant)
    for curr_idx in range(len(wins)):
        ax = axes[curr_idx] if len(wins) > 1 else axes
        ax.plot(track_t - opts.center_time, track_f, 'r-', lw=1.5)

if opts.channel_name is not None:
    plt.suptitle(opts.channel_name)

if opts.plot_title is None:
    opts.plot_title = 'Q-transform plot around {:.3f}'.format(opts.center_time)
if opts.plot_caption is None:
    opts.plot_caption = ("This shows the Q-transform as a function of time and "
                         "frequency.")
    if opts.channel_name is not None:
        opts.plot_caption += (' The strain channel is ' + opts.channel_name
                              + '.')
    if plot_chirp:
        chirp_caption = (' The red curve is the time-frequency curve of a'
                         ' quadrupole-order quasicircular aligned-spin'
                         ' inspiral with mass1={:.3f}, mass2={:.3f},'
                         ' spin1z={:.3f}, spin1z={:.3f}.')
        opts.plot_caption += chirp_caption.format(opts.mass1, opts.mass2,
                                                  opts.spin1z, opts.spin2z)

pycbc.results.save_fig_with_metadata(
        fig, opts.output_file, cmd=' '.join(sys.argv), fig_kwds={'dpi': 150},
        title=opts.plot_title, caption=opts.plot_caption)
