#!/usr/bin/env python

# Copyright (C) 2020 Collin Capano
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Recalculates log likelihood and prior for points in a given inference or
posterior file and writes them to a new file. Also records auxillary model
stats that may have been ignored by the sampler.
"""

import os
import argparse
import shutil
import logging
import time
import numpy

import pycbc
from pycbc.pool import (use_mpi, choose_pool)
from pycbc.io import FieldArray
from pycbc.inference.io import loadfile
from pycbc.inference import models
from pycbc.workflow import WorkflowConfigParser


parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--input-file", type=str, required=True,
                    help="Input HDF file to read. Can be either an inference "
                         "file or a posterior file.")
parser.add_argument("--output-file", type=str, required=True,
                    help="Output file to create.")
parser.add_argument("--force", action="store_true", default=False,
                    help="If the output-file already exists, overwrite it. "
                         "Otherwise, an OSError is raised.")
parser.add_argument("--nprocesses", type=int, default=1,
                    help="Number of processes to use. If not given then only "
                         "a single core will be used.")
parser.add_argument("--config-file", nargs="+", type=str, default=None,
                    help="Override the config file stored in the input file "
                          "with the given file(s).")
parser.add_argument("--verbose", action="store_true", default=False,
                    help="Print logging messages.")


# parse command line
opts = parser.parse_args()

# check that the output file doesn't already exist
if os.path.exists(opts.output_file) and not opts.force:
    raise OSError("output-file already exists; use force if you "
                  "wish to overwrite it.")

# setup log
# If we're running in MPI mode, only allow the parent to print
use_mpi, size, rank = pycbc.pool.use_mpi(log=False)
if use_mpi:
    opts.verbose &= rank == 0
pycbc.init_logging(opts.verbose)

# load the config file to get the model
logging.info("Loading config file")
if opts.config_file is None:
    # try to load the config file from the input file
    try:
        with loadfile(opts.input_file, 'r') as fp:
            cp = fp.read_config_file()
    except ValueError:
        raise ValueError("no config file found in {}; please provide one "
                         "using the --config-file".format(opts.input_file))
else:
    cp = WorkflowConfigParser(opts.config_file)

# now load the model
logging.info("Loading model")
model = models.read_from_config(cp)
# turn off any sampling transforms as we'll be passing in parameters directly
# from the variable parameter space
model.sampling_transforms = None

# create function for calling the model to get the stats
def callmodel(paramvals):
    # paramvals should be a numpy record instance
    model.update(**{p: paramvals[p] for p in model.variable_params})
    # calculate the logposterior to get all stats populated
    _ = model.logposterior
    return model.get_current_stats()

# these help for parallelization
models._global_instance = callmodel
model_call = models._call_global_model

# setup the pool
pool = choose_pool(processes=opts.nprocesses)

# load the samples
logging.info('Getting samples')
with loadfile(opts.input_file, 'r') as fp:
    # we'll need the shape; all the arrays in the samples group should have the
    # same shape, so we'll just use the first one
    shape = fp[fp.samples_group][model.variable_params[0]].shape
    samples = {}
    for p in model.variable_params:
        samples[p] = fp[fp.samples_group][p][()].flatten()
# convert the samples array to a FieldArray for easy iteration over
samples = FieldArray.from_kwargs(**samples)
logging.info("Loaded %i samples", samples.size)

estsize = 100
if samples.size > estsize:
    logging.info("Estimating run time using first %i samples", estsize)
    # estimate the time it will take
    before = time.time()
    stats = list(pool.map(model_call, samples[:estsize]))
    now = time.time()
    total = ((samples.size - estsize)*(now - before)/estsize)/opts.nprocesses
    logging.info("Will take ~%.2f more minutes", total/60.)
    samples = samples[estsize:]
else:
    stats = []

# get the stats
logging.info("Calculating stats")
stats += list(pool.map(model_call, samples))

# write to the output file
logging.info("Copying input to output")
shutil.copy(opts.input_file, opts.output_file)

logging.info("Writing stats to output")
out = loadfile(opts.output_file, 'a')
idx = range(len(stats))
for pi, p in enumerate(model.default_stats):
    vals = numpy.array([stats[ii][pi] for ii in idx]).reshape(shape)
    out.write_data(p, vals, path=fp.samples_group, append=False)
out.close()

logging.info("Done")
