Coverage for nilearn/plotting/displays/_slicers.py: 0%

719 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-16 12:32 +0200

1import collections 

2import contextlib 

3import numbers 

4import warnings 

5from typing import ClassVar 

6 

7import matplotlib.pyplot as plt 

8import numpy as np 

9from matplotlib.colorbar import ColorbarBase 

10from matplotlib.colors import LinearSegmentedColormap, ListedColormap 

11from matplotlib.transforms import Bbox 

12 

13from nilearn._utils import check_niimg_3d, fill_doc 

14from nilearn._utils.logger import find_stack_level 

15from nilearn._utils.niimg import is_binary_niimg, safe_get_data 

16from nilearn._utils.niimg_conversions import _check_fov 

17from nilearn._utils.param_validation import check_params 

18from nilearn.image import get_data, new_img_like, reorder_img 

19from nilearn.image.resampling import get_bounds, get_mask_bounds, resample_img 

20from nilearn.plotting._utils import ( 

21 check_threshold_not_negative, 

22 get_cbar_ticks, 

23) 

24from nilearn.plotting.displays import CutAxes 

25from nilearn.plotting.displays._utils import ( 

26 coords_3d_to_2d, 

27 get_create_display_fun, 

28) 

29from nilearn.plotting.displays.edge_detect import edge_map 

30from nilearn.plotting.find_cuts import find_cut_slices, find_xyz_cut_coords 

31from nilearn.typing import NiimgLike 

32 

33 

34@fill_doc 

35class BaseSlicer: 

36 """BaseSlicer implementation which main purpose is to auto adjust \ 

37 the axes size to the data with different layout of cuts. 

38 

39 It creates 3 linked axes for plotting orthogonal cuts. 

40 

41 Attributes 

42 ---------- 

43 cut_coords : 3 :obj:`tuple` of :obj:`int` 

44 The cut position, in world space. 

45 

46 frame_axes : :class:`matplotlib.axes.Axes`, optional 

47 The matplotlib axes that will be subdivided in 3. 

48 

49 black_bg : :obj:`bool`, default=False 

50 If ``True``, the background of the figure will be put to 

51 black. If you wish to save figures with a black background, 

52 you will need to pass ``facecolor='k', edgecolor='k'`` 

53 to :func:`~matplotlib.pyplot.savefig`. 

54 

55 brain_color : :obj:`tuple`, default=(0.5, 0.5, 0.5) 

56 The brain color to use as the background color (e.g., for 

57 transparent colorbars). 

58 """ 

59 

60 # This actually encodes the figsize for only one axe 

61 _default_figsize: ClassVar[list[float]] = [2.2, 2.6] 

62 _axes_class = CutAxes 

63 

64 def __init__( 

65 self, 

66 cut_coords, 

67 axes=None, 

68 black_bg=False, 

69 brain_color=(0.5, 0.5, 0.5), 

70 **kwargs, 

71 ): 

72 self.cut_coords = cut_coords 

73 if axes is None: 

74 axes = plt.axes((0.0, 0.0, 1.0, 1.0)) 

75 axes.axis("off") 

76 self.frame_axes = axes 

77 axes.set_zorder(1) 

78 bb = axes.get_position() 

79 self.rect = (bb.x0, bb.y0, bb.x1, bb.y1) 

80 self._black_bg = black_bg 

81 self._brain_color = brain_color 

82 self._colorbar = False 

83 self._colorbar_width = 0.05 * bb.width 

84 self._cbar_tick_format = "%.2g" 

85 self._colorbar_margin = { 

86 "left": 0.25 * bb.width, 

87 "right": 0.02 * bb.width, 

88 "top": 0.05 * bb.height, 

89 "bottom": 0.05 * bb.height, 

90 } 

91 self._init_axes(**kwargs) 

92 

93 @property 

94 def brain_color(self): 

95 """Return brain color.""" 

96 return self._brain_color 

97 

98 @property 

99 def black_bg(self): 

100 """Return black background.""" 

101 return self._black_bg 

102 

103 @staticmethod 

104 def find_cut_coords(img=None, threshold=None, cut_coords=None): 

105 """Act as placeholder and is not implemented in the base class \ 

106 and has to be implemented in derived classes. 

107 """ 

108 # Implement this as a staticmethod or a classmethod when 

109 # subclassing 

110 raise NotImplementedError 

111 

112 @classmethod 

113 @fill_doc # the fill_doc decorator must be last applied 

114 def init_with_figure( 

115 cls, 

116 img, 

117 threshold=None, 

118 cut_coords=None, 

119 figure=None, 

120 axes=None, 

121 black_bg=False, 

122 leave_space=False, 

123 colorbar=False, 

124 brain_color=(0.5, 0.5, 0.5), 

125 **kwargs, 

126 ): 

127 """Initialize the slicer with an image. 

128 

129 Parameters 

130 ---------- 

131 %(img)s 

132 

133 %(threshold)s 

134 

135 cut_coords : 3 :obj:`tuple` of :obj:`int` 

136 The cut position, in world space. 

137 

138 axes : :class:`matplotlib.axes.Axes`, optional 

139 The axes that will be subdivided in 3. 

140 

141 black_bg : :obj:`bool`, default=False 

142 If ``True``, the background of the figure will be put to 

143 black. If you wish to save figures with a black background, 

144 you will need to pass ``facecolor='k', edgecolor='k'`` 

145 to :func:`matplotlib.pyplot.savefig`. 

146 

147 

148 brain_color : :obj:`tuple`, default=(0.5, 0.5, 0.5) 

149 The brain color to use as the background color (e.g., for 

150 transparent colorbars). 

151 

152 Raises 

153 ------ 

154 ValueError 

155 if the specified threshold is a negative number 

156 """ 

157 check_params(locals()) 

158 check_threshold_not_negative(threshold) 

159 

160 # deal with "fake" 4D images 

161 if img is not None and img is not False: 

162 img = check_niimg_3d(img) 

163 

164 cut_coords = cls.find_cut_coords(img, threshold, cut_coords) 

165 

166 if isinstance(axes, plt.Axes) and figure is None: 

167 figure = axes.figure 

168 

169 if not isinstance(figure, plt.Figure): 

170 # Make sure that we have a figure 

171 figsize = cls._default_figsize[:] 

172 

173 # Adjust for the number of axes 

174 figsize[0] *= len(cut_coords) 

175 

176 # Make space for the colorbar 

177 if colorbar: 

178 figsize[0] += 0.7 

179 

180 facecolor = "k" if black_bg else "w" 

181 

182 if leave_space: 

183 figsize[0] += 3.4 

184 figure = plt.figure(figure, figsize=figsize, facecolor=facecolor) 

185 if isinstance(axes, plt.Axes): 

186 assert axes.figure is figure, ( 

187 "The axes passed are not in the figure" 

188 ) 

189 

190 if axes is None: 

191 axes = [0.3, 0, 0.7, 1.0] if leave_space else [0.0, 0.0, 1.0, 1.0] 

192 if isinstance(axes, collections.abc.Sequence): 

193 axes = figure.add_axes(axes) 

194 # People forget to turn their axis off, or to set the zorder, and 

195 # then they cannot see their slicer 

196 axes.axis("off") 

197 return cls(cut_coords, axes, black_bg, brain_color, **kwargs) 

198 

199 def title( 

200 self, 

201 text, 

202 x=0.01, 

203 y=0.99, 

204 size=15, 

205 color=None, 

206 bgcolor=None, 

207 alpha=1, 

208 **kwargs, 

209 ): 

210 """Write a title to the view. 

211 

212 Parameters 

213 ---------- 

214 text : :obj:`str` 

215 The text of the title. 

216 

217 x : :obj:`float`, default=0.01 

218 The horizontal position of the title on the frame in 

219 fraction of the frame width. 

220 

221 y : :obj:`float`, default=0.99 

222 The vertical position of the title on the frame in 

223 fraction of the frame height. 

224 

225 size : :obj:`int`, default=15 

226 The size of the title text. 

227 

228 color : matplotlib color specifier, optional 

229 The color of the font of the title. 

230 

231 bgcolor : matplotlib color specifier, optional 

232 The color of the background of the title. 

233 

234 alpha : :obj:`float`, default=1 

235 The alpha value for the background. 

236 

237 kwargs : 

238 Extra keyword arguments are passed to matplotlib's text 

239 function. 

240 """ 

241 if color is None: 

242 color = "k" if self._black_bg else "w" 

243 if bgcolor is None: 

244 bgcolor = "w" if self._black_bg else "k" 

245 if hasattr(self, "_cut_displayed"): 

246 # Adapt to the case of mosaic plotting 

247 if isinstance(self.cut_coords, dict): 

248 first_axe = self._cut_displayed[-1] 

249 first_axe = (first_axe, self.cut_coords[first_axe][0]) 

250 else: 

251 first_axe = self._cut_displayed[0] 

252 else: 

253 first_axe = self.cut_coords[0] 

254 ax = self.axes[first_axe].ax 

255 ax.text( 

256 x, 

257 y, 

258 text, 

259 transform=self.frame_axes.transAxes, 

260 horizontalalignment="left", 

261 verticalalignment="top", 

262 size=size, 

263 color=color, 

264 bbox={ 

265 "boxstyle": "square,pad=.3", 

266 "ec": bgcolor, 

267 "fc": bgcolor, 

268 "alpha": alpha, 

269 }, 

270 zorder=1000, 

271 **kwargs, 

272 ) 

273 ax.set_zorder(1000) 

274 

275 @fill_doc 

276 def add_overlay( 

277 self, 

278 img, 

279 threshold=1e-6, 

280 colorbar=False, 

281 cbar_tick_format="%.2g", 

282 cbar_vmin=None, 

283 cbar_vmax=None, 

284 transparency=None, 

285 transparency_range=None, 

286 **kwargs, 

287 ): 

288 """Plot a 3D map in all the views. 

289 

290 Parameters 

291 ---------- 

292 %(img)s 

293 If it is a masked array, only the non-masked part will be plotted. 

294 

295 threshold : :obj:`int` or :obj:`float` or ``None``, default=1e-6 

296 Threshold to apply: 

297 

298 - If ``None`` is given, the maps are not thresholded. 

299 - If number is given, it must be non-negative. The specified 

300 value is used to threshold the image: values below the 

301 threshold (in absolute value) are plotted as transparent. 

302 

303 cbar_tick_format : str, default="%%.2g" (scientific notation) 

304 Controls how to format the tick labels of the colorbar. 

305 Ex: use "%%i" to display as integers. 

306 

307 colorbar : :obj:`bool`, default=False 

308 If ``True``, display a colorbar on the right of the plots. 

309 

310 cbar_vmin : :obj:`float`, optional 

311 Minimal value for the colorbar. If None, the minimal value 

312 is computed based on the data. 

313 

314 cbar_vmax : :obj:`float`, optional 

315 Maximal value for the colorbar. If None, the maximal value 

316 is computed based on the data. 

317 

318 %(transparency)s 

319 

320 %(transparency_range)s 

321 

322 kwargs : :obj:`dict` 

323 Extra keyword arguments are passed to function 

324 :func:`~matplotlib.pyplot.imshow`. 

325 

326 Raises 

327 ------ 

328 ValueError 

329 if the specified threshold is a negative number 

330 """ 

