Coverage for nilearn/plotting/matrix/matrix_plotting.py: 0%

215 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-20 10:58 +0200

1"""Miscellaneous matrix plotting utilities.""" 

2 

3import matplotlib.patches as mpatches 

4import matplotlib.pyplot as plt 

5import numpy as np 

6import pandas as pd 

7from mpl_toolkits.axes_grid1 import make_axes_locatable 

8 

9from nilearn import DEFAULT_DIVERGING_CMAP 

10from nilearn._utils import ( 

11 constrained_layout_kwargs, 

12 fill_doc, 

13 rename_parameters, 

14) 

15from nilearn._utils.glm import check_and_load_tables 

16from nilearn.glm.first_level import check_design_matrix 

17from nilearn.glm.first_level.experimental_paradigm import check_events 

18from nilearn.plotting._utils import save_figure_if_needed 

19from nilearn.plotting.matrix._utils import ( 

20 mask_matrix, 

21 pad_contrast_matrix, 

22 reorder_matrix, 

23 sanitize_labels, 

24 sanitize_reorder, 

25 sanitize_tri, 

26) 

27 

28 

29def _configure_axis( 

30 axes, labels, label_size, x_label_rotation, y_label_rotation 

31): 

32 """Help for plot_matrix.""" 

33 if not labels: 

34 axes.xaxis.set_major_formatter(plt.NullFormatter()) 

35 axes.yaxis.set_major_formatter(plt.NullFormatter()) 

36 else: 

37 axes.set_xticks(np.arange(len(labels))) 

38 axes.set_xticklabels(labels, size=label_size) 

39 for label in axes.get_xticklabels(): 

40 label.set_ha("right") 

41 label.set_rotation(x_label_rotation) 

42 axes.set_yticks(np.arange(len(labels))) 

43 axes.set_yticklabels(labels, size=label_size) 

44 for label in axes.get_yticklabels(): 

45 label.set_ha("right") 

46 label.set_va("top") 

47 label.set_rotation(y_label_rotation) 

48 

49 

50def _configure_grid(axes, tri, size): 

51 """Help for plot_matrix.""" 

52 # Different grids for different layouts 

53 if tri == "lower": 

54 for i in range(size): 

55 # Correct for weird mis-sizing 

56 i = 1.001 * i 

57 axes.plot([i + 0.5, i + 0.5], [size - 0.5, i + 0.5], color="gray") 

58 axes.plot([i + 0.5, -0.5], [i + 0.5, i + 0.5], color="gray") 

59 elif tri == "diag": 

60 for i in range(size): 

61 # Correct for weird mis-sizing 

62 i = 1.001 * i 

63 axes.plot([i + 0.5, i + 0.5], [size - 0.5, i - 0.5], color="gray") 

64 axes.plot([i + 0.5, -0.5], [i - 0.5, i - 0.5], color="gray") 

65 else: 

66 for i in range(size): 

67 # Correct for weird mis-sizing 

68 i = 1.001 * i 

69 axes.plot([i + 0.5, i + 0.5], [size - 0.5, -0.5], color="gray") 

70 axes.plot([size - 0.5, -0.5], [i + 0.5, i + 0.5], color="gray") 

71 

72 

73def _fit_axes(axes): 

74 """Help for plot_matrix. 

75 

76 This function redimensions the given axes to have 

77 labels fitting. 

78 """ 

79 fig = axes.get_figure() 

80 renderer = fig.canvas.get_renderer() 

81 ylabel_width = ( 

82 axes.yaxis.get_tightbbox(renderer) 

83 .transformed(axes.figure.transFigure.inverted()) 

84 .width 

85 ) 

86 if axes.get_position().xmin < 1.1 * ylabel_width: 

87 # we need to move it over 

88 new_position = axes.get_position() 

89 new_position.x0 = 1.1 * ylabel_width # pad a little 

90 axes.set_position(new_position) 

91 

92 xlabel_height = ( 

93 axes.xaxis.get_tightbbox(renderer) 

94 .transformed(axes.figure.transFigure.inverted()) 

95 .height 

96 ) 

97 if axes.get_position().ymin < 1.1 * xlabel_height: 

98 # we need to move it over 

99 new_position = axes.get_position() 

100 new_position.y0 = 1.1 * xlabel_height # pad a little 

101 axes.set_position(new_position) 

102 

103 

104def _sanitize_figure_and_axes(figure, axes): 

105 """Help for plot_matrix.""" 

106 if axes is not None and figure is not None: 

107 raise ValueError( 

108 "Parameters figure and axes cannot be specified together. " 

109 f"You gave 'figure={figure}, axes={axes}'." 

110 ) 

