import numpy as np
from ovito.io import *
from ovito.modifiers import *
from ovito.pipeline import *
import numpy as np
import subprocess
import random
import os
from config import Config
from typing import Tuple, Optional, Callable, List

class TDE_simulation:

    def __init__(self, atom_id, mass_data, direction=None,
                 lammps_data_file=None, run_line=None):

        # direction is np array 
        self.direction = direction

        
        self.mass_data = mass_data
        if not isinstance(self.mass_data, dict):
            raise TypeError('mass_data must be a dictionary of form [int,int], for lammps atom type -> mass')

        # ID of PKA atom
        self.atom_id = atom_id
        if not isinstance(self.atom_id, int):
            raise TypeError('atom_id must be an integer')
        

        # File containing lammps position data
        self.lammps_data_file = lammps_data_file

        # Initialise PKA mass at start of calcualtions
        self.get_pka_mass()

        self.run_line = run_line




    def get_pka_mass (self):

        # Pass lammps data file to retrieve PKA info
        with open (self.lammps_data_file, 'r') as f:
            lines = f.readlines()
        for index, line in enumerate(lines):
            if 'Atoms' in line:
                atom_data_start = index + 2
        data_lines = lines[atom_data_start:]

        # Parse data for PKA atom
        for line in data_lines:
            atom_data = line.split()
            if int(atom_data[0]) == int(self.atom_id):
                atom_type = atom_data[1]
                break
    
        self.pka_mass =  self.mass_data[int(atom_type)]



    def calculate_pka_vel (self, pka_vector, pka_energy):

        # PKA vector is list of floats
        eV_to_J = 1.60218e-19  # J/eV
        amu_to_kg = 1.66e-27   # kg/amu

        pka_vector_magnitude = np.linalg.norm(pka_vector)
        unit_vector = pka_vector / pka_vector_magnitude


        pka_velocity = np.sqrt(2 * pka_energy * eV_to_J / (self.pka_mass * amu_to_kg))
        pka_velocity = pka_velocity/100 #in lammps units

        pka_velocity_vector = pka_velocity * unit_vector

        # print(f"Velocity magnitude: {pka_velocity:.2f} m/s")
        # print(f"Velocity vector: {pka_velocity_vector}")

        # pka velocity vector is list of floats
        return pka_velocity_vector

    def defect_check (self):

        # print('Checking for defects..')
        filename = 'prod.data'
        pipeline = import_file(filename)


        ws_modifier = WignerSeitzAnalysisModifier(
            per_type_occupancies = True
            #eliminate_cell_deformation = True,
            #affine_mapping = ReferenceConfigurationModifier.AffineMapping.ToReference
        )


        pipeline.modifiers.append(ws_modifier)
        for frame in range(1, pipeline.source.num_frames):
            data = pipeline.compute(frame)
            occupancies = data.particles['Occupancy'].array
            occupancy2 = 0 #total num interstitial
            occupancy0 = 0 #total num vacancies
            # Get the site types as additional input:
            site_type = data.particles['Particle Type'].array
            # Calculate total occupancy of every site:
            try:
                total_occupancy = np.sum(occupancies, axis=1)
            except np.AxisError:
                total_occupancy = occupancies
            #print(total_occupancy)
            for element in total_occupancy:
                if element == 0:
                    occupancy0 +=1
                if element >= 2:
                    occupancy2 += (1 + (element-2))
            # Set up a particle selection by creating the Selection property:
            selection = data.particles_.create_property('Selection')

            # total number of types
            type_list = [x for x in range(1, len(self.mass_data)+1)]
            # This logic should work generall for any number of types
            # Lower triangular matrix of pairs excluding diagonal contains all unique pairs
            pair_matrix = np.array([[[i, j] for j in type_list] for i in type_list])
            i_lower_ex, j_lower_ex = np.tril_indices(len(type_list), k=-1)
            lower_triangular_matrix = pair_matrix[i_lower_ex, j_lower_ex] # Unique pairs

            for n in range(len(lower_triangular_matrix)):
                # Select A-sites occupied by exactly one B, C, or D atom
                # (the occupancy of the corresponding atom type must be 1, and all others 0)
                selection[...] |= ((site_type == lower_triangular_matrix[n][0]) & (occupancies[:, lower_triangular_matrix[n][1] - 1] == 1) & (total_occupancy == 1)) 
                             
                
            antisite_indices = np.where(selection == 1)[0]


            # Count the total number of antisite defects
            antisite_count = np.count_nonzero(selection[...])

            # Output the total number of antisites as a global attribute:
            data.attributes['Antisite_count'] = antisite_count
            tot_num_defects =  antisite_count + occupancy0 + occupancy2

            defect_count = tot_num_defects
        return defect_count

    def modify_lammps_in (self, filename, pka_vector):

        with open (filename, 'r') as f:
            lines = f.readlines()

        with open (filename, 'w') as f:
            for line in lines:
                if 'velocity PKA set' in line:
                    f.write(f'velocity PKA set {pka_vector[0]} {pka_vector[1]} {pka_vector[2]}\n')
                elif 'group PKA id' in line:
                    f.write(f'group PKA id {self.atom_id}\n')
                else:
                    f.write(line)


    def run_tde_loop (self, start_energy=None):

      
        start_energy = start_energy if start_energy is not None else 10
        energy_increment = 2
        max_energy = 1000

        energy = start_energy
        tde_energy = None
        while True:
            tempdirname = str(np.random.randint(1, 1000))
            try:
                os.mkdir(tempdirname)
                os.chdir(tempdirname)
                break 
            except FileExistsError:
                # Directory already exists, try again
                continue
        os.system('cp ../tde.in .')
        while energy < max_energy:
            
            # Calculate velocity and run simulation
            vel_vector = self.calculate_pka_vel(self.direction, pka_energy=energy)
            self.modify_lammps_in(filename='tde.in', pka_vector=vel_vector)
            subprocess.run(self.run_line, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            defect_count = self.defect_check()

            if defect_count != 0:  # defect found
                # Now step back down by 1 eV until no defect
                while defect_count != 0 and energy > 0:
                    energy -= 1
                    vel_vector = self.calculate_pka_vel(self.direction, pka_energy=energy)
                    self.modify_lammps_in(filename='tde.in', pka_vector=vel_vector)
                    subprocess.run(self.run_line, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
                    defect_count = self.defect_check()

                # Once we exit, energy is the lowest energy with a dfect found
                tde_energy = energy + 1
                break  # stop searching, we’ve bracketed the threshold

            else:
                energy += energy_increment  # keep increasing
                
        os.chdir('../')
        os.system(f'rm -rf {tempdirname}')

        return float(tde_energy)
                
    
def evaluate_tde(cfg, u: np.ndarray, start_energy = 10) -> Tuple[float, float]:
    x, y, z = u
    r2 = x**2 + y**2 + z**2
#     energy = (
#     30
#     + 8 * (x**2 * y**2 + y**2 * z**2 + z**2 * x**2)  # couples all directions symmetrically
#     + 5 * r2
#     + 6 * np.sin(4 * np.arctan2(abs(y), abs(x))) * (1 - z**2)  # symmetric angular variation
# )
    energy = TDE_simulation(atom_id = cfg.atom_id, mass_data= cfg.mass_data, 
                            direction = u, lammps_data_file=cfg.data_file,
                            run_line=cfg.run_line).run_tde_loop(start_energy = start_energy)
    return float(energy), 0.5