331 check_threshold_not_negative(threshold) 

332 

333 if colorbar and self._colorbar: 

334 raise ValueError( 

335 "This figure already has an overlay with a colorbar." 

336 ) 

337 

338 self._colorbar = colorbar 

339 self._cbar_tick_format = cbar_tick_format 

340 

341 img = check_niimg_3d(img) 

342 

343 # Make sure that add_overlay shows consistent default behavior 

344 # with plot_stat_map 

345 kwargs.setdefault("interpolation", "nearest") 

346 ims = self._map_show( 

347 img, 

348 type="imshow", 

349 threshold=threshold, 

350 transparency=transparency, 

351 transparency_range=transparency_range, 

352 **kwargs, 

353 ) 

354 

355 # `ims` can be empty in some corner cases, 

356 # look at test_img_plotting.test_outlier_cut_coords. 

357 if colorbar and ims: 

358 self._show_colorbar( 

359 ims[0].cmap, ims[0].norm, cbar_vmin, cbar_vmax, threshold 

360 ) 

361 

362 plt.draw_if_interactive() 

363 

364 @fill_doc 

365 def add_contours(self, img, threshold=1e-6, filled=False, **kwargs): 

366 """Contour a 3D map in all the views. 

367 

368 Parameters 

369 ---------- 

370 %(img)s 

371 Provides image to plot. 

372 

373 threshold : :obj:`int` or :obj:`float` or ``None``, default=1e-6 

374 Threshold to apply: 

375 

376 - If ``None`` is given, the maps are not thresholded. 

377 - If number is given, it must be non-negative. The specified 

378 value is used to threshold the image: values below the 

379 threshold (in absolute value) are plotted as transparent. 

380 

381 filled : :obj:`bool`, default=False 

382 If ``filled=True``, contours are displayed with color fillings. 

383 

384 

385 kwargs : :obj:`dict` 

386 Extra keyword arguments are passed to function 

387 :func:`~matplotlib.pyplot.contour`, or function 

388 :func:`~matplotlib.pyplot.contourf`. 

389 Useful, arguments are typical "levels", which is a 

390 list of values to use for plotting a contour or contour 

391 fillings (if ``filled=True``), and 

392 "colors", which is one color or a list of colors for 

393 these contours. 

394 

395 Raises 

396 ------ 

397 ValueError 

398 if the specified threshold is a negative number 

399 

400 Notes 

401 ----- 

402 If colors are not specified, default coloring choices 

403 (from matplotlib) for contours and contour_fillings can be 

404 different. 

405 

406 """ 

407 if not filled: 

408 threshold = None 

409 else: 

410 check_threshold_not_negative(threshold) 

411 

412 self._map_show(img, type="contour", threshold=threshold, **kwargs) 

413 if filled: 

414 if "levels" in kwargs: 

415 levels = kwargs["levels"] 

416 if len(levels) <= 1: 

417 # contour fillings levels 

418 # should be given as (lower, upper). 

419 levels.append(np.inf) 

420 

421 self._map_show(img, type="contourf", threshold=threshold, **kwargs) 

422 

423 plt.draw_if_interactive() 

424 

425 def _map_show( 

426 self, 

427 img, 

428 type="imshow", 

429 resampling_interpolation="continuous", 

430 threshold=None, 

431 transparency=None, 

432 transparency_range=None, 

433 **kwargs, 

434 ): 

435 # In the special case where the affine of img is not diagonal, 

436 # the function `reorder_img` will trigger a resampling 

437 # of the provided image with a continuous interpolation 

438 # since this is the default value here. In the special 

439 # case where this image is binary, such as when this function 

440 # is called from `add_contours`, continuous interpolation 

441 # does not make sense and we turn to nearest interpolation instead. 

442 

443 if is_binary_niimg(img): 

444 resampling_interpolation = "nearest" 

445 

446 # Image reordering should be done before sanitizing transparency 

447 img = reorder_img( 

448 img, resample=resampling_interpolation, copy_header=True 

449 ) 

450 

451 transparency, transparency_affine = self._sanitize_transparency( 

452 img, 

453 transparency, 

454 transparency_range, 

455 resampling_interpolation, 

456 ) 

457 

458 affine = img.affine 

459 

460 if threshold is not None: 

461 threshold = float(threshold) 

462 data = safe_get_data(img, ensure_finite=True) 

463 data = self._threshold(data, threshold, None, None) 

464 img = new_img_like(img, data, affine) 

465 

466 data = safe_get_data(img, ensure_finite=True) 

467 data_bounds = get_bounds(data.shape, affine) 

468 (xmin, xmax), (ymin, ymax), (zmin, zmax) = data_bounds 

469 

470 xmin_, xmax_, ymin_, ymax_, zmin_, zmax_ = ( 

471 xmin, 

472 xmax, 

473 ymin, 

474 ymax, 

475 zmin, 

476 zmax, 

477 ) 

478 

479 # Compute tight bounds 

480 if type in ("contour", "contourf"): 

481 # Define a pseudo threshold to have a tight bounding box 

482 thr = ( 

483 0.9 * np.min(np.abs(kwargs["levels"])) 

484 if "levels" in kwargs 

485 else 1e-6 

486 ) 

487 not_mask = np.logical_or(data > thr, data < -thr) 

488 xmin_, xmax_, ymin_, ymax_, zmin_, zmax_ = get_mask_bounds( 

489 new_img_like(img, not_mask, affine) 

490 ) 

491 elif hasattr(data, "mask") and isinstance(data.mask, np.ndarray): 

492 not_mask = np.logical_not(data.mask) 

493 xmin_, xmax_, ymin_, ymax_, zmin_, zmax_ = get_mask_bounds( 

494 new_img_like(img, not_mask, affine) 

495 ) 

496 

497 data_2d_list = [] 

498 transparency_list = [] 

499 for display_ax in self.axes.values(): 

500 if transparency is None or isinstance(transparency, (float, int)): 

501 transparency_2d = transparency 

502 

503 try: 

504 data_2d = display_ax.transform_to_2d(data, affine) 

505 if isinstance(transparency, np.ndarray): 

506 transparency_2d = display_ax.transform_to_2d( 

507 transparency, transparency_affine 

508 ) 

509 except IndexError: 

510 # We are cutting outside the indices of the data 

511 data_2d = None 

512 transparency_2d = None 

513 

514 data_2d_list.append(data_2d) 

515 transparency_list.append(transparency_2d) 

516 

517 if kwargs.get("vmin") is None: 

518 kwargs["vmin"] = np.ma.min( 

519 [d.min() for d in data_2d_list if d is not None] 

520 ) 

521 if kwargs.get("vmax") is None: 

522 kwargs["vmax"] = np.ma.max( 

523 [d.max() for d in data_2d_list if d is not None] 

524 ) 

525 

526 bounding_box = (xmin_, xmax_), (ymin_, ymax_), (zmin_, zmax_) 

527 ims = [] 

528 to_iterate_over = zip( 

529 self.axes.values(), data_2d_list, transparency_list 

530 ) 

531 threshold = float(threshold) if threshold else None 

532 for display_ax, data_2d, transparency_2d in to_iterate_over: 

533 # If data_2d is completely masked, then there is nothing to 

534 # plot. Hence, no point to do imshow(). 

535 if data_2d is not None: 

536 data_2d = self._threshold( 

537 data_2d, 

538 threshold, 

539 vmin=float(kwargs.get("vmin")), 

540 vmax=float(kwargs.get("vmax")), 

541 ) 

542 

543 im = display_ax.draw_2d( 

544 data_2d, 

545 data_bounds, 

546 bounding_box, 

547 type=type, 

548 transparency=transparency_2d, 

549 **kwargs, 

550 ) 

551 ims.append(im) 

552 return ims 

553 

554 def _sanitize_transparency( 

555 self, img, transparency, transparency_range, resampling_interpolation 

556 ): 

557 """Return transparency as None, float or an array. 

558 

559 Return 

560 ------ 

561 transparency: None, float or np.ndarray 

562 

563 transparency_affine: None or np.ndarray 

564 """ 

565 transparency_affine = None 

566 if isinstance(transparency, NiimgLike): 

567 transparency = check_niimg_3d(transparency, dtype="auto") 

568 if is_binary_niimg(transparency): 

569 resampling_interpolation = "nearest" 

570 transparency = reorder_img( 

571 transparency, 

572 resample=resampling_interpolation, 

573 copy_header=True, 

574 ) 

575 if not _check_fov(transparency, img.affine, img.shape[:3]): 

576 warnings.warn( 

577 "resampling transparency image to data image...", 

578 stacklevel=find_stack_level(), 

579 ) 

580 transparency = resample_img( 

581 transparency, 

582 img.affine, 

583 img.shape, 

584 force_resample=True, 

585 copy_header=True, 

586 interpolation=resampling_interpolation, 

587 ) 

588 

589 transparency_affine = transparency.affine 

590 transparency = safe_get_data(transparency, ensure_finite=True) 

591 

592 assert transparency is None or isinstance( 

593 transparency, (int, float, np.ndarray) 

594 ) 

595 

596 if isinstance(transparency, (float, int)): 

597 transparency = float(transparency) 

598 base_warning_message = ( 

599 "'transparency' must be in the interval [0, 1]. " 

600 ) 

601 if transparency > 1.0: 

602 warnings.warn( 

603 f"{base_warning_message} Setting it to 1.0.", 

604 stacklevel=find_stack_level(), 

605 ) 

606 transparency = 1.0 

607 if transparency < 0: 

608 warnings.warn( 

609 f"{base_warning_message} Setting it to 0.0.", 

610 stacklevel=find_stack_level(), 

611 ) 

612 transparency = 0.0 

613 

614 elif isinstance(transparency, np.ndarray): 

615 transparency = np.abs(transparency) 

616 

617 if transparency_range is None: 

618 transparency_range = [0.0, np.max(transparency)] 

619 

620 error_msg = ( 

621 "'transparency_range' must be " 

622 "a list or tuple of 2 non-negative numbers " 

623 "with 'first value < second value'." 

624 ) 

625 

626 if len(transparency_range) != 2: 

627 raise ValueError(f"{error_msg} Got '{transparency_range}'.") 

628 

629 transparency_range[1] = min( 

630 transparency_range[1], np.max(transparency) 

631 ) 

632 transparency_range[0] = max( 

633 transparency_range[0], np.min(transparency) 

634 ) 

635 

636 if transparency_range[0] >= transparency_range[1]: 

637 raise ValueError(f"{error_msg} Got '{transparency_range}'.") 

638 

639 # make sure that 0 <= transparency <= 1 

640 # taking into account the requested transparency_range 

