docs for muutils v0.8.12
View Source on GitHub

muutils.tensor_utils

utilities for working with tensors and arrays.

notably:

  • TYPE_TO_JAX_DTYPE : a mapping from python, numpy, and torch types to jaxtyping types
  • DTYPE_MAP mapping string representations of types to their type
  • TORCH_DTYPE_MAP mapping string representations of types to torch types
  • compare_state_dicts for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match

  1"""utilities for working with tensors and arrays.
  2
  3notably:
  4
  5- `TYPE_TO_JAX_DTYPE` : a mapping from python, numpy, and torch types to `jaxtyping` types
  6- `DTYPE_MAP` mapping string representations of types to their type
  7- `TORCH_DTYPE_MAP` mapping string representations of types to torch types
  8- `compare_state_dicts` for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match
  9
 10"""
 11
 12from __future__ import annotations
 13
 14import json
 15import typing
 16
 17import jaxtyping
 18import numpy as np
 19import torch
 20
 21from muutils.errormode import ErrorMode
 22from muutils.dictmagic import dotlist_to_nested_dict
 23
 24# pylint: disable=missing-class-docstring
 25
 26
 27TYPE_TO_JAX_DTYPE: dict = {
 28    float: jaxtyping.Float,
 29    int: jaxtyping.Int,
 30    jaxtyping.Float: jaxtyping.Float,
 31    jaxtyping.Int: jaxtyping.Int,
 32    # bool
 33    bool: jaxtyping.Bool,
 34    jaxtyping.Bool: jaxtyping.Bool,
 35    np.bool_: jaxtyping.Bool,
 36    torch.bool: jaxtyping.Bool,
 37    # numpy float
 38    np.float16: jaxtyping.Float,
 39    np.float32: jaxtyping.Float,
 40    np.float64: jaxtyping.Float,
 41    np.half: jaxtyping.Float,
 42    np.single: jaxtyping.Float,
 43    np.double: jaxtyping.Float,
 44    # numpy int
 45    np.int8: jaxtyping.Int,
 46    np.int16: jaxtyping.Int,
 47    np.int32: jaxtyping.Int,
 48    np.int64: jaxtyping.Int,
 49    np.longlong: jaxtyping.Int,
 50    np.short: jaxtyping.Int,
 51    np.uint8: jaxtyping.Int,
 52    # torch float
 53    torch.float: jaxtyping.Float,
 54    torch.float16: jaxtyping.Float,
 55    torch.float32: jaxtyping.Float,
 56    torch.float64: jaxtyping.Float,
 57    torch.half: jaxtyping.Float,
 58    torch.double: jaxtyping.Float,
 59    torch.bfloat16: jaxtyping.Float,
 60    # torch int
 61    torch.int: jaxtyping.Int,
 62    torch.int8: jaxtyping.Int,
 63    torch.int16: jaxtyping.Int,
 64    torch.int32: jaxtyping.Int,
 65    torch.int64: jaxtyping.Int,
 66    torch.long: jaxtyping.Int,
 67    torch.short: jaxtyping.Int,
 68}
 69"dict mapping python, numpy, and torch types to `jaxtyping` types"
 70
 71# we check for version here, so it shouldn't error
 72if np.version.version < "2.0.0":
 73    TYPE_TO_JAX_DTYPE[np.float_] = jaxtyping.Float  # type: ignore[attr-defined]
 74    TYPE_TO_JAX_DTYPE[np.int_] = jaxtyping.Int  # type: ignore[attr-defined]
 75
 76
 77# TODO: add proper type annotations to this signature
 78# TODO: maybe get rid of this altogether?
 79def jaxtype_factory(
 80    name: str,
 81    array_type: type,
 82    default_jax_dtype=jaxtyping.Float,
 83    legacy_mode: typing.Union[ErrorMode, str] = ErrorMode.WARN,
 84) -> type:
 85    """usage:
 86    ```
 87    ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
 88    x: ATensor["dim1 dim2", np.float32]
 89    ```
 90    """
 91    legacy_mode_ = ErrorMode.from_any(legacy_mode)
 92
 93    class _BaseArray:
 94        """jaxtyping shorthand
 95        (backwards compatible with older versions of muutils.tensor_utils)
 96
 97        default_jax_dtype = {default_jax_dtype}
 98        array_type = {array_type}
 99        """