111 if figure is not None: 

112 if isinstance(figure, plt.Figure): 

113 fig = figure 

114 if hasattr(fig, "set_layout_engine"): # can be removed w/mpl 3.5 

115 fig.set_layout_engine("constrained") 

116 else: 

117 fig = plt.figure(figsize=figure, **constrained_layout_kwargs()) 

118 axes = plt.gca() 

119 own_fig = True 

120 elif axes is None: 

121 fig, axes = plt.subplots( 

122 1, 

123 1, 

124 figsize=(7, 5), 

125 **constrained_layout_kwargs(), 

126 ) 

127 own_fig = True 

128 else: 

129 fig = axes.figure 

130 own_fig = False 

131 return fig, axes, own_fig 

132 

133 

134def _sanitize_inputs_plot_matrix( 

135 mat_shape, tri, labels, reorder, figure, axes 

136): 

137 """Help for plot_matrix. 

138 

139 This function makes sure the inputs to plot_matrix are valid. 

140 """ 

141 sanitize_tri(tri) 

142 labels = sanitize_labels(mat_shape, labels) 

143 reorder = sanitize_reorder(reorder) 

144 fig, axes, own_fig = _sanitize_figure_and_axes(figure, axes) 

145 return labels, reorder, fig, axes, own_fig 

146 

147 

148@fill_doc 

149def plot_matrix( 

150 mat, 

151 title=None, 

152 labels=None, 

153 figure=None, 

154 axes=None, 

155 colorbar=True, 

156 cmap=DEFAULT_DIVERGING_CMAP, 

157 tri="full", 

158 auto_fit=True, 

159 grid=False, 

160 reorder=False, 

161 **kwargs, 

162): 

163 """Plot the given matrix. 

164 

165 Parameters 

166 ---------- 

167 mat : 2-D :class:`numpy.ndarray` 

168 Matrix to be plotted. 

169 %(title)s 

170 labels : :obj:`list`, or :class:`numpy.ndarray` of :obj:`str`,\ 

171 or False, or None, default=None 

172 The label of each row and column. Needs to be the same 

173 length as rows/columns of mat. If False, None, or an 

174 empty list, no labels are plotted. 

175 

176 figure : :class:`matplotlib.figure.Figure`, figsize :obj:`tuple`,\ 

177 or None, default=None 

178 Sets the figure used. This argument can be either an existing 

179 figure, or a pair (width, height) that gives the size of a 

180 newly-created figure. 

181 

182 .. note:: 

183 

184 Specifying both axes and figure is not allowed. 

185 

186 axes : None or :class:`matplotlib.axes.Axes`, default=None 

187 Axes instance to be plotted on. Creates a new one if None. 

188 

189 .. note:: 

190 

191 Specifying both axes and figure is not allowed. 

192 

193 %(colorbar)s 

194 Default=True. 

195 

196 %(cmap)s 

197 default="RdBu_r" 

198 

199 tri : {'full', 'lower', 'diag'}, default='full' 

200 Which triangular part of the matrix to plot: 

201 

202 - 'lower': Plot the lower part 

203 - 'diag': Plot the lower part with the diagonal 

204 - 'full': Plot the full matrix 

205 

206 

207 auto_fit : :obj:`bool`, default=True 

208 If auto_fit is True, the axes are dimensioned to give room 

209 for the labels. This assumes that the labels are resting 

210 against the bottom and left edges of the figure. 

211 

212 grid : color or False, default=False 

213 If not False, a grid is plotted to separate rows and columns 

214 using the given color. 

215 

216 reorder : :obj:`bool` or {'single', 'complete', 'average'}, default=False 

217 If not False, reorders the matrix into blocks of clusters. 

218 Accepted linkage options for the clustering are 'single', 

219 'complete', and 'average'. True defaults to average linkage. 

220 

221 .. note:: 

222 This option is only available with SciPy >= 1.0.0. 

223 

224 .. versionadded:: 0.4.1 

225 

226 kwargs : extra keyword arguments, optional 

227 Extra keyword arguments are sent to pylab.imshow. 

228 

229 Returns 

230 ------- 

231 display : :class:`matplotlib.axes.Axes` 

232 Axes image. 

233 

234 """ 

235 labels, reorder, fig, axes, _ = _sanitize_inputs_plot_matrix( 

236 mat.shape, tri, labels, reorder, figure, axes 

237 ) 

238 if reorder: 

239 mat, labels = reorder_matrix(mat, labels, reorder) 

240 if tri != "full": 

241 mat = mask_matrix(mat, tri) 