641 transparency = np.clip( 

642 transparency, transparency_range[0], transparency_range[1] 

643 ) 

644 transparency = (transparency - transparency_range[0]) / ( 

645 transparency_range[1] - transparency_range[0] 

646 ) 

647 

648 return transparency, transparency_affine 

649 

650 @classmethod 

651 def _threshold(cls, data, threshold=None, vmin=None, vmax=None): 

652 """Threshold the data. 

653 

654 Parameters 

655 ---------- 

656 data: ndarray 

657 data to be thresholded 

658 

659 %(threshold)s 

660 

661 %(vmin)s 

662 

663 %(vmax)s 

664 

665 Raises 

666 ------ 

667 ValueError 

668 if the specified threshold is a negative number 

669 """ 

670 check_params(locals()) 

671 check_threshold_not_negative(threshold) 

672 

673 if threshold is not None: 

674 data = np.ma.masked_where( 

675 np.abs(data) <= threshold, 

676 data, 

677 copy=False, 

678 ) 

679 

680 if (vmin is not None) and (vmin >= -threshold): 

681 data = np.ma.masked_where(data < vmin, data, copy=False) 

682 if (vmax is not None) and (vmax <= threshold): 

683 data = np.ma.masked_where(data > vmax, data, copy=False) 

684 

685 return data 

686 

687 @fill_doc 

688 def _show_colorbar( 

689 self, cmap, norm, cbar_vmin=None, cbar_vmax=None, threshold=None 

690 ): 

691 """Display the colorbar. 

692 

693 Parameters 

694 ---------- 

695 %(cmap)s 

696 norm : :class:`~matplotlib.colors.Normalize` 

697 This object is typically found as the ``norm`` attribute of 

698 :class:`~matplotlib.image.AxesImage`. 

699 

700 threshold : :obj:`float` or ``None``, optional 

701 The absolute value at which the colorbar is thresholded. 

702 

703 cbar_vmin : :obj:`float`, optional 

704 Minimal value for the colorbar. If None, the minimal value 

705 is computed based on the data. 

706 

707 cbar_vmax : :obj:`float`, optional 

708 Maximal value for the colorbar. If None, the maximal value 

709 is computed based on the data. 

710 """ 

711 offset = 0 if threshold is None else threshold 

712 offset = min(offset, norm.vmax) 

713 

714 cbar_vmin = cbar_vmin if cbar_vmin is not None else norm.vmin 

715 cbar_vmax = cbar_vmax if cbar_vmax is not None else norm.vmax 

716 

717 # create new axis for the colorbar 

718 figure = self.frame_axes.figure 

719 _, y0, x1, y1 = self.rect 

720 height = y1 - y0 

721 x_adjusted_width = self._colorbar_width / len(self.axes) 

722 x_adjusted_margin = self._colorbar_margin["right"] / len(self.axes) 

723 lt_wid_top_ht = [ 

724 x1 - (x_adjusted_width + x_adjusted_margin), 

725 y0 + self._colorbar_margin["top"], 

726 x_adjusted_width, 

727 height 

728 - (self._colorbar_margin["top"] + self._colorbar_margin["bottom"]), 

729 ] 

730 self._colorbar_ax = figure.add_axes(lt_wid_top_ht) 

731 self._colorbar_ax.set_facecolor("w") 

732 

733 our_cmap = plt.get_cmap(cmap) 

734 # edge case where the data has a single value 

735 # yields a cryptic matplotlib error message 

736 # when trying to plot the color bar 

737 n_ticks = 5 if cbar_vmin != cbar_vmax else 1 

738 ticks = get_cbar_ticks(cbar_vmin, cbar_vmax, offset, n_ticks) 

739 bounds = np.linspace(cbar_vmin, cbar_vmax, our_cmap.N) 

740 

741 # some colormap hacking 

742 cmaplist = [our_cmap(i) for i in range(our_cmap.N)] 

743 transparent_start = int(norm(-offset, clip=True) * (our_cmap.N - 1)) 

744 transparent_stop = int(norm(offset, clip=True) * (our_cmap.N - 1)) 

745 for i in range(transparent_start, transparent_stop): 

746 cmaplist[i] = (*self._brain_color, 0.0) # transparent 

747 if cbar_vmin == cbar_vmax: # len(np.unique(data)) == 1 ? 

748 return 

749 else: 

750 our_cmap = LinearSegmentedColormap.from_list( 

751 "Custom cmap", cmaplist, our_cmap.N 

752 ) 

753 self._cbar = ColorbarBase( 

754 self._colorbar_ax, 

755 ticks=ticks, 

756 norm=norm, 

757 orientation="vertical", 

758 cmap=our_cmap, 

759 boundaries=bounds, 

760 spacing="proportional", 

761 format=self._cbar_tick_format, 

762 ) 

763 self._cbar.ax.set_facecolor(self._brain_color) 

764 

765 self._colorbar_ax.yaxis.tick_left() 

766 tick_color = "w" if self._black_bg else "k" 

767 outline_color = "w" if self._black_bg else "k" 

768 

769 for tick in self._colorbar_ax.yaxis.get_ticklabels(): 

770 tick.set_color(tick_color) 

771 self._colorbar_ax.yaxis.set_tick_params(width=0) 

772 self._cbar.outline.set_edgecolor(outline_color) 

773 

774 @fill_doc 

775 def add_edges(self, img, color="r"): 

776 """Plot the edges of a 3D map in all the views. 

777 

778 Parameters 

779 ---------- 

780 %(img)s 

781 The 3D map to be plotted. 

782 If it is a masked array, only the non-masked part will be plotted. 

783 

784 color : matplotlib color: :obj:`str` or (r, g, b) value, default='r' 

785 The color used to display the edge map. 

786 

787 """ 

788 img = reorder_img(img, resample="continuous", copy_header=True) 

789 data = get_data(img) 

790 affine = img.affine 

791 single_color_cmap = ListedColormap([color]) 

792 data_bounds = get_bounds(data.shape, img.affine) 

793 

794 # For each ax, cut the data and plot it 

795 for display_ax in self.axes.values(): 

796 try: 

797 data_2d = display_ax.transform_to_2d(data, affine) 

798 edge_mask = edge_map(data_2d) 

799 except IndexError: 

800 # We are cutting outside the indices of the data 

801 continue 

802 display_ax.draw_2d( 

803 edge_mask, 

804 data_bounds, 

805 data_bounds, 

806 type="imshow", 

807 cmap=single_color_cmap, 

808 ) 

809 

810 plt.draw_if_interactive() 

811 

812 def add_markers( 

813 self, marker_coords, marker_color="r", marker_size=30, **kwargs 

814 ): 

815 """Add markers to the plot. 

816 

817 Parameters 

818 ---------- 

819 marker_coords : :class:`~numpy.ndarray` of shape ``(n_markers, 3)`` 

820 Coordinates of the markers to plot. For each slice, only markers 

821 that are 2 millimeters away from the slice are plotted. 

822 

823 marker_color : pyplot compatible color or \ 

824 :obj:`list` of shape ``(n_markers,)``, default='r' 

825 List of colors for each marker 

826 that can be string or matplotlib colors. 

827 

828 

829 marker_size : :obj:`float` or \ 

830 :obj:`list` of :obj:`float` of shape ``(n_markers,)``, \ 

831 default=30 

832 Size in pixel for each marker. 

833 """ 

834 defaults = {"marker": "o", "zorder": 1000} 

835 marker_coords = np.asanyarray(marker_coords) 

836 for k, v in defaults.items(): 

837 kwargs.setdefault(k, v) 

838 

839 for display_ax in self.axes.values(): 

840 direction = display_ax.direction 

841 coord = display_ax.coord 

842 marker_coords_2d, third_d = coords_3d_to_2d( 

843 marker_coords, direction, return_direction=True 

844 ) 

845 xdata, ydata = marker_coords_2d.T 

846 # Allow markers only in their respective hemisphere 

847 # when appropriate 

848 marker_color_ = marker_color 

849 marker_size_ = marker_size 

850 if direction in ("lr"): 

851 if not isinstance(marker_color, str) and not isinstance( 

852 marker_color, np.ndarray 

853 ): 

854 marker_color_ = np.asarray(marker_color) 

855 xcoords, *_ = marker_coords.T 

856 if direction == "r": 

857 relevant_coords = xcoords >= 0 

858 elif direction == "l": 

859 relevant_coords = xcoords <= 0 

860 xdata = xdata[relevant_coords] 

861 ydata = ydata[relevant_coords] 

862 if ( 

863 not isinstance(marker_color, str) 

864 and len(marker_color) != 1 

865 ): 

866 marker_color_ = marker_color_[relevant_coords] 

867 if not isinstance(marker_size, numbers.Number): 

868 marker_size_ = np.asarray(marker_size_)[relevant_coords] 

869 

870 # Check if coord has integer represents a cut in direction 

871 # to follow the heuristic. If no foreground image is given 

872 # coordinate is empty or None. This case is valid for plotting 

873 # markers on glass brain without any foreground image. 

874 if isinstance(coord, numbers.Number): 

875 # Heuristic that plots only markers that are 2mm away 

876 # from the current slice. 

877 # XXX: should we keep this heuristic? 

878 mask = np.abs(third_d - coord) <= 2.0 

879 xdata = xdata[mask] 

880 ydata = ydata[mask] 

881 display_ax.ax.scatter( 

882 xdata, ydata, s=marker_size_, c=marker_color_, **kwargs 

883 ) 

884 

885 def annotate( 

886 self, 

887 left_right=True, 

888 positions=True, 

889 scalebar=False, 

890 size=12, 

891 scale_size=5.0, 

892 scale_units="cm", 

893 scale_loc=4, 

894 decimals=0, 

895 **kwargs, 

896 ): 

897 """Add annotations to the plot. 

898 

899 Parameters 

900 ---------- 

901 left_right : :obj:`bool`, default=True 

902 If ``True``, annotations indicating which side 

903 is left and which side is right are drawn. 

904 

905 

906 positions : :obj:`bool`, default=True 

907 If ``True``, annotations indicating the 

908 positions of the cuts are drawn. 

909 

910 

911 scalebar : :obj:`bool`, default=False 

912 If ``True``, cuts are annotated with a reference scale bar. 

913 For finer control of the scale bar, please check out 

914 the ``draw_scale_bar`` method on the axes in "axes" attribute 

915 of this object. 

916 

917 

918 size : :obj:`int`, default=12 

919 The size of the text used. 

920 

921 scale_size : :obj:`int` or :obj:`float`, default=5.0 

922 The length of the scalebar, in units of ``scale_units``. 

923 

924 

925 scale_units : {'cm', 'mm'}, default='cm' 

926 The units for the ``scalebar``. 

927 

928 scale_loc : :obj:`int`, default=4 

929 The positioning for the scalebar. 

930 Valid location codes are: 

931 

932 - 1: "upper right" 

933 - 2: "upper left" 

934 - 3: "lower left" 

935 - 4: "lower right" 

936 - 5: "right" 

937 - 6: "center left" 

938 - 7: "center right" 

939 - 8: "lower center" 

940 - 9: "upper center" 

941 - 10: "center" 

942 

943 decimals : :obj:`int`, default=0 

944 Number of decimal places on slice position annotation. If zero, 

945 the slice position is integer without decimal point. 

946 

947 

948 kwargs : :obj:`dict` 

949 Extra keyword arguments are passed to matplotlib's text 

950 function. 

951 """ 