100
101        def __new__(cls, *args, **kwargs):
102            raise TypeError("Type FArray cannot be instantiated.")
103
104        def __init_subclass__(cls, *args, **kwargs):
105            raise TypeError(f"Cannot subclass {cls.__name__}")
106
107        @classmethod
108        def param_info(cls, params) -> str:
109            """useful for error printing"""
110            return "\n".join(
111                f"{k} = {v}"
112                for k, v in {
113                    "cls.__name__": cls.__name__,
114                    "cls.__doc__": cls.__doc__,
115                    "params": params,
116                    "type(params)": type(params),
117                }.items()
118            )
119
120        @typing._tp_cache  # type: ignore
121        def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type:  # type: ignore
122            # MyTensor["dim1 dim2"]
123            if isinstance(params, str):
124                return default_jax_dtype[array_type, params]
125
126            elif isinstance(params, tuple):
127                if len(params) != 2:
128                    raise Exception(
129                        f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}"
130                    )
131
132                if isinstance(params[0], str):
133                    # MyTensor["dim1 dim2", int]
134                    return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]]
135
136                elif isinstance(params[0], tuple):
137                    legacy_mode_.process(
138                        f"legacy type annotation was used:\n{cls.param_info(params) = }",
139                        except_cls=Exception,
140                    )
141                    # MyTensor[("dim1", "dim2"), int]
142                    shape_anot: list[str] = list()
143                    for x in params[0]:
144                        if isinstance(x, str):
145                            shape_anot.append(x)
146                        elif isinstance(x, int):
147                            shape_anot.append(str(x))
148                        elif isinstance(x, tuple):
149                            shape_anot.append("".join(str(y) for y in x))
150                        else:
151                            raise Exception(
152                                f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}"
153                            )
154
155                    return TYPE_TO_JAX_DTYPE[params[1]][
156                        array_type, " ".join(shape_anot)
157                    ]
158            else:
159                raise Exception(
160                    f"unexpected type for params:\n{cls.param_info(params)}"
161                )
162
163    _BaseArray.__name__ = name
164
165    if _BaseArray.__doc__ is None:
166        _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }"
167
168    _BaseArray.__doc__ = _BaseArray.__doc__.format(
169        default_jax_dtype=repr(default_jax_dtype),
170        array_type=repr(array_type),
171    )
172
173    return _BaseArray
174
175
176if typing.TYPE_CHECKING:
177    # these class definitions are only used here to make pylint happy,
178    # but they make mypy unhappy and there is no way to only run if not mypy
179    # so, later on we have more ignores
180    class ATensor(torch.Tensor):
181        @typing._tp_cache  # type: ignore
182        def __class_getitem__(cls, params):
183            raise NotImplementedError()
184
185    class NDArray(torch.Tensor):
186        @typing._tp_cache  # type: ignore
187        def __class_getitem__(cls, params):
188            raise NotImplementedError()
189
190
191ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)  # type: ignore[misc, assignment]
192
193NDArray = jaxtype_factory("NDArray", np.ndarray, jaxtyping.Float)  # type: ignore[misc, assignment]
194
195
196def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype:
197    """convert numpy dtype to torch dtype"""
198    if isinstance(dtype, torch.dtype):
199        return dtype
200    else:
201        return torch.from_numpy(np.array(0, dtype=dtype)).dtype
202
203
204DTYPE_LIST: list = [
205    *[
206        bool,
207        int,
208        float,
209    ],
210    *[
211        # ----------
212        # pytorch
213        # ----------
214        # floats
215        torch.float,
216        torch.float32,
217        torch.float64,
218        torch.half,
219        torch.double,
220        torch.bfloat16,
221        # complex
222        torch.complex64,
223        torch.complex128,
224        # ints
225        torch.int,
226        torch.int8,
227        torch.int16,
228        torch.int32,
229        torch.int64,
230        torch.long,
231        torch.short,
232        # simplest
233        torch.uint8,
234        torch.bool,
235    ],
236    *[
237        # ----------
238        # numpy
239        # ----------
240        # floats
241        np.float16,
242        np.float32,
243        np.float64,
244        np.half,
245        np.single,
246        np.double,
247        # complex
248        np.complex64,
249        np.complex128,
250        # ints
251        np.int8,
252        np.int16,
253        np.int32,
254        np.int64,
255        np.longlong,
256        np.short,
257        # simplest
258        np.uint8,
259        np.bool_,
260    ],
261]
262"list of all the python, numpy, and torch numerical types I could think of"
263
264if np.version.version < "2.0.0":
265    DTYPE_LIST.extend([np.float_, np.int_])  # type: ignore[attr-defined]
266
267DTYPE_MAP: dict = {
268    **{str(x): x for x in DTYPE_LIST},
269    **{dtype.__name__: dtype for dtype in DTYPE_LIST if dtype.__module__ == "numpy"},
270}
271"mapping from string representations of types to their type"
272
273TORCH_DTYPE_MAP: dict = {
274    key: numpy_to_torch_dtype(dtype) for key, dtype in DTYPE_MAP.items()
275}
276"mapping from string representations of types to specifically torch types"
277
278# no idea why we have to do this, smh
279DTYPE_MAP["bool"] = np.bool_
280TORCH_DTYPE_MAP["bool"] = torch.bool
281
282
283TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.Optimizer]] = {
284    "Adagrad": torch.optim.Adagrad,
285    "Adam": torch.optim.Adam,
286    "AdamW": torch.optim.AdamW,
287    "SparseAdam": torch.optim.SparseAdam,
288    "Adamax": torch.optim.Adamax,
289    "ASGD": torch.optim.ASGD,
290    "LBFGS": torch.optim.LBFGS,
291    "NAdam": torch.optim.NAdam,
292    "RAdam": torch.optim.RAdam,
293    "RMSprop": torch.optim.RMSprop,
294    "Rprop": torch.optim.Rprop,
295    "SGD": torch.optim.SGD,
296}
297
298
299def pad_tensor(
300    tensor: jaxtyping.Shaped[torch.Tensor, "dim1"],  # noqa: F821
301    padded_length: int,
302    pad_value: float = 0.0,
303    rpad: bool = False,
304) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]:  # noqa: F821
305    """pad a 1-d tensor on the left with pad_value to length `padded_length`
306
307    set `rpad = True` to pad on the right instead"""
308
309    temp: list[torch.Tensor] = [
310        torch.full(
311            (padded_length - tensor.shape[0],),
312            pad_value,
313            dtype=tensor.dtype,
314            device=tensor.device,
315        ),
316        tensor,
317    ]
318
319    if rpad:
320        temp.reverse()
321
322    return torch.cat(temp)
323
324
325def lpad_tensor(
326    tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0
327) -> torch.Tensor:
328    """pad a 1-d tensor on the left with pad_value to length `padded_length`"""
329    return pad_tensor(tensor, padded_length, pad_value, rpad=False)
330
331
332def rpad_tensor(
333    tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0
334) -> torch.Tensor:
335    """pad a 1-d tensor on the right with pad_value to length `pad_length`"""
336    return pad_tensor(tensor, pad_length, pad_value, rpad=True)
337
338
339def pad_array(
340    array: jaxtyping.Shaped[np.ndarray, "dim1"],  # noqa: F821
341    padded_length: int,
342    pad_value: float = 0.0,
343    rpad: bool = False,
344) -> jaxtyping.Shaped[np.ndarray, "padded_length"]:  # noqa: F821
345    """pad a 1-d array on the left with pad_value to length `padded_length`
346
347    set `rpad = True` to pad on the right instead"""
348
349    temp: list[np.ndarray] = [
350        np.full(
351            (padded_length - array.shape[0],),
352            pad_value,
353            dtype=array.dtype,
354        ),
355        array,
356    ]
357
358    if rpad:
359        temp.reverse()
360
361    return np.concatenate(temp)
362
363
364def lpad_array(
365    array: np.ndarray, padded_length: int, pad_value: float = 0.0
366) -> np.ndarray:
367    """pad a 1-d array on the left with pad_value to length `padded_length`"""
368    return pad_array(array, padded_length, pad_value, rpad=False)
369
370
371def rpad_array(
372    array: np.ndarray, pad_length: int, pad_value: float = 0.0
373) -> np.ndarray:
374    """pad a 1-d array on the right with pad_value to length `pad_length`"""
375    return pad_array(array, pad_length, pad_value, rpad=True)
376
377
378def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]:
379    """given a state dict or cache dict, compute the shapes and put them in a nested dict"""
380    return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()})
381
382
383def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str:
384    """printable version of get_dict_shapes"""
385    return json.dumps(
386        dotlist_to_nested_dict(
387            {
388                k: str(
389                    tuple(v.shape)
390                )  # to string, since indent wont play nice with tuples
391                for k, v in d.items()
392            }
393        ),
394        indent=2,
395    )
396
397
398class StateDictCompareError(AssertionError):
399    """raised when state dicts don't match"""
400
401    pass
402
403
404class StateDictKeysError(StateDictCompareError):
405    """raised when state dict keys don't match"""
406
407    pass
408
409
410class StateDictShapeError(StateDictCompareError):
411    """raised when state dict shapes don't match"""
412
413    pass
414
415
416class StateDictValueError(StateDictCompareError):
417    """raised when state dict values don't match"""
418
419    pass
420
421
422def compare_state_dicts(
423    d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True
424) -> None:
425    """compare two dicts of tensors
426
427    # Parameters:
428
429     - `d1 : dict`
430     - `d2 : dict`
431     - `rtol : float`
432       (defaults to `1e-5`)
433     - `atol : float`
434       (defaults to `1e-8`)
435     - `verbose : bool`
436       (defaults to `True`)
437
438    # Raises:
439
440     - `StateDictKeysError` : keys don't match
441     - `StateDictShapeError` : shapes don't match (but keys do)
442     - `StateDictValueError` : values don't match (but keys and shapes do)
443    """
444    # check keys match
445    d1_keys: set = set(d1.keys())
446    d2_keys: set = set(d2.keys())
447    symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys)
448    keys_diff_1: set = d1_keys - d2_keys
449    keys_diff_2: set = d2_keys - d1_keys
450    # sort sets for easier debugging
451    symmetric_diff = set(sorted(symmetric_diff))
452    keys_diff_1 = set(sorted(keys_diff_1))
453    keys_diff_2 = set(sorted(keys_diff_2))
454    diff_shapes_1: str = (
455        string_dict_shapes({k: d1[k] for k in keys_diff_1})
456        if verbose
457        else "(verbose = False)"
458    )
459    diff_shapes_2: str = (
460        string_dict_shapes({k: d2[k] for k in keys_diff_2})
461        if verbose
462        else "(verbose = False)"
463    )
464    if not len(symmetric_diff) == 0:
465        raise StateDictKeysError(
466            f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}"
467        )
468
469    # check tensors match
470    shape_failed: list[str] = list()
471    vals_failed: list[str] = list()
472    for k, v1 in d1.items():
473        v2 = d2[k]
474        # check shapes first
475        if not v1.shape == v2.shape:
476            shape_failed.append(k)
477        else:
478            # if shapes match, check values
479            if not torch.allclose(v1, v2, rtol=rtol, atol=atol):
480                vals_failed.append(k)
481
482    str_shape_failed: str = (
483        string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else ""
484    )
485    str_vals_failed: str = (
486        string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else ""
487    )
488
489    if not len(shape_failed) == 0:
490        raise StateDictShapeError(
491            f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}"
492        )
493    if not len(vals_failed) == 0:
494        raise StateDictValueError(
495            f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}"
496        )

