docs for muutils v0.8.12
View Source on GitHub

muutils.tensor_info

get metadata about a tensor, mostly for muutils.dbg


  1"get metadata about a tensor, mostly for `muutils.dbg`"
  2
  3from __future__ import annotations
  4
  5import numpy as np
  6from typing import Union, Any, Literal, List, Dict, overload, Optional
  7
  8# Global color definitions
  9COLORS: Dict[str, Dict[str, str]] = {
 10    "latex": {
 11        "range": r"\textcolor{purple}",
 12        "mean": r"\textcolor{teal}",
 13        "std": r"\textcolor{orange}",
 14        "median": r"\textcolor{green}",
 15        "warning": r"\textcolor{red}",
 16        "shape": r"\textcolor{magenta}",
 17        "dtype": r"\textcolor{gray}",
 18        "device": r"\textcolor{gray}",
 19        "requires_grad": r"\textcolor{gray}",
 20        "sparkline": r"\textcolor{blue}",
 21        "torch": r"\textcolor{orange}",
 22        "dtype_bool": r"\textcolor{gray}",
 23        "dtype_int": r"\textcolor{blue}",
 24        "dtype_float": r"\textcolor{red!70}",  # 70% red intensity
 25        "dtype_str": r"\textcolor{red}",
 26        "device_cuda": r"\textcolor{green}",
 27        "reset": "",
 28    },
 29    "terminal": {
 30        "range": "\033[35m",  # purple
 31        "mean": "\033[36m",  # cyan/teal
 32        "std": "\033[33m",  # yellow/orange
 33        "median": "\033[32m",  # green
 34        "warning": "\033[31m",  # red
 35        "shape": "\033[95m",  # bright magenta
 36        "dtype": "\033[90m",  # gray
 37        "device": "\033[90m",  # gray
 38        "requires_grad": "\033[90m",  # gray
 39        "sparkline": "\033[34m",  # blue
 40        "torch": "\033[38;5;208m",  # bright orange
 41        "dtype_bool": "\033[38;5;245m",  # medium grey
 42        "dtype_int": "\033[38;5;39m",  # bright blue
 43        "dtype_float": "\033[38;5;167m",  # softer red/coral
 44        "device_cuda": "\033[38;5;76m",  # NVIDIA-style bright green
 45        "reset": "\033[0m",
 46    },
 47    "none": {
 48        "range": "",
 49        "mean": "",
 50        "std": "",
 51        "median": "",
 52        "warning": "",
 53        "shape": "",
 54        "dtype": "",
 55        "device": "",
 56        "requires_grad": "",
 57        "sparkline": "",
 58        "torch": "",
 59        "dtype_bool": "",
 60        "dtype_int": "",
 61        "dtype_float": "",
 62        "dtype_str": "",
 63        "device_cuda": "",
 64        "reset": "",
 65    },
 66}
 67
 68OutputFormat = Literal["unicode", "latex", "ascii"]
 69
 70SYMBOLS: Dict[OutputFormat, Dict[str, str]] = {
 71    "latex": {
 72        "range": r"\mathcal{R}",
 73        "mean": r"\mu",
 74        "std": r"\sigma",
 75        "median": r"\tilde{x}",
 76        "distribution": r"\mathbb{P}",
 77        "distribution_log": r"\mathbb{P}_L",
 78        "nan_values": r"\text{NANvals}",
 79        "warning": "!!!",
 80        "requires_grad": r"\nabla",
 81        "true": r"\checkmark",
 82        "false": r"\times",
 83    },
 84    "unicode": {
 85        "range": "R",
 86        "mean": "μ",
 87        "std": "σ",
 88        "median": "x̃",
 89        "distribution": "ℙ",
 90        "distribution_log": "ℙ˪",
 91        "nan_values": "NANvals",
 92        "warning": "🚨",
 93        "requires_grad": "∇",
 94        "true": "✓",
 95        "false": "✗",
 96    },
 97    "ascii": {
 98        "range": "range",
 99        "mean": "mean",
100        "std": "std",
101        "median": "med",
102        "distribution": "dist",
103        "distribution_log": "dist_log",
104        "nan_values": "NANvals",
105        "warning": "!!!",
106        "requires_grad": "requires_grad",
107        "true": "1",
108        "false": "0",
109    },
110}
111"Symbols for different formats"
112
113SPARK_CHARS: Dict[OutputFormat, List[str]] = {
114    "unicode": list(" ▁▂▃▄▅▆▇█"),
115    "ascii": list(" _.-~=#"),
116    "latex": list(" ▁▂▃▄▅▆▇█"),
117}
118"characters for sparklines in different formats"
119
120
121def array_info(
122    A: Any,
123    hist_bins: int = 5,
124) -> Dict[str, Any]:
125    """Extract statistical information from an array-like object.
126
127    # Parameters:
128     - `A : array-like`
129            Array to analyze (numpy array or torch tensor)
130
131    # Returns:
132     - `Dict[str, Any]`
133            Dictionary containing raw statistical information with numeric values
134    """
135    result: Dict[str, Any] = {
136        "is_tensor": None,
137        "device": None,
138        "requires_grad": None,
139        "shape": None,
140        "dtype": None,
141        "size": None,
142        "has_nans": None,
143        "nan_count": None,
144        "nan_percent": None,
145        "min": None,
146        "max": None,
147        "range": None,
148        "mean": None,
149        "std": None,
150        "median": None,
151        "histogram": None,
152        "bins": None,
153        "status": None,
154    }
155
156    # Check if it's a tensor by looking at its class name
157    # This avoids importing torch directly
158    A_type: str = type(A).__name__
159    result["is_tensor"] = A_type == "Tensor"
160
161    # Try to get device information if it's a tensor
162    if result["is_tensor"]:
163        try:
164            result["device"] = str(getattr(A, "device", None))
165        except:  # noqa: E722
166            pass
167
168    # Convert to numpy array for calculations
169    try:
170        # For PyTorch tensors
171        if result["is_tensor"]:
172            # Check if tensor is on GPU
173            is_cuda: bool = False
174            try:
175                is_cuda = bool(getattr(A, "is_cuda", False))
176            except:  # noqa: E722
177                pass
178
179            if is_cuda:
180                try:
181                    # Try to get CPU tensor first
182                    cpu_tensor = getattr(A, "cpu", lambda: A)()
183                except:  # noqa: E722
184                    A_np = np.array([])
185            else:
186                cpu_tensor = A
187            try:
188                # For CPU tensor, just detach and convert
189                detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)()
190                A_np = getattr(detached, "numpy", lambda: np.array([]))()
191            except:  # noqa: E722
192                A_np = np.array([])
193        else:
194            # For numpy arrays and other array-like objects
195            A_np = np.asarray(A)
196    except:  # noqa: E722
197        A_np = np.array([])
198
199    # Get basic information
200    try:
201        result["shape"] = A_np.shape
202        result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype)
203        result["size"] = A_np.size
204        result["requires_grad"] = getattr(A, "requires_grad", None)
205    except:  # noqa: E722
206        pass
207
208    # If array is empty, return early
209    if result["size"] == 0:
210        result["status"] = "empty array"
211        return result
212
213    # Flatten array for statistics if it's multi-dimensional
214    # TODO: type checks fail on 3.10, see https://github.com/mivanit/muutils/actions/runs/18883100459/job/53891346225
215    try:
216        if len(A_np.shape) > 1:
217            A_flat = A_np.flatten()  # type: ignore[assignment]
218        else:
219            A_flat = A_np  # type: ignore[assignment]
220    except:  # noqa: E722
221        A_flat = A_np  # type: ignore[assignment]
222
223    # Check for NaN values
224    try:
225        nan_mask = np.isnan(A_flat)
226        result["nan_count"] = np.sum(nan_mask)
227        result["has_nans"] = result["nan_count"] > 0
228        if result["size"] > 0:
229            result["nan_percent"] = (result["nan_count"] / result["size"]) * 100
230    except:  # noqa: E722
231        pass
232
233    # If all values are NaN, return early
234    if result["has_nans"] and result["nan_count"] == result["size"]:
235        result["status"] = "all NaN"
236        return result
237
238    # Calculate statistics
239    try:
240        if result["has_nans"]:
241            result["min"] = float(np.nanmin(A_flat))
242            result["max"] = float(np.nanmax(A_flat))
243            result["mean"] = float(np.nanmean(A_flat))
244            result["std"] = float(np.nanstd(A_flat))
245            result["median"] = float(np.nanmedian(A_flat))
246            result["range"] = (result["min"], result["max"])
247
248            # Remove NaNs for histogram
249            A_hist = A_flat[~nan_mask]
250        else:
251            result["min"] = float(np.min(A_flat))
252            result["max"] = float(np.max(A_flat))
253            result["mean"] = float(np.mean(A_flat))
254            result["std"] = float(np.std(A_flat))
255            result["median"] = float(np.median(A_flat))
256            result["range"] = (result["min"], result["max"])
257
258            A_hist = A_flat
259
260        # Calculate histogram data for sparklines
261        if A_hist.size > 0:
262            try:
263                # TODO: handle bool tensors correctly
264                # muutils/tensor_info.py:238: RuntimeWarning: Converting input from bool to <class 'numpy.uint8'> for compatibility.
265                hist, bins = np.histogram(A_hist, bins=hist_bins)
266                result["histogram"] = hist
267                result["bins"] = bins
268            except:  # noqa: E722
269                pass
270
271        result["status"] = "ok"
272    except Exception as e:
273        result["status"] = f"error: {str(e)}"
274
275    return result
276
277
278def generate_sparkline(
279    histogram: np.ndarray,
280    format: Literal["unicode", "latex", "ascii"] = "unicode",
281    log_y: Optional[bool] = None,
282) -> tuple[str, bool]:
283    """Generate a sparkline visualization of the histogram.
284
285    # Parameters:
286    - `histogram : np.ndarray`
287        Histogram data
288    - `format : Literal["unicode", "latex", "ascii"]`
289        Output format (defaults to `"unicode"`)
290    - `log_y : bool|None`
291        Whether to use logarithmic y-scale. `None` for automatic detection
292        (defaults to `None`)
293
294    # Returns:
295    - `tuple[str, bool]`
296        Sparkline visualization and whether log scale was used
297    """
298    if histogram is None or len(histogram) == 0:
299        return "", False
300
301    # Get the appropriate character set
302    chars: List[str]
303    if format in SPARK_CHARS:
304        chars = SPARK_CHARS[format]
305    else:
306        chars = SPARK_CHARS["ascii"]
307
308    # automatic detection of log_y
309    if log_y is None:
310        # we bin the histogram values to the number of levels in our sparkline characters
311        hist_hist = np.histogram(histogram, bins=len(chars))[0]
312        # if every bin except the smallest (first) and largest (last) is empty,
313        # then we should use the log scale. if those bins are nonempty, keep the linear scale
314        if hist_hist[1:-1].max() > 0:
315            log_y = False
316        else:
317            log_y = True
318
319    # Handle log scale
320    if log_y:
321        # Add small value to avoid log(0)
322        hist_data = np.log1p(histogram)
323    else:
324        hist_data = histogram
325
326    # Normalize to character set range
327    if hist_data.max() > 0:
328        normalized = hist_data / hist_data.max() * (len(chars) - 1)
329    else:
330        normalized = np.zeros_like(hist_data)
331
332    # Convert to characters
333    spark = ""
334    for val in normalized:
335        idx = round(val)
336        spark += chars[idx]
337
338    return spark, log_y
339
340
341DEFAULT_SETTINGS: Dict[str, Any] = dict(
342    fmt="unicode",
343    precision=2,
344    stats=True,
345    shape=True,
346    dtype=True,
347    device=True,
348    requires_grad=True,
349    sparkline=False,
350    sparkline_bins=5,
351    sparkline_logy=None,
352    colored=False,
353    as_list=False,
354    eq_char="=",
355)
356
357
358def apply_color(
359    text: str, color_key: str, colors: Dict[str, str], using_tex: bool
360) -> str:
361    if using_tex:
362        return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
363    else:
364        return (
365            f"{colors[color_key]}{text}{colors['reset']}" if colors[color_key] else text
366        )
367
368
369def colorize_dtype(dtype_str: str, colors: Dict[str, str], using_tex: bool) -> str:
370    """Colorize dtype string with specific colors for torch and type names."""
371
372    # Handle torch prefix
373    type_part: str = dtype_str
374    prefix_part: Optional[str] = None
375    if "torch." in dtype_str:
376        parts = dtype_str.split("torch.")
377        if len(parts) == 2:
378            prefix_part = apply_color("torch", "torch", colors, using_tex)
379            type_part = parts[1]
380
381    # Handle type coloring
382    color_key: str = "dtype"
383    if "bool" in dtype_str.lower():
384        color_key = "dtype_bool"
385    elif "int" in dtype_str.lower():
386        color_key = "dtype_int"
387    elif "float" in dtype_str.lower():
388        color_key = "dtype_float"
389
390    type_colored: str = apply_color(type_part, color_key, colors, using_tex)
391
392    if prefix_part:
393        return f"{prefix_part}.{type_colored}"
394    else:
395        return type_colored
396
397
398def format_shape_colored(shape_val, colors: Dict[str, str], using_tex: bool) -> str:
399    """Format shape with proper coloring for both 1D and multi-D arrays."""
400
401    def apply_color(text: str, color_key: str) -> str:
402        if using_tex:
403            return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
404        else:
405            return (
406                f"{colors[color_key]}{text}{colors['reset']}"
407                if colors[color_key]
408                else text
409            )
410
411    if len(shape_val) == 1:
412        # For 1D arrays, still color the dimension value
413        return apply_color(str(shape_val[0]), "shape")
414    else:
415        # For multi-D arrays, color each dimension
416        return "(" + ",".join(apply_color(str(dim), "shape") for dim in shape_val) + ")"
417
418
419def format_device_colored(
420    device_str: str, colors: Dict[str, str], using_tex: bool
421) -> str:
422    """Format device string with CUDA highlighting."""
423
424    def apply_color(text: str, color_key: str) -> str:
425        if using_tex:
426            return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
427        else:
428            return (
429                f"{colors[color_key]}{text}{colors['reset']}"
430                if colors[color_key]
431                else text
432            )
433
434    if "cuda" in device_str.lower():
435        return apply_color(device_str, "device_cuda")
436    else:
437        return apply_color(device_str, "device")
438
439
440class _UseDefaultType:
441    pass
442
443
444_USE_DEFAULT = _UseDefaultType()
445
446
447@overload
448def array_summary(
449    array: Any,
450    as_list: Literal[True],
451    **kwargs,
452) -> List[str]: ...
453@overload
454def array_summary(
455    array: Any,
456    as_list: Literal[False],
457    **kwargs,
458) -> str: ...
459def array_summary(  # type: ignore[misc]
460    array,
461    fmt: OutputFormat = _USE_DEFAULT,  # type: ignore[assignment]
462    precision: int = _USE_DEFAULT,  # type: ignore[assignment]
463    stats: bool = _USE_DEFAULT,  # type: ignore[assignment]
464    shape: bool = _USE_DEFAULT,  # type: ignore[assignment]
465    dtype: bool = _USE_DEFAULT,  # type: ignore[assignment]
466    device: bool = _USE_DEFAULT,  # type: ignore[assignment]
467    requires_grad: bool = _USE_DEFAULT,  # type: ignore[assignment]
468    sparkline: bool = _USE_DEFAULT,  # type: ignore[assignment]
469    sparkline_bins: int = _USE_DEFAULT,  # type: ignore[assignment]
470    sparkline_logy: Optional[bool] = _USE_DEFAULT,  # type: ignore[assignment]
471    colored: bool = _USE_DEFAULT,  # type: ignore[assignment]
472    eq_char: str = _USE_DEFAULT,  # type: ignore[assignment]
473    as_list: bool = _USE_DEFAULT,  # type: ignore[assignment]
474) -> Union[str, List[str]]:
475    """Format array information into a readable summary.
476
477    # Parameters:
478     - `array`
479            array-like object (numpy array or torch tensor)
480     - `precision : int`
481            Decimal places (defaults to `2`)
482     - `format : Literal["unicode", "latex", "ascii"]`
483            Output format (defaults to `{default_fmt}`)
484     - `stats : bool`
485            Whether to include statistical info (μ, σ, x̃) (defaults to `True`)
486     - `shape : bool`
487            Whether to include shape info (defaults to `True`)
488     - `dtype : bool`
489            Whether to include dtype info (defaults to `True`)
490     - `device : bool`
491            Whether to include device info for torch tensors (defaults to `True`)
492     - `requires_grad : bool`
493            Whether to include requires_grad info for torch tensors (defaults to `True`)
494     - `sparkline : bool`
495            Whether to include a sparkline visualization (defaults to `False`)
496     - `sparkline_width : int`
497            Width of the sparkline (defaults to `20`)
498     - `sparkline_logy : bool|None`
499            Whether to use logarithmic y-scale for sparkline (defaults to `None`)
500     - `colored : bool`
501            Whether to add color to output (defaults to `False`)
502     - `as_list : bool`
503            Whether to return as list of strings instead of joined string (defaults to `False`)
504
505    # Returns:
506     - `Union[str, List[str]]`
507            Formatted statistical summary, either as string or list of strings
508    """
509    if fmt is _USE_DEFAULT:
510        fmt = DEFAULT_SETTINGS["fmt"]
511    if precision is _USE_DEFAULT:
512        precision = DEFAULT_SETTINGS["precision"]
513    if stats is _USE_DEFAULT:
514        stats = DEFAULT_SETTINGS["stats"]
515    if shape is _USE_DEFAULT:
516        shape = DEFAULT_SETTINGS["shape"]
517    if dtype is _USE_DEFAULT:
518        dtype = DEFAULT_SETTINGS["dtype"]
519    if device is _USE_DEFAULT:
520        device = DEFAULT_SETTINGS["device"]
521    if requires_grad is _USE_DEFAULT:
522        requires_grad = DEFAULT_SETTINGS["requires_grad"]
523    if sparkline is _USE_DEFAULT:
524        sparkline = DEFAULT_SETTINGS["sparkline"]
525    if sparkline_bins is _USE_DEFAULT:
526        sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"]
527    if sparkline_logy is _USE_DEFAULT:
528        sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"]
529    if colored is _USE_DEFAULT:
530        colored = DEFAULT_SETTINGS["colored"]
531    if as_list is _USE_DEFAULT:
532        as_list = DEFAULT_SETTINGS["as_list"]
533    if eq_char is _USE_DEFAULT:
534        eq_char = DEFAULT_SETTINGS["eq_char"]
535
536    array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins)
537    result_parts: List[str] = []
538    using_tex: bool = fmt == "latex"
539
540    # Set color scheme based on format and colored flag
541    colors: Dict[str, str]
542    if colored:
543        colors = COLORS["latex"] if using_tex else COLORS["terminal"]
544    else:
545        colors = COLORS["none"]
546
547    # Get symbols for the current format
548    symbols: Dict[str, str] = SYMBOLS[fmt]
549
550    # Helper function to colorize text
551    def colorize(text: str, color_key: str) -> str:
552        if using_tex:
553            return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
554        else:
555            return (
556                f"{colors[color_key]}{text}{colors['reset']}"
557                if colors[color_key]
558                else text
559            )
560
561    # Check if dtype is integer type
562    dtype_str: str = array_data.get("dtype", "")
563    is_int_dtype: bool = any(
564        int_type in dtype_str.lower() for int_type in ["int", "uint", "bool"]
565    )
566
567    # Format string for numbers
568    float_fmt: str = f".{precision}f"
569
570    # Handle error status or empty array
571    if (
572        array_data["status"] in ["empty array", "all NaN", "unknown"]
573        or array_data["size"] == 0
574    ):
575        status = array_data["status"]
576        result_parts.append(colorize(symbols["warning"] + " " + status, "warning"))
577    else:
578        # Add NaN warning at the beginning if there are NaNs
579        if array_data["has_nans"]:
580            _percent: str = "\\%" if using_tex else "%"
581            nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})"
582            result_parts.append(colorize(nan_str, "warning"))
583
584        # Statistics
585        if stats:
586            for stat_key in ["mean", "std", "median"]:
587                if array_data[stat_key] is not None:
588                    stat_str: str = f"{array_data[stat_key]:{float_fmt}}"
589                    stat_colored: str = colorize(stat_str, stat_key)
590                    result_parts.append(f"{symbols[stat_key]}={stat_colored}")
591
592            # Range (min, max)
593            if array_data["range"] is not None:
594                min_val, max_val = array_data["range"]
595                if is_int_dtype:
596                    min_str: str = f"{int(min_val):d}"
597                    max_str: str = f"{int(max_val):d}"
598                else:
599                    min_str = f"{min_val:{float_fmt}}"
600                    max_str = f"{max_val:{float_fmt}}"
601                min_colored: str = colorize(min_str, "range")
602                max_colored: str = colorize(max_str, "range")
603                range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]"
604                result_parts.append(range_str)
605
606    # Add sparkline if requested
607    if sparkline and array_data["histogram"] is not None:
608        # this should return whether log_y is used or not and then we set the symbol accordingly
609        spark, used_log = generate_sparkline(
610            array_data["histogram"],
611            format=fmt,
612            log_y=sparkline_logy,
613        )
614        if spark:
615            spark_colored = colorize(spark, "sparkline")
616            dist_symbol = (
617                symbols["distribution_log"] if used_log else symbols["distribution"]
618            )
619            result_parts.append(f"{dist_symbol}{eq_char}|{spark_colored}|")
620
621    # Add shape if requested
622    if shape and array_data["shape"]:
623        shape_val = array_data["shape"]
624        shape_str = format_shape_colored(shape_val, colors, using_tex)
625        result_parts.append(f"shape{eq_char}{shape_str}")
626
627    # Add dtype if requested
628    if dtype and array_data["dtype"]:
629        dtype_colored = colorize_dtype(array_data["dtype"], colors, using_tex)
630        result_parts.append(f"dtype={dtype_colored}")
631
632    # Add device if requested and it's a tensor with device info
633    if device and array_data["is_tensor"] and array_data["device"]:
634        device_colored = format_device_colored(array_data["device"], colors, using_tex)
635        result_parts.append(f"device{eq_char}{device_colored}")
636
637    # Add gradient info
638    if requires_grad and array_data["is_tensor"]:
639        bool_req_grad_symb: str = (
640            symbols["true"] if array_data["requires_grad"] else symbols["false"]
641        )
642        result_parts.append(
643            colorize(symbols["requires_grad"] + bool_req_grad_symb, "requires_grad")
644        )
645
646    # Return as list if requested, otherwise join with spaces
647    if as_list:
648        return result_parts
649    else:
650        joinchar: str = r" \quad " if using_tex else " "
651        return joinchar.join(result_parts)

COLORS: Dict[str, Dict[str, str]] = {'latex': {'range': '\\textcolor{purple}', 'mean': '\\textcolor{teal}', 'std': '\\textcolor{orange}', 'median': '\\textcolor{green}', 'warning': '\\textcolor{red}', 'shape': '\\textcolor{magenta}', 'dtype': '\\textcolor{gray}', 'device': '\\textcolor{gray}', 'requires_grad': '\\textcolor{gray}', 'sparkline': '\\textcolor{blue}', 'torch': '\\textcolor{orange}', 'dtype_bool': '\\textcolor{gray}', 'dtype_int': '\\textcolor{blue}', 'dtype_float': '\\textcolor{red!70}', 'dtype_str': '\\textcolor{red}', 'device_cuda': '\\textcolor{green}', 'reset': ''}, 'terminal': {'range': '\x1b[35m', 'mean': '\x1b[36m', 'std': '\x1b[33m', 'median': '\x1b[32m', 'warning': '\x1b[31m', 'shape': '\x1b[95m', 'dtype': '\x1b[90m', 'device': '\x1b[90m', 'requires_grad': '\x1b[90m', 'sparkline': '\x1b[34m', 'torch': '\x1b[38;5;208m', 'dtype_bool': '\x1b[38;5;245m', 'dtype_int': '\x1b[38;5;39m', 'dtype_float': '\x1b[38;5;167m', 'device_cuda': '\x1b[38;5;76m', 'reset': '\x1b[0m'}, 'none': {'range': '', 'mean': '', 'std': '', 'median': '', 'warning': '', 'shape': '', 'dtype': '', 'device': '', 'requires_grad': '', 'sparkline': '', 'torch': '', 'dtype_bool': '', 'dtype_int': '', 'dtype_float': '', 'dtype_str': '', 'device_cuda': '', 'reset': ''}}
OutputFormat = typing.Literal['unicode', 'latex', 'ascii']
SYMBOLS: Dict[Literal['unicode', 'latex', 'ascii'], Dict[str, str]] = {'latex': {'range': '\\mathcal{R}', 'mean': '\\mu', 'std': '\\sigma', 'median': '\\tilde{x}', 'distribution': '\\mathbb{P}', 'distribution_log': '\\mathbb{P}_L', 'nan_values': '\\text{NANvals}', 'warning': '!!!', 'requires_grad': '\\nabla', 'true': '\\checkmark', 'false': '\\times'}, 'unicode': {'range': 'R', 'mean': 'μ', 'std': 'σ', 'median': 'x̃', 'distribution': 'ℙ', 'distribution_log': 'ℙ˪', 'nan_values': 'NANvals', 'warning': '🚨', 'requires_grad': '∇', 'true': '✓', 'false': '✗'}, 'ascii': {'range': 'range', 'mean': 'mean', 'std': 'std', 'median': 'med', 'distribution': 'dist', 'distribution_log': 'dist_log', 'nan_values': 'NANvals', 'warning': '!!!', 'requires_grad': 'requires_grad', 'true': '1', 'false': '0'}}

Symbols for different formats

SPARK_CHARS: Dict[Literal['unicode', 'latex', 'ascii'], List[str]] = {'unicode': [' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█'], 'ascii': [' ', '_', '.', '-', '~', '=', '#'], 'latex': [' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█']}

characters for sparklines in different formats

def array_info(A: Any, hist_bins: int = 5) -> Dict[str, Any]:
122def array_info(
123    A: Any,
124    hist_bins: int = 5,
125) -> Dict[str, Any]:
126    """Extract statistical information from an array-like object.
127
128    # Parameters:
129     - `A : array-like`
130            Array to analyze (numpy array or torch tensor)
131
132    # Returns:
133     - `Dict[str, Any]`
134            Dictionary containing raw statistical information with numeric values
135    """
136    result: Dict[str, Any] = {
137        "is_tensor": None,
138        "device": None,
139        "requires_grad": None,
140        "shape": None,
141        "dtype": None,
142        "size": None,
143        "has_nans": None,
144        "nan_count": None,
145        "nan_percent": None,
146        "min": None,
147        "max": None,
148        "range": None,
149        "mean": None,
150        "std": None,
151        "median": None,
152        "histogram": None,
153        "bins": None,
154        "status": None,
155    }
156
157    # Check if it's a tensor by looking at its class name
158    # This avoids importing torch directly
159    A_type: str = type(A).__name__
160    result["is_tensor"] = A_type == "Tensor"
161
162    # Try to get device information if it's a tensor
163    if result["is_tensor"]:
164        try:
165            result["device"] = str(getattr(A, "device", None))
166        except:  # noqa: E722
167            pass
168
169    # Convert to numpy array for calculations
170    try:
171        # For PyTorch tensors
172        if result["is_tensor"]:
173            # Check if tensor is on GPU
174            is_cuda: bool = False
175            try:
176                is_cuda = bool(getattr(A, "is_cuda", False))
177            except:  # noqa: E722
178                pass
179
180            if is_cuda:
181                try:
182                    # Try to get CPU tensor first
183                    cpu_tensor = getattr(A, "cpu", lambda: A)()
184                except:  # noqa: E722
185                    A_np = np.array([])
186            else:
187                cpu_tensor = A
188            try:
189                # For CPU tensor, just detach and convert
190                detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)()
191                A_np = getattr(detached, "numpy", lambda: np.array([]))()
192            except:  # noqa: E722
193                A_np = np.array([])
194        else:
195            # For numpy arrays and other array-like objects
196            A_np = np.asarray(A)
197    except:  # noqa: E722
198        A_np = np.array([])
199
200    # Get basic information
201    try:
202        result["shape"] = A_np.shape
203        result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype)
204        result["size"] = A_np.size
205        result["requires_grad"] = getattr(A, "requires_grad", None)
206    except:  # noqa: E722
207        pass
208
209    # If array is empty, return early
210    if result["size"] == 0:
211        result["status"] = "empty array"
212        return result
213
214    # Flatten array for statistics if it's multi-dimensional
215    # TODO: type checks fail on 3.10, see https://github.com/mivanit/muutils/actions/runs/18883100459/job/53891346225
216    try:
217        if len(A_np.shape) > 1:
218            A_flat = A_np.flatten()  # type: ignore[assignment]
219        else:
220            A_flat = A_np  # type: ignore[assignment]
221    except:  # noqa: E722
222        A_flat = A_np  # type: ignore[assignment]
223
224    # Check for NaN values
225    try:
226        nan_mask = np.isnan(A_flat)
227        result["nan_count"] = np.sum(nan_mask)
228        result["has_nans"] = result["nan_count"] > 0
229        if result["size"] > 0:
230            result["nan_percent"] = (result["nan_count"] / result["size"]) * 100
231    except:  # noqa: E722
232        pass
233
234    # If all values are NaN, return early
235    if result["has_nans"] and result["nan_count"] == result["size"]:
236        result["status"] = "all NaN"
237        return result
238
239    # Calculate statistics
240    try:
241        if result["has_nans"]:
242            result["min"] = float(np.nanmin(A_flat))
243            result["max"] = float(np.nanmax(A_flat))
244            result["mean"] = float(np.nanmean(A_flat))
245            result["std"] = float(np.nanstd(A_flat))
246            result["median"] = float(np.nanmedian(A_flat))
247            result["range"] = (result["min"], result["max"])
248
249            # Remove NaNs for histogram
250            A_hist = A_flat[~nan_mask]
251        else:
252            result["min"] = float(np.min(A_flat))
253            result["max"] = float(np.max(A_flat))
254            result["mean"] = float(np.mean(A_flat))
255            result["std"] = float(np.std(A_flat))
256            result["median"] = float(np.median(A_flat))
257            result["range"] = (result["min"], result["max"])
258
259            A_hist = A_flat
260
261        # Calculate histogram data for sparklines
262        if A_hist.size > 0:
263            try:
264                # TODO: handle bool tensors correctly
265                # muutils/tensor_info.py:238: RuntimeWarning: Converting input from bool to <class 'numpy.uint8'> for compatibility.
266                hist, bins = np.histogram(A_hist, bins=hist_bins)
267                result["histogram"] = hist
268                result["bins"] = bins
269            except:  # noqa: E722
270                pass
271
272        result["status"] = "ok"
273    except Exception as e:
274        result["status"] = f"error: {str(e)}"
275
276    return result

Extract statistical information from an array-like object.

Parameters:

  • A : array-like Array to analyze (numpy array or torch tensor)

Returns:

  • Dict[str, Any] Dictionary containing raw statistical information with numeric values
def generate_sparkline( histogram: numpy.ndarray, format: Literal['unicode', 'latex', 'ascii'] = 'unicode', log_y: Optional[bool] = None) -> tuple[str, bool]:
279def generate_sparkline(
280    histogram: np.ndarray,
281    format: Literal["unicode", "latex", "ascii"] = "unicode",
282    log_y: Optional[bool] = None,
283) -> tuple[str, bool]:
284    """Generate a sparkline visualization of the histogram.
285
286    # Parameters:
287    - `histogram : np.ndarray`
288        Histogram data
289    - `format : Literal["unicode", "latex", "ascii"]`
290        Output format (defaults to `"unicode"`)
291    - `log_y : bool|None`
292        Whether to use logarithmic y-scale. `None` for automatic detection
293        (defaults to `None`)
294
295    # Returns:
296    - `tuple[str, bool]`
297        Sparkline visualization and whether log scale was used
298    """
299    if histogram is None or len(histogram) == 0:
300        return "", False
301
302    # Get the appropriate character set
303    chars: List[str]
304    if format in SPARK_CHARS:
305        chars = SPARK_CHARS[format]
306    else:
307        chars = SPARK_CHARS["ascii"]
308
309    # automatic detection of log_y
310    if log_y is None:
311        # we bin the histogram values to the number of levels in our sparkline characters
312        hist_hist = np.histogram(histogram, bins=len(chars))[0]
313        # if every bin except the smallest (first) and largest (last) is empty,
314        # then we should use the log scale. if those bins are nonempty, keep the linear scale
315        if hist_hist[1:-1].max() > 0:
316            log_y = False
317        else:
318            log_y = True
319
320    # Handle log scale
321    if log_y:
322        # Add small value to avoid log(0)
323        hist_data = np.log1p(histogram)
324    else:
325        hist_data = histogram
326
327    # Normalize to character set range
328    if hist_data.max() > 0:
329        normalized = hist_data / hist_data.max() * (len(chars) - 1)
330    else:
331        normalized = np.zeros_like(hist_data)
332
333    # Convert to characters
334    spark = ""
335    for val in normalized:
336        idx = round(val)
337        spark += chars[idx]
338
339    return spark, log_y

Generate a sparkline visualization of the histogram.

Parameters:

  • histogram : np.ndarray Histogram data
  • format : Literal["unicode", "latex", "ascii"] Output format (defaults to "unicode")
  • log_y : bool|None Whether to use logarithmic y-scale. None for automatic detection (defaults to None)

Returns:

  • tuple[str, bool] Sparkline visualization and whether log scale was used
DEFAULT_SETTINGS: Dict[str, Any] = {'fmt': 'unicode', 'precision': 2, 'stats': True, 'shape': True, 'dtype': True, 'device': True, 'requires_grad': True, 'sparkline': False, 'sparkline_bins': 5, 'sparkline_logy': None, 'colored': False, 'as_list': False, 'eq_char': '='}
def apply_color( text: str, color_key: str, colors: Dict[str, str], using_tex: bool) -> str:
359def apply_color(
360    text: str, color_key: str, colors: Dict[str, str], using_tex: bool
361) -> str:
362    if using_tex:
363        return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
364    else:
365        return (
366            f"{colors[color_key]}{text}{colors['reset']}" if colors[color_key] else text
367        )
def colorize_dtype(dtype_str: str, colors: Dict[str, str], using_tex: bool) -> str:
370def colorize_dtype(dtype_str: str, colors: Dict[str, str], using_tex: bool) -> str:
371    """Colorize dtype string with specific colors for torch and type names."""
372
373    # Handle torch prefix
374    type_part: str = dtype_str
375    prefix_part: Optional[str] = None
376    if "torch." in dtype_str:
377        parts = dtype_str.split("torch.")
378        if len(parts) == 2:
379            prefix_part = apply_color("torch", "torch", colors, using_tex)
380            type_part = parts[1]
381
382    # Handle type coloring
383    color_key: str = "dtype"
384    if "bool" in dtype_str.lower():
385        color_key = "dtype_bool"
386    elif "int" in dtype_str.lower():
387        color_key = "dtype_int"
388    elif "float" in dtype_str.lower():
389        color_key = "dtype_float"
390
391    type_colored: str = apply_color(type_part, color_key, colors, using_tex)
392
393    if prefix_part:
394        return f"{prefix_part}.{type_colored}"
395    else:
396        return type_colored

Colorize dtype string with specific colors for torch and type names.

def format_shape_colored(shape_val, colors: Dict[str, str], using_tex: bool) -> str:
399def format_shape_colored(shape_val, colors: Dict[str, str], using_tex: bool) -> str:
400    """Format shape with proper coloring for both 1D and multi-D arrays."""
401
402    def apply_color(text: str, color_key: str) -> str:
403        if using_tex:
404            return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
405        else:
406            return (
407                f"{colors[color_key]}{text}{colors['reset']}"
408                if colors[color_key]
409                else text
410            )
411
412    if len(shape_val) == 1:
413        # For 1D arrays, still color the dimension value
414        return apply_color(str(shape_val[0]), "shape")
415    else:
416        # For multi-D arrays, color each dimension
417        return "(" + ",".join(apply_color(str(dim), "shape") for dim in shape_val) + ")"

Format shape with proper coloring for both 1D and multi-D arrays.

def format_device_colored(device_str: str, colors: Dict[str, str], using_tex: bool) -> str:
420def format_device_colored(
421    device_str: str, colors: Dict[str, str], using_tex: bool
422) -> str:
423    """Format device string with CUDA highlighting."""
424
425    def apply_color(text: str, color_key: str) -> str:
426        if using_tex:
427            return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
428        else:
429            return (
430                f"{colors[color_key]}{text}{colors['reset']}"
431                if colors[color_key]
432                else text
433            )
434
435    if "cuda" in device_str.lower():
436        return apply_color(device_str, "device_cuda")
437    else:
438        return apply_color(device_str, "device")

Format device string with CUDA highlighting.

def array_summary( array, fmt: Literal['unicode', 'latex', 'ascii'] = <muutils.tensor_info._UseDefaultType object>, precision: int = <muutils.tensor_info._UseDefaultType object>, stats: bool = <muutils.tensor_info._UseDefaultType object>, shape: bool = <muutils.tensor_info._UseDefaultType object>, dtype: bool = <muutils.tensor_info._UseDefaultType object>, device: bool = <muutils.tensor_info._UseDefaultType object>, requires_grad: bool = <muutils.tensor_info._UseDefaultType object>, sparkline: bool = <muutils.tensor_info._UseDefaultType object>, sparkline_bins: int = <muutils.tensor_info._UseDefaultType object>, sparkline_logy: Optional[bool] = <muutils.tensor_info._UseDefaultType object>, colored: bool = <muutils.tensor_info._UseDefaultType object>, eq_char: str = <muutils.tensor_info._UseDefaultType object>, as_list: bool = <muutils.tensor_info._UseDefaultType object>) -> Union[str, List[str]]:
460def array_summary(  # type: ignore[misc]
461    array,
462    fmt: OutputFormat = _USE_DEFAULT,  # type: ignore[assignment]
463    precision: int = _USE_DEFAULT,  # type: ignore[assignment]
464    stats: bool = _USE_DEFAULT,  # type: ignore[assignment]
465    shape: bool = _USE_DEFAULT,  # type: ignore[assignment]
466    dtype: bool = _USE_DEFAULT,  # type: ignore[assignment]
467    device: bool = _USE_DEFAULT,  # type: ignore[assignment]
468    requires_grad: bool = _USE_DEFAULT,  # type: ignore[assignment]
469    sparkline: bool = _USE_DEFAULT,  # type: ignore[assignment]
470    sparkline_bins: int = _USE_DEFAULT,  # type: ignore[assignment]
471    sparkline_logy: Optional[bool] = _USE_DEFAULT,  # type: ignore[assignment]
472    colored: bool = _USE_DEFAULT,  # type: ignore[assignment]
473    eq_char: str = _USE_DEFAULT,  # type: ignore[assignment]
474    as_list: bool = _USE_DEFAULT,  # type: ignore[assignment]
475) -> Union[str, List[str]]:
476    """Format array information into a readable summary.
477
478    # Parameters:
479     - `array`
480            array-like object (numpy array or torch tensor)
481     - `precision : int`
482            Decimal places (defaults to `2`)
483     - `format : Literal["unicode", "latex", "ascii"]`
484            Output format (defaults to `{default_fmt}`)
485     - `stats : bool`
486            Whether to include statistical info (μ, σ, x̃) (defaults to `True`)
487     - `shape : bool`
488            Whether to include shape info (defaults to `True`)
489     - `dtype : bool`
490            Whether to include dtype info (defaults to `True`)
491     - `device : bool`
492            Whether to include device info for torch tensors (defaults to `True`)
493     - `requires_grad : bool`
494            Whether to include requires_grad info for torch tensors (defaults to `True`)
495     - `sparkline : bool`
496            Whether to include a sparkline visualization (defaults to `False`)
497     - `sparkline_width : int`
498            Width of the sparkline (defaults to `20`)
499     - `sparkline_logy : bool|None`
500            Whether to use logarithmic y-scale for sparkline (defaults to `None`)
501     - `colored : bool`
502            Whether to add color to output (defaults to `False`)
503     - `as_list : bool`
504            Whether to return as list of strings instead of joined string (defaults to `False`)
505
506    # Returns:
507     - `Union[str, List[str]]`
508            Formatted statistical summary, either as string or list of strings
509    """
510    if fmt is _USE_DEFAULT:
511        fmt = DEFAULT_SETTINGS["fmt"]
512    if precision is _USE_DEFAULT:
513        precision = DEFAULT_SETTINGS["precision"]
514    if stats is _USE_DEFAULT:
515        stats = DEFAULT_SETTINGS["stats"]
516    if shape is _USE_DEFAULT:
517        shape = DEFAULT_SETTINGS["shape"]
518    if dtype is _USE_DEFAULT:
519        dtype = DEFAULT_SETTINGS["dtype"]
520    if device is _USE_DEFAULT:
521        device = DEFAULT_SETTINGS["device"]
522    if requires_grad is _USE_DEFAULT:
523        requires_grad = DEFAULT_SETTINGS["requires_grad"]
524    if sparkline is _USE_DEFAULT:
525        sparkline = DEFAULT_SETTINGS["sparkline"]
526    if sparkline_bins is _USE_DEFAULT:
527        sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"]
528    if sparkline_logy is _USE_DEFAULT:
529        sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"]
530    if colored is _USE_DEFAULT:
531        colored = DEFAULT_SETTINGS["colored"]
532    if as_list is _USE_DEFAULT:
533        as_list = DEFAULT_SETTINGS["as_list"]
534    if eq_char is _USE_DEFAULT:
535        eq_char = DEFAULT_SETTINGS["eq_char"]
536
537    array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins)
538    result_parts: List[str] = []
539    using_tex: bool = fmt == "latex"
540
541    # Set color scheme based on format and colored flag
542    colors: Dict[str, str]
543    if colored:
544        colors = COLORS["latex"] if using_tex else COLORS["terminal"]
545    else:
546        colors = COLORS["none"]
547
548    # Get symbols for the current format
549    symbols: Dict[str, str] = SYMBOLS[fmt]
550
551    # Helper function to colorize text
552    def colorize(text: str, color_key: str) -> str:
553        if using_tex:
554            return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text
555        else:
556            return (
557                f"{colors[color_key]}{text}{colors['reset']}"
558                if colors[color_key]
559                else text
560            )
561
562    # Check if dtype is integer type
563    dtype_str: str = array_data.get("dtype", "")
564    is_int_dtype: bool = any(
565        int_type in dtype_str.lower() for int_type in ["int", "uint", "bool"]
566    )
567
568    # Format string for numbers
569    float_fmt: str = f".{precision}f"
570
571    # Handle error status or empty array
572    if (
573        array_data["status"] in ["empty array", "all NaN", "unknown"]
574        or array_data["size"] == 0
575    ):
576        status = array_data["status"]
577        result_parts.append(colorize(symbols["warning"] + " " + status, "warning"))
578    else:
579        # Add NaN warning at the beginning if there are NaNs
580        if array_data["has_nans"]:
581            _percent: str = "\\%" if using_tex else "%"
582            nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})"
583            result_parts.append(colorize(nan_str, "warning"))
584
585        # Statistics
586        if stats:
587            for stat_key in ["mean", "std", "median"]:
588                if array_data[stat_key] is not None:
589                    stat_str: str = f"{array_data[stat_key]:{float_fmt}}"
590                    stat_colored: str = colorize(stat_str, stat_key)
591                    result_parts.append(f"{symbols[stat_key]}={stat_colored}")
592
593            # Range (min, max)
594            if array_data["range"] is not None:
595                min_val, max_val = array_data["range"]
596                if is_int_dtype:
597                    min_str: str = f"{int(min_val):d}"
598                    max_str: str = f"{int(max_val):d}"
599                else:
600                    min_str = f"{min_val:{float_fmt}}"
601                    max_str = f"{max_val:{float_fmt}}"
602                min_colored: str = colorize(min_str, "range")
603                max_colored: str = colorize(max_str, "range")
604                range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]"
605                result_parts.append(range_str)
606
607    # Add sparkline if requested
608    if sparkline and array_data["histogram"] is not None:
609        # this should return whether log_y is used or not and then we set the symbol accordingly
610        spark, used_log = generate_sparkline(
611            array_data["histogram"],
612            format=fmt,
613            log_y=sparkline_logy,
614        )
615        if spark:
616            spark_colored = colorize(spark, "sparkline")
617            dist_symbol = (
618                symbols["distribution_log"] if used_log else symbols["distribution"]
619            )
620            result_parts.append(f"{dist_symbol}{eq_char}|{spark_colored}|")
621
622    # Add shape if requested
623    if shape and array_data["shape"]:
624        shape_val = array_data["shape"]
625        shape_str = format_shape_colored(shape_val, colors, using_tex)
626        result_parts.append(f"shape{eq_char}{shape_str}")
627
628    # Add dtype if requested
629    if dtype and array_data["dtype"]:
630        dtype_colored = colorize_dtype(array_data["dtype"], colors, using_tex)
631        result_parts.append(f"dtype={dtype_colored}")
632
633    # Add device if requested and it's a tensor with device info
634    if device and array_data["is_tensor"] and array_data["device"]:
635        device_colored = format_device_colored(array_data["device"], colors, using_tex)
636        result_parts.append(f"device{eq_char}{device_colored}")
637
638    # Add gradient info
639    if requires_grad and array_data["is_tensor"]:
640        bool_req_grad_symb: str = (
641            symbols["true"] if array_data["requires_grad"] else symbols["false"]
642        )
643        result_parts.append(
644            colorize(symbols["requires_grad"] + bool_req_grad_symb, "requires_grad")
645        )
646
647    # Return as list if requested, otherwise join with spaces
648    if as_list:
649        return result_parts
650    else:
651        joinchar: str = r" \quad " if using_tex else " "
652        return joinchar.join(result_parts)

Format array information into a readable summary.

Parameters:

  • array array-like object (numpy array or torch tensor)
  • precision : int Decimal places (defaults to 2)
  • format : Literal["unicode", "latex", "ascii"] Output format (defaults to {default_fmt})
  • stats : bool Whether to include statistical info (μ, σ, x̃) (defaults to True)
  • shape : bool Whether to include shape info (defaults to True)
  • dtype : bool Whether to include dtype info (defaults to True)
  • device : bool Whether to include device info for torch tensors (defaults to True)
  • requires_grad : bool Whether to include requires_grad info for torch tensors (defaults to True)
  • sparkline : bool Whether to include a sparkline visualization (defaults to False)
  • sparkline_width : int Width of the sparkline (defaults to 20)
  • sparkline_logy : bool|None Whether to use logarithmic y-scale for sparkline (defaults to None)
  • colored : bool Whether to add color to output (defaults to False)
  • as_list : bool Whether to return as list of strings instead of joined string (defaults to False)

Returns:

  • Union[str, List[str]] Formatted statistical summary, either as string or list of strings