952 kwargs = kwargs.copy() 

953 if "color" not in kwargs: 

954 kwargs["color"] = "w" if self._black_bg else "k" 

955 bg_color = "k" if self._black_bg else "w" 

956 

957 if left_right: 

958 for display_axis in self.axes.values(): 

959 display_axis.draw_left_right( 

960 size=size, bg_color=bg_color, **kwargs 

961 ) 

962 

963 if positions: 

964 for display_axis in self.axes.values(): 

965 display_axis.draw_position( 

966 size=size, bg_color=bg_color, decimals=decimals, **kwargs 

967 ) 

968 

969 if scalebar: 

970 axes = self.axes.values() 

971 for display_axis in axes: 

972 display_axis.draw_scale_bar( 

973 bg_color=bg_color, 

974 fontsize=size, 

975 size=scale_size, 

976 units=scale_units, 

977 loc=scale_loc, 

978 **kwargs, 

979 ) 

980 

981 def close(self): 

982 """Close the figure. 

983 

984 This is necessary to avoid leaking memory. 

985 """ 

986 plt.close(self.frame_axes.figure.number) 

987 

988 def savefig(self, filename, dpi=None, **kwargs): 

989 """Save the figure to a file. 

990 

991 Parameters 

992 ---------- 

993 filename : :obj:`str` 

994 The file name to save to. Its extension determines the 

995 file type, typically '.png', '.svg' or '.pdf'. 

996 

997 dpi : ``None`` or scalar, default=None 

998 The resolution in dots per inch. 

999 

1000 """ 

1001 facecolor = edgecolor = "k" if self._black_bg else "w" 

1002 self.frame_axes.figure.savefig( 

1003 filename, 

1004 dpi=dpi, 

1005 facecolor=facecolor, 

1006 edgecolor=edgecolor, 

1007 **kwargs, 

1008 ) 

1009 

1010 

1011@fill_doc 

1012class OrthoSlicer(BaseSlicer): 

1013 """Class to create 3 linked axes for plotting orthogonal \ 

1014 cuts of 3D maps. 

1015 

1016 This visualization mode can be activated 

1017 from Nilearn plotting functions, like 

1018 :func:`~nilearn.plotting.plot_img`, by setting 

1019 ``display_mode='ortho'``: 

1020 

1021 .. code-block:: python 

1022 

1023 from nilearn.datasets import load_mni152_template 

1024 from nilearn.plotting import plot_img 

1025 

1026 img = load_mni152_template() 

1027 # display is an instance of the OrthoSlicer class 

1028 display = plot_img(img, display_mode="ortho") 

1029 

1030 

1031 Attributes 

1032 ---------- 

1033 cut_coords : :obj:`list` 

1034 The cut coordinates. 

1035 

1036 axes : :obj:`dict` of :class:`~matplotlib.axes.Axes` 

1037 The 3 axes used to plot each view. 

1038 

1039 frame_axes : :class:`~matplotlib.axes.Axes` 

1040 The axes framing the whole set of views. 

1041 

1042 Notes 

1043 ----- 

1044 The extent of the different axes are adjusted to fit the data 

1045 best in the viewing area. 

1046 

1047 See Also 

1048 -------- 

1049 nilearn.plotting.displays.MosaicSlicer : Three cuts are performed \ 

1050 along multiple rows and columns. 

1051 nilearn.plotting.displays.TiledSlicer : Three cuts are performed \ 

1052 and arranged in a 2x2 grid. 

1053 

1054 """ 

1055 

1056 _cut_displayed: ClassVar[str] = "yxz" 

1057 _axes_class = CutAxes 

1058 _default_figsize: ClassVar[list[float]] = [2.2, 3.5] 

1059 

1060 @classmethod 

1061 @fill_doc # the fill_doc decorator must be last applied 

1062 def find_cut_coords(cls, img=None, threshold=None, cut_coords=None): 

1063 """Instantiate the slicer and find cut coordinates. 

1064 

1065 Parameters 

1066 ---------- 

1067 %(img)s 

1068 threshold : :obj:`int` or :obj:`float` or ``None``, default=None 

1069 Threshold to apply: 

1070 

1071 - If ``None`` is given, the maps are not thresholded. 

1072 - If number is given, it must be non-negative. The specified 

1073 value is used to threshold the image: values below the 

1074 threshold (in absolute value) are plotted as transparent. 

1075 

1076 cut_coords : 3 :obj:`tuple` of :obj:`int` 

1077 The cut position, in world space. 

1078 

1079 Raises 

1080 ------ 

1081 ValueError 

1082 if the specified threshold is a negative number 

1083 """ 

1084 if cut_coords is None: 

1085 if img is None or img is False: 

1086 cut_coords = (0, 0, 0) 

1087 else: 

1088 cut_coords = find_xyz_cut_coords( 

1089 img, activation_threshold=threshold 

1090 ) 

1091 cut_coords = [ 

1092 cut_coords["xyz".find(c)] for c in sorted(cls._cut_displayed) 

1093 ] 

1094 return cut_coords 

1095 

1096 def _init_axes(self, **kwargs): 

1097 cut_coords = self.cut_coords 

1098 if len(cut_coords) != len(self._cut_displayed): 

1099 raise ValueError( 

1100 "The number cut_coords passed does not match the display_mode" 

1101 ) 

1102 x0, y0, x1, y1 = self.rect 

1103 facecolor = "k" if self._black_bg else "w" 

1104 # Create our axes: 

1105 self.axes = {} 

1106 for index, direction in enumerate(self._cut_displayed): 

1107 fh = self.frame_axes.get_figure() 

1108 ax = fh.add_axes( 

1109 [0.3 * index * (x1 - x0) + x0, y0, 0.3 * (x1 - x0), y1 - y0], 

1110 aspect="equal", 

1111 ) 

1112 ax.set_facecolor(facecolor) 

1113 

1114 ax.axis("off") 

1115 coord = self.cut_coords[ 

1116 sorted(self._cut_displayed).index(direction) 

1117 ] 

1118 display_ax = self._axes_class(ax, direction, coord, **kwargs) 

1119 self.axes[direction] = display_ax 

1120 ax.set_axes_locator(self._locator) 

1121 

1122 if self._black_bg: 

1123 for ax in self.axes.values(): 

1124 ax.ax.imshow( 

1125 np.zeros((2, 2, 3)), 

1126 extent=[-5000, 5000, -5000, 5000], 

1127 zorder=-500, 

1128 aspect="equal", 

1129 ) 

1130 

1131 # To have a black background in PDF, we need to create a 

1132 # patch in black for the background 

1133 self.frame_axes.imshow( 

1134 np.zeros((2, 2, 3)), 

1135 extent=[-5000, 5000, -5000, 5000], 

1136 zorder=-500, 

1137 aspect="auto", 

1138 ) 

1139 self.frame_axes.set_zorder(-1000) 

1140 

1141 def _locator( 

1142 self, 

1143 axes, 

1144 renderer, # noqa: ARG002 

1145 ): 

1146 """Adjust the size of the axes. 

1147 

1148 The locator function used by matplotlib to position axes. 

1149 

1150 Here we put the logic used to adjust the size of the axes. 

1151 

1152 ``renderer`` is required to match the matplotlib API. 

1153 """ 

1154 x0, y0, x1, y1 = self.rect 

1155 # A dummy axes, for the situation in which we are not plotting 

1156 # all three (x, y, z) cuts 

1157 dummy_ax = self._axes_class(None, None, None) 

1158 width_dict = {dummy_ax.ax: 0} 

1159 display_ax_dict = self.axes 

1160 

1161 if self._colorbar: 

1162 adjusted_width = self._colorbar_width / len(self.axes) 

1163 right_margin = self._colorbar_margin["right"] / len(self.axes) 

1164 ticks_margin = self._colorbar_margin["left"] / len(self.axes) 

1165 x1 = x1 - (adjusted_width + ticks_margin + right_margin) 

1166 

1167 for display_ax in display_ax_dict.values(): 

1168 bounds = display_ax.get_object_bounds() 

1169 if not bounds: 

1170 # This happens if the call to _map_show was not 

1171 # successful. As it happens asynchronously (during a 

1172 # refresh of the figure) we capture the problem and 

1173 # ignore it: it only adds a non informative traceback 

1174 bounds = [0, 1, 0, 1] 

1175 xmin, xmax, _, _ = bounds 

1176 width_dict[display_ax.ax] = xmax - xmin 

1177 

1178 total_width = float(sum(width_dict.values())) 

1179 for ax, width in width_dict.items(): 

1180 width_dict[ax] = width / total_width * (x1 - x0) 

1181 

1182 direction_ax = [ 

1183 display_ax_dict.get(d, dummy_ax).ax for d in self._cut_displayed 

1184 ] 

1185 left_dict = {} 

1186 for idx, ax in enumerate(direction_ax): 

1187 left_dict[ax] = x0 

1188 for prev_ax in direction_ax[:idx]: 

1189 left_dict[ax] += width_dict[prev_ax] 

1190 

1191 return Bbox( 

1192 [[left_dict[axes], y0], [left_dict[axes] + width_dict[axes], y1]] 

1193 ) 

1194 

1195 def draw_cross(self, cut_coords=None, **kwargs): 

1196 """Draw a crossbar on the plot to show where the cut is performed. 

1197 

1198 Parameters 

1199 ---------- 

1200 cut_coords : 3-:obj:`tuple` of :obj:`float`, optional 

1201 The position of the cross to draw. If ``None`` is passed, the 

1202 ``OrthoSlicer``'s cut coordinates are used. 

1203 

1204 kwargs : :obj:`dict` 

1205 Extra keyword arguments are passed to function 

1206 :func:`~matplotlib.pyplot.axhline`. 

1207 """ 

1208 if cut_coords is None: 

1209 cut_coords = self.cut_coords 

1210 coords = {} 

1211 for direction in "xyz": 

1212 coord = None 

1213 if direction in self._cut_displayed: 

1214 coord = cut_coords[ 

1215 sorted(self._cut_displayed).index(direction) 

1216 ] 

1217 coords[direction] = coord 

1218 x, y, z = coords["x"], coords["y"], coords["z"] 

1219 

1220 kwargs = kwargs.copy() 

1221 if "color" not in kwargs: 

1222 kwargs["color"] = ".8" if self._black_bg else "k" 

1223 if "y" in self.axes: 

1224 ax = self.axes["y"].ax 

1225 if x is not None: 

1226 ax.axvline(x, ymin=0.05, ymax=0.95, **kwargs) 

1227 if z is not None: 