TYPE_TO_JAX_DTYPE: dict = {<class 'float'>: <class 'jaxtyping.Float'>, <class 'int'>: <class 'jaxtyping.Int'>, <class 'jaxtyping.Float'>: <class 'jaxtyping.Float'>, <class 'jaxtyping.Int'>: <class 'jaxtyping.Int'>, <class 'bool'>: <class 'jaxtyping.Bool'>, <class 'jaxtyping.Bool'>: <class 'jaxtyping.Bool'>, <class 'numpy.bool'>: <class 'jaxtyping.Bool'>, torch.bool: <class 'jaxtyping.Bool'>, <class 'numpy.float16'>: <class 'jaxtyping.Float'>, <class 'numpy.float32'>: <class 'jaxtyping.Float'>, <class 'numpy.float64'>: <class 'jaxtyping.Float'>, <class 'numpy.int8'>: <class 'jaxtyping.Int'>, <class 'numpy.int16'>: <class 'jaxtyping.Int'>, <class 'numpy.int32'>: <class 'jaxtyping.Int'>, <class 'numpy.int64'>: <class 'jaxtyping.Int'>, <class 'numpy.longlong'>: <class 'jaxtyping.Int'>, <class 'numpy.uint8'>: <class 'jaxtyping.Int'>, torch.float32: <class 'jaxtyping.Float'>, torch.float16: <class 'jaxtyping.Float'>, torch.float64: <class 'jaxtyping.Float'>, torch.bfloat16: <class 'jaxtyping.Float'>, torch.int32: <class 'jaxtyping.Int'>, torch.int8: <class 'jaxtyping.Int'>, torch.int16: <class 'jaxtyping.Int'>, torch.int64: <class 'jaxtyping.Int'>}

