import os
import math
import inspect
import textwrap
from pathlib import Path

import black

# NOTE: typeinfer and lowering should be the default, so we don't generate them.


def builtin_math_functions():
    for name, obj in inspect.getmembers(math):
        # skip some special cases for now
        if name in (
            "prod",
            "perm",
            "modf",
            "ldexp",
            "lcm",
            "isqrt",
            "isclose",
            "gcd",
            "fsum",
            "frexp",
            "factorial",
            "acosh",
            "comb",
            "dist",
            "sumprod",
            "nextafter",
            # 3.10 compat
            "cbrt",
            "exp2",
        ):
            continue

        if inspect.isbuiltin(obj):
            try:
                sig = inspect.signature(obj)
                yield name, obj, sig
            except:  # noqa: E722
                continue


with open(os.path.join(os.path.dirname(__file__), "stmts.py"), "w") as f:
    f.write("# This file is generated by gen.py\n")
    f.write("from kirin import ir, types, lowering2\n")
    f.write("from kirin.decl import statement, info\n")
    f.write("from kirin.dialects.math.dialect import dialect\n")
    f.write("\n")
    for name, obj, sig in builtin_math_functions():
        fields = "\n".join(
            [
                f"    {arg} : ir.SSAValue = info.argument(types.Float)"
                for arg in sig.parameters.keys()
            ]
        )
        f.write(
            textwrap.dedent(
                f"""
@statement(dialect=dialect)
class {name}(ir.Statement):
    \"\"\"{name} statement, wrapping the math.{name} function
    \"\"\"
    name = "{name}"
    traits = frozenset({{ir.Pure(), lowering2.FromPythonCall()}})
{fields}
    result: ir.ResultValue = info.result(types.Float)
"""
            )
        )


with open(os.path.join(os.path.dirname(__file__), "interp.py"), "w") as f:
    f.write("# This file is generated by gen.py\n")
    f.write("import math\n")
    f.write("from kirin.dialects.math.dialect import dialect\n")
    f.write("from kirin.dialects.math import stmts\n")
    f.write("from kirin.interp import MethodTable, Frame, impl\n")
    f.write("\n")

    implements = []
    for name, obj, sig in builtin_math_functions():
        fields = ", ".join(
            [f"values[{idx}]" for idx, _ in enumerate(sig.parameters.keys())]
        )
        implements.append(
            f"""
    @impl(stmts.{name})
    def {name}(self, interp, frame: Frame, stmt: stmts.{name}):
        values = frame.get_values(stmt.args)
        return (math.{name}({fields}),)"""
        )

    # Write the interpreter class
    methods = "\n\n".join(implements)
    f.write(
        f"""
@dialect.register
class MathMethodTable(MethodTable):
{methods}
"""
    )

# __init__.py
with open(os.path.join(os.path.dirname(__file__), "__init__.py"), "w") as f:
    f.write('"math dialect, modeling functions in python\'s `math` stdlib"')
    f.write("# This file is generated by gen.py\n")
    f.write("from kirin.dialects.math.dialect import dialect as dialect\n")
    f.write("from . import stmts as stmts, interp as interp\n")
    f.write("import math as pymath\n")
    f.write("pi = pymath.pi\n")
    f.write("e = pymath.e\n")
    f.write("tau = pymath.tau\n")
    f.write("from kirin import lowering2\n")

    for name, obj, sig in builtin_math_functions():
        f.write(
            textwrap.dedent(
                f"""
        @lowering2.wraps(stmts.{name})
        def {name}({", ".join(f"{arg}: float" for arg in sig.parameters.keys())}) -> float: ...
        """
            )
        )
    f.write("\n")

for file in ["__init__.py", "interp.py", "stmts.py"]:
    # format the file in place + using the project config
    black.format_file_in_place(
        Path(os.path.join(os.path.dirname(__file__), file)),
        fast=False,
        mode=black.FileMode(),
    )


# import math as pymath

# from kirin.compile import compile
# from kirin.dialects import math


# # print(math.sin(x=TestValue()))
# # print(inspect.getargspec(math.sin.__init__))
# # print(math.sin.__init__)
# @basic
# def complicated_math_expr(x):
#     return math.sin(math.cos(x) + math.tan(0.5))


# def test_math():
#     complicated_math_expr.code.print()
#     complicated_math_expr.narrow_types()
#     truth = pymath.sin(pymath.cos(1) + pymath.tan(0.5))
#     assert (complicated_math_expr(1) - truth) / truth < 1e-6

# test_basic.py
project_dir = Path(__file__).parent.parent.parent.parent.parent
with open(project_dir.joinpath("test", "dialects", "math", "test_basic.py"), "w") as f:
    f.write("# type: ignore\n")
    f.write("# This file is generated by gen.py\n")
    f.write("import math as pymath\n")
    f.write("from kirin.prelude import basic\n")
    f.write("from kirin.dialects import math\n")
    f.write("\n")
    f.write("\n")

    for name, obj, sig in builtin_math_functions():
        args = ", ".join(arg for arg in sig.parameters.keys())
        inputs = ", ".join("0.42" for _ in sig.parameters.keys())

        f.write(
            textwrap.dedent(
                f"""
                @basic
                def {name}_func({args}):
                    return math.{name}({args})

                def test_{name}():
                    truth = pymath.{name}({inputs})
                    assert ({name}_func({inputs}) - truth) < 1e-6
                """
            )
        )