1228 ax.axhline(z, **kwargs) 

1229 

1230 if "x" in self.axes: 

1231 ax = self.axes["x"].ax 

1232 if y is not None: 

1233 ax.axvline(y, ymin=0.05, ymax=0.95, **kwargs) 

1234 if z is not None: 

1235 ax.axhline(z, xmax=0.95, **kwargs) 

1236 

1237 if "z" in self.axes: 

1238 ax = self.axes["z"].ax 

1239 if x is not None: 

1240 ax.axvline(x, ymin=0.05, ymax=0.95, **kwargs) 

1241 if y is not None: 

1242 ax.axhline(y, **kwargs) 

1243 

1244 

1245class TiledSlicer(BaseSlicer): 

1246 """A class to create 3 axes for plotting orthogonal \ 

1247 cuts of 3D maps, organized in a 2x2 grid. 

1248 

1249 This visualization mode can be activated from Nilearn plotting functions, 

1250 like :func:`~nilearn.plotting.plot_img`, by setting 

1251 ``display_mode='tiled'``: 

1252 

1253 .. code-block:: python 

1254 

1255 from nilearn.datasets import load_mni152_template 

1256 from nilearn.plotting import plot_img 

1257 

1258 img = load_mni152_template() 

1259 # display is an instance of the TiledSlicer class 

1260 display = plot_img(img, display_mode="tiled") 

1261 

1262 Attributes 

1263 ---------- 

1264 cut_coords : :obj:`list` 

1265 The cut coordinates. 

1266 

1267 axes : :obj:`dict` of :class:`~matplotlib.axes.Axes` 

1268 The 3 axes used to plot each view. 

1269 

1270 frame_axes : :class:`~matplotlib.axes.Axes` 

1271 The axes framing the whole set of views. 

1272 

1273 Notes 

1274 ----- 

1275 The extent of the different axes are adjusted to fit the data 

1276 best in the viewing area. 

1277 

1278 See Also 

1279 -------- 

1280 nilearn.plotting.displays.MosaicSlicer : Three cuts are performed \ 

1281 along multiple rows and columns. 

1282 nilearn.plotting.displays.OrthoSlicer : Three cuts are performed \ 

1283 and arranged in a 2x2 grid. 

1284 

1285 """ 

1286 

1287 _cut_displayed: ClassVar[str] = "yxz" 

1288 _axes_class = CutAxes 

1289 _default_figsize: ClassVar[list[float]] = [2.0, 7.6] 

1290 

1291 @classmethod 

1292 def find_cut_coords(cls, img=None, threshold=None, cut_coords=None): 

1293 """Instantiate the slicer and find cut coordinates. 

1294 

1295 Parameters 

1296 ---------- 

1297 img : 3D :class:`~nibabel.nifti1.Nifti1Image` 

1298 The brain map. 

1299 

1300 threshold : :obj:`float`, optional 

1301 The lower threshold to the positive activation. 

1302 If ``None``, the activation threshold is computed using the 

1303 80% percentile of the absolute value of the map. 

1304 

1305 cut_coords : :obj:`list` of :obj:`float`, optional 

1306 xyz world coordinates of cuts. 

1307 

1308 Returns 

1309 ------- 

1310 cut_coords : :obj:`list` of :obj:`float` 

1311 xyz world coordinates of cuts. 

1312 

1313 Raises 

1314 ------ 

1315 ValueError 

1316 if the specified threshold is a negative number 

1317 """ 

1318 if cut_coords is None: 

1319 if img is None or img is False: 

1320 cut_coords = (0, 0, 0) 

1321 else: 

1322 cut_coords = find_xyz_cut_coords( 

1323 img, activation_threshold=threshold 

1324 ) 

1325 cut_coords = [ 

1326 cut_coords["xyz".find(c)] for c in sorted(cls._cut_displayed) 

1327 ] 

1328 

1329 return cut_coords 

1330 

1331 def _find_initial_axes_coord(self, index): 

1332 """Find coordinates for initial axes placement for xyz cuts. 

1333 

1334 Parameters 

1335 ---------- 

1336 index : :obj:`int` 

1337 Index corresponding to current cut 'x', 'y' or 'z'. 

1338 

1339 Returns 

1340 ------- 

1341 [coord1, coord2, coord3, coord4] : :obj:`list` of :obj:`int` 

1342 x0, y0, x1, y1 coordinates used by matplotlib 

1343 to position axes in figure. 

1344 """ 

1345 rect_x0, rect_y0, rect_x1, rect_y1 = self.rect 

1346 

1347 if index == 0: 

1348 coord1 = rect_x1 - rect_x0 

1349 coord2 = 0.5 * (rect_y1 - rect_y0) + rect_y0 

1350 coord3 = 0.5 * (rect_x1 - rect_x0) + rect_x0 

1351 coord4 = rect_y1 - rect_y0 

1352 elif index == 1: 

1353 coord1 = 0.5 * (rect_x1 - rect_x0) + rect_x0 

1354 coord2 = 0.5 * (rect_y1 - rect_y0) + rect_y0 

1355 coord3 = rect_x1 - rect_x0 

1356 coord4 = rect_y1 - rect_y0 

1357 elif index == 2: 

1358 coord1 = rect_x1 - rect_x0 

1359 coord2 = rect_y1 - rect_y0 

1360 coord3 = 0.5 * (rect_x1 - rect_x0) + rect_x0 

1361 coord4 = 0.5 * (rect_y1 - rect_y0) + rect_y0 

1362 return [coord1, coord2, coord3, coord4] 

1363 

1364 def _init_axes(self, **kwargs): 

1365 """Initialize and place axes for display of 'xyz' cuts. 

1366 

1367 Parameters 

1368 ---------- 

1369 kwargs : :obj:`dict` 

1370 Additional arguments to pass to ``self._axes_class``. 

1371 """ 

1372 cut_coords = self.cut_coords 

1373 if len(cut_coords) != len(self._cut_displayed): 

1374 raise ValueError( 

1375 "The number cut_coords passed does not match the display_mode" 

1376 ) 

1377 

1378 facecolor = "k" if self._black_bg else "w" 

1379 

1380 self.axes = {} 

1381 for index, direction in enumerate(self._cut_displayed): 

1382 fh = self.frame_axes.get_figure() 

1383 axes_coords = self._find_initial_axes_coord(index) 

1384 ax = fh.add_axes(axes_coords, aspect="equal") 

1385 

1386 ax.set_facecolor(facecolor) 

1387 

1388 ax.axis("off") 

1389 coord = self.cut_coords[ 

1390 sorted(self._cut_displayed).index(direction) 

1391 ] 

1392 display_ax = self._axes_class(ax, direction, coord, **kwargs) 

1393 self.axes[direction] = display_ax 

1394 ax.set_axes_locator(self._locator) 

1395 

1396 def _adjust_width_height( 

1397 self, width_dict, height_dict, rect_x0, rect_y0, rect_x1, rect_y1 

1398 ): 

1399 """Adjust absolute image width and height to ratios. 

1400 

1401 Parameters 

1402 ---------- 

1403 width_dict : :obj:`dict` 

1404 Width of image cuts displayed in axes. 

1405 

1406 height_dict : :obj:`dict` 

1407 Height of image cuts displayed in axes. 

1408 

1409 rect_x0, rect_y0, rect_x1, rect_y1 : :obj:`float` 

1410 Matplotlib figure boundaries. 

1411 

1412 Returns 

1413 ------- 

1414 width_dict : :obj:`dict` 

1415 Width ratios of image cuts for optimal positioning of axes. 

1416 

1417 height_dict : :obj:`dict` 

1418 Height ratios of image cuts for optimal positioning of axes. 

1419 """ 

1420 total_height = 0 

1421 total_width = 0 

1422 

1423 if "y" in self.axes: 

1424 ax = self.axes["y"].ax 

1425 total_height += height_dict[ax] 

1426 total_width += width_dict[ax] 

1427 

1428 if "x" in self.axes: 

1429 ax = self.axes["x"].ax 

1430 total_width = total_width + width_dict[ax] 

1431 

1432 if "z" in self.axes: 

1433 ax = self.axes["z"].ax 

1434 total_height = total_height + height_dict[ax] 

1435 

1436 for ax, width in width_dict.items(): 

1437 width_dict[ax] = width / total_width * (rect_x1 - rect_x0) 

1438 

1439 for ax, height in height_dict.items(): 

1440 height_dict[ax] = height / total_height * (rect_y1 - rect_y0) 

1441 

1442 return (width_dict, height_dict) 

1443 

1444 def _find_axes_coord( 

1445 self, 

1446 rel_width_dict, 

1447 rel_height_dict, 

1448 rect_x0, 

1449 rect_y0, 

1450 rect_x1, 

1451 rect_y1, 

1452 ): 

1453 """Find coordinates for initial axes placement for xyz cuts. 

1454 

1455 Parameters 

1456 ---------- 

1457 rel_width_dict : :obj:`dict` 

1458 Width ratios of image cuts for optimal positioning of axes. 

1459 

1460 rel_height_dict : :obj:`dict` 

1461 Height ratios of image cuts for optimal positioning of axes. 

1462 

1463 rect_x0, rect_y0, rect_x1, rect_y1 : :obj:`float` 

1464 Matplotlib figure boundaries. 

1465 

1466 Returns 

1467 ------- 

1468 coord1, coord2, coord3, coord4 : :obj:`dict` 

1469 x0, y0, x1, y1 coordinates per axes used by matplotlib 

1470 to position axes in figure. 

1471 """ 

1472 coord1 = {} 

1473 coord2 = {} 

1474 coord3 = {} 

1475 coord4 = {} 

1476 

1477 if "y" in self.axes: 

1478 ax = self.axes["y"].ax 

1479 coord1[ax] = rect_x0 

1480 coord2[ax] = (rect_y1) - rel_height_dict[ax] 

1481 coord3[ax] = rect_x0 + rel_width_dict[ax] 

1482 coord4[ax] = rect_y1 

1483 

1484 if "x" in self.axes: 

1485 ax = self.axes["x"].ax 

1486 coord1[ax] = (rect_x1) - rel_width_dict[ax] 

1487 coord2[ax] = (rect_y1) - rel_height_dict[ax] 

1488 coord3[ax] = rect_x1 

1489 coord4[ax] = rect_y1 

1490 

1491 if "z" in self.axes: 

1492 ax = self.axes["z"].ax 

1493 coord1[ax] = rect_x0 

1494 coord2[ax] = rect_y0 

1495 coord3[ax] = rect_x0 + rel_width_dict[ax] 

1496 coord4[ax] = rect_y0 + rel_height_dict[ax] 

1497 

1498 return (coord1, coord2, coord3, coord4) 

1499 

1500 def _locator( 

1501 self, 

1502 axes, 

1503 renderer, # noqa: ARG002 

1504 ): 

1505 """Adjust the size of the axes. 

1506 

1507 The locator function used by matplotlib to position axes. 

1508 

1509 Here we put the logic used to adjust the size of the axes. 

1510 

1511 ``renderer`` is required to match the matplotlib API. 

1512 """ 

