#!/usr/bin/env python
import argparse
import numpy as np
import dynaphopy
import dynaphopy.interface.iofile as reading
import dynaphopy.interface.interactive_ui as interactive_ui
import dynaphopy.generate_cell as generate
from dynaphopy.interface.phonopy_link import get_force_sets_from_file, get_force_constants_from_file

from fractions import Fraction

# Define arguments
parser = argparse.ArgumentParser(description='DynaPhoPy options')
parser.add_argument('input_file', metavar='data_file', type=str, nargs=1,
                    help='input file containing structure related data')

group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('md_file', metavar='trajectory', type=str, nargs='?',
                   help='Output file containing MD trajectory')
group.add_argument('-lv', '--load_data', metavar='file', type=str, nargs=1,
                   help='load MD data from hdf5 file')

parser.add_argument('-i', '--interactive', action='store_true',
                    help='enter interactive mode')

parser.add_argument('-q', metavar='F', type=str, nargs=3,
                    help='wave vector used to calculate projections (default: 0 0 0)')

parser.add_argument('-r', '--frequency_range', metavar='F', type=float, nargs=2,
                    help='frequency range of power spectrum in THz (min, max)')

parser.add_argument('-n', metavar='N', type=int, default=0,
                    help='number of MD last steps to take (default: All)')

parser.add_argument('-ts', '--time_step', metavar='step', type=float, default=None,
                    help='time step in ps (default: Read from trajectory file)')

parser.add_argument('-pd', '--plot_full', action='store_true',
                    help='plot full power spectrum')

parser.add_argument('-pw', '--plot_wave_vector', action='store_true',
                    help='plot projection into wave vector')

parser.add_argument('-pp', '--plot_phonon_mode', action='store_true',
                    help='plot projection into phonon modes')

parser.add_argument('-sd', '--save_full', metavar='file', type=str, nargs=1,
                    help='save full power spectrum to file')

parser.add_argument('-sw', '--save_wave_vector', metavar='file', type=str, nargs=1,
                    help='save projection into wave vector to file')

parser.add_argument('-sp', '--save_phonon_mode', metavar='file', type=str, nargs=1,
                    help='save projection into phonon modes to file')

parser.add_argument('-sv', '--save_data', metavar='file', type=str, nargs=1, default=False,
                    help='save MD data into hdf5 file')

parser.add_argument('-svc', '--save_vc_hdf5', metavar='file', type=str, nargs=1, default=False,
                    help='save wave vector projected velocity into hdf5 file')

parser.add_argument('-psm', '--power_spectrum_algorithm', metavar='N', type=int, nargs=1,
                    help='select power spectrum calculation algorithm (default MEM)')

parser.add_argument('-cf', '--number_of_mem_coefficients', metavar='N', type=int, nargs=1,
                    help='number of coefficients to use in MEM algorithm (default 300)')

parser.add_argument('-csa', '--coefficient_scan_analysis', action='store_true',
                    help='request coefficient scan analysis')

parser.add_argument('-pa', '--peak_analysis', action='store_true',
                    help='request a peak analysis')

parser.add_argument('-pad', '--plot_atomic_displacements', metavar='F', type=str, nargs=3,
                    help='plot atomic displacements respect to specified direction [F F F]')

parser.add_argument('-sad', '--save_atomic_displacements', metavar='S', type=str, nargs=4,
                    help='save atomic displacements into a file [ F F F filename ]')

parser.add_argument('-sfc', '--save_force_constants', metavar='file', type=str, nargs=1, default=False,
                    help='save the renormalized force constants into a file')

parser.add_argument('-adp', action='store_true',
                    help='print anisotropic displacement parameters')

parser.add_argument('-thm', action='store_true',
                    help='print thermal properties')

parser.add_argument('--thm_full', action='store_true',
                    help='print thermal properties from power spectrum')

parser.add_argument('--thm_reference', action='store_true',
                    help='print reference thermal properties computed with phonopy ')

parser.add_argument('-sdata', '--save_quasiparticle_data', action='store_true',
                    help='save quasiparticle data into a YAML formatted file')