dict mapping python, numpy, and torch types to jaxtyping types

def jaxtype_factory( name: str, array_type: type, default_jax_dtype=<class 'jaxtyping.Float'>, legacy_mode: Union[muutils.errormode.ErrorMode, str] = ErrorMode.Warn) -> type:
 80def jaxtype_factory(
 81    name: str,
 82    array_type: type,
 83    default_jax_dtype=jaxtyping.Float,
 84    legacy_mode: typing.Union[ErrorMode, str] = ErrorMode.WARN,
 85) -> type:
 86    """usage:
 87    ```
 88    ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
 89    x: ATensor["dim1 dim2", np.float32]
 90    ```
 91    """
 92    legacy_mode_ = ErrorMode.from_any(legacy_mode)
 93
 94    class _BaseArray:
 95        """jaxtyping shorthand
 96        (backwards compatible with older versions of muutils.tensor_utils)
 97
 98        default_jax_dtype = {default_jax_dtype}
 99        array_type = {array_type}
100        """
101
102        def __new__(cls, *args, **kwargs):
103            raise TypeError("Type FArray cannot be instantiated.")
104
105        def __init_subclass__(cls, *args, **kwargs):
106            raise TypeError(f"Cannot subclass {cls.__name__}")
107
108        @classmethod
109        def param_info(cls, params) -> str:
110            """useful for error printing"""
111            return "\n".join(
112                f"{k} = {v}"
113                for k, v in {
114                    "cls.__name__": cls.__name__,
115                    "cls.__doc__": cls.__doc__,
116                    "params": params,
117                    "type(params)": type(params),
118                }.items()
119            )
120
121        @typing._tp_cache  # type: ignore
122        def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type:  # type: ignore
123            # MyTensor["dim1 dim2"]
124            if isinstance(params, str):
125                return default_jax_dtype[array_type, params]
126
127            elif isinstance(params, tuple):
128                if len(params) != 2:
129                    raise Exception(
130                        f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}"
131                    )
132
133                if isinstance(params[0], str):
134                    # MyTensor["dim1 dim2", int]
135                    return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]]
136
137                elif isinstance(params[0], tuple):
138                    legacy_mode_.process(
139                        f"legacy type annotation was used:\n{cls.param_info(params) = }",
140                        except_cls=Exception,
141                    )
142                    # MyTensor[("dim1", "dim2"), int]
143                    shape_anot: list[str] = list()
144                    for x in params[0]:
145                        if isinstance(x, str):
146                            shape_anot.append(x)
147                        elif isinstance(x, int):
148                            shape_anot.append(str(x))
149                        elif isinstance(x, tuple):
150                            shape_anot.append("".join(str(y) for y in x))
151                        else:
152                            raise Exception(
153                                f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}"
154                            )
155
156                    return TYPE_TO_JAX_DTYPE[params[1]][
157                        array_type, " ".join(shape_anot)
158                    ]
159            else:
160                raise Exception(
161                    f"unexpected type for params:\n{cls.param_info(params)}"
162                )
163
164    _BaseArray.__name__ = name
165
166    if _BaseArray.__doc__ is None:
167        _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }"
168
169    _BaseArray.__doc__ = _BaseArray.__doc__.format(
170        default_jax_dtype=repr(default_jax_dtype),
171        array_type=repr(array_type),
172    )
173
174    return _BaseArray

