#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (C) 2015 Chris Pankow
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Given a set of extrinsic evidence calculations on a given set of intrinsic parameters, refines the grid to do additional calculations.
"""

__author__ = "Chris Pankow <chris.pankow@ligo.org>"

import os
import sys
import glob
import json
import bisect
import re
import math
import operator
import warnings
from collections import defaultdict
from functools import reduce
from argparse import ArgumentParser
from copy import copy

import h5py
import numpy
import numpy as np
from scipy.special import binom
from sklearn.neighbors import BallTree

from ligo.lw import utils, ligolw, lsctables
lsctables.use_in(ligolw.LIGOLWContentHandler)
from ligo.lw.utils import process

import lalsimulation
from rapid_pe import amrlib, lalsimutils, common_cl

def get_cr_from_grid(cells, weight, cr_thr=0.9, min_n=None, max_n=None):
    """
    Given a set of cells and the weight of that cell, calculate a N% CR including cells which contribute to that probability mass. If n is set, cr_thr is ignored and instead this many points are taken.
    """
    if cr_thr == 0.0:
        return numpy.empty((0,))

    # Arrange them all with their respective weight
    cell_sort = numpy.hstack( (weight[:,numpy.newaxis], cells) )

    # Sort and form the CDF
    cell_sort = cell_sort[cell_sort[:,0].argsort()]
    cell_sort[:,0] = cell_sort[:,0].cumsum()
    cell_sort[:,0] /= cell_sort[-1,0]

    idx = cell_sort[:,0].searchsorted(1-cr_thr)
    n_select = cell_sort.shape[0] - idx
    if min_n is not None:
        n_select = max(n_select, min_n)
    if max_n is not None:
        n_select = min(n_select, max_n)
    idx = cell_sort.shape[0] - n_select

    return cell_sort[idx:,1:]

def determine_region(pt, pts, ovrlp, ovrlp_thresh, expand_prms={}):
    """
    Given a point (pt) in a set of points (pts), with a function value at those points (ovrlp), return a rectangular hull such that the function exceeds the value ovrlp_thresh.
    """
    sidx = bisect.bisect(ovrlp, ovrlp_thresh)
    print("Found %d neighbors with overlap >= %f" % (len(ovrlp[sidx:]), ovrlp_thresh))
#    print("HERE",ovrlp_thresh,ovrlp[sidx:],pts[sidx:],pt)
#    print "HERE",ovrlp_thresh,ovrlp[sidx:],pts[sidx:],pt
    

    cell = amrlib.Cell.make_cell_from_boundaries(pt, pts[sidx:])
    for k, lim in expand_prms.items():
        cell._bounds = numpy.vstack((cell._bounds, lim))
        # FIXME: Need to do center?
    #Force eta to be bounded at 0.25
    if opts.distance_coordinates == "mchirp_eta":
        if cell._bounds[1][1] > 0.25:
            cell._bounds[1][1] = 0.25
        if cell._bounds[1][0] >= 0.25:
            cell._bounds[1][0] = 0.25-1e-7
        if cell._bounds[1][0] < 0.01:
            #As of 201905 SEOBNRv4 cant handle mass ratios > 100, so don't let eta frop below 0.01
            cell._bounds[1][0] = 0.01
    #If ranges for Mc are specified, require bins to be in this range
    if opts.distance_coordinates == "mchirp_eta" and (opts.mc_min is not None or opts.mc_max is not None):
        mc_min = -1 if opts.mc_min is None else opts.mc_min
        mc_max = 99999999.0 if opts.mc_max is None else opts.mc_max
        if cell._bounds[0][0] < mc_min:
            cell._bounds[0][0] = mc_min
        if cell._bounds[0][1] > mc_max:
            cell._bounds[0][1] = mc_max
        if cell._bounds[0][1] < mc_min or cell._bounds[0][0] > mc_max:
            sys.exit("ERROR: All grid points are outside the specified mc-range. Are you sure you want to set --mc-min or mc-max? You probably don't.")

    return cell, sidx

def find_olap_index(tree, intr_prms, exact=True, **kwargs):
    """
    Given an object that can retrieve distance via a 'query' function (e.g. KDTree or BallTree), find the index of a point closest to the input point. Note that kwargs is used to get the current known values of the event. E.g.

    intr_prms = {'mass1': 1.4, 'mass2': 1.35}
    find_olap_index(tree, **intr_prms)
    """
    pt = numpy.array([kwargs[k] for k in intr_prms])

    # FIXME: Replace with standard function
    dist, m_idx = tree.query(numpy.atleast_2d(pt), k=1)
    dist, m_idx = dist[0][0], int(m_idx[0][0])

    # FIXME: There's still some tolerance from floating point conversions
    if exact and dist > 0.000001:
        exit("Could not find template in bank, closest pt was %f away" % dist)
    return m_idx, pt, dist

def write_to_xml(cells, intr_prms, pin_prms={}, fvals=None, fname=None, verbose=False):
    """
    Write a set of cells, with dimensions corresponding to intr_prms to an XML file as sim_inspiral rows.
    """
    xmldoc = ligolw.Document()
    xmldoc.appendChild(ligolw.LIGO_LW())
    procrow = process.register_to_xmldoc(xmldoc, sys.argv[0], opts.__dict__)
    procid = procrow.process_id

    rows = ["simulation_id", "numrel_data"]
    # Override eff_lambda to with psi0, its shoehorn column
    if "eff_lambda" in intr_prms:
        intr_prms[intr_prms.index("eff_lambda")] = "psi0"
    if "deff_lambda" in intr_prms:
        intr_prms[intr_prms.index("deff_lambda")] = "psi3"
    rows += list(intr_prms)
    rows += list(pin_prms)
    if fvals is not None:
        rows.append("alpha1")
    sim_insp_tbl = lsctables.New(lsctables.SimInspiralTable, rows)
    for itr, intr_prm in enumerate(cells):
        sim_insp = sim_insp_tbl.RowType()
        # FIXME: Need better IDs
        sim_insp.numrel_data = "INTR_SET_%d" % itr
        sim_insp.simulation_id = sim_insp_tbl.get_next_id()
        if fvals:
            sim_insp.alpha1 = fvals[itr]
        for p, v in zip(intr_prms, intr_prm._center):
            setattr(sim_insp, p, v)
        for p, v in pin_prms.items():
            setattr(sim_insp, p, v)
        sim_insp_tbl.append(sim_insp)

    xmldoc.childNodes[0].appendChild(sim_insp_tbl)
    if fname is None:
        channel_name = ["H=H", "L=L"]
        ifos = "".join([o.split("=")[0][0] for o in channel_name])
        #start = int(event_time)
        start = 0
        fname = "%s-MASS_POINTS-%d-1.xml.gz" % (ifos, start)
    utils.write_filename(xmldoc, fname, verbose=verbose)

def get_evidence_grid(points, res_pts, intr_prms, exact=False):
    """
    Associate the "z-axis" value (evidence, overlap, etc...) res_pts with its
    corresponding point in the template bank (points). If exact is True, then
    the poit must exactly match the point in the bank.
    """
    grid_tree = BallTree(selected)
    grid_idx = []
    # Reorder the grid points to match their weight indices
    for res in res_pts:
        dist, idx = grid_tree.query(numpy.atleast_2d(res), k=1)
        # Stupid floating point inexactitude...
        #print res, selected[idx[0][0]]
        #assert numpy.allclose(res, selected[idx[0][0]])
        grid_idx.append(idx[0][0])
    return points[grid_idx]

#
# Plotting utilities
#
def plot_grid_cells(cells, color, axis1=0, axis2=1):
    from matplotlib.patches import Rectangle
    from matplotlib import pyplot
    ax = pyplot.gca()
    for cell in cells:
        ext1 = cell._bounds[axis1][1] - cell._bounds[axis1][0]
        ext2 = cell._bounds[axis2][1] - cell._bounds[axis2][0]

        ax.add_patch(Rectangle((cell._bounds[axis1][0], cell._bounds[axis2][0]), ext1, ext2, edgecolor = color, facecolor='none'))

argp = ArgumentParser()

argp.add_argument("-d", "--distance-coordinates", default=None, help="Coordinate system in which to calculate 'closeness'. Default is tau0_tau3.")
argp.add_argument("-n", "--no-exact-match", action="store_true", help="Loosen criteria that the input intrinsic point must be a member of the input template bank.")
argp.add_argument("-v", "--verbose", action='store_true', help="Be verbose.")

# FIXME: These two probably should only be for the initial set up. While it
# could work, in theory, for refinement, the procedure would be a bit more
# tricky.
# FIXME: This could be a single value (lock a point in) or a range (adapt across
# this is range). No argument given implies use entire known range (if
# available).
argp.add_argument("-i", "--intrinsic-param", action="append", help="Adapt in this intrinsic parameter. If a pre-existing value is known (e.g. a search template was identified), specify this parameter as -i mass1=1.4 . This will indicate to the program to choose grid points which are commensurate with this value.")
argp.add_argument("-p", "--pin-param", action="append", help="Pin the parameter to this value in the template bank. If spin is not defined, spin1z,spin2z will be pinned to 0. ")
#argp.add_argument( "--fmin-template",default=15.0, help="Min template frequency. Used in some mass transforms.") #Not implemented

grid_section = argp.add_argument_group("initial gridding options", "Options for setting up the initial grid.")
grid_section.add_argument("--setup", help="Set up the initial grid based on template bank overlaps. The new grid will be saved to this argument, e.g. --setup grid will produce a grid.hdf file.")
grid_section.add_argument("--output-xml-file-name",default="", help="Set the name of the output xml file. The default name is HL-MASS_POINTS_LEVEL_%%d-0-1.xml.gz")
grid_section.add_argument("-t", "--tmplt-bank", help="XML file with template bank.")
grid_section.add_argument("-O", "--use-overlap", action="append",help="Use overlap information to define 'closeness'. If a list of files is given, the script will find the file with the closest template, and select nearby templates only from that file.")
grid_section.add_argument("-T", "--overlap-threshold", default=0.9,type=float, help="Threshold on overlap value.")
point_specification = grid_section.add_mutually_exclusive_group()
point_specification.add_argument("-s", "--points-per-side", type=int, help="Number of points per side.")
point_specification.add_argument("--total-points", type=int, help="Requested total number of points in initial grid.  Note that actual number will only approximate this.")
grid_section.add_argument("-I", "--initial-region", action="append", help="Override the initial region with a custom specification. Specify multiple times like, -I mass1=1.0,2.0 -I mass2=1.0,1.5")
grid_section.add_argument("-D", "--deactivate", action="store_true", help="Deactivate cells initially which have no template within them.")
grid_section.add_argument("-P", "--prerefine", help="Refine this initial grid based on overlap values.")
grid_section.add_argument("--mc-min", type=float, default=None, help="Restrict chirp mass grid points to be > mc-min. This is used when generating pp-plots, so that recovered Mc isn't outside of injected prior range.It should not be used otherwise.")
grid_section.add_argument("--mc-max", type=float, default=None, help="Restrict chirp mass grid points to be < mc-max. This is used when generating pp-plots, so that recovered Mc isn't outside of injected prior range.It should not be used otherwise.")

refine_section = argp.add_argument_group("refine options", "Options for refining a pre-existing grid.")
refine_section.add_argument("--refine", help="Refine a prexisting grid. Pass this option the grid points from previous levels (or the --setup) option.")
refine_section.add_argument("-r", "--result-file", help="Input XML file containing newest result to refine.")
refine_section.add_argument("-M", "--max-n-points", help="Refine *at most* this many points, can override confidence region thresholds.")
refine_section.add_argument("-m", "--min-n-points", help="Refine *at least* this many points, can override confidence region thresholds.")

opts = argp.parse_args()

if not (opts.setup or opts.refine or opts.prerefine):
    exit("Either --setup or --refine or --prerefine must be chosen")

if opts.distance_coordinates == "mu1_mu2_q_s2z":
    warnings.warn(
        "--distance-coordinates=mu1_mu2_q_s2z has been deprecated, please use "
        "--distance-coordinates=mu1_mu2_q_spin2z instead"
    )
    opts.distance_coordinates = "mu1_mu2_q_spin2z"


# Hopefully the point is already present and we can just get it, otherwise it
# could incur an overlap calculation, or suffer from the effects of being close
# only in Euclidean terms

intr_prms, expand_prms = common_cl.parse_param(opts.intrinsic_param)
pin_prms, _ = common_cl.parse_param(opts.pin_param)
intr_pt = numpy.array([intr_prms[k] for k in sorted(intr_prms)])
# This keeps the list of parameters consistent across runs
intr_prms = sorted(intr_prms.keys())

#If spin 1 and 2 are not specified, they are pinned. This means the spin columns still appear in the output grid.
spin_transform=None
if not "spin1z" in intr_prms or not "spin2z" in intr_prms:
    if not "spin1z" in intr_prms and not "spin2z" in intr_prms:
        if not "spin1z" in pin_prms:
            pin_prms["spin1z"] = 0.0
        if not "spin2z" in pin_prms:
            pin_prms["spin2z"] = 0.0
    else:
        sys.exit("spin1z or spin2z is specified but not the other spin. compute intrinsic grid is not setup to search just one")
else:
    if opts.distance_coordinates == "mu1_mu2_q_spin2z":
        spin_transform = opts.distance_coordinates
    else:
        spin_transform = "chi_z"

#
# Step 2: Set up metric space
#
# If asked, retrieve bank overlap
# You need to use the overlap if generating the inital grid, there is not other option
# Also, if inital_region = None, it wont be used anywhere
ovrlp = []
if opts.use_overlap is not None:
    # Transform and repack initial point                                                                                                 
    intr_pt = amrlib.apply_transform(intr_pt[numpy.newaxis,:], intr_prms, opts.distance_coordinates,spin_transform)[0]
    intr_pt = dict(zip(intr_prms, intr_pt))
    #Check if there are many overlap files
    overlap_filename = ""
    if len(opts.use_overlap) > 1:
        #If yes, loop over each one, get the closest in each, and 
        dists = []
        for hdf_filename in opts.use_overlap:
            print (hdf_filename)
            h5file = h5py.File(hdf_filename, "r")
            #wfrm_fam = h5file.keys()[0]
            wfrm_fam = next(iter(h5file))
            odata = h5file[wfrm_fam]
            ovrlp = odata["overlaps"]
            
            #This only needs to be done when generating the initial grid, using overlaps, if overlap file doesn't have all the necessary information
            if opts.tmplt_bank is not None:
                sys.exit("ERROR: Not setup to handle reading stuff from template bank when multiple overlap files provided. You shouldn't need to provide the template bank at all at this point, all the info you need is probably in the overlap file. The template bank option is only kept for backwards compatibility")

            pts = numpy.array([odata[a] for a in intr_prms]).T
            pts = amrlib.apply_transform(pts, intr_prms, opts.distance_coordinates,spin_transform)
            tree = BallTree(pts[:ovrlp.shape[0]])
                
            #cant require exact match, because don't know if this is the right file yet
            unused_idx, unused_pt,dist = find_olap_index(tree, intr_prms, False, **intr_pt)
            dists.append(dist)
            
        #get index of filename with min dist
        min_olap_file_index = numpy.where(dists == min(dists))[0][0]
        overlap_filename = opts.use_overlap[min_olap_file_index]
        print ("File with closest template",overlap_filename)

    else:
        overlap_filename = opts.use_overlap[0]
    

    h5file = h5py.File(overlap_filename, "r")

    # FIXME:
    #wfrm_fam = args.waveform_type
    # Just get the first one
    #wfrm_fam = h5file.keys()[0]
    wfrm_fam = next(iter(h5file))

    odata = h5file[wfrm_fam]
    ovrlp = odata["overlaps"]
    if opts.verbose:
        print("Using overlap data from %s" % wfrm_fam)

    #This only needs to be done when generating the initial grid, using overlaps, if overlap file doesn't have all the necessary information
    tmplt_bank = []
    if opts.tmplt_bank is not None:
        xmldoc_tmplt_bank = utils.load_filename(opts.tmplt_bank, contenthandler=ligolw.LIGOLWContentHandler)
        tmplt_bank = lsctables.SnglInspiralTable.get_table(xmldoc_tmplt_bank)


    if ovrlp.shape[1] != len(tmplt_bank):
        pts = numpy.array([odata[a] for a in intr_prms]).T
    else:
        # NOTE: We use the template bank here because the overlap results might not
        # have all the intrinsic information stored (e.g.: no spins, even though the
        # bank is aligned-spin).
        # FIXME: This is an oversight in the overlap calculator which was rectified
        # but this remains for legacy banks
        #FIXME: this is the only place where template bank would possibly be used.
        pts = numpy.array([tuple(getattr(t, a) for a in intr_prms) for t in tmplt_bank])

    pts = amrlib.apply_transform(pts, intr_prms, opts.distance_coordinates,spin_transform)

    # FIXME: Can probably be moved to point index identification function -- it's
    # not used again
    # The slicing here is a slight hack to work around uberbank overlaps where the
    # overlap matrix is non square. This can be slightly dangerous because it
    # assumes the first N points are from the bank in question. That's okay for now
    # but we're getting increasingly complex in how we do construction, so we should
    # be more sophisticated by matching template IDs instead.
    tree = BallTree(pts[:ovrlp.shape[0]])

    #
    # Step 3: Get the row of the overlap matrix to work with
    #
    m_idx, pt, unused_dist_var = find_olap_index(tree, intr_prms, not opts.no_exact_match, **intr_pt)
#    m_idx = 266

    #
    # Rearrange data to correspond to input point
    #
    sort_order = ovrlp[m_idx].argsort()
    ovrlp = numpy.array(ovrlp[m_idx])[sort_order]

    # DANGEROUS: This assumes the (template bank) points are the same order as the
    # overlaps. While we've taken every precaution to ensure this is true, it may
    # not always be.
    pts = pts[sort_order]
    m_idx = sort_order[m_idx]

#
# Step 1: Retrieve results from previous integration 
# Step 2 is before step 1 becuase I moved around the code to put all the overlap stuff togther
#

# Expanded parameters are now part of the intrinsic set
intr_prms = list(intr_prms) + list(expand_prms.keys())

# Gather any results we may want to use -- this is either the evidence values
# we've calculated, or overlaps of points we've looked at
results = []
if opts.result_file:
    for arg in glob.glob(opts.result_file):
        # FIXME: Bad hardcode
        # This is here because I'm too lazy to figure out the glob syntax to
        # exclude the samples files which would be both double counting and
        # slow to load because of their potential size
        if "samples" in arg:
            continue
        xmldoc = utils.load_filename(arg, contenthandler=ligolw.LIGOLWContentHandler)

        # FIXME: The template banks we make are sim inspirals, we should
        # revisit this decision -- it isn't really helping anything
        if opts.prerefine:
            results.extend(lsctables.SimInspiralTable.get_table(xmldoc))
        else:
            results.extend(lsctables.SnglInspiralTable.get_table(xmldoc))

    res_pts = numpy.array([tuple(getattr(t, a) for a in intr_prms) for t in results])
    res_pts = amrlib.apply_transform(res_pts, intr_prms, opts.distance_coordinates,spin_transform)

    # In the prerefine case, the "result" is the overlap values, which we use as
    # a surrogate for the true evidence value.
    if opts.prerefine:
        # We only want toe overlap values
        # FIXME: this needs to be done in a more consistent way
        results = numpy.array([res.alpha1 for res in results])
    else:
        # Normalize
        # We're gathering the evidence values. We normalize here so as to avoid
        # overflows later on
        # FIXME: If we have more than 1 copies -- This is tricky because we need
        # to pare down the duplicate sngl rows too
        maxlnevid = numpy.max([s.snr for s in results])
        total_evid = numpy.exp([s.snr - maxlnevid for s in results]).sum()
        for res in results:
            res.snr = numpy.exp(res.snr - maxlnevid)/total_evid

        # FIXME: this needs to be done in a more consistent way
        results = numpy.array([res.snr for res in results])

#
# Build (or retrieve) the initial region
#
apply_truncation = False # Set to true if initial region truncation is needed
if opts.refine or opts.prerefine:
    grid_init_region, region_labels = amrlib.load_init_region(
        opts.refine or opts.prerefine, get_labels=True,
    )
else:
    points_per_side = opts.points_per_side
    if opts.total_points is not None:
        points_per_side = math.ceil(opts.total_points ** (1./len(intr_prms)))

    ####### BEGIN INITIAL GRID CODE #########
    if opts.initial_region is None:
        #This is the only time anything from the overlap file is used anywhere
        grid_init_region, idx = determine_region(
            pt, pts, ovrlp, opts.overlap_threshold, expand_prms
        )
        region_labels = intr_prms
#        print "init trgion",len(pts[idx:])
        # FIXME: To be reimplemented in a different way
        #if opts.expand_param is not None:
            #expand_param(grid_init_region, opts.expand_param)

    else:
        # Override initial region -- use with care
        _, boundary_init_region = common_cl.parse_param(opts.initial_region)

        if len(boundary_init_region) != len(intr_prms):
            raise ValueError(
                "Boundary and gridding coordinate systems must have the same "
                "number of dimensions."
            )

        # HACK: These labels do not actually match the data, but they mis-match
        #       in a way that is self-consistent with the rest of the code.
        region_labels = intr_prms

        # HACK: We map each input parameter to the mislabeled parameter it is
        #       associated with.
        param_mapping = {}

        def map_param(source, target):
            try:
                # Indicate that the input parameter maps to the given index
                param_mapping[source] = region_labels.index(target)
            except ValueError:
                # The given -I parameter does not have its associated
                # -i parameter
                raise ValueError(
                    f"Cannot map {source} to {target} as {target} is missing "
                    "from the intrinsic parameters"
                )

        if opts.distance_coordinates == "mu1_mu2_q_spin2z":
            map_param("mu1", "mass1")
            map_param("mu2", "mass2")
            map_param("q", "spin1z")
            map_param("spin2z", "spin2z")
        else:
            if opts.distance_coordinates == "mchirp_eta":
                map_param("mchirp", "mass1")
                map_param("eta", "mass2")
            elif opts.distance_coordinates == "mchirp_q":
                map_param("mchirp", "mass1")
                map_param("q", "mass2")
            elif opts.distance_coordinates == "tau0_tau3":
                map_param("tau0", "mass1")
                map_param("tau3", "mass2")
            elif opts.distance_coordinates == "mass1_mass2":
                map_param("mass1","mass1")
                map_param("mass2","mass2")
            else:
                raise ValueError(
                    f"Unknown distance coordinates {opts.distance_coordinates}"
                )

            if spin_transform is None:
                # No spins to map
                pass
            elif spin_transform == "chi_z":
                map_param("chieff", "spin1z")
                map_param("chia", "spin2z")
            else:
                raise ValueError(f"Unknown spin transform {spin_transform}")

        if len(param_mapping) != len(boundary_init_region):
            raise ValueError(
                f"Provided -I parameters {set(boundary_init_region.keys())}, but "
                f"expected parameters {set(param_mapping.keys())}"
            )

        # ### OLD ###
        # NOTE: init_region has been re-named grid_init_region or
        #       boundary_init_region depending on coordinates

        # # Map the input ranges into an amrlib.Cell
        # init_region_array = numpy.empty((len(region_labels), 2))
        # for param_name, index in param_mapping.items():
        #     init_region_array[index] = init_region[param_name]
        # init_region = amrlib.Cell(init_region_array)

        ### END ###

        # There are two sets of coordinates:
        # - the prior-boundary coordinates: `boundary_init_region`
        # - the coordinates in which the grid is rectilinear:
        #   `distance_coordinates` and `spin_transform`
        #   which we'll call the "grid coordinates"
        #
        # Here we use a dense grid in the prior-boundary coordinates to
        # approximately find the limits of the grid coordinates.  In all but
        # one case, the spin coordinates are the same for both, so we only worry
        # about the mass coordinates.

        if boundary_init_region.keys() == {"mu1", "mu2", "q", "spin2z"}:
            # mu1-mu2-q-spin2z is a special case, where mass and spin are
            # combined through mu1, mu2, and spin2z, so we have to consider
            # everything.
            boundary_check_mass_coordinate_names = boundary_init_region.keys()
            spin_transform_bound = spin_transform
        else:
            # The standard case has 2 mass coordinates along with chieff, chia,
            # so we strip out the two spins as they will not impact the
            # boundaries.
            boundary_check_mass_coordinate_names = (
                boundary_init_region.keys() - {"chieff", "chia"}
            )
            assert len(boundary_check_mass_coordinate_names) == 2, \
              f"Expected 2 mass parameters, but got: " + \
              ", ".join(boundary_check_mass_coordinate_names)

            spin_transform_bound = None if spin_transform is None else 'chi_z'


        # Get a string representation of the prior-boundary coordinates
        boundary_check_mass_coordinates = "_".join(
            boundary_check_mass_coordinate_names
        )
        # Get the limits of the prior boundary, turning param_name -> [min, max]
        # into a list of each [min, max]
        boundary_values = list(boundary_init_region.values())

        # Determine how dense to make the prior-boundary coordinate grids for
        # bounds checking.  While we'd like to go very dense, in higher
        # dimensions this becomes computationally infeasible.  In the future we
        # should take a smarter approach than we do now.
        if len(boundary_init_region.keys())>2:
            warnings.warn('4 dimensions in initial boundary region. To reduce '
                          'computation, using only 7 points per dimension '
                          'to find the initial region boundaries.'
            )
            boundary_check_points_per_side = 7
        else:
            boundary_check_points_per_side = 50

        # Construct 1-D grids along each of the prior-boundary coordinate axes
        boundary_check_1d_grids = [
            np.linspace(lower_bound, upper_bound,
                        boundary_check_points_per_side)
            for lower_bound, upper_bound in boundary_values
        ]
        # Convert 1-D grids to a dense mesh grid
        boundary_check_meshgrids = np.asarray(
            np.meshgrid(*boundary_check_1d_grids, indexing="ij")
        ).T

        # Transform mesh grid in prior-boundary coordinates...
        #
        # ...first from prior-boundary coordinates to m1, m2[, spin1z, spin2z]
        boundary_check_converted = amrlib.apply_inv_transform(
            boundary_check_meshgrids,
            intr_prms,
            mass_transform=boundary_check_mass_coordinates,
            spin_transform=spin_transform_bound,
        )
        # ...then from m1, m2[, spin1z, spin2z] to the grid coordinates.
        boundary_check_converted = amrlib.apply_transform(
            boundary_check_converted,
            intr_prms,
            mass_transform=opts.distance_coordinates,
            spin_transform=spin_transform,
        )

        # Create a column matrix with the min and max of each grid coordinate
        #
        # [ min(grid_coord[0]), max(grid_coord[0]) ]
        # [        ...        ,        ...         ]
        # [ min(grid_coord[N]), max(grid_coord[N]) ]
        axes_to_min_max = tuple(range(0, len(region_labels)))
        grid_init_region_boundaries = np.column_stack((
            np.nanmin(boundary_check_converted, axis=axes_to_min_max),
            np.nanmax(boundary_check_converted, axis=axes_to_min_max),
        ))

        # Convert matrix into the initial AMR Cell
        grid_init_region = amrlib.Cell(grid_init_region_boundaries)

        # AMR requires `points_per_side`, but if `opts.total_points` was
        # specified, and truncation may have occurred due to the different
        # prior-boundary and grid coordinates, we need to find the minimal
        # `points_per_side` that meets our `opts.total_points` requirement.
        if opts.total_points is not None:
            apply_truncation = True # Ensures truncation applied later
            while True:
                # Construct 1-D grids along each of the grid coordinate axes
                total_points_check_1d_grids = [
                    np.linspace(lower_bound, upper_bound, points_per_side)
                    for lower_bound, upper_bound in grid_init_region_boundaries
                ]
                # Convert 1-D grids to a dense mesh grid
                total_points_check_meshgrids = np.column_stack(
                    tuple(mesh.flatten() for mesh in np.meshgrid(
                        *total_points_check_1d_grids, indexing="ij")
                ))
                # Compute mask array which is True where the corresponding
                # grid value is physical.  Skipping this could cause issues
                # with transformation functions.
                bounds_mask = amrlib.check_grid(
                    total_points_check_meshgrids,
                    intr_prms,
                    opts.distance_coordinates,
                )

                # Transform mesh grid in grid coordinates...
                #
                # ...first from grid coordinates to m1, m2[, spin1z, spin2z]
                total_points_check_converted = amrlib.apply_inv_transform(
                    total_points_check_meshgrids[bounds_mask],
                    intr_prms,
                    mass_transform=opts.distance_coordinates,
                    spin_transform=spin_transform,
                )
                # ...then from m1, m2[, spin1z, spin2z] to the prior-boundary
                # coordinates.
                total_points_check_converted = amrlib.apply_transform(
                    total_points_check_converted,
                    intr_prms,
                    mass_transform=boundary_check_mass_coordinates,
                    spin_transform=spin_transform_bound,
                )
                # Determine whether each point is within the prior-boundary
                # limits
                total_points_check_is_in_bounds = reduce(operator.and_, [
                    (lower_bound <= points) & (points <= upper_bound)
                    for points, (lower_bound, upper_bound)
                    in zip(total_points_check_converted.T, boundary_values)
                ])
                # Count the number of points which were in our prior boundary,
                total_points = np.count_nonzero(total_points_check_is_in_bounds)
                # If the count meets the `opts.total_points` requirement then
                # we're done, otherwise try again with one more point per side.
                if total_points >= opts.total_points:
                    break
                else:
                    points_per_side += 1

            # Later we'll want to re-use the boundary checking we've already
            # done.  We need to know the indices in
            # `total_points_check_meshgrids` which were both physical and within
            # the prior bounds.  For that we need to combine the information in
            # `bounds_mask` and `total_points_check_is_in_bounds`.
            truncation_mask = (
                np.zeros_like(total_points_check_meshgrids[:,0], dtype=bool)
            )
            truncation_mask[bounds_mask] = total_points_check_is_in_bounds


        # Old code below: used incorrect labeling
#        region_labels = init_region.keys()
#        init_region = amrlib.Cell(numpy.vstack(init_region[k] for k in region_labels))

    # TODO: Alternatively, check density of points in the region to determine
    # the points to a side
    grid, spacing = amrlib.create_regular_grid_from_cell(
        grid_init_region, side_pts=points_per_side,
        return_cells=True,
    )

    # "Deactivate" cells not close to template points
    # FIXME: This gets more and more dangerous in higher dimensions
    # FIXME: Move to function
    #tree = BallTree(grid)
    tree = BallTree(list(grid))
    if opts.deactivate:
        get_idx = set()
        for pt in pts[idx:]:
            get_idx.add(tree.query(numpy.atleast_2d(pt), k=1, return_distance=False)[0][0])
        selected = grid[numpy.array(list(get_idx))]
    else:
        selected = grid
#    print "selected",len(selected)
#    sys.exit()

# Make sure all our dimensions line up
# FIXME: We just need to be consistent from the beginning
reindex = numpy.array([list(region_labels).index(l) for l in intr_prms])
intr_prms = list(region_labels)
if opts.refine or opts.prerefine:
    res_pts = res_pts[:,reindex]

extent_str = " ".join("(%f, %f)" % bnd for bnd in map(tuple, grid_init_region._bounds))
center_str = " ".join(map(str, grid_init_region._center))
label_str = ", ".join(region_labels)
print("Initial region (" + label_str + ") has center " + center_str + " and extent " + extent_str)

#### BEGIN REFINEMENT OF RESULTS #########

if opts.result_file is not None:
    (prev_cells, spacing), level, _ = amrlib.load_grid_level(opts.refine or opts.prerefine, -1, True)

    selected = numpy.array([c._center for c in prev_cells])
    selected = amrlib.apply_transform(selected, intr_prms, opts.distance_coordinates,spin_transform)

    selected = get_evidence_grid(selected, res_pts, intr_prms)

    if opts.verbose:
        print("Loaded %d result points" % len(selected))

    if opts.refine:
        # FIXME: We use overlap threshold as a proxy for confidence level
        selected = get_cr_from_grid(selected, results, cr_thr=opts.overlap_threshold, min_n=opts.min_n_points, max_n=opts.max_n_points)
        print("Selected %d cells from %3.2f%% confidence region" % (len(selected), opts.overlap_threshold*100))

if opts.prerefine:
    print("Performing refinement for points with overlap > %1.3f" % opts.overlap_threshold)
    pt_select = results > opts.overlap_threshold
    selected = selected[pt_select]
    results = results[pt_select]
    grid, spacing = amrlib.refine_regular_grid(selected, spacing, return_cntr=True)

elif opts.refine:
#    print "selected",len(selected)
    grid, spacing = amrlib.refine_regular_grid(selected, spacing, return_cntr=opts.setup)
#    print "refine grid",len(grid)

print("%d cells after refinement" % len(grid))
grid = amrlib.prune_duplicate_pts(grid, grid_init_region._bounds, spacing)
#print "prune grid",len(grid)
#
# Clean up
#

grid = numpy.array(grid)
bounds_mask = amrlib.check_grid(grid, intr_prms, opts.distance_coordinates)

if apply_truncation:
    bounds_mask &= truncation_mask.flatten()

grid = grid[bounds_mask]
print("%d cells after bounds checking" % len(grid))

if len(grid) == 0:
    exit("All cells would be removed by physical boundaries.")

# Convert back to physical mass
grid = amrlib.apply_inv_transform(grid, intr_prms, opts.distance_coordinates,spin_transform)
#print "inv transform",grid

cells = amrlib.grid_to_cells(grid, spacing)
if opts.setup:
    hdf_filename = opts.setup+".hdf" if not ".hdf" in opts.setup else opts.setup
    grid_group = amrlib.init_grid_hdf(grid_init_region, hdf_filename, opts.overlap_threshold, opts.distance_coordinates, intr_prms=intr_prms)
    level = amrlib.save_grid_cells_hdf(grid_group, cells, "mass1_mass2", intr_prms=intr_prms)
else:
    grp = amrlib.load_grid_level(opts.refine, None)
    level = amrlib.save_grid_cells_hdf(grp, cells, "mass1_mass2", intr_prms)

print("Selected %d cells for further analysis." % len(cells))
if opts.setup:
    fname = "HL-MASS_POINTS_LEVEL_0-0-1.xml.gz" if opts.output_xml_file_name == "" else opts.output_xml_file_name 
    write_to_xml(cells, intr_prms, pin_prms, None, fname, verbose=opts.verbose)
else:
    #m = re.search("LEVEL_(\d+)", opts.result_file)
    #if m is not None:
        #level = int(m.group(1)) + 1
        #fname = "HL-MASS_POINTS_LEVEL_%d-0-1.xml.gz" % level
    #else:
        #fname = "HL-MASS_POINTS_LEVEL_X-0-1.xml.gz"
    fname = "HL-MASS_POINTS_LEVEL_%d-0-1.xml.gz" % level if opts.output_xml_file_name == "" else opts.output_xml_file_name 
    write_to_xml(cells, intr_prms, pin_prms, None, fname, verbose=opts.verbose)