parser.add_argument('-smesh', '--save_mesh_data', action='store_true',
                    help='save mesh data into a YAML formatted file')


# Extra options
parser.add_argument('--temperature', metavar='F', type=float, default=None,
                    help='set temperature (thermal properties)')

parser.add_argument('--silent', action='store_true',
                    help='executes without showing plots')

parser.add_argument('--velocity_only', action='store_true',
                    help='loads only velocity data from hdf5 file')

parser.add_argument('--fitting_function', metavar='index', type=int, default=0,
                    help='define fitting function (default: Lorentzian)')

parser.add_argument('--read_from', metavar='step', type=int, default=1,
                    help='define interval of trajectory to read (default: 1)')

parser.add_argument('--read_to', metavar='step', type=int, default=None,
                    help='define interval of trajectory to read (default: end)')

parser.add_argument('--fcsymm', action='store_true',
                    help='symmetrize force constants')

parser.add_argument('--symprec', metavar='F', type=float, default=1e-5,
                    help='set symmetry search tolerance for phonopy')

parser.add_argument('--no_symmetry', action='store_true',
                    help='do not use crystal symmetry')

parser.add_argument('--save_renormalized_frequencies', action='store_true',
                    help='save renormalized frequencies in a file when renormalized force constants are calculated')

parser.add_argument('--save_band_structure', metavar='file', type=str, default=False,
                    help='save renormalized phonon dispersion bands into a yaml file')

parser.add_argument('--resolution', metavar='F', type=float, default=None, nargs=1,
                    help='define power spectrum resolution in THz (default 0.05)')

parser.add_argument('--dim', metavar='N', type=int, default=[1, 1, 1], nargs=3,
                    help='Set the dimensions of generated structure/trajectory supercell (Default: [1, 1, 1])')

group.add_argument('--generate_trajectory', metavar='F', type=float, default=None, nargs=3,
                    help='generate a test harmonic trajectory using force constants [total time (ps), '
                         'time step (ps), amplitude (Kelvin)]')

group.add_argument('--run_lammps', metavar='F', type=str, default=None, nargs=4,
                    help='Run lammps input file [input_file  total_time time_step relaxation_time]')

group.add_argument('-c_poscar', metavar='file', type=str, default=None, nargs=1,
                    help='generate supercell in VASP POSCAR file format')

group.add_argument('-c_lammps', metavar='file', type=str, default=None, nargs=1,
                    help='generate supercell in LAMMPS data file format')

parser.add_argument('-average', action='store_true',
                    help='returns average atomic positions')

parser.add_argument('--MD_commensurate', action='store_true',
                    help='use commensurate points in MD supercell instead of force constants supercell')

parser.add_argument('--normalize_dos', action='store_true',
                    help='normalize DoS obtained from PS in thermal properties calculation')

parser.add_argument('--memmap', action='store_true',
                    help='map largest arrays into files to reduce RAM memory usage')

parser.add_argument('--qha_force_constants', metavar='file', type=str, nargs=1,
                    help='Adds QHA contribution to shifts via renormalized force constants')


#Temporal interface on development
parser.add_argument('--auto_order', action='store_true',
                    help='set the atoms order automatically from the first step of the MD)')

parser.add_argument('--project_on_atom', metavar='atom_index', type=int, default=None,
                    help='calculate PS projected onto one atom')

parser.add_argument('--save_power_spectrum_partials', metavar='filename', type=str, default=None,
                    help='store the full power spectrum separated by atom contributions')

parser.add_argument('--nac', action='store_true',
                    help='use non-analytical term corrections (BORN file should exist) ')

parser.add_argument('--save_vq', metavar='file', type=str, nargs=1, default=False,
                    help='save phonon mode projected velocity into file')

parser.add_argument('--save_vc', metavar='file', type=str, nargs=2, default=False,
                    help='save wave vector mode projected velocity into file')

parser.add_argument('--save_extra_properties', action='store_true',
                    help='stores extra q-point properties when using -sdata (group velocity)')

args = parser.parse_args()

# Get data from input file & process parameters
input_parameters = reading.read_parameters_from_input_file(args.input_file[0])

