docs for muutils v0.8.12
View Source on GitHub

muutils.mlutils

miscellaneous utilities for ML pipelines


  1"miscellaneous utilities for ML pipelines"
  2
  3from __future__ import annotations
  4
  5import json
  6import os
  7import random
  8import typing
  9import warnings
 10from itertools import islice
 11from pathlib import Path
 12from typing import Any, Callable, Optional, TypeVar, Union
 13
 14ARRAY_IMPORTS: bool
 15try:
 16    import numpy as np
 17    import torch
 18
 19    ARRAY_IMPORTS = True
 20except ImportError as e:
 21    warnings.warn(
 22        f"Numpy or torch not installed. Array operations will not be available.\n{e}"
 23    )
 24    ARRAY_IMPORTS = False
 25
 26DEFAULT_SEED: int = 42
 27GLOBAL_SEED: int = DEFAULT_SEED
 28
 29
 30def get_device(device: "Union[str,torch.device,None]" = None) -> "torch.device":
 31    """Get the torch.device instance on which `torch.Tensor`s should be allocated."""
 32    if not ARRAY_IMPORTS:
 33        raise ImportError(
 34            "Numpy or torch not installed. Array operations will not be available."
 35        )
 36    try:
 37        # if device is given
 38        if device is not None:
 39            device = torch.device(device)
 40            if any(
 41                [
 42                    torch.cuda.is_available() and device.type == "cuda",
 43                    torch.backends.mps.is_available() and device.type == "mps",
 44                    device.type == "cpu",
 45                ]
 46            ):
 47                # if device is given and available
 48                pass
 49            else:
 50                warnings.warn(
 51                    f"Specified device {device} is not available, falling back to CPU"
 52                )
 53                return torch.device("cpu")
 54
 55        # no device given, infer from availability
 56        else:
 57            if torch.cuda.is_available():
 58                device = torch.device("cuda")
 59            elif torch.backends.mps.is_available():
 60                device = torch.device("mps")
 61            else:
 62                device = torch.device("cpu")
 63
 64        # put a dummy tensor on the device to check if it is available
 65        _dummy = torch.zeros(1, device=device)
 66
 67        return device
 68
 69    except Exception as e:
 70        warnings.warn(
 71            f"Error while getting device, falling back to CPU. Error: {e}",
 72            RuntimeWarning,
 73        )
 74        return torch.device("cpu")
 75
 76
 77def set_reproducibility(seed: int = DEFAULT_SEED):
 78    """
 79    Improve model reproducibility. See https://github.com/NVIDIA/framework-determinism for more information.
 80
 81    Deterministic operations tend to have worse performance than nondeterministic operations, so this method trades
 82    off performance for reproducibility. Set use_deterministic_algorithms to True to improve performance.
 83    """
 84    global GLOBAL_SEED
 85
 86    GLOBAL_SEED = seed
 87
 88    random.seed(seed)
 89
 90    if ARRAY_IMPORTS:
 91        np.random.seed(seed)
 92        torch.manual_seed(seed)
 93
 94        torch.use_deterministic_algorithms(True)
 95        # Ensure reproducibility for concurrent CUDA streams
 96        # see https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility.
 97        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
 98
 99
100def chunks(it, chunk_size):
101    """Yield successive chunks from an iterator."""
102    # https://stackoverflow.com/a/61435714
103    iterator = iter(it)
104    while chunk := list(islice(iterator, chunk_size)):
105        yield chunk
106
107
108def get_checkpoint_paths_for_run(
109    run_path: Path,
110    extension: typing.Literal["pt", "zanj"],
111    checkpoints_format: str = "checkpoints/model.iter_*.{extension}",
112) -> list[tuple[int, Path]]:
113    """get checkpoints of the format from the run_path
114
115    note that `checkpoints_format` should contain a glob pattern with:
116     - unresolved "{extension}" format term for the extension
117     - a wildcard for the iteration number
118    """
119
120    assert run_path.is_dir(), (
121        f"Model path {run_path} is not a directory (expect run directory, not model files)"
122    )
123
124    return [
125        (int(checkpoint_path.stem.split("_")[-1].split(".")[0]), checkpoint_path)
126        for checkpoint_path in sorted(
127            Path(run_path).glob(checkpoints_format.format(extension=extension))
128        )
129    ]
130
131
132F = TypeVar("F", bound=Callable[..., Any])
133
134
135def register_method(
136    method_dict: dict[str, Callable[..., Any]],
137    custom_name: Optional[str] = None,
138) -> Callable[[F], F]:
139    """Decorator to add a method to the method_dict"""
140
141    def decorator(method: F) -> F:
142        method_name: str
143        if custom_name is None:
144            method_name_orig: str | None = getattr(method, "__name__", None)
145            if method_name_orig is None:
146                warnings.warn(
147                    f"Method {method} does not have a name, using sanitized repr"
148                )
149                from muutils.misc import sanitize_identifier
150
151                method_name = sanitize_identifier(repr(method))
152            else:
153                method_name = method_name_orig
154        else:
155            method_name = custom_name
156            method.__name__ = custom_name
157        assert method_name not in method_dict, (
158            f"Method name already exists in method_dict: {method_name = }, {list(method_dict.keys()) = }"
159        )
160        method_dict[method_name] = method
161        return method
162
163    return decorator
164
165
166def pprint_summary(summary: dict):
167    print(json.dumps(summary, indent=2))

