#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (C) 2022 Ignacio Magana Hernandez
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
This script provides a low latency estimate for the sky-localization of a given trigger using
the search pipeline output e.g. a coinc.xml file. The output of this file is a skymap file for
which pixel values of 1 have sky support and pixel values of 0 do not. The skymap file is in a
format that makes them compatible with most LSC codes.
"""
from __future__ import print_function

# System imports
import os
import sys
import argparse

import numpy as np
import bilby
import healpy as hp

from ligo.lw import ligolw, lsctables                                           
from ligo.lw import utils as ligolw_utils
from ligo.skymap.io.fits import write_sky_map

# Command line options
argp = argparse.ArgumentParser()
argp.add_argument("-c", "--coinc-file", help="File path for the coinc.xml file Required.", required=True)
argp.add_argument("-n", "--nside", type=int, default=32, help="Determines the resolution for the skymap output.")
argp.add_argument("-f", "--factor", type=int, default=1, help="Factor to which one can scale down the allowed time delay matrix tau_ij.")
argp.add_argument("-o", "--output", type=str, default='./triangulate.fits.gz', help="Output '.fits.gz' file where the skymap will be written.")
args = argp.parse_args()

# Parse the arguments
coincfile = args.coinc_file
nside = args.nside
factor = args.factor
output = args.output

# Read in the coinc.xml file
@lsctables.use_in                                                               
class LIGOLWContentHandler(ligolw.LIGOLWContentHandler):                        
    pass
coinc_xml_obj = ligolw_utils.load_filename(coincfile, contenthandler=LIGOLWContentHandler)

# Parse the coinc.xml file
sngl_table = lsctables.SnglInspiralTable.get_table(coinc_xml_obj)

# Read in detectors in coinc file and create bilby IFOs
detectors = [sngl_table[i].ifo for i in range(len(sngl_table))]
ifos = bilby.gw.detector.InterferometerList(detectors)

toa = {}
snr = {}
for s in sngl_table:
    toa[s.ifo] = float(str(s.end_time) + "." + str(s.end_time_ns))
    snr[s.ifo] = float(str(s.snr))
    print("For " + s.ifo + " TOA is " + str(toa[s.ifo]) + " with SNR " + str(snr[s.ifo]))

# Prepare the healpy grid
npix = hp.pixelfunc.nside2npix(nside)
apix = hp.pixelfunc.nside2pixarea(nside)
print("The number of pixels will be " + str(npix))

m = np.ones(npix)
theta, phi = hp.pix2ang(nside=nside, ipix=np.arange(npix))
ra = phi
dec = np.pi/2 - theta

# Calculate the geocenter times for each IFO at the given TOA for every (ra,dec)
dts = {}
for ifo in ifos:
    dt = []
    for k in range(npix):
        dt.append(ifo.time_delay_from_geocenter(ra[k], dec[k], toa[ifo.name]))
    dts[ifo.name] = np.array(dt)
    
# Allowed delays determined through recovered search pipeline injections
delays = {'H1L1':0.010/factor,
          'L1H1':0.010/factor,
          'H1V1':0.027/factor,
          'V1H1':0.027/factor,
          'L1V1':0.026/factor,
          'V1L1':0.026/factor}

# Generate all possible baselines
baselines =  [a + b for idx, a in enumerate(detectors) for b in detectors[idx + 1:]]

# Calculate the allowed (ra,dec) given the allowed delays between baselines
masks = {}
for baseline in baselines:
    detector_1 = baseline[0:2]
    detector_2 = baseline[2:4]
    delay_geocenter = np.abs((toa[detector_1] + dts[detector_1]) - (toa[detector_2] + dts[detector_2]))
    masks[baseline] = np.where((delay_geocenter < delays[baseline]))

# Convert the above into healpy format
skymaps = {}
for baseline in baselines:
    samples_ind = hp.pixelfunc.ang2pix(nside, np.pi/2 + dec[masks[baseline]], ra[masks[baseline]])

    m = []
    for pix in range(npix):
        id = np.where(samples_ind == pix)[0]
        m.append(len(id))

    skymaps[baseline] = np.asarray(m)

# Combine the skymaps for each baseline
mtiming = np.ones(npix)
for baseline in baselines:
    mtiming *= skymaps[baseline]

# Format the map into zero or one weights and print the fraction of the sky with support
index = np.where(mtiming > 0)[0]
mtiming[index] = 1

ntot = len(mtiming)
nsupport = len(index)

print("The fraction of the sky with posterior support is " + str(nsupport/ntot))

# Save the skymap in the standard format using ligo.skymap
write_sky_map(output, mtiming, nest=True)