242 display = axes.imshow( 

243 mat, aspect="equal", interpolation="nearest", cmap=cmap, **kwargs 

244 ) 

245 axes.set_autoscale_on(False) 

246 ymin, ymax = axes.get_ylim() 

247 _configure_axis( 

248 axes, 

249 labels, 

250 label_size="x-small", 

251 x_label_rotation=50, 

252 y_label_rotation=10, 

253 ) 

254 if grid is not False: 

255 _configure_grid(axes, tri, len(mat)) 

256 axes.set_ylim(ymin, ymax) 

257 if auto_fit and labels: 

258 _fit_axes(axes) 

259 if colorbar: 

260 divider = make_axes_locatable(axes) 

261 cax = divider.append_axes("right", size="5%", pad=0.05) 

262 

263 fig.colorbar(display, cax=cax) 

264 

265 if title is not None: 

266 axes.set_title(title, size=16) 

267 

268 return display 

269 

270 

271@fill_doc 

272@rename_parameters({"ax": "axes"}, end_version="0.13.0") 

273def plot_contrast_matrix( 

274 contrast_def, design_matrix, colorbar=True, axes=None, output_file=None 

275): 

276 """Create plot for :term:`contrast` definition. 

277 

278 Parameters 

279 ---------- 

280 contrast_def : :obj:`str` or :class:`numpy.ndarray` of shape[1] <= n_col \ 

281 where ``n_col`` is the number of columns of the design matrix. 

282 The string can be a formula compatible 

283 with :meth:`pandas.DataFrame.eval`. 

284 Basically one can use the name of the conditions 

285 as they appear in the design matrix of the fitted model 

286 combined with operators +- 

287 and combined with numbers with operators +-`*`/. 

288 

289 design_matrix : :class:`pandas.DataFrame` 

290 Design matrix to use. 

291 

292 %(colorbar)s 

293 Default=True. 

294 

295 axes : :class:`matplotlib.axes.Axes` or None, default=None 

296 Axis on which to plot the figure. 

297 If None, a new figure will be created. 

298 

299 %(output_file)s 

300 

301 Returns 

302 ------- 

303 axes : :class:`matplotlib.axes.Axes` 

304 Figure object. 

305 

306 """ 

307 contrast_def = pad_contrast_matrix(contrast_def, design_matrix) 

308 con_matrix = np.array(contrast_def, ndmin=2) 

309 

310 design_column_names = design_matrix.columns.tolist() 

311 max_len = np.max([len(str(name)) for name in design_column_names]) 

312 

313 n_columns_design_matrix = len(design_column_names) 

314 if axes is None: 

315 _, axes = plt.subplots( 

316 figsize=( 

317 0.4 * n_columns_design_matrix, 

318 1 + 0.5 * con_matrix.shape[0] + 0.04 * max_len, 

319 ), 

320 **constrained_layout_kwargs(), 

321 ) 

322 

323 maxval = np.max(np.abs(contrast_def)) 

324 mat = axes.matshow( 

325 con_matrix, aspect="equal", cmap="gray", vmin=-maxval, vmax=maxval 

326 ) 

327 

328 axes.set_label("conditions") 

329 axes.set_ylabel("") 

330 axes.set_yticks(()) 

331 

332 axes.xaxis.set(ticks=np.arange(n_columns_design_matrix)) 

333 axes.set_xticklabels(design_column_names, rotation=50, ha="left") 

334 

335 if colorbar: 

336 fig = axes.figure 

337 fig.colorbar(mat, fraction=0.025, pad=0.04) 

338 

339 return save_figure_if_needed(axes, output_file) 

340 

341 

342@fill_doc 

343@rename_parameters({"ax": "axes"}, end_version="0.13.0") 

344def plot_design_matrix( 

345 design_matrix, 

346 rescale=True, 

347 axes=None, 

348 output_file=None, 

349): 

350 """Plot a design matrix. 

351 

352 Parameters 

353 ---------- 

354 design matrix : :class:`pandas.DataFrame` or \ 

355 :obj:`str` or :obj:`pathlib.Path` to a TSV event file 

356 Describes a design matrix. 

357 

358 rescale : :obj:`bool`, default=True 

359 Rescale columns magnitude for visualization or not. 

360 

361 axes : :class:`matplotlib.axes.Axes` or None, default=None 

362 Handle to axes onto which we will draw the design matrix. 

363 

364 %(output_file)s 

365 

366 Returns 

367 ------- 

368 axes : :class:`matplotlib.axes.Axes` 

369 The axes used for plotting. 

370 

371 """ 

372 design_matrix = check_and_load_tables(design_matrix, "design_matrix")[0] 