1513 rect_x0, rect_y0, rect_x1, rect_y1 = self.rect 

1514 

1515 # A dummy axes, for the situation in which we are not plotting 

1516 # all three (x, y, z) cuts 

1517 dummy_ax = self._axes_class(None, None, None) 

1518 width_dict = {dummy_ax.ax: 0} 

1519 height_dict = {dummy_ax.ax: 0} 

1520 display_ax_dict = self.axes 

1521 

1522 if self._colorbar: 

1523 adjusted_width = self._colorbar_width / len(self.axes) 

1524 right_margin = self._colorbar_margin["right"] / len(self.axes) 

1525 ticks_margin = self._colorbar_margin["left"] / len(self.axes) 

1526 rect_x1 = rect_x1 - (adjusted_width + ticks_margin + right_margin) 

1527 

1528 for display_ax in display_ax_dict.values(): 

1529 bounds = display_ax.get_object_bounds() 

1530 if not bounds: 

1531 # This happens if the call to _map_show was not 

1532 # successful. As it happens asynchronously (during a 

1533 # refresh of the figure) we capture the problem and 

1534 # ignore it: it only adds a non informative traceback 

1535 bounds = [0, 1, 0, 1] 

1536 xmin, xmax, ymin, ymax = bounds 

1537 width_dict[display_ax.ax] = xmax - xmin 

1538 height_dict[display_ax.ax] = ymax - ymin 

1539 

1540 # relative image height and width 

1541 rel_width_dict, rel_height_dict = self._adjust_width_height( 

1542 width_dict, height_dict, rect_x0, rect_y0, rect_x1, rect_y1 

1543 ) 

1544 

1545 coord1, coord2, coord3, coord4 = self._find_axes_coord( 

1546 rel_width_dict, rel_height_dict, rect_x0, rect_y0, rect_x1, rect_y1 

1547 ) 

1548 

1549 return Bbox( 

1550 [[coord1[axes], coord2[axes]], [coord3[axes], coord4[axes]]] 

1551 ) 

1552 

1553 def draw_cross(self, cut_coords=None, **kwargs): 

1554 """Draw a crossbar on the plot to show where the cut is performed. 

1555 

1556 Parameters 

1557 ---------- 

1558 cut_coords : 3-:obj:`tuple` of :obj:`float`, optional 

1559 The position of the cross to draw. If ``None`` is passed, the 

1560 ``OrthoSlicer``'s cut coordinates are used. 

1561 

1562 kwargs : :obj:`dict` 

1563 Extra keyword arguments are passed to function 

1564 :func:`~matplotlib.pyplot.axhline`. 

1565 """ 

1566 if cut_coords is None: 

1567 cut_coords = self.cut_coords 

1568 coords = {} 

1569 for direction in "xyz": 

1570 coord_ = None 

1571 if direction in self._cut_displayed: 

1572 sorted_cuts = sorted(self._cut_displayed) 

1573 index = sorted_cuts.index(direction) 

1574 coord_ = cut_coords[index] 

1575 coords[direction] = coord_ 

1576 x, y, z = coords["x"], coords["y"], coords["z"] 

1577 

1578 kwargs = kwargs.copy() 

1579 if "color" not in kwargs: 

1580 with contextlib.suppress(KeyError): 

1581 kwargs["color"] = ".8" if self._black_bg else "k" 

1582 

1583 if "y" in self.axes: 

1584 ax = self.axes["y"].ax 

1585 if x is not None: 

1586 ax.axvline(x, **kwargs) 

1587 if z is not None: 

1588 ax.axhline(z, **kwargs) 

1589 

1590 if "x" in self.axes: 

1591 ax = self.axes["x"].ax 

1592 if y is not None: 

1593 ax.axvline(y, **kwargs) 

1594 if z is not None: 

1595 ax.axhline(z, **kwargs) 

1596 

1597 if "z" in self.axes: 

1598 ax = self.axes["z"].ax 

1599 if x is not None: 

1600 ax.axvline(x, **kwargs) 

1601 if y is not None: 

1602 ax.axhline(y, **kwargs) 

1603 

1604 

1605class BaseStackedSlicer(BaseSlicer): 

1606 """A class to create linked axes for plotting stacked cuts of 2D maps. 

1607 

1608 Attributes 

1609 ---------- 

1610 axes : :obj:`dict` of :class:`~matplotlib.axes.Axes` 

1611 The axes used to plot each view. 

1612 

1613 frame_axes : :class:`~matplotlib.axes.Axes` 

1614 The axes framing the whole set of views. 

1615 

1616 Notes 

1617 ----- 

1618 The extent of the different axes are adjusted to fit the data 

1619 best in the viewing area. 

1620 """ 

1621 

1622 @classmethod 

1623 def find_cut_coords( 

1624 cls, 

1625 img=None, 

1626 threshold=None, # noqa: ARG003 

1627 cut_coords=None, 

1628 ): 

1629 """Instantiate the slicer and find cut coordinates. 

1630 

1631 Parameters 

1632 ---------- 

1633 img : 3D :class:`~nibabel.nifti1.Nifti1Image` 

1634 The brain map. 

1635 

1636 threshold : :obj:`float`, optional 

1637 The lower threshold to the positive activation. 

1638 If ``None``, the activation threshold is computed using the 

1639 80% percentile of the absolute value of the map. 

1640 

1641 cut_coords : :obj:`list` of :obj:`float`, optional 

1642 xyz world coordinates of cuts. 

1643 

1644 Returns 

1645 ------- 

1646 cut_coords : :obj:`list` of :obj:`float` 

1647 xyz world coordinates of cuts. 

1648 """ 

1649 if cut_coords is None: 

1650 cut_coords = 7 

1651 

1652 if img is None or img is False: 

1653 bounds = ((-40, 40), (-30, 30), (-30, 75)) 

1654 lower, upper = bounds["xyz".index(cls._direction)] 

1655 if isinstance(cut_coords, numbers.Number): 

1656 cut_coords = np.linspace(lower, upper, cut_coords).tolist() 

1657 elif not isinstance( 

1658 cut_coords, collections.abc.Sequence 

1659 ) and isinstance(cut_coords, numbers.Number): 

1660 cut_coords = find_cut_slices( 

1661 img, direction=cls._direction, n_cuts=cut_coords 

1662 ) 

1663 

1664 return cut_coords 

1665 

1666 def _init_axes(self, **kwargs): 

1667 x0, y0, x1, y1 = self.rect 

1668 # Create our axes: 

1669 self.axes = {} 

1670 fraction = 1.0 / len(self.cut_coords) 

1671 for index, coord in enumerate(self.cut_coords): 

1672 coord = float(coord) 

1673 fh = self.frame_axes.get_figure() 

1674 ax = fh.add_axes( 

1675 [ 

1676 fraction * index * (x1 - x0) + x0, 

1677 y0, 

1678 fraction * (x1 - x0), 

1679 y1 - y0, 

1680 ] 

1681 ) 

1682 ax.axis("off") 

1683 display_ax = self._axes_class(ax, self._direction, coord, **kwargs) 

1684 self.axes[coord] = display_ax 

1685 ax.set_axes_locator(self._locator) 

1686 

1687 if self._black_bg: 

1688 for ax in self.axes.values(): 

1689 ax.ax.imshow( 

1690 np.zeros((2, 2, 3)), 

1691 extent=[-5000, 5000, -5000, 5000], 

1692 zorder=-500, 

1693 aspect="equal", 

1694 ) 

1695 

1696 # To have a black background in PDF, we need to create a 

1697 # patch in black for the background 

1698 self.frame_axes.imshow( 

1699 np.zeros((2, 2, 3)), 

1700 extent=[-5000, 5000, -5000, 5000], 

1701 zorder=-500, 

1702 aspect="auto", 

1703 ) 

1704 self.frame_axes.set_zorder(-1000) 

1705 

1706 def _locator( 

1707 self, 

1708 axes, 

1709 renderer, # noqa: ARG002 

1710 ): 

1711 """Adjust the size of the axes. 

1712 

1713 The locator function used by matplotlib to position axes. 

1714 

1715 Here we put the logic used to adjust the size of the axes. 

1716 

1717 ``renderer`` is required to match the matplotlib API. 

1718 """ 

1719 x0, y0, x1, y1 = self.rect 

1720 width_dict = {} 

1721 display_ax_dict = self.axes 

1722 

1723 if self._colorbar: 

1724 adjusted_width = self._colorbar_width / len(self.axes) 

1725 right_margin = self._colorbar_margin["right"] / len(self.axes) 

1726 ticks_margin = self._colorbar_margin["left"] / len(self.axes) 

1727 x1 = x1 - (adjusted_width + right_margin + ticks_margin) 

1728 

1729 for display_ax in display_ax_dict.values(): 

1730 bounds = display_ax.get_object_bounds() 

1731 if not bounds: 

1732 # This happens if the call to _map_show was not 

1733 # successful. As it happens asynchronously (during a 

1734 # refresh of the figure) we capture the problem and 

1735 # ignore it: it only adds a non informative traceback 

1736 bounds = [0, 1, 0, 1] 

1737 xmin, xmax, _, _ = bounds 

1738 width_dict[display_ax.ax] = xmax - xmin 

1739 total_width = float(sum(width_dict.values())) 

1740 for ax, width in width_dict.items(): 

1741 width_dict[ax] = width / total_width * (x1 - x0) 

1742 left_dict = {} 

1743 left = float(x0) 

1744 for display_ax in display_ax_dict.values(): 

1745 left_dict[display_ax.ax] = left 

1746 this_width = width_dict[display_ax.ax] 

1747 left += this_width 

1748 return Bbox( 

1749 [[left_dict[axes], y0], [left_dict[axes] + width_dict[axes], y1]] 

1750 ) 

1751 

1752 def draw_cross(self, cut_coords=None, **kwargs): 

1753 """Draw a crossbar on the plot to show where the cut is performed. 

1754 

1755 Parameters 

1756 ---------- 

1757 cut_coords : 3-:obj:`tuple` of :obj:`float`, optional 

1758 The position of the cross to draw. If ``None`` is passed, the 

1759 ``OrthoSlicer``'s cut coordinates are used. 

1760 

1761 kwargs : :obj:`dict` 

1762 Extra keyword arguments are passed to function 

1763 :func:`matplotlib.pyplot.axhline`. 

1764 """ 

1765 pass 

1766 

1767 

1768class XSlicer(BaseStackedSlicer): 