if 'structure_file_name_outcar' in input_parameters:
    structure = reading.read_from_file_structure_outcar(input_parameters['structure_file_name_outcar'])
    structure_file = input_parameters['structure_file_name_outcar']
else:
    structure = reading.read_from_file_structure_poscar(input_parameters['structure_file_name_poscar'])
    structure_file = input_parameters['structure_file_name_poscar']

structure.get_data_from_dict(input_parameters)

if 'supercell_phonon' in input_parameters:
    supercell_phonon = input_parameters['supercell_phonon']
else:
    supercell_phonon = np.identity(3)

if 'force_sets_file_name' in input_parameters:
    structure.set_force_set(get_force_sets_from_file(file_name=input_parameters['force_sets_file_name'],
                                                     fs_supercell=supercell_phonon))

if 'force_constants_file_name' in input_parameters:
    structure.set_force_constants(get_force_constants_from_file(file_name=input_parameters['force_constants_file_name'],
                                  fc_supercell=supercell_phonon))

if 'force_sets_file_name' in input_parameters and 'force_constants_file_name' in input_parameters:
    print('Both FORCE SETS and FORCE CONSTANTS are found in input file, only one is allowed')
    exit()

if args.c_poscar:
    f = open(args.c_poscar[0], 'w')
    f.write(generate.generate_VASP_structure(structure, scaled=True, supercell=args.dim))
    f.close()
    exit()

if args.c_lammps:
    f = open(args.c_lammps[0], 'w')
    f.write(generate.generate_LAMMPS_structure(structure, supercell=args.dim))
    f.close()
    exit()


# Load trajectory
if args.load_data:
    trajectory = reading.initialize_from_hdf5_file(args.load_data[0],
                                                   structure,
                                                   read_trajectory=not args.velocity_only,
                                                   initial_cut=args.read_from,
                                                   final_cut=args.read_to,
                                                   memmap=args.memmap)
    structure_file = args.load_data[0]
if args.md_file:
    trajectory_reading_function = reading.get_trajectory_parser(args.md_file)
    if trajectory_reading_function is None:
        print('Trajectory file format not recognized')
        exit()
    if args.auto_order:
        print ('Use auto order')
        template = reading.check_atoms_order(args.md_file, trajectory_reading_function, structure)
    else:
        template = None

    trajectory = trajectory_reading_function(args.md_file,
                                             structure,
                                             args.time_step,
                                             initial_cut=args.read_from,
                                             end_cut=args.read_to,
                                             memmap=args.memmap,
                                             template=template
                                             )
    # np.savetxt('trajectory.xyz', trajectory.trajectory.real[0], fmt='C %.4e %.4e %.4e')

if args.generate_trajectory:
    trajectory = reading.generate_test_trajectory(structure,
                                                  total_time=args.generate_trajectory[0],
                                                  time_step=args.generate_trajectory[1],
                                                  temperature=args.generate_trajectory[2],
                                                  supercell=args.dim,
                                                  silent=args.silent,
                                                  memmap=args.memmap)

if args.run_lammps:
    from dynaphopy.interface.lammps_link import generate_lammps_trajectory
    trajectory = generate_lammps_trajectory(structure, args.run_lammps[0],
                                            total_time=float(args.run_lammps[1]),
                                            time_step=float(args.run_lammps[2]),
                                            relaxation_time=float(args.run_lammps[3]),
                                            silent=args.silent,
                                            supercell=args.dim,
                                            memmap=args.memmap,
                                            velocity_only=args.velocity_only,
                                            temperature=args.temperature)

if isinstance(trajectory, list) or isinstance(trajectory, tuple):
    print('Loading projected velocity only (limited features only)')
    calculation = dynaphopy.Quasiparticle(trajectory[2], vc=trajectory[0], last_steps=args.n)
    input_parameters.update({'_reduced_q_vector': trajectory[1], '_use_symmetry': False})

else:
    calculation = dynaphopy.Quasiparticle(trajectory, last_steps=args.n)


calculation.parameters.get_data_from_dict(input_parameters)