usage:

ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
x: ATensor["dim1 dim2", np.float32]
ATensor = <class 'jaxtype_factory.<locals>._BaseArray'>
NDArray = <class 'jaxtype_factory.<locals>._BaseArray'>
def numpy_to_torch_dtype(dtype: Union[numpy.dtype, torch.dtype]) -> torch.dtype:
197def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype:
198    """convert numpy dtype to torch dtype"""
199    if isinstance(dtype, torch.dtype):
200        return dtype
201    else:
202        return torch.from_numpy(np.array(0, dtype=dtype)).dtype

convert numpy dtype to torch dtype

DTYPE_LIST: list = [<class 'bool'>, <class 'int'>, <class 'float'>, torch.float32, torch.float32, torch.float64, torch.float16, torch.float64, torch.bfloat16, torch.complex64, torch.complex128, torch.int32, torch.int8, torch.int16, torch.int32, torch.int64, torch.int64, torch.int16, torch.uint8, torch.bool, <class 'numpy.float16'>, <class 'numpy.float32'>, <class 'numpy.float64'>, <class 'numpy.float16'>, <class 'numpy.float32'>, <class 'numpy.float64'>, <class 'numpy.complex64'>, <class 'numpy.complex128'>, <class 'numpy.int8'>, <class 'numpy.int16'>, <class 'numpy.int32'>, <class 'numpy.int64'>, <class 'numpy.longlong'>, <class 'numpy.int16'>, <class 'numpy.uint8'>, <class 'numpy.bool'>]

