#!/usr/bin/env python3
import numpy as np
from subscript.defaults import ParamKeys
from numpy import testing as npt
from subscript.external import symphony_to_galacticus_like_dict, KEY_MAP_SYMPHONY_DEFAULT


def test_symphony_conversion():
    h_mock = np.array([[(       -1, -1.0000000e+00, -1.  , -1.       , [  -1.      ,   -1.      ,   -1.      ], [  -1.      ,   -1.      ,   -1.      ], False, -1.       , -1.      ),
        (       -1, -1.0000000e+00, -1.  , -1.       , [  -1.      ,   -1.      ,   -1.      ], [  -1.      ,   -1.      ,   -1.      ], False, -1.       , -1.      ),
        (       -1, -1.0000000e+00, -1.  , -1.       , [  -1.      ,   -1.      ,   -1.      ], [  -1.      ,   -1.      ,   -1.      ], False, -1.       , -1.      ),
        (       -1, -1.0000000e+00, -1.  , -1.       , [  -1.      ,   -1.      ,   -1.      ], [  -1.      ,   -1.      ,   -1.      ], False, -1.       , -1.      ),
        (       -1, -1.0000000e+00, -1.  , -1.       , [  -1.      ,   -1.      ,   -1.      ], [  -1.      ,   -1.      ,   -1.      ], False, -1.       , -1.      ),
        (       -1, -1.0000000e+00, -1.  , -1.       , [  -1.      ,   -1.      ,   -1.      ], [  -1.      ,   -1.      ,   -1.      ], False, -1.       , -1.      ),
        (       -1, -1.0000000e+00, -1.  , -1.       , [  -1.      ,   -1.      ,   -1.      ], [  -1.      ,   -1.      ,   -1.      ], False, -1.       , -1.      ),
        (       -1, -1.0000000e+00, -1.  , -1.       , [  -1.      ,   -1.      ,   -1.      ], [  -1.      ,   -1.      ,   -1.      ], False, -1.       , -1.      ),
        (       -1, -1.0000000e+00, -1.  , -1.       , [  -1.      ,   -1.      ,   -1.      ], [  -1.      ,   -1.      ,   -1.      ], False, -1.       , -1.      ),
        (       -1, -1.0000000e+00, -1.  , -1.       , [  -1.      ,   -1.      ,   -1.      ], [  -1.      ,   -1.      ,   -1.      ], False, -1.       , -1.      ),
        (       -1, -1.0000000e+00, -1.  , -1.       , [  -1.      ,   -1.      ,   -1.      ], [  -1.      ,   -1.      ,   -1.      ], False, -1.       , -1.      ),],
       [(161191255,  1.1802858e+08, 10.73,  1.4001675, [-352.33643 , -466.7435  ,  399.61328 ], [ 230.34    ,   86.240005, -149.66    ],  True, 11.92505  , 31.378775),
        (161763033,  1.1762858e+08, 10.58,  2.1060712, [-314.90338 , -456.68347 ,  377.55533 ], [ 239.53    ,   99.67    , -161.12    ],  True, 12.00846  , 30.496603),
        (162334212,  1.1682858e+08, 10.83,  1.5075221, [-274.99142 , -444.41937 ,  353.44006 ], [ 244.91    ,  112.77    , -170.87    ],  True, 12.077489 , 33.516068),
        (162905122,  1.1601429e+08, 10.75,  1.677781 , [-235.31105 , -429.24124 ,  326.8972  ], [ 249.5     ,  126.96    , -182.73001 ],  True, 12.145057 , 33.39693 ),
        (163473737,  1.1601429e+08, 11.17,  1.9568021, [-192.91576 , -411.1923  ,  297.21194 ], [ 255.76001 ,  143.56999 , -195.86    ],  True, 12.240316 , 38.10785 ),
        (164041230,  1.1641429e+08, 10.75,  1.6198878, [-149.09872 , -389.31595 ,  264.40155 ], [ 263.41    ,  166.73001 , -210.83    ],  True, 12.349269 , 34.10137 ),
        (164614912,  1.1641429e+08, 10.82,  1.5401733, [-103.52247 , -361.36618 ,  227.9414  ], [ 272.7     ,  195.22    , -228.29    ],  True, 12.4436655, 35.222027),
        (165176558,  1.1601429e+08, 10.94,  1.2020737, [ -55.2754  , -326.9403  ,  187.49542 ], [ 283.72    ,  230.87999 , -247.72    ],  True, 12.523146 , 37.016132),
        (165729990,  1.1521429e+08, 10.46,  1.8375131, [  -4.342098, -285.045   ,  142.50906 ], [ 291.55    ,  276.11002 , -272.67    ],  True, 12.587252 , 32.767822),
        (166275341,  1.1360000e+08,  9.95,  1.1902572, [  48.24502 , -232.74557 ,   91.75982 ], [ 291.81    ,  339.43    , -304.88998 ],  True, 12.620098 , 28.667397),
        (       -1, -1.0000000e+00, -1.  , -1.       , [  -1.      ,   -1.      ,   -1.      ], [  -1.      ,   -1.      ,   -1.      ], False, -1.       , -1.      ),]],
      dtype=[('id', '<i4'), ('mvir', '<f4'), ('vmax', '<f4'), ('rvmax', '<f4'), ('x', '<f4', (3,)), ('v', '<f4', (3,)), ('ok', '?'), ('rvir', '<f4'), ('cvir', '<f4')]).T

    hist_mock = np.array([(1.49857152e+08, 22.2 ,  81, 6.6185818e-04,  13600455,  13600542,  True, False,  True,      -1,  81,  232141, False, 1.21257144e+08, 0, 0, 0, 0, 0, 0),
       (1.24471432e+08, 13.12, 134, 6.8571426e-05,  30015901,  30016099,  True, False,  True,      -1, 134,  569486, False, 1.22057144e+08, 0, 0, 0, 0, 0, 0),
       (1.21257144e+08, 12.83, 157, 2.1584741e-05,  29358255,  29358454,  True, False,  True,      -1, 157,  561426, False, 1.21257144e+08, 0, 0, 0, 0, 0, 0),
       (1.21257144e+08, 12.02, 225, 9.6484882e-06,  27526539,  27526714,  True, False,  True,      -1, 225,  536412, False, 1.21257144e+08, 0, 0, 0, 0, 0, 0),
       (1.21257144e+08, 13.89, 168, 4.8911315e-06,  31447571,  31447779,  True, False,  True,  307964, 127,  585027, False, 1.21257144e+08, 0, 0, 0, 0, 0, 0),
       (1.24071432e+08, 13.45, 154, 2.6225385e-05,  29468890,  29469053,  True, False,  True,  272054, 153,  562801, False, 1.23671432e+08, 0, 0, 0, 0, 0, 0),
       (1.21257144e+08, 12.65, 206, 8.1003436e-06, 105872378, 105872578,  True, False,  True, 1769091, 169, 1877999, False, 1.21257144e+08, 0, 0, 0, 0, 0, 0),
       (1.21257144e+08, 16.5 , 140, 1.9619740e-05,  14126887,  14127033,  True, False,  True,  562239, 123,  247818, False, 1.21257144e+08, 0, 0, 0, 0, 0, 0),
       (1.21257144e+08, 15.84, 165, 1.7540825e-06,  16509517,  16509676,  True, False,  True,  303394, 108,  307590, False, 1.21257144e+08, 0, 0, 0, 0, 0, 0),
       (1.21257144e+08, 11.88, 230, 9.2452192e-06,  27391623,  27391782,  True, False,  True,      -1, 230,  534378, False, 1.21257144e+08, 0, 0, 0, 0, 0, 0),
       (1.21257144e+08, 15.84, 165, 1.7540825e-06,  16509517,  16509676,  True, False,  True,  303394, 108,  307590, False, 1.21257144e+08, 0, 0, 0, 0, 0, 0),],

      dtype=[('mpeak', '<f4'), ('vpeak', '<f4'), ('merger_snap', '<i4'), ('merger_ratio', '<f4'), ('start', '<i4'), ('end', '<i4'), ('is_real', '?'), ('is_disappear', '?'), ('is_main_sub', '?'), ('preprocess', '<i4'), ('first_infall_snap', '<i4'), ('branch_idx', '<i4'), ('false_selection', '?'), ('mpeak_pre', '<f4'), ('conv_snap_discrete', '<i4'), ('conv_snap_eps', '<i4'), ('conv_snap_relax', '<i4'), ('conv_snap_relax_hydro', '<i4'), ('disrupt_snap', '<i4'), ('disrupt_snap_rs', '<i4')])

    z_snap = np.array([1.90000003e+01, 1.87466623e+01, 1.84965337e+01, 1.82495730e+01,
       1.80057406e+01, 1.77649972e+01, 1.75273028e+01, 1.72926197e+01,
       1.70609089e+01, 1.68321335e+01, 1.66062556e+01, 1.63832393e+01,
       1.61630475e+01, 1.59456452e+01, 1.57309964e+01, 1.55190669e+01,
       1.53098215e+01, 1.51032269e+01, 1.48992489e+01, 1.46978550e+01,
       1.44990119e+01, 1.43026877e+01, 1.41088501e+01, 1.39174681e+01,
       1.37285100e+01, 1.35419455e+01, 1.33577443e+01, 1.31758762e+01,
       1.29963121e+01, 1.28190222e+01, 1.26439783e+01, 1.24711514e+01,
       1.23005139e+01, 1.21320376e+01, 1.19656956e+01, 1.18014605e+01,
       1.16393059e+01, 1.14792051e+01, 1.13211324e+01, 1.11650619e+01,
       1.10109685e+01, 1.08588267e+01, 1.07086124e+01, 1.05603006e+01,
       1.04138676e+01, 1.02692893e+01, 1.01265423e+01, 9.98560373e+00,
       9.84645019e+00, 9.70905948e+00, 9.57340890e+00, 9.43947677e+00,
       9.30724097e+00, 9.17668036e+00, 9.04777337e+00, 8.92049941e+00,
       8.79483744e+00, 8.67076738e+00, 8.54826874e+00, 8.42732195e+00,
       8.30790701e+00, 8.19000484e+00, 8.07359598e+00, 7.95866180e+00,
       7.84518333e+00, 7.73314229e+00, 7.62252060e+00, 7.51330000e+00,
       7.40546303e+00, 7.29899187e+00, 7.19386952e+00, 7.09007859e+00,
       6.98760252e+00, 6.88642436e+00, 6.78652795e+00, 6.68789678e+00,
       6.59051510e+00, 6.49436681e+00, 6.39943654e+00, 6.30570863e+00,
       6.21316807e+00, 6.12179960e+00, 6.03158860e+00, 5.94252017e+00,
       5.85458009e+00, 5.76775381e+00, 5.68202735e+00, 5.59738690e+00,
       5.51381846e+00, 5.43130869e+00, 5.34984395e+00, 5.26941122e+00,
       5.18999722e+00, 5.11158925e+00, 5.03417436e+00, 4.95774019e+00,
       4.88227409e+00, 4.80776402e+00, 4.73419765e+00, 4.66156325e+00,
       4.58984879e+00, 4.51904284e+00, 4.44913368e+00, 4.38011014e+00,
       4.31196082e+00, 4.24467484e+00, 4.17824107e+00, 4.11264881e+00,
       4.04788749e+00, 3.98394641e+00, 3.92081534e+00, 3.85848387e+00,
       3.79694203e+00, 3.73617965e+00, 3.67618701e+00, 3.61695422e+00,
       3.55847181e+00, 3.50073010e+00, 3.44371988e+00, 3.38743173e+00,
       3.33185665e+00, 3.27698546e+00, 3.22280938e+00, 3.16931948e+00,
       3.11650720e+00, 3.06436382e+00, 3.01288100e+00, 2.96205024e+00,
       2.91186335e+00, 2.86231223e+00, 2.81338871e+00, 2.76508497e+00,
       2.71739301e+00, 2.67030523e+00, 2.62381385e+00, 2.57791142e+00,
       2.53259038e+00, 2.48784347e+00, 2.44366331e+00, 2.40004283e+00,
       2.35697483e+00, 2.31445242e+00, 2.27246859e+00, 2.23101661e+00,
       2.19008965e+00, 2.14968116e+00, 2.10978446e+00, 2.07039319e+00,
       2.03150083e+00, 1.99310111e+00, 1.95518785e+00, 1.91775479e+00,
       1.88079593e+00, 1.84430518e+00, 1.80827670e+00, 1.77270454e+00,
       1.73758302e+00, 1.70290633e+00, 1.66866893e+00, 1.63486517e+00,
       1.60148964e+00, 1.56853683e+00, 1.53600148e+00, 1.50387820e+00,
       1.47216187e+00, 1.44084725e+00, 1.40992932e+00, 1.37940299e+00,
       1.34926333e+00, 1.31950549e+00, 1.29012455e+00, 1.26111581e+00,
       1.23247448e+00, 1.20419599e+00, 1.17627566e+00, 1.14870903e+00,
       1.12149155e+00, 1.09461887e+00, 1.06808654e+00, 1.04189033e+00,
       1.01602591e+00, 9.90489147e-01, 9.65275821e-01, 9.40381903e-01,
       9.15803281e-01, 8.91536026e-01, 8.67576129e-01, 8.43919762e-01,
       8.20563016e-01, 7.97502127e-01, 7.74733379e-01, 7.52253010e-01,
       7.30057427e-01, 7.08142963e-01, 6.86506117e-01, 6.65143314e-01,
       6.44051139e-01, 6.23226108e-01, 6.02664894e-01, 5.82364099e-01,
       5.62320479e-01, 5.42530722e-01, 5.22991667e-01, 5.03700085e-01,
       4.84652892e-01, 4.65846944e-01, 4.47279233e-01, 4.28946692e-01,
       4.10846392e-01, 3.92975343e-01, 3.75330664e-01, 3.57909512e-01,
       3.40709009e-01, 3.23726406e-01, 3.06958896e-01, 2.90403801e-01,
       2.74058386e-01, 2.57920038e-01, 2.41986092e-01, 2.26254000e-01,
       2.10721164e-01, 1.95385101e-01, 1.80243278e-01, 1.65293275e-01,
       1.50532623e-01, 1.35958961e-01, 1.21569883e-01, 1.07363089e-01,
       9.33362322e-02, 7.94870706e-02, 6.58133165e-02, 5.23127664e-02,
       3.89832442e-02, 2.58225479e-02, 1.28285742e-02, 4.44089210e-16])

    out = symphony_to_galacticus_like_dict(
                                           sim_data=(h_mock, hist_mock),
                                           z_snap=z_snap,
                                           key_map=KEY_MAP_SYMPHONY_DEFAULT,
                                           isnap=-1,
                                           tree_index=2
                                          )
    # Ensure output size is correct
    npt.assert_equal(out[ParamKeys.mass_basic].shape, (10,))

    # Ensure outputs are expected
    mass_basic_expect = np.asarray((1.49857152e+08, 1.24471432e+08, 1.21257144e+08, 1.21257144e+08, 1.21257144e+08, 1.24071432e+08, 1.21257144e+08, 1.21257144e+08, 1.21257144e+08, 1.21257144e+08))
    npt.assert_allclose(out[ParamKeys.mass_basic], mass_basic_expect)

    mass_bound_expect = np.asarray((1.1802858e+08, 1.1762858e+08, 1.1682858e+08, 1.1601429e+08, 1.1601429e+08,1.1641429e+08, 1.1641429e+08, 1.1601429e+08, 1.1521429e+08, 1.1360000e+08))
    npt.assert_allclose(out[ParamKeys.mass_bound], mass_bound_expect)

    rvir_expect       = 1E-3 * np.asarray((11.92505, 12.00846, 12.077489, 12.145057,12.240316,12.349269,12.4436655,12.523146,12.587252,12.620098))
    npt.assert_allclose(out[ParamKeys.rvir], rvir_expect)

    r_expect =  1E-3 * np.asarray(([-352.33643 , -466.7435  ,  399.61328 ],
                            [-314.90338 , -456.68347 ,  377.55533 ],
                            [-274.99142 , -444.41937 ,  353.44006 ],
                            [-235.31105 , -429.24124 ,  326.8972  ],
                            [-192.91576 , -411.1923  ,  297.21194 ],
                            [-149.09872 , -389.31595 ,  264.40155 ],
                            [-103.52247 , -361.36618 ,  227.9414  ],
                            [ -55.2754  , -326.9403  ,  187.49542 ],
                            [  -4.342098, -285.045   ,  142.50906 ],
                            [  48.24502 , -232.74557 ,   91.75982 ]))

    x_expect = r_expect[:, 0]
    npt.assert_allclose(out[ParamKeys.x], x_expect)

    y_expect = r_expect[:, 1]
    npt.assert_allclose(out[ParamKeys.y], y_expect)

    z_expect = r_expect[:, 2]
    npt.assert_allclose(out[ParamKeys.z], z_expect)

    z_lastisolated_expect = np.array([6.1217996 , 2.62381385, 1.70290633, 0.13595896, 1.34926333,
       1.8082767 , 0.44727923, 2.35697483, 1.44084725, 0.06581332])
    npt.assert_allclose(out[ParamKeys.z_lastisolated], z_lastisolated_expect)

    iso_expect = np.ones(10, dtype=int)
    iso_expect[0] = 1
    npt.assert_allclose(out[ParamKeys.is_isolated], iso_expect)

    id_expect = np.asarray((161191255, 161763033, 162334212, 162905122, 163473737, 164041230, 164614912, 165176558, 165729990, 166275341))
    npt.assert_allclose(out['custom_id'], id_expect)

    tree_expect = 2 * np.ones(10, dtype=int)
    npt.assert_allclose(out['custom_node_tree'], tree_expect)






if __name__ == '__main__':
    test_symphony_conversion()
