import zoidberg as zb
import numpy as np
import xarray as xr
from mms_helper import lst

nonlin = 0.8
assert nonlin < 1


def var1(*args):
    pos = np.linspace(*args)
    return pos + np.sin(pos) * nonlin


def var2(*args):
    pos = np.linspace(*args)
    return pos + np.cos(pos) * nonlin


def var3(*args):
    pos = np.linspace(*args)
    return pos + np.cos(pos * 2) / 2 * nonlin


def var4(*args):
    pos = np.linspace(*args)
    return pos + np.sin(pos) / 2 * nonlin


modes = [
    ("const", np.linspace),
    ("var1", var1),
    ("var2", var2),
    ("var3", var3),
    ("var4", var4),
]


def gen_name(*args):
    nx, ny, nz, R0, r0, r1, mode = args
    return f"poloidal_{modes[mode][0]}_{nx}_{ny}_{nz}_{R0}.fci.nc"


def gen_grid(*args):
    nx, ny, nz, R0, r0, r1, mode = args
    mode = modes[mode]
    one = np.ones((nx, ny, nz))
    r = np.linspace(r0, r1, nx)[:, None]
    theta = mode[1](0, 2 * np.pi, nz, False)[None, :]
    phi = np.linspace(0, 2 * np.pi / 5, ny, False)
    R = R0 + np.cos(theta) * r
    Z = np.sin(theta) * r
    pol_grid = zb.poloidal_grid.StructuredPoloidalGrid(R, Z)

    field = zb.field.CurvedSlab(Bz=0, Bzprime=0, Rmaj=R0)
    grid = zb.grid.Grid(pol_grid, phi, 5, yperiodic=True)

    fn = gen_name(*args)

    maps = zb.make_maps(grid, field, quiet=True, MXG=1)
    zb.write_maps(
        grid,
        field,
        maps,
        "tmp.nc",
        metric2d=False,
    )
    with xr.open_dataset("tmp.nc") as ds:
        dims = ds.dz.dims
        ds["r_minor"] = dims, one * r[:, None, :]
        ds["phi"] = "y", phi
        ds["theta"] = dims, one * theta[:, None, :]
        ds["one"] = dims, one
        ds["g12"] = dims, one * 0
        ds["g23"] = dims, one * 0
        ds["g_12"] = dims, one * 0
        ds["g_23"] = dims, one * 0
        ds.to_netcdf(fn)


def _togen(*args):
    return gen_name(*args), args


grids = {}
for mode in range(5):
    grids[modes[mode][0]] = [_togen(4, 2, nz, 1, 0.1, 0.5, mode) + (nz,) for nz in lst]

if __name__ == "__main__":
    for todos in grids.values():
        for fn, args, _ in todos:
            gen_grid(*args)