1769 """The ``XSlicer`` class enables sagittal visualization with \ 

1770 plotting functions of Nilearn like \ 

1771 :func:`nilearn.plotting.plot_img`. 

1772 

1773 This visualization mode 

1774 can be activated by setting ``display_mode='x'``: 

1775 

1776 .. code-block:: python 

1777 

1778 from nilearn.datasets import load_mni152_template 

1779 from nilearn.plotting import plot_img 

1780 

1781 img = load_mni152_template() 

1782 # display is an instance of the XSlicer class 

1783 display = plot_img(img, display_mode="x") 

1784 

1785 Attributes 

1786 ---------- 

1787 cut_coords : 1D :class:`~numpy.ndarray` 

1788 The cut coordinates. 

1789 

1790 axes : :obj:`dict` of :class:`~nilearn.plotting.displays.CutAxes` 

1791 The axes used for plotting. 

1792 

1793 frame_axes : :class:`~matplotlib.axes.Axes` 

1794 The axes framing the whole set of views. 

1795 

1796 See Also 

1797 -------- 

1798 nilearn.plotting.displays.YSlicer : Coronal view 

1799 nilearn.plotting.displays.ZSlicer : Axial view 

1800 

1801 """ 

1802 

1803 _direction: ClassVar[str] = "x" 

1804 _default_figsize: ClassVar[list[float]] = [2.6, 2.3] 

1805 

1806 

1807class YSlicer(BaseStackedSlicer): 

1808 """The ``YSlicer`` class enables coronal visualization with \ 

1809 plotting functions of Nilearn like \ 

1810 :func:`nilearn.plotting.plot_img`. 

1811 

1812 This visualization mode 

1813 can be activated by setting ``display_mode='y'``: 

1814 

1815 .. code-block:: python 

1816 

1817 from nilearn.datasets import load_mni152_template 

1818 from nilearn.plotting import plot_img 

1819 

1820 img = load_mni152_template() 

1821 # display is an instance of the YSlicer class 

1822 display = plot_img(img, display_mode="y") 

1823 

1824 Attributes 

1825 ---------- 

1826 cut_coords : 1D :class:`~numpy.ndarray` 

1827 The cut coordinates. 

1828 

1829 axes : :obj:`dict` of :class:`~nilearn.plotting.displays.CutAxes` 

1830 The axes used for plotting. 

1831 

1832 frame_axes : :class:`~matplotlib.axes.Axes` 

1833 The axes framing the whole set of views. 

1834 

1835 See Also 

1836 -------- 

1837 nilearn.plotting.displays.XSlicer : Sagittal view 

1838 nilearn.plotting.displays.ZSlicer : Axial view 

1839 

1840 """ 

1841 

1842 _direction: ClassVar[str] = "y" 

1843 _default_figsize: ClassVar[list[float]] = [2.2, 3.0] 

1844 

1845 

1846class ZSlicer(BaseStackedSlicer): 

1847 """The ``ZSlicer`` class enables axial visualization with \ 

1848 plotting functions of Nilearn like \ 

1849 :func:`nilearn.plotting.plot_img`. 

1850 

1851 This visualization mode 

1852 can be activated by setting ``display_mode='z'``: 

1853 

1854 .. code-block:: python 

1855 

1856 from nilearn.datasets import load_mni152_template 

1857 from nilearn.plotting import plot_img 

1858 

1859 img = load_mni152_template() 

1860 # display is an instance of the ZSlicer class 

1861 display = plot_img(img, display_mode="z") 

1862 

1863 Attributes 

1864 ---------- 

1865 cut_coords : 1D :class:`~numpy.ndarray` 

1866 The cut coordinates. 

1867 

1868 axes : :obj:`dict` of :class:`~nilearn.plotting.displays.CutAxes` 

1869 The axes used for plotting. 

1870 

1871 frame_axes : :class:`~matplotlib.axes.Axes` 

1872 The axes framing the whole set of views. 

1873 

1874 See Also 

1875 -------- 

1876 nilearn.plotting.displays.XSlicer : Sagittal view 

1877 nilearn.plotting.displays.YSlicer : Coronal view 

1878 

1879 """ 

1880 

1881 _direction: ClassVar[str] = "z" 

1882 _default_figsize: ClassVar[list[float]] = [2.2, 3.2] 

1883 

1884 

1885class XZSlicer(OrthoSlicer): 

1886 """The ``XZSlicer`` class enables to combine sagittal and axial views \ 

1887 on the same figure with plotting functions of Nilearn like \ 

1888 :func:`nilearn.plotting.plot_img`. 

1889 

1890 This visualization mode 

1891 can be activated by setting ``display_mode='xz'``: 

1892 

1893 .. code-block:: python 

1894 

1895 from nilearn.datasets import load_mni152_template 

1896 from nilearn.plotting import plot_img 

1897 

1898 img = load_mni152_template() 

1899 # display is an instance of the XZSlicer class 

1900 display = plot_img(img, display_mode="xz") 

1901 

1902 Attributes 

1903 ---------- 

1904 cut_coords : :obj:`list` of :obj:`float` 

1905 The cut coordinates. 

1906 

1907 axes : :obj:`dict` of :class:`~nilearn.plotting.displays.CutAxes` 

1908 The axes used for plotting in each direction ('x' and 'z' here). 

1909 

1910 frame_axes : :class:`~matplotlib.axes.Axes` 

1911 The axes framing the whole set of views. 

1912 

1913 See Also 

1914 -------- 

1915 nilearn.plotting.displays.YXSlicer : Coronal + Sagittal views 

1916 nilearn.plotting.displays.YZSlicer : Coronal + Axial views 

1917 

1918 """ 

1919 

1920 _cut_displayed = "xz" 

1921 

1922 

1923class YXSlicer(OrthoSlicer): 

1924 """The ``YXSlicer`` class enables to combine coronal and sagittal views \ 

1925 on the same figure with plotting functions of Nilearn like \ 

1926 :func:`nilearn.plotting.plot_img`. 

1927 

1928 This visualization mode 

1929 can be activated by setting ``display_mode='yx'``: 

1930 

1931 .. code-block:: python 

1932 

1933 from nilearn.datasets import load_mni152_template 

1934 from nilearn.plotting import plot_img 

1935 

1936 img = load_mni152_template() 

1937 # display is an instance of the YXSlicer class 

1938 display = plot_img(img, display_mode="yx") 

1939 

1940 Attributes 

1941 ---------- 

1942 cut_coords : :obj:`list` of :obj:`float` 

1943 The cut coordinates. 

1944 

1945 axes : :obj:`dict` of :class:`~nilearn.plotting.displays.CutAxes` 

1946 The axes used for plotting in each direction ('x' and 'y' here). 

1947 

1948 frame_axes : :class:`~matplotlib.axes.Axes` 

1949 The axes framing the whole set of views. 

1950 

1951 See Also 

1952 -------- 

1953 nilearn.plotting.displays.XZSlicer : Sagittal + Axial views 

1954 nilearn.plotting.displays.YZSlicer : Coronal + Axial views 

1955 

1956 """ 

1957 

1958 _cut_displayed = "yx" 

1959 

1960 

1961class YZSlicer(OrthoSlicer): 

1962 """The ``YZSlicer`` class enables to combine coronal and axial views \ 

1963 on the same figure with plotting functions of Nilearn like \ 

1964 :func:`nilearn.plotting.plot_img`. 

1965 

1966 This visualization mode 

1967 can be activated by setting ``display_mode='yz'``: 

1968 

1969 .. code-block:: python 

1970 

1971 from nilearn.datasets import load_mni152_template 

1972 from nilearn.plotting import plot_img 

1973 

1974 img = load_mni152_template() 

1975 # display is an instance of the YZSlicer class 

1976 display = plot_img(img, display_mode="yz") 

1977 

1978 Attributes 

1979 ---------- 

1980 cut_coords : :obj:`list` of :obj:`float` 

1981 The cut coordinates. 

1982 

1983 axes : :obj:`dict` of :class:`~nilearn.plotting.displays.CutAxes` 

1984 The axes used for plotting in each direction ('y' and 'z' here). 

1985 

1986 frame_axes : :class:`~matplotlib.axes.Axes` 

1987 The axes framing the whole set of views. 

1988 

1989 See Also 

1990 -------- 

1991 nilearn.plotting.displays.XZSlicer : Sagittal + Axial views 

1992 nilearn.plotting.displays.YXSlicer : Coronal + Sagittal views 

1993 

1994 """ 

1995 

1996 _cut_displayed: ClassVar[str] = "yz" 

1997 _default_figsize: ClassVar[list[float]] = [2.2, 3.0] 

1998 

1999 

2000class MosaicSlicer(BaseSlicer): 

2001 """A class to create 3 :class:`~matplotlib.axes.Axes` for \ 

2002 plotting cuts of 3D maps, in multiple rows and columns. 

2003 

2004 This visualization mode can be activated from Nilearn plotting 

2005 functions, like :func:`~nilearn.plotting.plot_img`, by setting 

2006 ``display_mode='mosaic'``. 

2007 

2008 .. code-block:: python 

2009 

2010 from nilearn.datasets import load_mni152_template 

2011 from nilearn.plotting import plot_img 

2012 

2013 img = load_mni152_template() 

2014 # display is an instance of the MosaicSlicer class 

2015 display = plot_img(img, display_mode="mosaic") 

2016 

2017 Attributes 

2018 ---------- 

2019 cut_coords : :obj:`dict` <:obj:`str`: 1D :class:`~numpy.ndarray`> 

2020 The cut coordinates in a dictionary. The keys are the directions 

2021 ('x', 'y', 'z'), and the values are arrays holding the cut 

2022 coordinates. 

2023 

2024 axes : :obj:`dict` of :class:`~matplotlib.axes.Axes` 

2025 The 3 axes used to plot multiple views. 

2026 

2027 frame_axes : :class:`~matplotlib.axes.Axes` 

2028 The axes framing the whole set of views. 

2029 

2030 See Also 

2031 -------- 

2032 nilearn.plotting.displays.TiledSlicer : Three cuts are performed \ 

2033 in orthogonal directions. 

2034 nilearn.plotting.displays.OrthoSlicer : Three cuts are performed \ 

2035 and arranged in a 2x2 grid. 

2036 

2037 """ 

2038 

2039 _cut_displayed: ClassVar[str] = "yxz" 

2040 _axes_class: ClassVar[CutAxes] = CutAxes # type: ignore[assignment, misc] 

2041 _default_figsize: ClassVar[list[float]] = [4.0, 5.0] 

2042 

2043 @classmethod 

2044 def find_cut_coords( 

2045 cls, 

2046 img=None, 

2047 threshold=None, # noqa: ARG003 

2048 cut_coords=None, 

2049 ): 

2050 """Instantiate the slicer and find cut coordinates for mosaic plotting. 

2051 

2052 Parameters 

2053 ---------- 

2054 img : 3D :class:`~nibabel.nifti1.Nifti1Image`, optional 

2055 The brain image. 

2056 

2057 threshold : :obj:`float`, optional 

2058 The lower threshold to the positive activation. If ``None``, 

2059 the activation threshold is computed using the 80% percentile of 

2060 the absolute value of the map. 

2061 

2062 cut_coords : :obj:`list` / :obj:`tuple` of 3 :obj:`float`,\ 

2063 :obj:`int`, optional 

2064 xyz world coordinates of cuts. If ``cut_coords`` 

2065 are not provided, 7 coordinates of cuts are automatically 

2066 calculated. 

2067 

2068 Returns 

2069 ------- 

2070 cut_coords : :obj:`dict` 

2071 xyz world coordinates of cuts in a direction. 

2072 Each key denotes the direction. 

2073 """ 

