#!/usr/bin/env python
"""
Analyse Networks
================

This script provides a command-line interface (CLI) to analyze polymer networks
generated by LAMMPS simulations. It reads structure files, computes various
properties of the networks, and outputs statistics such as bond lengths,
end-to-end distances, and stoichiometric imbalances.
"""

import click
import statistics

from pylimer_tools.calc.structure_analysis import (
    compute_crosslinker_conversion,
    compute_extent_of_reaction,
    compute_stoichiometric_imbalance,
)
from pylimer_tools.io.bead_spring_parameter_provider import (
    get_parameters_for_polymer,
    get_supported_polymer_names,
)
from pylimer_tools.io.read_lammps_output_file import read_data_file
from pylimer_tools_cpp import MEHPForceBalance2


@click.command()
@click.argument("files", nargs=-1, type=click.Path(exists=True))
@click.option("--crosslinker-type", type=int, default=2)
@click.option(
    "--polymer-name",
    type=click.Choice(get_supported_polymer_names(), case_sensitive=False),
    default="PDMS",
    help="Name of the polymer to use for parameter retrieval.",
)
def cli(files, crosslinker_type, polymer_name):
    """
    Basic CLI application reading all passed files, outputting some stats on the structures therein

    Arguments:
      - files: list of files to read
    """
    click.echo("Processing {} files".format(len(files)))
    crosslinker_type = crosslinker_type
    params = get_parameters_for_polymer(
        polymer_name=polymer_name,
    )
    base_distance_unit = params.get_base_distance_units().units

    for file_path in files:
        click.echo(
            "\nAnalysing File with {} units ".format(
                params.get_name()) + file_path
        )

        universe = read_data_file(file_path)

        click.echo(
            "Size: {}. Volume: {} {}^3 (ρ = {} atoms/{}^3)".format(
                universe.get_nr_of_atoms(),
                universe.get_volume(),
                base_distance_unit,
                universe.get_nr_of_atoms() / universe.get_volume(),
                base_distance_unit,
            )
        )
        click.echo(
            "{} atoms and {} bonds, {} angles, {} dihedrals".format(
                universe.get_nr_of_atoms(),
                universe.get_nr_of_bonds(),
                universe.get_nr_of_angles(),
                universe.get_nr_of_dihedral_angles(),
            )
        )
        molecules = universe.get_molecules(crosslinker_type)
        bond_lengths = universe.compute_bond_lengths()
        non_none_bond_lengths = [
            bl for bl in bond_lengths if bl is not None and bl > 0]
        click.echo(
            "Bond length b: <b> = {} {}, (min: {}, max: {}, median: {}) {}, <b^2> = {} {}^2".format(
                statistics.mean(non_none_bond_lengths),
                base_distance_unit,
                min(non_none_bond_lengths),
                max(non_none_bond_lengths),
                statistics.median(non_none_bond_lengths),
                base_distance_unit,
                statistics.mean(
                    bl**2 for bl in bond_lengths if bl is not None),
                base_distance_unit,
            )
        )
        end_to_end_distances = universe.compute_end_to_end_distances(
            crosslinker_type=crosslinker_type, derive_image_flags=True
        )
        non_none_end_to_end_distances = [
            e for e in end_to_end_distances if e is not None and e > 0
        ]
        click.echo(
            "End to end distance R_ee: <R_ee> = {} {}, <R_ee^2> = {}".format(
                statistics.mean(non_none_end_to_end_distances),
                base_distance_unit,
                statistics.mean(e**2 for e in non_none_end_to_end_distances),
            )
        )
        click.echo(
            "For {} molecules of mean length of {} atoms".format(
                len(molecules),
                statistics.mean([m.get_nr_of_atoms() for m in molecules]),
            )
        )
        click.echo(
            "f = {}, r = {}, p = {} ({}), D = {}".format(
                universe.determine_functionality_per_type()[crosslinker_type],
                compute_stoichiometric_imbalance(universe, crosslinker_type),
                compute_extent_of_reaction(universe, crosslinker_type),
                # mehp.calculateEffectiveCrosslinkerFunctionality(
                #     universe, crosslinker_type),
                compute_crosslinker_conversion(universe, crosslinker_type),
                universe.compute_polydispersity_index(crosslinker_type),
            )
        )

        # Conduct force balance analysis
        mehp = MEHPForceBalance2(
            universe,
            crosslinker_type=crosslinker_type,
            nr_of_entanglements_to_sample=int(
                params.get_entanglement_density() * universe.get_volume()
            ),
            upper_sampling_cutoff=params.get_sampling_cutoff(),
        )

        mehp.run_force_relaxation()

        click.echo(
            "Shear Modulus [MPa]: {:.2f}".format(
                params.get_gamma_conversion_factor().to("MPa").magnitude
                * sum(
                    mehp.get_gamma_factors(
                        params.get("R02")
                        .to(params.get("distance_units") ** 2)
                        .magnitude
                    )
                )
                / universe.get_volume()
            )
        )
        click.echo(
            "w_sol = {:.3f}, w_dang = {:.3f}, w_active = {:.3f}".format(
                mehp.get_soluble_weight_fraction(),
                mehp.get_dangling_weight_fraction(),
                1
                - mehp.get_soluble_weight_fraction()
                - mehp.get_dangling_weight_fraction(),
            )
        )

        click.echo("")
    click.echo("Arbitrary units used. E.g.: Length: u")


if __name__ == "__main__":
    cli()