list of all the python, numpy, and torch numerical types I could think of

DTYPE_MAP: dict = {"<class 'bool'>": <class 'bool'>, "<class 'int'>": <class 'int'>, "<class 'float'>": <class 'float'>, 'torch.float32': torch.float32, 'torch.float64': torch.float64, 'torch.float16': torch.float16, 'torch.bfloat16': torch.bfloat16, 'torch.complex64': torch.complex64, 'torch.complex128': torch.complex128, 'torch.int32': torch.int32, 'torch.int8': torch.int8, 'torch.int16': torch.int16, 'torch.int64': torch.int64, 'torch.uint8': torch.uint8, 'torch.bool': torch.bool, "<class 'numpy.float16'>": <class 'numpy.float16'>, "<class 'numpy.float32'>": <class 'numpy.float32'>, "<class 'numpy.float64'>": <class 'numpy.float64'>, "<class 'numpy.complex64'>": <class 'numpy.complex64'>, "<class 'numpy.complex128'>": <class 'numpy.complex128'>, "<class 'numpy.int8'>": <class 'numpy.int8'>, "<class 'numpy.int16'>": <class 'numpy.int16'>, "<class 'numpy.int32'>": <class 'numpy.int32'>, "<class 'numpy.int64'>": <class 'numpy.int64'>, "<class 'numpy.longlong'>": <class 'numpy.longlong'>, "<class 'numpy.uint8'>": <class 'numpy.uint8'>, "<class 'numpy.bool'>": <class 'numpy.bool'>, 'float16': <class 'numpy.float16'>, 'float32': <class 'numpy.float32'>, 'float64': <class 'numpy.float64'>, 'complex64': <class 'numpy.complex64'>, 'complex128': <class 'numpy.complex128'>, 'int8': <class 'numpy.int8'>, 'int16': <class 'numpy.int16'>, 'int32': <class 'numpy.int32'>, 'int64': <class 'numpy.int64'>, 'longlong': <class 'numpy.longlong'>, 'uint8': <class 'numpy.uint8'>, 'bool': <class 'numpy.bool'>}

mapping from string representations of types to their type

TORCH_DTYPE_MAP: dict = {"<class 'bool'>": torch.bool, "<class 'int'>": torch.int64, "<class 'float'>": torch.float64, 'torch.float32': torch.float32, 'torch.float64': torch.float64, 'torch.float16': torch.float16, 'torch.bfloat16': torch.bfloat16, 'torch.complex64': torch.complex64, 'torch.complex128': torch.complex128, 'torch.int32': torch.int32, 'torch.int8': torch.int8, 'torch.int16': torch.int16, 'torch.int64': torch.int64, 'torch.uint8': torch.uint8, 'torch.bool': torch.bool, "<class 'numpy.float16'>": torch.float16, "<class 'numpy.float32'>": torch.float32, "<class 'numpy.float64'>": torch.float64, "<class 'numpy.complex64'>": torch.complex64, "<class 'numpy.complex128'>": torch.complex128, "<class 'numpy.int8'>": torch.int8, "<class 'numpy.int16'>": torch.int16, "<class 'numpy.int32'>": torch.int32, "<class 'numpy.int64'>": torch.int64, "<class 'numpy.longlong'>": torch.int64, "<class 'numpy.uint8'>": torch.uint8, "<class 'numpy.bool'>": torch.bool, 'float16': torch.float16, 'float32': torch.float32, 'float64': torch.float64, 'complex64': torch.complex64, 'complex128': torch.complex128, 'int8': torch.int8, 'int16': torch.int16, 'int32': torch.int32, 'int64': torch.int64, 'longlong': torch.int64, 'uint8': torch.uint8, 'bool': torch.bool}

mapping from string representations of types to specifically torch types

TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.optimizer.Optimizer]] = {'Adagrad': <class 'torch.optim.adagrad.Adagrad'>, 'Adam': <class 'torch.optim.adam.Adam'>, 'AdamW': <class 'torch.optim.adamw.AdamW'>, 'SparseAdam': <class 'torch.optim.sparse_adam.SparseAdam'>, 'Adamax': <class 'torch.optim.adamax.Adamax'>, 'ASGD': <class 'torch.optim.asgd.ASGD'>, 'LBFGS': <class 'torch.optim.lbfgs.LBFGS'>, 'NAdam': <class 'torch.optim.nadam.NAdam'>, 'RAdam': <class 'torch.optim.radam.RAdam'>, 'RMSprop': <class 'torch.optim.rmsprop.RMSprop'>, 'Rprop': <class 'torch.optim.rprop.Rprop'>, 'SGD': <class 'torch.optim.sgd.SGD'>}
def pad_tensor( tensor: jaxtyping.Shaped[Tensor, 'dim1'], padded_length: int, pad_value: float = 0.0, rpad: bool = False) -> jaxtyping.Shaped[Tensor, 'padded_length']:
300def pad_tensor(
301    tensor: jaxtyping.Shaped[torch.Tensor, "dim1"],  # noqa: F821
302    padded_length: int,
303    pad_value: float = 0.0,
304    rpad: bool = False,
305) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]:  # noqa: F821
306    """pad a 1-d tensor on the left with pad_value to length `padded_length`
307
308    set `rpad = True` to pad on the right instead"""
309
310    temp: list[torch.Tensor] = [
311        torch.full(
312            (padded_length - tensor.shape[0],),
313            pad_value,
314            dtype=tensor.dtype,
315            device=tensor.device,
316        ),
317        tensor,
318    ]
319
320    if rpad:
321        temp.reverse()
322
323    return torch.cat(temp)

pad a 1-d tensor on the left with pad_value to length padded_length

set rpad = True to pad on the right instead

def lpad_tensor( tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0) -> torch.Tensor:
326def lpad_tensor(
327    tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0
328) -> torch.Tensor:
329    """pad a 1-d tensor on the left with pad_value to length `padded_length`"""
330    return pad_tensor(tensor, padded_length, pad_value, rpad=False)

pad a 1-d tensor on the left with pad_value to length padded_length

def rpad_tensor( tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0) -> torch.Tensor:
333def rpad_tensor(
334    tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0
335) -> torch.Tensor:
336    """pad a 1-d tensor on the right with pad_value to length `pad_length`"""
337    return pad_tensor(tensor, pad_length, pad_value, rpad=True)

pad a 1-d tensor on the right with pad_value to length pad_length

def pad_array( array: jaxtyping.Shaped[ndarray, 'dim1'], padded_length: int, pad_value: float = 0.0, rpad: bool = False) -> jaxtyping.Shaped[ndarray, 'padded_length']:
340def pad_array(
341    array: jaxtyping.Shaped[np.ndarray, "dim1"],  # noqa: F821
342    padded_length: int,
343    pad_value: float = 0.0,
344    rpad: bool = False,
345) -> jaxtyping.Shaped[np.ndarray, "padded_length"]:  # noqa: F821
346    """pad a 1-d array on the left with pad_value to length `padded_length`
347
348    set `rpad = True` to pad on the right instead"""
349
350    temp: list[np.ndarray] = [
351        np.full(
352            (padded_length - array.shape[0],),
353            pad_value,
354            dtype=array.dtype,
355        ),
356        array,
357    ]
358
359    if rpad:
360        temp.reverse()
361
362    return np.concatenate(temp)

pad a 1-d array on the left with pad_value to length padded_length

set rpad = True to pad on the right instead

def lpad_array( array: numpy.ndarray, padded_length: int, pad_value: float = 0.0) -> numpy.ndarray:
365def lpad_array(
366    array: np.ndarray, padded_length: int, pad_value: float = 0.0
367) -> np.ndarray:
368    """pad a 1-d array on the left with pad_value to length `padded_length`"""
369    return pad_array(array, padded_length, pad_value, rpad=False)

pad a 1-d array on the left with pad_value to length padded_length

def rpad_array( array: numpy.ndarray, pad_length: int, pad_value: float = 0.0) -> numpy.ndarray:
372def rpad_array(
373    array: np.ndarray, pad_length: int, pad_value: float = 0.0
374) -> np.ndarray:
375    """pad a 1-d array on the right with pad_value to length `pad_length`"""
376    return pad_array(array, pad_length, pad_value, rpad=True)

pad a 1-d array on the right with pad_value to length pad_length

def get_dict_shapes(d: dict[str, torch.Tensor]) -> dict[str, tuple[int, ...]]:
379def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]:
380    """given a state dict or cache dict, compute the shapes and put them in a nested dict"""
381    return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()})