# Set Parameters
calculation.parameters.silent = args.silent
calculation.parameters.degenerate = args.fcsymm
calculation.parameters.use_MD_cell_commensurate = args.MD_commensurate
calculation.parameters.save_renormalized_frequencies = args.save_renormalized_frequencies
calculation.parameters.use_NAC = args.nac
calculation.parameters.symprec = args.symprec

calculation.select_fitting_function(args.fitting_function)
calculation.set_temperature(args.temperature)

if args.qha_force_constants is not None:
    calculation.set_qha_force_constants(args.qha_force_constants[0])

if args.project_on_atom is not None:
    calculation.set_projection_onto_atom_type(args.project_on_atom)

if args.no_symmetry:
    calculation.parameters.use_symmetry = False

# Process properties arguments
if args.resolution:
    calculation.set_spectra_resolution(args.resolution[0])

if args.q:
    calculation.set_reduced_q_vector(np.array([float(Fraction(s)) for s in args.q]))

if args.power_spectrum_algorithm:
    calculation.select_power_spectra_algorithm(args.power_spectrum_algorithm[0])

if args.number_of_mem_coefficients:
    calculation.parameters.number_of_coefficients_mem = args.number_of_mem_coefficients[0]

# Process save properties
if args.save_vc_hdf5:
    calculation.save_vc_hdf5(args.save_vc_hdf5[0])

if args.save_vq:
    calculation.save_vq(args.save_vq[0])

if args.save_vc:
    calculation.save_vc(args.save_vc[0], int(args.save_vc[1]))

if args.save_data:
    calculation.save_velocity_hdf5(args.save_data[0], save_trajectory=not args.velocity_only)

# Process calculation arguments
if args.frequency_range:
    calculation.set_frequency_limits(args.frequency_range)
 #   calculation.set_frequency_range(np.linspace(*args.frequency_range))

if args.save_force_constants:
    calculation.write_renormalized_constants(args.save_force_constants[0])

if args.save_band_structure:
    calculation.write_renormalized_phonon_dispersion_bands(filename=args.save_band_structure)

if args.save_full:
    calculation.write_power_spectrum_full(args.save_full[0])

if args.save_wave_vector:
    calculation.write_power_spectrum_wave_vector(args.save_wave_vector[0])

if args.save_phonon_mode:
    calculation.write_power_spectrum_phonon(args.save_phonon_mode[0])

if args.plot_full:
    calculation.plot_power_spectrum_full()

if args.plot_wave_vector:
    calculation.plot_power_spectrum_wave_vector()

if args.plot_phonon_mode:
    calculation.plot_power_spectrum_phonon()

if args.plot_atomic_displacements:
    calculation.plot_trajectory_distribution([float(Fraction(i)) for i in args.plot_atomic_displacements])

if args.save_atomic_displacements:
    calculation.write_atomic_displacements([float(Fraction(i)) for i in args.save_atomic_displacements[0:3]],
                                           args.save_atomic_displacements[3])

if args.save_quasiparticle_data:
    calculation.write_quasiparticles_data(with_extra=args.save_extra_properties)

if args.save_mesh_data:
    calculation.write_mesh_data()


if args.coefficient_scan_analysis:
    calculation.phonon_width_scan_analysis()

if args.peak_analysis:
    calculation.phonon_individual_analysis()

if args.adp:
    calculation.get_anisotropic_displacement_parameters()

if args.thm:
    calculation.display_thermal_properties(normalize_dos=args.normalize_dos,
                                           print_phonopy=args.thm_reference)

if args.thm_full:
    calculation.display_thermal_properties(from_power_spectrum=True,
                                           normalize_dos=args.normalize_dos,
                                           print_phonopy=args.thm_reference)


# Temporal interface (name may change)
if args.save_power_spectrum_partials:
    calculation.get_power_spectrum_partials(save_to_file=args.save_power_spectrum_partials)
# -----------------

if args.average:
    calculation.get_average_atomic_positions()


if args.interactive:
    if isinstance(trajectory, list) or isinstance(trajectory, tuple):
        print('Interactive mode cannot be used loading wave vector projected velocity')
        exit()
    interactive_ui.interactive_interface(calculation, trajectory, args, structure_file)