# Stats
- 61 files
- 12504 (13K) lines
- 430432 (430K) chars
- 46022 (46K) `whitespace-split` tokens

# File Tree

```
muutils                                  
├── .github                              
│   └── workflows                        
│    ├── checks.yml                      [  110L  2,773C   301T]
│    └── make-docs.yml                   [   48L  1,261C   140T]
├── muutils                              
│   ├── json_serialize                   
│   │   ├── __init__.py                  [   51L  2,416C   297T]
│   │   ├── array.py                     [  226L  7,956C   789T]
│   │   ├── dataclass_transform_mock.py  [   29L    817C    69T]
│   │   ├── json_serialize.py            [  333L 11,880C 1,135T]
│   │   ├── serializable_dataclass.py    [  905L 35,786C 3,564T]
│   │   ├── serializable_field.py        [  308L 12,195C 1,247T]
│   │   └── util.py                      [  281L  9,185C   983T]
│   ├── logger                           
│   │   ├── __init__.py                  [   30L    684C    55T]
│   │   ├── exception_context.py         [   43L  1,183C   110T]
│   │   ├── headerfuncs.py               [   68L  1,695C   207T]
│   │   ├── log_util.py                  [   81L  2,112C   268T]
│   │   ├── logger.py                    [  306L 10,726C 1,091T]
│   │   ├── loggingstream.py             [   95L  3,856C   412T]
│   │   ├── simplelogger.py              [   81L  2,173C   239T]
│   │   └── timing.py                    [   87L  2,613C   257T]
│   ├── math                             
│   │   ├── __init__.py                  [    4L     47C     6T]
│   │   ├── bins.py                      [   67L  2,143C   165T]
│   │   └── matrix_powers.py             [  164L  5,323C   649T]
│   ├── misc                             
│   │   ├── __init__.py                  [   83L  1,941C   153T]
│   │   ├── b64_decode.py                [    9L    276C    30T]
│   │   ├── classes.py                   [   97L  3,425C   409T]
│   │   ├── freezing.py                  [  121L  3,698C   371T]
│   │   ├── func.py                      [  277L  8,982C   931T]
│   │   ├── hashing.py                   [   38L  1,049C   121T]
│   │   ├── numerical.py                 [  165L  4,601C   522T]
│   │   ├── sequence.py                  [  234L  7,251C   877T]
│   │   └── string.py                    [  108L  3,008C   349T]
│   ├── nbutils                          
│   │   ├── __init__.py                  [   21L    488C    51T]
│   │   ├── configure_notebook.py        [  320L  9,792C 1,038T]
│   │   ├── convert_ipynb_to_script.py   [  374L 13,400C 1,219T]
│   │   ├── mermaid.py                   [   20L    571C    52T]
│   │   ├── print_tex.py                 [   21L    495C    70T]
│   │   └── run_notebook_tests.py        [  255L  9,284C   826T]
│   ├── web                              
│   │   ├── __init__.py                  [    3L     33C     5T]
│   │   └── bundle_html.py               [  388L 13,088C 1,271T]
│   ├── __init__.py                      [   34L    544C    43T]
│   ├── collect_warnings.py              [  132L  4,073C   375T]
│   ├── console_unicode.py               [   34L  1,070C   133T]
│   ├── dbg.py                           [  516L 15,967C 1,750T]
│   ├── dictmagic.py                     [  522L 18,132C 1,968T]
│   ├── errormode.py                     [  241L  7,939C   816T]
│   ├── group_equiv.py                   [   66L  2,060C   246T]
│   ├── interval.py                      [  532L 18,188C 1,720T]
│   ├── jsonlines.py                     [   77L  1,993C   246T]
│   ├── kappa.py                         [   46L  1,260C   150T]
│   ├── mlutils.py                       [  167L  5,214C   497T]
│   ├── parallel.py                      [  279L  9,245C 1,091T]
│   ├── py.typed                         [    0L      0C     0T]
│   ├── spinner.py                       [  511L 17,879C 1,688T]
│   ├── statcounter.py                   [  231L  7,373C   768T]
│   ├── sysinfo.py                       [  210L  7,178C   514T]
│   ├── tensor_info.py                   [  651L 21,853C 2,145T]
│   ├── tensor_utils.py                  [  496L 14,934C 1,437T]
│   ├── timeit_fancy.py                  [  107L  3,900C   482T]
│   └── validate_type.py                 [  237L  8,401C   896T]
├── LICENSE                              [  674L 35,149C 5,644T]
├── README.md                            [  131L  6,356C   551T]
├── makefile                             [1,695L 52,472C 6,256T]
├── pyproject.toml                       [  224L  6,583C   752T]
```

# File Contents

``````{ path=".github/workflows/checks.yml"  }
name: Checks

on:
  pull_request:
    branches:
      - main
      - "*"
  push:
    branches:
      - main

jobs:
  lint:
    name: Formatting
    runs-on: ubuntu-latest
    steps:
      - name: Checkout code
        uses: actions/checkout@v4
        with: 
          fetch-depth: 1

      - name: install format tools
        run: pip install -r .meta/requirements/requirements-lint.txt

      - name: Run Format Checks
        run: make format-check RUN_GLOBAL=1
  
  check-deps:
    name: Check dependencies
    runs-on: ubuntu-latest
    steps:
      - name: Checkout code
        uses: actions/checkout@v4
        with:
          fetch-depth: 1

      - name: Set up Python
        uses: actions/setup-python@v5
        with:
          python-version: '3.10'

      - name: set up uv
        run: curl -LsSf https://astral.sh/uv/install.sh | sh

      - name: check dependencies
        run: make dep-check
  
  test:
    name: Test and Lint
    runs-on: ubuntu-latest
    # needs: [lint, check-deps] # for conditionally running this job
    strategy:
      matrix:
        python: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13", "3.14"]
        pkg:
          - group: "legacy"
            torch: "1.13.1"
            numpy: "1.24.4"
          - group: "latest"
            torch: ""
            numpy: ""
        exclude:
          - python: "3.12"
            pkg:
              group: "legacy"
          - python: "3.13"
            pkg:
              group: "legacy"
          - python: "3.14"
            pkg:
              group: "legacy"
    
    steps:
      - name: Checkout code
        uses: actions/checkout@v4
        with: 
          fetch-depth: 1

      - name: Set up python
        uses: actions/setup-python@v5
        with:
          python-version: ${{ matrix.python }}

      - name: set up uv
        run: curl -LsSf https://astral.sh/uv/install.sh | sh

      - name: install
        run: make setup

      - name: Install different pytorch version
        if: ${{ matrix.pkg.torch != '' && matrix.python != '3.14' }}
        run: |
          uv pip install torch==${{ matrix.pkg.torch }}+cpu --extra-index-url https://download.pytorch.org/whl/cpu
      
      - name: Install different numpy version
        if: ${{ matrix.pkg.numpy != '' }}
        run: uv pip install numpy==${{ matrix.pkg.numpy }}
    
      - name: tests
        run: make test UV_NOSYNC=1

      # - name: tests in strict mode
      #   # TODO: until zanj ported to 3.8 and 3.9
      #   if: ${{ matrix.python != '3.8' && matrix.python != '3.9' }}
      #   run: make test WARN_STRICT=1

      - name: check typing
        # TODO[torch-python-3.14]: lack of torch causes mypy issues in 3.14
        if: ${{ matrix.python != '3.14' }}
        run: make typing UV_NOSYNC=1
``````{ end_of_file=".github/workflows/checks.yml" }

``````{ path=".github/workflows/make-docs.yml"  }
# this workflow partially copied from
# https://github.com/TransformerLensOrg/TransformerLens/blob/main/.github/workflows/checks.yml
name: make docs

on:
  pull_request:
    branches:
      - main
      - "*"
  push:
    branches:
      - main

jobs:
  build-docs:
    # When running on a PR, this just checks we can build the docs without errors
    # When running on merge to main, it builds the docs and then another job deploys them
    name: 'Build Docs'
    runs-on: ubuntu-latest
    if: github.event_name == 'push' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev') || contains(github.head_ref, 'docs')
    steps:
      - name: Install pandoc
        uses: awalsh128/cache-apt-pkgs-action@latest
        with:
          packages: pandoc
          version: '3.3'
      
      - name: Check pandoc version
        run: pandoc --version

      - name: Checkout code
        uses: actions/checkout@v4
        with:
          fetch-depth: 0

      - name: Set up Python
        uses: actions/setup-python@v5
        with:
          python-version: '3.13'

      - name: set up uv
        run: curl -LsSf https://astral.sh/uv/install.sh | sh
  
      - name: Install
        run: make setup

      - name: Build Docs
        run: make docs
``````{ end_of_file=".github/workflows/make-docs.yml" }

``````{ path="muutils/json_serialize/__init__.py"  }
"""submodule for serializing things to json in a recoverable way

you can throw *any* object into `muutils.json_serialize.json_serialize`
and it will return a `JSONitem`, meaning a bool, int, float, str, None, list of `JSONitem`s, or a dict mappting to `JSONitem`.

The goal of this is if you want to just be able to store something as relatively human-readable JSON, and don't care as much about recovering it, you can throw it into `json_serialize` and it will just work. If you want to do so in a recoverable way, check out [`ZANJ`](https://github.com/mivanit/ZANJ).

it will do so by looking in `DEFAULT_HANDLERS`, which will keep it as-is if its already valid, then try to find a `.serialize()` method on the object, and then have a bunch of special cases. You can add handlers by initializing a `JsonSerializer` object and passing a sequence of them to `handlers_pre`

additionally, `SerializeableDataclass` is a special kind of dataclass where you specify how to serialize each field, and a `.serialize()` method is automatically added to the class. This is done by using the `serializable_dataclass` decorator, inheriting from `SerializeableDataclass`, and `serializable_field` in place of `dataclasses.field` when defining non-standard fields.

This module plays nicely with and is a dependency of the [`ZANJ`](https://github.com/mivanit/ZANJ) library, which extends this to support saving things to disk in a more efficient way than just plain json (arrays are saved as npy files, for example), and automatically detecting how to load saved objects into their original classes.

"""

from __future__ import annotations

from muutils.json_serialize.array import arr_metadata, load_array
from muutils.json_serialize.json_serialize import (
    BASE_HANDLERS,
    JsonSerializer,
    json_serialize,
)
from muutils.json_serialize.serializable_dataclass import (
    SerializableDataclass,
    serializable_dataclass,
    serializable_field,
)
from muutils.json_serialize.util import try_catch, JSONitem, dc_eq

__all__ = [
    # submodules
    "array",
    "json_serialize",
    "serializable_dataclass",
    "serializable_field",
    "util",
    # imports
    "arr_metadata",
    "load_array",
    "BASE_HANDLERS",
    "JSONitem",
    "JsonSerializer",
    "json_serialize",
    "try_catch",
    "JSONitem",
    "dc_eq",
    "serializable_dataclass",
    "serializable_field",
    "SerializableDataclass",
]

``````{ end_of_file="muutils/json_serialize/__init__.py" }

``````{ path="muutils/json_serialize/array.py"  }
"""this utilities module handles serialization and loading of numpy and torch arrays as json

- `array_list_meta` is less efficient (arrays are stored as nested lists), but preserves both metadata and human readability.
- `array_b64_meta` is the most efficient, but is not human readable.
- `external` is mostly for use in [`ZANJ`](https://github.com/mivanit/ZANJ)

"""

from __future__ import annotations

import base64
import typing
import warnings
from typing import Any, Iterable, Literal, Optional, Sequence

try:
    import numpy as np
except ImportError as e:
    warnings.warn(
        f"numpy is not installed, array serialization will not work: \n{e}",
        ImportWarning,
    )

from muutils.json_serialize.util import _FORMAT_KEY, JSONitem

# pylint: disable=unused-argument

ArrayMode = Literal[
    "list",
    "array_list_meta",
    "array_hex_meta",
    "array_b64_meta",
    "external",
    "zero_dim",
]


def array_n_elements(arr) -> int:  # type: ignore[name-defined]
    """get the number of elements in an array"""
    if isinstance(arr, np.ndarray):
        return arr.size
    elif str(type(arr)) == "<class 'torch.Tensor'>":
        return arr.nelement()
    else:
        raise TypeError(f"invalid type: {type(arr)}")


def arr_metadata(arr) -> dict[str, list[int] | str | int]:
    """get metadata for a numpy array"""
    return {
        "shape": list(arr.shape),
        "dtype": (
            arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype)
        ),
        "n_elements": array_n_elements(arr),
    }


def serialize_array(
    jser: "JsonSerializer",  # type: ignore[name-defined] # noqa: F821
    arr: np.ndarray,
    path: str | Sequence[str | int],
    array_mode: ArrayMode | None = None,
) -> JSONitem:
    """serialize a numpy or pytorch array in one of several modes

    if the object is zero-dimensional, simply get the unique item

    `array_mode: ArrayMode` can be one of:
    - `list`: serialize as a list of values, no metadata (equivalent to `arr.tolist()`)
    - `array_list_meta`: serialize dict with metadata, actual list under the key `data`
    - `array_hex_meta`: serialize dict with metadata, actual hex string under the key `data`
    - `array_b64_meta`: serialize dict with metadata, actual base64 string under the key `data`

    for `array_list_meta`, `array_hex_meta`, and `array_b64_meta`, the serialized object is:
    ```
    {
        _FORMAT_KEY: <array_list_meta|array_hex_meta>,
        "shape": arr.shape,
        "dtype": str(arr.dtype),
        "data": <arr.tolist()|arr.tobytes().hex()|base64.b64encode(arr.tobytes()).decode()>,
    }
    ```

    # Parameters:
     - `arr : Any` array to serialize
     - `array_mode : ArrayMode` mode in which to serialize the array
       (defaults to `None` and inheriting from `jser: JsonSerializer`)

    # Returns:
     - `JSONitem`
       json serialized array

    # Raises:
     - `KeyError` : if the array mode is not valid
    """

    if array_mode is None:
        array_mode = jser.array_mode

    arr_type: str = f"{type(arr).__module__}.{type(arr).__name__}"
    arr_np: np.ndarray = arr if isinstance(arr, np.ndarray) else np.array(arr)

    # handle zero-dimensional arrays
    if len(arr.shape) == 0:
        return {
            _FORMAT_KEY: f"{arr_type}:zero_dim",
            "data": arr.item(),
            **arr_metadata(arr),
        }

    if array_mode == "array_list_meta":
        return {
            _FORMAT_KEY: f"{arr_type}:array_list_meta",
            "data": arr_np.tolist(),
            **arr_metadata(arr_np),
        }
    elif array_mode == "list":
        return arr_np.tolist()
    elif array_mode == "array_hex_meta":
        return {
            _FORMAT_KEY: f"{arr_type}:array_hex_meta",
            "data": arr_np.tobytes().hex(),
            **arr_metadata(arr_np),
        }
    elif array_mode == "array_b64_meta":
        return {
            _FORMAT_KEY: f"{arr_type}:array_b64_meta",
            "data": base64.b64encode(arr_np.tobytes()).decode(),
            **arr_metadata(arr_np),
        }
    else:
        raise KeyError(f"invalid array_mode: {array_mode}")


def infer_array_mode(arr: JSONitem) -> ArrayMode:
    """given a serialized array, infer the mode

    assumes the array was serialized via `serialize_array()`
    """
    if isinstance(arr, typing.Mapping):
        fmt: str = arr.get(_FORMAT_KEY, "")  # type: ignore
        if fmt.endswith(":array_list_meta"):
            if not isinstance(arr["data"], Iterable):
                raise ValueError(f"invalid list format: {type(arr['data']) = }\t{arr}")
            return "array_list_meta"
        elif fmt.endswith(":array_hex_meta"):
            if not isinstance(arr["data"], str):
                raise ValueError(f"invalid hex format: {type(arr['data']) = }\t{arr}")
            return "array_hex_meta"
        elif fmt.endswith(":array_b64_meta"):
            if not isinstance(arr["data"], str):
                raise ValueError(f"invalid b64 format: {type(arr['data']) = }\t{arr}")
            return "array_b64_meta"
        elif fmt.endswith(":external"):
            return "external"
        elif fmt.endswith(":zero_dim"):
            return "zero_dim"
        else:
            raise ValueError(f"invalid format: {arr}")
    elif isinstance(arr, list):
        return "list"
    else:
        raise ValueError(f"cannot infer array_mode from\t{type(arr) = }\n{arr = }")


def load_array(arr: JSONitem, array_mode: Optional[ArrayMode] = None) -> Any:
    """load a json-serialized array, infer the mode if not specified"""
    # return arr if its already a numpy array
    if isinstance(arr, np.ndarray) and array_mode is None:
        return arr

    # try to infer the array_mode
    array_mode_inferred: ArrayMode = infer_array_mode(arr)
    if array_mode is None:
        array_mode = array_mode_inferred
    elif array_mode != array_mode_inferred:
        warnings.warn(
            f"array_mode {array_mode} does not match inferred array_mode {array_mode_inferred}"
        )

    # actually load the array
    if array_mode == "array_list_meta":
        assert isinstance(arr, typing.Mapping), (
            f"invalid list format: {type(arr) = }\n{arr = }"
        )
        data = np.array(arr["data"], dtype=arr["dtype"])  # type: ignore
        if tuple(arr["shape"]) != tuple(data.shape):  # type: ignore
            raise ValueError(f"invalid shape: {arr}")
        return data

    elif array_mode == "array_hex_meta":
        assert isinstance(arr, typing.Mapping), (
            f"invalid list format: {type(arr) = }\n{arr = }"
        )
        data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"])  # type: ignore
        return data.reshape(arr["shape"])  # type: ignore

    elif array_mode == "array_b64_meta":
        assert isinstance(arr, typing.Mapping), (
            f"invalid list format: {type(arr) = }\n{arr = }"
        )
        data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"])  # type: ignore
        return data.reshape(arr["shape"])  # type: ignore

    elif array_mode == "list":
        assert isinstance(arr, typing.Sequence), (
            f"invalid list format: {type(arr) = }\n{arr = }"
        )
        return np.array(arr)  # type: ignore
    elif array_mode == "external":
        # assume ZANJ has taken care of it
        assert isinstance(arr, typing.Mapping)
        if "data" not in arr:
            raise KeyError(
                f"invalid external array, expected key 'data', got keys: '{list(arr.keys())}' and arr: {arr}"
            )
        return arr["data"]
    elif array_mode == "zero_dim":
        assert isinstance(arr, typing.Mapping)
        data = np.array(arr["data"])
        if tuple(arr["shape"]) != tuple(data.shape):  # type: ignore
            raise ValueError(f"invalid shape: {arr}")
        return data
    else:
        raise ValueError(f"invalid array_mode: {array_mode}")

``````{ end_of_file="muutils/json_serialize/array.py" }

``````{ path="muutils/json_serialize/dataclass_transform_mock.py"  }
from __future__ import annotations

import typing
from typing import Any, Union


def dataclass_transform(
    *,
    eq_default: bool = True,
    order_default: bool = False,
    kw_only_default: bool = False,
    frozen_default: bool = False,
    field_specifiers: tuple[Union[type[Any], typing.Callable[..., Any]], ...] = (),
    **kwargs: Any,
) -> typing.Callable:
    "mock `typing.dataclass_transform` for python <3.11"

    def decorator(cls_or_fn):
        cls_or_fn.__dataclass_transform__ = {
            "eq_default": eq_default,
            "order_default": order_default,
            "kw_only_default": kw_only_default,
            "frozen_default": frozen_default,
            "field_specifiers": field_specifiers,
            "kwargs": kwargs,
        }
        return cls_or_fn

    return decorator

``````{ end_of_file="muutils/json_serialize/dataclass_transform_mock.py" }

``````{ path="muutils/json_serialize/json_serialize.py"  }
"""provides the basic framework for json serialization of objects

notably:

- `SerializerHandler` defines how to serialize a specific type of object
- `JsonSerializer` handles configuration for which handlers to use
- `json_serialize` provides the default configuration if you don't care -- call it on any object!

"""

from __future__ import annotations

import inspect
import warnings
from dataclasses import dataclass, is_dataclass
from pathlib import Path
from typing import Any, Callable, Iterable, Mapping, Set, Union

from muutils.errormode import ErrorMode

try:
    from muutils.json_serialize.array import ArrayMode, serialize_array
except ImportError as e:
    ArrayMode = str  # type: ignore[misc]
    serialize_array = lambda *args, **kwargs: None  # noqa: E731
    warnings.warn(
        f"muutils.json_serialize.array could not be imported probably because missing numpy, array serialization will not work: \n{e}",
        ImportWarning,
    )

from muutils.json_serialize.util import (
    _FORMAT_KEY,
    Hashableitem,
    JSONitem,
    MonoTuple,
    SerializationException,
    _recursive_hashify,
    isinstance_namedtuple,
    safe_getsource,
    string_as_lines,
    try_catch,
)

# pylint: disable=protected-access

SERIALIZER_SPECIAL_KEYS: MonoTuple[str] = (
    "__name__",
    "__doc__",
    "__module__",
    "__class__",
    "__dict__",
    "__annotations__",
)

SERIALIZER_SPECIAL_FUNCS: dict[str, Callable] = {
    "str": str,
    "dir": dir,
    "type": try_catch(lambda x: str(type(x).__name__)),
    "repr": try_catch(lambda x: repr(x)),
    "code": try_catch(lambda x: inspect.getsource(x)),
    "sourcefile": try_catch(lambda x: inspect.getsourcefile(x)),
}

SERIALIZE_DIRECT_AS_STR: Set[str] = {
    "<class 'torch.device'>",
    "<class 'torch.dtype'>",
}

ObjectPath = MonoTuple[Union[str, int]]


@dataclass
class SerializerHandler:
    """a handler for a specific type of object

    # Parameters:
        - `check : Callable[[JsonSerializer, Any], bool]` takes a JsonSerializer and an object, returns whether to use this handler
        - `serialize : Callable[[JsonSerializer, Any, ObjectPath], JSONitem]` takes a JsonSerializer, an object, and the current path, returns the serialized object
        - `desc : str` description of the handler (optional)
    """

    # (self_config, object) -> whether to use this handler
    check: Callable[["JsonSerializer", Any, ObjectPath], bool]
    # (self_config, object, path) -> serialized object
    serialize_func: Callable[["JsonSerializer", Any, ObjectPath], JSONitem]
    # unique identifier for the handler
    uid: str
    # description of this serializer
    desc: str

    def serialize(self) -> dict:
        """serialize the handler info"""
        return {
            # get the code and doc of the check function
            "check": {
                "code": safe_getsource(self.check),
                "doc": string_as_lines(self.check.__doc__),
            },
            # get the code and doc of the load function
            "serialize_func": {
                "code": safe_getsource(self.serialize_func),
                "doc": string_as_lines(self.serialize_func.__doc__),
            },
            # get the uid, source_pckg, priority, and desc
            "uid": str(self.uid),
            "source_pckg": getattr(self.serialize_func, "source_pckg", None),
            "__module__": getattr(self.serialize_func, "__module__", None),
            "desc": str(self.desc),
        }


BASE_HANDLERS: MonoTuple[SerializerHandler] = (
    SerializerHandler(
        check=lambda self, obj, path: isinstance(
            obj, (bool, int, float, str, type(None))
        ),
        serialize_func=lambda self, obj, path: obj,
        uid="base types",
        desc="base types (bool, int, float, str, None)",
    ),
    SerializerHandler(
        check=lambda self, obj, path: isinstance(obj, Mapping),
        serialize_func=lambda self, obj, path: {
            str(k): self.json_serialize(v, tuple(path) + (k,)) for k, v in obj.items()
        },
        uid="dictionaries",
        desc="dictionaries",
    ),
    SerializerHandler(
        check=lambda self, obj, path: isinstance(obj, (list, tuple)),
        serialize_func=lambda self, obj, path: [
            self.json_serialize(x, tuple(path) + (i,)) for i, x in enumerate(obj)
        ],
        uid="(list, tuple) -> list",
        desc="lists and tuples as lists",
    ),
)


def _serialize_override_serialize_func(
    self: "JsonSerializer", obj: Any, path: ObjectPath
) -> JSONitem:
    # obj_cls: type = type(obj)
    # if hasattr(obj_cls, "_register_self") and callable(obj_cls._register_self):
    #     obj_cls._register_self()

    # get the serialized object
    return obj.serialize()


DEFAULT_HANDLERS: MonoTuple[SerializerHandler] = tuple(BASE_HANDLERS) + (
    SerializerHandler(
        # TODO: allow for custom serialization handler name
        check=lambda self, obj, path: hasattr(obj, "serialize")
        and callable(obj.serialize),
        serialize_func=_serialize_override_serialize_func,
        uid=".serialize override",
        desc="objects with .serialize method",
    ),
    SerializerHandler(
        check=lambda self, obj, path: isinstance_namedtuple(obj),
        serialize_func=lambda self, obj, path: self.json_serialize(dict(obj._asdict())),
        uid="namedtuple -> dict",
        desc="namedtuples as dicts",
    ),
    SerializerHandler(
        check=lambda self, obj, path: is_dataclass(obj),
        serialize_func=lambda self, obj, path: {
            k: self.json_serialize(getattr(obj, k), tuple(path) + (k,))
            for k in obj.__dataclass_fields__
        },
        uid="dataclass -> dict",
        desc="dataclasses as dicts",
    ),
    SerializerHandler(
        check=lambda self, obj, path: isinstance(obj, Path),
        serialize_func=lambda self, obj, path: obj.as_posix(),
        uid="path -> str",
        desc="Path objects as posix strings",
    ),
    SerializerHandler(
        check=lambda self, obj, path: str(type(obj)) in SERIALIZE_DIRECT_AS_STR,
        serialize_func=lambda self, obj, path: str(obj),
        uid="obj -> str(obj)",
        desc="directly serialize objects in `SERIALIZE_DIRECT_AS_STR` to strings",
    ),
    SerializerHandler(
        check=lambda self, obj, path: str(type(obj)) == "<class 'numpy.ndarray'>",
        serialize_func=lambda self, obj, path: serialize_array(self, obj, path=path),
        uid="numpy.ndarray",
        desc="numpy arrays",
    ),
    SerializerHandler(
        check=lambda self, obj, path: str(type(obj)) == "<class 'torch.Tensor'>",
        serialize_func=lambda self, obj, path: serialize_array(
            self, obj.detach().cpu(), path=path
        ),
        uid="torch.Tensor",
        desc="pytorch tensors",
    ),
    SerializerHandler(
        check=lambda self, obj, path: (
            str(type(obj)) == "<class 'pandas.core.frame.DataFrame'>"
        ),
        serialize_func=lambda self, obj, path: {
            _FORMAT_KEY: "pandas.DataFrame",
            "columns": obj.columns.tolist(),
            "data": obj.to_dict(orient="records"),
            "path": path,  # type: ignore
        },
        uid="pandas.DataFrame",
        desc="pandas DataFrames",
    ),
    SerializerHandler(
        check=lambda self, obj, path: isinstance(obj, (set, list, tuple))
        or isinstance(obj, Iterable),
        serialize_func=lambda self, obj, path: [
            self.json_serialize(x, tuple(path) + (i,)) for i, x in enumerate(obj)
        ],
        uid="(set, list, tuple, Iterable) -> list",
        desc="sets, lists, tuples, and Iterables as lists",
    ),
    SerializerHandler(
        check=lambda self, obj, path: True,
        serialize_func=lambda self, obj, path: {
            **{k: str(getattr(obj, k, None)) for k in SERIALIZER_SPECIAL_KEYS},
            **{k: f(obj) for k, f in SERIALIZER_SPECIAL_FUNCS.items()},
        },
        uid="fallback",
        desc="fallback handler -- serialize object attributes and special functions as strings",
    ),
)


class JsonSerializer:
    """Json serialization class (holds configs)

    # Parameters:
    - `array_mode : ArrayMode`
    how to write arrays
    (defaults to `"array_list_meta"`)
    - `error_mode : ErrorMode`
    what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn")
    (defaults to `"except"`)
    - `handlers_pre : MonoTuple[SerializerHandler]`
    handlers to use before the default handlers
    (defaults to `tuple()`)
    - `handlers_default : MonoTuple[SerializerHandler]`
    default handlers to use
    (defaults to `DEFAULT_HANDLERS`)
    - `write_only_format : bool`
    changes _FORMAT_KEY keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading)
    (defaults to `False`)

    # Raises:
    - `ValueError`: on init, if `args` is not empty
    - `SerializationException`: on `json_serialize()`, if any error occurs when trying to serialize an object and `error_mode` is set to `ErrorMode.EXCEPT"`

    """

    def __init__(
        self,
        *args,
        array_mode: ArrayMode = "array_list_meta",
        error_mode: ErrorMode = ErrorMode.EXCEPT,
        handlers_pre: MonoTuple[SerializerHandler] = tuple(),
        handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS,
        write_only_format: bool = False,
    ):
        if len(args) > 0:
            raise ValueError(
                f"JsonSerializer takes no positional arguments!\n{args = }"
            )

        self.array_mode: ArrayMode = array_mode
        self.error_mode: ErrorMode = ErrorMode.from_any(error_mode)
        self.write_only_format: bool = write_only_format
        # join up the handlers
        self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple(
            handlers_default
        )

    def json_serialize(
        self,
        obj: Any,
        path: ObjectPath = tuple(),
    ) -> JSONitem:
        try:
            for handler in self.handlers:
                if handler.check(self, obj, path):
                    output: JSONitem = handler.serialize_func(self, obj, path)
                    if self.write_only_format:
                        if isinstance(output, dict) and _FORMAT_KEY in output:
                            new_fmt: JSONitem = output.pop(_FORMAT_KEY)
                            output["__write_format__"] = new_fmt
                    return output

            raise ValueError(f"no handler found for object with {type(obj) = }")

        except Exception as e:
            if self.error_mode == "except":
                obj_str: str = repr(obj)
                if len(obj_str) > 1000:
                    obj_str = obj_str[:1000] + "..."
                raise SerializationException(
                    f"error serializing at {path = } with last handler: '{handler.uid}'\nfrom: {e}\nobj: {obj_str}"
                ) from e
            elif self.error_mode == "warn":
                warnings.warn(
                    f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}"
                )

            return repr(obj)

    def hashify(
        self,
        obj: Any,
        path: ObjectPath = tuple(),
        force: bool = True,
    ) -> Hashableitem:
        """try to turn any object into something hashable"""
        data = self.json_serialize(obj, path=path)

        # recursive hashify, turning dicts and lists into tuples
        return _recursive_hashify(data, force=force)


GLOBAL_JSON_SERIALIZER: JsonSerializer = JsonSerializer()


def json_serialize(obj: Any, path: ObjectPath = tuple()) -> JSONitem:
    """serialize object to json-serializable object with default config"""
    return GLOBAL_JSON_SERIALIZER.json_serialize(obj, path=path)

``````{ end_of_file="muutils/json_serialize/json_serialize.py" }

``````{ path="muutils/json_serialize/serializable_dataclass.py"  }
"""save and load objects to and from json or compatible formats in a recoverable way

`d = dataclasses.asdict(my_obj)` will give you a dict, but if some fields are not json-serializable,
you will get an error when you call `json.dumps(d)`. This module provides a way around that.

Instead, you define your class:

```python
@serializable_dataclass
class MyClass(SerializableDataclass):
    a: int
    b: str
```

and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:

    >>> my_obj = MyClass(a=1, b="q")
    >>> s = json.dumps(my_obj.serialize())
    >>> s
    '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
    >>> read_obj = MyClass.load(json.loads(s))
    >>> read_obj == my_obj
    True

This isn't too impressive on its own, but it gets more useful when you have nested classses,
or fields that are not json-serializable by default:

```python
@serializable_dataclass
class NestedClass(SerializableDataclass):
    x: str
    y: MyClass
    act_fun: torch.nn.Module = serializable_field(
        default=torch.nn.ReLU(),
        serialization_fn=lambda x: str(x),
        deserialize_fn=lambda x: getattr(torch.nn, x)(),
    )
```

which gives us:

    >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
    >>> s = json.dumps(nc.serialize())
    >>> s
    '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
    >>> read_nc = NestedClass.load(json.loads(s))
    >>> read_nc == nc
    True

"""

from __future__ import annotations

import abc
import dataclasses
import functools
import json
import sys
import typing
import warnings
from typing import Any, Optional, Type, TypeVar

from muutils.errormode import ErrorMode
from muutils.validate_type import validate_type
from muutils.json_serialize.serializable_field import (
    SerializableField,
    serializable_field,
)
from muutils.json_serialize.util import _FORMAT_KEY, array_safe_eq, dc_eq

# pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access

# this is quite horrible, but unfortunately mypy fails if we try to assign to `dataclass_transform` directly
# and every time we try to init a serializable dataclass it says the argument doesnt exist
try:
    try:
        # type ignore here for legacy versions
        from typing import dataclass_transform  # type: ignore[attr-defined]
    except Exception:
        from typing_extensions import dataclass_transform
except Exception:
    from muutils.json_serialize.dataclass_transform_mock import dataclass_transform

T = TypeVar("T")


class CantGetTypeHintsWarning(UserWarning):
    "special warning for when we can't get type hints"

    pass


class ZanjMissingWarning(UserWarning):
    "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work"

    pass


_zanj_loading_needs_import: bool = True
"flag to keep track of if we have successfully imported ZANJ"


def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]):
    """Register a serializable dataclass with the ZANJ import

    this allows `ZANJ().read()` to load the class and not just return plain dicts


    # TODO: there is some duplication here with register_loader_handler
    """
    global _zanj_loading_needs_import

    if _zanj_loading_needs_import:
        try:
            from zanj.loading import (  # type: ignore[import]
                LoaderHandler,
                register_loader_handler,
            )
        except ImportError:
            # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter
            # warnings.warn(
            #     "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`",
            #     ZanjMissingWarning,
            # )
            return

    _format: str = f"{cls.__name__}(SerializableDataclass)"
    lh: LoaderHandler = LoaderHandler(
        check=lambda json_item, path=None, z=None: (  # type: ignore
            isinstance(json_item, dict)
            and _FORMAT_KEY in json_item
            and json_item[_FORMAT_KEY].startswith(_format)
        ),
        load=lambda json_item, path=None, z=None: cls.load(json_item),  # type: ignore
        uid=_format,
        source_pckg=cls.__module__,
        desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass",
    )

    register_loader_handler(lh)

    return lh


_DEFAULT_ON_TYPECHECK_MISMATCH: ErrorMode = ErrorMode.WARN
_DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT


class FieldIsNotInitOrSerializeWarning(UserWarning):
    pass


def SerializableDataclass__validate_field_type(
    self: SerializableDataclass,
    field: SerializableField | str,
    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
) -> bool:
    """given a dataclass, check the field matches the type hint

    this function is written to `SerializableDataclass.validate_field_type`

    # Parameters:
     - `self : SerializableDataclass`
       `SerializableDataclass` instance
     - `field : SerializableField | str`
        field to validate, will get from `self.__dataclass_fields__` if an `str`
     - `on_typecheck_error : ErrorMode`
        what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False`
       (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`)

    # Returns:
     - `bool`
        if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore`
    """
    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)

    # get field
    _field: SerializableField
    if isinstance(field, str):
        _field = self.__dataclass_fields__[field]  # type: ignore[attr-defined]
    else:
        _field = field

    # do nothing case
    if not _field.assert_type:
        return True

    # if field is not `init` or not `serialize`, skip but warn
    # TODO: how to handle fields which are not `init` or `serialize`?
    if not _field.init or not _field.serialize:
        warnings.warn(
            f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked",
            FieldIsNotInitOrSerializeWarning,
        )
        return True

    assert isinstance(_field, SerializableField), (
        f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }"
    )

    # get field type hints
    try:
        field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name]
    except KeyError as e:
        on_typecheck_error.process(
            (
                f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n"
                + f"{get_cls_type_hints(self.__class__) = }\n"
                + f"Python version is {sys.version_info = }. You can:\n"
                + f"  - disable `assert_type`. Currently: {_field.assert_type = }\n"
                + f"  - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n"
                + "  - use python 3.9.x or higher\n"
                + "  - specify custom type validation function via `custom_typecheck_fn`\n"
            ),
            except_cls=TypeError,
            except_from=e,
        )
        return False

    # get the value
    value: Any = getattr(self, _field.name)

    # validate the type
    try:
        type_is_valid: bool
        # validate the type with the default type validator
        if _field.custom_typecheck_fn is None:
            type_is_valid = validate_type(value, field_type_hint)
        # validate the type with a custom type validator
        else:
            type_is_valid = _field.custom_typecheck_fn(field_type_hint)

        return type_is_valid

    except Exception as e:
        on_typecheck_error.process(
            "exception while validating type: "
            + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }",
            except_cls=ValueError,
            except_from=e,
        )
        return False


def SerializableDataclass__validate_fields_types__dict(
    self: SerializableDataclass,
    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
) -> dict[str, bool]:
    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field

    returns a dict of field names to bools, where the bool is if the field type is valid
    """
    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)

    # if except, bundle the exceptions
    results: dict[str, bool] = dict()
    exceptions: dict[str, Exception] = dict()

    # for each field in the class
    cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self)  # type: ignore[arg-type, assignment]
    for field in cls_fields:
        try:
            results[field.name] = self.validate_field_type(field, on_typecheck_error)
        except Exception as e:
            results[field.name] = False
            exceptions[field.name] = e

    # figure out what to do with the exceptions
    if len(exceptions) > 0:
        on_typecheck_error.process(
            f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}"
            + "\n\t"
            + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]),
            except_cls=ValueError,
            # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict
            except_from=list(exceptions.values())[0],
        )

    return results


def SerializableDataclass__validate_fields_types(
    self: SerializableDataclass,
    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
) -> bool:
    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
    return all(
        SerializableDataclass__validate_fields_types__dict(
            self, on_typecheck_error=on_typecheck_error
        ).values()
    )


@dataclass_transform(
    field_specifiers=(serializable_field, SerializableField),
)
class SerializableDataclass(abc.ABC):
    """Base class for serializable dataclasses

    only for linting and type checking, still need to call `serializable_dataclass` decorator

    # Usage:

    ```python
    @serializable_dataclass
    class MyClass(SerializableDataclass):
        a: int
        b: str
    ```

    and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:

        >>> my_obj = MyClass(a=1, b="q")
        >>> s = json.dumps(my_obj.serialize())
        >>> s
        '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
        >>> read_obj = MyClass.load(json.loads(s))
        >>> read_obj == my_obj
        True

    This isn't too impressive on its own, but it gets more useful when you have nested classses,
    or fields that are not json-serializable by default:

    ```python
    @serializable_dataclass
    class NestedClass(SerializableDataclass):
        x: str
        y: MyClass
        act_fun: torch.nn.Module = serializable_field(
            default=torch.nn.ReLU(),
            serialization_fn=lambda x: str(x),
            deserialize_fn=lambda x: getattr(torch.nn, x)(),
        )
    ```

    which gives us:

        >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
        >>> s = json.dumps(nc.serialize())
        >>> s
        '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
        >>> read_nc = NestedClass.load(json.loads(s))
        >>> read_nc == nc
        True
    """

    def serialize(self) -> dict[str, Any]:
        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
        raise NotImplementedError(
            f"decorate {self.__class__ = } with `@serializable_dataclass`"
        )

    @classmethod
    def load(cls: Type[T], data: dict[str, Any] | T) -> T:
        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")

    def validate_fields_types(
        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
    ) -> bool:
        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
        return SerializableDataclass__validate_fields_types(
            self, on_typecheck_error=on_typecheck_error
        )

    def validate_field_type(
        self,
        field: "SerializableField|str",
        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
    ) -> bool:
        """given a dataclass, check the field matches the type hint"""
        return SerializableDataclass__validate_field_type(
            self, field, on_typecheck_error=on_typecheck_error
        )

    def __eq__(self, other: Any) -> bool:
        return dc_eq(self, other)

    def __hash__(self) -> int:
        "hashes the json-serialized representation of the class"
        return hash(json.dumps(self.serialize()))

    def diff(
        self, other: "SerializableDataclass", of_serialized: bool = False
    ) -> dict[str, Any]:
        """get a rich and recursive diff between two instances of a serializable dataclass

        ```python
        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
        {'b': {'self': 2, 'other': 3}}
        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
        ```

        # Parameters:
         - `other : SerializableDataclass`
           other instance to compare against
         - `of_serialized : bool`
           if true, compare serialized data and not raw values
           (defaults to `False`)

        # Returns:
         - `dict[str, Any]`


        # Raises:
         - `ValueError` : if the instances are not of the same type
         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
        """
        # match types
        if type(self) is not type(other):
            raise ValueError(
                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
            )

        # initialize the diff result
        diff_result: dict = {}

        # if they are the same, return the empty diff
        try:
            if self == other:
                return diff_result
        except Exception:
            pass

        # if we are working with serialized data, serialize the instances
        if of_serialized:
            ser_self: dict = self.serialize()
            ser_other: dict = other.serialize()

        # for each field in the class
        for field in dataclasses.fields(self):  # type: ignore[arg-type]
            # skip fields that are not for comparison
            if not field.compare:
                continue

            # get values
            field_name: str = field.name
            self_value = getattr(self, field_name)
            other_value = getattr(other, field_name)

            # if the values are both serializable dataclasses, recurse
            if isinstance(self_value, SerializableDataclass) and isinstance(
                other_value, SerializableDataclass
            ):
                nested_diff: dict = self_value.diff(
                    other_value, of_serialized=of_serialized
                )
                if nested_diff:
                    diff_result[field_name] = nested_diff
            # only support serializable dataclasses
            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
                other_value
            ):
                raise ValueError("Non-serializable dataclass is not supported")
            else:
                # get the values of either the serialized or the actual values
                self_value_s = ser_self[field_name] if of_serialized else self_value
                other_value_s = ser_other[field_name] if of_serialized else other_value
                # compare the values
                if not array_safe_eq(self_value_s, other_value_s):
                    diff_result[field_name] = {"self": self_value, "other": other_value}

        # return the diff result
        return diff_result

    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
        """update the instance from a nested dict, useful for configuration from command line args

        # Parameters:
            - `nested_dict : dict[str, Any]`
                nested dict to update the instance with
        """
        for field in dataclasses.fields(self):  # type: ignore[arg-type]
            field_name: str = field.name
            self_value = getattr(self, field_name)

            if field_name in nested_dict:
                if isinstance(self_value, SerializableDataclass):
                    self_value.update_from_nested_dict(nested_dict[field_name])
                else:
                    setattr(self, field_name, nested_dict[field_name])

    def __copy__(self) -> "SerializableDataclass":
        "deep copy by serializing and loading the instance to json"
        return self.__class__.load(json.loads(json.dumps(self.serialize())))

    def __deepcopy__(self, memo: dict) -> "SerializableDataclass":
        "deep copy by serializing and loading the instance to json"
        return self.__class__.load(json.loads(json.dumps(self.serialize())))


# cache this so we don't have to keep getting it
# TODO: are the types hashable? does this even make sense?
@functools.lru_cache(typed=True)
def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]:
    "cached typing.get_type_hints for a class"
    return typing.get_type_hints(cls)


def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]:
    "helper function to get type hints for a class"
    cls_type_hints: dict[str, Any]
    try:
        cls_type_hints = get_cls_type_hints_cached(cls)  # type: ignore
        if len(cls_type_hints) == 0:
            cls_type_hints = typing.get_type_hints(cls)

        if len(cls_type_hints) == 0:
            raise ValueError(f"empty type hints for {cls.__name__ = }")
    except (TypeError, NameError, ValueError) as e:
        raise TypeError(
            f"Cannot get type hints for {cls = }\n"
            + f"  Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n"
            + f"  {dataclasses.fields(cls) = }\n"  # type: ignore[arg-type]
            + f"  {e = }"
        ) from e

    return cls_type_hints


class KWOnlyError(NotImplementedError):
    "kw-only dataclasses are not supported in python <3.9"

    pass


class FieldError(ValueError):
    "base class for field errors"

    pass


class NotSerializableFieldException(FieldError):
    "field is not a `SerializableField`"

    pass


class FieldSerializationError(FieldError):
    "error while serializing a field"

    pass


class FieldLoadingError(FieldError):
    "error while loading a field"

    pass


class FieldTypeMismatchError(FieldError, TypeError):
    "error when a field type does not match the type hint"

    pass


@dataclass_transform(
    field_specifiers=(serializable_field, SerializableField),
)
def serializable_dataclass(
    # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it
    _cls=None,  # type: ignore
    *,
    init: bool = True,
    repr: bool = True,  # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass`
    eq: bool = True,
    order: bool = False,
    unsafe_hash: bool = False,
    frozen: bool = False,
    properties_to_serialize: Optional[list[str]] = None,
    register_handler: bool = True,
    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
    on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH,
    methods_no_override: list[str] | None = None,
    **kwargs,
):
    """decorator to make a dataclass serializable. **must also make it inherit from `SerializableDataclass`!!**

    types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE`

    behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs. any kwargs not listed here are passed to `dataclasses.dataclass`

    Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.

    Examines PEP 526 `__annotations__` to determine fields.

    If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation.

    ```python
    @serializable_dataclass(kw_only=True)
    class Myclass(SerializableDataclass):
        a: int
        b: str
    ```
    ```python
    >>> Myclass(a=1, b="q").serialize()
    {_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
    ```

    # Parameters:

    - `_cls : _type_`
       class to decorate. don't pass this arg, just use this as a decorator
       (defaults to `None`)
    - `init : bool`
       whether to add an `__init__` method
       *(passed to dataclasses.dataclass)*
       (defaults to `True`)
    - `repr : bool`
       whether to add a `__repr__` method
       *(passed to dataclasses.dataclass)*
       (defaults to `True`)
    - `order : bool`
       whether to add rich comparison methods
       *(passed to dataclasses.dataclass)*
       (defaults to `False`)
    - `unsafe_hash : bool`
       whether to add a `__hash__` method
       *(passed to dataclasses.dataclass)*
       (defaults to `False`)
    - `frozen : bool`
       whether to make the class frozen
       *(passed to dataclasses.dataclass)*
       (defaults to `False`)
    - `properties_to_serialize : Optional[list[str]]`
       which properties to add to the serialized data dict
       **SerializableDataclass only**
       (defaults to `None`)
    - `register_handler : bool`
        if true, register the class with ZANJ for loading
        **SerializableDataclass only**
        (defaults to `True`)
    - `on_typecheck_error : ErrorMode`
        what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false
        **SerializableDataclass only**
    - `on_typecheck_mismatch : ErrorMode`
        what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True`
        **SerializableDataclass only**
    - `methods_no_override : list[str]|None`
        list of methods that should not be overridden by the decorator
        by default, `__eq__`, `serialize`, `load`, and `validate_fields_types` are overridden by this function,
        but you can disable this if you'd rather write your own. `dataclasses.dataclass` might still overwrite these, and those options take precedence
        **SerializableDataclass only**
        (defaults to `None`)
    - `**kwargs`
        *(passed to dataclasses.dataclass)*

    # Returns:

    - `_type_`
       the decorated class

    # Raises:

    - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this
    - `NotSerializableFieldException` : if a field is not a `SerializableField`
    - `FieldSerializationError` : if there is an error serializing a field
    - `AttributeError` : if a property is not found on the class
    - `FieldLoadingError` : if there is an error loading a field
    """
    # -> Union[Callable[[Type[T]], Type[T]], Type[T]]:
    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
    on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch)

    if properties_to_serialize is None:
        _properties_to_serialize: list = list()
    else:
        _properties_to_serialize = properties_to_serialize

    def wrap(cls: Type[T]) -> Type[T]:
        # Modify the __annotations__ dictionary to replace regular fields with SerializableField
        for field_name, field_type in cls.__annotations__.items():
            field_value = getattr(cls, field_name, None)
            if not isinstance(field_value, SerializableField):
                if isinstance(field_value, dataclasses.Field):
                    # Convert the field to a SerializableField while preserving properties
                    field_value = SerializableField.from_Field(field_value)
                else:
                    # Create a new SerializableField
                    field_value = serializable_field()
                setattr(cls, field_name, field_value)

        # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy
        if sys.version_info < (3, 10):
            if "kw_only" in kwargs:
                if kwargs["kw_only"] == True:  # noqa: E712
                    raise KWOnlyError(
                        "kw_only is not supported in python < 3.10, but if you pass a `False` value, it will be ignored"
                    )
                else:
                    del kwargs["kw_only"]

        # call `dataclasses.dataclass` to set some stuff up
        cls = dataclasses.dataclass(  # type: ignore[call-overload]
            cls,
            init=init,
            repr=repr,
            eq=eq,
            order=order,
            unsafe_hash=unsafe_hash,
            frozen=frozen,
            **kwargs,
        )

        # copy these to the class
        cls._properties_to_serialize = _properties_to_serialize.copy()  # type: ignore[attr-defined]

        # ======================================================================
        # define `serialize` func
        # done locally since it depends on args to the decorator
        # ======================================================================
        def serialize(self) -> dict[str, Any]:
            result: dict[str, Any] = {
                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
            }
            # for each field in the class
            for field in dataclasses.fields(self):  # type: ignore[arg-type]
                # need it to be our special SerializableField
                if not isinstance(field, SerializableField):
                    raise NotSerializableFieldException(
                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
                        f"but a {type(field)} "
                        "this state should be inaccessible, please report this bug!"
                    )

                # try to save it
                if field.serialize:
                    try:
                        # get the val
                        value = getattr(self, field.name)
                        # if it is a serializable dataclass, serialize it
                        if isinstance(value, SerializableDataclass):
                            value = value.serialize()
                        # if the value has a serialization function, use that
                        if hasattr(value, "serialize") and callable(value.serialize):
                            value = value.serialize()
                        # if the field has a serialization function, use that
                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
                        elif field.serialization_fn:
                            value = field.serialization_fn(value)

                        # store the value in the result
                        result[field.name] = value
                    except Exception as e:
                        raise FieldSerializationError(
                            "\n".join(
                                [
                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
                                    f"{field = }",
                                    f"{value = }",
                                    f"{self = }",
                                ]
                            )
                        ) from e

            # store each property if we can get it
            for prop in self._properties_to_serialize:
                if hasattr(cls, prop):
                    value = getattr(self, prop)
                    result[prop] = value
                else:
                    raise AttributeError(
                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
                        + f"but it is in {self._properties_to_serialize = }"
                        + f"\n{self = }"
                    )

            return result

        # ======================================================================
        # define `load` func
        # done locally since it depends on args to the decorator
        # ======================================================================
        # mypy thinks this isnt a classmethod
        @classmethod  # type: ignore[misc]
        def load(cls, data: dict[str, Any] | T) -> Type[T]:
            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
            if isinstance(data, cls):
                return data

            assert isinstance(data, typing.Mapping), (
                f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
            )

            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)

            # initialize dict for keeping what we will pass to the constructor
            ctor_kwargs: dict[str, Any] = dict()

            # iterate over the fields of the class
            for field in dataclasses.fields(cls):
                # check if the field is a SerializableField
                assert isinstance(field, SerializableField), (
                    f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
                )

                # check if the field is in the data and if it should be initialized
                if (field.name in data) and field.init:
                    # get the value, we will be processing it
                    value: Any = data[field.name]

                    # get the type hint for the field
                    field_type_hint: Any = cls_type_hints.get(field.name, None)

                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
                    if field.deserialize_fn:
                        # if it has a deserialization function, use that
                        value = field.deserialize_fn(value)
                    elif field.loading_fn:
                        # if it has a loading function, use that
                        value = field.loading_fn(data)
                    elif (
                        field_type_hint is not None
                        and hasattr(field_type_hint, "load")
                        and callable(field_type_hint.load)
                    ):
                        # if no loading function but has a type hint with a load method, use that
                        if isinstance(value, dict):
                            value = field_type_hint.load(value)
                        else:
                            raise FieldLoadingError(
                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
                            )
                    else:
                        # assume no loading needs to happen, keep `value` as-is
                        pass

                    # store the value in the constructor kwargs
                    ctor_kwargs[field.name] = value

            # create a new instance of the class with the constructor kwargs
            output: cls = cls(**ctor_kwargs)

            # validate the types of the fields if needed
            if on_typecheck_mismatch != ErrorMode.IGNORE:
                fields_valid: dict[str, bool] = (
                    SerializableDataclass__validate_fields_types__dict(
                        output,
                        on_typecheck_error=on_typecheck_error,
                    )
                )

                # if there are any fields that are not valid, raise an error
                if not all(fields_valid.values()):
                    msg: str = (
                        f"Type mismatch in fields of {cls.__name__}:\n"
                        + "\n".join(
                            [
                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
                                for k, v in fields_valid.items()
                                if not v
                            ]
                        )
                    )

                    on_typecheck_mismatch.process(
                        msg, except_cls=FieldTypeMismatchError
                    )

            # return the new instance
            return output

        _methods_no_override: set[str]
        if methods_no_override is None:
            _methods_no_override = set()
        else:
            _methods_no_override = set(methods_no_override)

        if _methods_no_override - {
            "__eq__",
            "serialize",
            "load",
            "validate_fields_types",
        }:
            warnings.warn(
                f"Unknown methods in `methods_no_override`: {_methods_no_override = }"
            )

        # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments
        if "serialize" not in _methods_no_override:
            # type is `Callable[[T], dict]`
            cls.serialize = serialize  # type: ignore[attr-defined]
        if "load" not in _methods_no_override:
            # type is `Callable[[dict], T]`
            cls.load = load  # type: ignore[attr-defined]

        if "validate_field_type" not in _methods_no_override:
            # type is `Callable[[T, ErrorMode], bool]`
            cls.validate_fields_types = SerializableDataclass__validate_fields_types  # type: ignore[attr-defined]

        if "__eq__" not in _methods_no_override:
            # type is `Callable[[T, T], bool]`
            cls.__eq__ = lambda self, other: dc_eq(self, other)  # type: ignore[assignment]

        # Register the class with ZANJ
        if register_handler:
            zanj_register_loader_serializable_dataclass(cls)

        return cls

    if _cls is None:
        return wrap
    else:
        return wrap(_cls)

``````{ end_of_file="muutils/json_serialize/serializable_dataclass.py" }

``````{ path="muutils/json_serialize/serializable_field.py"  }
"""extends `dataclasses.Field` for use with `SerializableDataclass`

In particular, instead of using `dataclasses.field`, use `serializable_field` to define fields in a `SerializableDataclass`.
You provide information on how the field should be serialized and loaded (as well as anything that goes into `dataclasses.field`)
when you define the field, and the `SerializableDataclass` will automatically use those functions.

"""

from __future__ import annotations

import dataclasses
import sys
import types
from typing import Any, Callable, Optional, Union, overload, TypeVar


# pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access


class SerializableField(dataclasses.Field):
    """extension of `dataclasses.Field` with additional serialization properties"""

    __slots__ = (
        # from dataclasses.Field.__slots__
        "name",
        "type",
        "default",
        "default_factory",
        "repr",
        "hash",
        "init",
        "compare",
        "doc",
        "metadata",
        "kw_only",
        "_field_type",  # Private: not to be used by user code.
        # new ones
        "serialize",
        "serialization_fn",
        "loading_fn",
        "deserialize_fn",  # new alternative to loading_fn
        "assert_type",
        "custom_typecheck_fn",
    )

    def __init__(
        self,
        default: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING,
        default_factory: Union[
            Callable[[], Any], dataclasses._MISSING_TYPE
        ] = dataclasses.MISSING,
        init: bool = True,
        repr: bool = True,
        hash: Optional[bool] = None,
        compare: bool = True,
        doc: str | None = None,
        # TODO: add field for custom comparator (such as serializing)
        metadata: Optional[types.MappingProxyType] = None,
        kw_only: Union[bool, dataclasses._MISSING_TYPE] = dataclasses.MISSING,
        serialize: bool = True,
        serialization_fn: Optional[Callable[[Any], Any]] = None,
        loading_fn: Optional[Callable[[Any], Any]] = None,
        deserialize_fn: Optional[Callable[[Any], Any]] = None,
        assert_type: bool = True,
        custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
    ):
        # TODO: should we do this check, or assume the user knows what they are doing?
        if init and not serialize:
            raise ValueError("Cannot have init=True and serialize=False")

        # need to assemble kwargs in this hacky way so as not to upset type checking
        super_kwargs: dict[str, Any] = dict(
            default=default,
            default_factory=default_factory,
            init=init,
            repr=repr,
            hash=hash,
            compare=compare,
            kw_only=kw_only,
        )

        if metadata is not None:
            super_kwargs["metadata"] = metadata
        else:
            super_kwargs["metadata"] = types.MappingProxyType({})

        # only pass `doc` to super if python >=3.14
        if sys.version_info >= (3, 14):
            super_kwargs["doc"] = doc

        # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy
        if sys.version_info < (3, 10):
            if super_kwargs["kw_only"] == True:  # noqa: E712
                raise ValueError("kw_only is not supported in python >=3.9")
            else:
                del super_kwargs["kw_only"]

        # actually init the super class
        super().__init__(**super_kwargs)  # type: ignore[call-arg]

        # init doc if python <3.14
        if sys.version_info < (3, 14):
            self.doc: str | None = doc

        # now init the new fields
        self.serialize: bool = serialize
        self.serialization_fn: Optional[Callable[[Any], Any]] = serialization_fn

        if loading_fn is not None and deserialize_fn is not None:
            raise ValueError(
                "Cannot pass both loading_fn and deserialize_fn, pass only one. ",
                "`loading_fn` is the older interface and takes the dict of the class, ",
                "`deserialize_fn` is the new interface and takes only the field's value.",
            )
        self.loading_fn: Optional[Callable[[Any], Any]] = loading_fn
        self.deserialize_fn: Optional[Callable[[Any], Any]] = deserialize_fn

        self.assert_type: bool = assert_type
        self.custom_typecheck_fn: Optional[Callable[[type], bool]] = custom_typecheck_fn

    @classmethod
    def from_Field(cls, field: dataclasses.Field) -> "SerializableField":
        """copy all values from a `dataclasses.Field` to new `SerializableField`"""
        return cls(
            default=field.default,
            default_factory=field.default_factory,
            init=field.init,
            repr=field.repr,
            hash=field.hash,
            compare=field.compare,
            doc=getattr(field, "doc", None),  # `doc` added in python <3.14
            metadata=field.metadata,
            kw_only=getattr(field, "kw_only", dataclasses.MISSING),  # for python <3.9
            serialize=field.repr,  # serialize if it's going to be repr'd
            serialization_fn=None,
            loading_fn=None,
            deserialize_fn=None,
        )


Sfield_T = TypeVar("Sfield_T")


@overload
def serializable_field(  # only `default_factory` is provided
    *_args,
    default_factory: Callable[[], Sfield_T],
    default: dataclasses._MISSING_TYPE = dataclasses.MISSING,
    init: bool = True,
    repr: bool = True,
    hash: Optional[bool] = None,
    compare: bool = True,
    doc: str | None = None,
    metadata: Optional[types.MappingProxyType] = None,
    kw_only: Union[bool, dataclasses._MISSING_TYPE] = dataclasses.MISSING,
    serialize: bool = True,
    serialization_fn: Optional[Callable[[Any], Any]] = None,
    deserialize_fn: Optional[Callable[[Any], Any]] = None,
    assert_type: bool = True,
    custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
    **kwargs: Any,
) -> Sfield_T: ...
@overload
def serializable_field(  # only `default` is provided
    *_args,
    default: Sfield_T,
    default_factory: dataclasses._MISSING_TYPE = dataclasses.MISSING,
    init: bool = True,
    repr: bool = True,
    hash: Optional[bool] = None,
    compare: bool = True,
    doc: str | None = None,
    metadata: Optional[types.MappingProxyType] = None,
    kw_only: Union[bool, dataclasses._MISSING_TYPE] = dataclasses.MISSING,
    serialize: bool = True,
    serialization_fn: Optional[Callable[[Any], Any]] = None,
    deserialize_fn: Optional[Callable[[Any], Any]] = None,
    assert_type: bool = True,
    custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
    **kwargs: Any,
) -> Sfield_T: ...
@overload
def serializable_field(  # both `default` and `default_factory` are MISSING
    *_args,
    default: dataclasses._MISSING_TYPE = dataclasses.MISSING,
    default_factory: dataclasses._MISSING_TYPE = dataclasses.MISSING,
    init: bool = True,
    repr: bool = True,
    hash: Optional[bool] = None,
    compare: bool = True,
    doc: str | None = None,
    metadata: Optional[types.MappingProxyType] = None,
    kw_only: Union[bool, dataclasses._MISSING_TYPE] = dataclasses.MISSING,
    serialize: bool = True,
    serialization_fn: Optional[Callable[[Any], Any]] = None,
    deserialize_fn: Optional[Callable[[Any], Any]] = None,
    assert_type: bool = True,
    custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
    **kwargs: Any,
) -> Any: ...
def serializable_field(  # general implementation
    *_args,
    default: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING,
    default_factory: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING,
    init: bool = True,
    repr: bool = True,
    hash: Optional[bool] = None,
    compare: bool = True,
    doc: str | None = None,
    metadata: Optional[types.MappingProxyType] = None,
    kw_only: Union[bool, dataclasses._MISSING_TYPE] = dataclasses.MISSING,
    serialize: bool = True,
    serialization_fn: Optional[Callable[[Any], Any]] = None,
    deserialize_fn: Optional[Callable[[Any], Any]] = None,
    assert_type: bool = True,
    custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
    **kwargs: Any,
) -> Any:
    """Create a new `SerializableField`

    ```
    default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING,
    default_factory: Callable[[], Sfield_T]
    | dataclasses._MISSING_TYPE = dataclasses.MISSING,
    init: bool = True,
    repr: bool = True,
    hash: Optional[bool] = None,
    compare: bool = True,
    doc: str | None = None, # new in python 3.14. can alternately pass `description` to match pydantic, but this is discouraged
    metadata: types.MappingProxyType | None = None,
    kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING,
    # ----------------------------------------------------------------------
    # new in `SerializableField`, not in `dataclasses.Field`
    serialize: bool = True,
    serialization_fn: Optional[Callable[[Any], Any]] = None,
    loading_fn: Optional[Callable[[Any], Any]] = None,
    deserialize_fn: Optional[Callable[[Any], Any]] = None,
    assert_type: bool = True,
    custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
    ```

    # new Parameters:
    - `serialize`: whether to serialize this field when serializing the class'
    - `serialization_fn`: function taking the instance of the field and returning a serializable object. If not provided, will iterate through the `SerializerHandler`s defined in `muutils.json_serialize.json_serialize`
    - `loading_fn`: function taking the serialized object and returning the instance of the field. If not provided, will take object as-is.
    - `deserialize_fn`: new alternative to `loading_fn`. takes only the field's value, not the whole class. if both `loading_fn` and `deserialize_fn` are provided, an error will be raised.
    - `assert_type`: whether to assert the type of the field when loading. if `False`, will not check the type of the field.
    - `custom_typecheck_fn`: function taking the type of the field and returning whether the type itself is valid. if not provided, will use the default type checking.

    # Gotchas:
    - `loading_fn` takes the dict of the **class**, not the field. if you wanted a `loading_fn` that does nothing, you'd write:

    ```python
    class MyClass:
        my_field: int = serializable_field(
            serialization_fn=lambda x: str(x),
            loading_fn=lambda x["my_field"]: int(x)
        )
    ```

    using `deserialize_fn` instead:

    ```python
    class MyClass:
        my_field: int = serializable_field(
            serialization_fn=lambda x: str(x),
            deserialize_fn=lambda x: int(x)
        )
    ```

    In the above code, `my_field` is an int but will be serialized as a string.

    note that if not using ZANJ, and you have a class inside a container, you MUST provide
    `serialization_fn` and `loading_fn` to serialize and load the container.
    ZANJ will automatically do this for you.

    # TODO: `custom_value_check_fn`: function taking the value of the field and returning whether the value itself is valid. if not provided, any value is valid as long as it passes the type test
    """
    assert len(_args) == 0, f"unexpected positional arguments: {_args}"

    if "description" in kwargs:
        import warnings

        warnings.warn(
            "`description` is deprecated, use `doc` instead",
            DeprecationWarning,
        )
        if doc is not None:
            err_msg: str = f"cannot pass both `doc` and `description`: {doc=}, {kwargs['description']=}"
            raise ValueError(err_msg)
        doc = kwargs.pop("description")

    return SerializableField(
        default=default,
        default_factory=default_factory,
        init=init,
        repr=repr,
        hash=hash,
        compare=compare,
        metadata=metadata,
        kw_only=kw_only,
        serialize=serialize,
        serialization_fn=serialization_fn,
        deserialize_fn=deserialize_fn,
        assert_type=assert_type,
        custom_typecheck_fn=custom_typecheck_fn,
        **kwargs,
    )

``````{ end_of_file="muutils/json_serialize/serializable_field.py" }

``````{ path="muutils/json_serialize/util.py"  }
"""utilities for json_serialize"""

from __future__ import annotations

import dataclasses
import functools
import inspect
import sys
import typing
import warnings
from typing import Any, Callable, Iterable, Union

_NUMPY_WORKING: bool
try:
    _NUMPY_WORKING = True
except ImportError:
    warnings.warn("numpy not found, cannot serialize numpy arrays!")
    _NUMPY_WORKING = False


BaseType = Union[
    bool,
    int,
    float,
    str,
    None,
]

JSONitem = Union[
    BaseType,
    # mypy doesn't like recursive types, so we just go down a few levels manually
    typing.List[Union[BaseType, typing.List[Any], typing.Dict[str, Any]]],
    typing.Dict[str, Union[BaseType, typing.List[Any], typing.Dict[str, Any]]],
]
JSONdict = typing.Dict[str, JSONitem]

Hashableitem = Union[bool, int, float, str, tuple]


_FORMAT_KEY: str = "__muutils_format__"
_REF_KEY: str = "$ref"

# or if python version <3.9
if typing.TYPE_CHECKING or sys.version_info < (3, 9):
    MonoTuple = typing.Sequence
else:

    class MonoTuple:
        """tuple type hint, but for a tuple of any length with all the same type"""

        __slots__ = ()

        def __new__(cls, *args, **kwargs):
            raise TypeError("Type MonoTuple cannot be instantiated.")

        def __init_subclass__(cls, *args, **kwargs):
            raise TypeError(f"Cannot subclass {cls.__module__}")

        # idk why mypy thinks there is no such function in typing
        @typing._tp_cache  # type: ignore
        def __class_getitem__(cls, params):
            if getattr(params, "__origin__", None) == typing.Union:
                return typing.GenericAlias(tuple, (params, Ellipsis))
            elif isinstance(params, type):
                typing.GenericAlias(tuple, (params, Ellipsis))
            # test if has len and is iterable
            elif isinstance(params, Iterable):
                if len(params) == 0:
                    return tuple
                elif len(params) == 1:
                    return typing.GenericAlias(tuple, (params[0], Ellipsis))
            else:
                raise TypeError(f"MonoTuple expects 1 type argument, got {params = }")


class UniversalContainer:
    """contains everything -- `x in UniversalContainer()` is always True"""

    def __contains__(self, x: Any) -> bool:
        return True


def isinstance_namedtuple(x: Any) -> bool:
    """checks if `x` is a `namedtuple`

    credit to https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
    """
    t: type = type(x)
    b: tuple = t.__bases__
    if len(b) != 1 or (b[0] is not tuple):
        return False
    f: Any = getattr(t, "_fields", None)
    if not isinstance(f, tuple):
        return False
    return all(isinstance(n, str) for n in f)


def try_catch(func: Callable):
    """wraps the function to catch exceptions, returns serialized error message on exception

    returned func will return normal result on success, or error message on exception
    """

    @functools.wraps(func)
    def newfunc(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except Exception as e:
            return f"{e.__class__.__name__}: {e}"

    return newfunc


def _recursive_hashify(obj: Any, force: bool = True) -> Hashableitem:
    if isinstance(obj, typing.Mapping):
        return tuple((k, _recursive_hashify(v)) for k, v in obj.items())
    elif isinstance(obj, (tuple, list, Iterable)):
        return tuple(_recursive_hashify(v) for v in obj)
    elif isinstance(obj, (bool, int, float, str)):
        return obj
    else:
        if force:
            return str(obj)
        else:
            raise ValueError(f"cannot hashify:\n{obj}")


class SerializationException(Exception):
    pass


def string_as_lines(s: str | None) -> list[str]:
    """for easier reading of long strings in json, split up by newlines

    sort of like how jupyter notebooks do it
    """
    if s is None:
        return list()
    else:
        return s.splitlines(keepends=False)


def safe_getsource(func) -> list[str]:
    try:
        return string_as_lines(inspect.getsource(func))
    except Exception as e:
        return string_as_lines(f"Error: Unable to retrieve source code:\n{e}")


# credit to https://stackoverflow.com/questions/51743827/how-to-compare-equality-of-dataclasses-holding-numpy-ndarray-boola-b-raises
def array_safe_eq(a: Any, b: Any) -> bool:
    """check if two objects are equal, account for if numpy arrays or torch tensors"""
    if a is b:
        return True

    if type(a) is not type(b):
        return False

    if (
        str(type(a)) == "<class 'numpy.ndarray'>"
        and str(type(b)) == "<class 'numpy.ndarray'>"
    ) or (
        str(type(a)) == "<class 'torch.Tensor'>"
        and str(type(b)) == "<class 'torch.Tensor'>"
    ):
        return (a == b).all()

    if (
        str(type(a)) == "<class 'pandas.core.frame.DataFrame'>"
        and str(type(b)) == "<class 'pandas.core.frame.DataFrame'>"
    ):
        return a.equals(b)

    if isinstance(a, typing.Sequence) and isinstance(b, typing.Sequence):
        if len(a) == 0 and len(b) == 0:
            return True
        return len(a) == len(b) and all(array_safe_eq(a1, b1) for a1, b1 in zip(a, b))

    if isinstance(a, (dict, typing.Mapping)) and isinstance(b, (dict, typing.Mapping)):
        return len(a) == len(b) and all(
            array_safe_eq(k1, k2) and array_safe_eq(a[k1], b[k2])
            for k1, k2 in zip(a.keys(), b.keys())
        )

    try:
        return bool(a == b)
    except (TypeError, ValueError) as e:
        warnings.warn(f"Cannot compare {a} and {b} for equality\n{e}")
        return NotImplemented  # type: ignore[return-value]


def dc_eq(
    dc1,
    dc2,
    except_when_class_mismatch: bool = False,
    false_when_class_mismatch: bool = True,
    except_when_field_mismatch: bool = False,
) -> bool:
    """
    checks if two dataclasses which (might) hold numpy arrays are equal

    # Parameters:

    - `dc1`: the first dataclass
    - `dc2`: the second dataclass
    - `except_when_class_mismatch: bool`
        if `True`, will throw `TypeError` if the classes are different.
        if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False`
        (default: `False`)
    - `false_when_class_mismatch: bool`
        only relevant if `except_when_class_mismatch` is `False`.
        if `True`, will return `False` if the classes are different.
        if `False`, will attempt to compare the fields.
    - `except_when_field_mismatch: bool`
        only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`.
        if `True`, will throw `TypeError` if the fields are different.
        (default: `True`)

    # Returns:
    - `bool`: True if the dataclasses are equal, False otherwise

    # Raises:
    - `TypeError`: if the dataclasses are of different classes
    - `AttributeError`: if the dataclasses have different fields

    # TODO: after "except when class mismatch" is False, shouldn't we then go to "field keys match"?
    ```
              [START]
                 ▼
           ┌───────────┐  ┌─────────┐
           │dc1 is dc2?├─►│ classes │
           └──┬────────┘No│ match?  │
      ────    │           ├─────────┤
     (True)◄──┘Yes        │No       │Yes
      ────                ▼         ▼
          ┌────────────────┐ ┌────────────┐
          │ except when    │ │ fields keys│
          │ class mismatch?│ │ match?     │
          ├───────────┬────┘ ├───────┬────┘
          │Yes        │No    │No     │Yes
          ▼           ▼      ▼       ▼
     ───────────  ┌──────────┐  ┌────────┐
    { raise     } │ except   │  │ field  │
    { TypeError } │ when     │  │ values │
     ───────────  │ field    │  │ match? │
                  │ mismatch?│  ├────┬───┘
                  ├───────┬──┘  │    │Yes
                  │Yes    │No   │No  ▼
                  ▼       ▼     │   ────
     ───────────────     ─────  │  (True)
    { raise         }   (False)◄┘   ────
    { AttributeError}    ─────
     ───────────────
    ```

    """
    if dc1 is dc2:
        return True

    if dc1.__class__ is not dc2.__class__:
        if except_when_class_mismatch:
            # if the classes don't match, raise an error
            raise TypeError(
                f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`"
            )
        if except_when_field_mismatch:
            dc1_fields: set = set([fld.name for fld in dataclasses.fields(dc1)])
            dc2_fields: set = set([fld.name for fld in dataclasses.fields(dc2)])
            fields_match: bool = set(dc1_fields) == set(dc2_fields)
            if not fields_match:
                # if the fields match, keep going
                raise AttributeError(
                    f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`"
                )
        return False

    return all(
        array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name))
        for fld in dataclasses.fields(dc1)
        if fld.compare
    )

``````{ end_of_file="muutils/json_serialize/util.py" }

``````{ path="muutils/logger/__init__.py"  }
"""(deprecated) experimenting with logging utilities"""

import warnings

from muutils.logger.logger import Logger
from muutils.logger.loggingstream import LoggingStream
from muutils.logger.simplelogger import SimpleLogger
from muutils.logger.timing import TimerContext

warnings.warn(
    DeprecationWarning(
        "muutils.logger is no longer maintained. Consider using [trnbl](https://github.com/mivanit/trnbl) instead."
    )
)

__all__ = [
    # submodules
    "exception_context",
    "headerfuncs",
    "log_util",
    "logger",
    "loggingstream",
    "simplelogger",
    "timing",
    # imports
    "Logger",
    "LoggingStream",
    "SimpleLogger",
    "TimerContext",
]

``````{ end_of_file="muutils/logger/__init__.py" }

``````{ path="muutils/logger/exception_context.py"  }
import json

from muutils.json_serialize import json_serialize


class ExceptionContext:
    """context manager which catches all exceptions happening while the context is open, `.write()` the exception trace to the given stream, and then raises the exception


    for example:

    ```python
    errorfile = open('error.log', 'w')

    with ExceptionContext(errorfile):
            # do something that might throw an exception
            # if it does, the exception trace will be written to errorfile
            # and then the exception will be raised
    ```

    """

    def __init__(self, stream):
        self.stream = stream

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, exc_traceback):
        if exc_type is not None:
            self.stream.write(
                json.dumps(
                    json_serialize(
                        {
                            "exc_type": exc_type,
                            "exc_value": exc_value,
                            "exc_traceback": exc_traceback,
                        }
                    )
                )
            )
            return False
        return True

``````{ end_of_file="muutils/logger/exception_context.py" }

``````{ path="muutils/logger/headerfuncs.py"  }
from __future__ import annotations

import json
from typing import Any, Mapping, Protocol

from muutils.json_serialize import json_serialize

# takes message, level, other data, and outputs message with appropriate header
# HeaderFunction = Callable[[str, int, Any], str]


class HeaderFunction(Protocol):
    def __call__(self, msg: Any, lvl: int, **kwargs) -> str: ...


def md_header_function(
    msg: Any,
    lvl: int,
    stream: str | None = None,
    indent_lvl: str = "  ",
    extra_indent: str = "",
    **kwargs,
) -> str:
    """standard header function. will output

    - `# {msg}`

            for levels in [0, 9]

    - `## {msg}`

            for levels in [10, 19], and so on

    - `[{stream}] # {msg}`

            for a non-`None` stream, with level headers as before

    - `!WARNING! [{stream}] {msg}`

            for level in [-9, -1]

    - `!!WARNING!! [{stream}] {msg}`

            for level in [-19, -10] and so on

    """
    stream_prefix: str = ""
    if stream is not None:
        stream_prefix = f"[{stream}] "

    lvl_div_10: int = lvl // 10

    msg_processed: str
    if isinstance(msg, Mapping):
        msg_processed = ", ".join([f"{k}: {json_serialize(v)}" for k, v in msg.items()])
    else:
        msg_processed = json.dumps(json_serialize(msg))

    if lvl >= 0:
        return f"{extra_indent}{indent_lvl * (lvl_div_10 - 1)}{stream_prefix}#{'#' * lvl_div_10 if lvl else ''} {msg_processed}"
    else:
        exclamation_pts: str = "!" * (abs(lvl) // 10)
        return f"{extra_indent}{exclamation_pts}WARNING{exclamation_pts} {stream_prefix} {msg_processed}"


HEADER_FUNCTIONS: dict[str, HeaderFunction] = {
    "md": md_header_function,
}

``````{ end_of_file="muutils/logger/headerfuncs.py" }

``````{ path="muutils/logger/log_util.py"  }
from __future__ import annotations
from muutils.jsonlines import jsonl_load_log


def get_any_from_stream(stream: list[dict], key: str) -> None:
    """get the first value of a key from a stream. errors if not found"""
    for msg in stream:
        if key in msg:
            return msg[key]

    raise KeyError(f"key '{key}' not found in stream")


def gather_log(file: str) -> dict[str, list[dict]]:
    """gathers and sorts all streams from a log"""
    data: list[dict] = jsonl_load_log(file)
    output: dict[str, list[dict]] = dict()

    for item in data:
        stream: str = item.get("_stream", "default")
        if stream not in output:
            output[stream] = list()
        output[stream].append(item)

    return output


def gather_stream(
    file: str,
    stream: str,
) -> list[dict]:
    """gets all entries from a specific stream in a log file"""
    data: list[dict] = jsonl_load_log(file)

    output: list[dict] = list()

    for item in data:
        # select for the stream
        if ("_stream" in item) and (item["_stream"] == stream):
            output.append(item)
    return output


def gather_val(
    file: str,
    stream: str,
    keys: tuple[str],
    allow_skip: bool = True,
) -> list[list]:
    """gather specific keys from a specific stream in a log file

    example:
    if "log.jsonl" has contents:
    ```jsonl
    {"a": 1, "b": 2, "c": 3, "_stream": "s1"}
    {"a": 4, "b": 5, "c": 6, "_stream": "s1"}
    {"a": 7, "b": 8, "c": 9, "_stream": "s2"}
    ```
    then `gather_val("log.jsonl", "s1", ("a", "b"))` will return
    ```python
    [
        [1, 2],
        [4, 5]
    ]
    ```

    """
    data: list[dict] = jsonl_load_log(file)

    output: list[list] = list()

    for item in data:
        # select for the stream
        if ("_stream" in item) and (item["_stream"] == stream):
            # select for the keys
            if all(k in item for k in keys):
                output.append(list(item[k] for k in keys))
            elif not allow_skip:
                raise ValueError(f"missing keys '{keys = }' in '{item = }'")

    return output

``````{ end_of_file="muutils/logger/log_util.py" }

``````{ path="muutils/logger/logger.py"  }
"""logger with streams & levels, and a timer context manager

- `SimpleLogger` is an extremely simple logger that can write to both console and a file
- `Logger` class handles levels in a slightly different way than default python `logging`,
        and also has "streams" which allow for different sorts of output in the same logger
        this was mostly made with training models in mind and storing both metadata and loss
- `TimerContext` is a context manager that can be used to time the duration of a block of code
"""

from __future__ import annotations

import json
import time
import typing
from functools import partial
from typing import Callable, Sequence

from muutils.json_serialize import JSONitem, json_serialize
from muutils.logger.exception_context import ExceptionContext
from muutils.logger.headerfuncs import HEADER_FUNCTIONS, HeaderFunction
from muutils.logger.loggingstream import LoggingStream
from muutils.logger.simplelogger import AnyIO, SimpleLogger

# pylint: disable=arguments-differ, bad-indentation, trailing-whitespace, trailing-newlines, unnecessary-pass, consider-using-with, use-dict-literal


def decode_level(level: int) -> str:
    if not isinstance(level, int):
        raise TypeError(f"level must be int, got {type(level) = } {level = }")

    if level < -255:
        return f"FATAL_ERROR({level})"
    elif level < 0:
        return f"WARNING({level})"
    else:
        return f"INFO({level})"


# todo: add a context which catches and logs all exceptions
class Logger(SimpleLogger):
    """logger with more features, including log levels and streams

    # Parameters:
            - `log_path : str | None`
            default log file path
            (defaults to `None`)
            - `log_file : AnyIO | None`
            default log io, should have a `.write()` method (pass only this or `log_path`, not both)
            (defaults to `None`)
            - `timestamp : bool`
            whether to add timestamps to every log message (under the `_timestamp` key)
            (defaults to `True`)
            - `default_level : int`
            default log level for streams/messages that don't specify a level
            (defaults to `0`)
            - `console_print_threshold : int`
            log level at which to print to the console, anything greater will not be printed unless overridden by `console_print`
            (defaults to `50`)
            - `level_header : HeaderFunction`
            function for formatting log messages when printing to console
            (defaults to `HEADER_FUNCTIONS["md"]`)
    - `keep_last_msg_time : bool`
            whether to keep the last message time
            (defaults to `True`)


    # Raises:
            - `ValueError` : _description_
    """

    def __init__(
        self,
        log_path: str | None = None,
        log_file: AnyIO | None = None,
        default_level: int = 0,
        console_print_threshold: int = 50,
        level_header: HeaderFunction = HEADER_FUNCTIONS["md"],
        streams: dict[str | None, LoggingStream] | Sequence[LoggingStream] = (),
        keep_last_msg_time: bool = True,
        # junk args
        timestamp: bool = True,
        **kwargs,
    ):
        # junk arg checking
        # ==================================================
        if len(kwargs) > 0:
            raise ValueError(f"unrecognized kwargs: {kwargs}")

        if not timestamp:
            raise ValueError(
                "timestamp must be True -- why would you not want timestamps?"
            )

        # timing
        # ==================================================
        # timing compares
        self._keep_last_msg_time: bool = keep_last_msg_time
        # TODO: handle per stream?
        self._last_msg_time: float | None = time.time()

        # basic setup
        # ==================================================
        # init BaseLogger
        super().__init__(log_file=log_file, log_path=log_path, timestamp=timestamp)

        # level-related
        self._console_print_threshold: int = console_print_threshold
        self._default_level: int = default_level

        # set up streams
        self._streams: dict[str | None, LoggingStream] = (
            streams
            if isinstance(streams, typing.Mapping)
            else {s.name: s for s in streams}
        )
        # default error stream
        if "error" not in self._streams:
            self._streams["error"] = LoggingStream(
                "error",
                aliases={
                    "err",
                    "except",
                    "Exception",
                    "exception",
                    "exceptions",
                    "errors",
                },
            )

        # check alias duplicates
        alias_set: set[str | None] = set()
        for stream in self._streams.values():
            for alias in stream.aliases:
                if alias in alias_set:
                    raise ValueError(f"alias {alias} is already in use")
                alias_set.add(alias)

        # add aliases
        for stream in tuple(self._streams.values()):
            for alias in stream.aliases:
                if alias not in self._streams:
                    self._streams[alias] = stream

        # print formatting
        self._level_header: HeaderFunction = level_header

        print({k: str(v) for k, v in self._streams.items()})

    def _exception_context(
        self,
        stream: str = "error",
        # level: int = -256,
        # **kwargs,
    ) -> ExceptionContext:
        s: LoggingStream = self._streams[stream]
        return ExceptionContext(stream=s)

    def log(  # type: ignore # yes, the signatures are different here.
        self,
        msg: JSONitem = None,
        lvl: int | None = None,
        stream: str | None = None,
        console_print: bool = False,
        extra_indent: str = "",
        **kwargs,
    ):
        """logging function

        ### Parameters:
         - `msg : JSONitem`
           message (usually string or dict) to be logged
         - `lvl : int | None`
           level of message (lower levels are more important)
           (defaults to `None`)
         - `console_print : bool`
           override `console_print_threshold` setting
           (defaults to `False`)
         - `stream : str | None`
           whether to log to a stream (defaults to `None`), which logs to the default `None` stream
           (defaults to `None`)
        """

        # add to known stream names if not present
        if stream not in self._streams:
            self._streams[stream] = LoggingStream(stream)

        # set default level to either global or stream-specific default level
        # ========================================
        if lvl is None:
            if stream is None:
                lvl = self._default_level
            else:
                if self._streams[stream].default_level is not None:
                    lvl = self._streams[stream].default_level
                else:
                    lvl = self._default_level

        assert lvl is not None, "lvl should not be None at this point"

        # print to console with formatting
        # ========================================
        _printed: bool = False
        if console_print or (lvl <= self._console_print_threshold):
            # add some formatting
            print(
                self._level_header(
                    msg=msg,
                    lvl=lvl,
                    stream=stream,
                    extra_indent=extra_indent,
                )
            )

            # store the last message time
            if self._last_msg_time is not None:
                self._last_msg_time = time.time()

            _printed = True

        # convert and add data
        # ========================================
        # converting to dict
        msg_dict: typing.Mapping
        if not isinstance(msg, typing.Mapping):
            msg_dict = {"_msg": msg}
        else:
            msg_dict = msg

        # level+stream metadata
        if lvl is not None:
            msg_dict["_lvl"] = lvl

        # msg_dict["_stream"] = stream # moved to LoggingStream

        # extra data in kwargs
        if len(kwargs) > 0:
            msg_dict["_kwargs"] = kwargs

        # add default contents (timing, etc)
        msg_dict = {
            **{k: v() for k, v in self._streams[stream].default_contents.items()},
            **msg_dict,
        }

        # write
        # ========================================
        logfile_msg: str = json.dumps(json_serialize(msg_dict)) + "\n"
        if (
            (stream is None)
            or (stream not in self._streams)
            or (self._streams[stream].handler is None)
        ):
            # write to the main log file if no stream is specified
            self._log_file_handle.write(logfile_msg)
        else:
            # otherwise, write to the stream-specific file
            s_handler: AnyIO | None = self._streams[stream].handler
            if s_handler is not None:
                s_handler.write(logfile_msg)
            else:
                raise ValueError(
                    f"stream handler is None! something in the logging stream setup is wrong:\n{self}"
                )

        # if it was important enough to print, flush all streams
        if _printed:
            self.flush_all()

    def log_elapsed_last(
        self,
        lvl: int | None = None,
        stream: str | None = None,
        console_print: bool = True,
        **kwargs,
    ) -> float:
        """logs the time elapsed since the last message was printed to the console (in any stream)"""
        if self._last_msg_time is None:
            raise ValueError("no last message time!")
        else:
            return self.log(
                {"elapsed_time": round(time.time() - self._last_msg_time, 6)},
                lvl=(lvl if lvl is not None else self._console_print_threshold),
                stream=stream,
                console_print=console_print,
                **kwargs,
            )

    def flush_all(self):
        """flush all streams"""

        self._log_file_handle.flush()

        for stream in self._streams.values():
            if stream.handler is not None:
                stream.handler.flush()

    def __getattr__(self, stream: str) -> Callable:
        if stream.startswith("_"):
            raise AttributeError(f"invalid stream name {stream} (no underscores)")
        return partial(self.log, stream=stream)

    def __getitem__(self, stream: str):
        return partial(self.log, stream=stream)

    def __call__(self, *args, **kwargs):
        return self.log(*args, **kwargs)

``````{ end_of_file="muutils/logger/logger.py" }

``````{ path="muutils/logger/loggingstream.py"  }
from __future__ import annotations

import time
from dataclasses import dataclass, field
from typing import Any, Callable

from muutils.logger.simplelogger import AnyIO, NullIO
from muutils.misc import sanitize_fname


@dataclass
class LoggingStream:
    """properties of a logging stream

    - `name: str` name of the stream
    - `aliases: set[str]` aliases for the stream
            (calls to these names will be redirected to this stream. duplicate alises will result in errors)
            TODO: perhaps duplicate alises should result in duplicate writes?
    - `file: str|bool|AnyIO|None` file to write to
            - if `None`, will write to standard log
            - if `True`, will write to `name + ".log"`
            - if `False` will "write" to `NullIO` (throw it away)
            - if a string, will write to that file
            - if a fileIO type object, will write to that object
    - `default_level: int|None` default level for this stream
    - `default_contents: dict[str, Callable[[], Any]]` default contents for this stream
    - `last_msg: tuple[float, Any]|None` last message written to this stream (timestamp, message)
    """

    name: str | None
    aliases: set[str | None] = field(default_factory=set)
    file: str | bool | AnyIO | None = None
    default_level: int | None = None
    default_contents: dict[str, Callable[[], Any]] = field(default_factory=dict)
    handler: AnyIO | None = None

    # TODO: implement last-message caching
    # last_msg: tuple[float, Any]|None = None

    def make_handler(self) -> AnyIO | None:
        if self.file is None:
            return None
        elif isinstance(self.file, str):
            # if its a string, open a file
            return open(
                self.file,
                "w",
                encoding="utf-8",
            )
        elif isinstance(self.file, bool):
            # if its a bool and true, open a file with the same name as the stream (in the current dir)
            # TODO: make this happen in the same dir as the main logfile?
            if self.file:
                return open(  # type: ignore[return-value]
                    f"{sanitize_fname(self.name)}.log.jsonl",
                    "w",
                    encoding="utf-8",
                )
            else:
                return NullIO()
        else:
            # if its neither, check it has `.write()` and `.flush()` methods
            if (
                (
                    not hasattr(self.file, "write")
                    or (not callable(self.file.write))
                    or (not hasattr(self.file, "flush"))
                    or (not callable(self.file.flush))
                )
                or (not hasattr(self.file, "close"))
                or (not callable(self.file.close))
            ):
                raise ValueError(f"stream {self.name} has invalid handler {self.file}")
            # ignore type check because we know it has a .write() method,
            # assume the user knows what they're doing
            return self.file  # type: ignore

    def __post_init__(self):
        self.aliases = set(self.aliases)
        if any(x.startswith("_") for x in self.aliases if x is not None):
            raise ValueError(
                "stream names or aliases cannot start with an underscore, sorry"
            )
        self.aliases.add(self.name)
        self.default_contents["_timestamp"] = time.time
        self.default_contents["_stream"] = lambda: self.name
        self.handler = self.make_handler()

    def __del__(self):
        if self.handler is not None:
            self.handler.flush()
            self.handler.close()

    def __str__(self):
        return f"LoggingStream(name={self.name}, aliases={self.aliases}, file={self.file}, default_level={self.default_level}, default_contents={self.default_contents})"

``````{ end_of_file="muutils/logger/loggingstream.py" }

``````{ path="muutils/logger/simplelogger.py"  }
from __future__ import annotations

import json
import sys
import time
import typing
from typing import TextIO, Union

from muutils.json_serialize import JSONitem, json_serialize


class NullIO:
    """null IO class"""

    def __init__(self) -> None:
        pass

    def write(self, msg: str) -> int:
        """write to nothing! this throws away the message"""
        return len(msg)

    def flush(self) -> None:
        """flush nothing! this is a no-op"""
        pass

    def close(self) -> None:
        """close nothing! this is a no-op"""
        pass


AnyIO = Union[TextIO, NullIO]


class SimpleLogger:
    """logs training data to a jsonl file"""

    def __init__(
        self,
        log_path: str | None = None,
        log_file: AnyIO | None = None,
        timestamp: bool = True,
    ):
        self._timestamp: bool = timestamp
        self._log_path: str | None = log_path

        self._log_file_handle: AnyIO

        if (log_path is None) and (log_file is None):
            print(
                "[logger_internal] # no log file specified, will only write to console",
                sys.stderr,
            )
            self._log_file_handle = sys.stdout

        elif (log_path is not None) and (log_file is not None):
            raise ValueError(
                "cannot specify both log_path and log_file, use streams in `SimpleLogger`"
            )
        else:
            # now exactly one of the two is None
            if log_file is not None:
                self._log_file_handle = log_file
            else:
                assert log_path is not None
                self._log_file_handle = open(log_path, "w", encoding="utf-8")

    def log(self, msg: JSONitem, console_print: bool = False, **kwargs):
        """log a message to the log file, and optionally to the console"""
        if console_print:
            print(msg)

        if not isinstance(msg, typing.Mapping):
            msg = {"_msg": msg}

        if self._timestamp:
            msg["_timestamp"] = time.time()

        if len(kwargs) > 0:
            msg["_kwargs"] = kwargs

        self._log_file_handle.write(json.dumps(json_serialize(msg)) + "\n")

``````{ end_of_file="muutils/logger/simplelogger.py" }

``````{ path="muutils/logger/timing.py"  }
from __future__ import annotations

import time
from typing import Literal


class TimerContext:
    """context manager for timing code"""

    def __init__(self) -> None:
        self.start_time: float
        self.end_time: float
        self.elapsed_time: float

    def __enter__(self) -> "TimerContext":
        self.start_time = time.time()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb) -> Literal[False]:
        self.end_time = time.time()
        self.elapsed_time = self.end_time - self.start_time
        return False


def filter_time_str(time: str) -> str:
    """assuming format `h:mm:ss`, clips off the hours if its 0"""
    if (len(time) == 7) and (time[0] == "0"):
        return time[3:]
    else:
        return time


class ProgressEstimator:
    """estimates progress and can give a progress bar"""

    def __init__(
        self,
        n_total: int,
        pbar_fill: str = "█",
        pbar_empty: str = " ",
        pbar_bounds: tuple[str, str] = ("|", "|"),
    ):
        self.n_total: int = n_total
        self.starttime: float = time.time()
        self.pbar_fill: str = pbar_fill
        self.pbar_empty: str = pbar_empty
        self.pbar_bounds: tuple[str, str] = pbar_bounds
        self.total_str_len: int = len(str(n_total))

    def get_timing_raw(self, i: int) -> dict[str, float]:
        """returns dict(elapsed, per_iter, remaining, percent)"""
        elapsed: float = time.time() - self.starttime
        per_iter: float = elapsed / i
        return dict(
            elapsed=elapsed,
            per_iter=per_iter,
            remaining=(self.n_total - i) * per_iter,
            percent=i / self.n_total,
        )

    def get_pbar(
        self,
        i: int,
        width: int = 30,
    ) -> str:
        """returns a progress bar"""
        percent_filled: float = i / self.n_total
        # round to nearest integer
        n_filled: int = int(round(percent_filled * width))
        return "".join(
            [
                self.pbar_bounds[0],
                self.pbar_fill * n_filled,
                self.pbar_empty * (width - n_filled),
                self.pbar_bounds[1],
            ]
        )

    def get_progress_default(self, i: int) -> str:
        """returns a progress string"""
        timing_raw: dict[str, float] = self.get_timing_raw(i)

        percent_str: str = str(int(timing_raw["percent"] * 100)).ljust(2)
        # TODO: get_progress_default
        # iters_str: str = f"{str(i).ljust(self.total_str_len)}/{self.n_total}"
        # timing_str: str
        return f"{percent_str}% {self.get_pbar(i)}"

``````{ end_of_file="muutils/logger/timing.py" }

``````{ path="muutils/math/__init__.py"  }
__all__ = [
    "bins",
    "matrix_powers",
]

``````{ end_of_file="muutils/math/__init__.py" }

``````{ path="muutils/math/bins.py"  }
from __future__ import annotations

from dataclasses import dataclass
from functools import cached_property
from typing import Literal

import numpy as np
from jaxtyping import Float


@dataclass(frozen=True)
class Bins:
    n_bins: int = 32
    start: float = 0
    stop: float = 1.0
    scale: Literal["lin", "log"] = "log"

    _log_min: float = 1e-3
    _zero_in_small_start_log: bool = True

    @cached_property
    def edges(self) -> Float[np.ndarray, "n_bins+1"]:
        if self.scale == "lin":
            return np.linspace(self.start, self.stop, self.n_bins + 1)
        elif self.scale == "log":
            if self.start < 0:
                raise ValueError(
                    f"start must be positive for log scale, got {self.start}"
                )
            if self.start == 0:
                return np.concatenate(
                    [
                        np.array([0]),
                        np.logspace(
                            np.log10(self._log_min), np.log10(self.stop), self.n_bins
                        ),
                    ]
                )
            elif self.start < self._log_min and self._zero_in_small_start_log:
                return np.concatenate(
                    [
                        np.array([0]),
                        np.logspace(
                            np.log10(self.start), np.log10(self.stop), self.n_bins
                        ),
                    ]
                )
            else:
                return np.logspace(
                    np.log10(self.start), np.log10(self.stop), self.n_bins + 1
                )
        else:
            raise ValueError(f"Invalid scale {self.scale}, expected lin or log")

    @cached_property
    def centers(self) -> Float[np.ndarray, "n_bins"]:
        return (self.edges[:-1] + self.edges[1:]) / 2

    def changed_n_bins_copy(self, n_bins: int) -> "Bins":
        return Bins(
            n_bins=n_bins,
            start=self.start,
            stop=self.stop,
            scale=self.scale,
            _log_min=self._log_min,
            _zero_in_small_start_log=self._zero_in_small_start_log,
        )

``````{ end_of_file="muutils/math/bins.py" }

``````{ path="muutils/math/matrix_powers.py"  }
from __future__ import annotations

from typing import List, Sequence, TYPE_CHECKING

import numpy as np
from jaxtyping import Float, Int

if TYPE_CHECKING:
    pass


def matrix_powers(
    A: Float[np.ndarray, "n n"],
    powers: Sequence[int],
) -> Float[np.ndarray, "n_powers n n"]:
    """Compute multiple powers of a matrix efficiently.

    Uses binary exponentiation to compute powers in O(log max(powers))
    matrix multiplications, avoiding redundant calculations when
    computing multiple powers.

    # Parameters:
     - `A : Float[np.ndarray, "n n"]`
            Square matrix to exponentiate
     - `powers : Sequence[int]`
            List of powers to compute (non-negative integers)

    # Returns:
     - `dict[int, Float[np.ndarray, "n n"]]`
            Dictionary mapping each requested power to the corresponding matrix power
    """
    dim_n: int = A.shape[0]
    assert A.shape[0] == A.shape[1], f"Matrix must be square, but got {A.shape = }"
    powers_np: Int[np.ndarray, "n_powers_unique"] = np.array(
        sorted(set(powers)), dtype=int
    )
    n_powers_unique: int = len(powers_np)

    if n_powers_unique < 1:
        raise ValueError(f"No powers requested: {powers = }")

    output: Float[np.ndarray, "n_powers_unique n n"] = np.full(
        (n_powers_unique, dim_n, dim_n),
        fill_value=np.nan,
        dtype=A.dtype,
    )

    # Find the maximum power to compute
    max_power: int = max(powers_np)

    # Precompute all powers of 2 up to the largest power needed
    # This forms our basis for binary decomposition
    powers_of_two: dict[int, Float[np.ndarray, "n n"]] = {}
    powers_of_two[0] = np.eye(dim_n, dtype=A.dtype)
    powers_of_two[1] = A.copy()

    # Compute powers of 2: A^2, A^4, A^8, ...
    p: int = 1
    while p < max_power:
        if p <= max_power:
            A_power_p = powers_of_two[p]
            powers_of_two[p * 2] = A_power_p @ A_power_p
        p = p * 2

    # For each requested power, compute it using the powers of 2
    for p_idx, power in enumerate(powers_np):
        # Decompose power into sum of powers of 2
        temp_result: Float[np.ndarray, "n n"] = powers_of_two[0].copy()
        temp_power: int = power
        p_temp: int = 1

        while temp_power > 0:
            if temp_power % 2 == 1:
                temp_result = temp_result @ powers_of_two[p_temp]
            temp_power = temp_power // 2
            p_temp *= 2

        output[p_idx] = temp_result

    return output


# BUG: breaks with integer matrices???
# TYPING: jaxtyping hints not working here, separate file for torch implementation?
def matrix_powers_torch(
    A,  # : Float["torch.Tensor", "n n"],
    powers: Sequence[int],
):  # Float["torch.Tensor", "n_powers n n"]:
    """Compute multiple powers of a matrix efficiently.

    Uses binary exponentiation to compute powers in O(log max(powers))
    matrix multiplications, avoiding redundant calculations when
    computing multiple powers.

    # Parameters:
     - `A : Float[torch.Tensor, "n n"]`
        Square matrix to exponentiate
     - `powers : Sequence[int]`
        List of powers to compute (non-negative integers)

    # Returns:
     - `Float[torch.Tensor, "n_powers n n"]`
        Tensor containing the requested matrix powers stacked along the first dimension

    # Raises:
     - `ValueError` : If no powers are requested or if A is not a square matrix
    """

    import torch

    if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
        raise ValueError(f"Matrix must be square, but got {A.shape = }")

    dim_n: int = A.shape[0]
    # Get unique powers and sort them
    unique_powers: List[int] = sorted(set(powers))
    n_powers_unique: int = len(unique_powers)
    powers_tensor: Int[torch.Tensor, "n_powers_unique"] = torch.tensor(
        unique_powers, dtype=torch.int64, device=A.device
    )

    if n_powers_unique < 1:
        raise ValueError(f"No powers requested: {powers = }")

    output: Float[torch.Tensor, "n_powers_unique n n"] = torch.full(
        (n_powers_unique, dim_n, dim_n),
        float("nan"),
        dtype=A.dtype,
        device=A.device,
    )

    # Find the maximum power to compute
    max_power: int = int(powers_tensor.max().item())

    # Precompute all powers of 2 up to the largest power needed
    # This forms our basis for binary decomposition
    powers_of_two: dict[int, Float[torch.Tensor, "n n"]] = {}
    powers_of_two[0] = torch.eye(dim_n, dtype=A.dtype, device=A.device)
    powers_of_two[1] = A.clone()

    # Compute powers of 2: A^2, A^4, A^8, ...
    p: int = 1
    while p < max_power:
        if p <= max_power:
            A_power_p: Float[torch.Tensor, "n n"] = powers_of_two[p]
            powers_of_two[p * 2] = A_power_p @ A_power_p
        p = p * 2

    # For each requested power, compute it using the powers of 2
    for p_idx, power in enumerate(unique_powers):
        # Decompose power into sum of powers of 2
        temp_result: Float[torch.Tensor, "n n"] = powers_of_two[0].clone()
        temp_power: int = power
        p_temp: int = 1

        while temp_power > 0:
            if temp_power % 2 == 1:
                temp_result = temp_result @ powers_of_two[p_temp]
            temp_power = temp_power // 2
            p_temp *= 2

        output[p_idx] = temp_result

    return output

``````{ end_of_file="muutils/math/matrix_powers.py" }

``````{ path="muutils/misc/__init__.py"  }
"""miscellaneous utilities

- `stable_hash` for hashing that is stable across runs
- `muutils.misc.sequence` for sequence manipulation, applying mappings, and string-like operations on lists
- `muutils.misc.string` for sanitizing things for filenames, adjusting docstrings, and converting dicts to filenames
- `muutils.misc.numerical` for turning numbers into nice strings and back
- `muutils.misc.freezing` for freezing things
- `muutils.misc.classes` for some weird class utilities
"""

from muutils.misc.hashing import stable_hash
from muutils.misc.sequence import (
    WhenMissing,
    empty_sequence_if_attr_false,
    flatten,
    list_split,
    list_join,
    apply_mapping,
    apply_mapping_chain,
)
from muutils.misc.string import (
    sanitize_name,
    sanitize_fname,
    sanitize_identifier,
    dict_to_filename,
    dynamic_docstring,
)
from muutils.misc.numerical import (
    shorten_numerical_to_str,
    str_to_numeric,
    _SHORTEN_MAP,
)
from muutils.misc.freezing import (
    FrozenDict,
    FrozenList,
    freeze,
)
from muutils.misc.classes import (
    is_abstract,
    get_all_subclasses,
    isinstance_by_type_name,
    IsDataclass,
    get_hashable_eq_attrs,
    dataclass_set_equals,
)


__all__ = [
    # submodules
    "classes",
    "freezing",
    "func",
    "hashing",
    "numerical",
    "sequence",
    "string",
    # imports
    "stable_hash",
    "WhenMissing",
    "empty_sequence_if_attr_false",
    "flatten",
    "list_split",
    "list_join",
    "apply_mapping",
    "apply_mapping_chain",
    "sanitize_name",
    "sanitize_fname",
    "sanitize_identifier",
    "dict_to_filename",
    "dynamic_docstring",
    "shorten_numerical_to_str",
    "str_to_numeric",
    "_SHORTEN_MAP",
    "FrozenDict",
    "FrozenList",
    "freeze",
    "is_abstract",
    "get_all_subclasses",
    "isinstance_by_type_name",
    "IsDataclass",
    "get_hashable_eq_attrs",
    "dataclass_set_equals",
]

``````{ end_of_file="muutils/misc/__init__.py" }

``````{ path="muutils/misc/b64_decode.py"  }
from sys import argv
from pathlib import Path
from base64 import b64decode

if __name__ == "__main__":
    input_file: Path = Path(argv[1])
    out: Path = Path(argv[2])
    input_text: str = input_file.read_text().replace("\n", "")
    out.write_bytes(b64decode(input_text))

``````{ end_of_file="muutils/misc/b64_decode.py" }

``````{ path="muutils/misc/classes.py"  }
from __future__ import annotations

from typing import (
    Iterable,
    Any,
    Protocol,
    ClassVar,
    runtime_checkable,
)

from muutils.misc.sequence import flatten


def is_abstract(cls: type) -> bool:
    """
    Returns if a class is abstract.
    """
    if not hasattr(cls, "__abstractmethods__"):
        return False  # an ordinary class
    elif len(cls.__abstractmethods__) == 0:
        return False  # a concrete implementation of an abstract class
    else:
        return True  # an abstract class


def get_all_subclasses(class_: type, include_self=False) -> set[type]:
    """
    Returns a set containing all child classes in the subclass graph of `class_`.
    I.e., includes subclasses of subclasses, etc.

    # Parameters
    - `include_self`: Whether to include `class_` itself in the returned set
    - `class_`: Superclass

    # Development
    Since most class hierarchies are small, the inefficiencies of the existing recursive implementation aren't problematic.
    It might be valuable to refactor with memoization if the need arises to use this function on a very large class hierarchy.
    """
    subs: set[type] = set(
        flatten(
            get_all_subclasses(sub, include_self=True)
            for sub in class_.__subclasses__()
            if sub is not None
        )
    )
    if include_self:
        subs.add(class_)
    return subs


def isinstance_by_type_name(o: object, type_name: str):
    """Behaves like stdlib `isinstance` except it accepts a string representation of the type rather than the type itself.
    This is a hacky function intended to circumvent the need to import a type into a module.
    It is susceptible to type name collisions.

    # Parameters
    `o`: Object (not the type itself) whose type to interrogate
    `type_name`: The string returned by `type_.__name__`.
    Generic types are not supported, only types that would appear in `type_.__mro__`.
    """
    return type_name in {s.__name__ for s in type(o).__mro__}


# dataclass magic
# --------------------------------------------------------------------------------


@runtime_checkable
class IsDataclass(Protocol):
    # Generic type for any dataclass instance
    # https://stackoverflow.com/questions/54668000/type-hint-for-an-instance-of-a-non-specific-dataclass
    __dataclass_fields__: ClassVar[dict[str, Any]]


def get_hashable_eq_attrs(dc: IsDataclass) -> tuple[Any]:
    """Returns a tuple of all fields used for equality comparison, including the type of the dataclass itself.
    The type is included to preserve the unequal equality behavior of instances of different dataclasses whose fields are identical.
    Essentially used to generate a hashable dataclass representation for equality comparison even if it's not frozen.
    """
    return *(
        getattr(dc, fld.name)
        for fld in filter(lambda x: x.compare, dc.__dataclass_fields__.values())
    ), type(dc)


def dataclass_set_equals(
    coll1: Iterable[IsDataclass], coll2: Iterable[IsDataclass]
) -> bool:
    """Compares 2 collections of dataclass instances as if they were sets.
    Duplicates are ignored in the same manner as a set.
    Unfrozen dataclasses can't be placed in sets since they're not hashable.
    Collections of them may be compared using this function.
    """

    return {get_hashable_eq_attrs(x) for x in coll1} == {
        get_hashable_eq_attrs(y) for y in coll2
    }

``````{ end_of_file="muutils/misc/classes.py" }

``````{ path="muutils/misc/freezing.py"  }
from __future__ import annotations
from typing import Any, TypeVar, overload


class FrozenDict(dict):
    def __setitem__(self, key, value):
        raise AttributeError("dict is frozen")

    def __delitem__(self, key):
        raise AttributeError("dict is frozen")


class FrozenList(list):
    def __setitem__(self, index, value):
        raise AttributeError("list is frozen")

    def __delitem__(self, index):
        raise AttributeError("list is frozen")

    def append(self, value):
        raise AttributeError("list is frozen")

    def extend(self, iterable):
        raise AttributeError("list is frozen")

    def insert(self, index, value):
        raise AttributeError("list is frozen")

    def remove(self, value):
        raise AttributeError("list is frozen")

    def pop(self, index=-1):
        raise AttributeError("list is frozen")

    def clear(self):
        raise AttributeError("list is frozen")


FreezeMe = TypeVar("FreezeMe")


@overload
def freeze(instance: dict) -> FrozenDict: ...
@overload
def freeze(instance: list) -> FrozenList: ...
@overload
def freeze(instance: tuple) -> tuple: ...
@overload
def freeze(instance: set) -> frozenset: ...
@overload
def freeze(instance: FreezeMe) -> FreezeMe: ...
def freeze(instance: Any) -> Any:
    """recursively freeze an object in-place so that its attributes and elements cannot be changed

    messy in the sense that sometimes the object is modified in place, but you can't rely on that. always use the return value.

    the [gelidum](https://github.com/diegojromerolopez/gelidum/) package is a more complete implementation of this idea

    """

    # mark as frozen
    if hasattr(instance, "_IS_FROZEN"):
        if instance._IS_FROZEN:
            return instance

    # try to mark as frozen
    try:
        instance._IS_FROZEN = True  # type: ignore[attr-defined]
    except AttributeError:
        pass

    # skip basic types, weird things, or already frozen things
    if isinstance(instance, (bool, int, float, str, bytes)):
        pass

    elif isinstance(instance, (type(None), type(Ellipsis))):
        pass

    elif isinstance(instance, (FrozenList, FrozenDict, frozenset)):
        pass

    # handle containers
    elif isinstance(instance, list):
        for i in range(len(instance)):
            instance[i] = freeze(instance[i])
        instance = FrozenList(instance)

    elif isinstance(instance, tuple):
        instance = tuple(freeze(item) for item in instance)

    elif isinstance(instance, set):
        instance = frozenset({freeze(item) for item in instance})  # type: ignore[assignment]

    elif isinstance(instance, dict):
        for key, value in instance.items():
            instance[key] = freeze(value)
        instance = FrozenDict(instance)

    # handle custom classes
    else:
        # set everything in the __dict__ to frozen
        instance.__dict__ = freeze(instance.__dict__)  # type: ignore[assignment]

        # create a new class which inherits from the original class
        class FrozenClass(instance.__class__):  # type: ignore[name-defined]
            def __setattr__(self, name, value):
                raise AttributeError("class is frozen")

        FrozenClass.__name__ = f"FrozenClass__{instance.__class__.__name__}"
        FrozenClass.__module__ = instance.__class__.__module__
        FrozenClass.__doc__ = instance.__class__.__doc__

        # set the instance's class to the new class
        try:
            instance.__class__ = FrozenClass
        except TypeError as e:
            raise TypeError(
                f"Cannot freeze:\n{instance = }\n{instance.__class__ = }\n{FrozenClass = }"
            ) from e

    return instance

``````{ end_of_file="muutils/misc/freezing.py" }

``````{ path="muutils/misc/func.py"  }
from __future__ import annotations
import functools
import sys
from types import CodeType
import warnings
from typing import Any, Callable, Tuple, cast, TypeVar

try:
    if sys.version_info >= (3, 11):
        # 3.11+
        from typing import Unpack, TypeVarTuple, ParamSpec
    else:
        # 3.9+
        from typing_extensions import Unpack, TypeVarTuple, ParamSpec  # type: ignore[assignment]
except ImportError:
    warnings.warn(
        "muutils.misc.func could not import Unpack and TypeVarTuple from typing or typing_extensions, typed_lambda may not work"
    )
    ParamSpec = TypeVar  # type: ignore
    Unpack = Any  # type: ignore
    TypeVarTuple = TypeVar  # type: ignore


from muutils.errormode import ErrorMode

warnings.warn("muutils.misc.func is experimental, use with caution")

ReturnType = TypeVar("ReturnType")
T_kwarg = TypeVar("T_kwarg")
T_process_in = TypeVar("T_process_in")
T_process_out = TypeVar("T_process_out")

FuncParams = ParamSpec("FuncParams")
FuncParamsPreWrap = ParamSpec("FuncParamsPreWrap")


def process_kwarg(
    kwarg_name: str,
    processor: Callable[[T_process_in], T_process_out],
) -> Callable[
    [Callable[FuncParamsPreWrap, ReturnType]], Callable[FuncParams, ReturnType]
]:
    """Decorator that applies a processor to a keyword argument.

    The underlying function is expected to have a keyword argument
    (with name `kwarg_name`) of type `T_out`, but the caller provides
    a value of type `T_in` that is converted via `processor`.

    # Parameters:
     - `kwarg_name : str`
        The name of the keyword argument to process.
     - `processor : Callable[[T_in], T_out]`
        A callable that converts the input value (`T_in`) into the
        type expected by the function (`T_out`).

    # Returns:
     - A decorator that converts a function of type
       `Callable[OutputParams, ReturnType]` (expecting `kwarg_name` of type `T_out`)
       into one of type `Callable[InputParams, ReturnType]` (accepting `kwarg_name` of type `T_in`).
    """

    def decorator(
        func: Callable[FuncParamsPreWrap, ReturnType],
    ) -> Callable[FuncParams, ReturnType]:
        @functools.wraps(func)
        def wrapper(*args: Any, **kwargs: Any) -> ReturnType:
            if kwarg_name in kwargs:
                # Convert the caller’s value (of type T_in) to T_out
                kwargs[kwarg_name] = processor(kwargs[kwarg_name])
            return func(*args, **kwargs)  # type: ignore[arg-type]

        return cast(Callable[FuncParams, ReturnType], wrapper)

    return decorator


@process_kwarg("action", ErrorMode.from_any)
def validate_kwarg(
    kwarg_name: str,
    validator: Callable[[T_kwarg], bool],
    description: str | None = None,
    action: ErrorMode = ErrorMode.EXCEPT,
) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]:
    """Decorator that validates a specific keyword argument.

    # Parameters:
     - `kwarg_name : str`
        The name of the keyword argument to validate.
     - `validator : Callable[[Any], bool]`
        A callable that returns True if the keyword argument is valid.
     - `description : str | None`
        A message template if validation fails.
     - `action : str`
        Either `"raise"` (default) or `"warn"`.

    # Returns:
     - `Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]`
        A decorator that validates the keyword argument.

    # Modifies:
     - If validation fails and `action=="warn"`, emits a warning.
       Otherwise, raises a ValueError.

    # Usage:

    ```python
    @validate_kwarg("x", lambda val: val > 0, "Invalid {kwarg_name}: {value}")
    def my_func(x: int) -> int:
        return x

    assert my_func(x=1) == 1
    ```

    # Raises:
     - `ValueError` if validation fails and `action == "raise"`.
    """

    def decorator(
        func: Callable[FuncParams, ReturnType],
    ) -> Callable[FuncParams, ReturnType]:
        @functools.wraps(func)
        def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType:
            if kwarg_name in kwargs:
                value: Any = kwargs[kwarg_name]
                if not validator(value):
                    msg: str = (
                        description.format(kwarg_name=kwarg_name, value=value)
                        if description
                        else f"Validation failed for keyword '{kwarg_name}' with value {value}"
                    )
                    if action == "warn":
                        warnings.warn(msg, UserWarning)
                    else:
                        raise ValueError(msg)
            return func(*args, **kwargs)

        return cast(Callable[FuncParams, ReturnType], wrapper)

    return decorator


def replace_kwarg(
    kwarg_name: str,
    check: Callable[[T_kwarg], bool],
    replacement_value: T_kwarg,
    replace_if_missing: bool = False,
) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]:
    """Decorator that replaces a specific keyword argument value by identity comparison.

    # Parameters:
     - `kwarg_name : str`
        The name of the keyword argument to replace.
     - `check : Callable[[T_kwarg], bool]`
        A callable that returns True if the keyword argument should be replaced.
     - `replacement_value : T_kwarg`
        The value to replace with.
     - `replace_if_missing : bool`
        If True, replaces the keyword argument even if it's missing.

    # Returns:
     - `Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]`
        A decorator that replaces the keyword argument value.

    # Modifies:
     - Updates `kwargs[kwarg_name]` if its value is `default_value`.

    # Usage:

    ```python
    @replace_kwarg("x", None, "default_string")
    def my_func(*, x: str | None = None) -> str:
        return x

    assert my_func(x=None) == "default_string"
    ```
    """

    def decorator(
        func: Callable[FuncParams, ReturnType],
    ) -> Callable[FuncParams, ReturnType]:
        @functools.wraps(func)
        def wrapper(*args: FuncParams.args, **kwargs: FuncParams.kwargs) -> ReturnType:
            if kwarg_name in kwargs:
                # TODO: no way to type hint this, I think
                if check(kwargs[kwarg_name]):  # type: ignore[arg-type]
                    kwargs[kwarg_name] = replacement_value
            elif replace_if_missing and kwarg_name not in kwargs:
                kwargs[kwarg_name] = replacement_value
            return func(*args, **kwargs)

        return cast(Callable[FuncParams, ReturnType], wrapper)

    return decorator


def is_none(value: Any) -> bool:
    return value is None


def always_true(value: Any) -> bool:
    return True


def always_false(value: Any) -> bool:
    return False


def format_docstring(
    **fmt_kwargs: Any,
) -> Callable[[Callable[FuncParams, ReturnType]], Callable[FuncParams, ReturnType]]:
    """Decorator that formats a function's docstring with the provided keyword arguments."""

    def decorator(
        func: Callable[FuncParams, ReturnType],
    ) -> Callable[FuncParams, ReturnType]:
        if func.__doc__ is not None:
            func.__doc__ = func.__doc__.format(**fmt_kwargs)
        return func

    return decorator


# TODO: no way to make the type system understand this afaik
LambdaArgs = TypeVarTuple("LambdaArgs")
LambdaArgsTypes = TypeVar("LambdaArgsTypes", bound=Tuple[type, ...])


def typed_lambda(
    fn: Callable[[Unpack[LambdaArgs]], ReturnType],
    in_types: LambdaArgsTypes,
    out_type: type[ReturnType],
) -> Callable[[Unpack[LambdaArgs]], ReturnType]:
    """Wraps a lambda function with type hints.

    # Parameters:
     - `fn : Callable[[Unpack[LambdaArgs]], ReturnType]`
        The lambda function to wrap.
     - `in_types : tuple[type, ...]`
        Tuple of input types.
     - `out_type : type[ReturnType]`
        The output type.

    # Returns:
     - `Callable[..., ReturnType]`
        A new function with annotations matching the given signature.

    # Usage:

    ```python
    add = typed_lambda(lambda x, y: x + y, (int, int), int)
    assert add(1, 2) == 3
    ```

    # Raises:
     - `ValueError` if the number of input types doesn't match the lambda's parameters.
    """
    code: CodeType = fn.__code__
    n_params: int = code.co_argcount

    if len(in_types) != n_params:
        raise ValueError(
            f"Number of input types ({len(in_types)}) doesn't match number of parameters ({n_params})"
        )

    param_names: tuple[str, ...] = code.co_varnames[:n_params]
    annotations: dict[str, type] = {  # type: ignore[var-annotated]
        name: typ
        for name, typ in zip(param_names, in_types)  # type: ignore[arg-type]
    }
    annotations["return"] = out_type

    @functools.wraps(fn)
    def wrapped(*args: Unpack[LambdaArgs]) -> ReturnType:
        return fn(*args)

    wrapped.__annotations__ = annotations
    return wrapped

``````{ end_of_file="muutils/misc/func.py" }

``````{ path="muutils/misc/hashing.py"  }
from __future__ import annotations

import base64
import hashlib
import json


def stable_hash(s: str | bytes) -> int:
    """Returns a stable hash of the given string. not cryptographically secure, but stable between runs"""
    # init hash object and update with string
    s_bytes: bytes
    if isinstance(s, str):
        s_bytes = s.encode("utf-8")
    else:
        s_bytes = s
    hash_obj: hashlib._Hash = hashlib.md5(s_bytes)
    # get digest and convert to int
    return int.from_bytes(hash_obj.digest(), "big")


def stable_json_dumps(d) -> str:
    return json.dumps(
        d,
        sort_keys=True,
        indent=None,
    )


def base64_hash(s: str | bytes) -> str:
    """Returns a base64 representation of the hash of the given string. not cryptographically secure"""
    s_bytes: bytes
    if isinstance(s, str):
        s_bytes = bytes(s, "UTF-8")
    else:
        s_bytes = s
    hash_bytes: bytes = hashlib.md5(s_bytes).digest()
    hash_b64: str = base64.b64encode(hash_bytes, altchars=b"-_").decode()
    return hash_b64

``````{ end_of_file="muutils/misc/hashing.py" }

``````{ path="muutils/misc/numerical.py"  }
from __future__ import annotations


_SHORTEN_MAP: dict[int | float, str] = {
    1e3: "K",
    1e6: "M",
    1e9: "B",
    1e12: "t",
    1e15: "q",
    1e18: "Q",
}

_SHORTEN_TUPLES: list[tuple[int | float, str]] = sorted(
    ((val, suffix) for val, suffix in _SHORTEN_MAP.items()),
    key=lambda x: -x[0],
)


_REVERSE_SHORTEN_MAP: dict[str, int | float] = {v: k for k, v in _SHORTEN_MAP.items()}


def shorten_numerical_to_str(
    num: int | float,
    small_as_decimal: bool = True,
    precision: int = 1,
) -> str:
    """shorten a large numerical value to a string
    1234 -> 1K

    precision guaranteed to 1 in 10, but can be higher. reverse of `str_to_numeric`
    """

    # small values are returned as is
    num_abs: float = abs(num)
    if num_abs < 1e3:
        return str(num)

    # iterate over suffixes from largest to smallest
    for i, (val, suffix) in enumerate(_SHORTEN_TUPLES):
        if num_abs > val or i == len(_SHORTEN_TUPLES) - 1:
            if (num_abs < val * 10) and small_as_decimal:
                return f"{num / val:.{precision}f}{suffix}"
            elif num_abs < val * 1e3:
                return f"{int(round(num / val))}{suffix}"

    return f"{num:.{precision}f}"


def str_to_numeric(
    quantity: str,
    mapping: None | bool | dict[str, int | float] = True,
) -> int | float:
    """Convert a string representing a quantity to a numeric value.

    The string can represent an integer, python float, fraction, or shortened via `shorten_numerical_to_str`.

    # Examples:
    ```
    >>> str_to_numeric("5")
    5
    >>> str_to_numeric("0.1")
    0.1
    >>> str_to_numeric("1/5")
    0.2
    >>> str_to_numeric("-1K")
    -1000.0
    >>> str_to_numeric("1.5M")
    1500000.0
    >>> str_to_numeric("1.2e2")
    120.0
    ```

    """

    # check is string
    if not isinstance(quantity, str):
        raise TypeError(
            f"quantity must be a string, got '{type(quantity) = }' '{quantity = }'"
        )

    # basic int conversion
    try:
        quantity_int: int = int(quantity)
        return quantity_int
    except ValueError:
        pass

    # basic float conversion
    try:
        quantity_float: float = float(quantity)
        return quantity_float
    except ValueError:
        pass

    # mapping
    _mapping: dict[str, int | float]
    if mapping is True or mapping is None:
        _mapping = _REVERSE_SHORTEN_MAP
    else:
        _mapping = mapping  # type: ignore[assignment]

    quantity_original: str = quantity

    quantity = quantity.strip()

    result: int | float
    multiplier: int | float = 1

    # detect if it has a suffix
    suffixes_detected: list[bool] = [suffix in quantity for suffix in _mapping]
    n_suffixes_detected: int = sum(suffixes_detected)
    if n_suffixes_detected == 0:
        # no suffix
        pass
    elif n_suffixes_detected == 1:
        # find multiplier
        for suffix, mult in _mapping.items():
            if quantity.endswith(suffix):
                # remove suffix, store multiplier, and break
                quantity = quantity[: -len(suffix)].strip()
                multiplier = mult
                break
        else:
            raise ValueError(f"Invalid suffix in {quantity_original}")
    else:
        # multiple suffixes
        raise ValueError(f"Multiple suffixes detected in {quantity_original}")

    # fractions
    if "/" in quantity:
        try:
            assert quantity.count("/") == 1, "too many '/'"
            # split and strip
            num, den = quantity.split("/")
            num = num.strip()
            den = den.strip()
            num_sign: int = 1
            # negative numbers
            if num.startswith("-"):
                num_sign = -1
                num = num[1:]
            # assert that both are digits
            assert num.isdigit() and den.isdigit(), (
                "numerator and denominator must be digits"
            )
            # return the fraction
            result = num_sign * (
                int(num) / int(den)
            )  # this allows for fractions with suffixes, which is weird, but whatever
        except AssertionError as e:
            raise ValueError(f"Invalid fraction {quantity_original}: {e}") from e

    # decimals
    else:
        try:
            result = int(quantity)
        except ValueError:
            try:
                result = float(quantity)
            except ValueError as e:
                raise ValueError(
                    f"Invalid quantity {quantity_original} ({quantity})"
                ) from e

    return result * multiplier

``````{ end_of_file="muutils/misc/numerical.py" }

``````{ path="muutils/misc/sequence.py"  }
from __future__ import annotations

from typing import (
    Iterable,
    Any,
    Generator,
    Callable,
    Union,
)

import typing
from typing import (
    Literal,
    Mapping,
)


WhenMissing = Literal["except", "skip", "include"]


def empty_sequence_if_attr_false(
    itr: Iterable[Any],
    attr_owner: Any,
    attr_name: str,
) -> Iterable[Any]:
    """Returns `itr` if `attr_owner` has the attribute `attr_name` and it boolean casts to `True`. Returns an empty sequence otherwise.

    Particularly useful for optionally inserting delimiters into a sequence depending on an `TokenizerElement` attribute.

    # Parameters:
    - `itr: Iterable[Any]`
        The iterable to return if the attribute is `True`.
    - `attr_owner: Any`
        The object to check for the attribute.
    - `attr_name: str`
        The name of the attribute to check.

    # Returns:
    - `itr: Iterable` if `attr_owner` has the attribute `attr_name` and it boolean casts to `True`, otherwise an empty sequence.
    - `()` an empty sequence if the attribute is `False` or not present.
    """
    return itr if bool(getattr(attr_owner, attr_name, False)) else ()


def flatten(it: Iterable[Any], levels_to_flatten: int | None = None) -> Generator:
    """
    Flattens an arbitrarily nested iterable.
    Flattens all iterable data types except for `str` and `bytes`.

    # Returns
    Generator over the flattened sequence.

    # Parameters
    - `it`: Any arbitrarily nested iterable.
    - `levels_to_flatten`: Number of levels to flatten by, starting at the outermost layer. If `None`, performs full flattening.
    """
    for x in it:
        # TODO: swap type check with more general check for __iter__() or __next__() or whatever
        if (
            hasattr(x, "__iter__")
            and not isinstance(x, (str, bytes))
            and (levels_to_flatten is None or levels_to_flatten > 0)
        ):
            yield from flatten(
                x, None if levels_to_flatten is None else levels_to_flatten - 1
            )
        else:
            yield x


# string-like operations on lists
# --------------------------------------------------------------------------------


def list_split(lst: list, val: Any) -> list[list]:
    """split a list into sublists by `val`. similar to "a_b_c".split("_")

    ```python
    >>> list_split([1,2,3,0,4,5,0,6], 0)
    [[1, 2, 3], [4, 5], [6]]
    >>> list_split([0,1,2,3], 0)
    [[], [1, 2, 3]]
    >>> list_split([1,2,3], 0)
    [[1, 2, 3]]
    >>> list_split([], 0)
    [[]]
    ```

    """

    if len(lst) == 0:
        return [[]]

    output: list[list] = [
        [],
    ]

    for x in lst:
        if x == val:
            output.append([])
        else:
            output[-1].append(x)
    return output


def list_join(lst: list, factory: Callable) -> list:
    """add a *new* instance of `factory()` between each element of `lst`

    ```python
    >>> list_join([1,2,3], lambda : 0)
    [1,0,2,0,3]
    >>> list_join([1,2,3], lambda: [time.sleep(0.1), time.time()][1])
    [1, 1600000000.0, 2, 1600000000.1, 3]
    ```
    """

    if len(lst) == 0:
        return []

    output: list = [
        lst[0],
    ]

    for x in lst[1:]:
        output.append(factory())
        output.append(x)

    return output


# applying mappings
# --------------------------------------------------------------------------------

_AM_K = typing.TypeVar("_AM_K")
_AM_V = typing.TypeVar("_AM_V")


def apply_mapping(
    mapping: Mapping[_AM_K, _AM_V],
    iter: Iterable[_AM_K],
    when_missing: WhenMissing = "skip",
) -> list[Union[_AM_K, _AM_V]]:
    """Given an iterable and a mapping, apply the mapping to the iterable with certain options

    Gotcha: if `when_missing` is invalid, this is totally fine until a missing key is actually encountered.

    Note: you can use this with `muutils.kappa.Kappa` if you want to pass a function instead of a dict

    # Parameters:
     - `mapping : Mapping[_AM_K, _AM_V]`
        must have `__contains__` and `__getitem__`, both of which take `_AM_K` and the latter returns `_AM_V`
     - `iter : Iterable[_AM_K]`
        the iterable to apply the mapping to
     - `when_missing : WhenMissing`
        what to do when a key is missing from the mapping -- this is what distinguishes this function from `map`
        you can choose from `"skip"`, `"include"` (without converting), and `"except"`
       (defaults to `"skip"`)

    # Returns:
    return type is one of:
     - `list[_AM_V]` if `when_missing` is `"skip"` or `"except"`
     - `list[Union[_AM_K, _AM_V]]` if `when_missing` is `"include"`

    # Raises:
     - `KeyError` : if the item is missing from the mapping and `when_missing` is `"except"`
     - `ValueError` : if `when_missing` is invalid
    """
    output: list[Union[_AM_K, _AM_V]] = list()
    item: _AM_K
    for item in iter:
        if item in mapping:
            output.append(mapping[item])
            continue
        if when_missing == "skip":
            continue
        elif when_missing == "include":
            output.append(item)
        elif when_missing == "except":
            raise KeyError(f"item {item} is missing from mapping {mapping}")
        else:
            raise ValueError(
                f"invalid value for {when_missing = }\n{item = }\n{mapping = }"
            )
    return output


def apply_mapping_chain(
    mapping: Mapping[_AM_K, Iterable[_AM_V]],
    iter: Iterable[_AM_K],
    when_missing: WhenMissing = "skip",
) -> list[Union[_AM_K, _AM_V]]:
    """Given an iterable and a mapping, chain the mappings together

    Gotcha: if `when_missing` is invalid, this is totally fine until a missing key is actually encountered.

    Note: you can use this with `muutils.kappa.Kappa` if you want to pass a function instead of a dict

    # Parameters:
    - `mapping : Mapping[_AM_K, Iterable[_AM_V]]`
        must have `__contains__` and `__getitem__`, both of which take `_AM_K` and the latter returns `Iterable[_AM_V]`
    - `iter : Iterable[_AM_K]`
        the iterable to apply the mapping to
    - `when_missing : WhenMissing`
        what to do when a key is missing from the mapping -- this is what distinguishes this function from `map`
        you can choose from `"skip"`, `"include"` (without converting), and `"except"`
    (defaults to `"skip"`)

    # Returns:
    return type is one of:
     - `list[_AM_V]` if `when_missing` is `"skip"` or `"except"`
     - `list[Union[_AM_K, _AM_V]]` if `when_missing` is `"include"`

    # Raises:
    - `KeyError` : if the item is missing from the mapping and `when_missing` is `"except"`
    - `ValueError` : if `when_missing` is invalid

    """
    output: list[Union[_AM_K, _AM_V]] = list()
    item: _AM_K
    for item in iter:
        if item in mapping:
            output.extend(mapping[item])
            continue
        if when_missing == "skip":
            continue
        elif when_missing == "include":
            output.append(item)
        elif when_missing == "except":
            raise KeyError(f"item {item} is missing from mapping {mapping}")
        else:
            raise ValueError(
                f"invalid value for {when_missing = }\n{item = }\n{mapping = }"
            )
    return output

``````{ end_of_file="muutils/misc/sequence.py" }

``````{ path="muutils/misc/string.py"  }
from __future__ import annotations


from muutils.misc.hashing import stable_hash


def sanitize_name(
    name: str | None,
    additional_allowed_chars: str = "",
    replace_invalid: str = "",
    when_none: str | None = "_None_",
    leading_digit_prefix: str = "",
) -> str:
    """sanitize a string, leaving only alphanumerics and `additional_allowed_chars`

    # Parameters:
     - `name : str | None`
       input string
     - `additional_allowed_chars : str`
       additional characters to allow, none by default
       (defaults to `""`)
     - `replace_invalid : str`
        character to replace invalid characters with
       (defaults to `""`)
     - `when_none : str | None`
        string to return if `name` is `None`. if `None`, raises an exception
       (defaults to `"_None_"`)
     - `leading_digit_prefix : str`
        character to prefix the string with if it starts with a digit
       (defaults to `""`)

    # Returns:
     - `str`
        sanitized string
    """

    if name is None:
        if when_none is None:
            raise ValueError("name is None")
        else:
            return when_none

    sanitized: str = ""
    for char in name:
        if char.isalnum():
            sanitized += char
        elif char in additional_allowed_chars:
            sanitized += char
        else:
            sanitized += replace_invalid

    if sanitized[0].isdigit():
        sanitized = leading_digit_prefix + sanitized

    return sanitized


def sanitize_fname(fname: str | None, **kwargs) -> str:
    """sanitize a filename to posix standards

    - leave only alphanumerics, `_` (underscore), '-' (dash) and `.` (period)
    """
    return sanitize_name(fname, additional_allowed_chars="._-", **kwargs)


def sanitize_identifier(fname: str | None, **kwargs) -> str:
    """sanitize an identifier (variable or function name)

    - leave only alphanumerics and `_` (underscore)
    - prefix with `_` if it starts with a digit
    """
    return sanitize_name(
        fname, additional_allowed_chars="_", leading_digit_prefix="_", **kwargs
    )


def dict_to_filename(
    data: dict,
    format_str: str = "{key}_{val}",
    separator: str = ".",
    max_length: int = 255,
):
    # Convert the dictionary items to a list of strings using the format string
    formatted_items: list[str] = [
        format_str.format(key=k, val=v) for k, v in data.items()
    ]

    # Join the formatted items using the separator
    joined_str: str = separator.join(formatted_items)

    # Remove special characters and spaces
    sanitized_str: str = sanitize_fname(joined_str)

    # Check if the length is within limits
    if len(sanitized_str) <= max_length:
        return sanitized_str

    # If the string is too long, generate a hash
    return f"h_{stable_hash(sanitized_str)}"


def dynamic_docstring(**doc_params):
    def decorator(func):
        if func.__doc__:
            func.__doc__ = func.__doc__.format(**doc_params)
        return func

    return decorator

``````{ end_of_file="muutils/misc/string.py" }

``````{ path="muutils/nbutils/__init__.py"  }
"""utilities for working with notebooks

- configuring figures mdoes and torch devices: `configure_notebook`
- converting them to scripts: `convert_ipynb_to_script`
- running them as tests: `run_notebook_tests`
- and working with diagrams/LaTeX: `mermaid`, `print_tex`

"""

from muutils.nbutils.mermaid import mm

__all__ = [
    # sub-modules
    "configure_notebook",
    "convert_ipynb_to_script",
    "mermaid",
    "print_tex",
    "run_notebook_tests",
    # functions
    "mm",
]

``````{ end_of_file="muutils/nbutils/__init__.py" }

``````{ path="muutils/nbutils/configure_notebook.py"  }
"""shared utilities for setting up a notebook"""

from __future__ import annotations

import os
import typing
import warnings

import matplotlib.pyplot as plt  # type: ignore[import]


class PlotlyNotInstalledWarning(UserWarning):
    pass


# handle plotly importing
PLOTLY_IMPORTED: bool
try:
    import plotly.io as pio  # type: ignore[import]
except ImportError:
    warnings.warn(
        "Plotly not installed. Plotly plots will not be available.",
        PlotlyNotInstalledWarning,
    )
    PLOTLY_IMPORTED = False
else:
    PLOTLY_IMPORTED = True

# figure out if we're in a jupyter notebook
try:
    from IPython import get_ipython  # type: ignore[import-not-found]

    IN_JUPYTER = get_ipython() is not None
except ImportError:
    IN_JUPYTER = False

# muutils imports
from muutils.mlutils import get_device, set_reproducibility  # noqa: E402

# handling figures
PlottingMode = typing.Literal["ignore", "inline", "widget", "save"]
PLOT_MODE: PlottingMode = "inline"
CONVERSION_PLOTMODE_OVERRIDE: PlottingMode | None = None
FIG_COUNTER: int = 0
FIG_OUTPUT_FMT: str | None = None
FIG_NUMBERED_FNAME: str = "figure-{num}"
FIG_CONFIG: dict | None = None
FIG_BASEPATH: str | None = None
CLOSE_AFTER_PLOTSHOW: bool = False

MATPLOTLIB_FORMATS = ["pdf", "png", "jpg", "jpeg", "svg", "eps", "ps", "tif", "tiff"]
TIKZPLOTLIB_FORMATS = ["tex", "tikz"]


class UnknownFigureFormatWarning(UserWarning):
    pass


def universal_savefig(fname: str, fmt: str | None = None) -> None:
    # try to infer format from fname
    if fmt is None:
        fmt = fname.split(".")[-1]

    if not (fmt in MATPLOTLIB_FORMATS or fmt in TIKZPLOTLIB_FORMATS):
        warnings.warn(
            f"Unknown format '{fmt}', defaulting to '{FIG_OUTPUT_FMT}'",
            UnknownFigureFormatWarning,
        )
        fmt = FIG_OUTPUT_FMT

    # not sure why linting is throwing an error here
    if not fname.endswith(fmt):  # type: ignore[arg-type]
        fname += f".{fmt}"

    if fmt in MATPLOTLIB_FORMATS:
        plt.savefig(fname, format=fmt, bbox_inches="tight")
    elif fmt in TIKZPLOTLIB_FORMATS:
        import tikzplotlib  # type: ignore[import]

        tikzplotlib.save(fname)
    else:
        warnings.warn(f"Unknown format '{fmt}', going with matplotlib default")
        plt.savefig(fname, bbox_inches="tight")


def setup_plots(
    plot_mode: PlottingMode = "inline",
    fig_output_fmt: str | None = "pdf",
    fig_numbered_fname: str = "figure-{num}",
    fig_config: dict | None = None,
    fig_basepath: str | None = None,
    close_after_plotshow: bool = False,
) -> None:
    """Set up plot saving/rendering options"""
    global \
        PLOT_MODE, \
        CONVERSION_PLOTMODE_OVERRIDE, \
        FIG_COUNTER, \
        FIG_OUTPUT_FMT, \
        FIG_NUMBERED_FNAME, \
        FIG_CONFIG, \
        FIG_BASEPATH, \
        CLOSE_AFTER_PLOTSHOW

    # set plot mode, handling override
    if CONVERSION_PLOTMODE_OVERRIDE is not None:
        # override if set
        PLOT_MODE = CONVERSION_PLOTMODE_OVERRIDE
    else:
        # otherwise use the given plot mode
        PLOT_MODE = plot_mode

    FIG_COUNTER = 0
    CLOSE_AFTER_PLOTSHOW = close_after_plotshow

    if PLOT_MODE == "inline":
        if IN_JUPYTER:
            ipython = get_ipython()
            ipython.magic("matplotlib inline")
        else:
            raise RuntimeError(
                f"Cannot use inline plotting outside of Jupyter\n{PLOT_MODE = }\t{CONVERSION_PLOTMODE_OVERRIDE = }"
            )
        return
    elif PLOT_MODE == "widget":
        if IN_JUPYTER:
            ipython = get_ipython()
            ipython.magic("matplotlib widget")
        else:
            # matplotlib outside of jupyter will bring up a new window by default
            pass
        return
    elif PLOT_MODE == "ignore":
        # disable plotting
        plt.show = lambda: None  # type: ignore[misc]
        return

    # everything except saving handled up to this point
    assert PLOT_MODE == "save", f"Invalid plot mode: {PLOT_MODE}"

    FIG_OUTPUT_FMT = fig_output_fmt
    FIG_NUMBERED_FNAME = fig_numbered_fname
    FIG_CONFIG = fig_config

    # set default figure format in rcParams savefig.format
    plt.rcParams["savefig.format"] = FIG_OUTPUT_FMT
    if FIG_OUTPUT_FMT in TIKZPLOTLIB_FORMATS:
        try:
            import tikzplotlib  # type: ignore[import] # noqa: F401
        except ImportError:
            warnings.warn(
                f"Tikzplotlib not installed. Cannot save figures in Tikz format '{FIG_OUTPUT_FMT}', things might break."
            )
    else:
        if FIG_OUTPUT_FMT not in MATPLOTLIB_FORMATS:
            warnings.warn(
                f'Unknown figure format, things might break: {plt.rcParams["savefig.format"] = }'
            )

    # if base path not given, make one
    if fig_basepath is None:
        if fig_config is None:
            # if no config, use the current time
            from datetime import datetime

            fig_basepath = f"figures/{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
        else:
            # if config given, convert to string
            from muutils.misc import dict_to_filename

            fig_basepath = f"figures/{dict_to_filename(fig_config)}"

    FIG_BASEPATH = fig_basepath
    os.makedirs(fig_basepath, exist_ok=True)

    # if config given, serialize and save that config
    if fig_config is not None:
        import json

        from muutils.json_serialize import json_serialize

        with open(f"{fig_basepath}/config.json", "w") as f:
            json.dump(
                json_serialize(fig_config),
                f,
                indent="\t",
            )

    print(f"Figures will be saved to: '{fig_basepath}'")


def configure_notebook(
    *args,
    seed: int = 42,
    device: typing.Any = None,  # this can be a string, torch.device, or None
    dark_mode: bool = True,
    plot_mode: PlottingMode = "inline",
    fig_output_fmt: str | None = "pdf",
    fig_numbered_fname: str = "figure-{num}",
    fig_config: dict | None = None,
    fig_basepath: str | None = None,
    close_after_plotshow: bool = False,
) -> "torch.device|None":  # type: ignore[name-defined] # noqa: F821
    """Shared Jupyter notebook setup steps

    - Set random seeds and library reproducibility settings
    - Set device based on availability
    - Set module reloading before code execution
    - Set plot formatting
    - Set plot saving/rendering options

    # Parameters:
     - `seed : int`
        random seed across libraries including torch, numpy, and random (defaults to `42`)
       (defaults to `42`)
     - `device : typing.Any`
       pytorch device to use
       (defaults to `None`)
     - `dark_mode : bool`
       figures in dark mode
       (defaults to `True`)
     - `plot_mode : PlottingMode`
       how to display plots, one of `PlottingMode` or `["ignore", "inline", "widget", "save"]`
       (defaults to `"inline"`)
     - `fig_output_fmt : str | None`
       format for saving figures
       (defaults to `"pdf"`)
     - `fig_numbered_fname : str`
        format for saving figures with numbers (if they aren't named)
       (defaults to `"figure-{num}"`)
     - `fig_config : dict | None`
       metadata to save with the figures
       (defaults to `None`)
     - `fig_basepath : str | None`
        base path for saving figures
       (defaults to `None`)
     - `close_after_plotshow : bool`
        close figures after showing them
       (defaults to `False`)

    # Returns:
     - `torch.device|None`
       the device set, if torch is installed
    """

    # set some globals related to plotting
    setup_plots(
        plot_mode=plot_mode,
        fig_output_fmt=fig_output_fmt,
        fig_numbered_fname=fig_numbered_fname,
        fig_config=fig_config,
        fig_basepath=fig_basepath,
        close_after_plotshow=close_after_plotshow,
    )

    global PLOT_MODE, FIG_OUTPUT_FMT, FIG_BASEPATH

    print(f"set up plots with {PLOT_MODE = }, {FIG_OUTPUT_FMT = }, {FIG_BASEPATH = }")

    # Set seeds and other reproducibility-related library options
    set_reproducibility(seed)

    # Reload modules before executing user code
    if IN_JUPYTER:
        ipython = get_ipython()
        if "IPython.extensions.autoreload" not in ipython.extension_manager.loaded:
            ipython.magic("load_ext autoreload")
            ipython.magic("autoreload 2")

        # Specify plotly renderer for vscode
        if PLOTLY_IMPORTED:
            pio.renderers.default = "notebook_connected"

            if dark_mode:
                pio.templates.default = "plotly_dark"
                plt.style.use("dark_background")

    try:
        # Set device
        device = get_device(device)
        return device
    except ImportError:
        warnings.warn("Torch not installed. Cannot get/set device.")
        return None


def plotshow(
    fname: str | None = None,
    plot_mode: PlottingMode | None = None,
    fmt: str | None = None,
):
    """Show the active plot, depending on global configs"""
    global FIG_COUNTER, CLOSE_AFTER_PLOTSHOW, PLOT_MODE
    FIG_COUNTER += 1

    if plot_mode is None:
        plot_mode = PLOT_MODE

    if plot_mode == "save":
        # get numbered figure name if not given
        if fname is None:
            fname = FIG_NUMBERED_FNAME.format(num=FIG_COUNTER)

        # save figure
        assert FIG_BASEPATH is not None
        universal_savefig(os.path.join(FIG_BASEPATH, fname), fmt=fmt)
    elif plot_mode == "ignore":
        # do nothing
        pass
    elif plot_mode == "inline":
        # show figure
        plt.show()
    elif plot_mode == "widget":
        # show figure
        plt.show()
    else:
        warnings.warn(f"Invalid plot mode: {plot_mode}")

    if CLOSE_AFTER_PLOTSHOW:
        plt.close()

``````{ end_of_file="muutils/nbutils/configure_notebook.py" }

``````{ path="muutils/nbutils/convert_ipynb_to_script.py"  }
"""fast conversion of Jupyter Notebooks to scripts, with some basic and hacky filtering and formatting."""

from __future__ import annotations

import argparse
import json
import os
from pathlib import Path
import sys
import typing
import warnings

from muutils.spinner import SpinnerContext

DISABLE_PLOTS: dict[str, list[str]] = {
    "matplotlib": [
        """
# ------------------------------------------------------------
# Disable matplotlib plots, done during processing by `convert_ipynb_to_script.py`
import matplotlib.pyplot as plt
plt.show = lambda: None
# ------------------------------------------------------------
"""
    ],
    "circuitsvis": [
        """
# ------------------------------------------------------------
# Disable circuitsvis plots, done during processing by `convert_ipynb_to_script.py`
from circuitsvis.utils.convert_props import PythonProperty, convert_props
from circuitsvis.utils.render import RenderedHTML, render, render_cdn, render_local

def new_render(
    react_element_name: str,
    **kwargs: PythonProperty
) -> RenderedHTML:
    "return a visualization as raw HTML"
    local_src = render_local(react_element_name, **kwargs)
    cdn_src = render_cdn(react_element_name, **kwargs)
    # return as string instead of RenderedHTML for CI
    return str(RenderedHTML(local_src, cdn_src))

render = new_render
# ------------------------------------------------------------
"""
    ],
    "muutils": [
        """import muutils.nbutils.configure_notebook as nb_conf
nb_conf.CONVERSION_PLOTMODE_OVERRIDE = "ignore"
"""
    ],
}

DISABLE_PLOTS_WARNING: list[str] = [
    """
# ------------------------------------------------------------
# WARNING: this script is auto-generated by `convert_ipynb_to_script.py`
# showing plots has been disabled, so this is presumably in a temp dict for CI or something
# so don't modify this code, it will be overwritten!
# ------------------------------------------------------------
""".lstrip()
]


def disable_plots_in_script(script_lines: list[str]) -> list[str]:
    """Disable plots in a script by adding cursed things after the import statements"""
    result_str_TEMP: str = "\n\n".join(script_lines)
    script_lines_new: list[str] = script_lines

    if "muutils" in result_str_TEMP:
        script_lines_new = DISABLE_PLOTS["muutils"] + script_lines_new

    if "matplotlib" in result_str_TEMP:
        assert "import matplotlib.pyplot as plt" in result_str_TEMP, (
            "matplotlib.pyplot must be imported as plt"
        )

        # find the last import statement involving matplotlib, and the first line that uses plt
        mpl_last_import_index: int = -1
        mpl_first_usage_index: int = -1
        for i, line in enumerate(script_lines_new):
            if "matplotlib" in line and (("import" in line) or ("from" in line)):
                mpl_last_import_index = i

            if "configure_notebook" in line:
                mpl_last_import_index = i

            if "plt." in line:
                mpl_first_usage_index = i

        assert mpl_last_import_index != -1, (
            f"matplotlib imports not found! see line {mpl_last_import_index}"
        )
        if mpl_first_usage_index != -1:
            assert mpl_first_usage_index > mpl_last_import_index, (
                f"matplotlib plots created before import! see lines {mpl_first_usage_index}, {mpl_last_import_index}"
            )
        else:
            warnings.warn(
                "could not find where matplotlib is used, plot disabling might not work!"
            )

        # insert the cursed things
        script_lines_new = (
            script_lines_new[: mpl_last_import_index + 1]
            + DISABLE_PLOTS["matplotlib"]
            + script_lines_new[mpl_last_import_index + 1 :]
        )
        result_str_TEMP = "\n\n".join(script_lines_new)

    if "circuitsvis" in result_str_TEMP:
        # find the last import statement involving circuitsvis, and the first line that uses it
        cirv_last_import_index: int = -1
        cirv_first_usage_index: int = -1

        for i, line in enumerate(script_lines_new):
            if "circuitsvis" in line:
                if (("import" in line) or ("from" in line)) and "circuitsvis" in line:
                    cirv_last_import_index = i
                else:
                    cirv_first_usage_index = i

                if "configure_notebook" in line:
                    mpl_last_import_index = i

                if "render" in line:
                    cirv_first_usage_index = i

        assert cirv_last_import_index != -1, (
            f"circuitsvis imports not found! see line {cirv_last_import_index}"
        )
        if cirv_first_usage_index != -1:
            assert cirv_first_usage_index > cirv_last_import_index, (
                f"circuitsvis plots created before import! see lines {cirv_first_usage_index}, {cirv_last_import_index}"
            )
        else:
            warnings.warn(
                "could not find where circuitsvis is used, plot disabling might not work!"
            )

        # insert the cursed things
        script_lines_new = (
            script_lines_new[: cirv_last_import_index + 1]
            + DISABLE_PLOTS["circuitsvis"]
            + script_lines_new[cirv_last_import_index + 1 :]
        )
        result_str_TEMP = "\n\n".join(script_lines_new)

    return script_lines_new


def convert_ipynb(
    notebook: dict,
    strip_md_cells: bool = False,
    header_comment: str = r"#%%",
    disable_plots: bool = False,
    filter_out_lines: str | typing.Sequence[str] = (
        "%",
        "!",
    ),  # ignore notebook magic commands and shell commands
) -> str:
    """Convert Jupyter Notebook to a script, doing some basic filtering and formatting.

    # Arguments
        - `notebook: dict`: Jupyter Notebook loaded as json.
        - `strip_md_cells: bool = False`: Remove markdown cells from the output script.
        - `header_comment: str = r'#%%'`: Comment string to separate cells in the output script.
        - `disable_plots: bool = False`: Disable plots in the output script.
        - `filter_out_lines: str|typing.Sequence[str] = ('%', '!')`: comment out lines starting with these strings (in code blocks).
            if a string is passed, it will be split by char and each char will be treated as a separate filter.

    # Returns
        - `str`: Converted script.
    """

    if isinstance(filter_out_lines, str):
        filter_out_lines = tuple(filter_out_lines)
    filter_out_lines_set: set = set(filter_out_lines)

    result: list[str] = []

    all_cells: list[dict] = notebook["cells"]

    for cell in all_cells:
        cell_type: str = cell["cell_type"]

        if not strip_md_cells and cell_type == "markdown":
            result.append(f'{header_comment}\n"""\n{"".join(cell["source"])}\n"""')
        elif cell_type == "code":
            source: list[str] = cell["source"]
            if filter_out_lines:
                source = [
                    (
                        f"#{line}"
                        if any(
                            line.startswith(filter_prefix)
                            for filter_prefix in filter_out_lines_set
                        )
                        else line
                    )
                    for line in source
                ]
            result.append(f"{header_comment}\n{''.join(source)}")

    if disable_plots:
        result = disable_plots_in_script(result)
        result = DISABLE_PLOTS_WARNING + result

    return "\n\n".join(result)


def process_file(
    in_file: str,
    out_file: str | None = None,
    strip_md_cells: bool = False,
    header_comment: str = r"#%%",
    disable_plots: bool = False,
    filter_out_lines: str | typing.Sequence[str] = ("%", "!"),
):
    print(f"\tProcessing {in_file}...", file=sys.stderr)
    assert os.path.exists(in_file), f"File {in_file} does not exist."
    assert os.path.isfile(in_file), f"Path {in_file} is not a file."
    assert in_file.endswith(".ipynb"), f"File {in_file} is not a Jupyter Notebook."

    with open(in_file, "r") as file:
        notebook: dict = json.load(file)

    try:
        converted_script: str = convert_ipynb(
            notebook=notebook,
            strip_md_cells=strip_md_cells,
            header_comment=header_comment,
            disable_plots=disable_plots,
            filter_out_lines=filter_out_lines,
        )
    except AssertionError as e:
        print(f"Error converting {in_file}: {e}", file=sys.stderr)
        raise e

    if out_file:
        with open(out_file, "w") as file:
            file.write(converted_script)
    else:
        print(converted_script)


def process_dir(
    input_dir: typing.Union[str, Path],
    output_dir: typing.Union[str, Path],
    strip_md_cells: bool = False,
    header_comment: str = r"#%%",
    disable_plots: bool = False,
    filter_out_lines: str | typing.Sequence[str] = ("%", "!"),
):
    """Convert all Jupyter Notebooks in a directory to scripts.

    # Arguments
        - `input_dir: str`: Input directory.
        - `output_dir: str`: Output directory.
        - `strip_md_cells: bool = False`: Remove markdown cells from the output script.
        - `header_comment: str = r'#%%'`: Comment string to separate cells in the output script.
        - `disable_plots: bool = False`: Disable plots in the output script.
        - `filter_out_lines: str|typing.Sequence[str] = ('%', '!')`: comment out lines starting with these strings (in code blocks).
            if a string is passed, it will be split by char and each char will be treated as a separate filter.
    """

    assert os.path.exists(input_dir), f"Directory {input_dir} does not exist."
    assert os.path.isdir(input_dir), f"Path {input_dir} is not a directory."

    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)

    filenames: list[str] = [
        fname for fname in os.listdir(input_dir) if fname.endswith(".ipynb")
    ]

    assert filenames, f"Directory {input_dir} does not contain any Jupyter Notebooks."
    n_files: int = len(filenames)
    print(f"Converting {n_files} notebooks:", file=sys.stderr)

    with SpinnerContext(
        spinner_chars="braille",
        update_interval=0.01,
        format_string_when_updated=True,
        output_stream=sys.stderr,
    ) as spinner:
        for idx, fname in enumerate(filenames):
            spinner.update_value(f"\tConverting {idx + 1}/{n_files}: {fname}")
            in_file: str = os.path.join(input_dir, fname)
            out_file: str = os.path.join(output_dir, fname.replace(".ipynb", ".py"))

            with open(in_file, "r", encoding="utf-8") as file_in:
                notebook: dict = json.load(file_in)

            try:
                converted_script: str = convert_ipynb(
                    notebook=notebook,
                    strip_md_cells=strip_md_cells,
                    header_comment=header_comment,
                    disable_plots=disable_plots,
                    filter_out_lines=filter_out_lines,
                )
            except AssertionError as e:
                spinner.stop()
                raise Exception(f"Error converting {in_file}") from e

            with open(out_file, "w", encoding="utf-8") as file_out:
                file_out.write(converted_script)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Convert Jupyter Notebook to a script with cell separators."
    )
    parser.add_argument(
        "in_path",
        type=str,
        help="Input Jupyter Notebook file (.ipynb) or directory of files.",
    )
    parser.add_argument(
        "--out-file",
        type=str,
        help="Output script file. If not specified, the result will be printed to stdout.",
    )
    parser.add_argument(
        "--output-dir", type=str, help="Output directory for converted script files."
    )
    parser.add_argument(
        "--strip-md-cells",
        action="store_true",
        help="Remove markdown cells from the output script.",
    )
    parser.add_argument(
        "--header-comment",
        type=str,
        default=r"#%%",
        help="Comment string to separate cells in the output script.",
    )
    parser.add_argument(
        "--disable-plots",
        action="store_true",
        help="Disable plots in the output script. Useful for testing in CI.",
    )
    parser.add_argument(
        "--filter-out-lines",
        type=str,
        default="%",
        help="Comment out lines starting with these characters.",
    )

    args = parser.parse_args()

    if args.output_dir:
        assert not args.out_file, "Cannot specify both --out_file and --output_dir."
        process_dir(
            input_dir=args.in_path,
            output_dir=args.output_dir,
            strip_md_cells=args.strip_md_cells,
            header_comment=args.header_comment,
            disable_plots=args.disable_plots,
            filter_out_lines=args.filter_out_lines,
        )

    else:
        process_file(
            in_file=args.in_path,
            out_file=args.out_file,
            strip_md_cells=args.strip_md_cells,
            header_comment=args.header_comment,
            disable_plots=args.disable_plots,
            filter_out_lines=args.filter_out_lines,
        )


print("muutils.nbutils.convert_ipynb_to_script.py loaded.")

``````{ end_of_file="muutils/nbutils/convert_ipynb_to_script.py" }

``````{ path="muutils/nbutils/mermaid.py"  }
"""display mermaid.js diagrams in jupyter notebooks by the `mermaid.ink/img` service"""

import base64

try:
    from IPython.display import Image, display
except ImportError:
    import warnings

    warnings.warn(
        "IPython.display could not be imported, mermaid will not work", ImportWarning
    )


def mm(graph):
    """for plotting mermaid.js diagrams"""
    graphbytes = graph.encode("ascii")
    base64_bytes = base64.b64encode(graphbytes)
    base64_string = base64_bytes.decode("ascii")
    display(Image(url="https://mermaid.ink/img/" + base64_string))

``````{ end_of_file="muutils/nbutils/mermaid.py" }

``````{ path="muutils/nbutils/print_tex.py"  }
"""quickly print a sympy expression in latex"""

import sympy as sp  # type: ignore
from IPython.display import Math, display  # type: ignore


def print_tex(
    expr: sp.Expr,
    name: str | None = None,
    plain: bool = False,
    rendered: bool = True,
):
    """function for easily rendering a sympy expression in latex"""
    out: str = sp.latex(expr)
    if name is not None:
        out = f"{name} = {out}"

    if plain:
        print(out)
    if rendered:
        display(Math(out))

``````{ end_of_file="muutils/nbutils/print_tex.py" }

``````{ path="muutils/nbutils/run_notebook_tests.py"  }
"""turn a folder of notebooks into scripts, run them, and make sure they work.

made to be called as

```bash
python -m muutils.nbutils.run_notebook_tests --notebooks-dir <notebooks_dir> --converted-notebooks-temp-dir <converted_notebooks_temp_dir>
```
"""

import os
import subprocess
import sys
from pathlib import Path
from typing import Optional
import warnings

from muutils.console_unicode import get_console_safe_str
from muutils.spinner import SpinnerContext


class NotebookTestError(Exception):
    pass


SUCCESS_STR: str = get_console_safe_str("✅", "[OK]")
FAILURE_STR: str = get_console_safe_str("❌", "[!!]")


def run_notebook_tests(
    notebooks_dir: Path,
    converted_notebooks_temp_dir: Path,
    CI_output_suffix: str = ".CI-output.txt",
    run_python_cmd: Optional[str] = None,
    run_python_cmd_fmt: str = "{python_tool} run python",
    python_tool: str = "poetry",
    exit_on_first_fail: bool = False,
):
    """Run converted Jupyter notebooks as Python scripts and verify they execute successfully.

    Takes a directory of notebooks and their corresponding converted Python scripts,
    executes each script, and captures the output. Failures are collected and reported,
    with optional early exit on first failure.

    # Parameters:
     - `notebooks_dir : Path`
        Directory containing the original .ipynb notebook files
     - `converted_notebooks_temp_dir : Path`
        Directory containing the corresponding converted .py files
     - `CI_output_suffix : str`
        Suffix to append to output files capturing execution results
        (defaults to `".CI-output.txt"`)
     - `run_python_cmd : str | None`
        Custom command to run Python scripts. Overrides python_tool and run_python_cmd_fmt if provided
        (defaults to `None`)
     - `run_python_cmd_fmt : str`
        Format string for constructing the Python run command
        (defaults to `"{python_tool} run python"`)
     - `python_tool : str`
        Tool used to run Python (e.g. poetry, uv)
        (defaults to `"poetry"`)
     - `exit_on_first_fail : bool`
        Whether to raise exception immediately on first notebook failure
        (defaults to `False`)

    # Returns:
     - `None`

    # Modifies:
     - Working directory: Temporarily changes to notebooks_dir during execution
     - Filesystem: Creates output files with CI_output_suffix for each notebook

    # Raises:
     - `NotebookTestError`: If any notebooks fail to execute, or if input directories are invalid
     - `TypeError`: If run_python_cmd is provided but not a string

    # Usage:
    ```python
    >>> run_notebook_tests(
    ...     notebooks_dir=Path("notebooks"),
    ...     converted_notebooks_temp_dir=Path("temp/converted"),
    ...     python_tool="poetry"
    ... )
    # testing notebooks in 'notebooks'
    # reading converted notebooks from 'temp/converted'
    Running 1/2: temp/converted/notebook1.py
        Output in temp/converted/notebook1.CI-output.txt
        {SUCCESS_STR} Run completed with return code 0
    ```
    """

    run_python_cmd_: str
    if run_python_cmd is None:
        run_python_cmd_ = run_python_cmd_fmt.format(python_tool=python_tool)
    elif isinstance(run_python_cmd, str):
        run_python_cmd_ = run_python_cmd
        warnings.warn(
            "You have specified a custom run_python_cmd, this will override the `python_tool` parameter and `run_python_cmd_fmt` parameter. This will be removed in a future version.",
            DeprecationWarning,
        )
    else:
        raise TypeError(
            f"run_python_cmd must be a string or None, got {run_python_cmd =}, {type(run_python_cmd) =}"
        )

    original_cwd: Path = Path.cwd()
    # get paths
    notebooks_dir = Path(notebooks_dir)
    converted_notebooks_temp_dir = Path(converted_notebooks_temp_dir)
    root_relative_to_notebooks: Path = Path(os.path.relpath(".", notebooks_dir))

    term_width: int
    try:
        term_width = os.get_terminal_size().columns
    except OSError:
        term_width = 80

    exceptions: dict[str, str] = dict()

    print(f"# testing notebooks in '{notebooks_dir}'")
    print(
        f"# reading converted notebooks from '{converted_notebooks_temp_dir.as_posix()}'"
    )

    try:
        # check things exist
        if not notebooks_dir.exists():
            raise NotebookTestError(f"Notebooks dir '{notebooks_dir}' does not exist")
        if not notebooks_dir.is_dir():
            raise NotebookTestError(
                f"Notebooks dir '{notebooks_dir}' is not a directory"
            )
        if not converted_notebooks_temp_dir.exists():
            raise NotebookTestError(
                f"Converted notebooks dir '{converted_notebooks_temp_dir}' does not exist"
            )
        if not converted_notebooks_temp_dir.is_dir():
            raise NotebookTestError(
                f"Converted notebooks dir '{converted_notebooks_temp_dir}' is not a directory"
            )

        notebooks: list[Path] = list(notebooks_dir.glob("*.ipynb"))
        if not notebooks:
            raise NotebookTestError(f"No notebooks found in '{notebooks_dir}'")

        converted_notebooks: list[Path] = list()
        for nb in notebooks:
            converted_file: Path = (
                converted_notebooks_temp_dir / nb.with_suffix(".py").name
            )
            if not converted_file.exists():
                raise NotebookTestError(
                    f"Did not find converted notebook '{converted_file}' for '{nb}'"
                )
            converted_notebooks.append(converted_file)

        del converted_file

        # the location of this line is important
        os.chdir(notebooks_dir)

        n_notebooks: int = len(converted_notebooks)
        for idx, file in enumerate(converted_notebooks):
            # run the file
            print(f"Running {idx + 1}/{n_notebooks}: {file.as_posix()}")
            output_file: Path = file.with_suffix(CI_output_suffix)
            print(f"    Output in {output_file.as_posix()}")
            with SpinnerContext(
                spinner_chars="braille",
                update_interval=0.5,
                format_string="\r    {spinner} ({elapsed_time:.2f}s) {message}{value}",
            ):
                command: str = f"{run_python_cmd_} {root_relative_to_notebooks / file} > {root_relative_to_notebooks / output_file} 2>&1"
                process: subprocess.CompletedProcess = subprocess.run(
                    command,
                    shell=True,
                    text=True,
                    env={**os.environ, "PYTHONIOENCODING": "utf-8"},
                )

            if process.returncode == 0:
                print(
                    f"    {SUCCESS_STR} Run completed with return code {process.returncode}"
                )
            else:
                print(
                    f"    {FAILURE_STR} Run failed with return code {process.returncode}!!! Check {output_file.as_posix()}"
                )

            # print the output of the file to the console if it failed
            if process.returncode != 0:
                with open(root_relative_to_notebooks / output_file, "r") as f:
                    file_output: str = f.read()
                err: str = f"Error in {file}:\n{'-' * term_width}\n{file_output}"
                exceptions[file.as_posix()] = err
                if exit_on_first_fail:
                    raise NotebookTestError(err)

            del process

        if len(exceptions) > 0:
            exceptions_str: str = ("\n" + "=" * term_width + "\n").join(
                list(exceptions.values())
            )
            raise NotebookTestError(
                exceptions_str
                + "=" * term_width
                + f"\n{FAILURE_STR} {len(exceptions)}/{n_notebooks} notebooks failed:\n{list(exceptions.keys())}"
            )

    except NotebookTestError as e:
        print("!" * term_width, file=sys.stderr)
        print(e, file=sys.stderr)
        print("!" * term_width, file=sys.stderr)
        raise e
    finally:
        # return to original cwd
        os.chdir(original_cwd)


if __name__ == "__main__":
    import argparse

    parser: argparse.ArgumentParser = argparse.ArgumentParser()

    parser.add_argument(
        "--notebooks-dir",
        type=str,
        help="The directory from which to run the notebooks",
    )
    parser.add_argument(
        "--converted-notebooks-temp-dir",
        type=str,
        help="The directory containing the converted notebooks to test",
    )
    parser.add_argument(
        "--python-tool",
        type=str,
        default="poetry",
        help="The python tool to use to run the notebooks (usually uv or poetry)",
    )
    parser.add_argument(
        "--run-python-cmd-fmt",
        type=str,
        default="{python_tool} run python",
        help="The command to run python with the python tool. if you don't want to use poetry or uv, you can just set this to 'python'",
    )

    args: argparse.Namespace = parser.parse_args()

    run_notebook_tests(
        notebooks_dir=Path(args.notebooks_dir),
        converted_notebooks_temp_dir=Path(args.converted_notebooks_temp_dir),
        python_tool=args.python_tool,
        run_python_cmd_fmt=args.run_python_cmd_fmt,
    )

``````{ end_of_file="muutils/nbutils/run_notebook_tests.py" }

``````{ path="muutils/web/__init__.py"  }
__all__ = [
    "bundle_html",
]

``````{ end_of_file="muutils/web/__init__.py" }

``````{ path="muutils/web/bundle_html.py"  }
"""
Inline / bundle external assets (CSS, JS, SVG, PNG) into an HTML document.

Default mode uses **zero external dependencies** and a few well-targeted
regular expressions.  If you install *beautifulsoup4* you can enable the
far more robust BS4 mode by passing `InlineConfig(use_bs4=True)`.
"""

from __future__ import annotations

import base64
import re
import urllib.request
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Final, Literal

# bs4 import deferred to avoid an unconditional dependency.

# constants
# ---------------------------------------------------------------------

AssetExt = Literal[".css", ".js", ".svg", ".png"]

DEFAULT_ALLOWED_EXTENSIONS: Final[set[AssetExt]] = {".css", ".js", ".svg", ".png"}

DEFAULT_TAG_ATTR: Final[dict[str, str]] = {
    "link": "href",  # <link rel="stylesheet" href="...">
    "script": "src",  # <script src="..."></script>
    "img": "src",  # <img src="...">
    "use": "xlink:href",  # <use xlink:href="sprite.svg#id">
}

MIME_BY_EXT: Final[dict[AssetExt, str]] = {
    ".css": "text/css",
    ".js": "application/javascript",
    ".svg": "image/svg+xml",
    ".png": "image/png",
}

# Configuration
# ---------------------------------------------------------------------


@dataclass
class InlineConfig:
    """High-level configuration for the inliner.

    # Parameters
    - `allowed_extensions : set[AssetExt]`
        Extensions that may be inlined.
    - `tag_attr : dict[str, str]`
        Mapping *tag -> attribute* that holds the asset reference.
    - `max_bytes : int`
        Assets larger than this are ignored.
    - `local : bool`
        Allow local filesystem assets.
    - `remote : bool`
        Allow remote http/https assets.
    - `include_filename_comments : bool`
        Surround every replacement with `<!-- begin '...' -->`
        and `<!-- end '...' -->`.
    - `use_bs4 : bool`
        Parse the document with BeautifulSoup if available.
    """

    allowed_extensions: set[AssetExt] = field(
        default_factory=lambda: set(DEFAULT_ALLOWED_EXTENSIONS)
    )
    tag_attr: dict[str, str] = field(default_factory=lambda: dict(DEFAULT_TAG_ATTR))
    max_bytes: int = 128 * 1024
    local: bool = True
    remote: bool = False
    include_filename_comments: bool = True
    use_bs4: bool = False


# Low-level helpers
# ---------------------------------------------------------------------


def _is_remote(url: str) -> bool:
    """Return *True* if *url* starts with http:// or https://."""
    return url.lower().startswith(("http://", "https://"))


def _fetch_bytes(src: str, base: Path) -> bytes:
    """Fetch *src* (local or remote) and return its raw bytes."""
    if _is_remote(src):
        with urllib.request.urlopen(src) as resp:
            return resp.read()
    return (base / src).read_bytes()


def _decode_text(buf: bytes) -> str:
    """Decode *buf* as UTF-8, falling back to replacement."""
    try:
        return buf.decode()
    except UnicodeDecodeError:
        return buf.decode("utf-8", "replace")


# Regex-based implementation (no deps)
# ---------------------------------------------------------------------


def _apply_indent(html: str, start: int, replacement: str) -> str:
    """Indent *replacement* to match the line that starts at *start*."""
    line_start: int = html.rfind("\n", 0, start) + 1
    indent: str = html[line_start:start]
    return "\n".join(indent + line for line in replacement.splitlines())


def _inline_with_regex(html: str, base: Path, cfg: InlineConfig) -> str:
    """Inline assets using pure-regex parsing (no third-party libs)."""
    tag: str
    attr: str
    for tag, attr in cfg.tag_attr.items():
        pattern: str
        if tag == "script":
            pattern = (
                rf"<script\b[^>]*\s{attr}\s*=\s*['\"]([^'\"]+)['\"][^>]*>\s*</script>"
            )
        elif tag == "link":
            pattern = rf"<link\b[^>]*\s{attr}\s*=\s*['\"]([^'\"]+)['\"][^>]*>"
        else:  # img, use, etc.
            pattern = rf"<{tag}\b[^>]*\s{attr}\s*=\s*['\"]([^'\"]+)['\"][^>]*>"

        matches: list[re.Match[str]] = list(re.finditer(pattern, html, re.IGNORECASE))
        m: re.Match[str]
        for m in reversed(matches):
            raw_src: str = m.group(1)  # may contain #fragment
            clean_src: str = re.split(r"[?#]", raw_src, maxsplit=1)[0]  # file path only
            ext: str = Path(clean_src).suffix.lower()

            if ext not in cfg.allowed_extensions:
                continue
            if _is_remote(clean_src) and not cfg.remote:
                continue
            if not _is_remote(clean_src) and not cfg.local:
                continue

            try:
                data: bytes = _fetch_bytes(clean_src, base)
            except Exception as err:
                warnings.warn(f"skip '{raw_src}': {err}")
                continue

            if len(data) > cfg.max_bytes:
                continue

            # build replacement
            replacement: str
            if ext in {".css", ".js"}:
                tag_name: str = "style" if ext == ".css" else "script"
                replacement = f"<{tag_name}>\n{_decode_text(data)}\n</{tag_name}>"
            else:  # .svg or .png
                b64: str = base64.b64encode(data).decode()
                # TYPING: we check earlier, ext if for sure in MIME_BY_EXT
                data_uri: str = f"data:{MIME_BY_EXT[ext]};base64,{b64}"  # type: ignore[index]
                replacement = m.group(0).replace(raw_src, data_uri, 1)

            if cfg.include_filename_comments:
                replacement = f"<!-- begin '{clean_src}' -->\n{replacement}\n<!-- end '{clean_src}' -->"

            replacement = _apply_indent(html, m.start(), replacement)
            html = html[: m.start()] + replacement + html[m.end() :]

    return html


# BeautifulSoup-based implementation (optional)
# ---------------------------------------------------------------------


def _inline_with_bs4(html: str, base: Path, cfg: InlineConfig) -> str:
    """Inline assets using BeautifulSoup when available."""
    try:
        from bs4 import BeautifulSoup, Comment, Tag
    except ModuleNotFoundError as exc:  # pragma: no cover
        raise RuntimeError("BeautifulSoup requested but not installed") from exc

    soup: BeautifulSoup = BeautifulSoup(html, "html.parser")

    tag: Tag  # TYPING: i think soup.find_all() returns a list of Tag objects? mypy thinks it should be PageElement (of which Tag is a subclass)
    for tag in list(soup.find_all(cfg.tag_attr.keys())):  # type: ignore[assignment]
        attr: str = cfg.tag_attr[tag.name]
        # TYPING: error: Incompatible types in assignment (expression has type "str | AttributeValueList | None", variable has type "str | None")  [assignment]
        src_full: str | None = tag.get(attr)  # type: ignore[assignment]
        if not src_full:
            continue

        clean_src: str = re.split(r"[?#]", src_full, maxsplit=1)[0]
        ext: str = Path(clean_src).suffix.lower()

        if ext not in cfg.allowed_extensions:
            continue
        if _is_remote(clean_src) and not cfg.remote:
            continue
        if not _is_remote(clean_src) and not cfg.local:
            continue

        try:
            data: bytes = _fetch_bytes(clean_src, base)
        except Exception as err:
            warnings.warn(f"skip '{src_full}': {err}")
            continue

        if len(data) > cfg.max_bytes:
            continue

        if ext in {".css", ".js"}:
            new_tag: Tag = soup.new_tag("style" if ext == ".css" else "script")
            new_tag.string = _decode_text(data)
            if cfg.include_filename_comments:
                tag.insert_before(Comment(f" begin '{src_full}' "))
                tag.insert_after(Comment(f" end '{src_full}' "))
            tag.replace_with(new_tag)
        else:  # .svg or .png
            b64: str = base64.b64encode(data).decode()
            # we are sure ext is in MIME_BY_EXT, so ignore type error
            tag[attr] = f"data:{MIME_BY_EXT[ext]};base64,{b64}"  # type: ignore[index]
            if cfg.include_filename_comments:
                tag.insert_before(Comment(f" begin '{src_full}' "))
                tag.insert_after(Comment(f" end '{src_full}' "))

    return str(soup)


# Public API
# ---------------------------------------------------------------------


def inline_html_assets(
    html: str,
    *,
    base_path: Path,
    config: InlineConfig | None = None,
    prettify: bool = False,  # kept for API compatibility (ignored in regex mode)
) -> str:
    """Inline permitted external assets inside *html*.

    # Parameters
    - `html : str`
        Raw HTML text.
    - `base_path : Path`
        Directory used to resolve relative asset paths.
    - `config : InlineConfig | None`
        Inlining options (see `InlineConfig`).
    - `prettify : bool`
        Pretty-print output (only effective in BS4 mode).

    # Returns
    - `str`
        Modified HTML.
    """
    cfg: InlineConfig = config or InlineConfig()
    if cfg.use_bs4:
        html_out: str = _inline_with_bs4(html, base_path, cfg)
        if prettify:
            # lazy import to avoid unconditional dependency
            from bs4 import BeautifulSoup

            # TYPING: .prettify() returns str if no encoding is set
            html_out = str(BeautifulSoup(html_out, "html.parser").prettify())
    else:
        html_out = _inline_with_regex(html, base_path, cfg)
    return html_out


def inline_html_file(
    html_path: Path,
    output_path: Path,
    base_path: Path | None = None,
    config: InlineConfig | None = None,
    prettify: bool = False,
) -> Path:
    """Read *html_path*, inline its assets, and write the result.

    # Parameters
    - `html_path : Path`
        Source HTML file.
    - `output_path : Path`
        Destination path to write the modified HTML.
    - `base_path : Path | None`
        Directory used to resolve relative asset paths (defaults to the HTML file's directory).
        If `None`, uses the directory of *html_path*.
        (default: `None` -> use `html_path.parent`)
    - `config : InlineConfig | None`
        Inlining options.
        If `None`, uses default configuration.
        (default: `None` -> use `InlineConfig()`)
    - `prettify : bool`
        Pretty-print when `use_bs4=True`.
        (default: `False`)

    # Returns
    - `Path`
        Path actually written.
    """
    if base_path is None:
        base_path = html_path.parent
    html_raw: str = html_path.read_text()
    html_new: str = inline_html_assets(
        html_raw,
        base_path=base_path,
        config=config,
        prettify=prettify,
    )
    dest: Path = output_path or html_path
    dest.write_text(html_new)
    return dest


# CLI
# ---------------------------------------------------------------------

if __name__ == "__main__":
    import argparse

    parser: argparse.ArgumentParser = argparse.ArgumentParser(
        description="Inline / bundle CSS, JS, SVG, PNG assets. "
        "Uses regex parsing by default; pass --bs4 to require BeautifulSoup."
    )
    parser.add_argument("html", type=Path, help="input HTML file")
    parser.add_argument(
        "-o",
        "--output",
        type=Path,
        help="output file",
        required=True,
    )
    parser.add_argument(
        "--source-dir",
        type=Path,
        default=None,
        help="base directory for relative asset paths (defaults to the HTML file's directory)",
    )
    parser.add_argument("--remote", action="store_true", help="allow remote URLs")
    parser.add_argument("--bs4", action="store_true", help="use BeautifulSoup parser")
    parser.add_argument(
        "--prettify", action="store_true", help="pretty-print with BeautifulSoup)"
    )
    parser.add_argument(
        "--max-bytes", type=int, default=128 * 1024, help="size limit per asset"
    )
    parser.add_argument(
        "--ext",
        nargs="+",
        default=list(DEFAULT_ALLOWED_EXTENSIONS),
        help="extensions to inline",
    )
    parser.add_argument(
        "--tag-attr",
        type=str,
        default=None,
        help='override tag->attr map. format: "tag1=attr1,tag2=attr2"',
    )
    parser.add_argument("--no-comments", dest="comments", action="store_false")
    args: argparse.Namespace = parser.parse_args()

    tag_attr: dict[str, str]
    if args.tag_attr:
        tag_attr = {
            tag: attr
            for tag, attr in (item.split("=") for item in args.tag_attr.split(","))
        }

    else:
        tag_attr = dict(DEFAULT_TAG_ATTR)

    cfg: InlineConfig = InlineConfig(
        allowed_extensions=set(args.ext),  # type: ignore[arg-type]
        tag_attr=tag_attr,
        max_bytes=args.max_bytes,
        remote=args.remote,
        include_filename_comments=args.comments,
        use_bs4=args.bs4,
    )

    inline_html_file(
        args.html,
        output_path=args.output,
        base_path=args.source_dir,
        config=cfg,
        prettify=args.prettify,
    )

``````{ end_of_file="muutils/web/bundle_html.py" }

``````{ path="muutils/__init__.py"  }
"""
.. include:: ../README.md
"""

from __future__ import annotations

__all__ = [
    # submodules (with sub-submodules)
    "json_serialize",
    "logger",
    "math",
    "misc",
    "nbutils",
    "web",
    # submodules
    "collect_warnings",
    "console_unicode",
    "dbg",
    "dictmagic",
    "errormode",
    "group_equiv",
    "interval",
    "jsonlines",
    "kappa",
    "mlutils",
    "parallel",
    "spinner",
    "statcounter",
    "sysinfo",
    "tensor_info",
    "tensor_utils",
    "timeit_fancy",
    "validate_type",
]

``````{ end_of_file="muutils/__init__.py" }

``````{ path="muutils/collect_warnings.py"  }
from __future__ import annotations

import sys
import warnings
from collections import Counter
from contextlib import AbstractContextManager
from types import TracebackType
from typing import Any, Literal


class CollateWarnings(AbstractContextManager["CollateWarnings"]):
    """Capture every warning issued inside a `with` block and print a collated
    summary when the block exits.

    Internally this wraps `warnings.catch_warnings(record=True)` so that all
    warnings raised in the block are recorded.  When the context exits, identical
    warnings are grouped and (optionally) printed with a user-defined format.

    # Parameters:
     - `print_on_exit : bool`
       Whether to print the summary when the context exits
       (defaults to `True`)
     - `fmt : str`
       Format string used for printing each line of the summary.
       Available fields are:

       * `{count}`     : number of occurrences
       * `{filename}`  : file where the warning originated
       * `{lineno}`    : line number
       * `{category}`  : warning class name
       * `{message}`   : warning message text

       (defaults to `"({count}x) {filename}:{lineno} {category}: {message}"`)

    # Returns:
     - `CollateWarnings`
       The context-manager instance.  After exit, the attribute
       `counts` holds a mapping

       ```python
       {(filename, lineno, category, message): count}
       ```

    # Usage:
    ```python
    >>> import warnings
    >>> with CollateWarnings() as cw:
    ...     warnings.warn("deprecated", DeprecationWarning)
    ...     warnings.warn("deprecated", DeprecationWarning)
    ...     warnings.warn("other", UserWarning)
    (2x) /tmp/example.py:42 DeprecationWarning: deprecated
    (1x) /tmp/example.py:43 UserWarning: other
    >>> cw.counts
    {('/tmp/example.py', 42, 'DeprecationWarning', 'deprecated'): 2,
     ('/tmp/example.py', 43, 'UserWarning', 'other'): 1}
    ```
    """

    _active: bool
    _catcher: Any
    _records: list[warnings.WarningMessage]
    counts: Counter[
        tuple[
            str,  # filename
            int,  # lineno
            str,  # category name
            str,  # message
        ]
    ]
    print_on_exit: bool
    fmt: str

    def __init__(
        self,
        print_on_exit: bool = True,
        fmt: str = "({count}x) {filename}:{lineno} {category}: {message}",
    ) -> None:
        self.print_on_exit = print_on_exit
        self.fmt = fmt
        self._active = False
        self._records = []
        self.counts = Counter()

    def __enter__(self) -> CollateWarnings:
        if self._active:
            raise RuntimeError("CollateWarnings cannot be re-entered")

        self._active = True
        self._catcher = warnings.catch_warnings(record=True)
        self._records = self._catcher.__enter__()
        warnings.simplefilter("always")  # capture every warning
        return self

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_val: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> Literal[False]:
        if not self._active:
            raise RuntimeError("CollateWarnings exited twice")

        self._active = False
        # stop capturing
        self._catcher.__exit__(exc_type, exc_val, exc_tb)

        # collate
        self.counts = Counter(
            (
                rec.filename,
                rec.lineno,
                rec.category.__name__,
                str(rec.message),
            )
            for rec in self._records
        )

        if self.print_on_exit:
            for (filename, lineno, category, message), count in self.counts.items():
                print(
                    self.fmt.format(
                        count=count,
                        filename=filename,
                        lineno=lineno,
                        category=category,
                        message=message,
                    ),
                    file=sys.stderr,
                )

        # propagate any exception from the with-block
        return False

``````{ end_of_file="muutils/collect_warnings.py" }

``````{ path="muutils/console_unicode.py"  }
import locale


def get_console_safe_str(
    default: str,
    fallback: str,
) -> str:
    """Determine a console-safe string based on the preferred encoding.

    This function attempts to encode a given `default` string using the system's preferred encoding.
    If encoding is successful, it returns the `default` string; otherwise, it returns a `fallback` string.

    # Parameters:
     - `default : str`
        The primary string intended for use, to be tested against the system's preferred encoding.
     - `fallback : str`
        The alternative string to be used if `default` cannot be encoded in the system's preferred encoding.

    # Returns:
     - `str`
        Either `default` or `fallback` based on whether `default` can be encoded safely.

    # Usage:

    ```python
    >>> get_console_safe_str("café", "cafe")
    "café"  # This result may vary based on the system's preferred encoding.
    ```
    """
    try:
        default.encode(locale.getpreferredencoding())
        return default
    except UnicodeEncodeError:
        return fallback

``````{ end_of_file="muutils/console_unicode.py" }

``````{ path="muutils/dbg.py"  }
"""

this code is based on an implementation of the Rust builtin `dbg!` for Python, originally from
https://github.com/tylerwince/pydbg/blob/master/pydbg.py
although it has been significantly modified

licensed under MIT:

Copyright (c) 2019 Tyler Wince

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

"""

from __future__ import annotations

import inspect
import sys
import typing
from pathlib import Path
import re

# type defs
_ExpType = typing.TypeVar("_ExpType")
_ExpType_dict = typing.TypeVar(
    "_ExpType_dict", bound=typing.Dict[typing.Any, typing.Any]
)
_ExpType_list = typing.TypeVar("_ExpType_list", bound=typing.List[typing.Any])


# Sentinel type for no expression passed
class _NoExpPassedSentinel:
    """Unique sentinel type used to indicate that no expression was passed."""

    pass


_NoExpPassed = _NoExpPassedSentinel()

# global variables
_CWD: Path = Path.cwd().absolute()
_COUNTER: int = 0

# configuration
PATH_MODE: typing.Literal["relative", "absolute"] = "relative"
DEFAULT_VAL_JOINER: str = " = "


# path processing
def _process_path(path: Path) -> str:
    path_abs: Path = path.absolute()
    fname: Path
    if PATH_MODE == "absolute":
        fname = path_abs
    elif PATH_MODE == "relative":
        try:
            # if it's inside the cwd, print the relative path
            fname = path.relative_to(_CWD)
        except ValueError:
            # if its not in the subpath, use the absolute path
            fname = path_abs
    else:
        raise ValueError("PATH_MODE must be either 'relative' or 'absolute")

    return fname.as_posix()


# actual dbg function
@typing.overload
def dbg() -> _NoExpPassedSentinel: ...
@typing.overload
def dbg(
    exp: _NoExpPassedSentinel,
    formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
    val_joiner: str = DEFAULT_VAL_JOINER,
) -> _NoExpPassedSentinel: ...
@typing.overload
def dbg(
    exp: _ExpType,
    formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
    val_joiner: str = DEFAULT_VAL_JOINER,
) -> _ExpType: ...
def dbg(
    exp: typing.Union[_ExpType, _NoExpPassedSentinel] = _NoExpPassed,
    formatter: typing.Optional[typing.Callable[[typing.Any], str]] = None,
    val_joiner: str = DEFAULT_VAL_JOINER,
) -> typing.Union[_ExpType, _NoExpPassedSentinel]:
    """Call dbg with any variable or expression.

    Calling dbg will print to stderr the current filename and lineno,
    as well as the passed expression and what the expression evaluates to:

            from muutils.dbg import dbg

            a = 2
            b = 5

            dbg(a+b)

            def square(x: int) -> int:
                    return x * x

            dbg(square(a))

    """
    global _COUNTER

    # get the context
    line_exp: str = "unknown"
    current_file: str = "unknown"
    dbg_frame: typing.Optional[inspect.FrameInfo] = None
    for frame in inspect.stack():
        if frame.code_context is None:
            continue
        line: str = frame.code_context[0]
        if "dbg" in line:
            current_file = _process_path(Path(frame.filename))
            dbg_frame = frame
            start: int = line.find("(") + 1
            end: int = line.rfind(")")
            if end == -1:
                end = len(line)
            line_exp = line[start:end]
            break

    fname: str = "unknown"
    if current_file.startswith("/tmp/ipykernel_"):
        stack: list[inspect.FrameInfo] = inspect.stack()
        filtered_functions: list[str] = []
        # this loop will find, in this order:
        # - the dbg function call
        # - the functions we care about displaying
        # - `<module>`
        # - a bunch of jupyter internals we don't care about
        for frame_info in stack:
            if _process_path(Path(frame_info.filename)) != current_file:
                continue
            if frame_info.function == "<module>":
                break
            if frame_info.function.startswith("dbg"):
                continue
            filtered_functions.append(frame_info.function)
        if dbg_frame is not None:
            filtered_functions.append(f"<ipykernel>:{dbg_frame.lineno}")
        else:
            filtered_functions.append(current_file)
        filtered_functions.reverse()
        fname = " -> ".join(filtered_functions)
    elif dbg_frame is not None:
        fname = f"{current_file}:{dbg_frame.lineno}"

    # assemble the message
    msg: str
    if exp is _NoExpPassed:
        # if no expression is passed, just show location and counter value
        msg = f"[ {fname} ] <dbg {_COUNTER}>"
        _COUNTER += 1
    else:
        # if expression passed, format its value and show location, expr, and value
        exp_val: str = formatter(exp) if formatter else repr(exp)
        msg = f"[ {fname} ] {line_exp}{val_joiner}{exp_val}"

    # print the message
    print(
        msg,
        file=sys.stderr,
    )

    # return the expression itself
    return exp


# formatted `dbg_*` functions with their helpers

DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS: typing.Dict[
    str, typing.Union[None, bool, int, str]
] = dict(
    fmt="unicode",
    precision=2,
    stats=True,
    shape=True,
    dtype=True,
    device=True,
    requires_grad=True,
    sparkline=True,
    sparkline_bins=7,
    sparkline_logy=None,  # None means auto-detect
    colored=True,
    eq_char="=",
)


DBG_TENSOR_VAL_JOINER: str = ": "


def tensor_info(tensor: typing.Any) -> str:
    from muutils.tensor_info import array_summary

    return array_summary(tensor, as_list=False, **DBG_TENSOR_ARRAY_SUMMARY_DEFAULTS)


DBG_DICT_DEFAULTS: typing.Dict[str, typing.Union[bool, int, str]] = dict(
    key_types=True,
    val_types=True,
    max_len=32,
    indent="  ",
    max_depth=3,
)

DBG_LIST_DEFAULTS: typing.Dict[str, typing.Union[bool, int, str]] = dict(
    max_len=16,
    summary_show_types=True,
)


def list_info(
    lst: typing.List[typing.Any],
) -> str:
    len_l: int = len(lst)
    output: str
    # TYPING: make `DBG_LIST_DEFAULTS` and the others typed dicts
    if len_l > DBG_LIST_DEFAULTS["max_len"]:  # type: ignore[operator]
        output = f"<list of len()={len_l}"
        if DBG_LIST_DEFAULTS["summary_show_types"]:
            val_types: typing.Set[str] = set(type(x).__name__ for x in lst)
            output += f", types={{{', '.join(sorted(val_types))}}}"
        output += ">"
    else:
        output = "[" + ", ".join(repr(x) for x in lst) + "]"

    return output


TENSOR_STR_TYPES: typing.Set[str] = {
    "<class 'torch.Tensor'>",
    "<class 'numpy.ndarray'>",
}


def dict_info(
    d: typing.Dict[typing.Any, typing.Any],
    depth: int = 0,
) -> str:
    len_d: int = len(d)
    indent: str = DBG_DICT_DEFAULTS["indent"]  # type: ignore[assignment]

    # summary line
    output: str = f"{indent * depth}<dict of len()={len_d}"

    if DBG_DICT_DEFAULTS["key_types"] and len_d > 0:
        key_types: typing.Set[str] = set(type(k).__name__ for k in d.keys())
        key_types_str: str = "{" + ", ".join(sorted(key_types)) + "}"
        output += f", key_types={key_types_str}"

    if DBG_DICT_DEFAULTS["val_types"] and len_d > 0:
        val_types: typing.Set[str] = set(type(v).__name__ for v in d.values())
        val_types_str: str = "{" + ", ".join(sorted(val_types)) + "}"
        output += f", val_types={val_types_str}"

    output += ">"

    # keys/values if not to deep and not too many
    if depth < DBG_DICT_DEFAULTS["max_depth"]:  # type: ignore[operator]
        if len_d > 0 and len_d < DBG_DICT_DEFAULTS["max_len"]:  # type: ignore[operator]
            for k, v in d.items():
                key_str: str = repr(k) if not isinstance(k, str) else k

                val_str: str
                val_type_str: str = str(type(v))
                if isinstance(v, dict):
                    val_str = dict_info(v, depth + 1)
                elif val_type_str in TENSOR_STR_TYPES:
                    val_str = tensor_info(v)
                elif isinstance(v, list):
                    val_str = list_info(v)
                else:
                    val_str = repr(v)

                output += (
                    f"\n{indent * (depth + 1)}{key_str}{DBG_TENSOR_VAL_JOINER}{val_str}"
                )

    return output


def info_auto(
    obj: typing.Any,
) -> str:
    """Automatically format an object for debugging."""
    if isinstance(obj, dict):
        return dict_info(obj)
    elif isinstance(obj, list):
        return list_info(obj)
    elif str(type(obj)) in TENSOR_STR_TYPES:
        return tensor_info(obj)
    else:
        return repr(obj)


def dbg_tensor(
    tensor: _ExpType,  # numpy array or torch tensor
) -> _ExpType:
    """dbg function for tensors, using tensor_info formatter."""
    return dbg(
        tensor,
        formatter=tensor_info,
        val_joiner=DBG_TENSOR_VAL_JOINER,
    )


def dbg_dict(
    d: _ExpType_dict,
) -> _ExpType_dict:
    """dbg function for dictionaries, using dict_info formatter."""
    return dbg(
        d,
        formatter=dict_info,
        val_joiner=DBG_TENSOR_VAL_JOINER,
    )


def dbg_auto(
    obj: _ExpType,
) -> _ExpType:
    """dbg function for automatic formatting based on type."""
    return dbg(
        obj,
        formatter=info_auto,
        val_joiner=DBG_TENSOR_VAL_JOINER,
    )


def _normalize_for_loose(text: str) -> str:
    """Normalize text for loose matching by replacing non-alphanumeric chars with spaces."""
    normalized: str = re.sub(r"[^a-zA-Z0-9]+", " ", text)
    return " ".join(normalized.split())


def _compile_pattern(
    pattern: str | re.Pattern[str],
    *,
    cased: bool = False,
    loose: bool = False,
) -> re.Pattern[str]:
    """Compile pattern with appropriate flags for case sensitivity and loose matching."""
    if isinstance(pattern, re.Pattern):
        return pattern

    # Start with no flags for case-insensitive default
    flags: int = 0
    if not cased:
        flags |= re.IGNORECASE

    if loose:
        pattern = _normalize_for_loose(pattern)

    return re.compile(pattern, flags)


def grep_repr(
    obj: typing.Any,
    pattern: str | re.Pattern[str],
    *,
    char_context: int | None = 20,
    line_context: int | None = None,
    before_context: int = 0,
    after_context: int = 0,
    context: int | None = None,
    max_count: int | None = None,
    cased: bool = False,
    loose: bool = False,
    line_numbers: bool = False,
    highlight: bool = True,
    color: str = "31",
    separator: str = "--",
    quiet: bool = False,
) -> typing.List[str] | None:
    """grep-like search on ``repr(obj)`` with improved grep-style options.

    By default, string patterns are case-insensitive. Pre-compiled regex
    patterns use their own flags.

    Parameters:
    - obj: Object to search (its repr() string is scanned)
    - pattern: Regular expression pattern (string or pre-compiled)
    - char_context: Characters of context before/after each match (default: 20)
    - line_context: Lines of context before/after; overrides char_context
    - before_context: Lines of context before match (like grep -B)
    - after_context: Lines of context after match (like grep -A)
    - context: Lines of context before AND after (like grep -C)
    - max_count: Stop after this many matches
    - cased: Force case-sensitive search for string patterns
    - loose: Normalize spaces/punctuation for flexible matching
    - line_numbers: Show line numbers in output
    - highlight: Wrap matches with ANSI color codes
    - color: ANSI color code (default: "31" for red)
    - separator: Separator between multiple matches
    - quiet: Return results instead of printing

    Returns:
    - None if quiet=False (prints to stdout)
    - List[str] if quiet=True (returns formatted output lines)
    """
    # Handle context parameter shortcuts
    if context is not None:
        before_context = after_context = context

    # Prepare text and pattern
    text: str = repr(obj)
    if loose:
        text = _normalize_for_loose(text)

    regex: re.Pattern[str] = _compile_pattern(pattern, cased=cased, loose=loose)

    def _color_match(segment: str) -> str:
        if not highlight:
            return segment
        return regex.sub(lambda m: f"\033[1;{color}m{m.group(0)}\033[0m", segment)

    output_lines: list[str] = []
    match_count: int = 0

    # Determine if we're using line-based context
    using_line_context = (
        line_context is not None or before_context > 0 or after_context > 0
    )

    if using_line_context:
        lines: list[str] = text.splitlines()
        line_starts: list[int] = []
        pos: int = 0
        for line in lines:
            line_starts.append(pos)
            pos += len(line) + 1  # +1 for newline

        processed_lines: set[int] = set()

        for match in regex.finditer(text):
            if max_count is not None and match_count >= max_count:
                break

            # Find which line contains this match
            match_line = max(
                i for i, start in enumerate(line_starts) if start <= match.start()
            )

            # Calculate context range
            ctx_before: int
            ctx_after: int
            if line_context is not None:
                ctx_before = ctx_after = line_context
            else:
                ctx_before, ctx_after = before_context, after_context

            start_line: int = max(0, match_line - ctx_before)
            end_line: int = min(len(lines), match_line + ctx_after + 1)

            # Avoid duplicate output for overlapping contexts
            line_range: set[int] = set(range(start_line, end_line))
            if line_range & processed_lines:
                continue
            processed_lines.update(line_range)

            # Format the context block
            context_lines: list[str] = []
            for i in range(start_line, end_line):
                line_text = lines[i]
                if line_numbers:
                    line_prefix = f"{i + 1}:"
                    line_text = f"{line_prefix}{line_text}"
                context_lines.append(_color_match(line_text))

            if output_lines and separator:
                output_lines.append(separator)
            output_lines.extend(context_lines)
            match_count += 1

    else:
        # Character-based context
        ctx: int = 0 if char_context is None else char_context

        for match in regex.finditer(text):
            if max_count is not None and match_count >= max_count:
                break

            start: int = max(0, match.start() - ctx)
            end: int = min(len(text), match.end() + ctx)
            snippet: str = text[start:end]

            if output_lines and separator:
                output_lines.append(separator)
            output_lines.append(_color_match(snippet))
            match_count += 1

    if quiet:
        return output_lines
    else:
        for line in output_lines:
            print(line)
        return None

``````{ end_of_file="muutils/dbg.py" }

``````{ path="muutils/dictmagic.py"  }
"""making working with dictionaries easier

- `DefaulterDict`: like a defaultdict, but default_factory is passed the key as an argument
- various methods for working wit dotlist-nested dicts, converting to and from them
- `condense_nested_dicts`: condense a nested dict, by condensing numeric or matching keys with matching values to ranges
- `condense_tensor_dict`: convert a dictionary of tensors to a dictionary of shapes
- `kwargs_to_nested_dict`: given kwargs from fire, convert them to a nested dict
"""

from __future__ import annotations

import typing
import warnings
from collections import defaultdict
from typing import (
    Any,
    Callable,
    Generic,
    Hashable,
    Iterable,
    Literal,
    Optional,
    TypeVar,
    Union,
)

from muutils.errormode import ErrorMode

_KT = TypeVar("_KT")
_VT = TypeVar("_VT")


class DefaulterDict(typing.Dict[_KT, _VT], Generic[_KT, _VT]):
    """like a defaultdict, but default_factory is passed the key as an argument"""

    def __init__(self, default_factory: Callable[[_KT], _VT], *args, **kwargs):
        if args:
            raise TypeError(
                f"DefaulterDict does not support positional arguments: *args = {args}"
            )
        super().__init__(**kwargs)
        self.default_factory: Callable[[_KT], _VT] = default_factory

    def __getitem__(self, k: _KT) -> _VT:
        if k in self:
            return dict.__getitem__(self, k)
        else:
            v: _VT = self.default_factory(k)
            dict.__setitem__(self, k, v)
            return v


def _recursive_defaultdict_ctor() -> defaultdict:
    return defaultdict(_recursive_defaultdict_ctor)


def defaultdict_to_dict_recursive(dd: Union[defaultdict, DefaulterDict]) -> dict:
    """Convert a defaultdict or DefaulterDict to a normal dict, recursively"""
    return {
        key: (
            defaultdict_to_dict_recursive(value)
            if isinstance(value, (defaultdict, DefaulterDict))
            else value
        )
        for key, value in dd.items()
    }


def dotlist_to_nested_dict(
    dot_dict: typing.Dict[str, Any], sep: str = "."
) -> typing.Dict[str, Any]:
    """Convert a dict with dot-separated keys to a nested dict

    Example:

        >>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3})
        {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
    """
    nested_dict: defaultdict = _recursive_defaultdict_ctor()
    for key, value in dot_dict.items():
        if not isinstance(key, str):
            raise TypeError(f"key must be a string, got {type(key)}")
        keys: list[str] = key.split(sep)
        current: defaultdict = nested_dict
        # iterate over the keys except the last one
        for sub_key in keys[:-1]:
            current = current[sub_key]
        current[keys[-1]] = value
    return defaultdict_to_dict_recursive(nested_dict)


def nested_dict_to_dotlist(
    nested_dict: typing.Dict[str, Any],
    sep: str = ".",
    allow_lists: bool = False,
) -> dict[str, Any]:
    def _recurse(current: Any, parent_key: str = "") -> typing.Dict[str, Any]:
        items: dict = dict()

        new_key: str
        if isinstance(current, dict):
            # dict case
            if not current and parent_key:
                items[parent_key] = current
            else:
                for k, v in current.items():
                    new_key = f"{parent_key}{sep}{k}" if parent_key else k
                    items.update(_recurse(v, new_key))

        elif allow_lists and isinstance(current, list):
            # list case
            for i, item in enumerate(current):
                new_key = f"{parent_key}{sep}{i}" if parent_key else str(i)
                items.update(_recurse(item, new_key))

        else:
            # anything else (write value)
            items[parent_key] = current

        return items

    return _recurse(nested_dict)


def update_with_nested_dict(
    original: dict[str, Any],
    update: dict[str, Any],
) -> dict[str, Any]:
    """Update a dict with a nested dict

    Example:
    >>> update_with_nested_dict({'a': {'b': 1}, "c": -1}, {'a': {"b": 2}})
    {'a': {'b': 2}, 'c': -1}

    # Arguments
    - `original: dict[str, Any]`
        the dict to update (will be modified in-place)
    - `update: dict[str, Any]`
        the dict to update with

    # Returns
    - `dict`
        the updated dict
    """
    for key, value in update.items():
        if key in original:
            if isinstance(original[key], dict) and isinstance(value, dict):
                update_with_nested_dict(original[key], value)
            else:
                original[key] = value
        else:
            original[key] = value

    return original


def kwargs_to_nested_dict(
    kwargs_dict: dict[str, Any],
    sep: str = ".",
    strip_prefix: Optional[str] = None,
    when_unknown_prefix: Union[ErrorMode, str] = ErrorMode.WARN,
    transform_key: Optional[Callable[[str], str]] = None,
) -> dict[str, Any]:
    """given kwargs from fire, convert them to a nested dict

    if strip_prefix is not None, then all keys must start with the prefix. by default,
    will warn if an unknown prefix is found, but can be set to raise an error or ignore it:
    `when_unknown_prefix: ErrorMode`

    Example:
    ```python
    def main(**kwargs):
        print(kwargs_to_nested_dict(kwargs))
    fire.Fire(main)
    ```
    running the above script will give:
    ```bash
    $ python test.py --a.b.c=1 --a.b.d=2 --a.e=3
    {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
    ```

    # Arguments
    - `kwargs_dict: dict[str, Any]`
        the kwargs dict to convert
    - `sep: str = "."`
        the separator to use for nested keys
    - `strip_prefix: Optional[str] = None`
        if not None, then all keys must start with this prefix
    - `when_unknown_prefix: ErrorMode = ErrorMode.WARN`
        what to do when an unknown prefix is found
    - `transform_key: Callable[[str], str] | None = None`
        a function to apply to each key before adding it to the dict (applied after stripping the prefix)
    """
    when_unknown_prefix_ = ErrorMode.from_any(when_unknown_prefix)
    filtered_kwargs: dict[str, Any] = dict()
    for key, value in kwargs_dict.items():
        if strip_prefix is not None:
            if not key.startswith(strip_prefix):
                when_unknown_prefix_.process(
                    f"key '{key}' does not start with '{strip_prefix}'",
                    except_cls=ValueError,
                )
            else:
                key = key[len(strip_prefix) :]

        if transform_key is not None:
            key = transform_key(key)

        filtered_kwargs[key] = value

    return dotlist_to_nested_dict(filtered_kwargs, sep=sep)


def is_numeric_consecutive(lst: list[str]) -> bool:
    """Check if the list of keys is numeric and consecutive."""
    try:
        numbers: list[int] = [int(x) for x in lst]
        return sorted(numbers) == list(range(min(numbers), max(numbers) + 1))
    except ValueError:
        return False


def condense_nested_dicts_numeric_keys(
    data: dict[str, Any],
) -> dict[str, Any]:
    """condense a nested dict, by condensing numeric keys with matching values to ranges

    # Examples:
    ```python
    >>> condense_nested_dicts_numeric_keys({'1': 1, '2': 1, '3': 1, '4': 2, '5': 2, '6': 2})
    {'[1-3]': 1, '[4-6]': 2}
    >>> condense_nested_dicts_numeric_keys({'1': {'1': 'a', '2': 'a'}, '2': 'b'})
    {"1": {"[1-2]": "a"}, "2": "b"}
    ```
    """

    if not isinstance(data, dict):
        return data

    # Process each sub-dictionary
    for key, value in list(data.items()):
        data[key] = condense_nested_dicts_numeric_keys(value)

    # Find all numeric, consecutive keys
    if is_numeric_consecutive(list(data.keys())):
        keys: list[str] = sorted(data.keys(), key=lambda x: int(x))
    else:
        return data

    # output dict
    condensed_data: dict[str, Any] = {}

    # Identify ranges of identical values and condense
    i: int = 0
    while i < len(keys):
        j: int = i
        while j + 1 < len(keys) and data[keys[j]] == data[keys[j + 1]]:
            j += 1
        if j > i:  # Found consecutive keys with identical values
            condensed_key: str = f"[{keys[i]}-{keys[j]}]"
            condensed_data[condensed_key] = data[keys[i]]
            i = j + 1
        else:
            condensed_data[keys[i]] = data[keys[i]]
            i += 1

    return condensed_data


def condense_nested_dicts_matching_values(
    data: dict[str, Any],
    val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None,
) -> dict[str, Any]:
    """condense a nested dict, by condensing keys with matching values

    # Examples: TODO

    # Parameters:
     - `data : dict[str, Any]`
        data to process
     - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None`
        a function to apply to each value before adding it to the dict (if it's not hashable)
        (defaults to `None`)

    """

    if isinstance(data, dict):
        data = {
            key: condense_nested_dicts_matching_values(
                value, val_condense_fallback_mapping
            )
            for key, value in data.items()
        }
    else:
        return data

    # Find all identical values and condense by stitching together keys
    values_grouped: defaultdict[Any, list[str]] = defaultdict(list)
    data_persist: dict[str, Any] = dict()
    for key, value in data.items():
        if not isinstance(value, dict):
            try:
                values_grouped[value].append(key)
            except TypeError:
                # If the value is unhashable, use a fallback mapping to find a hashable representation
                if val_condense_fallback_mapping is not None:
                    values_grouped[val_condense_fallback_mapping(value)].append(key)
                else:
                    data_persist[key] = value
        else:
            data_persist[key] = value

    condensed_data = data_persist
    for value, keys in values_grouped.items():
        if len(keys) > 1:
            merged_key = f"[{', '.join(keys)}]"  # Choose an appropriate method to represent merged keys
            condensed_data[merged_key] = value
        else:
            condensed_data[keys[0]] = value

    return condensed_data


def condense_nested_dicts(
    data: dict[str, Any],
    condense_numeric_keys: bool = True,
    condense_matching_values: bool = True,
    val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None,
) -> dict[str, Any]:
    """condense a nested dict, by condensing numeric or matching keys with matching values to ranges

    combines the functionality of `condense_nested_dicts_numeric_keys()` and `condense_nested_dicts_matching_values()`

    # NOTE: this process is not meant to be reversible, and is intended for pretty-printing and visualization purposes
    it's not reversible because types are lost to make the printing pretty

    # Parameters:
     - `data : dict[str, Any]`
        data to process
     - `condense_numeric_keys : bool`
        whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]")
       (defaults to `True`)
     - `condense_matching_values : bool`
        whether to condense keys with matching values
       (defaults to `True`)
     - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None`
        a function to apply to each value before adding it to the dict (if it's not hashable)
       (defaults to `None`)

    """

    condensed_data: dict = data
    if condense_numeric_keys:
        condensed_data = condense_nested_dicts_numeric_keys(condensed_data)
    if condense_matching_values:
        condensed_data = condense_nested_dicts_matching_values(
            condensed_data, val_condense_fallback_mapping
        )
    return condensed_data


def tuple_dims_replace(
    t: tuple[int, ...], dims_names_map: Optional[dict[int, str]] = None
) -> tuple[Union[int, str], ...]:
    if dims_names_map is None:
        return t
    else:
        return tuple(dims_names_map.get(x, x) for x in t)


TensorDict = typing.Dict[str, "torch.Tensor|np.ndarray"]  # type: ignore[name-defined] # noqa: F821
TensorIterable = Iterable[typing.Tuple[str, "torch.Tensor|np.ndarray"]]  # type: ignore[name-defined] # noqa: F821
TensorDictFormats = Literal["dict", "json", "yaml", "yml"]


def _default_shapes_convert(x: tuple) -> str:
    return str(x).replace('"', "").replace("'", "")


def condense_tensor_dict(
    data: TensorDict | TensorIterable,
    fmt: TensorDictFormats = "dict",
    *args,
    shapes_convert: Callable[[tuple], Any] = _default_shapes_convert,
    drop_batch_dims: int = 0,
    sep: str = ".",
    dims_names_map: Optional[dict[int, str]] = None,
    condense_numeric_keys: bool = True,
    condense_matching_values: bool = True,
    val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None,
    return_format: Optional[TensorDictFormats] = None,
) -> Union[str, dict[str, str | tuple[int, ...]]]:
    """Convert a dictionary of tensors to a dictionary of shapes.

    by default, values are converted to strings of their shapes (for nice printing).
    If you want the actual shapes, set `shapes_convert = lambda x: x` or `shapes_convert = None`.

    # Parameters:
     - `data : dict[str, "torch.Tensor|np.ndarray"] | Iterable[tuple[str, "torch.Tensor|np.ndarray"]]`
        a either a `TensorDict` dict from strings to tensors, or an `TensorIterable` iterable of (key, tensor) pairs (like you might get from a `dict().items())` )
     - `fmt : TensorDictFormats`
        format to return the result in -- either a dict, or dump to json/yaml directly for pretty printing. will crash if yaml is not installed.
        (defaults to `'dict'`)
     - `shapes_convert : Callable[[tuple], Any]`
        conversion of a shape tuple to a string or other format (defaults to turning it into a string and removing quotes)
        (defaults to `lambdax:str(x).replace('"', '').replace("'", '')`)
     - `drop_batch_dims : int`
        number of leading dimensions to drop from the shape
        (defaults to `0`)
     - `sep : str`
        separator to use for nested keys
        (defaults to `'.'`)
     - `dims_names_map : dict[int, str] | None`
        convert certain dimension values in shape. not perfect, can be buggy
        (defaults to `None`)
     - `condense_numeric_keys : bool`
        whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]"), passed on to `condense_nested_dicts`
        (defaults to `True`)
     - `condense_matching_values : bool`
        whether to condense keys with matching values, passed on to `condense_nested_dicts`
        (defaults to `True`)
     - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None`
        a function to apply to each value before adding it to the dict (if it's not hashable), passed on to `condense_nested_dicts`
        (defaults to `None`)
     - `return_format : TensorDictFormats | None`
        legacy alias for `fmt` kwarg

    # Returns:
     - `str|dict[str, str|tuple[int, ...]]`
        dict if `return_format='dict'`, a string for `json` or `yaml` output

    # Examples:
    ```python
    >>> model = transformer_lens.HookedTransformer.from_pretrained("gpt2")
    >>> print(condense_tensor_dict(model.named_parameters(), return_format='yaml'))
    ```
    ```yaml
    embed:
      W_E: (50257, 768)
    pos_embed:
      W_pos: (1024, 768)
    blocks:
      '[0-11]':
        attn:
          '[W_Q, W_K, W_V]': (12, 768, 64)
          W_O: (12, 64, 768)
          '[b_Q, b_K, b_V]': (12, 64)
          b_O: (768,)
        mlp:
          W_in: (768, 3072)
          b_in: (3072,)
          W_out: (3072, 768)
          b_out: (768,)
    unembed:
      W_U: (768, 50257)
      b_U: (50257,)
    ```

    # Raises:
     - `ValueError` :  if `return_format` is not one of 'dict', 'json', or 'yaml', or if you try to use 'yaml' output without having PyYAML installed
    """

    # handle arg processing:
    # ----------------------------------------------------------------------
    # make all args except data and format keyword-only
    assert len(args) == 0, f"unexpected positional args: {args}"
    # handle legacy return_format
    if return_format is not None:
        warnings.warn(
            "return_format is deprecated, use fmt instead",
            DeprecationWarning,
        )
        fmt = return_format

    # identity function for shapes_convert if not provided
    if shapes_convert is None:
        shapes_convert = lambda x: x  # noqa: E731

    # convert to iterable
    data_items: "Iterable[tuple[str, Union[torch.Tensor,np.ndarray]]]" = (  # type: ignore # noqa: F821
        data.items() if hasattr(data, "items") and callable(data.items) else data  # type: ignore
    )

    # get shapes
    data_shapes: dict[str, Union[str, tuple[int, ...]]] = {
        k: shapes_convert(
            tuple_dims_replace(
                tuple(v.shape)[drop_batch_dims:],
                dims_names_map,
            )
        )
        for k, v in data_items
    }

    # nest the dict
    data_nested: dict[str, Any] = dotlist_to_nested_dict(data_shapes, sep=sep)

    # condense the nested dict
    data_condensed: dict[str, Union[str, tuple[int, ...]]] = condense_nested_dicts(
        data=data_nested,
        condense_numeric_keys=condense_numeric_keys,
        condense_matching_values=condense_matching_values,
        val_condense_fallback_mapping=val_condense_fallback_mapping,
    )

    # return in the specified format
    fmt_lower: str = fmt.lower()
    if fmt_lower == "dict":
        return data_condensed
    elif fmt_lower == "json":
        import json

        return json.dumps(data_condensed, indent=2)
    elif fmt_lower in ["yaml", "yml"]:
        try:
            import yaml  # type: ignore[import-untyped]

            return yaml.dump(data_condensed, sort_keys=False)
        except ImportError as e:
            raise ValueError("PyYAML is required for YAML output") from e
    else:
        raise ValueError(f"Invalid return format: {fmt}")

``````{ end_of_file="muutils/dictmagic.py" }

``````{ path="muutils/errormode.py"  }
"""provides `ErrorMode` enum for handling errors consistently

pass an `error_mode: ErrorMode` to a function to specify how to handle a certain kind of exception.
That function then instead of `raise`ing or `warnings.warn`ing, calls `error_mode.process` with the message and the exception.

you can also specify the exception class to raise, the warning class to use, and the source of the exception/warning.

"""

from __future__ import annotations

import sys
import typing
import types
import warnings
from enum import Enum


class WarningFunc(typing.Protocol):
    def __call__(
        self,
        msg: str,
        category: typing.Type[Warning],
        source: typing.Any = None,
    ) -> None: ...


LoggingFunc = typing.Callable[[str], None]

GLOBAL_WARN_FUNC: WarningFunc = warnings.warn  # type: ignore[assignment]
GLOBAL_LOG_FUNC: LoggingFunc = print


def custom_showwarning(
    message: Warning | str,
    category: typing.Type[Warning] | None = None,
    filename: str | None = None,
    lineno: int | None = None,
    file: typing.Optional[typing.TextIO] = None,
    line: typing.Optional[str] = None,
) -> None:
    if category is None:
        category = UserWarning
    # Get the frame where process() was called
    # Adjusted to account for the extra function call
    frame: types.FrameType = sys._getframe(2)
    # get globals and traceback
    traceback: types.TracebackType = types.TracebackType(
        None, frame, frame.f_lasti, frame.f_lineno
    )
    _globals: dict[str, typing.Any] = frame.f_globals
    # init the new warning and add the traceback
    if isinstance(message, str):
        message = category(message)
    message = message.with_traceback(traceback)

    # Call the original showwarning function
    warnings.warn_explicit(
        message=message,
        category=category,
        # filename arg if it's passed, otherwise use the frame's filename
        filename=frame.f_code.co_filename,
        lineno=frame.f_lineno,
        module=frame.f_globals.get("__name__", "__main__"),
        registry=_globals.setdefault("__warningregistry__", {}),
        module_globals=_globals,
    )
    # warnings._showwarning_orig(
    #     message,
    #     category,
    #     frame.f_code.co_filename,
    #     frame.f_lineno,
    #     file,
    #     line,
    # )


class ErrorMode(Enum):
    """Enum for handling errors consistently

    pass one of the instances of this enum to a function to specify how to handle a certain kind of exception.

    That function then instead of `raise`ing or `warnings.warn`ing, calls `error_mode.process` with the message and the exception.
    """

    EXCEPT = "except"
    WARN = "warn"
    LOG = "log"
    IGNORE = "ignore"

    def process(
        self,
        msg: str,
        except_cls: typing.Type[Exception] = ValueError,
        warn_cls: typing.Type[Warning] = UserWarning,
        except_from: typing.Optional[Exception] = None,
        warn_func: WarningFunc | None = None,
        log_func: LoggingFunc | None = None,
    ):
        """process an exception or warning according to the error mode

        # Parameters:
         - `msg : str`
           message to pass to `except_cls` or `warn_func`
         - `except_cls : typing.Type[Exception]`
            exception class to raise, must be a subclass of `Exception`
           (defaults to `ValueError`)
         - `warn_cls : typing.Type[Warning]`
            warning class to use, must be a subclass of `Warning`
           (defaults to `UserWarning`)
         - `except_from : typing.Optional[Exception]`
            will `raise except_cls(msg) from except_from` if not `None`
           (defaults to `None`)
         - `warn_func : WarningFunc | None`
            function to use for warnings, must have the signature `warn_func(msg: str, category: typing.Type[Warning], source: typing.Any = None) -> None`
           (defaults to `None`)
         - `log_func : LoggingFunc | None`
            function to use for logging, must have the signature `log_func(msg: str) -> None`
           (defaults to `None`)

        # Raises:
         - `except_cls` : _description_
         - `except_cls` : _description_
         - `ValueError` : _description_
        """
        if self is ErrorMode.EXCEPT:
            # except, possibly with a chained exception
            frame: types.FrameType = sys._getframe(1)
            traceback: types.TracebackType = types.TracebackType(
                None, frame, frame.f_lasti, frame.f_lineno
            )

            # Attach the new traceback to the exception and raise it without the internal call stack
            if except_from is not None:
                raise except_cls(msg).with_traceback(traceback) from except_from
            else:
                raise except_cls(msg).with_traceback(traceback)
        elif self is ErrorMode.WARN:
            # get global warn function if not passed
            if warn_func is None:
                warn_func = GLOBAL_WARN_FUNC
            # augment warning message with source
            if except_from is not None:
                msg = f"{msg}\n\tSource of warning: {except_from}"
            if warn_func == warnings.warn:
                custom_showwarning(msg, category=warn_cls)
            else:
                # Use the provided warn_func as-is
                warn_func(msg, category=warn_cls)
        elif self is ErrorMode.LOG:
            # get global log function if not passed
            if log_func is None:
                log_func = GLOBAL_LOG_FUNC
            # log
            log_func(msg)
        elif self is ErrorMode.IGNORE:
            # do nothing
            pass
        else:
            raise ValueError(f"Unknown error mode {self}")

    @classmethod
    def from_any(
        cls,
        mode: "str|ErrorMode",
        allow_aliases: bool = True,
        allow_prefix: bool = True,
    ) -> ErrorMode:
        """initialize an `ErrorMode` from a string or an `ErrorMode` instance"""
        if isinstance(mode, ErrorMode):
            return mode
        elif isinstance(mode, str):
            # strip
            mode = mode.strip()

            # remove prefix
            if allow_prefix and mode.startswith("ErrorMode."):
                mode = mode[len("ErrorMode.") :]

            # lowercase and strip again
            mode = mode.strip().lower()

            if not allow_aliases:
                # try without aliases
                try:
                    return ErrorMode(mode)
                except ValueError as e:
                    raise KeyError(f"Unknown error mode {mode = }") from e
            else:
                # look up in aliases map
                return ERROR_MODE_ALIASES[mode]
        else:
            raise TypeError(
                f"Expected {ErrorMode = } or str, got {type(mode) = } {mode = }"
            )

    def __str__(self) -> str:
        return f"ErrorMode.{self.value.capitalize()}"

    def __repr__(self) -> str:
        return str(self)

    def serialize(self) -> str:
        return str(self)

    @classmethod
    def load(cls, data: str) -> ErrorMode:
        return cls.from_any(
            data,
            allow_aliases=False,
            allow_prefix=True,
        )


ERROR_MODE_ALIASES: dict[str, ErrorMode] = {
    # base
    "except": ErrorMode.EXCEPT,
    "warn": ErrorMode.WARN,
    "log": ErrorMode.LOG,
    "ignore": ErrorMode.IGNORE,
    # except
    "e": ErrorMode.EXCEPT,
    "error": ErrorMode.EXCEPT,
    "err": ErrorMode.EXCEPT,
    "raise": ErrorMode.EXCEPT,
    # warn
    "w": ErrorMode.WARN,
    "warning": ErrorMode.WARN,
    # log
    "l": ErrorMode.LOG,
    "print": ErrorMode.LOG,
    "output": ErrorMode.LOG,
    "show": ErrorMode.LOG,
    "display": ErrorMode.LOG,
    # ignore
    "i": ErrorMode.IGNORE,
    "silent": ErrorMode.IGNORE,
    "quiet": ErrorMode.IGNORE,
    "nothing": ErrorMode.IGNORE,
}
"map of string aliases to `ErrorMode` instances"

``````{ end_of_file="muutils/errormode.py" }

``````{ path="muutils/group_equiv.py"  }
"group items by assuming that `eq_func` defines an equivalence relation"

from __future__ import annotations

from itertools import chain
from typing import Callable, Sequence, TypeVar

T = TypeVar("T")


def group_by_equivalence(
    items_in: Sequence[T],
    eq_func: Callable[[T, T], bool],
) -> list[list[T]]:
    """group items by assuming that `eq_func` implies an equivalence relation but might not be transitive

    so, if f(a,b) and f(b,c) then f(a,c) might be false, but we still want to put [a,b,c] in the same class

    note that lists are used to avoid the need for hashable items, and to allow for duplicates

    # Arguments
     - `items_in: Sequence[T]` the items to group
     - `eq_func: Callable[[T, T], bool]` a function that returns true if two items are equivalent. need not be transitive
    """

    items: list[T] = list(items_in)
    items.reverse()
    output: list[list[T]] = list()

    while items:
        x: T = items.pop()

        # try to add to an existing class
        found_classes: list[int] = list()
        for i, c in enumerate(output):
            if any(eq_func(x, y) for y in c):
                found_classes.append(i)

        # if one class found, add to it
        if len(found_classes) == 1:
            output[found_classes.pop()].append(x)

        elif len(found_classes) > 1:
            # if multiple classes found, merge the classes

            # first sort the ones to be merged
            output_new: list[list[T]] = list()
            to_merge: list[list[T]] = list()
            for i, c in enumerate(output):
                if i in found_classes:
                    to_merge.append(c)
                else:
                    output_new.append(c)

            # then merge them back in, along with the element `x`
            merged: list[T] = list(chain.from_iterable(to_merge))
            merged.append(x)

            output_new.append(merged)
            output = output_new

        # if no class found, make a new one
        else:
            output.append([x])

    return output

``````{ end_of_file="muutils/group_equiv.py" }

``````{ path="muutils/interval.py"  }
"represents a mathematical `Interval` over the real numbers"

from __future__ import annotations

import math
import typing
from typing import Optional, Iterable, Sequence, Union, Any

from muutils.misc import str_to_numeric

_EPSILON: float = 1e-10

Number = Union[float, int]
# TODO: make this also work with decimals, fractions, numpy types, etc.
# except we must somehow avoid importing them? idk

_EMPTY_INTERVAL_ARGS: tuple[Number, Number, bool, bool, set[Number]] = (
    math.nan,
    math.nan,
    False,
    False,
    set(),
)


class Interval:
    """
    Represents a mathematical interval, open by default.

    The Interval class can represent both open and closed intervals, as well as half-open intervals.
    It supports various initialization methods and provides containment checks.

    Examples:

        >>> i1 = Interval(1, 5)  # Default open interval (1, 5)
        >>> 3 in i1
        True
        >>> 1 in i1
        False
        >>> i2 = Interval([1, 5])  # Closed interval [1, 5]
        >>> 1 in i2
        True
        >>> i3 = Interval(1, 5, closed_L=True)  # Half-open interval [1, 5)
        >>> str(i3)
        '[1, 5)'
        >>> i4 = ClosedInterval(1, 5)  # Closed interval [1, 5]
        >>> i5 = OpenInterval(1, 5)  # Open interval (1, 5)

    """

    def __init__(
        self,
        *args: Union[Sequence[Number], Number],
        is_closed: Optional[bool] = None,
        closed_L: Optional[bool] = None,
        closed_R: Optional[bool] = None,
    ):
        self.lower: Number
        self.upper: Number
        self.closed_L: bool
        self.closed_R: bool
        self.singleton_set: Optional[set[Number]] = None
        try:
            if len(args) == 0:
                (
                    self.lower,
                    self.upper,
                    self.closed_L,
                    self.closed_R,
                    self.singleton_set,
                ) = _EMPTY_INTERVAL_ARGS
                return
            # Handle different types of input arguments
            if len(args) == 1 and isinstance(
                args[0], (list, tuple, Sequence, Iterable)
            ):
                assert len(args[0]) == 2, (
                    "if arg is a list or tuple, it must have length 2"
                )
                self.lower = args[0][0]
                self.upper = args[0][1]
                # Determine closure type based on the container type
                default_closed = isinstance(args[0], list)
            elif len(args) == 1 and isinstance(
                args[0], (int, float, typing.SupportsFloat, typing.SupportsInt)
            ):
                # a singleton, but this will be handled later
                self.lower = args[0]
                self.upper = args[0]
                default_closed = False
            elif len(args) == 2:
                self.lower, self.upper = args  # type: ignore[assignment]
                default_closed = False  # Default to open interval if two args
            else:
                raise ValueError(f"Invalid input arguments: {args}")

            # if both of the bounds are NaN or None, return an empty interval
            if any(x is None for x in (self.lower, self.upper)) or any(
                math.isnan(x) for x in (self.lower, self.upper)
            ):
                if (self.lower is None and self.upper is None) or (
                    math.isnan(self.lower) and math.isnan(self.upper)
                ):
                    (
                        self.lower,
                        self.upper,
                        self.closed_L,
                        self.closed_R,
                        self.singleton_set,
                    ) = _EMPTY_INTERVAL_ARGS
                    return
                else:
                    raise ValueError(
                        "Both bounds must be NaN or None to create an empty interval. Also, just use `Interval.get_empty()` instead."
                    )

            # Ensure lower bound is less than upper bound
            if self.lower > self.upper:
                raise ValueError("Lower bound must be less than upper bound")

            if math.isnan(self.lower) or math.isnan(self.upper):
                raise ValueError("NaN is not allowed as an interval bound")

            # Determine closure properties
            if is_closed is not None:
                # can't specify both is_closed and closed_L/R
                if (closed_L is not None) or (closed_R is not None):
                    raise ValueError("Cannot specify both is_closed and closed_L/R")
                self.closed_L = is_closed
                self.closed_R = is_closed
            else:
                self.closed_L = closed_L if closed_L is not None else default_closed
                self.closed_R = closed_R if closed_R is not None else default_closed

            # handle singleton/empty case
            if self.lower == self.upper and not (self.closed_L or self.closed_R):
                (
                    self.lower,
                    self.upper,
                    self.closed_L,
                    self.closed_R,
                    self.singleton_set,
                ) = _EMPTY_INTERVAL_ARGS
                return

            elif self.lower == self.upper and (self.closed_L or self.closed_R):
                self.singleton_set = {self.lower}  # Singleton interval
                self.closed_L = True
                self.closed_R = True
                return
            # otherwise `singleton_set` is `None`

        except (AssertionError, ValueError) as e:
            raise ValueError(
                f"Invalid input arguments to Interval: {args = }, {is_closed = }, {closed_L = }, {closed_R = }\n{e}\nUsage:\n{self.__doc__}"
            ) from e

    @property
    def is_closed(self) -> bool:
        if self.is_empty:
            return True
        if self.is_singleton:
            return True
        return self.closed_L and self.closed_R

    @property
    def is_open(self) -> bool:
        if self.is_empty:
            return True
        if self.is_singleton:
            return False
        return not self.closed_L and not self.closed_R

    @property
    def is_half_open(self) -> bool:
        return (self.closed_L and not self.closed_R) or (
            not self.closed_L and self.closed_R
        )

    @property
    def is_singleton(self) -> bool:
        return self.singleton_set is not None and len(self.singleton_set) == 1

    @property
    def is_empty(self) -> bool:
        return self.singleton_set is not None and len(self.singleton_set) == 0

    @property
    def is_finite(self) -> bool:
        return not math.isinf(self.lower) and not math.isinf(self.upper)

    @property
    def singleton(self) -> Number:
        if not self.is_singleton:
            raise ValueError("Interval is not a singleton")
        return next(iter(self.singleton_set))  # type: ignore[arg-type]

    @staticmethod
    def get_empty() -> Interval:
        return Interval(math.nan, math.nan, closed_L=None, closed_R=None)

    @staticmethod
    def get_singleton(value: Number) -> Interval:
        if math.isnan(value) or value is None:
            return Interval.get_empty()
        return Interval(value, value, closed_L=True, closed_R=True)

    def numerical_contained(self, item: Number) -> bool:
        if self.is_empty:
            return False
        if math.isnan(item):
            raise ValueError("NaN cannot be checked for containment in an interval")
        if self.is_singleton:
            return item in self.singleton_set  # type: ignore[operator]
        return ((self.closed_L and item >= self.lower) or item > self.lower) and (
            (self.closed_R and item <= self.upper) or item < self.upper
        )

    def interval_contained(self, item: Interval) -> bool:
        if item.is_empty:
            return True
        if self.is_empty:
            return False
        if item.is_singleton:
            return self.numerical_contained(item.singleton)
        if self.is_singleton:
            if not item.is_singleton:
                return False
            return self.singleton == item.singleton

        lower_contained: bool = (
            # either strictly wider bound
            self.lower < item.lower
            # if same, then self must be closed if item is open
            or (self.lower == item.lower and self.closed_L >= item.closed_L)
        )

        upper_contained: bool = (
            # either strictly wider bound
            self.upper > item.upper
            # if same, then self must be closed if item is open
            or (self.upper == item.upper and self.closed_R >= item.closed_R)
        )

        return lower_contained and upper_contained

    def __contains__(self, item: Any) -> bool:
        if isinstance(item, Interval):
            return self.interval_contained(item)
        else:
            return self.numerical_contained(item)

    def __repr__(self) -> str:
        if self.is_empty:
            return r"∅"
        if self.is_singleton:
            return "{" + str(self.singleton) + "}"
        left: str = "[" if self.closed_L else "("
        right: str = "]" if self.closed_R else ")"
        return f"{left}{self.lower}, {self.upper}{right}"

    def __str__(self) -> str:
        return repr(self)

    @classmethod
    def from_str(cls, input_str: str) -> Interval:
        input_str = input_str.strip()
        # empty and singleton
        if input_str.count(",") == 0:
            # empty set
            if input_str == "∅":
                return cls.get_empty()
            assert input_str.startswith("{") and input_str.endswith("}"), (
                "Invalid input string"
            )
            input_str_set_interior: str = input_str.strip("{}").strip()
            if len(input_str_set_interior) == 0:
                return cls.get_empty()
            # singleton set
            return cls.get_singleton(str_to_numeric(input_str_set_interior))

        # expect commas
        if not input_str.count(",") == 1:
            raise ValueError("Invalid input string")

        # get bounds
        lower: str
        upper: str
        lower, upper = input_str.strip("[]()").split(",")
        lower = lower.strip()
        upper = upper.strip()

        lower_num: Number = str_to_numeric(lower)
        upper_num: Number = str_to_numeric(upper)

        # figure out closure
        closed_L: bool
        closed_R: bool
        if input_str[0] == "[":
            closed_L = True
        elif input_str[0] == "(":
            closed_L = False
        else:
            raise ValueError("Invalid input string")

        if input_str[-1] == "]":
            closed_R = True
        elif input_str[-1] == ")":
            closed_R = False
        else:
            raise ValueError("Invalid input string")

        return cls(lower_num, upper_num, closed_L=closed_L, closed_R=closed_R)

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, Interval):
            return False
        if self.is_empty and other.is_empty:
            return True
        if self.is_singleton and other.is_singleton:
            return self.singleton == other.singleton
        return (self.lower, self.upper, self.closed_L, self.closed_R) == (
            other.lower,
            other.upper,
            other.closed_L,
            other.closed_R,
        )

    def __iter__(self):
        if self.is_empty:
            return
        elif self.is_singleton:
            yield self.singleton
            return
        else:
            yield self.lower
            yield self.upper

    def __getitem__(self, index: int) -> float:
        if self.is_empty:
            raise IndexError("Empty interval has no bounds")
        if self.is_singleton:
            if index == 0:
                return self.singleton
            else:
                raise IndexError("Singleton interval has only one bound")
        if index == 0:
            return self.lower
        elif index == 1:
            return self.upper
        else:
            raise IndexError("Interval index out of range")

    def __len__(self) -> int:
        return 0 if self.is_empty else 1 if self.is_singleton else 2

    def copy(self) -> Interval:
        if self.is_empty:
            return Interval.get_empty()
        if self.is_singleton:
            return Interval.get_singleton(self.singleton)
        return Interval(
            self.lower, self.upper, closed_L=self.closed_L, closed_R=self.closed_R
        )

    def size(self) -> float:
        """
        Returns the size of the interval.

        # Returns:

         - `float`
            the size of the interval
        """
        if self.is_empty or self.is_singleton:
            return 0
        else:
            return self.upper - self.lower

    def clamp(self, value: Union[int, float], epsilon: float = _EPSILON) -> float:
        """
        Clamp the given value to the interval bounds.

        For open bounds, the clamped value will be slightly inside the interval (by epsilon).

        # Parameters:

         - `value : Union[int, float]`
           the value to clamp.
         - `epsilon : float`
           margin for open bounds
           (defaults to `_EPSILON`)

        # Returns:

         - `float`
            the clamped value

        # Raises:

         - `ValueError` : If the input value is NaN.
        """

        if math.isnan(value):
            raise ValueError("Cannot clamp NaN value")

        if math.isnan(epsilon):
            raise ValueError("Epsilon cannot be NaN")

        if epsilon < 0:
            raise ValueError(f"Epsilon must be non-negative: {epsilon = }")

        if self.is_empty:
            raise ValueError("Cannot clamp to an empty interval")

        if self.is_singleton:
            return self.singleton

        if epsilon > self.size():
            raise ValueError(
                f"epsilon is greater than the size of the interval: {epsilon = }, {self.size() = }, {self = }"
            )

        # make type work with decimals and stuff
        if not isinstance(value, (int, float)):
            epsilon = value.__class__(epsilon)

        clamped_min: Number
        if self.closed_L:
            clamped_min = self.lower
        else:
            clamped_min = self.lower + epsilon

        clamped_max: Number
        if self.closed_R:
            clamped_max = self.upper
        else:
            clamped_max = self.upper - epsilon

        return max(clamped_min, min(value, clamped_max))

    def intersection(self, other: Interval) -> Interval:
        if not isinstance(other, Interval):
            raise TypeError("Can only intersect with another Interval")

        if self.is_empty or other.is_empty:
            return Interval.get_empty()

        if self.is_singleton:
            if other.numerical_contained(self.singleton):
                return self.copy()
            else:
                return Interval.get_empty()

        if other.is_singleton:
            if self.numerical_contained(other.singleton):
                return other.copy()
            else:
                return Interval.get_empty()

        if self.upper < other.lower or other.upper < self.lower:
            return Interval.get_empty()

        lower: Number = max(self.lower, other.lower)
        upper: Number = min(self.upper, other.upper)
        closed_L: bool = self.closed_L if self.lower > other.lower else other.closed_L
        closed_R: bool = self.closed_R if self.upper < other.upper else other.closed_R

        return Interval(lower, upper, closed_L=closed_L, closed_R=closed_R)

    def union(self, other: Interval) -> Interval:
        if not isinstance(other, Interval):
            raise TypeError("Can only union with another Interval")

        # empty set case
        if self.is_empty:
            return other.copy()
        if other.is_empty:
            return self.copy()

        # special case where the intersection is empty but the intervals are contiguous
        if self.upper == other.lower:
            if self.closed_R or other.closed_L:
                return Interval(
                    self.lower,
                    other.upper,
                    closed_L=self.closed_L,
                    closed_R=other.closed_R,
                )
        elif other.upper == self.lower:
            if other.closed_R or self.closed_L:
                return Interval(
                    other.lower,
                    self.upper,
                    closed_L=other.closed_L,
                    closed_R=self.closed_R,
                )

        # non-intersecting nonempty and non-contiguous intervals
        if self.intersection(other) == Interval.get_empty():
            raise NotImplementedError(
                "Union of non-intersecting nonempty non-contiguous intervals is not implemented "
                + f"{self = }, {other = }, {self.intersection(other) = }"
            )

        # singleton case
        if self.is_singleton:
            return other.copy()
        if other.is_singleton:
            return self.copy()

        # regular case
        lower: Number = min(self.lower, other.lower)
        upper: Number = max(self.upper, other.upper)
        closed_L: bool = self.closed_L if self.lower < other.lower else other.closed_L
        closed_R: bool = self.closed_R if self.upper > other.upper else other.closed_R

        return Interval(lower, upper, closed_L=closed_L, closed_R=closed_R)


class ClosedInterval(Interval):
    def __init__(self, *args: Union[Sequence[float], float], **kwargs: Any):
        if any(key in kwargs for key in ("is_closed", "closed_L", "closed_R")):
            raise ValueError("Cannot specify closure properties for ClosedInterval")
        super().__init__(*args, is_closed=True)


class OpenInterval(Interval):
    def __init__(self, *args: Union[Sequence[float], float], **kwargs: Any):
        if any(key in kwargs for key in ("is_closed", "closed_L", "closed_R")):
            raise ValueError("Cannot specify closure properties for OpenInterval")
        super().__init__(*args, is_closed=False)

``````{ end_of_file="muutils/interval.py" }

``````{ path="muutils/jsonlines.py"  }
"utilities for reading and writing jsonlines files, including gzip support"

from __future__ import annotations

import gzip
import json
from typing import Callable, Sequence

from muutils.json_serialize import JSONitem

_GZIP_EXTENSIONS: tuple = (".gz", ".gzip")


def _file_is_gzip(path: str) -> bool:
    return any(str(path).endswith(ext) for ext in _GZIP_EXTENSIONS)


def _get_opener(
    path: str,
    use_gzip: bool | None = None,
) -> Callable:
    if use_gzip is None:
        use_gzip = _file_is_gzip(path)

    # appears to be another mypy bug
    # https://github.com/python/mypy/issues/10740
    return open if not use_gzip else gzip.open  # type: ignore


def jsonl_load(
    path: str,
    /,
    *,
    use_gzip: bool | None = None,
) -> list[JSONitem]:
    opener: Callable = _get_opener(path, use_gzip)

    data: list[JSONitem] = list()
    with opener(path, "rt", encoding="UTF-8") as f:
        for line in f:
            data.append(json.loads(line))

    return data


def jsonl_load_log(
    path: str,
    /,
    *,
    use_gzip: bool | None = None,
) -> list[dict]:
    data: list[JSONitem] = jsonl_load(path, use_gzip=use_gzip)
    for idx, item in enumerate(data):
        assert isinstance(item, dict), (
            f"item {idx = } from file {path} is not a dict: {type(item) = }\t{item = }"
        )

    # mypy complains that we are returning a list[JSONitem] but the function signature says list[dict]
    # it can't figure out that we are asserting that all items are dicts
    return data  # type: ignore


def jsonl_write(
    path: str,
    items: Sequence[JSONitem],
    use_gzip: bool | None = None,
    gzip_compresslevel: int = 2,
) -> None:
    opener: Callable = _get_opener(path, use_gzip)

    opener_kwargs: dict = dict()
    if use_gzip:
        opener_kwargs = dict(compresslevel=gzip_compresslevel)

    with opener(path, "wt", encoding="UTF-8", **opener_kwargs) as f:
        for item in items:
            f.write(json.dumps(item) + "\n")

``````{ end_of_file="muutils/jsonlines.py" }

``````{ path="muutils/kappa.py"  }
"""anonymous getitem class

util for constructing a class which has a getitem method which just calls a function

a `lambda` is an anonymous function: kappa is the letter before lambda in the greek alphabet,
hence the name of this class"""

from __future__ import annotations

from typing import Callable, Mapping, TypeVar

_kappa_K = TypeVar("_kappa_K")
_kappa_V = TypeVar("_kappa_V")

# get the docstring of this file
_BASE_DOC: str = (
    __doc__
    + """

source function docstring:
==============================\n
"""
)


class Kappa(Mapping[_kappa_K, _kappa_V]):
    def __init__(self, func_getitem: Callable[[_kappa_K], _kappa_V]) -> None:
        self.func_getitem = func_getitem
        self.doc = _BASE_DOC + str(
            getattr(
                func_getitem, "__doc__", "<no docstring provided for source function>"
            )
        )

    def __getitem__(self, x) -> _kappa_V:
        return self.func_getitem(x)

    def __iter__(self):
        raise NotImplementedError(
            "This method is not implemented for Kappa, we don't know the valid inputs"
        )

    def __len__(self):
        raise NotImplementedError(
            "This method is not implemented for Kappa, no idea how many valid inputs there are"
        )

``````{ end_of_file="muutils/kappa.py" }

``````{ path="muutils/mlutils.py"  }
"miscellaneous utilities for ML pipelines"

from __future__ import annotations

import json
import os
import random
import typing
import warnings
from itertools import islice
from pathlib import Path
from typing import Any, Callable, Optional, TypeVar, Union

ARRAY_IMPORTS: bool
try:
    import numpy as np
    import torch

    ARRAY_IMPORTS = True
except ImportError as e:
    warnings.warn(
        f"Numpy or torch not installed. Array operations will not be available.\n{e}"
    )
    ARRAY_IMPORTS = False

DEFAULT_SEED: int = 42
GLOBAL_SEED: int = DEFAULT_SEED


def get_device(device: "Union[str,torch.device,None]" = None) -> "torch.device":
    """Get the torch.device instance on which `torch.Tensor`s should be allocated."""
    if not ARRAY_IMPORTS:
        raise ImportError(
            "Numpy or torch not installed. Array operations will not be available."
        )
    try:
        # if device is given
        if device is not None:
            device = torch.device(device)
            if any(
                [
                    torch.cuda.is_available() and device.type == "cuda",
                    torch.backends.mps.is_available() and device.type == "mps",
                    device.type == "cpu",
                ]
            ):
                # if device is given and available
                pass
            else:
                warnings.warn(
                    f"Specified device {device} is not available, falling back to CPU"
                )
                return torch.device("cpu")

        # no device given, infer from availability
        else:
            if torch.cuda.is_available():
                device = torch.device("cuda")
            elif torch.backends.mps.is_available():
                device = torch.device("mps")
            else:
                device = torch.device("cpu")

        # put a dummy tensor on the device to check if it is available
        _dummy = torch.zeros(1, device=device)

        return device

    except Exception as e:
        warnings.warn(
            f"Error while getting device, falling back to CPU. Error: {e}",
            RuntimeWarning,
        )
        return torch.device("cpu")


def set_reproducibility(seed: int = DEFAULT_SEED):
    """
    Improve model reproducibility. See https://github.com/NVIDIA/framework-determinism for more information.

    Deterministic operations tend to have worse performance than nondeterministic operations, so this method trades
    off performance for reproducibility. Set use_deterministic_algorithms to True to improve performance.
    """
    global GLOBAL_SEED

    GLOBAL_SEED = seed

    random.seed(seed)

    if ARRAY_IMPORTS:
        np.random.seed(seed)
        torch.manual_seed(seed)

        torch.use_deterministic_algorithms(True)
        # Ensure reproducibility for concurrent CUDA streams
        # see https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility.
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"


def chunks(it, chunk_size):
    """Yield successive chunks from an iterator."""
    # https://stackoverflow.com/a/61435714
    iterator = iter(it)
    while chunk := list(islice(iterator, chunk_size)):
        yield chunk


def get_checkpoint_paths_for_run(
    run_path: Path,
    extension: typing.Literal["pt", "zanj"],
    checkpoints_format: str = "checkpoints/model.iter_*.{extension}",
) -> list[tuple[int, Path]]:
    """get checkpoints of the format from the run_path

    note that `checkpoints_format` should contain a glob pattern with:
     - unresolved "{extension}" format term for the extension
     - a wildcard for the iteration number
    """

    assert run_path.is_dir(), (
        f"Model path {run_path} is not a directory (expect run directory, not model files)"
    )

    return [
        (int(checkpoint_path.stem.split("_")[-1].split(".")[0]), checkpoint_path)
        for checkpoint_path in sorted(
            Path(run_path).glob(checkpoints_format.format(extension=extension))
        )
    ]


F = TypeVar("F", bound=Callable[..., Any])


def register_method(
    method_dict: dict[str, Callable[..., Any]],
    custom_name: Optional[str] = None,
) -> Callable[[F], F]:
    """Decorator to add a method to the method_dict"""

    def decorator(method: F) -> F:
        method_name: str
        if custom_name is None:
            method_name_orig: str | None = getattr(method, "__name__", None)
            if method_name_orig is None:
                warnings.warn(
                    f"Method {method} does not have a name, using sanitized repr"
                )
                from muutils.misc import sanitize_identifier

                method_name = sanitize_identifier(repr(method))
            else:
                method_name = method_name_orig
        else:
            method_name = custom_name
            method.__name__ = custom_name
        assert method_name not in method_dict, (
            f"Method name already exists in method_dict: {method_name = }, {list(method_dict.keys()) = }"
        )
        method_dict[method_name] = method
        return method

    return decorator


def pprint_summary(summary: dict):
    print(json.dumps(summary, indent=2))

``````{ end_of_file="muutils/mlutils.py" }

``````{ path="muutils/parallel.py"  }
"parallel processing utilities, chiefly `run_maybe_parallel`"

from __future__ import annotations

import multiprocessing
import functools
from typing import (
    Any,
    Callable,
    Iterable,
    Literal,
    Optional,
    Tuple,
    TypeVar,
    Dict,
    List,
    Union,
    Protocol,
)

# for no tqdm fallback
from muutils.spinner import SpinnerContext
from muutils.validate_type import get_fn_allowed_kwargs


InputType = TypeVar("InputType")
OutputType = TypeVar("OutputType")
# typevars for our iterable and map


class ProgressBarFunction(Protocol):
    "a protocol for a progress bar function"

    def __call__(self, iterable: Iterable, **kwargs: Any) -> Iterable: ...


ProgressBarOption = Literal["tqdm", "spinner", "none", None]
# type for the progress bar option


DEFAULT_PBAR_FN: ProgressBarOption
# default progress bar function

try:
    # use tqdm if it's available
    import tqdm  # type: ignore[import-untyped]

    DEFAULT_PBAR_FN = "tqdm"

except ImportError:
    # use progress bar as fallback
    DEFAULT_PBAR_FN = "spinner"


def spinner_fn_wrap(x: Iterable, **kwargs) -> List:
    "spinner wrapper"
    spinnercontext_allowed_kwargs: set[str] = get_fn_allowed_kwargs(
        SpinnerContext.__init__
    )
    mapped_kwargs: dict = {
        k: v for k, v in kwargs.items() if k in spinnercontext_allowed_kwargs
    }
    if "desc" in kwargs and "message" not in mapped_kwargs:
        mapped_kwargs["message"] = kwargs["desc"]

    if "message" not in mapped_kwargs and "total" in kwargs:
        mapped_kwargs["message"] = f"Processing {kwargs['total']} items"

    with SpinnerContext(**mapped_kwargs):
        output = list(x)

    return output


def map_kwargs_for_tqdm(kwargs: dict) -> dict:
    "map kwargs for tqdm, cant wrap because the pbar dissapears?"
    tqdm_allowed_kwargs: set[str] = get_fn_allowed_kwargs(tqdm.tqdm.__init__)
    mapped_kwargs: dict = {k: v for k, v in kwargs.items() if k in tqdm_allowed_kwargs}

    if "desc" not in kwargs:
        if "message" in kwargs:
            mapped_kwargs["desc"] = kwargs["message"]

        elif "total" in kwargs:
            mapped_kwargs["desc"] = f"Processing {kwargs.get('total')} items"
    return mapped_kwargs


def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable:
    "fallback to no progress bar"
    return x


def set_up_progress_bar_fn(
    pbar: Union[ProgressBarFunction, ProgressBarOption],
    pbar_kwargs: Optional[Dict[str, Any]] = None,
    **extra_kwargs,
) -> Tuple[ProgressBarFunction, dict]:
    """set up the progress bar function and its kwargs

    # Parameters:
     - `pbar : Union[ProgressBarFunction, ProgressBarOption]`
       progress bar function or option. if a function, we return as-is. if a string, we figure out which progress bar to use
     - `pbar_kwargs : Optional[Dict[str, Any]]`
       kwargs passed to the progress bar function (default to `None`)
       (defaults to `None`)

    # Returns:
     - `Tuple[ProgressBarFunction, dict]`
         a tuple of the progress bar function and its kwargs

    # Raises:
     - `ValueError` : if `pbar` is not one of the valid options
    """
    pbar_fn: ProgressBarFunction

    if pbar_kwargs is None:
        pbar_kwargs = dict()

    pbar_kwargs = {**extra_kwargs, **pbar_kwargs}

    # dont use a progress bar if `pbar` is None or "none", or if `disable` is set to True in `pbar_kwargs`
    if (pbar is None) or (pbar == "none") or pbar_kwargs.get("disable", False):
        pbar_fn = no_progress_fn_wrap  # type: ignore[assignment]

    # if `pbar` is a different string, figure out which progress bar to use
    elif isinstance(pbar, str):
        if pbar == "tqdm":
            pbar_fn = tqdm.tqdm
            pbar_kwargs = map_kwargs_for_tqdm(pbar_kwargs)
        elif pbar == "spinner":
            pbar_fn = functools.partial(spinner_fn_wrap, **pbar_kwargs)
            pbar_kwargs = dict()
        else:
            raise ValueError(
                f"`pbar` must be either 'tqdm' or 'spinner' if `str`, or a valid callable, got {type(pbar) = } {pbar = }"
            )
    else:
        # the default value is a callable which will resolve to tqdm if available or spinner as a fallback. we pass kwargs to this
        pbar_fn = pbar

    return pbar_fn, pbar_kwargs


# TODO: if `parallel` is a negative int, use `multiprocessing.cpu_count() + parallel` to determine the number of processes
def run_maybe_parallel(
    func: Callable[[InputType], OutputType],
    iterable: Iterable[InputType],
    parallel: Union[bool, int],
    pbar_kwargs: Optional[Dict[str, Any]] = None,
    chunksize: Optional[int] = None,
    keep_ordered: bool = True,
    use_multiprocess: bool = False,
    pbar: Union[ProgressBarFunction, ProgressBarOption] = DEFAULT_PBAR_FN,
) -> List[OutputType]:
    """a function to make it easier to sometimes parallelize an operation

    - if `parallel` is `False`, then the function will run in serial, running `map(func, iterable)`
    - if `parallel` is `True`, then the function will run in parallel, running in parallel with the maximum number of processes
    - if `parallel` is an `int`, it must be greater than 1, and the function will run in parallel with the number of processes specified by `parallel`

    the maximum number of processes is given by the `min(len(iterable), multiprocessing.cpu_count())`

    # Parameters:
     - `func : Callable[[InputType], OutputType]`
       function passed to either `map` or `Pool.imap`
     - `iterable : Iterable[InputType]`
       iterable passed to either `map` or `Pool.imap`
     - `parallel : bool | int`
       whether to run in parallel, and how many processes to use
     - `pbar_kwargs : Dict[str, Any]`
       kwargs passed to the progress bar function

    # Returns:
     - `List[OutputType]`
       a list of the output of `func` for each element in `iterable`

    # Raises:
     - `ValueError` : if `parallel` is not a boolean or an integer greater than 1
     - `ValueError` : if `use_multiprocess=True` and `parallel=False`
     - `ImportError` : if `use_multiprocess=True` and `multiprocess` is not available
    """

    # number of inputs in iterable
    n_inputs: int = len(iterable)  # type: ignore[arg-type]
    if n_inputs == 0:
        # Return immediately if there is no input
        return list()

    # which progress bar to use
    pbar_fn: ProgressBarFunction
    pbar_kwargs_processed: dict
    pbar_fn, pbar_kwargs_processed = set_up_progress_bar_fn(
        pbar=pbar,
        pbar_kwargs=pbar_kwargs,
        # extra kwargs
        total=n_inputs,
    )

    # number of processes
    num_processes: int
    if isinstance(parallel, bool):
        num_processes = multiprocessing.cpu_count() if parallel else 1
    elif isinstance(parallel, int):
        if parallel < 2:
            raise ValueError(
                f"`parallel` must be a boolean, or be an integer greater than 1, got {type(parallel) = } {parallel = }"
            )
        num_processes = parallel
    else:
        raise ValueError(
            f"The 'parallel' parameter must be a boolean or an integer, got {type(parallel) = } {parallel = }"
        )

    # make sure we don't have more processes than iterable, and don't bother with parallel if there's only one process
    num_processes = min(num_processes, n_inputs)
    mp = multiprocessing
    if num_processes == 1:
        parallel = False

    if use_multiprocess:
        if not parallel:
            raise ValueError("`use_multiprocess=True` requires `parallel=True`")

        try:
            import multiprocess  # type: ignore[import-untyped]
        except ImportError as e:
            raise ImportError(
                "`use_multiprocess=True` requires the `multiprocess` package -- this is mostly useful when you need to pickle a lambda. install muutils with `pip install muutils[multiprocess]` or just do `pip install multiprocess`"
            ) from e

        mp = multiprocess

    # set up the map function -- maybe its parallel, maybe it's just `map`
    do_map: Callable[
        [Callable[[InputType], OutputType], Iterable[InputType]],
        Iterable[OutputType],
    ]
    if parallel:
        # use `mp.Pool` since we might want to use `multiprocess` instead of `multiprocessing`
        pool = mp.Pool(num_processes)

        # use `imap` if we want to keep the order, otherwise use `imap_unordered`
        if keep_ordered:
            do_map = pool.imap
        else:
            do_map = pool.imap_unordered

        # figure out a smart chunksize if one is not given
        chunksize_int: int
        if chunksize is None:
            chunksize_int = max(1, n_inputs // num_processes)
        else:
            chunksize_int = chunksize

        # set the chunksize
        do_map = functools.partial(do_map, chunksize=chunksize_int)  # type: ignore

    else:
        do_map = map

    # run the map function with a progress bar
    output: List[OutputType] = list(
        pbar_fn(
            do_map(
                func,
                iterable,
            ),
            **pbar_kwargs_processed,
        )
    )

    # close the pool if we used one
    if parallel:
        pool.close()
        pool.join()

    # return the output as a list
    return output

``````{ end_of_file="muutils/parallel.py" }

``````{ path="muutils/py.typed"  }

``````{ end_of_file="muutils/py.typed" }

``````{ path="muutils/spinner.py"  }
"""decorator `spinner_decorator` and context manager `SpinnerContext` to display a spinner

using the base `Spinner` class while some code is running.
"""

import os
import time
from dataclasses import dataclass, field
import threading
import sys
from functools import wraps
from typing import (
    List,
    Dict,
    Callable,
    Any,
    Literal,
    Optional,
    TextIO,
    TypeVar,
    Sequence,
    Union,
    ContextManager,
)
import warnings

DecoratedFunction = TypeVar("DecoratedFunction", bound=Callable[..., Any])
"Define a generic type for the decorated function"


@dataclass
class SpinnerConfig:
    working: List[str] = field(default_factory=lambda: ["|", "/", "-", "\\"])
    success: str = "✔️"
    fail: str = "❌"

    def is_ascii(self) -> bool:
        "whether all characters are ascii"
        return all(s.isascii() for s in self.working + [self.success, self.fail])

    def eq_lens(self) -> bool:
        "whether all working characters are the same length"
        expected_len: int = len(self.working[0])
        return all(
            [
                len(char) == expected_len
                for char in self.working + [self.success, self.fail]
            ]
        )

    def is_valid(self) -> bool:
        "whether the spinner config is valid"
        return all(
            [
                len(self.working) > 0,
                isinstance(self.working, list),
                isinstance(self.success, str),
                isinstance(self.fail, str),
                all(isinstance(char, str) for char in self.working),
            ]
        )

    def __post_init__(self):
        if not self.is_valid():
            raise ValueError(f"Invalid SpinnerConfig: {self}")

    @classmethod
    def from_any(cls, arg: "SpinnerConfigArg") -> "SpinnerConfig":
        if isinstance(arg, str):
            return SPINNERS[arg]
        elif isinstance(arg, list):
            return SpinnerConfig(working=arg)
        elif isinstance(arg, dict):
            return SpinnerConfig(**arg)
        elif isinstance(arg, SpinnerConfig):
            return arg
        else:
            raise TypeError(
                f"to create a SpinnerConfig, you must pass a string (key), list (working seq), dict (kwargs to SpinnerConfig), or SpinnerConfig, but got {type(arg) = }, {arg = }"
            )


SpinnerConfigArg = Union[str, List[str], SpinnerConfig, dict]

SPINNERS: Dict[str, SpinnerConfig] = dict(
    default=SpinnerConfig(working=["|", "/", "-", "\\"], success="#", fail="X"),
    dots=SpinnerConfig(working=[".  ", ".. ", "..."], success="***", fail="xxx"),
    bars=SpinnerConfig(working=["|  ", "|| ", "|||"], success="|||", fail="///"),
    arrows=SpinnerConfig(working=["<", "^", ">", "v"], success="►", fail="✖"),
    arrows_2=SpinnerConfig(
        working=["←", "↖", "↑", "↗", "→", "↘", "↓", "↙"], success="→", fail="↯"
    ),
    bouncing_bar=SpinnerConfig(
        working=["[    ]", "[=   ]", "[==  ]", "[=== ]", "[ ===]", "[  ==]", "[   =]"],
        success="[====]",
        fail="[XXXX]",
    ),
    bar=SpinnerConfig(
        working=["[  ]", "[- ]", "[--]", "[ -]"],
        success="[==]",
        fail="[xx]",
    ),
    bouncing_ball=SpinnerConfig(
        working=[
            "( ●    )",
            "(  ●   )",
            "(   ●  )",
            "(    ● )",
            "(     ●)",
            "(    ● )",
            "(   ●  )",
            "(  ●   )",
            "( ●    )",
            "(●     )",
        ],
        success="(●●●●●●)",
        fail="(  ✖  )",
    ),
    ooo=SpinnerConfig(working=[".", "o", "O", "o"], success="O", fail="x"),
    braille=SpinnerConfig(
        working=["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"],
        success="⣿",
        fail="X",
    ),
    clock=SpinnerConfig(
        working=[
            "🕛",
            "🕐",
            "🕑",
            "🕒",
            "🕓",
            "🕔",
            "🕕",
            "🕖",
            "🕗",
            "🕘",
            "🕙",
            "🕚",
        ],
        success="✔️",
        fail="❌",
    ),
    hourglass=SpinnerConfig(working=["⏳", "⌛"], success="✔️", fail="❌"),
    square_corners=SpinnerConfig(working=["◰", "◳", "◲", "◱"], success="◼", fail="✖"),
    triangle=SpinnerConfig(working=["◢", "◣", "◤", "◥"], success="◆", fail="✖"),
    square_dot=SpinnerConfig(
        working=["⣷", "⣯", "⣟", "⡿", "⢿", "⣻", "⣽", "⣾"], success="⣿", fail="❌"
    ),
    box_bounce=SpinnerConfig(working=["▌", "▀", "▐", "▄"], success="■", fail="✖"),
    hamburger=SpinnerConfig(working=["☱", "☲", "☴"], success="☰", fail="✖"),
    earth=SpinnerConfig(working=["🌍", "🌎", "🌏"], success="✔️", fail="❌"),
    growing_dots=SpinnerConfig(
        working=["⣀", "⣄", "⣤", "⣦", "⣶", "⣷", "⣿"], success="⣿", fail="✖"
    ),
    dice=SpinnerConfig(working=["⚀", "⚁", "⚂", "⚃", "⚄", "⚅"], success="🎲", fail="✖"),
    wifi=SpinnerConfig(
        working=["▁", "▂", "▃", "▄", "▅", "▆", "▇", "█"], success="✔️", fail="❌"
    ),
    bounce=SpinnerConfig(working=["⠁", "⠂", "⠄", "⠂"], success="⠿", fail="⢿"),
    arc=SpinnerConfig(working=["◜", "◠", "◝", "◞", "◡", "◟"], success="○", fail="✖"),
    toggle=SpinnerConfig(working=["⊶", "⊷"], success="⊷", fail="⊗"),
    toggle2=SpinnerConfig(working=["▫", "▪"], success="▪", fail="✖"),
    toggle3=SpinnerConfig(working=["□", "■"], success="■", fail="✖"),
    toggle4=SpinnerConfig(working=["■", "□", "▪", "▫"], success="■", fail="✖"),
    toggle5=SpinnerConfig(working=["▮", "▯"], success="▮", fail="✖"),
    toggle7=SpinnerConfig(working=["⦾", "⦿"], success="⦿", fail="✖"),
    toggle8=SpinnerConfig(working=["◍", "◌"], success="◍", fail="✖"),
    toggle9=SpinnerConfig(working=["◉", "◎"], success="◉", fail="✖"),
    arrow2=SpinnerConfig(
        working=["⬆️ ", "↗️ ", "➡️ ", "↘️ ", "⬇️ ", "↙️ ", "⬅️ ", "↖️ "], success="➡️", fail="❌"
    ),
    point=SpinnerConfig(
        working=["∙∙∙", "●∙∙", "∙●∙", "∙∙●", "∙∙∙"], success="●●●", fail="xxx"
    ),
    layer=SpinnerConfig(working=["-", "=", "≡"], success="≡", fail="✖"),
    speaker=SpinnerConfig(
        working=["🔈 ", "🔉 ", "🔊 ", "🔉 "], success="🔊", fail="🔇"
    ),
    orangePulse=SpinnerConfig(
        working=["🔸 ", "🔶 ", "🟠 ", "🟠 ", "🔷 "], success="🟠", fail="❌"
    ),
    bluePulse=SpinnerConfig(
        working=["🔹 ", "🔷 ", "🔵 ", "🔵 ", "🔷 "], success="🔵", fail="❌"
    ),
    satellite_signal=SpinnerConfig(
        working=["📡   ", "📡·  ", "📡·· ", "📡···", "📡 ··", "📡  ·"],
        success="📡 ✔️ ",
        fail="📡 ❌ ",
    ),
    rocket_orbit=SpinnerConfig(
        working=["🌍🚀  ", "🌏 🚀 ", "🌎  🚀"], success="🌍  ✨", fail="🌍  💥"
    ),
    ogham=SpinnerConfig(working=["ᚁ ", "ᚂ ", "ᚃ ", "ᚄ", "ᚅ"], success="᚛᚜", fail="✖"),
    eth=SpinnerConfig(
        working=["᛫", "፡", "፥", "፤", "፧", "።", "፨"], success="፠", fail="✖"
    ),
)
# spinner configurations


class Spinner:
    """displays a spinner, and optionally elapsed time and a mutable value while a function is running.

    # Parameters:

    - `update_interval : float`
        how often to update the spinner display in seconds
        (defaults to `0.1`)
    - `initial_value : str`
        initial value to display with the spinner
        (defaults to `""`)
    - `message : str`
        message to display with the spinner
        (defaults to `""`)
    - `format_string : str`
        string to format the spinner with. must have `"\\r"` prepended to clear the line.
        allowed keys are `spinner`, `elapsed_time`, `message`, and `value`
        (defaults to `"\\r{spinner} ({elapsed_time:.2f}s) {message}{value}"`)
    - `output_stream : TextIO`
        stream to write the spinner to
        (defaults to `sys.stdout`)
    - `format_string_when_updated : Union[bool,str]`
        whether to use a different format string when the value is updated.
        if `True`, use the default format string with a newline appended. if a string, use that string.
        this is useful if you want update_value to print to console and be preserved.
        (defaults to `False`)

    # Deprecated Parameters:

    - `spinner_chars : Union[str, Sequence[str]]`
        sequence of strings, or key to look up in `SPINNER_CHARS`, to use as the spinner characters
        (defaults to `"default"`)
    - `spinner_complete : str`
        string to display when the spinner is complete
        (defaults to looking up `spinner_chars` in `SPINNER_COMPLETE` or `"#"`)

    # Methods:
    - `update_value(value: Any) -> None`
        update the current value displayed by the spinner

    # Usage:

    ## As a context manager:
    ```python
    with SpinnerContext() as sp:
        for i in range(1):
            time.sleep(0.1)
            spinner.update_value(f"Step {i+1}")
    ```

    ## As a decorator:
    ```python
    @spinner_decorator
    def long_running_function():
        for i in range(1):
            time.sleep(0.1)
            spinner.update_value(f"Step {i+1}")
        return "Function completed"
    ```
    """

    def __init__(
        self,
        # no positional args
        *args,
        config: SpinnerConfigArg = "default",
        update_interval: float = 0.1,
        initial_value: str = "",
        message: str = "",
        format_string: str = "\r{spinner} ({elapsed_time:.2f}s) {message}{value}",
        output_stream: TextIO = sys.stdout,
        format_string_when_updated: Union[str, bool] = False,
        # deprecated
        spinner_chars: Optional[Union[str, Sequence[str]]] = None,
        spinner_complete: Optional[str] = None,
        # no other kwargs accepted
        **kwargs: Any,
    ):
        if args:
            raise ValueError(f"Spinner does not accept positional arguments: {args}")
        if kwargs:
            raise ValueError(
                f"Spinner did not recognize these keyword arguments: {kwargs}"
            )

        # old spinner display
        if (spinner_chars is not None) or (spinner_complete is not None):
            warnings.warn(
                "spinner_chars and spinner_complete are deprecated and will have no effect. Use `config` instead.",
                DeprecationWarning,
            )

        # config
        self.config: SpinnerConfig = SpinnerConfig.from_any(config)

        # special format string for when the value is updated
        self.format_string_when_updated: Optional[str] = None
        "format string to use when the value is updated"
        if format_string_when_updated is not False:
            if format_string_when_updated is True:
                # modify the default format string
                self.format_string_when_updated = format_string + "\n"
            elif isinstance(format_string_when_updated, str):
                # use the provided format string
                self.format_string_when_updated = format_string_when_updated
            else:
                raise TypeError(
                    "format_string_when_updated must be a string or True, got"
                    + f" {type(format_string_when_updated) = }{format_string_when_updated}"
                )

        # copy other kwargs
        self.update_interval: float = update_interval
        self.message: str = message
        self.current_value: Any = initial_value
        self.format_string: str = format_string
        self.output_stream: TextIO = output_stream

        # test out format string
        try:
            self.format_string.format(
                spinner=self.config.working[0],
                elapsed_time=0.0,
                message=self.message,
                value=self.current_value,
            )
        except Exception as e:
            raise ValueError(
                f"Invalid format string: {format_string}. Must take keys "
                + "'spinner: str', 'elapsed_time: float', 'message: str', and 'value: Any'."
            ) from e

        # init
        self.start_time: float = 0
        "for measuring elapsed time"
        self.stop_spinner: threading.Event = threading.Event()
        "to stop the spinner"
        self.spinner_thread: Optional[threading.Thread] = None
        "the thread running the spinner"
        self.value_changed: bool = False
        "whether the value has been updated since the last display"
        self.term_width: int
        "width of the terminal, for padding with spaces"
        try:
            self.term_width = os.get_terminal_size().columns
        except OSError:
            self.term_width = 80

        # state of the spinner
        self.state: Literal["initialized", "running", "success", "fail"] = "initialized"

    def spin(self) -> None:
        "Function to run in a separate thread, displaying the spinner and optional information"
        i: int = 0
        while not self.stop_spinner.is_set():
            # get current spinner str
            spinner: str = self.config.working[i % len(self.config.working)]

            # args for display string
            display_parts: Dict[str, Any] = dict(
                spinner=spinner,  # str
                elapsed_time=time.time() - self.start_time,  # float
                message=self.message,  # str
                value=self.current_value,  # Any, but will be formatted as str
            )

            # use the special one if needed
            format_str: str = self.format_string
            if self.value_changed and (self.format_string_when_updated is not None):
                self.value_changed = False
                format_str = self.format_string_when_updated

            # write and flush the display string
            output: str = format_str.format(**display_parts).ljust(self.term_width)
            self.output_stream.write(output)
            self.output_stream.flush()

            # wait for the next update
            time.sleep(self.update_interval)
            i += 1

    def update_value(self, value: Any) -> None:
        "Update the current value displayed by the spinner"
        self.current_value = value
        self.value_changed = True

    def start(self) -> None:
        "Start the spinner"
        self.start_time = time.time()
        self.spinner_thread = threading.Thread(target=self.spin)
        self.spinner_thread.start()
        self.state = "running"

    def stop(self, failed: bool = False) -> None:
        "Stop the spinner"
        self.output_stream.write(
            self.format_string.format(
                spinner=self.config.success if not failed else self.config.fail,
                elapsed_time=time.time() - self.start_time,  # float
                message=self.message,  # str
                value=self.current_value,  # Any, but will be formatted as str
            ).ljust(self.term_width)
        )
        self.stop_spinner.set()
        if self.spinner_thread:
            self.spinner_thread.join()
        self.output_stream.write("\n")
        self.output_stream.flush()

        self.state = "fail" if failed else "success"


class NoOpContextManager(ContextManager):
    """A context manager that does nothing."""

    def __init__(self, *args, **kwargs):
        pass

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        pass


class SpinnerContext(Spinner, ContextManager):
    "see `Spinner` for parameters"

    def __enter__(self) -> "SpinnerContext":
        self.start()
        return self

    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
        self.stop(failed=exc_type is not None)


SpinnerContext.__doc__ = Spinner.__doc__


# TODO: type hint that the `update_status` kwarg is not needed when calling the function we just decorated
def spinner_decorator(
    *args,
    # passed to `Spinner.__init__`
    config: SpinnerConfigArg = "default",
    update_interval: float = 0.1,
    initial_value: str = "",
    message: str = "",
    format_string: str = "{spinner} ({elapsed_time:.2f}s) {message}{value}",
    output_stream: TextIO = sys.stdout,
    # new kwarg
    mutable_kwarg_key: Optional[str] = None,
    # deprecated
    spinner_chars: Union[str, Sequence[str], None] = None,
    spinner_complete: Optional[str] = None,
    **kwargs,
) -> Callable[[DecoratedFunction], DecoratedFunction]:
    """see `Spinner` for parameters. Also takes `mutable_kwarg_key`

    `mutable_kwarg_key` is the key with which `Spinner().update_value`
    will be passed to the decorated function. if `None`, won't pass it.

    """

    if len(args) > 1:
        raise ValueError(
            f"spinner_decorator does not accept positional arguments: {args}"
        )
    if kwargs:
        raise ValueError(
            f"spinner_decorator did not recognize these keyword arguments: {kwargs}"
        )

    def decorator(func: DecoratedFunction) -> DecoratedFunction:
        @wraps(func)
        def wrapper(*args: Any, **kwargs: Any) -> Any:
            spinner: Spinner = Spinner(
                config=config,
                update_interval=update_interval,
                initial_value=initial_value,
                message=message,
                format_string=format_string,
                output_stream=output_stream,
                spinner_chars=spinner_chars,
                spinner_complete=spinner_complete,
            )

            if mutable_kwarg_key:
                kwargs[mutable_kwarg_key] = spinner.update_value

            spinner.start()
            try:
                result: Any = func(*args, **kwargs)
                spinner.stop(failed=False)
            except Exception as e:
                spinner.stop(failed=True)
                raise e

            return result

        # TODO: fix this type ignore
        return wrapper  # type: ignore[return-value]

    if not args:
        # called as `@spinner_decorator(stuff)`
        return decorator
    else:
        # called as `@spinner_decorator` without parens
        return decorator(args[0])


spinner_decorator.__doc__ = Spinner.__doc__

``````{ end_of_file="muutils/spinner.py" }

``````{ path="muutils/statcounter.py"  }
"""`StatCounter` class for counting and calculating statistics on numbers

cleaner and more efficient than just using a `Counter` or array"""

from __future__ import annotations

import json
import math
from collections import Counter
from functools import cached_property
from itertools import chain
from typing import Callable, Optional, Sequence, Union


# _GeneralArray = Union[np.ndarray, "torch.Tensor"]
NumericSequence = Sequence[Union[float, int, "NumericSequence"]]

# pylint: disable=abstract-method

# misc
# ==================================================


def universal_flatten(
    arr: Union[NumericSequence, float, int], require_rectangular: bool = True
) -> NumericSequence:
    """flattens any iterable"""

    # mypy complains that the sequence has no attribute "flatten"
    if hasattr(arr, "flatten") and callable(arr.flatten):  # type: ignore
        return arr.flatten()  # type: ignore
    elif isinstance(arr, Sequence):
        elements_iterable: list[bool] = [isinstance(x, Sequence) for x in arr]
        if require_rectangular and (all(elements_iterable) != any(elements_iterable)):
            raise ValueError("arr contains mixed iterable and non-iterable elements")
        if any(elements_iterable):
            return list(chain.from_iterable(universal_flatten(x) for x in arr))  # type: ignore[misc]
        else:
            return arr
    else:
        return [arr]


# StatCounter
# ==================================================


class StatCounter(Counter):
    """`Counter`, but with some stat calculation methods which assume the keys are numerical

    works best when the keys are `int`s
    """

    def validate(self) -> bool:
        """validate the counter as being all floats or ints"""
        return all(isinstance(k, (bool, int, float, type(None))) for k in self.keys())

    def min(self):
        "minimum value"
        return min(x for x, v in self.items() if v > 0)

    def max(self):
        "maximum value"
        return max(x for x, v in self.items() if v > 0)

    def total(self):
        """Sum of the counts"""
        return sum(self.values())

    @cached_property
    def keys_sorted(self) -> list:
        """return the keys"""
        return sorted(list(self.keys()))

    def percentile(self, p: float):
        """return the value at the given percentile

        this could be log time if we did binary search, but that would be a lot of added complexity
        """

        if p < 0 or p > 1:
            raise ValueError(f"percentile must be between 0 and 1: {p}")
        # flip for speed
        sorted_keys: list[float] = [float(x) for x in self.keys_sorted]
        sort: int = 1
        if p > 0.51:
            sort = -1
            p = 1 - p

        sorted_keys = sorted_keys[::sort]
        real_target: float = p * (self.total() - 1)

        n_target_f: int = math.floor(real_target)
        n_target_c: int = math.ceil(real_target)

        n_sofar: float = -1

        # print(f'{p = } {real_target = } {n_target_f = } {n_target_c = }')

        for i, k in enumerate(sorted_keys):
            n_sofar += self[k]

            # print(f'{k = } {n_sofar = }')

            if n_sofar > n_target_f:
                return k

            elif n_sofar == n_target_f:
                if n_sofar == n_target_c:
                    return k
                else:
                    # print(
                    #     sorted_keys[i], (n_sofar + 1 - real_target),
                    #     sorted_keys[i + 1], (real_target - n_sofar),
                    # )
                    return sorted_keys[i] * (n_sofar + 1 - real_target) + sorted_keys[
                        i + 1
                    ] * (real_target - n_sofar)
            else:
                continue

        raise ValueError(f"percentile {p} not found???")

    def median(self) -> float:
        return self.percentile(0.5)

    def mean(self) -> float:
        """return the mean of the values"""
        return float(sum(k * c for k, c in self.items()) / self.total())

    def mode(self) -> float:
        return self.most_common()[0][0]

    def std(self) -> float:
        """return the standard deviation of the values"""
        mean: float = self.mean()
        deviations: float = sum(c * (k - mean) ** 2 for k, c in self.items())

        return (deviations / self.total()) ** 0.5

    def summary(
        self,
        typecast: Callable = lambda x: x,
        *,
        extra_percentiles: Optional[list[float]] = None,
    ) -> dict[str, Union[float, int]]:
        """return a summary of the stats, without the raw data. human readable and small"""
        # common stats that always work
        output: dict = dict(
            total_items=self.total(),
            n_keys=len(self.keys()),
            mode=self.mode(),
        )

        if self.total() > 0:
            if self.validate():
                # if its a numeric counter, we can do some stats
                output = {
                    **output,
                    **dict(
                        mean=float(self.mean()),
                        std=float(self.std()),
                        min=typecast(self.min()),
                        q1=typecast(self.percentile(0.25)),
                        median=typecast(self.median()),
                        q3=typecast(self.percentile(0.75)),
                        max=typecast(self.max()),
                    ),
                }

                if extra_percentiles is not None:
                    for p in extra_percentiles:
                        output[f"percentile_{p}"] = typecast(self.percentile(p))
            else:
                # if its not, we can only do the simpler things
                # mean mode and total are done in the initial declaration of `output`
                pass

        return output

    def serialize(
        self,
        typecast: Callable = lambda x: x,
        *,
        extra_percentiles: Optional[list[float]] = None,
    ) -> dict:
        """return a json-serializable version of the counter

        includes both the output of `summary` and the raw data:

        ```json
        {
            "StatCounter": { <keys, values from raw data> },
            "summary": self.summary(typecast, extra_percentiles=extra_percentiles),
        }

        """

        return {
            "StatCounter": {
                typecast(k): v
                for k, v in sorted(dict(self).items(), key=lambda x: x[0])
            },
            "summary": self.summary(typecast, extra_percentiles=extra_percentiles),
        }

    def __str__(self) -> str:
        "summary as json with 2 space indent, good for printing"
        return json.dumps(self.summary(), indent=2)

    def __repr__(self) -> str:
        return json.dumps(self.serialize(), indent=2)

    @classmethod
    def load(cls, data: dict) -> "StatCounter":
        "load from a the output of `StatCounter.serialize`"
        if "StatCounter" in data:
            loadme = data["StatCounter"]
        else:
            loadme = data

        return cls({float(k): v for k, v in loadme.items()})

    @classmethod
    def from_list_arrays(
        cls,
        arr,
        map_func: Callable = float,
    ) -> "StatCounter":
        """calls `map_func` on each element of `universal_flatten(arr)`"""
        return cls([map_func(x) for x in universal_flatten(arr)])

``````{ end_of_file="muutils/statcounter.py" }

``````{ path="muutils/sysinfo.py"  }
"utilities for getting information about the system, see `SysInfo` class"

from __future__ import annotations

import subprocess
import sys
import typing
from importlib.metadata import distributions


def _popen(cmd: list[str], split_out: bool = False) -> dict[str, typing.Any]:
    p: subprocess.Popen = subprocess.Popen(
        cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
    )

    stdout, stderr = p.communicate()

    p_out: typing.Union[str, list[str], None]
    if stdout:
        p_out = stdout.decode("utf-8")
        if split_out:
            assert isinstance(p_out, str)
            p_out = p_out.strip().split("\n")
    else:
        p_out = None

    return {
        "stdout": p_out,
        "stderr": stderr.decode("utf-8") if stderr else None,
        "returncode": p.returncode if p.returncode is None else int(p.returncode),
    }


class SysInfo:
    """getters for various information about the system"""

    @staticmethod
    def python() -> dict:
        """details about python version"""
        ver_tup = sys.version_info
        return {
            "version": sys.version,
            "version_info": ver_tup,
            "major": ver_tup[0],
            "minor": ver_tup[1],
            "micro": ver_tup[2],
            "releaselevel": ver_tup[3],
            "serial": ver_tup[4],
        }

    @staticmethod
    def pip() -> dict:
        """installed packages info"""
        # in python <= 3.9  `Distribution` has no attribute `name`
        pckgs: list[tuple[str, str]] = [
            (
                (
                    x.metadata.get("Name", "<unknown>")  # type: ignore[attr-defined]
                    if sys.version_info < (3, 10)
                    else x.name  # type: ignore[attr-defined]
                ),
                x.version,
            )
            for x in distributions()
        ]
        return {
            "n_packages": len(pckgs),
            "packages": pckgs,
        }

    @staticmethod
    def pytorch() -> dict:
        """pytorch and cuda information"""
        try:
            import torch
            import torch.version
        except Exception as e:
            return {
                "importable": False,
                "error": str(e),
            }

        output: dict = {"importable": True}

        output["torch.__version__"] = torch.__version__
        output["torch.version.cuda"] = torch.version.cuda
        output["torch.version.debug"] = torch.version.debug
        output["torch.version.git_version"] = torch.version.git_version
        output["torch.version.hip"] = torch.version.hip
        output["torch.cuda.is_available()"] = torch.cuda.is_available()
        output["torch.cuda.device_count()"] = torch.cuda.device_count()
        output["torch.cuda.is_initialized()"] = torch.cuda.is_initialized()

        if torch.cuda.is_available():
            import os

            cuda_version_nvcc: str = os.popen("nvcc --version").read()
            output["nvcc --version"] = cuda_version_nvcc.split("\n")

            if torch.cuda.device_count() > 0:
                n_devices: int = torch.cuda.device_count()
                output["torch.cuda.current_device()"] = torch.cuda.current_device()
                output["torch devices"] = []
                for current_device in range(n_devices):
                    try:
                        # print(f'checking current device {current_device} of {torch.cuda.device_count()} devices')
                        # print(f'\tdevice {current_device}')
                        # dev_prop = torch.cuda.get_device_properties(torch.device(0))
                        # print(f'\t    name:                   {dev_prop.name}')
                        # print(f'\t    version:                {dev_prop.major}.{dev_prop.minor}')
                        # print(f'\t    total_memory:           {dev_prop.total_memory}')
                        # print(f'\t    multi_processor_count:  {dev_prop.multi_processor_count}')
                        # print(f'\t')
                        dev_prop = torch.cuda.get_device_properties(current_device)
                        output["torch devices"].append(
                            {
                                "device": current_device,
                                "name": dev_prop.name,
                                "version": {
                                    "major": dev_prop.major,
                                    "minor": dev_prop.minor,
                                },
                                "total_memory": dev_prop.total_memory,
                                "multi_processor_count": dev_prop.multi_processor_count,
                            }
                        )
                    except Exception as e:
                        output["torch devices"].append(
                            {
                                "device": current_device,
                                "error": str(e),
                            }
                        )
        return output

    @staticmethod
    def platform() -> dict:
        import platform

        items = [
            "platform",
            "machine",
            "processor",
            "system",
            "version",
            "architecture",
            "uname",
            "node",
            "python_branch",
            "python_build",
            "python_compiler",
            "python_implementation",
        ]

        return {x: getattr(platform, x)() for x in items}

    @staticmethod
    def git_info(with_log: bool = False) -> dict:
        git_version: dict = _popen(["git", "version"])
        git_status: dict = _popen(["git", "status"])
        if not git_status["stderr"] or git_status["stderr"].startswith(
            "fatal: not a git repository"
        ):
            return {
                "git version": git_version["stdout"],
                "git status": git_status,
            }
        else:
            output: dict = {
                "git version": git_version["stdout"],
                "git status": git_status,
                "git branch": _popen(["git", "branch"], split_out=True),
                "git remote -v": _popen(["git", "remote", "-v"], split_out=True),
            }
            if with_log:
                output["git log"] = _popen(["git", "log"], split_out=False)

            return output

    @classmethod
    def get_all(
        cls,
        include: typing.Optional[tuple[str, ...]] = None,
        exclude: tuple[str, ...] = tuple(),
    ) -> dict:
        include_meta: tuple[str, ...]
        if include is None:
            include_meta = tuple(cls.__dict__.keys())
        else:
            include_meta = include

        return {
            x: getattr(cls, x)()
            for x in include_meta
            if all(
                [
                    not x.startswith("_"),
                    x not in exclude,
                    callable(getattr(cls, x)),
                    x != "get_all",
                    x in include if include is not None else True,
                ]
            )
        }


if __name__ == "__main__":
    import pprint

    pprint.pprint(SysInfo.get_all())

``````{ end_of_file="muutils/sysinfo.py" }

``````{ path="muutils/tensor_info.py"  }
"get metadata about a tensor, mostly for `muutils.dbg`"

from __future__ import annotations

import numpy as np
from typing import Union, Any, Literal, List, Dict, overload, Optional

# Global color definitions
COLORS: Dict[str, Dict[str, str]] = {
    "latex": {
        "range": r"\textcolor{purple}",
        "mean": r"\textcolor{teal}",
        "std": r"\textcolor{orange}",
        "median": r"\textcolor{green}",
        "warning": r"\textcolor{red}",
        "shape": r"\textcolor{magenta}",
        "dtype": r"\textcolor{gray}",
        "device": r"\textcolor{gray}",
        "requires_grad": r"\textcolor{gray}",
        "sparkline": r"\textcolor{blue}",
        "torch": r"\textcolor{orange}",
        "dtype_bool": r"\textcolor{gray}",
        "dtype_int": r"\textcolor{blue}",
        "dtype_float": r"\textcolor{red!70}",  # 70% red intensity
        "dtype_str": r"\textcolor{red}",
        "device_cuda": r"\textcolor{green}",
        "reset": "",
    },
    "terminal": {
        "range": "\033[35m",  # purple
        "mean": "\033[36m",  # cyan/teal
        "std": "\033[33m",  # yellow/orange
        "median": "\033[32m",  # green
        "warning": "\033[31m",  # red
        "shape": "\033[95m",  # bright magenta
        "dtype": "\033[90m",  # gray
        "device": "\033[90m",  # gray
        "requires_grad": "\033[90m",  # gray
        "sparkline": "\033[34m",  # blue
        "torch": "\033[38;5;208m",  # bright orange
        "dtype_bool": "\033[38;5;245m",  # medium grey
        "dtype_int": "\033[38;5;39m",  # bright blue
        "dtype_float": "\033[38;5;167m",  # softer red/coral
        "device_cuda": "\033[38;5;76m",  # NVIDIA-style bright green
        "reset": "\033[0m",
    },
    "none": {
        "range": "",
        "mean": "",
        "std": "",
        "median": "",
        "warning": "",
        "shape": "",
        "dtype": "",
        "device": "",
        "requires_grad": "",
        "sparkline": "",
        "torch": "",
        "dtype_bool": "",
        "dtype_int": "",
        "dtype_float": "",
        "dtype_str": "",
        "device_cuda": "",
        "reset": "",
    },
}

OutputFormat = Literal["unicode", "latex", "ascii"]

SYMBOLS: Dict[OutputFormat, Dict[str, str]] = {
    "latex": {
        "range": r"\mathcal{R}",
        "mean": r"\mu",
        "std": r"\sigma",
        "median": r"\tilde{x}",
        "distribution": r"\mathbb{P}",
        "distribution_log": r"\mathbb{P}_L",
        "nan_values": r"\text{NANvals}",
        "warning": "!!!",
        "requires_grad": r"\nabla",
        "true": r"\checkmark",
        "false": r"\times",
    },
    "unicode": {
        "range": "R",
        "mean": "μ",
        "std": "σ",
        "median": "x̃",
        "distribution": "ℙ",
        "distribution_log": "ℙ˪",
        "nan_values": "NANvals",
        "warning": "🚨",
        "requires_grad": "∇",
        "true": "✓",
        "false": "✗",
    },
    "ascii": {
        "range": "range",
        "mean": "mean",
        "std": "std",
        "median": "med",
        "distribution": "dist",
        "distribution_log": "dist_log",
        "nan_values": "NANvals",
        "warning": "!!!",
        "requires_grad": "requires_grad",
        "true": "1",
        "false": "0",
    },
}
"Symbols for different formats"

SPARK_CHARS: Dict[OutputFormat, List[str]] = {
    "unicode": list(" ▁▂▃▄▅▆▇█"),
    "ascii": list(" _.-~=#"),
    "latex": list(" ▁▂▃▄▅▆▇█"),
}
"characters for sparklines in different formats"


def array_info(
    A: Any,
    hist_bins: int = 5,
) -> Dict[str, Any]:
    """Extract statistical information from an array-like object.

    # Parameters:
     - `A : array-like`
            Array to analyze (numpy array or torch tensor)

    # Returns:
     - `Dict[str, Any]`
            Dictionary containing raw statistical information with numeric values
    """
    result: Dict[str, Any] = {
        "is_tensor": None,
        "device": None,
        "requires_grad": None,
        "shape": None,
        "dtype": None,
        "size": None,
        "has_nans": None,
        "nan_count": None,
        "nan_percent": None,
        "min": None,
        "max": None,
        "range": None,
        "mean": None,
        "std": None,
        "median": None,
        "histogram": None,
        "bins": None,
        "status": None,
    }

    # Check if it's a tensor by looking at its class name
    # This avoids importing torch directly
    A_type: str = type(A).__name__
    result["is_tensor"] = A_type == "Tensor"

    # Try to get device information if it's a tensor
    if result["is_tensor"]:
        try:
            result["device"] = str(getattr(A, "device", None))
        except:  # noqa: E722
            pass

    # Convert to numpy array for calculations
    try:
        # For PyTorch tensors
        if result["is_tensor"]:
            # Check if tensor is on GPU
            is_cuda: bool = False
            try:
                is_cuda = bool(getattr(A, "is_cuda", False))
            except:  # noqa: E722
                pass

            if is_cuda:
                try:
                    # Try to get CPU tensor first
                    cpu_tensor = getattr(A, "cpu", lambda: A)()
                except:  # noqa: E722
                    A_np = np.array([])
            else:
                cpu_tensor = A
            try:
                # For CPU tensor, just detach and convert
                detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)()
                A_np = getattr(detached, "numpy", lambda: np.array([]))()
            except:  # noqa: E722
                A_np = np.array([])
        else:
            # For numpy arrays and other array-like objects
            A_np = np.asarray(A)
    except:  # noqa: E722
        A_np = np.array([])

    # Get basic information
    try:
        result["shape"] = A_np.shape
        result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype)
        result["size"] = A_np.size
        result["requires_grad"] = getattr(A, "requires_grad", None)
    except:  # noqa: E722
        pass

    # If array is empty, return early
    if result["size"] == 0:
        result["status"] = "empty array"
        return result

    # Flatten array for statistics if it's multi-dimensional
    # TODO: type checks fail on 3.10, see https://github.com/mivanit/muutils/actions/runs/18883100459/job/53891346225
    try:
        if len(A_np.shape) > 1:
            A_flat = A_np.flatten()  # type: ignore[assignment]
        else:
            A_flat = A_np  # type: ignore[assignment]
    except:  # noqa: E722
        A_flat = A_np  # type: ignore[assignment]

    # Check for NaN values
    try:
        nan_mask = np.isnan(A_flat)
        result["nan_count"] = np.sum(nan_mask)
        result["has_nans"] = result["nan_count"] > 0
        if result["size"] > 0:
            result["nan_percent"] = (result["nan_count"] / result["size"]) * 100
    except:  # noqa: E722
        pass

    # If all values are NaN, return early
    if result["has_nans"] and result["nan_count"] == result["size"]:
        result["status"] = "all NaN"
        return result

    # Calculate statistics
    try:
        if result["has_nans"]:
            result["min"] = float(np.nanmin(A_flat))
            result["max"] = float(np.nanmax(A_flat))
            result["mean"] = float(np.nanmean(A_flat))
            result["std"] = float(np.nanstd(A_flat))
            result["median"] = float(np.nanmedian(A_flat))
            result["range"] = (result["min"], result["max"])

            # Remove NaNs for histogram
            A_hist = A_flat[~nan_mask]
        else:
            result["min"] = float(np.min(A_flat))
            result["max"] = float(np.max(A_flat))
            result["mean"] = float(np.mean(A_flat))
            result["std"] = float(np.std(A_flat))
            result["median"] = float(np.median(A_flat))
            result["range"] = (result["min"], result["max"])

            A_hist = A_flat

        # Calculate histogram data for sparklines
        if A_hist.size > 0:
            try:
                # TODO: handle bool tensors correctly
                # muutils/tensor_info.py:238: RuntimeWarning: Converting input from bool to <class 'numpy.uint8'> for compatibility.
                hist, bins = np.histogram(A_hist, bins=hist_bins)
                result["histogram"] = hist
                result["bins"] = bins
            except:  # noqa: E722
                pass

        result["status"] = "ok"
    except Exception as e:
        result["status"] = f"error: {str(e)}"

    return result


def generate_sparkline(
    histogram: np.ndarray,
    format: Literal["unicode", "latex", "ascii"] = "unicode",
    log_y: Optional[bool] = None,
) -> tuple[str, bool]:
    """Generate a sparkline visualization of the histogram.

    # Parameters:
    - `histogram : np.ndarray`
        Histogram data
    - `format : Literal["unicode", "latex", "ascii"]`
        Output format (defaults to `"unicode"`)
    - `log_y : bool|None`
        Whether to use logarithmic y-scale. `None` for automatic detection
        (defaults to `None`)

    # Returns:
    - `tuple[str, bool]`
        Sparkline visualization and whether log scale was used
    """
    if histogram is None or len(histogram) == 0:
        return "", False

    # Get the appropriate character set
    chars: List[str]
    if format in SPARK_CHARS:
        chars = SPARK_CHARS[format]
    else:
        chars = SPARK_CHARS["ascii"]

    # automatic detection of log_y
    if log_y is None:
        # we bin the histogram values to the number of levels in our sparkline characters
        hist_hist = np.histogram(histogram, bins=len(chars))[0]
        # if every bin except the smallest (first) and largest (last) is empty,
        # then we should use the log scale. if those bins are nonempty, keep the linear scale
        if hist_hist[1:-1].max() > 0:
            log_y = False
        else:
            log_y = True

    # Handle log scale
    if log_y:
        # Add small value to avoid log(0)
        hist_data = np.log1p(histogram)
    else:
        hist_data = histogram

    # Normalize to character set range
    if hist_data.max() > 0:
        normalized = hist_data / hist_data.max() * (len(chars) - 1)
    else:
        normalized = np.zeros_like(hist_data)

    # Convert to characters
    spark = ""
    for val in normalized:
        idx = round(val)
        spark += chars[idx]

    return spark, log_y


DEFAULT_SETTINGS: Dict[str, Any] = dict(
    fmt="unicode",
    precision=2,
    stats=True,
    shape=True,
    dtype=True,
    device=True,
    requires_grad=True,
    sparkline=False,
    sparkline_bins=5,
    sparkline_logy=None,
    colored=False,
    as_list=False,
    eq_char="=",
)


def apply_color(
    text: str, color_key: str, colors: Dict[str, str], using_tex: bool
) -> str:
    if using_tex:
        return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
    else:
        return (
            f"{colors[color_key]}{text}{colors['reset']}" if colors[color_key] else text
        )


def colorize_dtype(dtype_str: str, colors: Dict[str, str], using_tex: bool) -> str:
    """Colorize dtype string with specific colors for torch and type names."""

    # Handle torch prefix
    type_part: str = dtype_str
    prefix_part: Optional[str] = None
    if "torch." in dtype_str:
        parts = dtype_str.split("torch.")
        if len(parts) == 2:
            prefix_part = apply_color("torch", "torch", colors, using_tex)
            type_part = parts[1]

    # Handle type coloring
    color_key: str = "dtype"
    if "bool" in dtype_str.lower():
        color_key = "dtype_bool"
    elif "int" in dtype_str.lower():
        color_key = "dtype_int"
    elif "float" in dtype_str.lower():
        color_key = "dtype_float"

    type_colored: str = apply_color(type_part, color_key, colors, using_tex)

    if prefix_part:
        return f"{prefix_part}.{type_colored}"
    else:
        return type_colored


def format_shape_colored(shape_val, colors: Dict[str, str], using_tex: bool) -> str:
    """Format shape with proper coloring for both 1D and multi-D arrays."""

    def apply_color(text: str, color_key: str) -> str:
        if using_tex:
            return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
        else:
            return (
                f"{colors[color_key]}{text}{colors['reset']}"
                if colors[color_key]
                else text
            )

    if len(shape_val) == 1:
        # For 1D arrays, still color the dimension value
        return apply_color(str(shape_val[0]), "shape")
    else:
        # For multi-D arrays, color each dimension
        return "(" + ",".join(apply_color(str(dim), "shape") for dim in shape_val) + ")"


def format_device_colored(
    device_str: str, colors: Dict[str, str], using_tex: bool
) -> str:
    """Format device string with CUDA highlighting."""

    def apply_color(text: str, color_key: str) -> str:
        if using_tex:
            return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
        else:
            return (
                f"{colors[color_key]}{text}{colors['reset']}"
                if colors[color_key]
                else text
            )

    if "cuda" in device_str.lower():
        return apply_color(device_str, "device_cuda")
    else:
        return apply_color(device_str, "device")


class _UseDefaultType:
    pass


_USE_DEFAULT = _UseDefaultType()


@overload
def array_summary(
    array: Any,
    as_list: Literal[True],
    **kwargs,
) -> List[str]: ...
@overload
def array_summary(
    array: Any,
    as_list: Literal[False],
    **kwargs,
) -> str: ...
def array_summary(  # type: ignore[misc]
    array,
    fmt: OutputFormat = _USE_DEFAULT,  # type: ignore[assignment]
    precision: int = _USE_DEFAULT,  # type: ignore[assignment]
    stats: bool = _USE_DEFAULT,  # type: ignore[assignment]
    shape: bool = _USE_DEFAULT,  # type: ignore[assignment]
    dtype: bool = _USE_DEFAULT,  # type: ignore[assignment]
    device: bool = _USE_DEFAULT,  # type: ignore[assignment]
    requires_grad: bool = _USE_DEFAULT,  # type: ignore[assignment]
    sparkline: bool = _USE_DEFAULT,  # type: ignore[assignment]
    sparkline_bins: int = _USE_DEFAULT,  # type: ignore[assignment]
    sparkline_logy: Optional[bool] = _USE_DEFAULT,  # type: ignore[assignment]
    colored: bool = _USE_DEFAULT,  # type: ignore[assignment]
    eq_char: str = _USE_DEFAULT,  # type: ignore[assignment]
    as_list: bool = _USE_DEFAULT,  # type: ignore[assignment]
) -> Union[str, List[str]]:
    """Format array information into a readable summary.

    # Parameters:
     - `array`
            array-like object (numpy array or torch tensor)
     - `precision : int`
            Decimal places (defaults to `2`)
     - `format : Literal["unicode", "latex", "ascii"]`
            Output format (defaults to `{default_fmt}`)
     - `stats : bool`
            Whether to include statistical info (μ, σ, x̃) (defaults to `True`)
     - `shape : bool`
            Whether to include shape info (defaults to `True`)
     - `dtype : bool`
            Whether to include dtype info (defaults to `True`)
     - `device : bool`
            Whether to include device info for torch tensors (defaults to `True`)
     - `requires_grad : bool`
            Whether to include requires_grad info for torch tensors (defaults to `True`)
     - `sparkline : bool`
            Whether to include a sparkline visualization (defaults to `False`)
     - `sparkline_width : int`
            Width of the sparkline (defaults to `20`)
     - `sparkline_logy : bool|None`
            Whether to use logarithmic y-scale for sparkline (defaults to `None`)
     - `colored : bool`
            Whether to add color to output (defaults to `False`)
     - `as_list : bool`
            Whether to return as list of strings instead of joined string (defaults to `False`)

    # Returns:
     - `Union[str, List[str]]`
            Formatted statistical summary, either as string or list of strings
    """
    if fmt is _USE_DEFAULT:
        fmt = DEFAULT_SETTINGS["fmt"]
    if precision is _USE_DEFAULT:
        precision = DEFAULT_SETTINGS["precision"]
    if stats is _USE_DEFAULT:
        stats = DEFAULT_SETTINGS["stats"]
    if shape is _USE_DEFAULT:
        shape = DEFAULT_SETTINGS["shape"]
    if dtype is _USE_DEFAULT:
        dtype = DEFAULT_SETTINGS["dtype"]
    if device is _USE_DEFAULT:
        device = DEFAULT_SETTINGS["device"]
    if requires_grad is _USE_DEFAULT:
        requires_grad = DEFAULT_SETTINGS["requires_grad"]
    if sparkline is _USE_DEFAULT:
        sparkline = DEFAULT_SETTINGS["sparkline"]
    if sparkline_bins is _USE_DEFAULT:
        sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"]
    if sparkline_logy is _USE_DEFAULT:
        sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"]
    if colored is _USE_DEFAULT:
        colored = DEFAULT_SETTINGS["colored"]
    if as_list is _USE_DEFAULT:
        as_list = DEFAULT_SETTINGS["as_list"]
    if eq_char is _USE_DEFAULT:
        eq_char = DEFAULT_SETTINGS["eq_char"]

    array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins)
    result_parts: List[str] = []
    using_tex: bool = fmt == "latex"

    # Set color scheme based on format and colored flag
    colors: Dict[str, str]
    if colored:
        colors = COLORS["latex"] if using_tex else COLORS["terminal"]
    else:
        colors = COLORS["none"]

    # Get symbols for the current format
    symbols: Dict[str, str] = SYMBOLS[fmt]

    # Helper function to colorize text
    def colorize(text: str, color_key: str) -> str:
        if using_tex:
            return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
        else:
            return (
                f"{colors[color_key]}{text}{colors['reset']}"
                if colors[color_key]
                else text
            )

    # Check if dtype is integer type
    dtype_str: str = array_data.get("dtype", "")
    is_int_dtype: bool = any(
        int_type in dtype_str.lower() for int_type in ["int", "uint", "bool"]
    )

    # Format string for numbers
    float_fmt: str = f".{precision}f"

    # Handle error status or empty array
    if (
        array_data["status"] in ["empty array", "all NaN", "unknown"]
        or array_data["size"] == 0
    ):
        status = array_data["status"]
        result_parts.append(colorize(symbols["warning"] + " " + status, "warning"))
    else:
        # Add NaN warning at the beginning if there are NaNs
        if array_data["has_nans"]:
            _percent: str = "\\%" if using_tex else "%"
            nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})"
            result_parts.append(colorize(nan_str, "warning"))

        # Statistics
        if stats:
            for stat_key in ["mean", "std", "median"]:
                if array_data[stat_key] is not None:
                    stat_str: str = f"{array_data[stat_key]:{float_fmt}}"
                    stat_colored: str = colorize(stat_str, stat_key)
                    result_parts.append(f"{symbols[stat_key]}={stat_colored}")

            # Range (min, max)
            if array_data["range"] is not None:
                min_val, max_val = array_data["range"]
                if is_int_dtype:
                    min_str: str = f"{int(min_val):d}"
                    max_str: str = f"{int(max_val):d}"
                else:
                    min_str = f"{min_val:{float_fmt}}"
                    max_str = f"{max_val:{float_fmt}}"
                min_colored: str = colorize(min_str, "range")
                max_colored: str = colorize(max_str, "range")
                range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]"
                result_parts.append(range_str)

    # Add sparkline if requested
    if sparkline and array_data["histogram"] is not None:
        # this should return whether log_y is used or not and then we set the symbol accordingly
        spark, used_log = generate_sparkline(
            array_data["histogram"],
            format=fmt,
            log_y=sparkline_logy,
        )
        if spark:
            spark_colored = colorize(spark, "sparkline")
            dist_symbol = (
                symbols["distribution_log"] if used_log else symbols["distribution"]
            )
            result_parts.append(f"{dist_symbol}{eq_char}|{spark_colored}|")

    # Add shape if requested
    if shape and array_data["shape"]:
        shape_val = array_data["shape"]
        shape_str = format_shape_colored(shape_val, colors, using_tex)
        result_parts.append(f"shape{eq_char}{shape_str}")

    # Add dtype if requested
    if dtype and array_data["dtype"]:
        dtype_colored = colorize_dtype(array_data["dtype"], colors, using_tex)
        result_parts.append(f"dtype={dtype_colored}")

    # Add device if requested and it's a tensor with device info
    if device and array_data["is_tensor"] and array_data["device"]:
        device_colored = format_device_colored(array_data["device"], colors, using_tex)
        result_parts.append(f"device{eq_char}{device_colored}")

    # Add gradient info
    if requires_grad and array_data["is_tensor"]:
        bool_req_grad_symb: str = (
            symbols["true"] if array_data["requires_grad"] else symbols["false"]
        )
        result_parts.append(
            colorize(symbols["requires_grad"] + bool_req_grad_symb, "requires_grad")
        )

    # Return as list if requested, otherwise join with spaces
    if as_list:
        return result_parts
    else:
        joinchar: str = r" \quad " if using_tex else " "
        return joinchar.join(result_parts)

``````{ end_of_file="muutils/tensor_info.py" }

``````{ path="muutils/tensor_utils.py"  }
"""utilities for working with tensors and arrays.

notably:

- `TYPE_TO_JAX_DTYPE` : a mapping from python, numpy, and torch types to `jaxtyping` types
- `DTYPE_MAP` mapping string representations of types to their type
- `TORCH_DTYPE_MAP` mapping string representations of types to torch types
- `compare_state_dicts` for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match

"""

from __future__ import annotations

import json
import typing

import jaxtyping
import numpy as np
import torch

from muutils.errormode import ErrorMode
from muutils.dictmagic import dotlist_to_nested_dict

# pylint: disable=missing-class-docstring


TYPE_TO_JAX_DTYPE: dict = {
    float: jaxtyping.Float,
    int: jaxtyping.Int,
    jaxtyping.Float: jaxtyping.Float,
    jaxtyping.Int: jaxtyping.Int,
    # bool
    bool: jaxtyping.Bool,
    jaxtyping.Bool: jaxtyping.Bool,
    np.bool_: jaxtyping.Bool,
    torch.bool: jaxtyping.Bool,
    # numpy float
    np.float16: jaxtyping.Float,
    np.float32: jaxtyping.Float,
    np.float64: jaxtyping.Float,
    np.half: jaxtyping.Float,
    np.single: jaxtyping.Float,
    np.double: jaxtyping.Float,
    # numpy int
    np.int8: jaxtyping.Int,
    np.int16: jaxtyping.Int,
    np.int32: jaxtyping.Int,
    np.int64: jaxtyping.Int,
    np.longlong: jaxtyping.Int,
    np.short: jaxtyping.Int,
    np.uint8: jaxtyping.Int,
    # torch float
    torch.float: jaxtyping.Float,
    torch.float16: jaxtyping.Float,
    torch.float32: jaxtyping.Float,
    torch.float64: jaxtyping.Float,
    torch.half: jaxtyping.Float,
    torch.double: jaxtyping.Float,
    torch.bfloat16: jaxtyping.Float,
    # torch int
    torch.int: jaxtyping.Int,
    torch.int8: jaxtyping.Int,
    torch.int16: jaxtyping.Int,
    torch.int32: jaxtyping.Int,
    torch.int64: jaxtyping.Int,
    torch.long: jaxtyping.Int,
    torch.short: jaxtyping.Int,
}
"dict mapping python, numpy, and torch types to `jaxtyping` types"

# we check for version here, so it shouldn't error
if np.version.version < "2.0.0":
    TYPE_TO_JAX_DTYPE[np.float_] = jaxtyping.Float  # type: ignore[attr-defined]
    TYPE_TO_JAX_DTYPE[np.int_] = jaxtyping.Int  # type: ignore[attr-defined]


# TODO: add proper type annotations to this signature
# TODO: maybe get rid of this altogether?
def jaxtype_factory(
    name: str,
    array_type: type,
    default_jax_dtype=jaxtyping.Float,
    legacy_mode: typing.Union[ErrorMode, str] = ErrorMode.WARN,
) -> type:
    """usage:
    ```
    ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
    x: ATensor["dim1 dim2", np.float32]
    ```
    """
    legacy_mode_ = ErrorMode.from_any(legacy_mode)

    class _BaseArray:
        """jaxtyping shorthand
        (backwards compatible with older versions of muutils.tensor_utils)

        default_jax_dtype = {default_jax_dtype}
        array_type = {array_type}
        """

        def __new__(cls, *args, **kwargs):
            raise TypeError("Type FArray cannot be instantiated.")

        def __init_subclass__(cls, *args, **kwargs):
            raise TypeError(f"Cannot subclass {cls.__name__}")

        @classmethod
        def param_info(cls, params) -> str:
            """useful for error printing"""
            return "\n".join(
                f"{k} = {v}"
                for k, v in {
                    "cls.__name__": cls.__name__,
                    "cls.__doc__": cls.__doc__,
                    "params": params,
                    "type(params)": type(params),
                }.items()
            )

        @typing._tp_cache  # type: ignore
        def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type:  # type: ignore
            # MyTensor["dim1 dim2"]
            if isinstance(params, str):
                return default_jax_dtype[array_type, params]

            elif isinstance(params, tuple):
                if len(params) != 2:
                    raise Exception(
                        f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}"
                    )

                if isinstance(params[0], str):
                    # MyTensor["dim1 dim2", int]
                    return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]]

                elif isinstance(params[0], tuple):
                    legacy_mode_.process(
                        f"legacy type annotation was used:\n{cls.param_info(params) = }",
                        except_cls=Exception,
                    )
                    # MyTensor[("dim1", "dim2"), int]
                    shape_anot: list[str] = list()
                    for x in params[0]:
                        if isinstance(x, str):
                            shape_anot.append(x)
                        elif isinstance(x, int):
                            shape_anot.append(str(x))
                        elif isinstance(x, tuple):
                            shape_anot.append("".join(str(y) for y in x))
                        else:
                            raise Exception(
                                f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}"
                            )

                    return TYPE_TO_JAX_DTYPE[params[1]][
                        array_type, " ".join(shape_anot)
                    ]
            else:
                raise Exception(
                    f"unexpected type for params:\n{cls.param_info(params)}"
                )

    _BaseArray.__name__ = name

    if _BaseArray.__doc__ is None:
        _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }"

    _BaseArray.__doc__ = _BaseArray.__doc__.format(
        default_jax_dtype=repr(default_jax_dtype),
        array_type=repr(array_type),
    )

    return _BaseArray


if typing.TYPE_CHECKING:
    # these class definitions are only used here to make pylint happy,
    # but they make mypy unhappy and there is no way to only run if not mypy
    # so, later on we have more ignores
    class ATensor(torch.Tensor):
        @typing._tp_cache  # type: ignore
        def __class_getitem__(cls, params):
            raise NotImplementedError()

    class NDArray(torch.Tensor):
        @typing._tp_cache  # type: ignore
        def __class_getitem__(cls, params):
            raise NotImplementedError()


ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)  # type: ignore[misc, assignment]

NDArray = jaxtype_factory("NDArray", np.ndarray, jaxtyping.Float)  # type: ignore[misc, assignment]


def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype:
    """convert numpy dtype to torch dtype"""
    if isinstance(dtype, torch.dtype):
        return dtype
    else:
        return torch.from_numpy(np.array(0, dtype=dtype)).dtype


DTYPE_LIST: list = [
    *[
        bool,
        int,
        float,
    ],
    *[
        # ----------
        # pytorch
        # ----------
        # floats
        torch.float,
        torch.float32,
        torch.float64,
        torch.half,
        torch.double,
        torch.bfloat16,
        # complex
        torch.complex64,
        torch.complex128,
        # ints
        torch.int,
        torch.int8,
        torch.int16,
        torch.int32,
        torch.int64,
        torch.long,
        torch.short,
        # simplest
        torch.uint8,
        torch.bool,
    ],
    *[
        # ----------
        # numpy
        # ----------
        # floats
        np.float16,
        np.float32,
        np.float64,
        np.half,
        np.single,
        np.double,
        # complex
        np.complex64,
        np.complex128,
        # ints
        np.int8,
        np.int16,
        np.int32,
        np.int64,
        np.longlong,
        np.short,
        # simplest
        np.uint8,
        np.bool_,
    ],
]
"list of all the python, numpy, and torch numerical types I could think of"

if np.version.version < "2.0.0":
    DTYPE_LIST.extend([np.float_, np.int_])  # type: ignore[attr-defined]

DTYPE_MAP: dict = {
    **{str(x): x for x in DTYPE_LIST},
    **{dtype.__name__: dtype for dtype in DTYPE_LIST if dtype.__module__ == "numpy"},
}
"mapping from string representations of types to their type"

TORCH_DTYPE_MAP: dict = {
    key: numpy_to_torch_dtype(dtype) for key, dtype in DTYPE_MAP.items()
}
"mapping from string representations of types to specifically torch types"

# no idea why we have to do this, smh
DTYPE_MAP["bool"] = np.bool_
TORCH_DTYPE_MAP["bool"] = torch.bool


TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.Optimizer]] = {
    "Adagrad": torch.optim.Adagrad,
    "Adam": torch.optim.Adam,
    "AdamW": torch.optim.AdamW,
    "SparseAdam": torch.optim.SparseAdam,
    "Adamax": torch.optim.Adamax,
    "ASGD": torch.optim.ASGD,
    "LBFGS": torch.optim.LBFGS,
    "NAdam": torch.optim.NAdam,
    "RAdam": torch.optim.RAdam,
    "RMSprop": torch.optim.RMSprop,
    "Rprop": torch.optim.Rprop,
    "SGD": torch.optim.SGD,
}


def pad_tensor(
    tensor: jaxtyping.Shaped[torch.Tensor, "dim1"],  # noqa: F821
    padded_length: int,
    pad_value: float = 0.0,
    rpad: bool = False,
) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]:  # noqa: F821
    """pad a 1-d tensor on the left with pad_value to length `padded_length`

    set `rpad = True` to pad on the right instead"""

    temp: list[torch.Tensor] = [
        torch.full(
            (padded_length - tensor.shape[0],),
            pad_value,
            dtype=tensor.dtype,
            device=tensor.device,
        ),
        tensor,
    ]

    if rpad:
        temp.reverse()

    return torch.cat(temp)


def lpad_tensor(
    tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0
) -> torch.Tensor:
    """pad a 1-d tensor on the left with pad_value to length `padded_length`"""
    return pad_tensor(tensor, padded_length, pad_value, rpad=False)


def rpad_tensor(
    tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0
) -> torch.Tensor:
    """pad a 1-d tensor on the right with pad_value to length `pad_length`"""
    return pad_tensor(tensor, pad_length, pad_value, rpad=True)


def pad_array(
    array: jaxtyping.Shaped[np.ndarray, "dim1"],  # noqa: F821
    padded_length: int,
    pad_value: float = 0.0,
    rpad: bool = False,
) -> jaxtyping.Shaped[np.ndarray, "padded_length"]:  # noqa: F821
    """pad a 1-d array on the left with pad_value to length `padded_length`

    set `rpad = True` to pad on the right instead"""

    temp: list[np.ndarray] = [
        np.full(
            (padded_length - array.shape[0],),
            pad_value,
            dtype=array.dtype,
        ),
        array,
    ]

    if rpad:
        temp.reverse()

    return np.concatenate(temp)


def lpad_array(
    array: np.ndarray, padded_length: int, pad_value: float = 0.0
) -> np.ndarray:
    """pad a 1-d array on the left with pad_value to length `padded_length`"""
    return pad_array(array, padded_length, pad_value, rpad=False)


def rpad_array(
    array: np.ndarray, pad_length: int, pad_value: float = 0.0
) -> np.ndarray:
    """pad a 1-d array on the right with pad_value to length `pad_length`"""
    return pad_array(array, pad_length, pad_value, rpad=True)


def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]:
    """given a state dict or cache dict, compute the shapes and put them in a nested dict"""
    return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()})


def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str:
    """printable version of get_dict_shapes"""
    return json.dumps(
        dotlist_to_nested_dict(
            {
                k: str(
                    tuple(v.shape)
                )  # to string, since indent wont play nice with tuples
                for k, v in d.items()
            }
        ),
        indent=2,
    )


class StateDictCompareError(AssertionError):
    """raised when state dicts don't match"""

    pass


class StateDictKeysError(StateDictCompareError):
    """raised when state dict keys don't match"""

    pass


class StateDictShapeError(StateDictCompareError):
    """raised when state dict shapes don't match"""

    pass


class StateDictValueError(StateDictCompareError):
    """raised when state dict values don't match"""

    pass


def compare_state_dicts(
    d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True
) -> None:
    """compare two dicts of tensors

    # Parameters:

     - `d1 : dict`
     - `d2 : dict`
     - `rtol : float`
       (defaults to `1e-5`)
     - `atol : float`
       (defaults to `1e-8`)
     - `verbose : bool`
       (defaults to `True`)

    # Raises:

     - `StateDictKeysError` : keys don't match
     - `StateDictShapeError` : shapes don't match (but keys do)
     - `StateDictValueError` : values don't match (but keys and shapes do)
    """
    # check keys match
    d1_keys: set = set(d1.keys())
    d2_keys: set = set(d2.keys())
    symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys)
    keys_diff_1: set = d1_keys - d2_keys
    keys_diff_2: set = d2_keys - d1_keys
    # sort sets for easier debugging
    symmetric_diff = set(sorted(symmetric_diff))
    keys_diff_1 = set(sorted(keys_diff_1))
    keys_diff_2 = set(sorted(keys_diff_2))
    diff_shapes_1: str = (
        string_dict_shapes({k: d1[k] for k in keys_diff_1})
        if verbose
        else "(verbose = False)"
    )
    diff_shapes_2: str = (
        string_dict_shapes({k: d2[k] for k in keys_diff_2})
        if verbose
        else "(verbose = False)"
    )
    if not len(symmetric_diff) == 0:
        raise StateDictKeysError(
            f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}"
        )

    # check tensors match
    shape_failed: list[str] = list()
    vals_failed: list[str] = list()
    for k, v1 in d1.items():
        v2 = d2[k]
        # check shapes first
        if not v1.shape == v2.shape:
            shape_failed.append(k)
        else:
            # if shapes match, check values
            if not torch.allclose(v1, v2, rtol=rtol, atol=atol):
                vals_failed.append(k)

    str_shape_failed: str = (
        string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else ""
    )
    str_vals_failed: str = (
        string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else ""
    )

    if not len(shape_failed) == 0:
        raise StateDictShapeError(
            f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}"
        )
    if not len(vals_failed) == 0:
        raise StateDictValueError(
            f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}"
        )

``````{ end_of_file="muutils/tensor_utils.py" }

``````{ path="muutils/timeit_fancy.py"  }
"`timeit_fancy` is just a fancier version of timeit with more options"

from __future__ import annotations

import pstats
import timeit
import cProfile
from typing import Callable, Union, TypeVar, NamedTuple, Any
import warnings

from muutils.statcounter import StatCounter

T_return = TypeVar("T_return")


class FancyTimeitResult(NamedTuple):
    """return type of `timeit_fancy`"""

    timings: StatCounter
    return_value: T_return  # type: ignore[valid-type]
    profile: Union[pstats.Stats, None]


def timeit_fancy(
    cmd: Union[Callable[[], T_return], str],
    setup: Union[str, Callable[[], Any]] = lambda: None,
    repeats: int = 5,
    namespace: Union[dict[str, Any], None] = None,
    get_return: bool = True,
    do_profiling: bool = False,
) -> FancyTimeitResult:
    """
    Wrapper for `timeit` to get the fastest run of a callable with more customization options.

    Approximates the functionality of the %timeit magic or command line interface in a Python callable.

    # Parameters

    - `cmd: Callable[[], T_return] | str`
        The callable to time. If a string, it will be passed to `timeit.Timer` as the `stmt` argument.
    - `setup: str`
        The setup code to run before `cmd`. If a string, it will be passed to `timeit.Timer` as the `setup` argument.
    - `repeats: int`
        The number of times to run `cmd` to get a reliable measurement.
    - `namespace: dict[str, Any]`
        Passed to `timeit.Timer` constructor.
        If `cmd` or `setup` use local or global variables, they must be passed here. See `timeit` documentation for details.
    - `get_return: bool`
        Whether to pass the value returned from `cmd`. If True, the return value will be appended in a tuple with execution time.
        This is for speed and convenience so that `cmd` doesn't need to be run again in the calling scope if the return values are needed.
        (default: `False`)
    - `do_profiling: bool`
        Whether to return a `pstats.Stats` object in addition to the time and return value.
        (default: `False`)

    # Returns

    `FancyTimeitResult`, which is a NamedTuple with the following fields:

    - `time: float`
        The time in seconds it took to run `cmd` the minimum number of times to get a reliable measurement.
    - `return_value: T|None`
        The return value of `cmd` if `get_return` is `True`, otherwise `None`.
    - `profile: pstats.Stats|None`
        A `pstats.Stats` object if `do_profiling` is `True`, otherwise `None`.
    """
    timer: timeit.Timer = timeit.Timer(cmd, setup, globals=namespace)

    # Perform the timing
    times: list[float] = timer.repeat(repeats, 1)

    # Optionally capture the return value
    profile: pstats.Stats | None = None

    return_value: T_return | None = None
    if (get_return or do_profiling) and isinstance(cmd, str):
        warnings.warn(
            "Can't do profiling or get return value from `cmd` because it is a string."
            " If you want to get the return value, pass a callable instead.",
            UserWarning,
        )
    if (get_return or do_profiling) and not isinstance(cmd, str):
        # Optionally perform profiling
        if do_profiling:
            profiler = cProfile.Profile()
            profiler.enable()

        try:
            return_value = cmd()
        except TypeError as e:
            warnings.warn(
                f"Failed to get return value from `cmd` due to error (probably passing a string). will return `return_value=None`\n{e}",
            )

        if do_profiling:
            profiler.disable()
            profile = pstats.Stats(profiler).strip_dirs().sort_stats("cumulative")

    # reset the return value if it wasn't requested
    if not get_return:
        return_value = None

    return FancyTimeitResult(
        timings=StatCounter(times),
        return_value=return_value,
        profile=profile,
    )

``````{ end_of_file="muutils/timeit_fancy.py" }

``````{ path="muutils/validate_type.py"  }
"""experimental utility for validating types in python, see `validate_type`"""

from __future__ import annotations

from inspect import signature, unwrap
import types
import typing
import functools

# this is also for python <3.10 compatibility
_GenericAliasTypeNames: typing.List[str] = [
    "GenericAlias",
    "_GenericAlias",
    "_UnionGenericAlias",
    "_BaseGenericAlias",
]

_GenericAliasTypesList: list = [
    getattr(typing, name, None) for name in _GenericAliasTypeNames
]

GenericAliasTypes: tuple = tuple([t for t in _GenericAliasTypesList if t is not None])


class IncorrectTypeException(TypeError):
    pass


class TypeHintNotImplementedError(NotImplementedError):
    pass


class InvalidGenericAliasError(TypeError):
    pass


def _return_validation_except(
    return_val: bool, value: typing.Any, expected_type: typing.Any
) -> bool:
    if return_val:
        return True
    else:
        raise IncorrectTypeException(
            f"Expected {expected_type = } for {value = }",
            f"{type(value) = }",
            f"{type(value).__mro__ = }",
            f"{typing.get_origin(expected_type) = }",
            f"{typing.get_args(expected_type) = }",
            "\ndo --tb=long in pytest to see full trace",
        )
        return False


def _return_validation_bool(return_val: bool) -> bool:
    return return_val


def validate_type(
    value: typing.Any, expected_type: typing.Any, do_except: bool = False
) -> bool:
    """Validate that a `value` is of the `expected_type`

    # Parameters
    - `value`: the value to check the type of
    - `expected_type`: the type to check against. Not all types are supported
    - `do_except`: if `True`, raise an exception if the type is incorrect (instead of returning `False`)
        (default: `False`)

    # Returns
    - `bool`: `True` if the value is of the expected type, `False` otherwise.

    # Raises
    - `IncorrectTypeException(TypeError)`: if the type is incorrect and `do_except` is `True`
    - `TypeHintNotImplementedError(NotImplementedError)`: if the type hint is not implemented
    - `InvalidGenericAliasError(TypeError)`: if the generic alias is invalid

    use `typeguard` for a more robust solution: https://github.com/agronholm/typeguard
    """
    if expected_type is typing.Any:
        return True

    # set up the return function depending on `do_except`
    _return_func: typing.Callable[[bool], bool] = (
        # functools.partial doesn't hint the function signature
        functools.partial(  # type: ignore[assignment]
            _return_validation_except, value=value, expected_type=expected_type
        )
        if do_except
        else _return_validation_bool
    )

    # base type without args
    if isinstance(expected_type, type):
        try:
            # if you use args on a type like `dict[str, int]`, this will fail
            return _return_func(isinstance(value, expected_type))
        except TypeError as e:
            if isinstance(e, IncorrectTypeException):
                raise e

    origin: typing.Any = typing.get_origin(expected_type)
    args: tuple = typing.get_args(expected_type)

    # useful for debugging
    # print(f"{value = },   {expected_type = },   {origin = },   {args = }")
    UnionType = getattr(types, "UnionType", None)

    if (origin is typing.Union) or (  # this works in python <3.10
        False
        if UnionType is None  # return False if UnionType is not available
        else origin is UnionType  # return True if UnionType is available
    ):
        return _return_func(any(validate_type(value, arg) for arg in args))

    # generic alias, more complicated
    item_type: type
    if isinstance(expected_type, GenericAliasTypes):
        if origin is list:
            # no args
            if len(args) == 0:
                return _return_func(isinstance(value, list))
            # incorrect number of args
            if len(args) != 1:
                raise InvalidGenericAliasError(
                    f"Too many arguments for list expected 1, got {args = },   {expected_type = },   {value = },   {origin = }",
                    f"{GenericAliasTypes = }",
                )
            # check is list
            if not isinstance(value, list):
                return _return_func(False)
            # check all items in list are of the correct type
            item_type = args[0]
            return all(validate_type(item, item_type) for item in value)

        if origin is dict:
            # no args
            if len(args) == 0:
                return _return_func(isinstance(value, dict))
            # incorrect number of args
            if len(args) != 2:
                raise InvalidGenericAliasError(
                    f"Expected 2 arguments for dict, expected 2, got {args = },   {expected_type = },   {value = },   {origin = }",
                    f"{GenericAliasTypes = }",
                )
            # check is dict
            if not isinstance(value, dict):
                return _return_func(False)
            # check all items in dict are of the correct type
            key_type: type = args[0]
            value_type: type = args[1]
            return _return_func(
                all(
                    validate_type(key, key_type) and validate_type(val, value_type)
                    for key, val in value.items()
                )
            )

        if origin is set:
            # no args
            if len(args) == 0:
                return _return_func(isinstance(value, set))
            # incorrect number of args
            if len(args) != 1:
                raise InvalidGenericAliasError(
                    f"Expected 1 argument for Set, got {args = },   {expected_type = },   {value = },   {origin = }",
                    f"{GenericAliasTypes = }",
                )
            # check is set
            if not isinstance(value, set):
                return _return_func(False)
            # check all items in set are of the correct type
            item_type = args[0]
            return _return_func(all(validate_type(item, item_type) for item in value))

        if origin is tuple:
            # no args
            if len(args) == 0:
                return _return_func(isinstance(value, tuple))
            # check is tuple
            if not isinstance(value, tuple):
                return _return_func(False)
            # check correct number of items in tuple
            if len(value) != len(args):
                return _return_func(False)
            # check all items in tuple are of the correct type
            return _return_func(
                all(validate_type(item, arg) for item, arg in zip(value, args))
            )

        if origin is type:
            # no args
            if len(args) == 0:
                return _return_func(isinstance(value, type))
            # incorrect number of args
            if len(args) != 1:
                raise InvalidGenericAliasError(
                    f"Expected 1 argument for Type, got {args = },   {expected_type = },   {value = },   {origin = }",
                    f"{GenericAliasTypes = }",
                )
            # check is type
            item_type = args[0]
            if item_type in value.__mro__:
                return _return_func(True)
            else:
                return _return_func(False)

        # TODO: Callables, etc.

        raise TypeHintNotImplementedError(
            f"Unsupported generic alias {expected_type = } for {value = },   {origin = },   {args = }",
            f"{origin = }, {args = }",
            f"\n{GenericAliasTypes = }",
        )

    else:
        raise TypeHintNotImplementedError(
            f"Unsupported type hint {expected_type = } for {value = }",
            f"{origin = }, {args = }",
            f"\n{GenericAliasTypes = }",
        )


def get_fn_allowed_kwargs(fn: typing.Callable) -> typing.Set[str]:
    """Get the allowed kwargs for a function, raising an exception if the signature cannot be determined."""
    try:
        fn = unwrap(fn)
        params = signature(fn).parameters
    except ValueError as e:
        raise ValueError(
            f"Cannot retrieve signature for {fn.__name__ = } {fn = }: {str(e)}"
        ) from e

    return {
        param.name
        for param in params.values()
        if param.kind in (param.POSITIONAL_OR_KEYWORD, param.KEYWORD_ONLY)
    }

``````{ end_of_file="muutils/validate_type.py" }

``````{ path="LICENSE"  }
                    GNU GENERAL PUBLIC LICENSE
                       Version 3, 29 June 2007

 Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
 Everyone is permitted to copy and distribute verbatim copies
 of this license document, but changing it is not allowed.

                            Preamble

  The GNU General Public License is a free, copyleft license for
software and other kinds of works.

  The licenses for most software and other practical works are designed
to take away your freedom to share and change the works.  By contrast,
the GNU General Public License is intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users.  We, the Free Software Foundation, use the
GNU General Public License for most of our software; it applies also to
any other work released this way by its authors.  You can apply it to
your programs, too.

  When we speak of free software, we are referring to freedom, not
price.  Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.

  To protect your rights, we need to prevent others from denying you
these rights or asking you to surrender the rights.  Therefore, you have
certain responsibilities if you distribute copies of the software, or if
you modify it: responsibilities to respect the freedom of others.

  For example, if you distribute copies of such a program, whether
gratis or for a fee, you must pass on to the recipients the same
freedoms that you received.  You must make sure that they, too, receive
or can get the source code.  And you must show them these terms so they
know their rights.

  Developers that use the GNU GPL protect your rights with two steps:
(1) assert copyright on the software, and (2) offer you this License
giving you legal permission to copy, distribute and/or modify it.

  For the developers' and authors' protection, the GPL clearly explains
that there is no warranty for this free software.  For both users' and
authors' sake, the GPL requires that modified versions be marked as
changed, so that their problems will not be attributed erroneously to
authors of previous versions.

  Some devices are designed to deny users access to install or run
modified versions of the software inside them, although the manufacturer
can do so.  This is fundamentally incompatible with the aim of
protecting users' freedom to change the software.  The systematic
pattern of such abuse occurs in the area of products for individuals to
use, which is precisely where it is most unacceptable.  Therefore, we
have designed this version of the GPL to prohibit the practice for those
products.  If such problems arise substantially in other domains, we
stand ready to extend this provision to those domains in future versions
of the GPL, as needed to protect the freedom of users.

  Finally, every program is threatened constantly by software patents.
States should not allow patents to restrict development and use of
software on general-purpose computers, but in those that do, we wish to
avoid the special danger that patents applied to a free program could
make it effectively proprietary.  To prevent this, the GPL assures that
patents cannot be used to render the program non-free.

  The precise terms and conditions for copying, distribution and
modification follow.

                       TERMS AND CONDITIONS

  0. Definitions.

  "This License" refers to version 3 of the GNU General Public License.

  "Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.

  "The Program" refers to any copyrightable work licensed under this
License.  Each licensee is addressed as "you".  "Licensees" and
"recipients" may be individuals or organizations.

  To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy.  The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.

  A "covered work" means either the unmodified Program or a work based
on the Program.

  To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy.  Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.

  To "convey" a work means any kind of propagation that enables other
parties to make or receive copies.  Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.

  An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License.  If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.

  1. Source Code.

  The "source code" for a work means the preferred form of the work
for making modifications to it.  "Object code" means any non-source
form of a work.

  A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.

  The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form.  A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.

  The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities.  However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work.  For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.

  The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.

  The Corresponding Source for a work in source code form is that
same work.

  2. Basic Permissions.

  All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met.  This License explicitly affirms your unlimited
permission to run the unmodified Program.  The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work.  This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.

  You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force.  You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright.  Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.

  Conveying under any other circumstances is permitted solely under
the conditions stated below.  Sublicensing is not allowed; section 10
makes it unnecessary.

  3. Protecting Users' Legal Rights From Anti-Circumvention Law.

  No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.

  When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.

  4. Conveying Verbatim Copies.

  You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.

  You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.

  5. Conveying Modified Source Versions.

  You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:

    a) The work must carry prominent notices stating that you modified
    it, and giving a relevant date.

    b) The work must carry prominent notices stating that it is
    released under this License and any conditions added under section
    7.  This requirement modifies the requirement in section 4 to
    "keep intact all notices".

    c) You must license the entire work, as a whole, under this
    License to anyone who comes into possession of a copy.  This
    License will therefore apply, along with any applicable section 7
    additional terms, to the whole of the work, and all its parts,
    regardless of how they are packaged.  This License gives no
    permission to license the work in any other way, but it does not
    invalidate such permission if you have separately received it.

    d) If the work has interactive user interfaces, each must display
    Appropriate Legal Notices; however, if the Program has interactive
    interfaces that do not display Appropriate Legal Notices, your
    work need not make them do so.

  A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit.  Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.

  6. Conveying Non-Source Forms.

  You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:

    a) Convey the object code in, or embodied in, a physical product
    (including a physical distribution medium), accompanied by the
    Corresponding Source fixed on a durable physical medium
    customarily used for software interchange.

    b) Convey the object code in, or embodied in, a physical product
    (including a physical distribution medium), accompanied by a
    written offer, valid for at least three years and valid for as
    long as you offer spare parts or customer support for that product
    model, to give anyone who possesses the object code either (1) a
    copy of the Corresponding Source for all the software in the
    product that is covered by this License, on a durable physical
    medium customarily used for software interchange, for a price no
    more than your reasonable cost of physically performing this
    conveying of source, or (2) access to copy the
    Corresponding Source from a network server at no charge.

    c) Convey individual copies of the object code with a copy of the
    written offer to provide the Corresponding Source.  This
    alternative is allowed only occasionally and noncommercially, and
    only if you received the object code with such an offer, in accord
    with subsection 6b.

    d) Convey the object code by offering access from a designated
    place (gratis or for a charge), and offer equivalent access to the
    Corresponding Source in the same way through the same place at no
    further charge.  You need not require recipients to copy the
    Corresponding Source along with the object code.  If the place to
    copy the object code is a network server, the Corresponding Source
    may be on a different server (operated by you or a third party)
    that supports equivalent copying facilities, provided you maintain
    clear directions next to the object code saying where to find the
    Corresponding Source.  Regardless of what server hosts the
    Corresponding Source, you remain obligated to ensure that it is
    available for as long as needed to satisfy these requirements.

    e) Convey the object code using peer-to-peer transmission, provided
    you inform other peers where the object code and Corresponding
    Source of the work are being offered to the general public at no
    charge under subsection 6d.

  A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.

  A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling.  In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage.  For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product.  A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.

  "Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source.  The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.

  If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information.  But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).

  The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed.  Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.

  Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.

  7. Additional Terms.

  "Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law.  If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.

  When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it.  (Additional permissions may be written to require their own
removal in certain cases when you modify the work.)  You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.

  Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:

    a) Disclaiming warranty or limiting liability differently from the
    terms of sections 15 and 16 of this License; or

    b) Requiring preservation of specified reasonable legal notices or
    author attributions in that material or in the Appropriate Legal
    Notices displayed by works containing it; or

    c) Prohibiting misrepresentation of the origin of that material, or
    requiring that modified versions of such material be marked in
    reasonable ways as different from the original version; or

    d) Limiting the use for publicity purposes of names of licensors or
    authors of the material; or

    e) Declining to grant rights under trademark law for use of some
    trade names, trademarks, or service marks; or

    f) Requiring indemnification of licensors and authors of that
    material by anyone who conveys the material (or modified versions of
    it) with contractual assumptions of liability to the recipient, for
    any liability that these contractual assumptions directly impose on
    those licensors and authors.

  All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10.  If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term.  If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.

  If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.

  Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.

  8. Termination.

  You may not propagate or modify a covered work except as expressly
provided under this License.  Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).

  However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.

  Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.

  Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License.  If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.

  9. Acceptance Not Required for Having Copies.

  You are not required to accept this License in order to receive or
run a copy of the Program.  Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance.  However,
nothing other than this License grants you permission to propagate or
modify any covered work.  These actions infringe copyright if you do
not accept this License.  Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.

  10. Automatic Licensing of Downstream Recipients.

  Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License.  You are not responsible
for enforcing compliance by third parties with this License.

  An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations.  If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.

  You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License.  For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.

  11. Patents.

  A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based.  The
work thus licensed is called the contributor's "contributor version".

  A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version.  For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.

  Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.

  In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement).  To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.

  If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients.  "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.

  If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.

  A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License.  You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.

  Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.

  12. No Surrender of Others' Freedom.

  If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License.  If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all.  For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.

  13. Use with the GNU Affero General Public License.

  Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU Affero General Public License into a single
combined work, and to convey the resulting work.  The terms of this
License will continue to apply to the part which is the covered work,
but the special requirements of the GNU Affero General Public License,
section 13, concerning interaction through a network will apply to the
combination as such.

  14. Revised Versions of this License.

  The Free Software Foundation may publish revised and/or new versions of
the GNU General Public License from time to time.  Such new versions will
be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.

  Each version is given a distinguishing version number.  If the
Program specifies that a certain numbered version of the GNU General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation.  If the Program does not specify a version number of the
GNU General Public License, you may choose any version ever published
by the Free Software Foundation.

  If the Program specifies that a proxy can decide which future
versions of the GNU General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.

  Later license versions may give you additional or different
permissions.  However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.

  15. Disclaimer of Warranty.

  THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW.  EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE.  THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU.  SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.

  16. Limitation of Liability.

  IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.

  17. Interpretation of Sections 15 and 16.

  If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.

                     END OF TERMS AND CONDITIONS

            How to Apply These Terms to Your New Programs

  If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.

  To do so, attach the following notices to the program.  It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.

    <one line to give the program's name and a brief idea of what it does.>
    Copyright (C) <year>  <name of author>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.

Also add information on how to contact you by electronic and paper mail.

  If the program does terminal interaction, make it output a short
notice like this when it starts in an interactive mode:

    <program>  Copyright (C) <year>  <name of author>
    This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
    This is free software, and you are welcome to redistribute it
    under certain conditions; type `show c' for details.

The hypothetical commands `show w' and `show c' should show the appropriate
parts of the General Public License.  Of course, your program's commands
might be different; for a GUI interface, you would use an "about box".

  You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU GPL, see
<https://www.gnu.org/licenses/>.

  The GNU General Public License does not permit incorporating your program
into proprietary programs.  If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library.  If this is what you want to do, use the GNU Lesser General
Public License instead of this License.  But first, please read
<https://www.gnu.org/licenses/why-not-lgpl.html>.

``````{ end_of_file="LICENSE" }

``````{ path="README.md"  }
[![PyPI](https://img.shields.io/pypi/v/muutils)](https://pypi.org/project/muutils/)
![PyPI - Downloads](https://img.shields.io/pypi/dm/muutils)
[![docs](https://img.shields.io/badge/docs-latest-blue)](https://miv.name/muutils)

[![Checks](https://github.com/mivanit/muutils/actions/workflows/checks.yml/badge.svg)](https://github.com/mivanit/muutils/actions/workflows/checks.yml)
[![Checks](https://github.com/mivanit/muutils/actions/workflows/make-docs.yml/badge.svg)](https://github.com/mivanit/muutils/actions/workflows/make-docs.yml)
[![Coverage](docs/coverage/coverage.svg)](docs/coverage/html/)

![GitHub commits](https://img.shields.io/github/commit-activity/t/mivanit/muutils)
![GitHub commit activity](https://img.shields.io/github/commit-activity/m/mivanit/muutils)
![GitHub closed pull requests](https://img.shields.io/github/issues-pr-closed/mivanit/muutils)
![code size, bytes](https://img.shields.io/github/languages/code-size/mivanit/muutils)
<!-- ![Lines of code](https://img.shields.io/tokei/lines/github.com/mivanit/muutils) -->

`muutils`, stylized as "$\mu$utils" or "μutils", is a collection of miscellaneous python utilities, meant to be small and with no dependencies outside of standard python.

# installation

PyPi: [muutils](https://pypi.org/project/muutils/)

```
pip install muutils
```

Note that for using `mlutils`, `tensor_utils`, `nbutils.configure_notebook`, or the array serialization features of `json_serialize`, you will need to install with optional `array` dependencies:
```
pip install muutils[array]
```

# documentation

[**hosted html docs:**](https://miv.name/muutils) https://miv.name/muutils

- [single-page html docs](https://miv.name/muutils/combined/muutils.html) [(absolute source link)](https://github.com/mivanit/muutils/tree/main/docs/combined/muutils.html)
- [single-page markdown docs](https://miv.name/muutils/combined/muutils.md) [(absolute source link)](https://github.com/mivanit/muutils/tree/main/docs/combined/muutils.md)
- Test coverage: [![Test Coverage](https://miv.name/muutils/coverage/coverage.svg)](https://miv.name/muutils/coverage/html/) [webpage](https://miv.name/muutils/coverage/html/) [(absolute source link)](https://github.com/mivanit/muutils/tree/main/docs/coverage/html/) [(plain text)](https://github.com/mivanit/muutils/tree/main/docs/coverage/coverage.txt)

# modules

## [`statcounter`](https://github.com/mivanit/muutils/tree/main/muutils/statcounter.py)

an extension of `collections.Counter` that provides "smart" computation of stats (mean, variance, median, other percentiles) from the counter object without using `Counter.elements()`

## [`dictmagic`](https://github.com/mivanit/muutils/tree/main/muutils/dictmagic.py)

has utilities for working with dictionaries, like:

  - converting dotlist-dictionaries to nested dictionaries and back:
      ```python
      >>> dotlist_to_nested_dict({'a.b.c': 1, 'a.b.d': 2, 'a.e': 3})
      {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
      >>> nested_dict_to_dotlist({'a': {'b': {'c': 1, 'd': 2}, 'e': 3}})
      {'a.b.c': 1, 'a.b.d': 2, 'a.e': 3}
      ```
  - `DefaulterDict` which works like a `defaultdict` but can generate the default value based on the key
  - `condense_tensor_dict` takes a dict of dotlist-tensors and gives a more human-readable summary:
      ```python
      >>> model = MyGPT()
      >>> print(condense_tensor_dict(model.named_parameters(), 'yaml'))
      ```
      ```yaml
      embed:
          W_E: (50257, 768)
      pos_embed:
          W_pos: (1024, 768)
      blocks:
        '[0-11]':
          attn:
          	'[W_Q, W_K, W_V]': (12, 768, 64)
          W_O: (12, 64, 768)
          	'[b_Q, b_K, b_V]': (12, 64)
          b_O: (768,)
	  <...>
      ```

## [`kappa`](https://github.com/mivanit/muutils/tree/main/muutils/kappa.py)

Anonymous gettitem, so you can do things like

```python
>>> k = Kappa(lambda x: x**2)
>>> k[2]
4
```

## [`sysinfo`](https://github.com/mivanit/muutils/tree/main/muutils/sysinfo.py)

utility for getting a bunch of system information. useful for logging.

## `misc`:

contains a few utilities:
    - `stable_hash()` uses `hashlib.sha256` to compute a hash of an object that is stable across runs of python
    - `list_join` and `list_split` which behave like `str.join` and `str.split` but for lists
    - `sanitize_fname` and `dict_to_filename` for simplifying the creation of unique filename
    - `shorten_numerical_to_str()` and `str_to_numeric` turns numbers like `123456789` into `"123M"` and back
    - `freeze`, which prevents an object from being modified. Also see [gelidum](https://github.com/diegojromerolopez/gelidum/)


## [`nbutils`](https://github.com/mivanit/muutils/tree/main/muutils/nbutils)

contains utilities for working with jupyter notebooks, such as:

- quickly converting notebooks to python scripts (and running those scripts) for testing in CI
- configuring notebooks, to make it easier to switch between figure output formats, locations, and more
- shorthand for displaying mermaid diagrams and TeX

## [`json_serialize`](https://github.com/mivanit/muutils/tree/main/muutils/json_serialize)

a tool for serializing and loading arbitrary python objects into json. plays nicely with [`ZANJ`](https://github.com/mivanit/ZANJ/)


## [`tensor_utils`](https://github.com/mivanit/muutils/tree/main/muutils/tensor_utils.py)

contains minor utilities for working with pytorch tensors and numpy arrays, mostly for making type conversions easier

## [`group_equiv`](https://github.com/mivanit/muutils/tree/main/muutils/group_equiv.py)

groups elements from a sequence according to a given equivalence relation, without assuming that the equivalence relation obeys the transitive property



## [`jsonlines`](https://github.com/mivanit/muutils/tree/main/muutils/jsonlines.py)

an extremely simple utility for reading/writing `jsonl` files

## [`ZANJ`](https://github.com/mivanit/ZANJ/)

is a human-readable and simple format for ML models, datasets, and arbitrary objects. It's build around having a zip file with `json` and `npy` files, and has been spun off into its [own project](https://github.com/mivanit/ZANJ/).

There are a couple work-in-progress utilities in [`_wip`](https://github.com/mivanit/muutils/tree/main/muutils/_wip/) that aren't ready for anything, but nothing in this repo is suitable for production. Use at your own risk!

``````{ end_of_file="README.md" }

``````{ path="makefile" processed_with="makefile_recipes" }
# first/default target is help
.PHONY: default
default: help
	...

# this recipe is weird. we need it because:
# - a one liner for getting the version with toml is unwieldy, and using regex is fragile
# - using $$SCRIPT_GET_VERSION within $(shell ...) doesn't work because of escaping issues
# - trying to write to the file inside the `gen-version-info` recipe doesn't work, 
# 	shell eval happens before our `python -c ...` gets run and `cat` doesn't see the new file
.PHONY: write-proj-version
write-proj-version:
	...

# gets version info from $(PYPROJECT), last version from $(LAST_VERSION_FILE), and python version
# uses just `python` for everything except getting the python version. no echo here, because this is "private"
.PHONY: gen-version-info
gen-version-info: write-proj-version
	...

# getting commit log since the tag specified in $(LAST_VERSION_FILE)
# will write to $(COMMIT_LOG_FILE)
# when publishing, the contents of $(COMMIT_LOG_FILE) will be used as the tag description (but can be edited during the process)
# no echo here, because this is "private"
.PHONY: gen-commit-log
gen-commit-log: gen-version-info
	...

# force the version info to be read, printing it out
# also force the commit log to be generated, and cat it out
.PHONY: version
version: gen-commit-log
	@echo "Current version is $(PROJ_VERSION), last auto-uploaded version is $(LAST_VERSION)"
	...

.PHONY: setup
setup: dep-check
	@echo "install and update via uv"
	...

.PHONY: dep-check-torch
dep-check-torch:
	@echo "see if torch is installed, and which CUDA version and devices it sees"
	...

.PHONY: dep
dep:
	@echo "Exporting dependencies as per $(PYPROJECT) section 'tool.uv-exports.exports'"
	...

.PHONY: dep-check
dep-check:
	@echo "Checking that exported requirements are up to date"
	...

.PHONY: dep-clean
dep-clean:
	@echo "clean up lock files, .venv, and requirements files"
	...

# extra tests with python >=3.10 type hints
.PHONY: gen-extra-tests
gen-extra-tests:
	...

# runs ruff and pycln to format the code
.PHONY: format
format:
	@echo "format the source code"
	...

# runs ruff and pycln to check if the code is formatted correctly
.PHONY: format-check
format-check:
	@echo "check if the source code is formatted correctly"
	...

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# runs type checks with mypy
# at some point, need to add back --check-untyped-defs to mypy call
# but it complains when we specify arguments by keyword where positional is fine
# not sure how to fix this
.PHONY: typing
typing: gen-extra-tests
	@echo "running type checks"
	...

# generates a report of the mypy output
.PHONY: typing-report
typing-report: clean gen-extra-tests
	@echo "generate a report of the type check output -- errors per file"
	...

.PHONY: test
test: clean gen-extra-tests
	@echo "running tests"
	...

.PHONY: check
check: clean format-check test typing
	@echo "run format checks, tests, and typing checks"
	...

# generates a whole tree of documentation in html format.
# see `$(MAKE_DOCS_SCRIPT_PATH)` and the templates in `$(DOCS_RESOURCES_DIR)/templates/html/` for more info
.PHONY: docs-html
docs-html:
	@echo "generate html docs"
	...

# instead of a whole website, generates a single markdown file with all docs using the templates in `$(DOCS_RESOURCES_DIR)/templates/markdown/`.
# this is useful if you want to have a copy that you can grep/search, but those docs are much messier.
# docs-combined will use pandoc to convert them to other formats.
.PHONY: docs-md
docs-md:
	@echo "generate combined (single-file) docs in markdown"
	...

# after running docs-md, this will convert the combined markdown file to other formats:
# gfm (github-flavored markdown), plain text, and html
# requires pandoc in path, pointed to by $(PANDOC)
# pdf output would be nice but requires other deps
.PHONY: docs-combined
docs-combined: docs-md
	@echo "generate combined (single-file) docs in markdown and convert to other formats"
	...

# generates coverage reports as html and text with `pytest-cov`, and a badge with `coverage-badge`
# if `.coverage` is not found, will run tests first
# also removes the `.gitignore` file that `coverage html` creates, since we count that as part of the docs
.PHONY: cov
cov:
	@echo "generate coverage reports"
	...

# runs the coverage report, then the docs, then the combined docs
.PHONY: docs
docs: cov docs-html docs-combined todo lmcat
	@echo "generate all documentation and coverage reports"
	...

# removed all generated documentation files, but leaves everything in `$DOCS_RESOURCES_DIR`
# and leaves things defined in `pyproject.toml:tool.makefile.docs.no_clean`
# (templates, svg, css, make_docs.py script)
# distinct from `make clean`
.PHONY: docs-clean
docs-clean:
	@echo "remove generated docs except resources"
	...

.PHONY: todo
todo:
	@echo "get all TODO's from the code"
	...

.PHONY: lmcat-tree
lmcat-tree:
	@echo "show in console the lmcat tree view"
	...

.PHONY: lmcat
lmcat:
	@echo "write the lmcat full output to pyproject.toml:[tool.lmcat.output]"
	...

# verifies that the current branch is $(PUBLISH_BRANCH) and that git is clean
# used before publishing
.PHONY: verify-git
verify-git: 
	@echo "checking git status"
	...

.PHONY: build
build: 
	@echo "build the package"
	...

# gets the commit log, checks everything, builds, and then publishes with twine
# will ask the user to confirm the new version number (and this allows for editing the tag info)
# will also print the contents of $(PYPI_TOKEN_FILE) to the console for the user to copy and paste in when prompted by twine
.PHONY: publish
publish: gen-commit-log check build verify-git version gen-version-info
	@echo "run all checks, build, and then publish"
	...

# cleans up temp files from formatter, type checking, tests, coverage
# removes all built files
# removes $(TESTS_TEMP_DIR) to remove temporary test files
# recursively removes all `__pycache__` directories and `*.pyc` or `*.pyo` files
# distinct from `make docs-clean`, which only removes generated documentation files
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# slight modification in last line for extra tests
.PHONY: clean
clean:
	@echo "clean up temporary files"
	...

.PHONY: clean-all
clean-all: clean docs-clean dep-clean
	@echo "clean up all temporary files, dep files, venv, and generated docs"
	...

.PHONY: info
info: gen-version-info
	@echo "# makefile variables"
	...

.PHONY: info-long
info-long: info
	@echo "# other variables"
	...

# immediately print out the help targets, and then local variables (but those take a bit longer)
.PHONY: help
help: help-targets info
	@echo -n ""
	...

``````{ end_of_file="makefile" }

``````{ path="pyproject.toml"  }
# metadata
# ==================================================
[project]
	name = "muutils"
	version = "0.8.12"
	description = "miscellaneous python utilities"
	readme = "README.md"
	requires-python = ">=3.8"
	license = { text = "GPL-3.0-only" }
	authors = [
		{ name = "mivanit", email = "mivanits@umich.edu" }
	]
	classifiers = [
		"Programming Language :: Python :: 3.8",
		"Programming Language :: Python :: 3.9",
		"Programming Language :: Python :: 3.10",
		"Programming Language :: Python :: 3.11",
		"Programming Language :: Python :: 3.12",
		"Programming Language :: Python :: 3.13",
		"Development Status :: 4 - Beta",
		"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
		"Operating System :: OS Independent",
		"Topic :: Utilities",
		"Typing :: Typed",
	]

	dependencies = [] # no required deps!

[project.urls]
	Homepage = "https://miv.name/muutils"
	Repository = "https://github.com/mivanit/muutils"
	Documentation = "https://miv.name/muutils/"
	Issues = "https://github.com/mivanit/muutils/issues"

# dependencies
# ==================================================

[project.optional-dependencies]
	array = [
		"numpy>=1.24.4; python_version < '3.9'",
		"numpy>1.24.4; python_version >= '3.9'",
		"torch>=1.13.1,<2.5.0; python_version < '3.9'",
		"torch>=1.13.1; python_version >= '3.9' and python_version < '3.13'",
		"torch>=2.5.0; python_version >= '3.13' and python_version < '3.14'",
		"jaxtyping>=0.2.12",
	]

	# special group for CI, where we install cpu torch separately
	array_no_torch = [
		"numpy>=1.24.4; python_version < '3.9'",
		"numpy>1.24.4; python_version >= '3.9'",
		"jaxtyping>=0.2.12",
	]

	notebook = [
		"ipython>=8.0.0",
	]

	parallel = [
		"multiprocess>=0.70.17",
		"tqdm>=4.67.1",
	]

[dependency-groups]
	dev = [
		# typing
		"mypy>=1.0.1; python_version < '3.9'",
		"mypy>=1.15; python_version >= '3.9'",
		"typing-extensions; python_version < '3.11'",
		"beartype>=0.14.1",
		"ty",
		# tests & coverage
		"pytest>=8.2.2",
		"pytest-cov>=4.1.0",
		"coverage-badge>=1.1.0",
		"setuptools>=78.1.1; python_version >= '3.9'", # https://github.com/mivanit/muutils/security/dependabot/31
		# for testing plotting and notebooks
		"ipykernel",
		"jupyter",
		# for jupyter
		"h11>=0.16.0", # https://github.com/mivanit/muutils/security/dependabot/23
		"tornado>=6.5; python_version >= '3.9'", # https://github.com/mivanit/muutils/security/dependabot/33
		# plotting
		"pandas",
		"matplotlib>=3.0.0",
		"plotly>=5.0.0",
		"beautifulsoup4",
		# generating docs
		"pdoc>=14.6.0",
		# https://github.com/mivanit/muutils/security/dependabot/7
		"jinja2>=3.1.6",
		# lmcat -- a custom library. not exactly docs, but lets an LLM see all the code
		"lmcat>=0.2.0; python_version >= '3.11'",
		# tomli since no tomlib in python < 3.11
		"tomli>=2.1.0; python_version < '3.11'",
        # twine dep
        "twine",
	]
	lint = [
		# lint
		"pycln>=2.1.3",
		"ruff>=0.4.8",
	]

# build system and tooling configuration
# ==================================================

[build-system]
	requires = ["hatchling"]
	build-backend = "hatchling.build"

[tool.pytest.ini_options]
	filterwarnings = [
		"ignore::muutils.nbutils.configure_notebook.UnknownFigureFormatWarning", # don't show warning for unknown figure format
		"ignore::muutils.nbutils.configure_notebook.PlotlyNotInstalledWarning", # don't show warning for missing plotly
		"ignore::muutils.json_serialize.serializable_dataclass.ZanjMissingWarning", # don't show warning for missing zanj (can't have as a dep since zanj depends on muutils)
		"ignore: PEP 484 type hint*:beartype.roar._roarwarn.BeartypeDecorHintPep585DeprecationWarning",
	]
	addopts = "--jaxtyping-packages=beartype.beartype"

[tool.ruff]
	# Exclude the directories specified in the global excludes
	exclude = ["tests/input_data", "tests/junk_data", "_wip/"]
	[tool.ruff.lint.per-file-ignores]
		"muutils/tensor_info.py" = [
			"E701", # multiple statements on one line (colon)
		]
		"tests/unit/math/test_matrix_powers_torch.py" = [
			"F722", # jaxtyping stuff
		]
		"muutils/math/matrix_powers.py" = [
			"F722", # jaxtyping stuff
		]

[tool.pycln]
	all = true
	exclude = ["tests/input_data", "tests/junk_data", "_wip/"]

[tool.mypy]
	exclude = [
		# tests
		"tests/input_data",
		"tests/junk_data",
		"tests/_temp/",
		# wip stuff
		"_wip/",
		# not our problem
		"docs/resources/make_docs.py",
	]
	show_error_codes = true
	# we disable this in the makefile for old versions
	check_untyped_defs = true

[tool.lmcat]
	output = "docs/other/lmcat.txt" # changing this might mean it wont be accessible from the docs
	ignore_patterns = [
		"docs/**",
		".venv/**",
		".git/**",
		".meta/**",
		"uv.lock",
		".ruff_cache/**",
		".github/ISSUE_TEMPLATE/**",
		"_wip/**",
		"sweep.yaml",
		# there are... a lot of tests. we usually dont need to put these in lmcat
		"tests/**",
	]
	[tool.lmcat.glob_process]
		"[mM]akefile" = "makefile_recipes"



# [tool.makefile]
# =================================================================
[tool.makefile.docs]
	warnings_ignore = [
		"Error parsing type annotation .* for muutils\\..*\\. Import of np failed:",
		"Error parsing type annotation .* for muutils\\..*\\. Import of JsonSerializer failed:",
		"Error parsing type annotation .* for muutils\\..*\\. Import of StatCounter failed:",
		"Error parsing type annotation .* for muutils\\..*\\. Import of Union failed:"
	]

# Custom export configurations
[tool.makefile.uv-exports]
	args = [
		"--no-hashes"
	]
	exports = [
		# no groups, no extras, just the base dependencies
		{ name = "base", groups = false, extras = false },
		# all extras but no groups
		{ name = "extras", groups = false, extras = true },
		# include the dev group (this is the default behavior)
		{ name = "dev", groups = true },
		# only the lint group -- custom options for this
		{ name = "lint", options = ["--only-group", "lint"] },
		# all groups and extras
		{ name = "all", filename="requirements.txt", groups = true, extras=true },
		# all groups and extras, a different way
		{ name = "all", groups = true, options = ["--all-extras"] },
	]



[tool.makefile.inline-todo]
	search_dir = "."
	out_file_base = "docs/other/todo-inline.md" # changing this might mean it wont be accessible from the docs
	context_lines = 5
	extensions = ["py", "md"]
	tags = ["CRIT", "TODO", "FIXME", "HACK", "BUG", "NOTE"]
	exclude = [
		"docs/**",
		".venv/**",
		"scripts/get_todos.py",
		"_wip/**",
	]
	[tool.inline-todo.tag_label_map]
		NOTE = "documentation"
		CRIT = "bug"
		TODO = "enhancement"
		FIXME = "bug"
		BUG = "bug"
		HACK = "enhancement"
``````{ end_of_file="pyproject.toml" }