ARRAY_IMPORTS: bool = True
DEFAULT_SEED: int = 42
GLOBAL_SEED: int = 42
def get_device(device: Union[str, torch.device, NoneType] = None) -> torch.device:
31def get_device(device: "Union[str,torch.device,None]" = None) -> "torch.device":
32    """Get the torch.device instance on which `torch.Tensor`s should be allocated."""
33    if not ARRAY_IMPORTS:
34        raise ImportError(
35            "Numpy or torch not installed. Array operations will not be available."
36        )
37    try:
38        # if device is given
39        if device is not None:
40            device = torch.device(device)
41            if any(
42                [
43                    torch.cuda.is_available() and device.type == "cuda",
44                    torch.backends.mps.is_available() and device.type == "mps",
45                    device.type == "cpu",
46                ]
47            ):
48                # if device is given and available
49                pass
50            else:
51                warnings.warn(
52                    f"Specified device {device} is not available, falling back to CPU"
53                )
54                return torch.device("cpu")
55
56        # no device given, infer from availability
57        else:
58            if torch.cuda.is_available():
59                device = torch.device("cuda")
60            elif torch.backends.mps.is_available():
61                device = torch.device("mps")
62            else:
63                device = torch.device("cpu")
64
65        # put a dummy tensor on the device to check if it is available
66        _dummy = torch.zeros(1, device=device)
67
68        return device
69
70    except Exception as e:
71        warnings.warn(
72            f"Error while getting device, falling back to CPU. Error: {e}",
73            RuntimeWarning,
74        )
75        return torch.device("cpu")

Get the torch.device instance on which torch.Tensors should be allocated.

def set_reproducibility(seed: int = 42):
78def set_reproducibility(seed: int = DEFAULT_SEED):
79    """
80    Improve model reproducibility. See https://github.com/NVIDIA/framework-determinism for more information.
81
82    Deterministic operations tend to have worse performance than nondeterministic operations, so this method trades
83    off performance for reproducibility. Set use_deterministic_algorithms to True to improve performance.
84    """
85    global GLOBAL_SEED
86
87    GLOBAL_SEED = seed
88
89    random.seed(seed)
90
91    if ARRAY_IMPORTS:
92        np.random.seed(seed)
93        torch.manual_seed(seed)
94
95        torch.use_deterministic_algorithms(True)
96        # Ensure reproducibility for concurrent CUDA streams
97        # see https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility.
98        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

Improve model reproducibility. See https://github.com/NVIDIA/framework-determinism for more information.

Deterministic operations tend to have worse performance than nondeterministic operations, so this method trades off performance for reproducibility. Set use_deterministic_algorithms to True to improve performance.

def chunks(it, chunk_size):
101def chunks(it, chunk_size):
102    """Yield successive chunks from an iterator."""
103    # https://stackoverflow.com/a/61435714
104    iterator = iter(it)
105    while chunk := list(islice(iterator, chunk_size)):
106        yield chunk

Yield successive chunks from an iterator.

def get_checkpoint_paths_for_run( run_path: pathlib._local.Path, extension: Literal['pt', 'zanj'], checkpoints_format: str = 'checkpoints/model.iter_*.{extension}') -> list[tuple[int, pathlib._local.Path]]:
109def get_checkpoint_paths_for_run(
110    run_path: Path,
111    extension: typing.Literal["pt", "zanj"],
112    checkpoints_format: str = "checkpoints/model.iter_*.{extension}",
113) -> list[tuple[int, Path]]:
114    """get checkpoints of the format from the run_path
115
116    note that `checkpoints_format` should contain a glob pattern with:
117     - unresolved "{extension}" format term for the extension
118     - a wildcard for the iteration number
119    """
120
121    assert run_path.is_dir(), (
122        f"Model path {run_path} is not a directory (expect run directory, not model files)"
123    )
124
125    return [
126        (int(checkpoint_path.stem.split("_")[-1].split(".")[0]), checkpoint_path)
127        for checkpoint_path in sorted(
128            Path(run_path).glob(checkpoints_format.format(extension=extension))
129        )
130    ]

get checkpoints of the format from the run_path

note that checkpoints_format should contain a glob pattern with:

  • unresolved "{extension}" format term for the extension
  • a wildcard for the iteration number
def register_method( method_dict: dict[str, typing.Callable[..., typing.Any]], custom_name: Optional[str] = None) -> Callable[[~F], ~F]:
136def register_method(
137    method_dict: dict[str, Callable[..., Any]],
138    custom_name: Optional[str] = None,
139) -> Callable[[F], F]:
140    """Decorator to add a method to the method_dict"""
141
142    def decorator(method: F) -> F:
143        method_name: str
144        if custom_name is None:
145            method_name_orig: str | None = getattr(method, "__name__", None)
146            if method_name_orig is None:
147                warnings.warn(
148                    f"Method {method} does not have a name, using sanitized repr"
149                )
150                from muutils.misc import sanitize_identifier
151
152                method_name = sanitize_identifier(repr(method))
153            else:
154                method_name = method_name_orig
155        else:
156            method_name = custom_name
157            method.__name__ = custom_name
158        assert method_name not in method_dict, (
159            f"Method name already exists in method_dict: {method_name = }, {list(method_dict.keys()) = }"
160        )
161        method_dict[method_name] = method
162        return method
163
164    return decorator

Decorator to add a method to the method_dict

def pprint_summary(summary: dict):
167def pprint_summary(summary: dict):
168    print(json.dumps(summary, indent=2))