#!/usr/bin/env python
#
# Integration of the area laying in the different cbc regions
# By A. Curiel Barroso
# August 2019

"""This script computes the area corresponding to different CBC on
the m1 & m2 plane when given a central mchirp value and uncertainty.
"""

import argparse
from pycbc.mchirp_area import calc_areas
from pycbc.mchirp_area import src_mass_from_z_det_mass
from pycbc.conversions import mass2_from_mchirp_mass1 as m2mcm1
import numpy
from matplotlib import use; use("Agg")
from matplotlib import pyplot

# ARGUMENT PARSER
parser = argparse.ArgumentParser()
parser.add_argument("--central-mc", type=float, help="Central value of mchirp")
parser.add_argument("--delta-mc", type=float, help="Uncertainty for mchirp")
parser.add_argument("--min-m2", type=float, help="Minimum value for m2")
parser.add_argument("--max-m1", type=float, help="Maximum value for m1")
parser.add_argument("--central-z", type=float, help="Central redshift value")
parser.add_argument("--delta-z", type=float, help="Redshift uncertainty")
parser.add_argument("--ns-max", type=float, help="Maximum neutron star mass")
parser.add_argument("--gap-max", type=float, help="Minimum black hole mass")

args = parser.parse_args()

if args.central_mc and args.delta_mc:
    central_mc = float(args.central_mc)
    delta_mc = float(args.delta_mc)

if args.min_m2 and args.max_m1:
    m2_min = float(args.min_m2)
    m1_max = float(args.max_m1)

if args.min_m2 and args.max_m1:
    ns_max = float(args.ns_max)
    gap_max = float(args.gap_max)

if args.central_z and args.delta_z:
    central_z = float(args.central_z)
    delta_z = float(args.delta_z)
else:
    central_z = 0.0
    delta_z = 0.0

mass_limits = {"max_m1": m1_max, "min_m2": m2_min}
mass_bdary = {"ns_max": ns_max, "gap_max": gap_max}
z = {"central": central_z, "delta": delta_z}
trig_mc = {"central": central_mc, "delta": delta_mc}

areas = calc_areas(trig_mc, mass_limits, mass_bdary, z)

print("abbh = " + str(areas["bbh"]))
print("abhg = " + str(areas["bhg"]))
print("agg = " + str(areas["gg"]))
print("ansbh = " + str(areas["nsbh"]))
print("agns = " + str(areas["gns"]))
print("abns = " + str(areas["bns"]))

# PLOT GENERATION
src_mchirp = src_mass_from_z_det_mass(central_z, delta_z,
                                      central_mc, delta_mc)

mcb = src_mchirp[0] + src_mchirp[1]
mcs = src_mchirp[0] - src_mchirp[1]

# The points where the equal mass line and a chirp mass
# curve intersect is m1 = m2 = (2**0.2)*mchirp

mib = (2**0.2)*mcb
mis = (2**0.2)*mcs

lim_m1b = min(m1_max, m2mcm1(mcb, m2_min))
m1b = numpy.linspace(mib, lim_m1b, num=100)
m2b = m2mcm1(mcb, m1b)

lim_m1s = min(m1_max, m2mcm1(mcs, m2_min))
m1s = numpy.linspace(mis, lim_m1s, num=100)
m2s = m2mcm1(mcs, m1s)

if mib > m1_max:
    pyplot.plot((m1_max, m1_max), (m2mcm1(mcs, lim_m1s), m1_max), "b")
else:
    pyplot.plot(m1b, m2b, "b")
    pyplot.plot((m1_max, m1_max), (m2mcm1(mcs, lim_m1s),
                                   m2mcm1(mcb, lim_m1b)),"b")

if mis >= m2_min:
    pyplot.plot(m1s, m2s, "b")
    pyplot.plot((lim_m1s, lim_m1b), (m2_min, m2_min), "b")
else:
    pyplot.plot((m2_min, lim_m1b), (m2_min, m2_min), "b")

pyplot.plot((m2_min, m1_max), (m2_min, m1_max), "k--")
pyplot.plot((ns_max, ns_max), (m2_min, ns_max), "k:")
pyplot.plot((gap_max, gap_max), (m2_min, gap_max), "k:")
pyplot.plot((ns_max, m1_max), (ns_max, ns_max), "k:")
pyplot.plot((gap_max, m1_max), (gap_max, gap_max), "k:")

pyplot.xlabel("M1")
pyplot.ylabel("M2")
pyplot.title("MChirp = " + str(0.5 * (mcb + mcs)) + " +/- "
             + str((mcb - mcs) * 0.5))
pyplot.savefig("mass_plot.png")