373 

374 _, X, names = check_design_matrix(design_matrix) 

375 # normalize the values per column for better visualization 

376 if rescale: 

377 X = X / np.maximum(1.0e-12, np.sqrt(np.sum(X**2, 0))) 

378 if axes is None: 

379 max_len = np.max([len(str(name)) for name in names]) 

380 fig_height = 1 + 0.1 * X.shape[0] + 0.04 * max_len 

381 if fig_height < 3: 

382 fig_height = 3 

383 elif fig_height > 10: 

384 fig_height = 10 

385 _, axes = plt.subplots( 

386 figsize=(1 + 0.23 * len(names), fig_height), 

387 **constrained_layout_kwargs(), 

388 ) 

389 

390 axes.imshow(X, interpolation="nearest", aspect="auto") 

391 axes.set_label("conditions") 

392 axes.set_ylabel("scan number") 

393 

394 axes.set_xticks(range(len(names))) 

395 axes.set_xticklabels(names, rotation=60, ha="left") 

396 # Set ticks above, to have a display more similar to the display of a 

397 # corresponding dataframe 

398 axes.xaxis.tick_top() 

399 

400 return save_figure_if_needed(axes, output_file) 

401 

402 

403@fill_doc 

404def plot_event(model_event, cmap=None, output_file=None, **fig_kwargs): 

405 """Create plot for event visualization. 

406 

407 .. warning:: 

408 

409 Events with a duration of 0 seconds will be plotted 

410 by a 'delta function'. 

411 

412 Parameters 

413 ---------- 

414 model_event : :class:`pandas.DataFrame`, \ 

415 :obj:`str` or :obj:`pathlib.Path` to a TSV event file, \ 

416 or a :obj:`list` or :obj:`tuple` \ 

417 of :class:`pandas.DataFrame`, \ 

418 :obj:`str` or :obj:`pathlib.Path` to a TSV event file. 

419 The :class:`pandas.DataFrame` must have three columns: 

420 ``trial_type`` with event name, ``onset`` and ``duration``. 

421 See :func:`~nilearn.glm.first_level.make_first_level_design_matrix` 

422 for details on the required content of events dataframes. 

423 

424 .. note:: 

425 

426 The :class:`pandas.DataFrame` can also be obtained 

427 from :func:`nilearn.glm.first_level.first_level_from_bids`. 

428 

429 %(cmap)s 

430 

431 %(output_file)s 

432 

433 **fig_kwargs : extra keyword arguments, optional 

434 Extra arguments passed to :func:`matplotlib.pyplot.subplots`. 

435 

436 Returns 

437 ------- 

438 figure : :class:`matplotlib.figure.Figure` 

439 Plot Figure object. 

440 

441 """ 

442 model_event = check_and_load_tables(model_event, "model_event") 

443 

444 for i, event in enumerate(model_event): 

445 event_copy = check_events(event) 

446 model_event[i] = event_copy 

447 

448 n_runs = len(model_event) 

449 if "layout" not in fig_kwargs and "constrained_layout" not in fig_kwargs: 

450 fig_kwargs.update(**constrained_layout_kwargs()) 

451 figure, axes = plt.subplots(1, 1, **fig_kwargs) 

452 

453 # input validation 

454 if cmap is None: 

455 cmap = "tab20" 

456 if isinstance(cmap, str): 

457 cmap = plt.get_cmap(cmap) 

458 

459 event_labels = pd.concat(event["trial_type"] for event in model_event) 

460 event_labels = np.unique(event_labels) 

461 

462 cmap_dictionary = {label: idx for idx, label in enumerate(event_labels)} 

463 

464 if len(event_labels) > cmap.N: 

465 plt.close(fig=figure) 

466 raise ValueError( 

467 "The number of event types is greater than " 

468 f"colors in colormap ({len(event_labels)} > {cmap.N}). " 

469 "Use a different colormap." 

470 ) 

471 

472 height = 0.5 

473 x_lim = [] 

474 for idx_run, event_df in enumerate(model_event): 

475 for _, event in event_df.iterrows(): 

476 modulation = 1.0 

477 if "modulation" in event: 

478 modulation = event["modulation"] 

479 

480 ymin = (idx_run + 0.25) / n_runs 

481 ymax = (idx_run + 0.25 + height * modulation) / n_runs 

482 

483 event_onset = event["onset"] 

484 event_end = event["onset"] + event["duration"] 

485 

486 x_lim.append(event_end) 

487 

488 color = cmap.colors[cmap_dictionary[event["trial_type"]]] 

489 

490 if event["duration"] != 0: 

