#! /user/bin/env python

# Copyright (C) 2017 Collin Capano
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#
"""Extracts posterior samples from a Sampler HDF file, writing to a
posterior HDF file. Parameters can be renamed in the output file using the
--parameters option.

If more than one file is provided, the samples from all of the files will
be combined. Some attempt is made to check that the files share common
attributes. Sampler info (stored in the 'sampler_info' group) will not be
written when multiple files are combined. All other groups, unless explicitly
skipped (see the --skip-groups option), will be written using the first file.
No attempt is made to check that the data in these other groups is the same
across combined files.
"""

import os
import argparse
import logging
import numpy
import pycbc
from pycbc.inference.io import (ResultsArgumentParser, results_from_cli,
                                PosteriorFile, loadfile)
from pycbc.inference.io.base_hdf import format_attr


def isthesame(current_val, val):
    """Checks that attrs values from two hdf files are the same (or nearly so).
    """
    if isinstance(current_val, float):
        issame = numpy.isclose(current_val, val, equal_nan=True)
    elif isinstance(current_val, numpy.ndarray):
        # sort before comparing
        current_val = numpy.sort(current_val)
        val = numpy.sort(val)
        try:
            # in case of floats, just check that they're close
            issame = numpy.isclose(current_val, val, equal_nan=True)
        except TypeError:
            # wasn't float, do direct comparison
            issame = current_val == val
    else:
        issame = current_val == val
    if isinstance(issame, numpy.ndarray):
        issame = issame.all()
    # we didn't have an array, assume it was a single value
    return issame

parser = ResultsArgumentParser(defaultparams='all', autoparamlabels=False,
                               description=__doc__)

parser.add_argument("--output-file", type=str, required=True,
                    help="Output file to create.")
parser.add_argument("--force", action="store_true", default=False,
                    help="If the output-file already exists, overwrite it. "
                         "Otherwise, an IOError is raised.")
parser.add_argument("--skip-groups", default=None, nargs="+",
                    help="Don't write the specified groups in the output "
                         "(aside from samples; samples are always written), "
                         "for example, 'sampler_info'. If 'all' skip "
                         "all groups, only writing the samples. Default is "
                         "to write all groups if only one file is provided, "
                         "and all groups from the first file except "
                         "sampler_info if multiple files are provided.")
parser.add_argument("--verbose", action="store_true", default=False,
                    help="Be verbose")

opts = parser.parse_args()

pycbc.init_logging(opts.verbose)

# check that the output doesn't already exist
if os.path.exists(opts.output_file) and not opts.force:
    raise IOError("output file already exists; use --force if you wish to "
                  "overwrite.")

# load the samples
fps, params, labels, samples = results_from_cli(opts)
if len(opts.input_file) == 1:
    fps = [fps]

# convert samples to a dict in which the keys are the labels
# also stack results if multiple files were provided
if len(opts.input_file) > 1:
    samples = {labels[p]: numpy.concatenate([s[p] for s in samples])
               for p in params}
else:
    samples = {labels[p]: samples[p] for p in params}

# create the file
outtype = PosteriorFile.name
out = loadfile(opts.output_file, 'w', filetype=outtype) 

# write the samples
out.write_samples(samples)

# Preserve samples group metadata
for fp in fps:
    for key, val in fp[fp.samples_group].attrs.items():
        val = format_attr(val)
        try:
            current_val = out[out.samples_group].attrs[key]
        except KeyError:
            out[out.samples_group].attrs[key] = val
            current_val = val
        # enforce that the metadata must be the same across multiple files
        if not isthesame(current_val, val):
            raise ValueError("cannot combine all files; samples group attr {} "
                             "is not the same across all files. ({} vs {})"
                             .format(key, current_val, val))

# Preserve top-level metadata...
# ...except for: the filetype, since that's set when the file was loaded;
# the automatic thinning settings, since those will no longer make sense;
# cmd (we will replace it with this program's command)
# resume points (may not be the same across multiple files)
# effective nsamples (that can just be obtained from the samples size)
skip_attrs = ['filetype', 'thin_start', 'thin_interval', 'thin_end',
              'thinned_by', 'cmd', 'resume_points', 'effective_nsamples']
# also skip evidence if multiple files are being combined, since that will
# not be the same
if len(opts.input_file) > 1:
    skip_attrs += ['log_evidence', 'dlog_evidence']
for fp in fps:
    for key in map(format_attr, fp.attrs):
        if key not in skip_attrs:
            val = format_attr(fp.attrs[key])
            try:
                current_val = format_attr(out.attrs[key])
            except KeyError:
                out.attrs[key] = val
                current_val = format_attr(out.attrs[key])
            if not isthesame(current_val, val):
                raise ValueError("cannot combine all files; file attr {} is "
                                 "not the same across all files ({} vs {})"
                                 .format(key, current_val, val))

# store what parameters were renamed
out.attrs['remapped_params'] = list(labels.items())

# write the other groups using the first file
fp = fps[0]
skip_groups = opts.skip_groups
if skip_groups is not None and 'all' in opts.skip_groups:
    skip_groups = [group for group in fp.keys()
                   if group != fp.samples_group]
# don't write the sampler info if more than one file was provided
if len(opts.input_file) > 1:
    if skip_groups is None:
        skip_groups = []
    skip_groups.append('sampler_info')

fp.copy_info(out, ignore=skip_groups)

# close and exit
for fp in fps:
    fp.close()
out.close()
