#!/usr/bin/env python3
#
# Python script to run and analyse MMS test
#

# Cores: 8
# requires: metric_3d

from boututils.run_wrapper import build_and_log, launch_safe, shell_safe
from boutdata.collect import collect
import boutconfig as conf
import itertools

import numpy as np

# Resolution in x and y
nlist = [1, 2, 4]

maxcores = 8

nslices = [1]

success = True

build_and_log("FCI MMS test")

for nslice in nslices:
    for NXPE, NYPE in itertools.product(nlist, nlist):
        if NXPE * NYPE > maxcores:
            continue

        args = f"NXPE={NXPE} NYPE={NYPE}"
        # Command to run
        cmd = f"./fci_mpi {args}"

        print(f"Running command: {cmd}")

        mthread = maxcores // (NXPE * NYPE)
        # Launch using MPI
        _, out = launch_safe(cmd, nproc=NXPE * NYPE, mthread=mthread, pipe=True)

        # Save output to log file
        with open(f"run.log.{NXPE}.{NYPE}.{nslice}.log", "w") as f:
            f.write(out)

        collect_kw = dict(info=False, xguards=False, yguards=False, path="data")
        if NXPE == NYPE == 1:
            # reference data!
            ref = {}
            for i in range(4):
                for yp in range(1, nslice + 1):
                    for y in [-yp, yp]:
                        name = f"output_{i}_{y:+d}"
                        ref[name] = collect(name, **collect_kw)
        else:
            for name, val in ref.items():
                assert np.allclose(val, collect(name, **collect_kw))
