Coverage for nilearn/plotting/matrix/matrix_plotting.py: 0%
215 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
1"""Miscellaneous matrix plotting utilities."""
3import matplotlib.patches as mpatches
4import matplotlib.pyplot as plt
5import numpy as np
6import pandas as pd
7from mpl_toolkits.axes_grid1 import make_axes_locatable
9from nilearn import DEFAULT_DIVERGING_CMAP
10from nilearn._utils import (
11 constrained_layout_kwargs,
12 fill_doc,
13 rename_parameters,
14)
15from nilearn._utils.glm import check_and_load_tables
16from nilearn.glm.first_level import check_design_matrix
17from nilearn.glm.first_level.experimental_paradigm import check_events
18from nilearn.plotting._utils import save_figure_if_needed
19from nilearn.plotting.matrix._utils import (
20 mask_matrix,
21 pad_contrast_matrix,
22 reorder_matrix,
23 sanitize_labels,
24 sanitize_reorder,
25 sanitize_tri,
26)
29def _configure_axis(
30 axes, labels, label_size, x_label_rotation, y_label_rotation
31):
32 """Help for plot_matrix."""
33 if not labels:
34 axes.xaxis.set_major_formatter(plt.NullFormatter())
35 axes.yaxis.set_major_formatter(plt.NullFormatter())
36 else:
37 axes.set_xticks(np.arange(len(labels)))
38 axes.set_xticklabels(labels, size=label_size)
39 for label in axes.get_xticklabels():
40 label.set_ha("right")
41 label.set_rotation(x_label_rotation)
42 axes.set_yticks(np.arange(len(labels)))
43 axes.set_yticklabels(labels, size=label_size)
44 for label in axes.get_yticklabels():
45 label.set_ha("right")
46 label.set_va("top")
47 label.set_rotation(y_label_rotation)
50def _configure_grid(axes, tri, size):
51 """Help for plot_matrix."""
52 # Different grids for different layouts
53 if tri == "lower":
54 for i in range(size):
55 # Correct for weird mis-sizing
56 i = 1.001 * i
57 axes.plot([i + 0.5, i + 0.5], [size - 0.5, i + 0.5], color="gray")
58 axes.plot([i + 0.5, -0.5], [i + 0.5, i + 0.5], color="gray")
59 elif tri == "diag":
60 for i in range(size):
61 # Correct for weird mis-sizing
62 i = 1.001 * i
63 axes.plot([i + 0.5, i + 0.5], [size - 0.5, i - 0.5], color="gray")
64 axes.plot([i + 0.5, -0.5], [i - 0.5, i - 0.5], color="gray")
65 else:
66 for i in range(size):
67 # Correct for weird mis-sizing
68 i = 1.001 * i
69 axes.plot([i + 0.5, i + 0.5], [size - 0.5, -0.5], color="gray")
70 axes.plot([size - 0.5, -0.5], [i + 0.5, i + 0.5], color="gray")
73def _fit_axes(axes):
74 """Help for plot_matrix.
76 This function redimensions the given axes to have
77 labels fitting.
78 """
79 fig = axes.get_figure()
80 renderer = fig.canvas.get_renderer()
81 ylabel_width = (
82 axes.yaxis.get_tightbbox(renderer)
83 .transformed(axes.figure.transFigure.inverted())
84 .width
85 )
86 if axes.get_position().xmin < 1.1 * ylabel_width:
87 # we need to move it over
88 new_position = axes.get_position()
89 new_position.x0 = 1.1 * ylabel_width # pad a little
90 axes.set_position(new_position)
92 xlabel_height = (
93 axes.xaxis.get_tightbbox(renderer)
94 .transformed(axes.figure.transFigure.inverted())
95 .height
96 )
97 if axes.get_position().ymin < 1.1 * xlabel_height:
98 # we need to move it over
99 new_position = axes.get_position()
100 new_position.y0 = 1.1 * xlabel_height # pad a little
101 axes.set_position(new_position)
104def _sanitize_figure_and_axes(figure, axes):
105 """Help for plot_matrix."""
106 if axes is not None and figure is not None:
107 raise ValueError(
108 "Parameters figure and axes cannot be specified together. "
109 f"You gave 'figure={figure}, axes={axes}'."
110 )
111 if figure is not None:
112 if isinstance(figure, plt.Figure):
113 fig = figure
114 if hasattr(fig, "set_layout_engine"): # can be removed w/mpl 3.5
115 fig.set_layout_engine("constrained")
116 else:
117 fig = plt.figure(figsize=figure, **constrained_layout_kwargs())
118 axes = plt.gca()
119 own_fig = True
120 elif axes is None:
121 fig, axes = plt.subplots(
122 1,
123 1,
124 figsize=(7, 5),
125 **constrained_layout_kwargs(),
126 )
127 own_fig = True
128 else:
129 fig = axes.figure
130 own_fig = False
131 return fig, axes, own_fig
134def _sanitize_inputs_plot_matrix(
135 mat_shape, tri, labels, reorder, figure, axes
136):
137 """Help for plot_matrix.
139 This function makes sure the inputs to plot_matrix are valid.
140 """
141 sanitize_tri(tri)
142 labels = sanitize_labels(mat_shape, labels)
143 reorder = sanitize_reorder(reorder)
144 fig, axes, own_fig = _sanitize_figure_and_axes(figure, axes)
145 return labels, reorder, fig, axes, own_fig
148@fill_doc
149def plot_matrix(
150 mat,
151 title=None,
152 labels=None,
153 figure=None,
154 axes=None,
155 colorbar=True,
156 cmap=DEFAULT_DIVERGING_CMAP,
157 tri="full",
158 auto_fit=True,
159 grid=False,
160 reorder=False,
161 **kwargs,
162):
163 """Plot the given matrix.
165 Parameters
166 ----------
167 mat : 2-D :class:`numpy.ndarray`
168 Matrix to be plotted.
169 %(title)s
170 labels : :obj:`list`, or :class:`numpy.ndarray` of :obj:`str`,\
171 or False, or None, default=None
172 The label of each row and column. Needs to be the same
173 length as rows/columns of mat. If False, None, or an
174 empty list, no labels are plotted.
176 figure : :class:`matplotlib.figure.Figure`, figsize :obj:`tuple`,\
177 or None, default=None
178 Sets the figure used. This argument can be either an existing
179 figure, or a pair (width, height) that gives the size of a
180 newly-created figure.
182 .. note::
184 Specifying both axes and figure is not allowed.
186 axes : None or :class:`matplotlib.axes.Axes`, default=None
187 Axes instance to be plotted on. Creates a new one if None.
189 .. note::
191 Specifying both axes and figure is not allowed.
193 %(colorbar)s
194 Default=True.
196 %(cmap)s
197 default="RdBu_r"
199 tri : {'full', 'lower', 'diag'}, default='full'
200 Which triangular part of the matrix to plot:
202 - 'lower': Plot the lower part
203 - 'diag': Plot the lower part with the diagonal
204 - 'full': Plot the full matrix
207 auto_fit : :obj:`bool`, default=True
208 If auto_fit is True, the axes are dimensioned to give room
209 for the labels. This assumes that the labels are resting
210 against the bottom and left edges of the figure.
212 grid : color or False, default=False
213 If not False, a grid is plotted to separate rows and columns
214 using the given color.
216 reorder : :obj:`bool` or {'single', 'complete', 'average'}, default=False
217 If not False, reorders the matrix into blocks of clusters.
218 Accepted linkage options for the clustering are 'single',
219 'complete', and 'average'. True defaults to average linkage.
221 .. note::
222 This option is only available with SciPy >= 1.0.0.
224 .. versionadded:: 0.4.1
226 kwargs : extra keyword arguments, optional
227 Extra keyword arguments are sent to pylab.imshow.
229 Returns
230 -------
231 display : :class:`matplotlib.axes.Axes`
232 Axes image.
234 """
235 labels, reorder, fig, axes, _ = _sanitize_inputs_plot_matrix(
236 mat.shape, tri, labels, reorder, figure, axes
237 )
238 if reorder:
239 mat, labels = reorder_matrix(mat, labels, reorder)
240 if tri != "full":
241 mat = mask_matrix(mat, tri)
242 display = axes.imshow(
243 mat, aspect="equal", interpolation="nearest", cmap=cmap, **kwargs
244 )
245 axes.set_autoscale_on(False)
246 ymin, ymax = axes.get_ylim()
247 _configure_axis(
248 axes,
249 labels,
250 label_size="x-small",
251 x_label_rotation=50,
252 y_label_rotation=10,
253 )
254 if grid is not False:
255 _configure_grid(axes, tri, len(mat))
256 axes.set_ylim(ymin, ymax)
257 if auto_fit and labels:
258 _fit_axes(axes)
259 if colorbar:
260 divider = make_axes_locatable(axes)
261 cax = divider.append_axes("right", size="5%", pad=0.05)
263 fig.colorbar(display, cax=cax)
265 if title is not None:
266 axes.set_title(title, size=16)
268 return display
271@fill_doc
272@rename_parameters({"ax": "axes"}, end_version="0.13.0")
273def plot_contrast_matrix(
274 contrast_def, design_matrix, colorbar=True, axes=None, output_file=None
275):
276 """Create plot for :term:`contrast` definition.
278 Parameters
279 ----------
280 contrast_def : :obj:`str` or :class:`numpy.ndarray` of shape[1] <= n_col \
281 where ``n_col`` is the number of columns of the design matrix.
282 The string can be a formula compatible
283 with :meth:`pandas.DataFrame.eval`.
284 Basically one can use the name of the conditions
285 as they appear in the design matrix of the fitted model
286 combined with operators +-
287 and combined with numbers with operators +-`*`/.
289 design_matrix : :class:`pandas.DataFrame`
290 Design matrix to use.
292 %(colorbar)s
293 Default=True.
295 axes : :class:`matplotlib.axes.Axes` or None, default=None
296 Axis on which to plot the figure.
297 If None, a new figure will be created.
299 %(output_file)s
301 Returns
302 -------
303 axes : :class:`matplotlib.axes.Axes`
304 Figure object.
306 """
307 contrast_def = pad_contrast_matrix(contrast_def, design_matrix)
308 con_matrix = np.array(contrast_def, ndmin=2)
310 design_column_names = design_matrix.columns.tolist()
311 max_len = np.max([len(str(name)) for name in design_column_names])
313 n_columns_design_matrix = len(design_column_names)
314 if axes is None:
315 _, axes = plt.subplots(
316 figsize=(
317 0.4 * n_columns_design_matrix,
318 1 + 0.5 * con_matrix.shape[0] + 0.04 * max_len,
319 ),
320 **constrained_layout_kwargs(),
321 )
323 maxval = np.max(np.abs(contrast_def))
324 mat = axes.matshow(
325 con_matrix, aspect="equal", cmap="gray", vmin=-maxval, vmax=maxval
326 )
328 axes.set_label("conditions")
329 axes.set_ylabel("")
330 axes.set_yticks(())
332 axes.xaxis.set(ticks=np.arange(n_columns_design_matrix))
333 axes.set_xticklabels(design_column_names, rotation=50, ha="left")
335 if colorbar:
336 fig = axes.figure
337 fig.colorbar(mat, fraction=0.025, pad=0.04)
339 return save_figure_if_needed(axes, output_file)
342@fill_doc
343@rename_parameters({"ax": "axes"}, end_version="0.13.0")
344def plot_design_matrix(
345 design_matrix,
346 rescale=True,
347 axes=None,
348 output_file=None,
349):
350 """Plot a design matrix.
352 Parameters
353 ----------
354 design matrix : :class:`pandas.DataFrame` or \
355 :obj:`str` or :obj:`pathlib.Path` to a TSV event file
356 Describes a design matrix.
358 rescale : :obj:`bool`, default=True
359 Rescale columns magnitude for visualization or not.
361 axes : :class:`matplotlib.axes.Axes` or None, default=None
362 Handle to axes onto which we will draw the design matrix.
364 %(output_file)s
366 Returns
367 -------
368 axes : :class:`matplotlib.axes.Axes`
369 The axes used for plotting.
371 """
372 design_matrix = check_and_load_tables(design_matrix, "design_matrix")[0]
374 _, X, names = check_design_matrix(design_matrix)
375 # normalize the values per column for better visualization
376 if rescale:
377 X = X / np.maximum(1.0e-12, np.sqrt(np.sum(X**2, 0)))
378 if axes is None:
379 max_len = np.max([len(str(name)) for name in names])
380 fig_height = 1 + 0.1 * X.shape[0] + 0.04 * max_len
381 if fig_height < 3:
382 fig_height = 3
383 elif fig_height > 10:
384 fig_height = 10
385 _, axes = plt.subplots(
386 figsize=(1 + 0.23 * len(names), fig_height),
387 **constrained_layout_kwargs(),
388 )
390 axes.imshow(X, interpolation="nearest", aspect="auto")
391 axes.set_label("conditions")
392 axes.set_ylabel("scan number")
394 axes.set_xticks(range(len(names)))
395 axes.set_xticklabels(names, rotation=60, ha="left")
396 # Set ticks above, to have a display more similar to the display of a
397 # corresponding dataframe
398 axes.xaxis.tick_top()
400 return save_figure_if_needed(axes, output_file)
403@fill_doc
404def plot_event(model_event, cmap=None, output_file=None, **fig_kwargs):
405 """Create plot for event visualization.
407 .. warning::
409 Events with a duration of 0 seconds will be plotted
410 by a 'delta function'.
412 Parameters
413 ----------
414 model_event : :class:`pandas.DataFrame`, \
415 :obj:`str` or :obj:`pathlib.Path` to a TSV event file, \
416 or a :obj:`list` or :obj:`tuple` \
417 of :class:`pandas.DataFrame`, \
418 :obj:`str` or :obj:`pathlib.Path` to a TSV event file.
419 The :class:`pandas.DataFrame` must have three columns:
420 ``trial_type`` with event name, ``onset`` and ``duration``.
421 See :func:`~nilearn.glm.first_level.make_first_level_design_matrix`
422 for details on the required content of events dataframes.
424 .. note::
426 The :class:`pandas.DataFrame` can also be obtained
427 from :func:`nilearn.glm.first_level.first_level_from_bids`.
429 %(cmap)s
431 %(output_file)s
433 **fig_kwargs : extra keyword arguments, optional
434 Extra arguments passed to :func:`matplotlib.pyplot.subplots`.
436 Returns
437 -------
438 figure : :class:`matplotlib.figure.Figure`
439 Plot Figure object.
441 """
442 model_event = check_and_load_tables(model_event, "model_event")
444 for i, event in enumerate(model_event):
445 event_copy = check_events(event)
446 model_event[i] = event_copy
448 n_runs = len(model_event)
449 if "layout" not in fig_kwargs and "constrained_layout" not in fig_kwargs:
450 fig_kwargs.update(**constrained_layout_kwargs())
451 figure, axes = plt.subplots(1, 1, **fig_kwargs)
453 # input validation
454 if cmap is None:
455 cmap = "tab20"
456 if isinstance(cmap, str):
457 cmap = plt.get_cmap(cmap)
459 event_labels = pd.concat(event["trial_type"] for event in model_event)
460 event_labels = np.unique(event_labels)
462 cmap_dictionary = {label: idx for idx, label in enumerate(event_labels)}
464 if len(event_labels) > cmap.N:
465 plt.close(fig=figure)
466 raise ValueError(
467 "The number of event types is greater than "
468 f"colors in colormap ({len(event_labels)} > {cmap.N}). "
469 "Use a different colormap."
470 )
472 height = 0.5
473 x_lim = []
474 for idx_run, event_df in enumerate(model_event):
475 for _, event in event_df.iterrows():
476 modulation = 1.0
477 if "modulation" in event:
478 modulation = event["modulation"]
480 ymin = (idx_run + 0.25) / n_runs
481 ymax = (idx_run + 0.25 + height * modulation) / n_runs
483 event_onset = event["onset"]
484 event_end = event["onset"] + event["duration"]
486 x_lim.append(event_end)
488 color = cmap.colors[cmap_dictionary[event["trial_type"]]]
490 if event["duration"] != 0:
491 axes.axvspan(
492 xmin=event_onset,
493 xmax=event_end,
494 ymin=ymin,
495 ymax=ymax,
496 facecolor=color,
497 )
499 # events will 0 duration are plotted as lines
500 else:
501 axes.axvline(
502 event_onset,
503 ymin=ymin,
504 ymax=ymax,
505 color=color,
506 )
508 handles = []
509 for label, idx in cmap_dictionary.items():
510 patch = mpatches.Patch(color=cmap.colors[idx], label=label)
511 handles.append(patch)
513 _ = axes.legend(handles=handles, ncol=4)
515 axes.set_xlabel("Time (sec.)")
516 axes.set_ylabel("Runs")
517 axes.set_ylim(0, n_runs)
518 axes.set_xlim(-1, max(x_lim) + 1)
519 axes.set_yticks(np.arange(n_runs) + 0.5)
520 axes.set_yticklabels(np.arange(n_runs) + 1)
522 return save_figure_if_needed(figure, output_file)
525@fill_doc
526def plot_design_matrix_correlation(
527 design_matrix,
528 tri="full",
529 cmap=DEFAULT_DIVERGING_CMAP,
530 colorbar=True,
531 output_file=None,
532 **kwargs,
533):
534 """Compute and plot the correlation between regressor of a design matrix.
536 The drift and constant regressors are omitted from the plot.
538 .. versionadded:: 0.11.0
540 Parameters
541 ----------
542 design_matrix : :obj:`pandas.DataFrame`, :obj:`pandas.DataFrame` \
543 :obj:`pathlib.Path`
544 Design matrix whose correlation matrix you want to plot.
546 tri : {"full", "diag"}, default="full"
547 Which triangular part of the matrix to plot:
549 - ``"diag"``: Plot the lower part with the diagonal
550 - ``"full"``: Plot the full matrix
552 %(cmap)s
553 default="RdBu_r"
555 This must be a diverging colormap as the correlation matrix
556 will be centered on 0.
557 The allowed colormaps are:
559 - ``"bwr"``
560 - ``"RdBu_r"``
561 - ``"seismic_r"``
563 %(output_file)s
565 kwargs : extra keyword arguments, optional
566 Extra keyword arguments are sent to
567 :func:`nilearn.plotting.plot_matrix`
569 Returns
570 -------
571 display : :class:`matplotlib.axes.Axes`
572 Axes image.
573 """
574 design_matrix = check_and_load_tables(design_matrix, "design_matrix")[0]
576 check_design_matrix(design_matrix)
578 ALLOWED_CMAP = ["RdBu_r", "bwr", "seismic_r"]
579 cmap_name = cmap if isinstance(cmap, str) else cmap.name
580 if cmap_name not in ALLOWED_CMAP:
581 raise ValueError(f"cmap must be one of {ALLOWED_CMAP}")
583 columns_to_drop = ["intercept", "constant"]
584 columns_to_drop.extend(
585 col for col in design_matrix.columns if col.startswith("drift_")
586 )
587 design_matrix = design_matrix.drop(
588 columns=columns_to_drop, errors="ignore"
589 )
591 if len(design_matrix.columns) == 0:
592 raise ValueError(
593 "Nothing left to plot after "
594 "removing drift and constant regressors."
595 )
597 sanitize_tri(tri, allowed_values=("full", "diag"))
599 mat = design_matrix.corr()
601 mat = mat.to_numpy()
602 vmax = max(mat.min(), mat.max(), key=abs)
603 if len(mat) > 1:
604 # find the second-largest value in each row
605 # to omit values on the diagonal that will always be == 1
606 second_largest = np.partition(mat, -2, axis=1)[:, -2]
607 vmax = max(abs(mat.min().min()), max(second_largest))
609 col_labels = design_matrix.columns
610 display = plot_matrix(
611 mat,
612 tri=tri,
613 cmap=cmap,
614 vmax=vmax,
615 vmin=vmax * -1,
616 labels=col_labels.to_list(),
617 colorbar=colorbar,
618 **kwargs,
619 )
621 return save_figure_if_needed(display, output_file)