491 axes.axvspan( 

492 xmin=event_onset, 

493 xmax=event_end, 

494 ymin=ymin, 

495 ymax=ymax, 

496 facecolor=color, 

497 ) 

498 

499 # events will 0 duration are plotted as lines 

500 else: 

501 axes.axvline( 

502 event_onset, 

503 ymin=ymin, 

504 ymax=ymax, 

505 color=color, 

506 ) 

507 

508 handles = [] 

509 for label, idx in cmap_dictionary.items(): 

510 patch = mpatches.Patch(color=cmap.colors[idx], label=label) 

511 handles.append(patch) 

512 

513 _ = axes.legend(handles=handles, ncol=4) 

514 

515 axes.set_xlabel("Time (sec.)") 

516 axes.set_ylabel("Runs") 

517 axes.set_ylim(0, n_runs) 

518 axes.set_xlim(-1, max(x_lim) + 1) 

519 axes.set_yticks(np.arange(n_runs) + 0.5) 

520 axes.set_yticklabels(np.arange(n_runs) + 1) 

521 

522 return save_figure_if_needed(figure, output_file) 

523 

524 

525@fill_doc 

526def plot_design_matrix_correlation( 

527 design_matrix, 

528 tri="full", 

529 cmap=DEFAULT_DIVERGING_CMAP, 

530 colorbar=True, 

531 output_file=None, 

532 **kwargs, 

533): 

534 """Compute and plot the correlation between regressor of a design matrix. 

535 

536 The drift and constant regressors are omitted from the plot. 

537 

538 .. versionadded:: 0.11.0 

539 

540 Parameters 

541 ---------- 

542 design_matrix : :obj:`pandas.DataFrame`, :obj:`pandas.DataFrame` \ 

543 :obj:`pathlib.Path` 

544 Design matrix whose correlation matrix you want to plot. 

545 

546 tri : {"full", "diag"}, default="full" 

547 Which triangular part of the matrix to plot: 

548 

549 - ``"diag"``: Plot the lower part with the diagonal 

550 - ``"full"``: Plot the full matrix 

551 

552 %(cmap)s 

553 default="RdBu_r" 

554 

555 This must be a diverging colormap as the correlation matrix 

556 will be centered on 0. 

557 The allowed colormaps are: 

558 

559 - ``"bwr"`` 

560 - ``"RdBu_r"`` 

561 - ``"seismic_r"`` 

562 

563 %(output_file)s 

564 

565 kwargs : extra keyword arguments, optional 

566 Extra keyword arguments are sent to 

567 :func:`nilearn.plotting.plot_matrix` 

568 

569 Returns 

570 ------- 

571 display : :class:`matplotlib.axes.Axes` 

572 Axes image. 

573 """ 

574 design_matrix = check_and_load_tables(design_matrix, "design_matrix")[0] 

575 

576 check_design_matrix(design_matrix) 

577 

578 ALLOWED_CMAP = ["RdBu_r", "bwr", "seismic_r"] 

579 cmap_name = cmap if isinstance(cmap, str) else cmap.name 

580 if cmap_name not in ALLOWED_CMAP: 

581 raise ValueError(f"cmap must be one of {ALLOWED_CMAP}") 

582 

583 columns_to_drop = ["intercept", "constant"] 

584 columns_to_drop.extend( 

585 col for col in design_matrix.columns if col.startswith("drift_") 

586 ) 

587 design_matrix = design_matrix.drop( 

588 columns=columns_to_drop, errors="ignore" 

589 ) 

590 

591 if len(design_matrix.columns) == 0: 

592 raise ValueError( 

593 "Nothing left to plot after " 

594 "removing drift and constant regressors." 

595 ) 

596 

597 sanitize_tri(tri, allowed_values=("full", "diag")) 

598 

599 mat = design_matrix.corr() 

600 

601 mat = mat.to_numpy() 

602 vmax = max(mat.min(), mat.max(), key=abs) 

603 if len(mat) > 1: 

604 # find the second-largest value in each row 

605 # to omit values on the diagonal that will always be == 1 

606 second_largest = np.partition(mat, -2, axis=1)[:, -2] 

607 vmax = max(abs(mat.min().min()), max(second_largest)) 

608 

609 col_labels = design_matrix.columns 

610 display = plot_matrix( 

611 mat, 

612 tri=tri, 

613 cmap=cmap, 

614 vmax=vmax, 

615 vmin=vmax * -1, 

616 labels=col_labels.to_list(), 

617 colorbar=colorbar, 

618 **kwargs, 

619 ) 

620 

621 return save_figure_if_needed(display, output_file)