muutils.tensor_utils
utilities for working with tensors and arrays.
notably:
TYPE_TO_JAX_DTYPE: a mapping from python, numpy, and torch types tojaxtypingtypesDTYPE_MAPmapping string representations of types to their typeTORCH_DTYPE_MAPmapping string representations of types to torch typescompare_state_dictsfor comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match
1"""utilities for working with tensors and arrays. 2 3notably: 4 5- `TYPE_TO_JAX_DTYPE` : a mapping from python, numpy, and torch types to `jaxtyping` types 6- `DTYPE_MAP` mapping string representations of types to their type 7- `TORCH_DTYPE_MAP` mapping string representations of types to torch types 8- `compare_state_dicts` for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match 9 10""" 11 12from __future__ import annotations 13 14import json 15import typing 16 17import jaxtyping 18import numpy as np 19import torch 20 21from muutils.errormode import ErrorMode 22from muutils.dictmagic import dotlist_to_nested_dict 23 24# pylint: disable=missing-class-docstring 25 26 27TYPE_TO_JAX_DTYPE: dict = { 28 float: jaxtyping.Float, 29 int: jaxtyping.Int, 30 jaxtyping.Float: jaxtyping.Float, 31 jaxtyping.Int: jaxtyping.Int, 32 # bool 33 bool: jaxtyping.Bool, 34 jaxtyping.Bool: jaxtyping.Bool, 35 np.bool_: jaxtyping.Bool, 36 torch.bool: jaxtyping.Bool, 37 # numpy float 38 np.float16: jaxtyping.Float, 39 np.float32: jaxtyping.Float, 40 np.float64: jaxtyping.Float, 41 np.half: jaxtyping.Float, 42 np.single: jaxtyping.Float, 43 np.double: jaxtyping.Float, 44 # numpy int 45 np.int8: jaxtyping.Int, 46 np.int16: jaxtyping.Int, 47 np.int32: jaxtyping.Int, 48 np.int64: jaxtyping.Int, 49 np.longlong: jaxtyping.Int, 50 np.short: jaxtyping.Int, 51 np.uint8: jaxtyping.Int, 52 # torch float 53 torch.float: jaxtyping.Float, 54 torch.float16: jaxtyping.Float, 55 torch.float32: jaxtyping.Float, 56 torch.float64: jaxtyping.Float, 57 torch.half: jaxtyping.Float, 58 torch.double: jaxtyping.Float, 59 torch.bfloat16: jaxtyping.Float, 60 # torch int 61 torch.int: jaxtyping.Int, 62 torch.int8: jaxtyping.Int, 63 torch.int16: jaxtyping.Int, 64 torch.int32: jaxtyping.Int, 65 torch.int64: jaxtyping.Int, 66 torch.long: jaxtyping.Int, 67 torch.short: jaxtyping.Int, 68} 69"dict mapping python, numpy, and torch types to `jaxtyping` types" 70 71# we check for version here, so it shouldn't error 72if np.version.version < "2.0.0": 73 TYPE_TO_JAX_DTYPE[np.float_] = jaxtyping.Float # type: ignore[attr-defined] 74 TYPE_TO_JAX_DTYPE[np.int_] = jaxtyping.Int # type: ignore[attr-defined] 75 76 77# TODO: add proper type annotations to this signature 78# TODO: maybe get rid of this altogether? 79def jaxtype_factory( 80 name: str, 81 array_type: type, 82 default_jax_dtype=jaxtyping.Float, 83 legacy_mode: typing.Union[ErrorMode, str] = ErrorMode.WARN, 84) -> type: 85 """usage: 86 ``` 87 ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) 88 x: ATensor["dim1 dim2", np.float32] 89 ``` 90 """ 91 legacy_mode_ = ErrorMode.from_any(legacy_mode) 92 93 class _BaseArray: 94 """jaxtyping shorthand 95 (backwards compatible with older versions of muutils.tensor_utils) 96 97 default_jax_dtype = {default_jax_dtype} 98 array_type = {array_type} 99 """ 100 101 def __new__(cls, *args, **kwargs): 102 raise TypeError("Type FArray cannot be instantiated.") 103 104 def __init_subclass__(cls, *args, **kwargs): 105 raise TypeError(f"Cannot subclass {cls.__name__}") 106 107 @classmethod 108 def param_info(cls, params) -> str: 109 """useful for error printing""" 110 return "\n".join( 111 f"{k} = {v}" 112 for k, v in { 113 "cls.__name__": cls.__name__, 114 "cls.__doc__": cls.__doc__, 115 "params": params, 116 "type(params)": type(params), 117 }.items() 118 ) 119 120 @typing._tp_cache # type: ignore 121 def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type: # type: ignore 122 # MyTensor["dim1 dim2"] 123 if isinstance(params, str): 124 return default_jax_dtype[array_type, params] 125 126 elif isinstance(params, tuple): 127 if len(params) != 2: 128 raise Exception( 129 f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}" 130 ) 131 132 if isinstance(params[0], str): 133 # MyTensor["dim1 dim2", int] 134 return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]] 135 136 elif isinstance(params[0], tuple): 137 legacy_mode_.process( 138 f"legacy type annotation was used:\n{cls.param_info(params) = }", 139 except_cls=Exception, 140 ) 141 # MyTensor[("dim1", "dim2"), int] 142 shape_anot: list[str] = list() 143 for x in params[0]: 144 if isinstance(x, str): 145 shape_anot.append(x) 146 elif isinstance(x, int): 147 shape_anot.append(str(x)) 148 elif isinstance(x, tuple): 149 shape_anot.append("".join(str(y) for y in x)) 150 else: 151 raise Exception( 152 f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}" 153 ) 154 155 return TYPE_TO_JAX_DTYPE[params[1]][ 156 array_type, " ".join(shape_anot) 157 ] 158 else: 159 raise Exception( 160 f"unexpected type for params:\n{cls.param_info(params)}" 161 ) 162 163 _BaseArray.__name__ = name 164 165 if _BaseArray.__doc__ is None: 166 _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }" 167 168 _BaseArray.__doc__ = _BaseArray.__doc__.format( 169 default_jax_dtype=repr(default_jax_dtype), 170 array_type=repr(array_type), 171 ) 172 173 return _BaseArray 174 175 176if typing.TYPE_CHECKING: 177 # these class definitions are only used here to make pylint happy, 178 # but they make mypy unhappy and there is no way to only run if not mypy 179 # so, later on we have more ignores 180 class ATensor(torch.Tensor): 181 @typing._tp_cache # type: ignore 182 def __class_getitem__(cls, params): 183 raise NotImplementedError() 184 185 class NDArray(torch.Tensor): 186 @typing._tp_cache # type: ignore 187 def __class_getitem__(cls, params): 188 raise NotImplementedError() 189 190 191ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) # type: ignore[misc, assignment] 192 193NDArray = jaxtype_factory("NDArray", np.ndarray, jaxtyping.Float) # type: ignore[misc, assignment] 194 195 196def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype: 197 """convert numpy dtype to torch dtype""" 198 if isinstance(dtype, torch.dtype): 199 return dtype 200 else: 201 return torch.from_numpy(np.array(0, dtype=dtype)).dtype 202 203 204DTYPE_LIST: list = [ 205 *[ 206 bool, 207 int, 208 float, 209 ], 210 *[ 211 # ---------- 212 # pytorch 213 # ---------- 214 # floats 215 torch.float, 216 torch.float32, 217 torch.float64, 218 torch.half, 219 torch.double, 220 torch.bfloat16, 221 # complex 222 torch.complex64, 223 torch.complex128, 224 # ints 225 torch.int, 226 torch.int8, 227 torch.int16, 228 torch.int32, 229 torch.int64, 230 torch.long, 231 torch.short, 232 # simplest 233 torch.uint8, 234 torch.bool, 235 ], 236 *[ 237 # ---------- 238 # numpy 239 # ---------- 240 # floats 241 np.float16, 242 np.float32, 243 np.float64, 244 np.half, 245 np.single, 246 np.double, 247 # complex 248 np.complex64, 249 np.complex128, 250 # ints 251 np.int8, 252 np.int16, 253 np.int32, 254 np.int64, 255 np.longlong, 256 np.short, 257 # simplest 258 np.uint8, 259 np.bool_, 260 ], 261] 262"list of all the python, numpy, and torch numerical types I could think of" 263 264if np.version.version < "2.0.0": 265 DTYPE_LIST.extend([np.float_, np.int_]) # type: ignore[attr-defined] 266 267DTYPE_MAP: dict = { 268 **{str(x): x for x in DTYPE_LIST}, 269 **{dtype.__name__: dtype for dtype in DTYPE_LIST if dtype.__module__ == "numpy"}, 270} 271"mapping from string representations of types to their type" 272 273TORCH_DTYPE_MAP: dict = { 274 key: numpy_to_torch_dtype(dtype) for key, dtype in DTYPE_MAP.items() 275} 276"mapping from string representations of types to specifically torch types" 277 278# no idea why we have to do this, smh 279DTYPE_MAP["bool"] = np.bool_ 280TORCH_DTYPE_MAP["bool"] = torch.bool 281 282 283TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.Optimizer]] = { 284 "Adagrad": torch.optim.Adagrad, 285 "Adam": torch.optim.Adam, 286 "AdamW": torch.optim.AdamW, 287 "SparseAdam": torch.optim.SparseAdam, 288 "Adamax": torch.optim.Adamax, 289 "ASGD": torch.optim.ASGD, 290 "LBFGS": torch.optim.LBFGS, 291 "NAdam": torch.optim.NAdam, 292 "RAdam": torch.optim.RAdam, 293 "RMSprop": torch.optim.RMSprop, 294 "Rprop": torch.optim.Rprop, 295 "SGD": torch.optim.SGD, 296} 297 298 299def pad_tensor( 300 tensor: jaxtyping.Shaped[torch.Tensor, "dim1"], # noqa: F821 301 padded_length: int, 302 pad_value: float = 0.0, 303 rpad: bool = False, 304) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]: # noqa: F821 305 """pad a 1-d tensor on the left with pad_value to length `padded_length` 306 307 set `rpad = True` to pad on the right instead""" 308 309 temp: list[torch.Tensor] = [ 310 torch.full( 311 (padded_length - tensor.shape[0],), 312 pad_value, 313 dtype=tensor.dtype, 314 device=tensor.device, 315 ), 316 tensor, 317 ] 318 319 if rpad: 320 temp.reverse() 321 322 return torch.cat(temp) 323 324 325def lpad_tensor( 326 tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0 327) -> torch.Tensor: 328 """pad a 1-d tensor on the left with pad_value to length `padded_length`""" 329 return pad_tensor(tensor, padded_length, pad_value, rpad=False) 330 331 332def rpad_tensor( 333 tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0 334) -> torch.Tensor: 335 """pad a 1-d tensor on the right with pad_value to length `pad_length`""" 336 return pad_tensor(tensor, pad_length, pad_value, rpad=True) 337 338 339def pad_array( 340 array: jaxtyping.Shaped[np.ndarray, "dim1"], # noqa: F821 341 padded_length: int, 342 pad_value: float = 0.0, 343 rpad: bool = False, 344) -> jaxtyping.Shaped[np.ndarray, "padded_length"]: # noqa: F821 345 """pad a 1-d array on the left with pad_value to length `padded_length` 346 347 set `rpad = True` to pad on the right instead""" 348 349 temp: list[np.ndarray] = [ 350 np.full( 351 (padded_length - array.shape[0],), 352 pad_value, 353 dtype=array.dtype, 354 ), 355 array, 356 ] 357 358 if rpad: 359 temp.reverse() 360 361 return np.concatenate(temp) 362 363 364def lpad_array( 365 array: np.ndarray, padded_length: int, pad_value: float = 0.0 366) -> np.ndarray: 367 """pad a 1-d array on the left with pad_value to length `padded_length`""" 368 return pad_array(array, padded_length, pad_value, rpad=False) 369 370 371def rpad_array( 372 array: np.ndarray, pad_length: int, pad_value: float = 0.0 373) -> np.ndarray: 374 """pad a 1-d array on the right with pad_value to length `pad_length`""" 375 return pad_array(array, pad_length, pad_value, rpad=True) 376 377 378def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]: 379 """given a state dict or cache dict, compute the shapes and put them in a nested dict""" 380 return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()}) 381 382 383def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str: 384 """printable version of get_dict_shapes""" 385 return json.dumps( 386 dotlist_to_nested_dict( 387 { 388 k: str( 389 tuple(v.shape) 390 ) # to string, since indent wont play nice with tuples 391 for k, v in d.items() 392 } 393 ), 394 indent=2, 395 ) 396 397 398class StateDictCompareError(AssertionError): 399 """raised when state dicts don't match""" 400 401 pass 402 403 404class StateDictKeysError(StateDictCompareError): 405 """raised when state dict keys don't match""" 406 407 pass 408 409 410class StateDictShapeError(StateDictCompareError): 411 """raised when state dict shapes don't match""" 412 413 pass 414 415 416class StateDictValueError(StateDictCompareError): 417 """raised when state dict values don't match""" 418 419 pass 420 421 422def compare_state_dicts( 423 d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True 424) -> None: 425 """compare two dicts of tensors 426 427 # Parameters: 428 429 - `d1 : dict` 430 - `d2 : dict` 431 - `rtol : float` 432 (defaults to `1e-5`) 433 - `atol : float` 434 (defaults to `1e-8`) 435 - `verbose : bool` 436 (defaults to `True`) 437 438 # Raises: 439 440 - `StateDictKeysError` : keys don't match 441 - `StateDictShapeError` : shapes don't match (but keys do) 442 - `StateDictValueError` : values don't match (but keys and shapes do) 443 """ 444 # check keys match 445 d1_keys: set = set(d1.keys()) 446 d2_keys: set = set(d2.keys()) 447 symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys) 448 keys_diff_1: set = d1_keys - d2_keys 449 keys_diff_2: set = d2_keys - d1_keys 450 # sort sets for easier debugging 451 symmetric_diff = set(sorted(symmetric_diff)) 452 keys_diff_1 = set(sorted(keys_diff_1)) 453 keys_diff_2 = set(sorted(keys_diff_2)) 454 diff_shapes_1: str = ( 455 string_dict_shapes({k: d1[k] for k in keys_diff_1}) 456 if verbose 457 else "(verbose = False)" 458 ) 459 diff_shapes_2: str = ( 460 string_dict_shapes({k: d2[k] for k in keys_diff_2}) 461 if verbose 462 else "(verbose = False)" 463 ) 464 if not len(symmetric_diff) == 0: 465 raise StateDictKeysError( 466 f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}" 467 ) 468 469 # check tensors match 470 shape_failed: list[str] = list() 471 vals_failed: list[str] = list() 472 for k, v1 in d1.items(): 473 v2 = d2[k] 474 # check shapes first 475 if not v1.shape == v2.shape: 476 shape_failed.append(k) 477 else: 478 # if shapes match, check values 479 if not torch.allclose(v1, v2, rtol=rtol, atol=atol): 480 vals_failed.append(k) 481 482 str_shape_failed: str = ( 483 string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else "" 484 ) 485 str_vals_failed: str = ( 486 string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else "" 487 ) 488 489 if not len(shape_failed) == 0: 490 raise StateDictShapeError( 491 f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}" 492 ) 493 if not len(vals_failed) == 0: 494 raise StateDictValueError( 495 f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}" 496 )
dict mapping python, numpy, and torch types to jaxtyping types
80def jaxtype_factory( 81 name: str, 82 array_type: type, 83 default_jax_dtype=jaxtyping.Float, 84 legacy_mode: typing.Union[ErrorMode, str] = ErrorMode.WARN, 85) -> type: 86 """usage: 87 ``` 88 ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) 89 x: ATensor["dim1 dim2", np.float32] 90 ``` 91 """ 92 legacy_mode_ = ErrorMode.from_any(legacy_mode) 93 94 class _BaseArray: 95 """jaxtyping shorthand 96 (backwards compatible with older versions of muutils.tensor_utils) 97 98 default_jax_dtype = {default_jax_dtype} 99 array_type = {array_type} 100 """ 101 102 def __new__(cls, *args, **kwargs): 103 raise TypeError("Type FArray cannot be instantiated.") 104 105 def __init_subclass__(cls, *args, **kwargs): 106 raise TypeError(f"Cannot subclass {cls.__name__}") 107 108 @classmethod 109 def param_info(cls, params) -> str: 110 """useful for error printing""" 111 return "\n".join( 112 f"{k} = {v}" 113 for k, v in { 114 "cls.__name__": cls.__name__, 115 "cls.__doc__": cls.__doc__, 116 "params": params, 117 "type(params)": type(params), 118 }.items() 119 ) 120 121 @typing._tp_cache # type: ignore 122 def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type: # type: ignore 123 # MyTensor["dim1 dim2"] 124 if isinstance(params, str): 125 return default_jax_dtype[array_type, params] 126 127 elif isinstance(params, tuple): 128 if len(params) != 2: 129 raise Exception( 130 f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}" 131 ) 132 133 if isinstance(params[0], str): 134 # MyTensor["dim1 dim2", int] 135 return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]] 136 137 elif isinstance(params[0], tuple): 138 legacy_mode_.process( 139 f"legacy type annotation was used:\n{cls.param_info(params) = }", 140 except_cls=Exception, 141 ) 142 # MyTensor[("dim1", "dim2"), int] 143 shape_anot: list[str] = list() 144 for x in params[0]: 145 if isinstance(x, str): 146 shape_anot.append(x) 147 elif isinstance(x, int): 148 shape_anot.append(str(x)) 149 elif isinstance(x, tuple): 150 shape_anot.append("".join(str(y) for y in x)) 151 else: 152 raise Exception( 153 f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}" 154 ) 155 156 return TYPE_TO_JAX_DTYPE[params[1]][ 157 array_type, " ".join(shape_anot) 158 ] 159 else: 160 raise Exception( 161 f"unexpected type for params:\n{cls.param_info(params)}" 162 ) 163 164 _BaseArray.__name__ = name 165 166 if _BaseArray.__doc__ is None: 167 _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }" 168 169 _BaseArray.__doc__ = _BaseArray.__doc__.format( 170 default_jax_dtype=repr(default_jax_dtype), 171 array_type=repr(array_type), 172 ) 173 174 return _BaseArray
usage:
ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
x: ATensor["dim1 dim2", np.float32]
197def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype: 198 """convert numpy dtype to torch dtype""" 199 if isinstance(dtype, torch.dtype): 200 return dtype 201 else: 202 return torch.from_numpy(np.array(0, dtype=dtype)).dtype
convert numpy dtype to torch dtype
list of all the python, numpy, and torch numerical types I could think of
mapping from string representations of types to their type
mapping from string representations of types to specifically torch types
300def pad_tensor( 301 tensor: jaxtyping.Shaped[torch.Tensor, "dim1"], # noqa: F821 302 padded_length: int, 303 pad_value: float = 0.0, 304 rpad: bool = False, 305) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]: # noqa: F821 306 """pad a 1-d tensor on the left with pad_value to length `padded_length` 307 308 set `rpad = True` to pad on the right instead""" 309 310 temp: list[torch.Tensor] = [ 311 torch.full( 312 (padded_length - tensor.shape[0],), 313 pad_value, 314 dtype=tensor.dtype, 315 device=tensor.device, 316 ), 317 tensor, 318 ] 319 320 if rpad: 321 temp.reverse() 322 323 return torch.cat(temp)
pad a 1-d tensor on the left with pad_value to length padded_length
set rpad = True to pad on the right instead
326def lpad_tensor( 327 tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0 328) -> torch.Tensor: 329 """pad a 1-d tensor on the left with pad_value to length `padded_length`""" 330 return pad_tensor(tensor, padded_length, pad_value, rpad=False)
pad a 1-d tensor on the left with pad_value to length padded_length
333def rpad_tensor( 334 tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0 335) -> torch.Tensor: 336 """pad a 1-d tensor on the right with pad_value to length `pad_length`""" 337 return pad_tensor(tensor, pad_length, pad_value, rpad=True)
pad a 1-d tensor on the right with pad_value to length pad_length
340def pad_array( 341 array: jaxtyping.Shaped[np.ndarray, "dim1"], # noqa: F821 342 padded_length: int, 343 pad_value: float = 0.0, 344 rpad: bool = False, 345) -> jaxtyping.Shaped[np.ndarray, "padded_length"]: # noqa: F821 346 """pad a 1-d array on the left with pad_value to length `padded_length` 347 348 set `rpad = True` to pad on the right instead""" 349 350 temp: list[np.ndarray] = [ 351 np.full( 352 (padded_length - array.shape[0],), 353 pad_value, 354 dtype=array.dtype, 355 ), 356 array, 357 ] 358 359 if rpad: 360 temp.reverse() 361 362 return np.concatenate(temp)
pad a 1-d array on the left with pad_value to length padded_length
set rpad = True to pad on the right instead
365def lpad_array( 366 array: np.ndarray, padded_length: int, pad_value: float = 0.0 367) -> np.ndarray: 368 """pad a 1-d array on the left with pad_value to length `padded_length`""" 369 return pad_array(array, padded_length, pad_value, rpad=False)
pad a 1-d array on the left with pad_value to length padded_length
372def rpad_array( 373 array: np.ndarray, pad_length: int, pad_value: float = 0.0 374) -> np.ndarray: 375 """pad a 1-d array on the right with pad_value to length `pad_length`""" 376 return pad_array(array, pad_length, pad_value, rpad=True)
pad a 1-d array on the right with pad_value to length pad_length
379def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]: 380 """given a state dict or cache dict, compute the shapes and put them in a nested dict""" 381 return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()})
given a state dict or cache dict, compute the shapes and put them in a nested dict
384def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str: 385 """printable version of get_dict_shapes""" 386 return json.dumps( 387 dotlist_to_nested_dict( 388 { 389 k: str( 390 tuple(v.shape) 391 ) # to string, since indent wont play nice with tuples 392 for k, v in d.items() 393 } 394 ), 395 indent=2, 396 )
printable version of get_dict_shapes
399class StateDictCompareError(AssertionError): 400 """raised when state dicts don't match""" 401 402 pass
raised when state dicts don't match
Inherited Members
- builtins.AssertionError
- AssertionError
- builtins.BaseException
- with_traceback
- add_note
- args
405class StateDictKeysError(StateDictCompareError): 406 """raised when state dict keys don't match""" 407 408 pass
raised when state dict keys don't match
Inherited Members
- builtins.AssertionError
- AssertionError
- builtins.BaseException
- with_traceback
- add_note
- args
411class StateDictShapeError(StateDictCompareError): 412 """raised when state dict shapes don't match""" 413 414 pass
raised when state dict shapes don't match
Inherited Members
- builtins.AssertionError
- AssertionError
- builtins.BaseException
- with_traceback
- add_note
- args
417class StateDictValueError(StateDictCompareError): 418 """raised when state dict values don't match""" 419 420 pass
raised when state dict values don't match
Inherited Members
- builtins.AssertionError
- AssertionError
- builtins.BaseException
- with_traceback
- add_note
- args
423def compare_state_dicts( 424 d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True 425) -> None: 426 """compare two dicts of tensors 427 428 # Parameters: 429 430 - `d1 : dict` 431 - `d2 : dict` 432 - `rtol : float` 433 (defaults to `1e-5`) 434 - `atol : float` 435 (defaults to `1e-8`) 436 - `verbose : bool` 437 (defaults to `True`) 438 439 # Raises: 440 441 - `StateDictKeysError` : keys don't match 442 - `StateDictShapeError` : shapes don't match (but keys do) 443 - `StateDictValueError` : values don't match (but keys and shapes do) 444 """ 445 # check keys match 446 d1_keys: set = set(d1.keys()) 447 d2_keys: set = set(d2.keys()) 448 symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys) 449 keys_diff_1: set = d1_keys - d2_keys 450 keys_diff_2: set = d2_keys - d1_keys 451 # sort sets for easier debugging 452 symmetric_diff = set(sorted(symmetric_diff)) 453 keys_diff_1 = set(sorted(keys_diff_1)) 454 keys_diff_2 = set(sorted(keys_diff_2)) 455 diff_shapes_1: str = ( 456 string_dict_shapes({k: d1[k] for k in keys_diff_1}) 457 if verbose 458 else "(verbose = False)" 459 ) 460 diff_shapes_2: str = ( 461 string_dict_shapes({k: d2[k] for k in keys_diff_2}) 462 if verbose 463 else "(verbose = False)" 464 ) 465 if not len(symmetric_diff) == 0: 466 raise StateDictKeysError( 467 f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}" 468 ) 469 470 # check tensors match 471 shape_failed: list[str] = list() 472 vals_failed: list[str] = list() 473 for k, v1 in d1.items(): 474 v2 = d2[k] 475 # check shapes first 476 if not v1.shape == v2.shape: 477 shape_failed.append(k) 478 else: 479 # if shapes match, check values 480 if not torch.allclose(v1, v2, rtol=rtol, atol=atol): 481 vals_failed.append(k) 482 483 str_shape_failed: str = ( 484 string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else "" 485 ) 486 str_vals_failed: str = ( 487 string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else "" 488 ) 489 490 if not len(shape_failed) == 0: 491 raise StateDictShapeError( 492 f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}" 493 ) 494 if not len(vals_failed) == 0: 495 raise StateDictValueError( 496 f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}" 497 )
compare two dicts of tensors
Parameters:
d1 : dictd2 : dictrtol : float(defaults to1e-5)atol : float(defaults to1e-8)verbose : bool(defaults toTrue)
Raises:
StateDictKeysError: keys don't matchStateDictShapeError: shapes don't match (but keys do)StateDictValueError: values don't match (but keys and shapes do)