given a state dict or cache dict, compute the shapes and put them in a nested dict

def string_dict_shapes(d: dict[str, torch.Tensor]) -> str:
384def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str:
385    """printable version of get_dict_shapes"""
386    return json.dumps(
387        dotlist_to_nested_dict(
388            {
389                k: str(
390                    tuple(v.shape)
391                )  # to string, since indent wont play nice with tuples
392                for k, v in d.items()
393            }
394        ),
395        indent=2,
396    )

printable version of get_dict_shapes

class StateDictCompareError(builtins.AssertionError):
399class StateDictCompareError(AssertionError):
400    """raised when state dicts don't match"""
401
402    pass

raised when state dicts don't match

Inherited Members
builtins.AssertionError
AssertionError
builtins.BaseException
with_traceback
add_note
args
class StateDictKeysError(StateDictCompareError):
405class StateDictKeysError(StateDictCompareError):
406    """raised when state dict keys don't match"""
407
408    pass

raised when state dict keys don't match

Inherited Members
builtins.AssertionError
AssertionError
builtins.BaseException
with_traceback
add_note
args
class StateDictShapeError(StateDictCompareError):
411class StateDictShapeError(StateDictCompareError):
412    """raised when state dict shapes don't match"""
413
414    pass

raised when state dict shapes don't match

Inherited Members
builtins.AssertionError
AssertionError
builtins.BaseException
with_traceback
add_note
args
class StateDictValueError(StateDictCompareError):
417class StateDictValueError(StateDictCompareError):
418    """raised when state dict values don't match"""
419
420    pass

raised when state dict values don't match

Inherited Members
builtins.AssertionError
AssertionError
builtins.BaseException
with_traceback
add_note
args
def compare_state_dicts( d1: dict, d2: dict, rtol: float = 1e-05, atol: float = 1e-08, verbose: bool = True) -> None:
423def compare_state_dicts(
424    d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True
425) -> None:
426    """compare two dicts of tensors
427
428    # Parameters:
429
430     - `d1 : dict`
431     - `d2 : dict`
432     - `rtol : float`
433       (defaults to `1e-5`)
434     - `atol : float`
435       (defaults to `1e-8`)
436     - `verbose : bool`
437       (defaults to `True`)
438
439    # Raises:
440
441     - `StateDictKeysError` : keys don't match
442     - `StateDictShapeError` : shapes don't match (but keys do)
443     - `StateDictValueError` : values don't match (but keys and shapes do)
444    """
445    # check keys match
446    d1_keys: set = set(d1.keys())
447    d2_keys: set = set(d2.keys())
448    symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys)
449    keys_diff_1: set = d1_keys - d2_keys
450    keys_diff_2: set = d2_keys - d1_keys
451    # sort sets for easier debugging
452    symmetric_diff = set(sorted(symmetric_diff))
453    keys_diff_1 = set(sorted(keys_diff_1))
454    keys_diff_2 = set(sorted(keys_diff_2))
455    diff_shapes_1: str = (
456        string_dict_shapes({k: d1[k] for k in keys_diff_1})
457        if verbose
458        else "(verbose = False)"
459    )
460    diff_shapes_2: str = (
461        string_dict_shapes({k: d2[k] for k in keys_diff_2})
462        if verbose
463        else "(verbose = False)"
464    )
465    if not len(symmetric_diff) == 0:
466        raise StateDictKeysError(
467            f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}"
468        )
469
470    # check tensors match
471    shape_failed: list[str] = list()
472    vals_failed: list[str] = list()
473    for k, v1 in d1.items():
474        v2 = d2[k]
475        # check shapes first
476        if not v1.shape == v2.shape:
477            shape_failed.append(k)
478        else:
479            # if shapes match, check values
480            if not torch.allclose(v1, v2, rtol=rtol, atol=atol):
481                vals_failed.append(k)
482
483    str_shape_failed: str = (
484        string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else ""
485    )
486    str_vals_failed: str = (
487        string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else ""
488    )
489
490    if not len(shape_failed) == 0:
491        raise StateDictShapeError(
492            f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}"
493        )
494    if not len(vals_failed) == 0:
495        raise StateDictValueError(
496            f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}"
497        )

compare two dicts of tensors

Parameters:

  • d1 : dict
  • d2 : dict
  • rtol : float (defaults to 1e-5)
  • atol : float (defaults to 1e-8)
  • verbose : bool (defaults to True)

Raises: