muutils.tensor_info
get metadata about a tensor, mostly for muutils.dbg
1"get metadata about a tensor, mostly for `muutils.dbg`" 2 3from __future__ import annotations 4 5import numpy as np 6from typing import Union, Any, Literal, List, Dict, overload, Optional 7 8# Global color definitions 9COLORS: Dict[str, Dict[str, str]] = { 10 "latex": { 11 "range": r"\textcolor{purple}", 12 "mean": r"\textcolor{teal}", 13 "std": r"\textcolor{orange}", 14 "median": r"\textcolor{green}", 15 "warning": r"\textcolor{red}", 16 "shape": r"\textcolor{magenta}", 17 "dtype": r"\textcolor{gray}", 18 "device": r"\textcolor{gray}", 19 "requires_grad": r"\textcolor{gray}", 20 "sparkline": r"\textcolor{blue}", 21 "torch": r"\textcolor{orange}", 22 "dtype_bool": r"\textcolor{gray}", 23 "dtype_int": r"\textcolor{blue}", 24 "dtype_float": r"\textcolor{red!70}", # 70% red intensity 25 "dtype_str": r"\textcolor{red}", 26 "device_cuda": r"\textcolor{green}", 27 "reset": "", 28 }, 29 "terminal": { 30 "range": "\033[35m", # purple 31 "mean": "\033[36m", # cyan/teal 32 "std": "\033[33m", # yellow/orange 33 "median": "\033[32m", # green 34 "warning": "\033[31m", # red 35 "shape": "\033[95m", # bright magenta 36 "dtype": "\033[90m", # gray 37 "device": "\033[90m", # gray 38 "requires_grad": "\033[90m", # gray 39 "sparkline": "\033[34m", # blue 40 "torch": "\033[38;5;208m", # bright orange 41 "dtype_bool": "\033[38;5;245m", # medium grey 42 "dtype_int": "\033[38;5;39m", # bright blue 43 "dtype_float": "\033[38;5;167m", # softer red/coral 44 "device_cuda": "\033[38;5;76m", # NVIDIA-style bright green 45 "reset": "\033[0m", 46 }, 47 "none": { 48 "range": "", 49 "mean": "", 50 "std": "", 51 "median": "", 52 "warning": "", 53 "shape": "", 54 "dtype": "", 55 "device": "", 56 "requires_grad": "", 57 "sparkline": "", 58 "torch": "", 59 "dtype_bool": "", 60 "dtype_int": "", 61 "dtype_float": "", 62 "dtype_str": "", 63 "device_cuda": "", 64 "reset": "", 65 }, 66} 67 68OutputFormat = Literal["unicode", "latex", "ascii"] 69 70SYMBOLS: Dict[OutputFormat, Dict[str, str]] = { 71 "latex": { 72 "range": r"\mathcal{R}", 73 "mean": r"\mu", 74 "std": r"\sigma", 75 "median": r"\tilde{x}", 76 "distribution": r"\mathbb{P}", 77 "distribution_log": r"\mathbb{P}_L", 78 "nan_values": r"\text{NANvals}", 79 "warning": "!!!", 80 "requires_grad": r"\nabla", 81 "true": r"\checkmark", 82 "false": r"\times", 83 }, 84 "unicode": { 85 "range": "R", 86 "mean": "μ", 87 "std": "σ", 88 "median": "x̃", 89 "distribution": "ℙ", 90 "distribution_log": "ℙ˪", 91 "nan_values": "NANvals", 92 "warning": "🚨", 93 "requires_grad": "∇", 94 "true": "✓", 95 "false": "✗", 96 }, 97 "ascii": { 98 "range": "range", 99 "mean": "mean", 100 "std": "std", 101 "median": "med", 102 "distribution": "dist", 103 "distribution_log": "dist_log", 104 "nan_values": "NANvals", 105 "warning": "!!!", 106 "requires_grad": "requires_grad", 107 "true": "1", 108 "false": "0", 109 }, 110} 111"Symbols for different formats" 112 113SPARK_CHARS: Dict[OutputFormat, List[str]] = { 114 "unicode": list(" ▁▂▃▄▅▆▇█"), 115 "ascii": list(" _.-~=#"), 116 "latex": list(" ▁▂▃▄▅▆▇█"), 117} 118"characters for sparklines in different formats" 119 120 121def array_info( 122 A: Any, 123 hist_bins: int = 5, 124) -> Dict[str, Any]: 125 """Extract statistical information from an array-like object. 126 127 # Parameters: 128 - `A : array-like` 129 Array to analyze (numpy array or torch tensor) 130 131 # Returns: 132 - `Dict[str, Any]` 133 Dictionary containing raw statistical information with numeric values 134 """ 135 result: Dict[str, Any] = { 136 "is_tensor": None, 137 "device": None, 138 "requires_grad": None, 139 "shape": None, 140 "dtype": None, 141 "size": None, 142 "has_nans": None, 143 "nan_count": None, 144 "nan_percent": None, 145 "min": None, 146 "max": None, 147 "range": None, 148 "mean": None, 149 "std": None, 150 "median": None, 151 "histogram": None, 152 "bins": None, 153 "status": None, 154 } 155 156 # Check if it's a tensor by looking at its class name 157 # This avoids importing torch directly 158 A_type: str = type(A).__name__ 159 result["is_tensor"] = A_type == "Tensor" 160 161 # Try to get device information if it's a tensor 162 if result["is_tensor"]: 163 try: 164 result["device"] = str(getattr(A, "device", None)) 165 except: # noqa: E722 166 pass 167 168 # Convert to numpy array for calculations 169 try: 170 # For PyTorch tensors 171 if result["is_tensor"]: 172 # Check if tensor is on GPU 173 is_cuda: bool = False 174 try: 175 is_cuda = bool(getattr(A, "is_cuda", False)) 176 except: # noqa: E722 177 pass 178 179 if is_cuda: 180 try: 181 # Try to get CPU tensor first 182 cpu_tensor = getattr(A, "cpu", lambda: A)() 183 except: # noqa: E722 184 A_np = np.array([]) 185 else: 186 cpu_tensor = A 187 try: 188 # For CPU tensor, just detach and convert 189 detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)() 190 A_np = getattr(detached, "numpy", lambda: np.array([]))() 191 except: # noqa: E722 192 A_np = np.array([]) 193 else: 194 # For numpy arrays and other array-like objects 195 A_np = np.asarray(A) 196 except: # noqa: E722 197 A_np = np.array([]) 198 199 # Get basic information 200 try: 201 result["shape"] = A_np.shape 202 result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype) 203 result["size"] = A_np.size 204 result["requires_grad"] = getattr(A, "requires_grad", None) 205 except: # noqa: E722 206 pass 207 208 # If array is empty, return early 209 if result["size"] == 0: 210 result["status"] = "empty array" 211 return result 212 213 # Flatten array for statistics if it's multi-dimensional 214 # TODO: type checks fail on 3.10, see https://github.com/mivanit/muutils/actions/runs/18883100459/job/53891346225 215 try: 216 if len(A_np.shape) > 1: 217 A_flat = A_np.flatten() # type: ignore[assignment] 218 else: 219 A_flat = A_np # type: ignore[assignment] 220 except: # noqa: E722 221 A_flat = A_np # type: ignore[assignment] 222 223 # Check for NaN values 224 try: 225 nan_mask = np.isnan(A_flat) 226 result["nan_count"] = np.sum(nan_mask) 227 result["has_nans"] = result["nan_count"] > 0 228 if result["size"] > 0: 229 result["nan_percent"] = (result["nan_count"] / result["size"]) * 100 230 except: # noqa: E722 231 pass 232 233 # If all values are NaN, return early 234 if result["has_nans"] and result["nan_count"] == result["size"]: 235 result["status"] = "all NaN" 236 return result 237 238 # Calculate statistics 239 try: 240 if result["has_nans"]: 241 result["min"] = float(np.nanmin(A_flat)) 242 result["max"] = float(np.nanmax(A_flat)) 243 result["mean"] = float(np.nanmean(A_flat)) 244 result["std"] = float(np.nanstd(A_flat)) 245 result["median"] = float(np.nanmedian(A_flat)) 246 result["range"] = (result["min"], result["max"]) 247 248 # Remove NaNs for histogram 249 A_hist = A_flat[~nan_mask] 250 else: 251 result["min"] = float(np.min(A_flat)) 252 result["max"] = float(np.max(A_flat)) 253 result["mean"] = float(np.mean(A_flat)) 254 result["std"] = float(np.std(A_flat)) 255 result["median"] = float(np.median(A_flat)) 256 result["range"] = (result["min"], result["max"]) 257 258 A_hist = A_flat 259 260 # Calculate histogram data for sparklines 261 if A_hist.size > 0: 262 try: 263 # TODO: handle bool tensors correctly 264 # muutils/tensor_info.py:238: RuntimeWarning: Converting input from bool to <class 'numpy.uint8'> for compatibility. 265 hist, bins = np.histogram(A_hist, bins=hist_bins) 266 result["histogram"] = hist 267 result["bins"] = bins 268 except: # noqa: E722 269 pass 270 271 result["status"] = "ok" 272 except Exception as e: 273 result["status"] = f"error: {str(e)}" 274 275 return result 276 277 278def generate_sparkline( 279 histogram: np.ndarray, 280 format: Literal["unicode", "latex", "ascii"] = "unicode", 281 log_y: Optional[bool] = None, 282) -> tuple[str, bool]: 283 """Generate a sparkline visualization of the histogram. 284 285 # Parameters: 286 - `histogram : np.ndarray` 287 Histogram data 288 - `format : Literal["unicode", "latex", "ascii"]` 289 Output format (defaults to `"unicode"`) 290 - `log_y : bool|None` 291 Whether to use logarithmic y-scale. `None` for automatic detection 292 (defaults to `None`) 293 294 # Returns: 295 - `tuple[str, bool]` 296 Sparkline visualization and whether log scale was used 297 """ 298 if histogram is None or len(histogram) == 0: 299 return "", False 300 301 # Get the appropriate character set 302 chars: List[str] 303 if format in SPARK_CHARS: 304 chars = SPARK_CHARS[format] 305 else: 306 chars = SPARK_CHARS["ascii"] 307 308 # automatic detection of log_y 309 if log_y is None: 310 # we bin the histogram values to the number of levels in our sparkline characters 311 hist_hist = np.histogram(histogram, bins=len(chars))[0] 312 # if every bin except the smallest (first) and largest (last) is empty, 313 # then we should use the log scale. if those bins are nonempty, keep the linear scale 314 if hist_hist[1:-1].max() > 0: 315 log_y = False 316 else: 317 log_y = True 318 319 # Handle log scale 320 if log_y: 321 # Add small value to avoid log(0) 322 hist_data = np.log1p(histogram) 323 else: 324 hist_data = histogram 325 326 # Normalize to character set range 327 if hist_data.max() > 0: 328 normalized = hist_data / hist_data.max() * (len(chars) - 1) 329 else: 330 normalized = np.zeros_like(hist_data) 331 332 # Convert to characters 333 spark = "" 334 for val in normalized: 335 idx = round(val) 336 spark += chars[idx] 337 338 return spark, log_y 339 340 341DEFAULT_SETTINGS: Dict[str, Any] = dict( 342 fmt="unicode", 343 precision=2, 344 stats=True, 345 shape=True, 346 dtype=True, 347 device=True, 348 requires_grad=True, 349 sparkline=False, 350 sparkline_bins=5, 351 sparkline_logy=None, 352 colored=False, 353 as_list=False, 354 eq_char="=", 355) 356 357 358def apply_color( 359 text: str, color_key: str, colors: Dict[str, str], using_tex: bool 360) -> str: 361 if using_tex: 362 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 363 else: 364 return ( 365 f"{colors[color_key]}{text}{colors['reset']}" if colors[color_key] else text 366 ) 367 368 369def colorize_dtype(dtype_str: str, colors: Dict[str, str], using_tex: bool) -> str: 370 """Colorize dtype string with specific colors for torch and type names.""" 371 372 # Handle torch prefix 373 type_part: str = dtype_str 374 prefix_part: Optional[str] = None 375 if "torch." in dtype_str: 376 parts = dtype_str.split("torch.") 377 if len(parts) == 2: 378 prefix_part = apply_color("torch", "torch", colors, using_tex) 379 type_part = parts[1] 380 381 # Handle type coloring 382 color_key: str = "dtype" 383 if "bool" in dtype_str.lower(): 384 color_key = "dtype_bool" 385 elif "int" in dtype_str.lower(): 386 color_key = "dtype_int" 387 elif "float" in dtype_str.lower(): 388 color_key = "dtype_float" 389 390 type_colored: str = apply_color(type_part, color_key, colors, using_tex) 391 392 if prefix_part: 393 return f"{prefix_part}.{type_colored}" 394 else: 395 return type_colored 396 397 398def format_shape_colored(shape_val, colors: Dict[str, str], using_tex: bool) -> str: 399 """Format shape with proper coloring for both 1D and multi-D arrays.""" 400 401 def apply_color(text: str, color_key: str) -> str: 402 if using_tex: 403 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 404 else: 405 return ( 406 f"{colors[color_key]}{text}{colors['reset']}" 407 if colors[color_key] 408 else text 409 ) 410 411 if len(shape_val) == 1: 412 # For 1D arrays, still color the dimension value 413 return apply_color(str(shape_val[0]), "shape") 414 else: 415 # For multi-D arrays, color each dimension 416 return "(" + ",".join(apply_color(str(dim), "shape") for dim in shape_val) + ")" 417 418 419def format_device_colored( 420 device_str: str, colors: Dict[str, str], using_tex: bool 421) -> str: 422 """Format device string with CUDA highlighting.""" 423 424 def apply_color(text: str, color_key: str) -> str: 425 if using_tex: 426 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 427 else: 428 return ( 429 f"{colors[color_key]}{text}{colors['reset']}" 430 if colors[color_key] 431 else text 432 ) 433 434 if "cuda" in device_str.lower(): 435 return apply_color(device_str, "device_cuda") 436 else: 437 return apply_color(device_str, "device") 438 439 440class _UseDefaultType: 441 pass 442 443 444_USE_DEFAULT = _UseDefaultType() 445 446 447@overload 448def array_summary( 449 array: Any, 450 as_list: Literal[True], 451 **kwargs, 452) -> List[str]: ... 453@overload 454def array_summary( 455 array: Any, 456 as_list: Literal[False], 457 **kwargs, 458) -> str: ... 459def array_summary( # type: ignore[misc] 460 array, 461 fmt: OutputFormat = _USE_DEFAULT, # type: ignore[assignment] 462 precision: int = _USE_DEFAULT, # type: ignore[assignment] 463 stats: bool = _USE_DEFAULT, # type: ignore[assignment] 464 shape: bool = _USE_DEFAULT, # type: ignore[assignment] 465 dtype: bool = _USE_DEFAULT, # type: ignore[assignment] 466 device: bool = _USE_DEFAULT, # type: ignore[assignment] 467 requires_grad: bool = _USE_DEFAULT, # type: ignore[assignment] 468 sparkline: bool = _USE_DEFAULT, # type: ignore[assignment] 469 sparkline_bins: int = _USE_DEFAULT, # type: ignore[assignment] 470 sparkline_logy: Optional[bool] = _USE_DEFAULT, # type: ignore[assignment] 471 colored: bool = _USE_DEFAULT, # type: ignore[assignment] 472 eq_char: str = _USE_DEFAULT, # type: ignore[assignment] 473 as_list: bool = _USE_DEFAULT, # type: ignore[assignment] 474) -> Union[str, List[str]]: 475 """Format array information into a readable summary. 476 477 # Parameters: 478 - `array` 479 array-like object (numpy array or torch tensor) 480 - `precision : int` 481 Decimal places (defaults to `2`) 482 - `format : Literal["unicode", "latex", "ascii"]` 483 Output format (defaults to `{default_fmt}`) 484 - `stats : bool` 485 Whether to include statistical info (μ, σ, x̃) (defaults to `True`) 486 - `shape : bool` 487 Whether to include shape info (defaults to `True`) 488 - `dtype : bool` 489 Whether to include dtype info (defaults to `True`) 490 - `device : bool` 491 Whether to include device info for torch tensors (defaults to `True`) 492 - `requires_grad : bool` 493 Whether to include requires_grad info for torch tensors (defaults to `True`) 494 - `sparkline : bool` 495 Whether to include a sparkline visualization (defaults to `False`) 496 - `sparkline_width : int` 497 Width of the sparkline (defaults to `20`) 498 - `sparkline_logy : bool|None` 499 Whether to use logarithmic y-scale for sparkline (defaults to `None`) 500 - `colored : bool` 501 Whether to add color to output (defaults to `False`) 502 - `as_list : bool` 503 Whether to return as list of strings instead of joined string (defaults to `False`) 504 505 # Returns: 506 - `Union[str, List[str]]` 507 Formatted statistical summary, either as string or list of strings 508 """ 509 if fmt is _USE_DEFAULT: 510 fmt = DEFAULT_SETTINGS["fmt"] 511 if precision is _USE_DEFAULT: 512 precision = DEFAULT_SETTINGS["precision"] 513 if stats is _USE_DEFAULT: 514 stats = DEFAULT_SETTINGS["stats"] 515 if shape is _USE_DEFAULT: 516 shape = DEFAULT_SETTINGS["shape"] 517 if dtype is _USE_DEFAULT: 518 dtype = DEFAULT_SETTINGS["dtype"] 519 if device is _USE_DEFAULT: 520 device = DEFAULT_SETTINGS["device"] 521 if requires_grad is _USE_DEFAULT: 522 requires_grad = DEFAULT_SETTINGS["requires_grad"] 523 if sparkline is _USE_DEFAULT: 524 sparkline = DEFAULT_SETTINGS["sparkline"] 525 if sparkline_bins is _USE_DEFAULT: 526 sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"] 527 if sparkline_logy is _USE_DEFAULT: 528 sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"] 529 if colored is _USE_DEFAULT: 530 colored = DEFAULT_SETTINGS["colored"] 531 if as_list is _USE_DEFAULT: 532 as_list = DEFAULT_SETTINGS["as_list"] 533 if eq_char is _USE_DEFAULT: 534 eq_char = DEFAULT_SETTINGS["eq_char"] 535 536 array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins) 537 result_parts: List[str] = [] 538 using_tex: bool = fmt == "latex" 539 540 # Set color scheme based on format and colored flag 541 colors: Dict[str, str] 542 if colored: 543 colors = COLORS["latex"] if using_tex else COLORS["terminal"] 544 else: 545 colors = COLORS["none"] 546 547 # Get symbols for the current format 548 symbols: Dict[str, str] = SYMBOLS[fmt] 549 550 # Helper function to colorize text 551 def colorize(text: str, color_key: str) -> str: 552 if using_tex: 553 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 554 else: 555 return ( 556 f"{colors[color_key]}{text}{colors['reset']}" 557 if colors[color_key] 558 else text 559 ) 560 561 # Check if dtype is integer type 562 dtype_str: str = array_data.get("dtype", "") 563 is_int_dtype: bool = any( 564 int_type in dtype_str.lower() for int_type in ["int", "uint", "bool"] 565 ) 566 567 # Format string for numbers 568 float_fmt: str = f".{precision}f" 569 570 # Handle error status or empty array 571 if ( 572 array_data["status"] in ["empty array", "all NaN", "unknown"] 573 or array_data["size"] == 0 574 ): 575 status = array_data["status"] 576 result_parts.append(colorize(symbols["warning"] + " " + status, "warning")) 577 else: 578 # Add NaN warning at the beginning if there are NaNs 579 if array_data["has_nans"]: 580 _percent: str = "\\%" if using_tex else "%" 581 nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})" 582 result_parts.append(colorize(nan_str, "warning")) 583 584 # Statistics 585 if stats: 586 for stat_key in ["mean", "std", "median"]: 587 if array_data[stat_key] is not None: 588 stat_str: str = f"{array_data[stat_key]:{float_fmt}}" 589 stat_colored: str = colorize(stat_str, stat_key) 590 result_parts.append(f"{symbols[stat_key]}={stat_colored}") 591 592 # Range (min, max) 593 if array_data["range"] is not None: 594 min_val, max_val = array_data["range"] 595 if is_int_dtype: 596 min_str: str = f"{int(min_val):d}" 597 max_str: str = f"{int(max_val):d}" 598 else: 599 min_str = f"{min_val:{float_fmt}}" 600 max_str = f"{max_val:{float_fmt}}" 601 min_colored: str = colorize(min_str, "range") 602 max_colored: str = colorize(max_str, "range") 603 range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]" 604 result_parts.append(range_str) 605 606 # Add sparkline if requested 607 if sparkline and array_data["histogram"] is not None: 608 # this should return whether log_y is used or not and then we set the symbol accordingly 609 spark, used_log = generate_sparkline( 610 array_data["histogram"], 611 format=fmt, 612 log_y=sparkline_logy, 613 ) 614 if spark: 615 spark_colored = colorize(spark, "sparkline") 616 dist_symbol = ( 617 symbols["distribution_log"] if used_log else symbols["distribution"] 618 ) 619 result_parts.append(f"{dist_symbol}{eq_char}|{spark_colored}|") 620 621 # Add shape if requested 622 if shape and array_data["shape"]: 623 shape_val = array_data["shape"] 624 shape_str = format_shape_colored(shape_val, colors, using_tex) 625 result_parts.append(f"shape{eq_char}{shape_str}") 626 627 # Add dtype if requested 628 if dtype and array_data["dtype"]: 629 dtype_colored = colorize_dtype(array_data["dtype"], colors, using_tex) 630 result_parts.append(f"dtype={dtype_colored}") 631 632 # Add device if requested and it's a tensor with device info 633 if device and array_data["is_tensor"] and array_data["device"]: 634 device_colored = format_device_colored(array_data["device"], colors, using_tex) 635 result_parts.append(f"device{eq_char}{device_colored}") 636 637 # Add gradient info 638 if requires_grad and array_data["is_tensor"]: 639 bool_req_grad_symb: str = ( 640 symbols["true"] if array_data["requires_grad"] else symbols["false"] 641 ) 642 result_parts.append( 643 colorize(symbols["requires_grad"] + bool_req_grad_symb, "requires_grad") 644 ) 645 646 # Return as list if requested, otherwise join with spaces 647 if as_list: 648 return result_parts 649 else: 650 joinchar: str = r" \quad " if using_tex else " " 651 return joinchar.join(result_parts)
COLORS: Dict[str, Dict[str, str]] =
{'latex': {'range': '\\textcolor{purple}', 'mean': '\\textcolor{teal}', 'std': '\\textcolor{orange}', 'median': '\\textcolor{green}', 'warning': '\\textcolor{red}', 'shape': '\\textcolor{magenta}', 'dtype': '\\textcolor{gray}', 'device': '\\textcolor{gray}', 'requires_grad': '\\textcolor{gray}', 'sparkline': '\\textcolor{blue}', 'torch': '\\textcolor{orange}', 'dtype_bool': '\\textcolor{gray}', 'dtype_int': '\\textcolor{blue}', 'dtype_float': '\\textcolor{red!70}', 'dtype_str': '\\textcolor{red}', 'device_cuda': '\\textcolor{green}', 'reset': ''}, 'terminal': {'range': '\x1b[35m', 'mean': '\x1b[36m', 'std': '\x1b[33m', 'median': '\x1b[32m', 'warning': '\x1b[31m', 'shape': '\x1b[95m', 'dtype': '\x1b[90m', 'device': '\x1b[90m', 'requires_grad': '\x1b[90m', 'sparkline': '\x1b[34m', 'torch': '\x1b[38;5;208m', 'dtype_bool': '\x1b[38;5;245m', 'dtype_int': '\x1b[38;5;39m', 'dtype_float': '\x1b[38;5;167m', 'device_cuda': '\x1b[38;5;76m', 'reset': '\x1b[0m'}, 'none': {'range': '', 'mean': '', 'std': '', 'median': '', 'warning': '', 'shape': '', 'dtype': '', 'device': '', 'requires_grad': '', 'sparkline': '', 'torch': '', 'dtype_bool': '', 'dtype_int': '', 'dtype_float': '', 'dtype_str': '', 'device_cuda': '', 'reset': ''}}
OutputFormat =
typing.Literal['unicode', 'latex', 'ascii']
SYMBOLS: Dict[Literal['unicode', 'latex', 'ascii'], Dict[str, str]] =
{'latex': {'range': '\\mathcal{R}', 'mean': '\\mu', 'std': '\\sigma', 'median': '\\tilde{x}', 'distribution': '\\mathbb{P}', 'distribution_log': '\\mathbb{P}_L', 'nan_values': '\\text{NANvals}', 'warning': '!!!', 'requires_grad': '\\nabla', 'true': '\\checkmark', 'false': '\\times'}, 'unicode': {'range': 'R', 'mean': 'μ', 'std': 'σ', 'median': 'x̃', 'distribution': 'ℙ', 'distribution_log': 'ℙ˪', 'nan_values': 'NANvals', 'warning': '🚨', 'requires_grad': '∇', 'true': '✓', 'false': '✗'}, 'ascii': {'range': 'range', 'mean': 'mean', 'std': 'std', 'median': 'med', 'distribution': 'dist', 'distribution_log': 'dist_log', 'nan_values': 'NANvals', 'warning': '!!!', 'requires_grad': 'requires_grad', 'true': '1', 'false': '0'}}
Symbols for different formats
SPARK_CHARS: Dict[Literal['unicode', 'latex', 'ascii'], List[str]] =
{'unicode': [' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█'], 'ascii': [' ', '_', '.', '-', '~', '=', '#'], 'latex': [' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█']}
characters for sparklines in different formats
def
array_info(A: Any, hist_bins: int = 5) -> Dict[str, Any]:
122def array_info( 123 A: Any, 124 hist_bins: int = 5, 125) -> Dict[str, Any]: 126 """Extract statistical information from an array-like object. 127 128 # Parameters: 129 - `A : array-like` 130 Array to analyze (numpy array or torch tensor) 131 132 # Returns: 133 - `Dict[str, Any]` 134 Dictionary containing raw statistical information with numeric values 135 """ 136 result: Dict[str, Any] = { 137 "is_tensor": None, 138 "device": None, 139 "requires_grad": None, 140 "shape": None, 141 "dtype": None, 142 "size": None, 143 "has_nans": None, 144 "nan_count": None, 145 "nan_percent": None, 146 "min": None, 147 "max": None, 148 "range": None, 149 "mean": None, 150 "std": None, 151 "median": None, 152 "histogram": None, 153 "bins": None, 154 "status": None, 155 } 156 157 # Check if it's a tensor by looking at its class name 158 # This avoids importing torch directly 159 A_type: str = type(A).__name__ 160 result["is_tensor"] = A_type == "Tensor" 161 162 # Try to get device information if it's a tensor 163 if result["is_tensor"]: 164 try: 165 result["device"] = str(getattr(A, "device", None)) 166 except: # noqa: E722 167 pass 168 169 # Convert to numpy array for calculations 170 try: 171 # For PyTorch tensors 172 if result["is_tensor"]: 173 # Check if tensor is on GPU 174 is_cuda: bool = False 175 try: 176 is_cuda = bool(getattr(A, "is_cuda", False)) 177 except: # noqa: E722 178 pass 179 180 if is_cuda: 181 try: 182 # Try to get CPU tensor first 183 cpu_tensor = getattr(A, "cpu", lambda: A)() 184 except: # noqa: E722 185 A_np = np.array([]) 186 else: 187 cpu_tensor = A 188 try: 189 # For CPU tensor, just detach and convert 190 detached = getattr(cpu_tensor, "detach", lambda: cpu_tensor)() 191 A_np = getattr(detached, "numpy", lambda: np.array([]))() 192 except: # noqa: E722 193 A_np = np.array([]) 194 else: 195 # For numpy arrays and other array-like objects 196 A_np = np.asarray(A) 197 except: # noqa: E722 198 A_np = np.array([]) 199 200 # Get basic information 201 try: 202 result["shape"] = A_np.shape 203 result["dtype"] = str(A.dtype if result["is_tensor"] else A_np.dtype) 204 result["size"] = A_np.size 205 result["requires_grad"] = getattr(A, "requires_grad", None) 206 except: # noqa: E722 207 pass 208 209 # If array is empty, return early 210 if result["size"] == 0: 211 result["status"] = "empty array" 212 return result 213 214 # Flatten array for statistics if it's multi-dimensional 215 # TODO: type checks fail on 3.10, see https://github.com/mivanit/muutils/actions/runs/18883100459/job/53891346225 216 try: 217 if len(A_np.shape) > 1: 218 A_flat = A_np.flatten() # type: ignore[assignment] 219 else: 220 A_flat = A_np # type: ignore[assignment] 221 except: # noqa: E722 222 A_flat = A_np # type: ignore[assignment] 223 224 # Check for NaN values 225 try: 226 nan_mask = np.isnan(A_flat) 227 result["nan_count"] = np.sum(nan_mask) 228 result["has_nans"] = result["nan_count"] > 0 229 if result["size"] > 0: 230 result["nan_percent"] = (result["nan_count"] / result["size"]) * 100 231 except: # noqa: E722 232 pass 233 234 # If all values are NaN, return early 235 if result["has_nans"] and result["nan_count"] == result["size"]: 236 result["status"] = "all NaN" 237 return result 238 239 # Calculate statistics 240 try: 241 if result["has_nans"]: 242 result["min"] = float(np.nanmin(A_flat)) 243 result["max"] = float(np.nanmax(A_flat)) 244 result["mean"] = float(np.nanmean(A_flat)) 245 result["std"] = float(np.nanstd(A_flat)) 246 result["median"] = float(np.nanmedian(A_flat)) 247 result["range"] = (result["min"], result["max"]) 248 249 # Remove NaNs for histogram 250 A_hist = A_flat[~nan_mask] 251 else: 252 result["min"] = float(np.min(A_flat)) 253 result["max"] = float(np.max(A_flat)) 254 result["mean"] = float(np.mean(A_flat)) 255 result["std"] = float(np.std(A_flat)) 256 result["median"] = float(np.median(A_flat)) 257 result["range"] = (result["min"], result["max"]) 258 259 A_hist = A_flat 260 261 # Calculate histogram data for sparklines 262 if A_hist.size > 0: 263 try: 264 # TODO: handle bool tensors correctly 265 # muutils/tensor_info.py:238: RuntimeWarning: Converting input from bool to <class 'numpy.uint8'> for compatibility. 266 hist, bins = np.histogram(A_hist, bins=hist_bins) 267 result["histogram"] = hist 268 result["bins"] = bins 269 except: # noqa: E722 270 pass 271 272 result["status"] = "ok" 273 except Exception as e: 274 result["status"] = f"error: {str(e)}" 275 276 return result
Extract statistical information from an array-like object.
Parameters:
A : array-likeArray to analyze (numpy array or torch tensor)
Returns:
Dict[str, Any]Dictionary containing raw statistical information with numeric values
def
generate_sparkline( histogram: numpy.ndarray, format: Literal['unicode', 'latex', 'ascii'] = 'unicode', log_y: Optional[bool] = None) -> tuple[str, bool]:
279def generate_sparkline( 280 histogram: np.ndarray, 281 format: Literal["unicode", "latex", "ascii"] = "unicode", 282 log_y: Optional[bool] = None, 283) -> tuple[str, bool]: 284 """Generate a sparkline visualization of the histogram. 285 286 # Parameters: 287 - `histogram : np.ndarray` 288 Histogram data 289 - `format : Literal["unicode", "latex", "ascii"]` 290 Output format (defaults to `"unicode"`) 291 - `log_y : bool|None` 292 Whether to use logarithmic y-scale. `None` for automatic detection 293 (defaults to `None`) 294 295 # Returns: 296 - `tuple[str, bool]` 297 Sparkline visualization and whether log scale was used 298 """ 299 if histogram is None or len(histogram) == 0: 300 return "", False 301 302 # Get the appropriate character set 303 chars: List[str] 304 if format in SPARK_CHARS: 305 chars = SPARK_CHARS[format] 306 else: 307 chars = SPARK_CHARS["ascii"] 308 309 # automatic detection of log_y 310 if log_y is None: 311 # we bin the histogram values to the number of levels in our sparkline characters 312 hist_hist = np.histogram(histogram, bins=len(chars))[0] 313 # if every bin except the smallest (first) and largest (last) is empty, 314 # then we should use the log scale. if those bins are nonempty, keep the linear scale 315 if hist_hist[1:-1].max() > 0: 316 log_y = False 317 else: 318 log_y = True 319 320 # Handle log scale 321 if log_y: 322 # Add small value to avoid log(0) 323 hist_data = np.log1p(histogram) 324 else: 325 hist_data = histogram 326 327 # Normalize to character set range 328 if hist_data.max() > 0: 329 normalized = hist_data / hist_data.max() * (len(chars) - 1) 330 else: 331 normalized = np.zeros_like(hist_data) 332 333 # Convert to characters 334 spark = "" 335 for val in normalized: 336 idx = round(val) 337 spark += chars[idx] 338 339 return spark, log_y
Generate a sparkline visualization of the histogram.
Parameters:
histogram : np.ndarrayHistogram dataformat : Literal["unicode", "latex", "ascii"]Output format (defaults to"unicode")log_y : bool|NoneWhether to use logarithmic y-scale.Nonefor automatic detection (defaults toNone)
Returns:
tuple[str, bool]Sparkline visualization and whether log scale was used
DEFAULT_SETTINGS: Dict[str, Any] =
{'fmt': 'unicode', 'precision': 2, 'stats': True, 'shape': True, 'dtype': True, 'device': True, 'requires_grad': True, 'sparkline': False, 'sparkline_bins': 5, 'sparkline_logy': None, 'colored': False, 'as_list': False, 'eq_char': '='}
def
apply_color( text: str, color_key: str, colors: Dict[str, str], using_tex: bool) -> str:
359def apply_color( 360 text: str, color_key: str, colors: Dict[str, str], using_tex: bool 361) -> str: 362 if using_tex: 363 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 364 else: 365 return ( 366 f"{colors[color_key]}{text}{colors['reset']}" if colors[color_key] else text 367 )
def
colorize_dtype(dtype_str: str, colors: Dict[str, str], using_tex: bool) -> str:
370def colorize_dtype(dtype_str: str, colors: Dict[str, str], using_tex: bool) -> str: 371 """Colorize dtype string with specific colors for torch and type names.""" 372 373 # Handle torch prefix 374 type_part: str = dtype_str 375 prefix_part: Optional[str] = None 376 if "torch." in dtype_str: 377 parts = dtype_str.split("torch.") 378 if len(parts) == 2: 379 prefix_part = apply_color("torch", "torch", colors, using_tex) 380 type_part = parts[1] 381 382 # Handle type coloring 383 color_key: str = "dtype" 384 if "bool" in dtype_str.lower(): 385 color_key = "dtype_bool" 386 elif "int" in dtype_str.lower(): 387 color_key = "dtype_int" 388 elif "float" in dtype_str.lower(): 389 color_key = "dtype_float" 390 391 type_colored: str = apply_color(type_part, color_key, colors, using_tex) 392 393 if prefix_part: 394 return f"{prefix_part}.{type_colored}" 395 else: 396 return type_colored
Colorize dtype string with specific colors for torch and type names.
def
format_shape_colored(shape_val, colors: Dict[str, str], using_tex: bool) -> str:
399def format_shape_colored(shape_val, colors: Dict[str, str], using_tex: bool) -> str: 400 """Format shape with proper coloring for both 1D and multi-D arrays.""" 401 402 def apply_color(text: str, color_key: str) -> str: 403 if using_tex: 404 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 405 else: 406 return ( 407 f"{colors[color_key]}{text}{colors['reset']}" 408 if colors[color_key] 409 else text 410 ) 411 412 if len(shape_val) == 1: 413 # For 1D arrays, still color the dimension value 414 return apply_color(str(shape_val[0]), "shape") 415 else: 416 # For multi-D arrays, color each dimension 417 return "(" + ",".join(apply_color(str(dim), "shape") for dim in shape_val) + ")"
Format shape with proper coloring for both 1D and multi-D arrays.
def
format_device_colored(device_str: str, colors: Dict[str, str], using_tex: bool) -> str:
420def format_device_colored( 421 device_str: str, colors: Dict[str, str], using_tex: bool 422) -> str: 423 """Format device string with CUDA highlighting.""" 424 425 def apply_color(text: str, color_key: str) -> str: 426 if using_tex: 427 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 428 else: 429 return ( 430 f"{colors[color_key]}{text}{colors['reset']}" 431 if colors[color_key] 432 else text 433 ) 434 435 if "cuda" in device_str.lower(): 436 return apply_color(device_str, "device_cuda") 437 else: 438 return apply_color(device_str, "device")
Format device string with CUDA highlighting.
def
array_summary( array, fmt: Literal['unicode', 'latex', 'ascii'] = <muutils.tensor_info._UseDefaultType object>, precision: int = <muutils.tensor_info._UseDefaultType object>, stats: bool = <muutils.tensor_info._UseDefaultType object>, shape: bool = <muutils.tensor_info._UseDefaultType object>, dtype: bool = <muutils.tensor_info._UseDefaultType object>, device: bool = <muutils.tensor_info._UseDefaultType object>, requires_grad: bool = <muutils.tensor_info._UseDefaultType object>, sparkline: bool = <muutils.tensor_info._UseDefaultType object>, sparkline_bins: int = <muutils.tensor_info._UseDefaultType object>, sparkline_logy: Optional[bool] = <muutils.tensor_info._UseDefaultType object>, colored: bool = <muutils.tensor_info._UseDefaultType object>, eq_char: str = <muutils.tensor_info._UseDefaultType object>, as_list: bool = <muutils.tensor_info._UseDefaultType object>) -> Union[str, List[str]]:
460def array_summary( # type: ignore[misc] 461 array, 462 fmt: OutputFormat = _USE_DEFAULT, # type: ignore[assignment] 463 precision: int = _USE_DEFAULT, # type: ignore[assignment] 464 stats: bool = _USE_DEFAULT, # type: ignore[assignment] 465 shape: bool = _USE_DEFAULT, # type: ignore[assignment] 466 dtype: bool = _USE_DEFAULT, # type: ignore[assignment] 467 device: bool = _USE_DEFAULT, # type: ignore[assignment] 468 requires_grad: bool = _USE_DEFAULT, # type: ignore[assignment] 469 sparkline: bool = _USE_DEFAULT, # type: ignore[assignment] 470 sparkline_bins: int = _USE_DEFAULT, # type: ignore[assignment] 471 sparkline_logy: Optional[bool] = _USE_DEFAULT, # type: ignore[assignment] 472 colored: bool = _USE_DEFAULT, # type: ignore[assignment] 473 eq_char: str = _USE_DEFAULT, # type: ignore[assignment] 474 as_list: bool = _USE_DEFAULT, # type: ignore[assignment] 475) -> Union[str, List[str]]: 476 """Format array information into a readable summary. 477 478 # Parameters: 479 - `array` 480 array-like object (numpy array or torch tensor) 481 - `precision : int` 482 Decimal places (defaults to `2`) 483 - `format : Literal["unicode", "latex", "ascii"]` 484 Output format (defaults to `{default_fmt}`) 485 - `stats : bool` 486 Whether to include statistical info (μ, σ, x̃) (defaults to `True`) 487 - `shape : bool` 488 Whether to include shape info (defaults to `True`) 489 - `dtype : bool` 490 Whether to include dtype info (defaults to `True`) 491 - `device : bool` 492 Whether to include device info for torch tensors (defaults to `True`) 493 - `requires_grad : bool` 494 Whether to include requires_grad info for torch tensors (defaults to `True`) 495 - `sparkline : bool` 496 Whether to include a sparkline visualization (defaults to `False`) 497 - `sparkline_width : int` 498 Width of the sparkline (defaults to `20`) 499 - `sparkline_logy : bool|None` 500 Whether to use logarithmic y-scale for sparkline (defaults to `None`) 501 - `colored : bool` 502 Whether to add color to output (defaults to `False`) 503 - `as_list : bool` 504 Whether to return as list of strings instead of joined string (defaults to `False`) 505 506 # Returns: 507 - `Union[str, List[str]]` 508 Formatted statistical summary, either as string or list of strings 509 """ 510 if fmt is _USE_DEFAULT: 511 fmt = DEFAULT_SETTINGS["fmt"] 512 if precision is _USE_DEFAULT: 513 precision = DEFAULT_SETTINGS["precision"] 514 if stats is _USE_DEFAULT: 515 stats = DEFAULT_SETTINGS["stats"] 516 if shape is _USE_DEFAULT: 517 shape = DEFAULT_SETTINGS["shape"] 518 if dtype is _USE_DEFAULT: 519 dtype = DEFAULT_SETTINGS["dtype"] 520 if device is _USE_DEFAULT: 521 device = DEFAULT_SETTINGS["device"] 522 if requires_grad is _USE_DEFAULT: 523 requires_grad = DEFAULT_SETTINGS["requires_grad"] 524 if sparkline is _USE_DEFAULT: 525 sparkline = DEFAULT_SETTINGS["sparkline"] 526 if sparkline_bins is _USE_DEFAULT: 527 sparkline_bins = DEFAULT_SETTINGS["sparkline_bins"] 528 if sparkline_logy is _USE_DEFAULT: 529 sparkline_logy = DEFAULT_SETTINGS["sparkline_logy"] 530 if colored is _USE_DEFAULT: 531 colored = DEFAULT_SETTINGS["colored"] 532 if as_list is _USE_DEFAULT: 533 as_list = DEFAULT_SETTINGS["as_list"] 534 if eq_char is _USE_DEFAULT: 535 eq_char = DEFAULT_SETTINGS["eq_char"] 536 537 array_data: Dict[str, Any] = array_info(array, hist_bins=sparkline_bins) 538 result_parts: List[str] = [] 539 using_tex: bool = fmt == "latex" 540 541 # Set color scheme based on format and colored flag 542 colors: Dict[str, str] 543 if colored: 544 colors = COLORS["latex"] if using_tex else COLORS["terminal"] 545 else: 546 colors = COLORS["none"] 547 548 # Get symbols for the current format 549 symbols: Dict[str, str] = SYMBOLS[fmt] 550 551 # Helper function to colorize text 552 def colorize(text: str, color_key: str) -> str: 553 if using_tex: 554 return f"{colors[color_key]}{{{text}}}" if colors[color_key] else text 555 else: 556 return ( 557 f"{colors[color_key]}{text}{colors['reset']}" 558 if colors[color_key] 559 else text 560 ) 561 562 # Check if dtype is integer type 563 dtype_str: str = array_data.get("dtype", "") 564 is_int_dtype: bool = any( 565 int_type in dtype_str.lower() for int_type in ["int", "uint", "bool"] 566 ) 567 568 # Format string for numbers 569 float_fmt: str = f".{precision}f" 570 571 # Handle error status or empty array 572 if ( 573 array_data["status"] in ["empty array", "all NaN", "unknown"] 574 or array_data["size"] == 0 575 ): 576 status = array_data["status"] 577 result_parts.append(colorize(symbols["warning"] + " " + status, "warning")) 578 else: 579 # Add NaN warning at the beginning if there are NaNs 580 if array_data["has_nans"]: 581 _percent: str = "\\%" if using_tex else "%" 582 nan_str: str = f"{symbols['warning']} {symbols['nan_values']}{eq_char}{array_data['nan_count']} ({array_data['nan_percent']:.1f}{_percent})" 583 result_parts.append(colorize(nan_str, "warning")) 584 585 # Statistics 586 if stats: 587 for stat_key in ["mean", "std", "median"]: 588 if array_data[stat_key] is not None: 589 stat_str: str = f"{array_data[stat_key]:{float_fmt}}" 590 stat_colored: str = colorize(stat_str, stat_key) 591 result_parts.append(f"{symbols[stat_key]}={stat_colored}") 592 593 # Range (min, max) 594 if array_data["range"] is not None: 595 min_val, max_val = array_data["range"] 596 if is_int_dtype: 597 min_str: str = f"{int(min_val):d}" 598 max_str: str = f"{int(max_val):d}" 599 else: 600 min_str = f"{min_val:{float_fmt}}" 601 max_str = f"{max_val:{float_fmt}}" 602 min_colored: str = colorize(min_str, "range") 603 max_colored: str = colorize(max_str, "range") 604 range_str: str = f"{symbols['range']}=[{min_colored},{max_colored}]" 605 result_parts.append(range_str) 606 607 # Add sparkline if requested 608 if sparkline and array_data["histogram"] is not None: 609 # this should return whether log_y is used or not and then we set the symbol accordingly 610 spark, used_log = generate_sparkline( 611 array_data["histogram"], 612 format=fmt, 613 log_y=sparkline_logy, 614 ) 615 if spark: 616 spark_colored = colorize(spark, "sparkline") 617 dist_symbol = ( 618 symbols["distribution_log"] if used_log else symbols["distribution"] 619 ) 620 result_parts.append(f"{dist_symbol}{eq_char}|{spark_colored}|") 621 622 # Add shape if requested 623 if shape and array_data["shape"]: 624 shape_val = array_data["shape"] 625 shape_str = format_shape_colored(shape_val, colors, using_tex) 626 result_parts.append(f"shape{eq_char}{shape_str}") 627 628 # Add dtype if requested 629 if dtype and array_data["dtype"]: 630 dtype_colored = colorize_dtype(array_data["dtype"], colors, using_tex) 631 result_parts.append(f"dtype={dtype_colored}") 632 633 # Add device if requested and it's a tensor with device info 634 if device and array_data["is_tensor"] and array_data["device"]: 635 device_colored = format_device_colored(array_data["device"], colors, using_tex) 636 result_parts.append(f"device{eq_char}{device_colored}") 637 638 # Add gradient info 639 if requires_grad and array_data["is_tensor"]: 640 bool_req_grad_symb: str = ( 641 symbols["true"] if array_data["requires_grad"] else symbols["false"] 642 ) 643 result_parts.append( 644 colorize(symbols["requires_grad"] + bool_req_grad_symb, "requires_grad") 645 ) 646 647 # Return as list if requested, otherwise join with spaces 648 if as_list: 649 return result_parts 650 else: 651 joinchar: str = r" \quad " if using_tex else " " 652 return joinchar.join(result_parts)
Format array information into a readable summary.
Parameters:
arrayarray-like object (numpy array or torch tensor)precision : intDecimal places (defaults to2)format : Literal["unicode", "latex", "ascii"]Output format (defaults to{default_fmt})stats : boolWhether to include statistical info (μ, σ, x̃) (defaults toTrue)shape : boolWhether to include shape info (defaults toTrue)dtype : boolWhether to include dtype info (defaults toTrue)device : boolWhether to include device info for torch tensors (defaults toTrue)requires_grad : boolWhether to include requires_grad info for torch tensors (defaults toTrue)sparkline : boolWhether to include a sparkline visualization (defaults toFalse)sparkline_width : intWidth of the sparkline (defaults to20)sparkline_logy : bool|NoneWhether to use logarithmic y-scale for sparkline (defaults toNone)colored : boolWhether to add color to output (defaults toFalse)as_list : boolWhether to return as list of strings instead of joined string (defaults toFalse)
Returns:
Union[str, List[str]]Formatted statistical summary, either as string or list of strings