#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2019 Duncan Macleod
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Cluster triggers from a pyGRB run
"""

from __future__ import print_function

import argparse
import operator
import os
import re
from collections import defaultdict
from itertools import compress
try:
    from functools import reduce
except ImportError:  # python < 2
    pass

import tqdm

import numpy

import h5py

from gwdatafind.utils import filename_metadata

from ligo.segments import segmentlist
from ligo.segments.utils import fromsegwizard

from pycbc import __version__
from pycbc.inject import InjectionSet
from pycbc.io import FieldArray

__author__ = "Duncan Macleod <duncan.macleod@ligo.org>"

SPLIT_FILENAME = re.compile(r"_(SPLIT\d+)")

NETWORK_IFO_EVENT_ID_REGEX = re.compile(
    r"\Anetwork/(?P<ifo>[A-Z]1)_event_id\Z",
)
EVENT_ID_REGEX = re.compile(r"event_id\Z")

TQDM_BAR_FORMAT = ("{desc}: |{bar}| "
                   "{n_fmt}/{total_fmt} {unit} ({percentage:3.0f}%) "
                   "[{elapsed} | ETA {remaining}]{postfix}")
TQDM_KW = {
    "ascii": " -=#",
    "bar_format": TQDM_BAR_FORMAT,
    "smoothing": 0.05,
}


# -- utilities ----------------------------------

def find_split_files(filelist, split):
    for fn in filelist:
        if (split is None and 'SPLIT' not in fn) or split in fn:
            yield fn


def read_segment_files(segfiles):
    def _read(name):
        with open(name, "r") as f:
            return fromsegwizard(f)

    return segmentlist(reduce(
        operator.or_,
        map(_read, segfiles),
        segmentlist()))


def read_hdf5_triggers(inputfiles, verbose=False):
    """Merge several HDF5 files into a single file

    Parameters
    ----------
    inputfiles : `list` of `str`
        the paths of the input HDF5 files to merge

    outputfile : `str`
        the path of the output HDF5 file to write
    """
    datasets = {}

    def _scan_dataset(name, obj):
        if not name.startswith("network") or not isinstance(obj, h5py.Dataset):
            return
        if NETWORK_IFO_EVENT_ID_REGEX.match(name):
            return
        shape = obj.shape
        dtype = obj.dtype
        try:
            shape = numpy.sum(datasets[name][0] + shape, keepdims=True)
        except KeyError:
            pass
        else:
            assert dtype == datasets[name][1], (
                "Cannot merge {0}/{1}, does not match dtype".format(
                    obj.file.filename, name,
                ))
        datasets[name] = (shape, dtype)

    # get list of datasets
    datasets = {}
    for filename in inputfiles:
        with h5py.File(filename, 'r') as h5f:
            h5f.visititems(_scan_dataset)

    position = defaultdict(int)

    out = {}

    # create datasets
    for dset, (shape, dtype) in datasets.items():
        out[dset[8:]] = numpy.empty(shape, dtype=dtype)

    # copy dataset contents
    for filename in inputfiles:
        with h5py.File(filename, 'r') as h5in:
            for dset in datasets:
                data = h5in[dset][:]
                size = data.shape[0]
                pos = position[dset]
                if EVENT_ID_REGEX.search(dset):
                    out[dset[8:]][pos:pos+size] = data + pos
                else:
                    out[dset[8:]][pos:pos+size] = data
                position[dset] += size

    return out


# -- parse command line -------------------------

parser = argparse.ArgumentParser(
    description=__doc__,
)

parser.add_argument(
    "-v",
    "--verbose",
    action="store_true",
    default=False,
    help="verbose output with microsecond timer (default: %(default)s)",
)
parser.add_argument(
    "-V",
    "--version",
    action="version",
    version=__version__,
    help="show version number and exit",
)

# input/output
parser.add_argument(
    "-f",
    "--input-files",
    nargs="+",
    help="path(s) to input trigger file(s)",
)
parser.add_argument(
    "-j",
    "--inj-files",
    nargs="+",
    help="path(s) to input injection file(s)",
)
parser.add_argument(
    "-o",
    "--output-dir",
    default=os.getcwd(),
    help="output directory (default: %(default)s)",
)
parser.add_argument(
    "-i",
    "--ifo-tag",
    help="the ifo tag, H1L1 or H1L1V1 for instance",
)

# parameters
parser.add_argument(
    "-W",
    "--time-window",
    type=float,
    default=0.,
    help="the found time window (default: %(default)s)",
)
parser.add_argument(
    "-c",
    "--rank-column",
    default="coherent_snr",
    help="column over which to rank events (default: %(default)s)",
)
parser.add_argument(
    "-b",
    "--exclude-segments",
    action="append",
    default=[],
    help="ignore injections in segments found within "
         "these files, e.g. buffer segments (may be given "
         "more than once",
)

args = parser.parse_args()

vprint = print if args.verbose else str
win = args.time_window

# -- find injections ----------------------------

exclude = read_segment_files(args.exclude_segments)

nexcluded = 0
missed = []
found = []
trigs = defaultdict(list)

allinjections = None

with tqdm.tqdm(args.inj_files, desc="Finding injections",
               disable=not args.verbose, unit="files",
               postfix=dict(found=0, missed=0, excluded=nexcluded),
               **TQDM_KW) as bar:
    for injf in bar:
        # get list of trigger files
        try:
            split, = SPLIT_FILENAME.search(injf).groups()
        except AttributeError:  # no regex match
            split = None
        trigfiles = list(find_split_files(args.input_files, split))

        # read injections and filter by exclude segments
        tmpinj = InjectionSet(injf).table
        injtime = numpy.array([inj.time_geocent for inj in tmpinj])
        keep = numpy.asarray([t not in exclude for t in injtime])
        nexcluded += (~keep).sum()
        injections = type(tmpinj)()
        injections.extend(compress(tmpinj, keep))
        injtime = injtime[keep]

        # read triggers
        triggers = read_hdf5_triggers(trigfiles)
        time = triggers["end_time_gc"]
        snr = triggers[args.rank_column]
        event_id = triggers["event_id"]
        time_sorting = time.argsort()

        # determine found or missed
        _left = numpy.searchsorted(
            time[time_sorting],
            injtime - args.time_window,
            side='left',
        )
        _right = numpy.searchsorted(
            time[time_sorting],
            injtime + args.time_window,
            side='right',
        )

        # get indices of found/missed injections and finding trigger
        for i, (l, r) in enumerate(
                zip(_left, _right),
                start=len(allinjections or []),
        ):
            if not r - l:
                missed.append(i)
            else:
                found.append(i)
                eid = l if r - l == 1 else event_id[l + snr[l:r].argmax()]
                for col in triggers:
                    trigs[col].append(triggers[col][i])

        # record all injections
        if allinjections is None:
            allinjections = injections
        else:
            allinjections.extend(injections)

        # update progress bar
        bar.set_postfix(
            found=len(found),
            missed=len(missed),
            excluded=nexcluded,
        )

# reformat injection table for writing
dtypes = {key: "string" if val == "lstring" else val for
          (key, val) in allinjections.validcolumns.items()}
injarr = FieldArray.from_ligolw_table(
    allinjections,
    cast_to_dtypes={x[0]: x for x in dtypes.items()},
)

# prepare output hdf5 file
ifotag, desc, segment = filename_metadata(args.inj_files[0])
desc = SPLIT_FILENAME.split(desc)[0]
outfilename = os.path.join(
    args.output_dir,
    "{}-{}_FOUNDMISSED-{}-{}.h5".format(
        args.ifo_tag or ifotag,
        desc,
        segment[0],
        abs(segment),
    ),
)

with h5py.File(outfilename, "w") as h5out:
    # write injections from recarray
    for name, injlist in (("found", found), ("missed", missed)):
        grp = h5out.create_group(name)
        injs = injarr[injlist]
        for field in injs.fieldnames:
            grp[field] = injs[field]

    # write triggers
    netg = h5out.create_group("network")
    for col in trigs:
        netg.create_dataset(
            col,
            data=trigs[col],
            compression="gzip",
            compression_opts=9,
        )

if args.verbose:
    print("Found/missed written to {}".format(outfilename))
