"""
 Copyright 2024 - The Minton Group at Purdue University
 This file is part of Swiftest.
 Swiftest is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License 
 as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
 Swiftest is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty 
 of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
 You should have received a copy of the GNU General Public License along with Swiftest. 
 If not, see: https://www.gnu.org/licenses. 
"""
import swiftest
import unittest
import os
import tempfile
import numpy as np
import warnings
warnings.simplefilter('error', RuntimeWarning)



class TestFraggle(unittest.TestCase):
    def setUp(self):
        # Initialize a target and surface for testing
        self.tmpdir=tempfile.TemporaryDirectory()
        self.simdir = self.tmpdir.name
        
    def tearDown(self):
        # Clean up temporary directory
        self.tmpdir.cleanup() 
        
    def test_collision_outcomes(self):
        '''
        Check that the head on disruption collision generates fragments and conserves quantities
        '''
        # Colliion initial conditions taken from Fragmentation_Movie.py in the example directory

        collision_type = ["disruption_headon", "disruption_off_axis", "supercatastrophic_headon", "supercatastrophic_off_axis","hitandrun_disrupt", "hitandrun_pure", "merge", "merge_spinner"]

        # ----------------------------------------------------------------------------------------------------------------------
        # To increase the number of bodies generated in each collision type, decrease the value of the corresponding nfrag_reduction number 
        # ----------------------------------------------------------------------------------------------------------------------
        nfrag_reduction = {"disruption_headon" : 10.0,
                "disruption_off_axis"         : 10.0,
                "supercatastrophic_headon"    : 10.0,
                "supercatastrophic_off_axis"  : 10.0,
                "hitandrun_disrupt"           : 10.0,
                "hitandrun_pure"              : 1.0,
                "merge"                       : 1.0,
                "merge_spinner"               : 1.0,
                }


        # These initial conditions were generated by trial and error
        names = ["Target","Projectile"]
        pos_vectors = {"disruption_headon"         : [np.array([1.0, -5.0e-05, 0.0]),
                                                    np.array([1.0,  5.0e-05 ,0.0])],
                    "disruption_off_axis"        : [np.array([1.0, -5.0e-05, 0.0]),
                                                    np.array([1.0,  5.0e-05 ,0.0])], 
                    "supercatastrophic_headon":   [np.array([1.0, -5.0e-05, 0.0]),
                                                    np.array([1.0,  5.0e-05, 0.0])],
                    "supercatastrophic_off_axis": [np.array([1.0, -5.0e-05, 0.0]),
                                                    np.array([1.0,  5.0e-05, 0.0])],
                    "hitandrun_disrupt"         : [np.array([1.0, -4.2e-05, 0.0]),
                                                    np.array([1.0,  4.2e-05, 0.0])],
                    "hitandrun_pure"            : [np.array([1.0, -4.2e-05, 0.0]),
                                                    np.array([1.0,  4.2e-05, 0.0])],
                    "merge"                      : [np.array([1.0, -5.0e-05, 0.0]),
                                                    np.array([1.0,  5.0e-05 ,0.0])],
                    "merge_spinner"               : [np.array([1.0, -5.0e-05, 0.0]),
                                                    np.array([1.0,  5.0e-05 ,0.0])]                
                    }

        vel_vectors = {"disruption_headon"         : [np.array([ 0.00,  6.280005, 0.0]),
                                                    np.array([ 0.00,  3.90,     0.0])],
                    "disruption_off_axis"       : [np.array([ 0.00,  6.280005, 0.0]),
                                                    np.array([ 0.05,  3.90,     0.0])],
                    "supercatastrophic_headon":   [np.array([ 0.00,  6.28,     0.0]),
                                                    np.array([ 0.00,  4.28,     0.0])],
                    "supercatastrophic_off_axis": [np.array([ 0.00,  6.28,     0.0]),
                                                    np.array([ 0.05,  4.28,     0.0])],
                    "hitandrun_disrupt"         : [np.array([ 0.00,  6.28,     0.0]),
                                                    np.array([-1.45, -6.28,     0.0])],
                    "hitandrun_pure"            : [np.array([ 0.00,  6.28,     0.0]),
                                                    np.array([-1.52, -6.28,     0.0])],
                    "merge"                     : [np.array([ 0.04,  6.28,     0.0]),
                                                    np.array([ 0.05,  6.18,     0.0])],
                    "merge_spinner"             : [np.array([ 0.04,  6.28,     0.0]),
                                                    np.array([ 0.05,  6.18,     0.0])] 
                    }

        rot_vectors = {"disruption_headon"         : [np.array([0.0, 0.0, 1.0e5]),
                                                    np.array([0.0, 0.0, -5e5])],
                    "disruption_off_axis":        [np.array([0.0, 0.0, 2.0e5]),
                                                    np.array([0.0, 0.0, -1.0e5])],
                    "supercatastrophic_headon":   [np.array([0.0, 0.0, 1.0e5]),
                                                    np.array([0.0, 0.0, -5.0e5])],
                    "supercatastrophic_off_axis": [np.array([0.0, 0.0, 1.0e5]),
                                                    np.array([0.0, 0.0, -1.0e4])],
                    "hitandrun_disrupt"         : [np.array([0.0, 0.0, 0.0]),
                                                    np.array([0.0, 0.0, 1.0e5])],
                    "hitandrun_pure"            : [np.array([0.0, 0.0, 0.0]),
                                                    np.array([0.0, 0.0, 1.0e5])],
                    "merge"                     : [np.array([0.0, 0.0, 0.0]),
                                                    np.array([0.0, 0.0, 0.0])],
                    "merge_spinner"             : [np.array([0.0, 0.0, -1.2e6]),
                                                    np.array([0.0, 0.0, 0.0])],
                    }

        body_Gmass = {"disruption_headon"        : [1e-7, 1e-9],
                    "disruption_off_axis"       : [1e-7, 1e-9],
                    "supercatastrophic_headon"  : [1e-7, 1e-8],
                    "supercatastrophic_off_axis": [1e-7, 1e-8],
                    "hitandrun_disrupt"         : [1e-7, 7e-10],
                    "hitandrun_pure"            : [1e-7, 7e-10],
                    "merge"                     : [1e-7, 1e-8],
                    "merge_spinner"             : [1e-7, 1e-8] 
                    }

        tstop = {"disruption_headon"         : 2.0e-3,
                "disruption_off_axis"       : 2.0e-3,
                "supercatastrophic_headon"  : 2.0e-3,
                "supercatastrophic_off_axis": 2.0e-3,
                "hitandrun_disrupt"         : 2.0e-4,
                "hitandrun_pure"            : 2.0e-4,
                "merge"                     : 5.0e-3,
                "merge_spinner"             : 5.0e-3,
                }

        nfrag_minimum_expected = {"disruption_headon"         : 20,
                                "disruption_off_axis"       : 20,
                                "supercatastrophic_headon"  : 20,
                                "supercatastrophic_off_axis"  : 20,
                                "hitandrun_disrupt"         : 20,
                                "hitandrun_pure"            : 0,
                                "merge"                     : 0,
                                "merge_spinner"             : 0,
                                }

        nfrag_maximum_expected = {"disruption_headon"         : 100,
                                "disruption_off_axis"       : 100,
                                "supercatastrophic_headon"  : 100,
                                "supercatastrophic_off_axis"  : 100,
                                "hitandrun_disrupt"         : 100,
                                "hitandrun_pure"            : 0,
                                "merge"                     : 0,
                                "merge_spinner"             : 0,
                                }

        expected_regime = {"disruption_headon"         : "Disruption",
                        "disruption_off_axis"       : "Disruption",
                        "supercatastrophic_headon"  : "Supercatastrophic disruption",
                        "supercatastrophic_off_axis" : "Supercatastrophic disruption",
                        "hitandrun_disrupt"         : "Hit and run",
                        "hitandrun_pure"            : "Hit and run",
                        "merge"                   : "Merge",
                        "merge_spinner"           : "Hit and run"
                        }
        expected_outcome =    {"disruption_headon"         : "calculation converged",
                        "disruption_off_axis"       : "calculation converged",
                        "supercatastrophic_headon"  : "calculation converged",
                        "supercatastrophic_off_axis" : "calculation converged",
                        "hitandrun_disrupt"         : "calculation converged",
                        "hitandrun_pure"            : "No new fragments generated",
                        "merge"                   : "Merging",
                        "merge_spinner"           : "No new fragments generated"
                        }

        density = 3000 * swiftest.AU2M**3 / swiftest.MSun
        GU = swiftest.GMSun * swiftest.YR2S**2 / swiftest.AU2M**3
        body_radius = body_Gmass.copy()
        for k,v in body_Gmass.items():
            body_radius[k] = [((Gmass/GU)/(4./3.*np.pi*density))**(1./3.) for Gmass in v]

        body_radius["hitandrun_disrupt"] = [7e-6, 3.25e-6] 
        body_radius["hitandrun_pure"] = [7e-6, 3.25e-6] 
        
        with warnings.catch_warnings():
            warnings.simplefilter("error", RuntimeWarning) 
            for style in collision_type:
                sim = swiftest.Simulation(simdir=self.simdir, rotation=True, compute_conservation_values=True)
                sim.add_solar_system_body("Sun")
                sim.add_body(name=names, Gmass=body_Gmass[style], radius=body_radius[style], rh=pos_vectors[style], vh=vel_vectors[style], rot=rot_vectors[style])

                # Set fragmentation parameters
                minimum_fragment_gmass = 0.01 * body_Gmass[style][1] 
                gmtiny = 0.50 * body_Gmass[style][1] 
                sim.set_parameter(collision_model="fraggle", 
                                encounter_save="both", 
                                gmtiny=gmtiny, 
                                minimum_fragment_gmass=minimum_fragment_gmass, 
                                nfrag_reduction=nfrag_reduction[style])
                sim.run(dt=tstop[style]/4, tstop=tstop[style], istep_out=1, dump_cadence=0)
                
                collision_logfile = os.path.join(self.simdir, "collisions.log")
                with open(collision_logfile, "r") as f:
                    content = f.read()
                    
                self.assertIn(expected_outcome[style], content, f'{style}: The collision.log file does not contain "{expected_outcome[style]}"')
            
                regime_name = content.split("Regime:")[1].split('\n')[0].strip() 
                self.assertIn(expected_regime[style], regime_name, f'{style}: The collision.log file does not contain the expected regime name.\nExpected: "{expected_regime[style]}"\nGot: "{regime_name}"')
            
                newbody_count = 0 
                with open(collision_logfile, "r") as f:
                    for line in f:
                        if "Newbody" in line:
                            newbody_count += 1
                            
                self.assertGreaterEqual(newbody_count, nfrag_minimum_expected[style], f"{style}: Expected more than {nfrag_minimum_expected[style]} new bodies, got {newbody_count}") 
                self.assertLessEqual(newbody_count, nfrag_maximum_expected[style], f"{style}: Expected less than {nfrag_maximum_expected[style]} new bodies, got {newbody_count}")
                sim.clean()
        return 


    def test_rotation_direction(self):
        # Tests that the rotation state of a target body is consistent with the impact condition. We test this by checking the
        # sign of the z-axis component of rotation and relative strength of the x and y components relative to the z component for
        # off axis collisions, and also ensure that little rotation occurs for head on collisions.
        
        collision_type = ["disruption_off_axis_clockwise", "disruption_off_axis_counterclockwise", "disruption_headon"]
        nfrag_reduction = 10.0
        names = ["Target","Projectile"]
        pos_vectors =  [np.array([1.0, -5.0e-05, 0.0]),
                        np.array([1.0,  5.0e-05 ,0.0])]
        vel_vectors = {
                    "disruption_headon"                    : [np.array([ 0.00,  6.280005, 0.0]),
                                                              np.array([ 0.00,  3.90,     0.0])],
                    "disruption_off_axis_clockwise"        : [np.array([ 0.00,  6.280005, 0.0]),
                                                              np.array([ 0.05,  3.90,     0.0])], 
                    "disruption_off_axis_counterclockwise" : [np.array([ 0.00,  6.280005, 0.0]),
                                                              np.array([-0.05,  3.90,     0.0])],
                    }

        body_Gmass = np.array([1e-7, 1e-9])
        tstop = 2e-3
        dt = tstop/4
        
        GU = swiftest.GMSun * swiftest.YR2S**2 / swiftest.AU2M**3
        body_Gmass = np.array([1e-7, 1e-9])
        body_mass = body_Gmass / GU
        density = 3000 * swiftest.AU2M**3 / swiftest.MSun
        body_radius = ((body_mass)/(4./3.*np.pi*density))**(1./3.) 
        
        with warnings.catch_warnings():
            warnings.simplefilter("error", RuntimeWarning) 
            for style in collision_type:        
                sim = swiftest.Simulation(simdir=self.simdir, rotation=True, init_cond_format = "XV", compute_conservation_values=True)
                sim.add_solar_system_body("Sun")
                sim.add_body(name=names, Gmass=body_Gmass, radius=body_radius, rh=pos_vectors, vh=vel_vectors[style])

                # Set fragmentation parameters
                minimum_fragment_gmass = 0.01 * body_Gmass[1] 
                gmtiny = 0.50 * body_Gmass[1] 
                sim.set_parameter(collision_model="fraggle", 
                                encounter_save="both", 
                                gmtiny=gmtiny, 
                                minimum_fragment_gmass=minimum_fragment_gmass, 
                                nfrag_reduction=nfrag_reduction)
                sim.run(dt=dt, tstop=tstop, istep_out=1, dump_cadence=0)        
                
                rot_final = sim.data.isel(time=-1).sel(name='Target').rot
                rotmag_final = rot_final.magnitude().values[()]
                vc = vel_vectors[style][1] - vel_vectors[style][0]
                rc = pos_vectors[1] - pos_vectors[0]
                
                dL = body_mass[1] * np.cross(rc, vc)
                expected_rot = np.rad2deg(dL / (body_mass[0] * body_radius[0]**2 * 0.4))
                rot_error =((rot_final - expected_rot) / rotmag_final).values
                
                # Assert that the absolute value of all components of rot_error is less than 2
                self.assertTrue(np.all(np.abs(rot_error) < 2), f"{style}: The relative error in the final rotation is greater than 2")
                
                
                
                # for the off axis collision styles, assert that the sign of the z component of rotation is consistent with the expected_rot sign
                if style != "disruption_headon":
                    self.assertTrue(np.sign(rot_final.values[2]) == np.sign(expected_rot[2]), f"{style}: The sign of the z component of rotation is not consistent with the expected rotation")
                    
                    # check that the ratio of the z component to either x or y is greater than 100
                    self.assertTrue(np.abs(rot_final.values[2]) > 100*np.abs(rot_final.values[0]), f"{style}: The z component of rotation is not significantly larger than the x component")
                    self.assertTrue(np.abs(rot_final.values[2]) > 100*np.abs(rot_final.values[1]), f"{style}: The z component of rotation is not significantly larger than the y component")
                
                        
if __name__ == '__main__':
    os.environ["HDF5_USE_FILE_LOCKING"]="FALSE"
    unittest.main()