#!/usr/bin/env python3
#
# Python script to run and analyse MMS test
#

# Cores: 2
# requires: zoidberg

from boututils.run_wrapper import build_and_log, launch_safe
from boutdata.collect import collect
import boutconfig as conf

from numpy import array, log, polyfit, linspace, arange

import pickle

from sys import stdout

import zoidberg as zb

nx = 4  # Not changed for these tests

# Resolution in y and z
nlist = [8, 16, 32, 64, 128]

# Number of parallel slices (in each direction)
nslices = [1]

directory = "data"

nproc = 2
mthread = 2


success = True

error_2 = {}
error_inf = {}
method_orders = {}

# Run with periodic Y?
yperiodic = True

failures = []

build_and_log("FCI MMS test")

for nslice in nslices:
    for method in [
        "hermitespline",
        "lagrange4pt",
        "bilinear",
        # "monotonichermitespline",
    ]:
        error_2[nslice] = []
        error_inf[nslice] = []

        # Which central difference scheme to use and its expected order
        order = nslice * 2
        method_orders[nslice] = {"name": "C{}".format(order), "order": order}

        for n in nlist:
            # Define the magnetic field using new poloidal gridding method
            # Note that the Bz and Bzprime parameters here must be the same as in mms.py
            field = zb.field.Slab(Bz=0.05, Bzprime=0.1)
            # Create rectangular poloidal grids
            poloidal_grid = zb.poloidal_grid.RectangularPoloidalGrid(
                nx, n, 0.1, 1.0, MXG=1
            )
            # Set the ylength and y locations
            ylength = 10.0

            if yperiodic:
                ycoords = linspace(0.0, ylength, n, endpoint=False)
            else:
                # Doesn't include the end points
                ycoords = (arange(n) + 0.5) * ylength / float(n)

            # Create the grid
            grid = zb.grid.Grid(poloidal_grid, ycoords, ylength, yperiodic=yperiodic)
            # Make and write maps
            maps = zb.make_maps(grid, field, nslice=nslice, quiet=True, MXG=1)
            zb.write_maps(
                grid,
                field,
                maps,
                new_names=False,
                metric2d=conf.isMetric2D(),
                quiet=True,
            )

            args = " MZ={} MYG={} mesh:paralleltransform:y_periodic={} mesh:ddy:first={} NXPE={}".format(
                n,
                nslice,
                yperiodic,
                method_orders[nslice]["name"],
                2 if conf.has["petsc"] and method == "hermitespline" else 1,
            )
            args += f" mesh:paralleltransform:xzinterpolation:type={method}"

            # Command to run
            cmd = "./fci_mms " + args

            print("Running command: " + cmd)

            # Launch using MPI
            s, out = launch_safe(cmd, nproc=nproc, mthread=mthread, pipe=True)

            # Save output to log file
            with open("run.log." + str(n), "w") as f:
                f.write(out)

            if s:
                print("Run failed!\nOutput was:\n")
                print(out)
                exit(s)

            # Collect data
            l_2 = collect(
                "l_2",
                tind=[1, 1],
                info=False,
                path=directory,
                xguards=False,
                yguards=False,
            )
            l_inf = collect(
                "l_inf",
                tind=[1, 1],
                info=False,
                path=directory,
                xguards=False,
                yguards=False,
            )

            error_2[nslice].append(l_2)
            error_inf[nslice].append(l_inf)

            print("Errors : l-2 {:f} l-inf {:f}".format(l_2, l_inf))

        dx = 1.0 / array(nlist)

        # Calculate convergence order
        fit = polyfit(log(dx), log(error_2[nslice]), 1)
        order = fit[0]
        stdout.write("Convergence order = {:f} (fit)".format(order))

        order = log(error_2[nslice][-2] / error_2[nslice][-1]) / log(dx[-2] / dx[-1])
        stdout.write(", {:f} (small spacing)".format(order))

        # Should be close to the expected order
        if order > method_orders[nslice]["order"] * 0.95:
            print("............ PASS\n")
        else:
            print("............ FAIL\n")
            success = False
            failures.append(method_orders[nslice]["name"])


with open("fci_mms.pkl", "wb") as output:
    pickle.dump(nlist, output)
    for nslice in nslices:
        pickle.dump(error_2[nslice], output)
        pickle.dump(error_inf[nslice], output)

# Do we want to show the plot as well as save it to file.
showPlot = True

if False:
    try:
        # Plot using matplotlib if available
        import matplotlib.pyplot as plt

        fig, ax = plt.subplots(1, 1)

        for nslice in nslices:
            ax.plot(
                dx,
                error_2[nslice],
                "-",
                label="{} $l_2$".format(method_orders[nslice]["name"]),
            )
            ax.plot(
                dx,
                error_inf[nslice],
                "--",
                label="{} $l_\\inf$".format(method_orders[nslice]["name"]),
            )
        ax.legend(loc="upper left")
        ax.grid()
        ax.set_yscale("log")
        ax.set_xscale("log")
        ax.set_title("error scaling")
        ax.set_xlabel(r"Mesh spacing $\delta x$")
        ax.set_ylabel("Error norm")

        plt.savefig("fci_mms.pdf")

        print("Plot saved to fci_mms.pdf")

        if showPlot:
            plt.show()
        plt.close()
    except ImportError:
        print("No matplotlib")

if success:
    print("All tests passed")
    exit(0)
else:
    print("Some tests failed:")
    for failure in failures:
        print("\t" + failure)
    exit(1)
