# Copyright 2021-2025 Xing Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Utility functions
"""
from functools import partial
from pyscfad import ops
from pyscfad import pytree

def pytree_node(leaf_names, num_args=0, exclude_aux_name=()):
    """Class decorator that registers the underlying class as a pytree.

    See `jax document <https://jax.readthedocs.io/en/latest/pytrees.html>`_
    for the definition of pytrees.

    Parameters
    ----------
    leaf_names : list or tuple
        Attributes of the class that are traced as pytree leaves.
    num_args : int, optional
        Number of positional arguments in ``leaf_names``.
        This is useful when the ``__init__`` method of the class
        has positional arguments that are named differently than
        the actual attribute names. Default value is 0.
    exclude_aux_name : tuple, default=()
        A set of static attribute names that are not used for comparing
        the pytrees. Note that ``jax.jit`` recompiles the function for input
        pytrees with different static attribute values.

    Notes
    -----
    The ``__init__`` method of the class can't have positional arguments
    that are not included in ``leaf_names``. If ``num_args`` is greater
    than 0, the sequence of positional arguments in ``leaf_names`` must
    follow that in the ``__init__`` method.
    """
    return partial(pytree.class_as_pytree_node,
                   leaf_names=leaf_names,
                   num_args=num_args,
                   exclude_aux_name=exclude_aux_name)

def to_pyscf(obj, nocopy_names=(), out=None):
    """Convert the pyscfad object to its pyscf counterpart.

    The conversion effectively removes the tracing of the object
    and its members.

    Parameters
    ----------
    obj : object
        The pyscfad object to be converted.
    nocopy_names : tuple, default=()
        Names of attributes that are not copied to the pyscf object.
    out : object, optional
        The target pyscf object.

    Notes
    -----
    Member arrays will be converted (whether a copy is made depends on
    the implementation of ``__array__`` function) to numpy arrays.
    """
    if obj.__module__.startswith("pyscf."):
        return obj

    if out is None:
        from importlib import import_module
        from pyscf.lib.misc import omniobj, set_class
        mod = import_module(obj.__module__.replace("pyscfad", "pyscf"))
        cls = getattr(mod, obj.__class__.__name__)
        out = cls(omniobj)
        if cls.__name__ == "_DFHF":
            mf_cls = obj.__class__.__mro__[2]
            mf_mod = import_module(mf_cls.__module__.replace("pyscfad", "pyscf"))
            mf_cls = getattr(mf_mod, mf_cls.__name__)
            # need to initilize the mean-field object to get the attributes
            out = cls(mf_cls(omniobj))
            out = set_class(out, (cls, mf_cls))

    cls_keys = [getattr(cls, "_keys", ()) for cls in out.__class__.__mro__[:-1]]
    out_keys = set(out.__dict__).union(*cls_keys)
    # Only overwrite the attributes of the same name.
    keys = set(obj.__dict__).intersection(out_keys)
    keys = keys - set(nocopy_names)

    for key in keys:
        val = getattr(obj, key)
        if ops.is_array(val):
            val = ops.to_numpy(val)
        elif hasattr(val, "to_pyscf"):
            val = val.to_pyscf()
        setattr(out, key, val)
    return out

def is_tracer(a):
    """Test if the object is a tracer.

    Parameters
    ----------
    a : object
        The object to be tested.

    Notes
    -----
    Only works for the jax backend.
    """
    return any(cls.__name__.endswith("Tracer") for cls in a.__class__.__mro__)