2074 if cut_coords is None: 

2075 cut_coords = 7 

2076 

2077 if not isinstance(cut_coords, collections.abc.Sequence) and isinstance( 

2078 cut_coords, numbers.Number 

2079 ): 

2080 cut_coords = [cut_coords] * 3 

2081 elif len(cut_coords) == len(cls._cut_displayed): 

2082 cut_coords = [ 

2083 cut_coords["xyz".find(c)] for c in sorted(cls._cut_displayed) 

2084 ] 

2085 else: 

2086 raise ValueError( 

2087 "The number cut_coords passed does not" 

2088 " match the display_mode. Mosaic plotting " 

2089 "expects tuple of length 3." 

2090 ) 

2091 cut_coords = cls._find_cut_coords(img, cut_coords, cls._cut_displayed) 

2092 return cut_coords 

2093 

2094 @staticmethod 

2095 def _find_cut_coords(img, cut_coords, cut_displayed): 

2096 """Find slicing positions along a given axis. 

2097 

2098 Help to :func:`~nilearn.plotting.find_cut_coords`. 

2099 

2100 Parameters 

2101 ---------- 

2102 img : 3D :class:`~nibabel.nifti1.Nifti1Image` 

2103 The brain image. 

2104 

2105 cut_coords : :obj:`list` / :obj:`tuple` of 3 :obj:`float`,\ 

2106 :obj:`int`, optional 

2107 xyz world coordinates of cuts. 

2108 

2109 cut_displayed : :obj:`str` 

2110 Sectional directions 'yxz'. 

2111 

2112 Returns 

2113 ------- 

2114 cut_coords : 1D :class:`~numpy.ndarray` of length specified\ 

2115 in ``n_cuts`` 

2116 The computed ``cut_coords``. 

2117 """ 

2118 coords = {} 

2119 if img is None or img is False: 

2120 bounds = ((-40, 40), (-30, 30), (-30, 75)) 

2121 for direction, n_cuts in zip(sorted(cut_displayed), cut_coords): 

2122 lower, upper = bounds["xyz".index(direction)] 

2123 coords[direction] = np.linspace(lower, upper, n_cuts).tolist() 

2124 else: 

2125 for direction, n_cuts in zip(sorted(cut_displayed), cut_coords): 

2126 coords[direction] = find_cut_slices( 

2127 img, direction=direction, n_cuts=n_cuts 

2128 ) 

2129 return coords 

2130 

2131 def _init_axes(self, **kwargs): 

2132 """Initialize and place axes for display of 'xyz' multiple cuts. 

2133 

2134 Also adapts the width of the color bar relative to the axes. 

2135 

2136 Parameters 

2137 ---------- 

2138 kwargs : :obj:`dict` 

2139 Additional arguments to pass to ``self._axes_class``. 

2140 """ 

2141 if not isinstance(self.cut_coords, dict): 

2142 self.cut_coords = self.find_cut_coords(cut_coords=self.cut_coords) 

2143 

2144 if len(self.cut_coords) != len(self._cut_displayed): 

2145 raise ValueError( 

2146 "The number cut_coords passed does not match the mosaic mode" 

2147 ) 

2148 x0, y0, x1, y1 = self.rect 

2149 

2150 # Create our axes: 

2151 self.axes = {} 

2152 # portions for main axes 

2153 fraction = y1 / len(self.cut_coords) 

2154 height = fraction 

2155 for index, direction in enumerate(self._cut_displayed): 

2156 coords = self.cut_coords[direction] 

2157 # portions allotment for each of 'x', 'y', 'z' coordinate 

2158 fraction_c = 1.0 / len(coords) 

2159 fh = self.frame_axes.get_figure() 

2160 indices = [ 

2161 x0, 

2162 fraction * index * (y1 - y0) + y0, 

2163 x1, 

2164 fraction * (y1 - y0), 

2165 ] 

2166 ax = fh.add_axes(indices) 

2167 ax.axis("off") 

2168 this_x0, this_y0, this_x1, _ = indices 

2169 for index_c, coord in enumerate(coords): 

2170 coord = float(coord) 

2171 fh_c = ax.get_figure() 

2172 # indices for each sub axes within main axes 

2173 indices = [ 

2174 fraction_c * index_c * (this_x1 - this_x0) + this_x0, 

2175 this_y0, 

2176 fraction_c * (this_x1 - this_x0), 

2177 height, 

2178 ] 

2179 ax = fh_c.add_axes(indices) 

2180 ax.axis("off") 

2181 display_ax = self._axes_class(ax, direction, coord, **kwargs) 

2182 self.axes[(direction, coord)] = display_ax 

2183 ax.set_axes_locator(self._locator) 

2184 

2185 # increase color bar width to adapt to the number of cuts 

2186 # see issue https://github.com/nilearn/nilearn/pull/4284 

2187 self._colorbar_width *= len(coords) ** 1.1 

2188 

2189 def _locator( 

2190 self, 

2191 axes, 

2192 renderer, # noqa: ARG002 

2193 ): 

2194 """Adjust the size of the axes. 

2195 

2196 Locator function used by matplotlib to position axes. 

2197 

2198 Here we put the logic used to adjust the size of the axes. 

2199 

2200 ``renderer`` is required to match the matplotlib API. 

2201 """ 

2202 x0, y0, x1, y1 = self.rect 

2203 display_ax_dict = self.axes 

2204 

2205 if self._colorbar: 

2206 adjusted_width = self._colorbar_width / len(self.axes) 

2207 right_margin = self._colorbar_margin["right"] / len(self.axes) 

2208 ticks_margin = self._colorbar_margin["left"] / len(self.axes) 

2209 x1 = x1 - (adjusted_width + right_margin + ticks_margin) 

2210 

2211 # capture widths for each axes for anchoring Bbox 

2212 width_dict = {} 

2213 for direction in self._cut_displayed: 

2214 this_width = {} 

2215 for display_ax in display_ax_dict.values(): 

2216 if direction == display_ax.direction: 

2217 bounds = display_ax.get_object_bounds() 

2218 if not bounds: 

2219 # This happens if the call to _map_show was not 

2220 # successful. As it happens asynchronously (during a 

2221 # refresh of the figure) we capture the problem and 

2222 # ignore it: it only adds a non informative traceback 

2223 bounds = [0, 1, 0, 1] 

2224 xmin, xmax, _, _ = bounds 

2225 this_width[display_ax.ax] = xmax - xmin 

2226 total_width = float(sum(this_width.values())) 

2227 for ax, w in this_width.items(): 

2228 width_dict[ax] = w / total_width * (x1 - x0) 

2229 

2230 left_dict = {} 

2231 # bottom positions in Bbox according to cuts 

2232 bottom_dict = {} 

2233 # fraction is divided by the cut directions 'y', 'x', 'z' 

2234 fraction = y1 / len(self._cut_displayed) 

2235 height_dict = {} 

2236 for index, direction in enumerate(self._cut_displayed): 

2237 left = float(x0) 

2238 this_height = fraction + fraction * index 

2239 for display_ax in display_ax_dict.values(): 

2240 if direction == display_ax.direction: 

2241 left_dict[display_ax.ax] = left 

2242 this_width = width_dict[display_ax.ax] 

2243 left += this_width 

2244 bottom_dict[display_ax.ax] = fraction * index * (y1 - y0) 

2245 height_dict[display_ax.ax] = this_height 

2246 return Bbox( 

2247 [ 

2248 [left_dict[axes], bottom_dict[axes]], 

2249 [left_dict[axes] + width_dict[axes], height_dict[axes]], 

2250 ] 

2251 ) 

2252 

2253 def draw_cross(self, cut_coords=None, **kwargs): 

2254 """Draw a crossbar on the plot to show where the cut is performed. 

2255 

2256 Parameters 

2257 ---------- 

2258 cut_coords : 3-:obj:`tuple` of :obj:`float`, optional 

2259 The position of the cross to draw. If ``None`` is passed, the 

2260 ``OrthoSlicer``'s cut coordinates are used. 

2261 

2262 kwargs : :obj:`dict` 

2263 Extra keyword arguments are passed to function 

2264 :func:`matplotlib.pyplot.axhline`. 

2265 """ 

2266 pass 

2267 

2268 

2269SLICERS = { 

2270 "ortho": OrthoSlicer, 

2271 "tiled": TiledSlicer, 

2272 "mosaic": MosaicSlicer, 

2273 "xz": XZSlicer, 

2274 "yz": YZSlicer, 

2275 "yx": YXSlicer, 

2276 "x": XSlicer, 

2277 "y": YSlicer, 

2278 "z": ZSlicer, 

2279} 

2280 

2281 

2282def get_slicer(display_mode): 

2283 """Retrieve a slicer from a given display mode. 

2284 

2285 Parameters 

2286 ---------- 

2287 display_mode : :obj:`str` 

2288 The desired display mode. 

2289 Possible options are: 

2290 

2291 - "ortho": Three cuts are performed in orthogonal directions. 

2292 - "tiled": Three cuts are performed and arranged in a 2x2 grid. 

2293 - "mosaic": Three cuts are performed along multiple rows and columns. 

2294 - "x": Sagittal 

2295 - "y": Coronal 

2296 - "z": Axial 

2297 - "xz": Sagittal + Axial 

2298 - "yz": Coronal + Axial 

2299 - "yx": Coronal + Sagittal 

2300 

2301 Returns 

2302 ------- 

2303 slicer : An instance of one of the subclasses of\ 

2304 :class:`~nilearn.plotting.displays.BaseSlicer` 

2305 

2306 The slicer corresponding to the requested display mode: 

2307 

2308 - "ortho": Returns an 

2309 :class:`~nilearn.plotting.displays.OrthoSlicer`. 

2310 - "tiled": Returns a 

2311 :class:`~nilearn.plotting.displays.TiledSlicer`. 

2312 - "mosaic": Returns a 

2313 :class:`~nilearn.plotting.displays.MosaicSlicer`. 

2314 - "xz": Returns a 

2315 :class:`~nilearn.plotting.displays.XZSlicer`. 

2316 - "yz": Returns a 

2317 :class:`~nilearn.plotting.displays.YZSlicer`. 

2318 - "yx": Returns a 

2319 :class:`~nilearn.plotting.displays.YZSlicer`. 

2320 - "x": Returns a 

2321 :class:`~nilearn.plotting.displays.XSlicer`. 

2322 - "y": Returns a 

2323 :class:`~nilearn.plotting.displays.YSlicer`. 

2324 - "z": Returns a 

2325 :class:`~nilearn.plotting.displays.ZSlicer`. 

2326 

2327 """ 

2328 return get_create_display_fun(display_mode, SLICERS)