Coverage for nilearn/plotting/displays/_slicers.py: 0%
719 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
1import collections
2import contextlib
3import numbers
4import warnings
5from typing import ClassVar
7import matplotlib.pyplot as plt
8import numpy as np
9from matplotlib.colorbar import ColorbarBase
10from matplotlib.colors import LinearSegmentedColormap, ListedColormap
11from matplotlib.transforms import Bbox
13from nilearn._utils import check_niimg_3d, fill_doc
14from nilearn._utils.logger import find_stack_level
15from nilearn._utils.niimg import is_binary_niimg, safe_get_data
16from nilearn._utils.niimg_conversions import _check_fov
17from nilearn._utils.param_validation import check_params
18from nilearn.image import get_data, new_img_like, reorder_img
19from nilearn.image.resampling import get_bounds, get_mask_bounds, resample_img
20from nilearn.plotting._utils import (
21 check_threshold_not_negative,
22 get_cbar_ticks,
23)
24from nilearn.plotting.displays import CutAxes
25from nilearn.plotting.displays._utils import (
26 coords_3d_to_2d,
27 get_create_display_fun,
28)
29from nilearn.plotting.displays.edge_detect import edge_map
30from nilearn.plotting.find_cuts import find_cut_slices, find_xyz_cut_coords
31from nilearn.typing import NiimgLike
34@fill_doc
35class BaseSlicer:
36 """BaseSlicer implementation which main purpose is to auto adjust \
37 the axes size to the data with different layout of cuts.
39 It creates 3 linked axes for plotting orthogonal cuts.
41 Attributes
42 ----------
43 cut_coords : 3 :obj:`tuple` of :obj:`int`
44 The cut position, in world space.
46 frame_axes : :class:`matplotlib.axes.Axes`, optional
47 The matplotlib axes that will be subdivided in 3.
49 black_bg : :obj:`bool`, default=False
50 If ``True``, the background of the figure will be put to
51 black. If you wish to save figures with a black background,
52 you will need to pass ``facecolor='k', edgecolor='k'``
53 to :func:`~matplotlib.pyplot.savefig`.
55 brain_color : :obj:`tuple`, default=(0.5, 0.5, 0.5)
56 The brain color to use as the background color (e.g., for
57 transparent colorbars).
58 """
60 # This actually encodes the figsize for only one axe
61 _default_figsize: ClassVar[list[float]] = [2.2, 2.6]
62 _axes_class = CutAxes
64 def __init__(
65 self,
66 cut_coords,
67 axes=None,
68 black_bg=False,
69 brain_color=(0.5, 0.5, 0.5),
70 **kwargs,
71 ):
72 self.cut_coords = cut_coords
73 if axes is None:
74 axes = plt.axes((0.0, 0.0, 1.0, 1.0))
75 axes.axis("off")
76 self.frame_axes = axes
77 axes.set_zorder(1)
78 bb = axes.get_position()
79 self.rect = (bb.x0, bb.y0, bb.x1, bb.y1)
80 self._black_bg = black_bg
81 self._brain_color = brain_color
82 self._colorbar = False
83 self._colorbar_width = 0.05 * bb.width
84 self._cbar_tick_format = "%.2g"
85 self._colorbar_margin = {
86 "left": 0.25 * bb.width,
87 "right": 0.02 * bb.width,
88 "top": 0.05 * bb.height,
89 "bottom": 0.05 * bb.height,
90 }
91 self._init_axes(**kwargs)
93 @property
94 def brain_color(self):
95 """Return brain color."""
96 return self._brain_color
98 @property
99 def black_bg(self):
100 """Return black background."""
101 return self._black_bg
103 @staticmethod
104 def find_cut_coords(img=None, threshold=None, cut_coords=None):
105 """Act as placeholder and is not implemented in the base class \
106 and has to be implemented in derived classes.
107 """
108 # Implement this as a staticmethod or a classmethod when
109 # subclassing
110 raise NotImplementedError
112 @classmethod
113 @fill_doc # the fill_doc decorator must be last applied
114 def init_with_figure(
115 cls,
116 img,
117 threshold=None,
118 cut_coords=None,
119 figure=None,
120 axes=None,
121 black_bg=False,
122 leave_space=False,
123 colorbar=False,
124 brain_color=(0.5, 0.5, 0.5),
125 **kwargs,
126 ):
127 """Initialize the slicer with an image.
129 Parameters
130 ----------
131 %(img)s
133 %(threshold)s
135 cut_coords : 3 :obj:`tuple` of :obj:`int`
136 The cut position, in world space.
138 axes : :class:`matplotlib.axes.Axes`, optional
139 The axes that will be subdivided in 3.
141 black_bg : :obj:`bool`, default=False
142 If ``True``, the background of the figure will be put to
143 black. If you wish to save figures with a black background,
144 you will need to pass ``facecolor='k', edgecolor='k'``
145 to :func:`matplotlib.pyplot.savefig`.
148 brain_color : :obj:`tuple`, default=(0.5, 0.5, 0.5)
149 The brain color to use as the background color (e.g., for
150 transparent colorbars).
152 Raises
153 ------
154 ValueError
155 if the specified threshold is a negative number
156 """
157 check_params(locals())
158 check_threshold_not_negative(threshold)
160 # deal with "fake" 4D images
161 if img is not None and img is not False:
162 img = check_niimg_3d(img)
164 cut_coords = cls.find_cut_coords(img, threshold, cut_coords)
166 if isinstance(axes, plt.Axes) and figure is None:
167 figure = axes.figure
169 if not isinstance(figure, plt.Figure):
170 # Make sure that we have a figure
171 figsize = cls._default_figsize[:]
173 # Adjust for the number of axes
174 figsize[0] *= len(cut_coords)
176 # Make space for the colorbar
177 if colorbar:
178 figsize[0] += 0.7
180 facecolor = "k" if black_bg else "w"
182 if leave_space:
183 figsize[0] += 3.4
184 figure = plt.figure(figure, figsize=figsize, facecolor=facecolor)
185 if isinstance(axes, plt.Axes):
186 assert axes.figure is figure, (
187 "The axes passed are not in the figure"
188 )
190 if axes is None:
191 axes = [0.3, 0, 0.7, 1.0] if leave_space else [0.0, 0.0, 1.0, 1.0]
192 if isinstance(axes, collections.abc.Sequence):
193 axes = figure.add_axes(axes)
194 # People forget to turn their axis off, or to set the zorder, and
195 # then they cannot see their slicer
196 axes.axis("off")
197 return cls(cut_coords, axes, black_bg, brain_color, **kwargs)
199 def title(
200 self,
201 text,
202 x=0.01,
203 y=0.99,
204 size=15,
205 color=None,
206 bgcolor=None,
207 alpha=1,
208 **kwargs,
209 ):
210 """Write a title to the view.
212 Parameters
213 ----------
214 text : :obj:`str`
215 The text of the title.
217 x : :obj:`float`, default=0.01
218 The horizontal position of the title on the frame in
219 fraction of the frame width.
221 y : :obj:`float`, default=0.99
222 The vertical position of the title on the frame in
223 fraction of the frame height.
225 size : :obj:`int`, default=15
226 The size of the title text.
228 color : matplotlib color specifier, optional
229 The color of the font of the title.
231 bgcolor : matplotlib color specifier, optional
232 The color of the background of the title.
234 alpha : :obj:`float`, default=1
235 The alpha value for the background.
237 kwargs :
238 Extra keyword arguments are passed to matplotlib's text
239 function.
240 """
241 if color is None:
242 color = "k" if self._black_bg else "w"
243 if bgcolor is None:
244 bgcolor = "w" if self._black_bg else "k"
245 if hasattr(self, "_cut_displayed"):
246 # Adapt to the case of mosaic plotting
247 if isinstance(self.cut_coords, dict):
248 first_axe = self._cut_displayed[-1]
249 first_axe = (first_axe, self.cut_coords[first_axe][0])
250 else:
251 first_axe = self._cut_displayed[0]
252 else:
253 first_axe = self.cut_coords[0]
254 ax = self.axes[first_axe].ax
255 ax.text(
256 x,
257 y,
258 text,
259 transform=self.frame_axes.transAxes,
260 horizontalalignment="left",
261 verticalalignment="top",
262 size=size,
263 color=color,
264 bbox={
265 "boxstyle": "square,pad=.3",
266 "ec": bgcolor,
267 "fc": bgcolor,
268 "alpha": alpha,
269 },
270 zorder=1000,
271 **kwargs,
272 )
273 ax.set_zorder(1000)
275 @fill_doc
276 def add_overlay(
277 self,
278 img,
279 threshold=1e-6,
280 colorbar=False,
281 cbar_tick_format="%.2g",
282 cbar_vmin=None,
283 cbar_vmax=None,
284 transparency=None,
285 transparency_range=None,
286 **kwargs,
287 ):
288 """Plot a 3D map in all the views.
290 Parameters
291 ----------
292 %(img)s
293 If it is a masked array, only the non-masked part will be plotted.
295 threshold : :obj:`int` or :obj:`float` or ``None``, default=1e-6
296 Threshold to apply:
298 - If ``None`` is given, the maps are not thresholded.
299 - If number is given, it must be non-negative. The specified
300 value is used to threshold the image: values below the
301 threshold (in absolute value) are plotted as transparent.
303 cbar_tick_format : str, default="%%.2g" (scientific notation)
304 Controls how to format the tick labels of the colorbar.
305 Ex: use "%%i" to display as integers.
307 colorbar : :obj:`bool`, default=False
308 If ``True``, display a colorbar on the right of the plots.
310 cbar_vmin : :obj:`float`, optional
311 Minimal value for the colorbar. If None, the minimal value
312 is computed based on the data.
314 cbar_vmax : :obj:`float`, optional
315 Maximal value for the colorbar. If None, the maximal value
316 is computed based on the data.
318 %(transparency)s
320 %(transparency_range)s
322 kwargs : :obj:`dict`
323 Extra keyword arguments are passed to function
324 :func:`~matplotlib.pyplot.imshow`.
326 Raises
327 ------
328 ValueError
329 if the specified threshold is a negative number
330 """
331 check_threshold_not_negative(threshold)
333 if colorbar and self._colorbar:
334 raise ValueError(
335 "This figure already has an overlay with a colorbar."
336 )
338 self._colorbar = colorbar
339 self._cbar_tick_format = cbar_tick_format
341 img = check_niimg_3d(img)
343 # Make sure that add_overlay shows consistent default behavior
344 # with plot_stat_map
345 kwargs.setdefault("interpolation", "nearest")
346 ims = self._map_show(
347 img,
348 type="imshow",
349 threshold=threshold,
350 transparency=transparency,
351 transparency_range=transparency_range,
352 **kwargs,
353 )
355 # `ims` can be empty in some corner cases,
356 # look at test_img_plotting.test_outlier_cut_coords.
357 if colorbar and ims:
358 self._show_colorbar(
359 ims[0].cmap, ims[0].norm, cbar_vmin, cbar_vmax, threshold
360 )
362 plt.draw_if_interactive()
364 @fill_doc
365 def add_contours(self, img, threshold=1e-6, filled=False, **kwargs):
366 """Contour a 3D map in all the views.
368 Parameters
369 ----------
370 %(img)s
371 Provides image to plot.
373 threshold : :obj:`int` or :obj:`float` or ``None``, default=1e-6
374 Threshold to apply:
376 - If ``None`` is given, the maps are not thresholded.
377 - If number is given, it must be non-negative. The specified
378 value is used to threshold the image: values below the
379 threshold (in absolute value) are plotted as transparent.
381 filled : :obj:`bool`, default=False
382 If ``filled=True``, contours are displayed with color fillings.
385 kwargs : :obj:`dict`
386 Extra keyword arguments are passed to function
387 :func:`~matplotlib.pyplot.contour`, or function
388 :func:`~matplotlib.pyplot.contourf`.
389 Useful, arguments are typical "levels", which is a
390 list of values to use for plotting a contour or contour
391 fillings (if ``filled=True``), and
392 "colors", which is one color or a list of colors for
393 these contours.
395 Raises
396 ------
397 ValueError
398 if the specified threshold is a negative number
400 Notes
401 -----
402 If colors are not specified, default coloring choices
403 (from matplotlib) for contours and contour_fillings can be
404 different.
406 """
407 if not filled:
408 threshold = None
409 else:
410 check_threshold_not_negative(threshold)
412 self._map_show(img, type="contour", threshold=threshold, **kwargs)
413 if filled:
414 if "levels" in kwargs:
415 levels = kwargs["levels"]
416 if len(levels) <= 1:
417 # contour fillings levels
418 # should be given as (lower, upper).
419 levels.append(np.inf)
421 self._map_show(img, type="contourf", threshold=threshold, **kwargs)
423 plt.draw_if_interactive()
425 def _map_show(
426 self,
427 img,
428 type="imshow",
429 resampling_interpolation="continuous",
430 threshold=None,
431 transparency=None,
432 transparency_range=None,
433 **kwargs,
434 ):
435 # In the special case where the affine of img is not diagonal,
436 # the function `reorder_img` will trigger a resampling
437 # of the provided image with a continuous interpolation
438 # since this is the default value here. In the special
439 # case where this image is binary, such as when this function
440 # is called from `add_contours`, continuous interpolation
441 # does not make sense and we turn to nearest interpolation instead.
443 if is_binary_niimg(img):
444 resampling_interpolation = "nearest"
446 # Image reordering should be done before sanitizing transparency
447 img = reorder_img(
448 img, resample=resampling_interpolation, copy_header=True
449 )
451 transparency, transparency_affine = self._sanitize_transparency(
452 img,
453 transparency,
454 transparency_range,
455 resampling_interpolation,
456 )
458 affine = img.affine
460 if threshold is not None:
461 threshold = float(threshold)
462 data = safe_get_data(img, ensure_finite=True)
463 data = self._threshold(data, threshold, None, None)
464 img = new_img_like(img, data, affine)
466 data = safe_get_data(img, ensure_finite=True)
467 data_bounds = get_bounds(data.shape, affine)
468 (xmin, xmax), (ymin, ymax), (zmin, zmax) = data_bounds
470 xmin_, xmax_, ymin_, ymax_, zmin_, zmax_ = (
471 xmin,
472 xmax,
473 ymin,
474 ymax,
475 zmin,
476 zmax,
477 )
479 # Compute tight bounds
480 if type in ("contour", "contourf"):
481 # Define a pseudo threshold to have a tight bounding box
482 thr = (
483 0.9 * np.min(np.abs(kwargs["levels"]))
484 if "levels" in kwargs
485 else 1e-6
486 )
487 not_mask = np.logical_or(data > thr, data < -thr)
488 xmin_, xmax_, ymin_, ymax_, zmin_, zmax_ = get_mask_bounds(
489 new_img_like(img, not_mask, affine)
490 )
491 elif hasattr(data, "mask") and isinstance(data.mask, np.ndarray):
492 not_mask = np.logical_not(data.mask)
493 xmin_, xmax_, ymin_, ymax_, zmin_, zmax_ = get_mask_bounds(
494 new_img_like(img, not_mask, affine)
495 )
497 data_2d_list = []
498 transparency_list = []
499 for display_ax in self.axes.values():
500 if transparency is None or isinstance(transparency, (float, int)):
501 transparency_2d = transparency
503 try:
504 data_2d = display_ax.transform_to_2d(data, affine)
505 if isinstance(transparency, np.ndarray):
506 transparency_2d = display_ax.transform_to_2d(
507 transparency, transparency_affine
508 )
509 except IndexError:
510 # We are cutting outside the indices of the data
511 data_2d = None
512 transparency_2d = None
514 data_2d_list.append(data_2d)
515 transparency_list.append(transparency_2d)
517 if kwargs.get("vmin") is None:
518 kwargs["vmin"] = np.ma.min(
519 [d.min() for d in data_2d_list if d is not None]
520 )
521 if kwargs.get("vmax") is None:
522 kwargs["vmax"] = np.ma.max(
523 [d.max() for d in data_2d_list if d is not None]
524 )
526 bounding_box = (xmin_, xmax_), (ymin_, ymax_), (zmin_, zmax_)
527 ims = []
528 to_iterate_over = zip(
529 self.axes.values(), data_2d_list, transparency_list
530 )
531 threshold = float(threshold) if threshold else None
532 for display_ax, data_2d, transparency_2d in to_iterate_over:
533 # If data_2d is completely masked, then there is nothing to
534 # plot. Hence, no point to do imshow().
535 if data_2d is not None:
536 data_2d = self._threshold(
537 data_2d,
538 threshold,
539 vmin=float(kwargs.get("vmin")),
540 vmax=float(kwargs.get("vmax")),
541 )
543 im = display_ax.draw_2d(
544 data_2d,
545 data_bounds,
546 bounding_box,
547 type=type,
548 transparency=transparency_2d,
549 **kwargs,
550 )
551 ims.append(im)
552 return ims
554 def _sanitize_transparency(
555 self, img, transparency, transparency_range, resampling_interpolation
556 ):
557 """Return transparency as None, float or an array.
559 Return
560 ------
561 transparency: None, float or np.ndarray
563 transparency_affine: None or np.ndarray
564 """
565 transparency_affine = None
566 if isinstance(transparency, NiimgLike):
567 transparency = check_niimg_3d(transparency, dtype="auto")
568 if is_binary_niimg(transparency):
569 resampling_interpolation = "nearest"
570 transparency = reorder_img(
571 transparency,
572 resample=resampling_interpolation,
573 copy_header=True,
574 )
575 if not _check_fov(transparency, img.affine, img.shape[:3]):
576 warnings.warn(
577 "resampling transparency image to data image...",
578 stacklevel=find_stack_level(),
579 )
580 transparency = resample_img(
581 transparency,
582 img.affine,
583 img.shape,
584 force_resample=True,
585 copy_header=True,
586 interpolation=resampling_interpolation,
587 )
589 transparency_affine = transparency.affine
590 transparency = safe_get_data(transparency, ensure_finite=True)
592 assert transparency is None or isinstance(
593 transparency, (int, float, np.ndarray)
594 )
596 if isinstance(transparency, (float, int)):
597 transparency = float(transparency)
598 base_warning_message = (
599 "'transparency' must be in the interval [0, 1]. "
600 )
601 if transparency > 1.0:
602 warnings.warn(
603 f"{base_warning_message} Setting it to 1.0.",
604 stacklevel=find_stack_level(),
605 )
606 transparency = 1.0
607 if transparency < 0:
608 warnings.warn(
609 f"{base_warning_message} Setting it to 0.0.",
610 stacklevel=find_stack_level(),
611 )
612 transparency = 0.0
614 elif isinstance(transparency, np.ndarray):
615 transparency = np.abs(transparency)
617 if transparency_range is None:
618 transparency_range = [0.0, np.max(transparency)]
620 error_msg = (
621 "'transparency_range' must be "
622 "a list or tuple of 2 non-negative numbers "
623 "with 'first value < second value'."
624 )
626 if len(transparency_range) != 2:
627 raise ValueError(f"{error_msg} Got '{transparency_range}'.")
629 transparency_range[1] = min(
630 transparency_range[1], np.max(transparency)
631 )
632 transparency_range[0] = max(
633 transparency_range[0], np.min(transparency)
634 )
636 if transparency_range[0] >= transparency_range[1]:
637 raise ValueError(f"{error_msg} Got '{transparency_range}'.")
639 # make sure that 0 <= transparency <= 1
640 # taking into account the requested transparency_range
641 transparency = np.clip(
642 transparency, transparency_range[0], transparency_range[1]
643 )
644 transparency = (transparency - transparency_range[0]) / (
645 transparency_range[1] - transparency_range[0]
646 )
648 return transparency, transparency_affine
650 @classmethod
651 def _threshold(cls, data, threshold=None, vmin=None, vmax=None):
652 """Threshold the data.
654 Parameters
655 ----------
656 data: ndarray
657 data to be thresholded
659 %(threshold)s
661 %(vmin)s
663 %(vmax)s
665 Raises
666 ------
667 ValueError
668 if the specified threshold is a negative number
669 """
670 check_params(locals())
671 check_threshold_not_negative(threshold)
673 if threshold is not None:
674 data = np.ma.masked_where(
675 np.abs(data) <= threshold,
676 data,
677 copy=False,
678 )
680 if (vmin is not None) and (vmin >= -threshold):
681 data = np.ma.masked_where(data < vmin, data, copy=False)
682 if (vmax is not None) and (vmax <= threshold):
683 data = np.ma.masked_where(data > vmax, data, copy=False)
685 return data
687 @fill_doc
688 def _show_colorbar(
689 self, cmap, norm, cbar_vmin=None, cbar_vmax=None, threshold=None
690 ):
691 """Display the colorbar.
693 Parameters
694 ----------
695 %(cmap)s
696 norm : :class:`~matplotlib.colors.Normalize`
697 This object is typically found as the ``norm`` attribute of
698 :class:`~matplotlib.image.AxesImage`.
700 threshold : :obj:`float` or ``None``, optional
701 The absolute value at which the colorbar is thresholded.
703 cbar_vmin : :obj:`float`, optional
704 Minimal value for the colorbar. If None, the minimal value
705 is computed based on the data.
707 cbar_vmax : :obj:`float`, optional
708 Maximal value for the colorbar. If None, the maximal value
709 is computed based on the data.
710 """
711 offset = 0 if threshold is None else threshold
712 offset = min(offset, norm.vmax)
714 cbar_vmin = cbar_vmin if cbar_vmin is not None else norm.vmin
715 cbar_vmax = cbar_vmax if cbar_vmax is not None else norm.vmax
717 # create new axis for the colorbar
718 figure = self.frame_axes.figure
719 _, y0, x1, y1 = self.rect
720 height = y1 - y0
721 x_adjusted_width = self._colorbar_width / len(self.axes)
722 x_adjusted_margin = self._colorbar_margin["right"] / len(self.axes)
723 lt_wid_top_ht = [
724 x1 - (x_adjusted_width + x_adjusted_margin),
725 y0 + self._colorbar_margin["top"],
726 x_adjusted_width,
727 height
728 - (self._colorbar_margin["top"] + self._colorbar_margin["bottom"]),
729 ]
730 self._colorbar_ax = figure.add_axes(lt_wid_top_ht)
731 self._colorbar_ax.set_facecolor("w")
733 our_cmap = plt.get_cmap(cmap)
734 # edge case where the data has a single value
735 # yields a cryptic matplotlib error message
736 # when trying to plot the color bar
737 n_ticks = 5 if cbar_vmin != cbar_vmax else 1
738 ticks = get_cbar_ticks(cbar_vmin, cbar_vmax, offset, n_ticks)
739 bounds = np.linspace(cbar_vmin, cbar_vmax, our_cmap.N)
741 # some colormap hacking
742 cmaplist = [our_cmap(i) for i in range(our_cmap.N)]
743 transparent_start = int(norm(-offset, clip=True) * (our_cmap.N - 1))
744 transparent_stop = int(norm(offset, clip=True) * (our_cmap.N - 1))
745 for i in range(transparent_start, transparent_stop):
746 cmaplist[i] = (*self._brain_color, 0.0) # transparent
747 if cbar_vmin == cbar_vmax: # len(np.unique(data)) == 1 ?
748 return
749 else:
750 our_cmap = LinearSegmentedColormap.from_list(
751 "Custom cmap", cmaplist, our_cmap.N
752 )
753 self._cbar = ColorbarBase(
754 self._colorbar_ax,
755 ticks=ticks,
756 norm=norm,
757 orientation="vertical",
758 cmap=our_cmap,
759 boundaries=bounds,
760 spacing="proportional",
761 format=self._cbar_tick_format,
762 )
763 self._cbar.ax.set_facecolor(self._brain_color)
765 self._colorbar_ax.yaxis.tick_left()
766 tick_color = "w" if self._black_bg else "k"
767 outline_color = "w" if self._black_bg else "k"
769 for tick in self._colorbar_ax.yaxis.get_ticklabels():
770 tick.set_color(tick_color)
771 self._colorbar_ax.yaxis.set_tick_params(width=0)
772 self._cbar.outline.set_edgecolor(outline_color)
774 @fill_doc
775 def add_edges(self, img, color="r"):
776 """Plot the edges of a 3D map in all the views.
778 Parameters
779 ----------
780 %(img)s
781 The 3D map to be plotted.
782 If it is a masked array, only the non-masked part will be plotted.
784 color : matplotlib color: :obj:`str` or (r, g, b) value, default='r'
785 The color used to display the edge map.
787 """
788 img = reorder_img(img, resample="continuous", copy_header=True)
789 data = get_data(img)
790 affine = img.affine
791 single_color_cmap = ListedColormap([color])
792 data_bounds = get_bounds(data.shape, img.affine)
794 # For each ax, cut the data and plot it
795 for display_ax in self.axes.values():
796 try:
797 data_2d = display_ax.transform_to_2d(data, affine)
798 edge_mask = edge_map(data_2d)
799 except IndexError:
800 # We are cutting outside the indices of the data
801 continue
802 display_ax.draw_2d(
803 edge_mask,
804 data_bounds,
805 data_bounds,
806 type="imshow",
807 cmap=single_color_cmap,
808 )
810 plt.draw_if_interactive()
812 def add_markers(
813 self, marker_coords, marker_color="r", marker_size=30, **kwargs
814 ):
815 """Add markers to the plot.
817 Parameters
818 ----------
819 marker_coords : :class:`~numpy.ndarray` of shape ``(n_markers, 3)``
820 Coordinates of the markers to plot. For each slice, only markers
821 that are 2 millimeters away from the slice are plotted.
823 marker_color : pyplot compatible color or \
824 :obj:`list` of shape ``(n_markers,)``, default='r'
825 List of colors for each marker
826 that can be string or matplotlib colors.
829 marker_size : :obj:`float` or \
830 :obj:`list` of :obj:`float` of shape ``(n_markers,)``, \
831 default=30
832 Size in pixel for each marker.
833 """
834 defaults = {"marker": "o", "zorder": 1000}
835 marker_coords = np.asanyarray(marker_coords)
836 for k, v in defaults.items():
837 kwargs.setdefault(k, v)
839 for display_ax in self.axes.values():
840 direction = display_ax.direction
841 coord = display_ax.coord
842 marker_coords_2d, third_d = coords_3d_to_2d(
843 marker_coords, direction, return_direction=True
844 )
845 xdata, ydata = marker_coords_2d.T
846 # Allow markers only in their respective hemisphere
847 # when appropriate
848 marker_color_ = marker_color
849 marker_size_ = marker_size
850 if direction in ("lr"):
851 if not isinstance(marker_color, str) and not isinstance(
852 marker_color, np.ndarray
853 ):
854 marker_color_ = np.asarray(marker_color)
855 xcoords, *_ = marker_coords.T
856 if direction == "r":
857 relevant_coords = xcoords >= 0
858 elif direction == "l":
859 relevant_coords = xcoords <= 0
860 xdata = xdata[relevant_coords]
861 ydata = ydata[relevant_coords]
862 if (
863 not isinstance(marker_color, str)
864 and len(marker_color) != 1
865 ):
866 marker_color_ = marker_color_[relevant_coords]
867 if not isinstance(marker_size, numbers.Number):
868 marker_size_ = np.asarray(marker_size_)[relevant_coords]
870 # Check if coord has integer represents a cut in direction
871 # to follow the heuristic. If no foreground image is given
872 # coordinate is empty or None. This case is valid for plotting
873 # markers on glass brain without any foreground image.
874 if isinstance(coord, numbers.Number):
875 # Heuristic that plots only markers that are 2mm away
876 # from the current slice.
877 # XXX: should we keep this heuristic?
878 mask = np.abs(third_d - coord) <= 2.0
879 xdata = xdata[mask]
880 ydata = ydata[mask]
881 display_ax.ax.scatter(
882 xdata, ydata, s=marker_size_, c=marker_color_, **kwargs
883 )
885 def annotate(
886 self,
887 left_right=True,
888 positions=True,
889 scalebar=False,
890 size=12,
891 scale_size=5.0,
892 scale_units="cm",
893 scale_loc=4,
894 decimals=0,
895 **kwargs,
896 ):
897 """Add annotations to the plot.
899 Parameters
900 ----------
901 left_right : :obj:`bool`, default=True
902 If ``True``, annotations indicating which side
903 is left and which side is right are drawn.
906 positions : :obj:`bool`, default=True
907 If ``True``, annotations indicating the
908 positions of the cuts are drawn.
911 scalebar : :obj:`bool`, default=False
912 If ``True``, cuts are annotated with a reference scale bar.
913 For finer control of the scale bar, please check out
914 the ``draw_scale_bar`` method on the axes in "axes" attribute
915 of this object.
918 size : :obj:`int`, default=12
919 The size of the text used.
921 scale_size : :obj:`int` or :obj:`float`, default=5.0
922 The length of the scalebar, in units of ``scale_units``.
925 scale_units : {'cm', 'mm'}, default='cm'
926 The units for the ``scalebar``.
928 scale_loc : :obj:`int`, default=4
929 The positioning for the scalebar.
930 Valid location codes are:
932 - 1: "upper right"
933 - 2: "upper left"
934 - 3: "lower left"
935 - 4: "lower right"
936 - 5: "right"
937 - 6: "center left"
938 - 7: "center right"
939 - 8: "lower center"
940 - 9: "upper center"
941 - 10: "center"
943 decimals : :obj:`int`, default=0
944 Number of decimal places on slice position annotation. If zero,
945 the slice position is integer without decimal point.
948 kwargs : :obj:`dict`
949 Extra keyword arguments are passed to matplotlib's text
950 function.
951 """
952 kwargs = kwargs.copy()
953 if "color" not in kwargs:
954 kwargs["color"] = "w" if self._black_bg else "k"
955 bg_color = "k" if self._black_bg else "w"
957 if left_right:
958 for display_axis in self.axes.values():
959 display_axis.draw_left_right(
960 size=size, bg_color=bg_color, **kwargs
961 )
963 if positions:
964 for display_axis in self.axes.values():
965 display_axis.draw_position(
966 size=size, bg_color=bg_color, decimals=decimals, **kwargs
967 )
969 if scalebar:
970 axes = self.axes.values()
971 for display_axis in axes:
972 display_axis.draw_scale_bar(
973 bg_color=bg_color,
974 fontsize=size,
975 size=scale_size,
976 units=scale_units,
977 loc=scale_loc,
978 **kwargs,
979 )
981 def close(self):
982 """Close the figure.
984 This is necessary to avoid leaking memory.
985 """
986 plt.close(self.frame_axes.figure.number)
988 def savefig(self, filename, dpi=None, **kwargs):
989 """Save the figure to a file.
991 Parameters
992 ----------
993 filename : :obj:`str`
994 The file name to save to. Its extension determines the
995 file type, typically '.png', '.svg' or '.pdf'.
997 dpi : ``None`` or scalar, default=None
998 The resolution in dots per inch.
1000 """
1001 facecolor = edgecolor = "k" if self._black_bg else "w"
1002 self.frame_axes.figure.savefig(
1003 filename,
1004 dpi=dpi,
1005 facecolor=facecolor,
1006 edgecolor=edgecolor,
1007 **kwargs,
1008 )
1011@fill_doc
1012class OrthoSlicer(BaseSlicer):
1013 """Class to create 3 linked axes for plotting orthogonal \
1014 cuts of 3D maps.
1016 This visualization mode can be activated
1017 from Nilearn plotting functions, like
1018 :func:`~nilearn.plotting.plot_img`, by setting
1019 ``display_mode='ortho'``:
1021 .. code-block:: python
1023 from nilearn.datasets import load_mni152_template
1024 from nilearn.plotting import plot_img
1026 img = load_mni152_template()
1027 # display is an instance of the OrthoSlicer class
1028 display = plot_img(img, display_mode="ortho")
1031 Attributes
1032 ----------
1033 cut_coords : :obj:`list`
1034 The cut coordinates.
1036 axes : :obj:`dict` of :class:`~matplotlib.axes.Axes`
1037 The 3 axes used to plot each view.
1039 frame_axes : :class:`~matplotlib.axes.Axes`
1040 The axes framing the whole set of views.
1042 Notes
1043 -----
1044 The extent of the different axes are adjusted to fit the data
1045 best in the viewing area.
1047 See Also
1048 --------
1049 nilearn.plotting.displays.MosaicSlicer : Three cuts are performed \
1050 along multiple rows and columns.
1051 nilearn.plotting.displays.TiledSlicer : Three cuts are performed \
1052 and arranged in a 2x2 grid.
1054 """
1056 _cut_displayed: ClassVar[str] = "yxz"
1057 _axes_class = CutAxes
1058 _default_figsize: ClassVar[list[float]] = [2.2, 3.5]
1060 @classmethod
1061 @fill_doc # the fill_doc decorator must be last applied
1062 def find_cut_coords(cls, img=None, threshold=None, cut_coords=None):
1063 """Instantiate the slicer and find cut coordinates.
1065 Parameters
1066 ----------
1067 %(img)s
1068 threshold : :obj:`int` or :obj:`float` or ``None``, default=None
1069 Threshold to apply:
1071 - If ``None`` is given, the maps are not thresholded.
1072 - If number is given, it must be non-negative. The specified
1073 value is used to threshold the image: values below the
1074 threshold (in absolute value) are plotted as transparent.
1076 cut_coords : 3 :obj:`tuple` of :obj:`int`
1077 The cut position, in world space.
1079 Raises
1080 ------
1081 ValueError
1082 if the specified threshold is a negative number
1083 """
1084 if cut_coords is None:
1085 if img is None or img is False:
1086 cut_coords = (0, 0, 0)
1087 else:
1088 cut_coords = find_xyz_cut_coords(
1089 img, activation_threshold=threshold
1090 )
1091 cut_coords = [
1092 cut_coords["xyz".find(c)] for c in sorted(cls._cut_displayed)
1093 ]
1094 return cut_coords
1096 def _init_axes(self, **kwargs):
1097 cut_coords = self.cut_coords
1098 if len(cut_coords) != len(self._cut_displayed):
1099 raise ValueError(
1100 "The number cut_coords passed does not match the display_mode"
1101 )
1102 x0, y0, x1, y1 = self.rect
1103 facecolor = "k" if self._black_bg else "w"
1104 # Create our axes:
1105 self.axes = {}
1106 for index, direction in enumerate(self._cut_displayed):
1107 fh = self.frame_axes.get_figure()
1108 ax = fh.add_axes(
1109 [0.3 * index * (x1 - x0) + x0, y0, 0.3 * (x1 - x0), y1 - y0],
1110 aspect="equal",
1111 )
1112 ax.set_facecolor(facecolor)
1114 ax.axis("off")
1115 coord = self.cut_coords[
1116 sorted(self._cut_displayed).index(direction)
1117 ]
1118 display_ax = self._axes_class(ax, direction, coord, **kwargs)
1119 self.axes[direction] = display_ax
1120 ax.set_axes_locator(self._locator)
1122 if self._black_bg:
1123 for ax in self.axes.values():
1124 ax.ax.imshow(
1125 np.zeros((2, 2, 3)),
1126 extent=[-5000, 5000, -5000, 5000],
1127 zorder=-500,
1128 aspect="equal",
1129 )
1131 # To have a black background in PDF, we need to create a
1132 # patch in black for the background
1133 self.frame_axes.imshow(
1134 np.zeros((2, 2, 3)),
1135 extent=[-5000, 5000, -5000, 5000],
1136 zorder=-500,
1137 aspect="auto",
1138 )
1139 self.frame_axes.set_zorder(-1000)
1141 def _locator(
1142 self,
1143 axes,
1144 renderer, # noqa: ARG002
1145 ):
1146 """Adjust the size of the axes.
1148 The locator function used by matplotlib to position axes.
1150 Here we put the logic used to adjust the size of the axes.
1152 ``renderer`` is required to match the matplotlib API.
1153 """
1154 x0, y0, x1, y1 = self.rect
1155 # A dummy axes, for the situation in which we are not plotting
1156 # all three (x, y, z) cuts
1157 dummy_ax = self._axes_class(None, None, None)
1158 width_dict = {dummy_ax.ax: 0}
1159 display_ax_dict = self.axes
1161 if self._colorbar:
1162 adjusted_width = self._colorbar_width / len(self.axes)
1163 right_margin = self._colorbar_margin["right"] / len(self.axes)
1164 ticks_margin = self._colorbar_margin["left"] / len(self.axes)
1165 x1 = x1 - (adjusted_width + ticks_margin + right_margin)
1167 for display_ax in display_ax_dict.values():
1168 bounds = display_ax.get_object_bounds()
1169 if not bounds:
1170 # This happens if the call to _map_show was not
1171 # successful. As it happens asynchronously (during a
1172 # refresh of the figure) we capture the problem and
1173 # ignore it: it only adds a non informative traceback
1174 bounds = [0, 1, 0, 1]
1175 xmin, xmax, _, _ = bounds
1176 width_dict[display_ax.ax] = xmax - xmin
1178 total_width = float(sum(width_dict.values()))
1179 for ax, width in width_dict.items():
1180 width_dict[ax] = width / total_width * (x1 - x0)
1182 direction_ax = [
1183 display_ax_dict.get(d, dummy_ax).ax for d in self._cut_displayed
1184 ]
1185 left_dict = {}
1186 for idx, ax in enumerate(direction_ax):
1187 left_dict[ax] = x0
1188 for prev_ax in direction_ax[:idx]:
1189 left_dict[ax] += width_dict[prev_ax]
1191 return Bbox(
1192 [[left_dict[axes], y0], [left_dict[axes] + width_dict[axes], y1]]
1193 )
1195 def draw_cross(self, cut_coords=None, **kwargs):
1196 """Draw a crossbar on the plot to show where the cut is performed.
1198 Parameters
1199 ----------
1200 cut_coords : 3-:obj:`tuple` of :obj:`float`, optional
1201 The position of the cross to draw. If ``None`` is passed, the
1202 ``OrthoSlicer``'s cut coordinates are used.
1204 kwargs : :obj:`dict`
1205 Extra keyword arguments are passed to function
1206 :func:`~matplotlib.pyplot.axhline`.
1207 """
1208 if cut_coords is None:
1209 cut_coords = self.cut_coords
1210 coords = {}
1211 for direction in "xyz":
1212 coord = None
1213 if direction in self._cut_displayed:
1214 coord = cut_coords[
1215 sorted(self._cut_displayed).index(direction)
1216 ]
1217 coords[direction] = coord
1218 x, y, z = coords["x"], coords["y"], coords["z"]
1220 kwargs = kwargs.copy()
1221 if "color" not in kwargs:
1222 kwargs["color"] = ".8" if self._black_bg else "k"
1223 if "y" in self.axes:
1224 ax = self.axes["y"].ax
1225 if x is not None:
1226 ax.axvline(x, ymin=0.05, ymax=0.95, **kwargs)
1227 if z is not None:
1228 ax.axhline(z, **kwargs)
1230 if "x" in self.axes:
1231 ax = self.axes["x"].ax
1232 if y is not None:
1233 ax.axvline(y, ymin=0.05, ymax=0.95, **kwargs)
1234 if z is not None:
1235 ax.axhline(z, xmax=0.95, **kwargs)
1237 if "z" in self.axes:
1238 ax = self.axes["z"].ax
1239 if x is not None:
1240 ax.axvline(x, ymin=0.05, ymax=0.95, **kwargs)
1241 if y is not None:
1242 ax.axhline(y, **kwargs)
1245class TiledSlicer(BaseSlicer):
1246 """A class to create 3 axes for plotting orthogonal \
1247 cuts of 3D maps, organized in a 2x2 grid.
1249 This visualization mode can be activated from Nilearn plotting functions,
1250 like :func:`~nilearn.plotting.plot_img`, by setting
1251 ``display_mode='tiled'``:
1253 .. code-block:: python
1255 from nilearn.datasets import load_mni152_template
1256 from nilearn.plotting import plot_img
1258 img = load_mni152_template()
1259 # display is an instance of the TiledSlicer class
1260 display = plot_img(img, display_mode="tiled")
1262 Attributes
1263 ----------
1264 cut_coords : :obj:`list`
1265 The cut coordinates.
1267 axes : :obj:`dict` of :class:`~matplotlib.axes.Axes`
1268 The 3 axes used to plot each view.
1270 frame_axes : :class:`~matplotlib.axes.Axes`
1271 The axes framing the whole set of views.
1273 Notes
1274 -----
1275 The extent of the different axes are adjusted to fit the data
1276 best in the viewing area.
1278 See Also
1279 --------
1280 nilearn.plotting.displays.MosaicSlicer : Three cuts are performed \
1281 along multiple rows and columns.
1282 nilearn.plotting.displays.OrthoSlicer : Three cuts are performed \
1283 and arranged in a 2x2 grid.
1285 """
1287 _cut_displayed: ClassVar[str] = "yxz"
1288 _axes_class = CutAxes
1289 _default_figsize: ClassVar[list[float]] = [2.0, 7.6]
1291 @classmethod
1292 def find_cut_coords(cls, img=None, threshold=None, cut_coords=None):
1293 """Instantiate the slicer and find cut coordinates.
1295 Parameters
1296 ----------
1297 img : 3D :class:`~nibabel.nifti1.Nifti1Image`
1298 The brain map.
1300 threshold : :obj:`float`, optional
1301 The lower threshold to the positive activation.
1302 If ``None``, the activation threshold is computed using the
1303 80% percentile of the absolute value of the map.
1305 cut_coords : :obj:`list` of :obj:`float`, optional
1306 xyz world coordinates of cuts.
1308 Returns
1309 -------
1310 cut_coords : :obj:`list` of :obj:`float`
1311 xyz world coordinates of cuts.
1313 Raises
1314 ------
1315 ValueError
1316 if the specified threshold is a negative number
1317 """
1318 if cut_coords is None:
1319 if img is None or img is False:
1320 cut_coords = (0, 0, 0)
1321 else:
1322 cut_coords = find_xyz_cut_coords(
1323 img, activation_threshold=threshold
1324 )
1325 cut_coords = [
1326 cut_coords["xyz".find(c)] for c in sorted(cls._cut_displayed)
1327 ]
1329 return cut_coords
1331 def _find_initial_axes_coord(self, index):
1332 """Find coordinates for initial axes placement for xyz cuts.
1334 Parameters
1335 ----------
1336 index : :obj:`int`
1337 Index corresponding to current cut 'x', 'y' or 'z'.
1339 Returns
1340 -------
1341 [coord1, coord2, coord3, coord4] : :obj:`list` of :obj:`int`
1342 x0, y0, x1, y1 coordinates used by matplotlib
1343 to position axes in figure.
1344 """
1345 rect_x0, rect_y0, rect_x1, rect_y1 = self.rect
1347 if index == 0:
1348 coord1 = rect_x1 - rect_x0
1349 coord2 = 0.5 * (rect_y1 - rect_y0) + rect_y0
1350 coord3 = 0.5 * (rect_x1 - rect_x0) + rect_x0
1351 coord4 = rect_y1 - rect_y0
1352 elif index == 1:
1353 coord1 = 0.5 * (rect_x1 - rect_x0) + rect_x0
1354 coord2 = 0.5 * (rect_y1 - rect_y0) + rect_y0
1355 coord3 = rect_x1 - rect_x0
1356 coord4 = rect_y1 - rect_y0
1357 elif index == 2:
1358 coord1 = rect_x1 - rect_x0
1359 coord2 = rect_y1 - rect_y0
1360 coord3 = 0.5 * (rect_x1 - rect_x0) + rect_x0
1361 coord4 = 0.5 * (rect_y1 - rect_y0) + rect_y0
1362 return [coord1, coord2, coord3, coord4]
1364 def _init_axes(self, **kwargs):
1365 """Initialize and place axes for display of 'xyz' cuts.
1367 Parameters
1368 ----------
1369 kwargs : :obj:`dict`
1370 Additional arguments to pass to ``self._axes_class``.
1371 """
1372 cut_coords = self.cut_coords
1373 if len(cut_coords) != len(self._cut_displayed):
1374 raise ValueError(
1375 "The number cut_coords passed does not match the display_mode"
1376 )
1378 facecolor = "k" if self._black_bg else "w"
1380 self.axes = {}
1381 for index, direction in enumerate(self._cut_displayed):
1382 fh = self.frame_axes.get_figure()
1383 axes_coords = self._find_initial_axes_coord(index)
1384 ax = fh.add_axes(axes_coords, aspect="equal")
1386 ax.set_facecolor(facecolor)
1388 ax.axis("off")
1389 coord = self.cut_coords[
1390 sorted(self._cut_displayed).index(direction)
1391 ]
1392 display_ax = self._axes_class(ax, direction, coord, **kwargs)
1393 self.axes[direction] = display_ax
1394 ax.set_axes_locator(self._locator)
1396 def _adjust_width_height(
1397 self, width_dict, height_dict, rect_x0, rect_y0, rect_x1, rect_y1
1398 ):
1399 """Adjust absolute image width and height to ratios.
1401 Parameters
1402 ----------
1403 width_dict : :obj:`dict`
1404 Width of image cuts displayed in axes.
1406 height_dict : :obj:`dict`
1407 Height of image cuts displayed in axes.
1409 rect_x0, rect_y0, rect_x1, rect_y1 : :obj:`float`
1410 Matplotlib figure boundaries.
1412 Returns
1413 -------
1414 width_dict : :obj:`dict`
1415 Width ratios of image cuts for optimal positioning of axes.
1417 height_dict : :obj:`dict`
1418 Height ratios of image cuts for optimal positioning of axes.
1419 """
1420 total_height = 0
1421 total_width = 0
1423 if "y" in self.axes:
1424 ax = self.axes["y"].ax
1425 total_height += height_dict[ax]
1426 total_width += width_dict[ax]
1428 if "x" in self.axes:
1429 ax = self.axes["x"].ax
1430 total_width = total_width + width_dict[ax]
1432 if "z" in self.axes:
1433 ax = self.axes["z"].ax
1434 total_height = total_height + height_dict[ax]
1436 for ax, width in width_dict.items():
1437 width_dict[ax] = width / total_width * (rect_x1 - rect_x0)
1439 for ax, height in height_dict.items():
1440 height_dict[ax] = height / total_height * (rect_y1 - rect_y0)
1442 return (width_dict, height_dict)
1444 def _find_axes_coord(
1445 self,
1446 rel_width_dict,
1447 rel_height_dict,
1448 rect_x0,
1449 rect_y0,
1450 rect_x1,
1451 rect_y1,
1452 ):
1453 """Find coordinates for initial axes placement for xyz cuts.
1455 Parameters
1456 ----------
1457 rel_width_dict : :obj:`dict`
1458 Width ratios of image cuts for optimal positioning of axes.
1460 rel_height_dict : :obj:`dict`
1461 Height ratios of image cuts for optimal positioning of axes.
1463 rect_x0, rect_y0, rect_x1, rect_y1 : :obj:`float`
1464 Matplotlib figure boundaries.
1466 Returns
1467 -------
1468 coord1, coord2, coord3, coord4 : :obj:`dict`
1469 x0, y0, x1, y1 coordinates per axes used by matplotlib
1470 to position axes in figure.
1471 """
1472 coord1 = {}
1473 coord2 = {}
1474 coord3 = {}
1475 coord4 = {}
1477 if "y" in self.axes:
1478 ax = self.axes["y"].ax
1479 coord1[ax] = rect_x0
1480 coord2[ax] = (rect_y1) - rel_height_dict[ax]
1481 coord3[ax] = rect_x0 + rel_width_dict[ax]
1482 coord4[ax] = rect_y1
1484 if "x" in self.axes:
1485 ax = self.axes["x"].ax
1486 coord1[ax] = (rect_x1) - rel_width_dict[ax]
1487 coord2[ax] = (rect_y1) - rel_height_dict[ax]
1488 coord3[ax] = rect_x1
1489 coord4[ax] = rect_y1
1491 if "z" in self.axes:
1492 ax = self.axes["z"].ax
1493 coord1[ax] = rect_x0
1494 coord2[ax] = rect_y0
1495 coord3[ax] = rect_x0 + rel_width_dict[ax]
1496 coord4[ax] = rect_y0 + rel_height_dict[ax]
1498 return (coord1, coord2, coord3, coord4)
1500 def _locator(
1501 self,
1502 axes,
1503 renderer, # noqa: ARG002
1504 ):
1505 """Adjust the size of the axes.
1507 The locator function used by matplotlib to position axes.
1509 Here we put the logic used to adjust the size of the axes.
1511 ``renderer`` is required to match the matplotlib API.
1512 """
1513 rect_x0, rect_y0, rect_x1, rect_y1 = self.rect
1515 # A dummy axes, for the situation in which we are not plotting
1516 # all three (x, y, z) cuts
1517 dummy_ax = self._axes_class(None, None, None)
1518 width_dict = {dummy_ax.ax: 0}
1519 height_dict = {dummy_ax.ax: 0}
1520 display_ax_dict = self.axes
1522 if self._colorbar:
1523 adjusted_width = self._colorbar_width / len(self.axes)
1524 right_margin = self._colorbar_margin["right"] / len(self.axes)
1525 ticks_margin = self._colorbar_margin["left"] / len(self.axes)
1526 rect_x1 = rect_x1 - (adjusted_width + ticks_margin + right_margin)
1528 for display_ax in display_ax_dict.values():
1529 bounds = display_ax.get_object_bounds()
1530 if not bounds:
1531 # This happens if the call to _map_show was not
1532 # successful. As it happens asynchronously (during a
1533 # refresh of the figure) we capture the problem and
1534 # ignore it: it only adds a non informative traceback
1535 bounds = [0, 1, 0, 1]
1536 xmin, xmax, ymin, ymax = bounds
1537 width_dict[display_ax.ax] = xmax - xmin
1538 height_dict[display_ax.ax] = ymax - ymin
1540 # relative image height and width
1541 rel_width_dict, rel_height_dict = self._adjust_width_height(
1542 width_dict, height_dict, rect_x0, rect_y0, rect_x1, rect_y1
1543 )
1545 coord1, coord2, coord3, coord4 = self._find_axes_coord(
1546 rel_width_dict, rel_height_dict, rect_x0, rect_y0, rect_x1, rect_y1
1547 )
1549 return Bbox(
1550 [[coord1[axes], coord2[axes]], [coord3[axes], coord4[axes]]]
1551 )
1553 def draw_cross(self, cut_coords=None, **kwargs):
1554 """Draw a crossbar on the plot to show where the cut is performed.
1556 Parameters
1557 ----------
1558 cut_coords : 3-:obj:`tuple` of :obj:`float`, optional
1559 The position of the cross to draw. If ``None`` is passed, the
1560 ``OrthoSlicer``'s cut coordinates are used.
1562 kwargs : :obj:`dict`
1563 Extra keyword arguments are passed to function
1564 :func:`~matplotlib.pyplot.axhline`.
1565 """
1566 if cut_coords is None:
1567 cut_coords = self.cut_coords
1568 coords = {}
1569 for direction in "xyz":
1570 coord_ = None
1571 if direction in self._cut_displayed:
1572 sorted_cuts = sorted(self._cut_displayed)
1573 index = sorted_cuts.index(direction)
1574 coord_ = cut_coords[index]
1575 coords[direction] = coord_
1576 x, y, z = coords["x"], coords["y"], coords["z"]
1578 kwargs = kwargs.copy()
1579 if "color" not in kwargs:
1580 with contextlib.suppress(KeyError):
1581 kwargs["color"] = ".8" if self._black_bg else "k"
1583 if "y" in self.axes:
1584 ax = self.axes["y"].ax
1585 if x is not None:
1586 ax.axvline(x, **kwargs)
1587 if z is not None:
1588 ax.axhline(z, **kwargs)
1590 if "x" in self.axes:
1591 ax = self.axes["x"].ax
1592 if y is not None:
1593 ax.axvline(y, **kwargs)
1594 if z is not None:
1595 ax.axhline(z, **kwargs)
1597 if "z" in self.axes:
1598 ax = self.axes["z"].ax
1599 if x is not None:
1600 ax.axvline(x, **kwargs)
1601 if y is not None:
1602 ax.axhline(y, **kwargs)
1605class BaseStackedSlicer(BaseSlicer):
1606 """A class to create linked axes for plotting stacked cuts of 2D maps.
1608 Attributes
1609 ----------
1610 axes : :obj:`dict` of :class:`~matplotlib.axes.Axes`
1611 The axes used to plot each view.
1613 frame_axes : :class:`~matplotlib.axes.Axes`
1614 The axes framing the whole set of views.
1616 Notes
1617 -----
1618 The extent of the different axes are adjusted to fit the data
1619 best in the viewing area.
1620 """
1622 @classmethod
1623 def find_cut_coords(
1624 cls,
1625 img=None,
1626 threshold=None, # noqa: ARG003
1627 cut_coords=None,
1628 ):
1629 """Instantiate the slicer and find cut coordinates.
1631 Parameters
1632 ----------
1633 img : 3D :class:`~nibabel.nifti1.Nifti1Image`
1634 The brain map.
1636 threshold : :obj:`float`, optional
1637 The lower threshold to the positive activation.
1638 If ``None``, the activation threshold is computed using the
1639 80% percentile of the absolute value of the map.
1641 cut_coords : :obj:`list` of :obj:`float`, optional
1642 xyz world coordinates of cuts.
1644 Returns
1645 -------
1646 cut_coords : :obj:`list` of :obj:`float`
1647 xyz world coordinates of cuts.
1648 """
1649 if cut_coords is None:
1650 cut_coords = 7
1652 if img is None or img is False:
1653 bounds = ((-40, 40), (-30, 30), (-30, 75))
1654 lower, upper = bounds["xyz".index(cls._direction)]
1655 if isinstance(cut_coords, numbers.Number):
1656 cut_coords = np.linspace(lower, upper, cut_coords).tolist()
1657 elif not isinstance(
1658 cut_coords, collections.abc.Sequence
1659 ) and isinstance(cut_coords, numbers.Number):
1660 cut_coords = find_cut_slices(
1661 img, direction=cls._direction, n_cuts=cut_coords
1662 )
1664 return cut_coords
1666 def _init_axes(self, **kwargs):
1667 x0, y0, x1, y1 = self.rect
1668 # Create our axes:
1669 self.axes = {}
1670 fraction = 1.0 / len(self.cut_coords)
1671 for index, coord in enumerate(self.cut_coords):
1672 coord = float(coord)
1673 fh = self.frame_axes.get_figure()
1674 ax = fh.add_axes(
1675 [
1676 fraction * index * (x1 - x0) + x0,
1677 y0,
1678 fraction * (x1 - x0),
1679 y1 - y0,
1680 ]
1681 )
1682 ax.axis("off")
1683 display_ax = self._axes_class(ax, self._direction, coord, **kwargs)
1684 self.axes[coord] = display_ax
1685 ax.set_axes_locator(self._locator)
1687 if self._black_bg:
1688 for ax in self.axes.values():
1689 ax.ax.imshow(
1690 np.zeros((2, 2, 3)),
1691 extent=[-5000, 5000, -5000, 5000],
1692 zorder=-500,
1693 aspect="equal",
1694 )
1696 # To have a black background in PDF, we need to create a
1697 # patch in black for the background
1698 self.frame_axes.imshow(
1699 np.zeros((2, 2, 3)),
1700 extent=[-5000, 5000, -5000, 5000],
1701 zorder=-500,
1702 aspect="auto",
1703 )
1704 self.frame_axes.set_zorder(-1000)
1706 def _locator(
1707 self,
1708 axes,
1709 renderer, # noqa: ARG002
1710 ):
1711 """Adjust the size of the axes.
1713 The locator function used by matplotlib to position axes.
1715 Here we put the logic used to adjust the size of the axes.
1717 ``renderer`` is required to match the matplotlib API.
1718 """
1719 x0, y0, x1, y1 = self.rect
1720 width_dict = {}
1721 display_ax_dict = self.axes
1723 if self._colorbar:
1724 adjusted_width = self._colorbar_width / len(self.axes)
1725 right_margin = self._colorbar_margin["right"] / len(self.axes)
1726 ticks_margin = self._colorbar_margin["left"] / len(self.axes)
1727 x1 = x1 - (adjusted_width + right_margin + ticks_margin)
1729 for display_ax in display_ax_dict.values():
1730 bounds = display_ax.get_object_bounds()
1731 if not bounds:
1732 # This happens if the call to _map_show was not
1733 # successful. As it happens asynchronously (during a
1734 # refresh of the figure) we capture the problem and
1735 # ignore it: it only adds a non informative traceback
1736 bounds = [0, 1, 0, 1]
1737 xmin, xmax, _, _ = bounds
1738 width_dict[display_ax.ax] = xmax - xmin
1739 total_width = float(sum(width_dict.values()))
1740 for ax, width in width_dict.items():
1741 width_dict[ax] = width / total_width * (x1 - x0)
1742 left_dict = {}
1743 left = float(x0)
1744 for display_ax in display_ax_dict.values():
1745 left_dict[display_ax.ax] = left
1746 this_width = width_dict[display_ax.ax]
1747 left += this_width
1748 return Bbox(
1749 [[left_dict[axes], y0], [left_dict[axes] + width_dict[axes], y1]]
1750 )
1752 def draw_cross(self, cut_coords=None, **kwargs):
1753 """Draw a crossbar on the plot to show where the cut is performed.
1755 Parameters
1756 ----------
1757 cut_coords : 3-:obj:`tuple` of :obj:`float`, optional
1758 The position of the cross to draw. If ``None`` is passed, the
1759 ``OrthoSlicer``'s cut coordinates are used.
1761 kwargs : :obj:`dict`
1762 Extra keyword arguments are passed to function
1763 :func:`matplotlib.pyplot.axhline`.
1764 """
1765 pass
1768class XSlicer(BaseStackedSlicer):
1769 """The ``XSlicer`` class enables sagittal visualization with \
1770 plotting functions of Nilearn like \
1771 :func:`nilearn.plotting.plot_img`.
1773 This visualization mode
1774 can be activated by setting ``display_mode='x'``:
1776 .. code-block:: python
1778 from nilearn.datasets import load_mni152_template
1779 from nilearn.plotting import plot_img
1781 img = load_mni152_template()
1782 # display is an instance of the XSlicer class
1783 display = plot_img(img, display_mode="x")
1785 Attributes
1786 ----------
1787 cut_coords : 1D :class:`~numpy.ndarray`
1788 The cut coordinates.
1790 axes : :obj:`dict` of :class:`~nilearn.plotting.displays.CutAxes`
1791 The axes used for plotting.
1793 frame_axes : :class:`~matplotlib.axes.Axes`
1794 The axes framing the whole set of views.
1796 See Also
1797 --------
1798 nilearn.plotting.displays.YSlicer : Coronal view
1799 nilearn.plotting.displays.ZSlicer : Axial view
1801 """
1803 _direction: ClassVar[str] = "x"
1804 _default_figsize: ClassVar[list[float]] = [2.6, 2.3]
1807class YSlicer(BaseStackedSlicer):
1808 """The ``YSlicer`` class enables coronal visualization with \
1809 plotting functions of Nilearn like \
1810 :func:`nilearn.plotting.plot_img`.
1812 This visualization mode
1813 can be activated by setting ``display_mode='y'``:
1815 .. code-block:: python
1817 from nilearn.datasets import load_mni152_template
1818 from nilearn.plotting import plot_img
1820 img = load_mni152_template()
1821 # display is an instance of the YSlicer class
1822 display = plot_img(img, display_mode="y")
1824 Attributes
1825 ----------
1826 cut_coords : 1D :class:`~numpy.ndarray`
1827 The cut coordinates.
1829 axes : :obj:`dict` of :class:`~nilearn.plotting.displays.CutAxes`
1830 The axes used for plotting.
1832 frame_axes : :class:`~matplotlib.axes.Axes`
1833 The axes framing the whole set of views.
1835 See Also
1836 --------
1837 nilearn.plotting.displays.XSlicer : Sagittal view
1838 nilearn.plotting.displays.ZSlicer : Axial view
1840 """
1842 _direction: ClassVar[str] = "y"
1843 _default_figsize: ClassVar[list[float]] = [2.2, 3.0]
1846class ZSlicer(BaseStackedSlicer):
1847 """The ``ZSlicer`` class enables axial visualization with \
1848 plotting functions of Nilearn like \
1849 :func:`nilearn.plotting.plot_img`.
1851 This visualization mode
1852 can be activated by setting ``display_mode='z'``:
1854 .. code-block:: python
1856 from nilearn.datasets import load_mni152_template
1857 from nilearn.plotting import plot_img
1859 img = load_mni152_template()
1860 # display is an instance of the ZSlicer class
1861 display = plot_img(img, display_mode="z")
1863 Attributes
1864 ----------
1865 cut_coords : 1D :class:`~numpy.ndarray`
1866 The cut coordinates.
1868 axes : :obj:`dict` of :class:`~nilearn.plotting.displays.CutAxes`
1869 The axes used for plotting.
1871 frame_axes : :class:`~matplotlib.axes.Axes`
1872 The axes framing the whole set of views.
1874 See Also
1875 --------
1876 nilearn.plotting.displays.XSlicer : Sagittal view
1877 nilearn.plotting.displays.YSlicer : Coronal view
1879 """
1881 _direction: ClassVar[str] = "z"
1882 _default_figsize: ClassVar[list[float]] = [2.2, 3.2]
1885class XZSlicer(OrthoSlicer):
1886 """The ``XZSlicer`` class enables to combine sagittal and axial views \
1887 on the same figure with plotting functions of Nilearn like \
1888 :func:`nilearn.plotting.plot_img`.
1890 This visualization mode
1891 can be activated by setting ``display_mode='xz'``:
1893 .. code-block:: python
1895 from nilearn.datasets import load_mni152_template
1896 from nilearn.plotting import plot_img
1898 img = load_mni152_template()
1899 # display is an instance of the XZSlicer class
1900 display = plot_img(img, display_mode="xz")
1902 Attributes
1903 ----------
1904 cut_coords : :obj:`list` of :obj:`float`
1905 The cut coordinates.
1907 axes : :obj:`dict` of :class:`~nilearn.plotting.displays.CutAxes`
1908 The axes used for plotting in each direction ('x' and 'z' here).
1910 frame_axes : :class:`~matplotlib.axes.Axes`
1911 The axes framing the whole set of views.
1913 See Also
1914 --------
1915 nilearn.plotting.displays.YXSlicer : Coronal + Sagittal views
1916 nilearn.plotting.displays.YZSlicer : Coronal + Axial views
1918 """
1920 _cut_displayed = "xz"
1923class YXSlicer(OrthoSlicer):
1924 """The ``YXSlicer`` class enables to combine coronal and sagittal views \
1925 on the same figure with plotting functions of Nilearn like \
1926 :func:`nilearn.plotting.plot_img`.
1928 This visualization mode
1929 can be activated by setting ``display_mode='yx'``:
1931 .. code-block:: python
1933 from nilearn.datasets import load_mni152_template
1934 from nilearn.plotting import plot_img
1936 img = load_mni152_template()
1937 # display is an instance of the YXSlicer class
1938 display = plot_img(img, display_mode="yx")
1940 Attributes
1941 ----------
1942 cut_coords : :obj:`list` of :obj:`float`
1943 The cut coordinates.
1945 axes : :obj:`dict` of :class:`~nilearn.plotting.displays.CutAxes`
1946 The axes used for plotting in each direction ('x' and 'y' here).
1948 frame_axes : :class:`~matplotlib.axes.Axes`
1949 The axes framing the whole set of views.
1951 See Also
1952 --------
1953 nilearn.plotting.displays.XZSlicer : Sagittal + Axial views
1954 nilearn.plotting.displays.YZSlicer : Coronal + Axial views
1956 """
1958 _cut_displayed = "yx"
1961class YZSlicer(OrthoSlicer):
1962 """The ``YZSlicer`` class enables to combine coronal and axial views \
1963 on the same figure with plotting functions of Nilearn like \
1964 :func:`nilearn.plotting.plot_img`.
1966 This visualization mode
1967 can be activated by setting ``display_mode='yz'``:
1969 .. code-block:: python
1971 from nilearn.datasets import load_mni152_template
1972 from nilearn.plotting import plot_img
1974 img = load_mni152_template()
1975 # display is an instance of the YZSlicer class
1976 display = plot_img(img, display_mode="yz")
1978 Attributes
1979 ----------
1980 cut_coords : :obj:`list` of :obj:`float`
1981 The cut coordinates.
1983 axes : :obj:`dict` of :class:`~nilearn.plotting.displays.CutAxes`
1984 The axes used for plotting in each direction ('y' and 'z' here).
1986 frame_axes : :class:`~matplotlib.axes.Axes`
1987 The axes framing the whole set of views.
1989 See Also
1990 --------
1991 nilearn.plotting.displays.XZSlicer : Sagittal + Axial views
1992 nilearn.plotting.displays.YXSlicer : Coronal + Sagittal views
1994 """
1996 _cut_displayed: ClassVar[str] = "yz"
1997 _default_figsize: ClassVar[list[float]] = [2.2, 3.0]
2000class MosaicSlicer(BaseSlicer):
2001 """A class to create 3 :class:`~matplotlib.axes.Axes` for \
2002 plotting cuts of 3D maps, in multiple rows and columns.
2004 This visualization mode can be activated from Nilearn plotting
2005 functions, like :func:`~nilearn.plotting.plot_img`, by setting
2006 ``display_mode='mosaic'``.
2008 .. code-block:: python
2010 from nilearn.datasets import load_mni152_template
2011 from nilearn.plotting import plot_img
2013 img = load_mni152_template()
2014 # display is an instance of the MosaicSlicer class
2015 display = plot_img(img, display_mode="mosaic")
2017 Attributes
2018 ----------
2019 cut_coords : :obj:`dict` <:obj:`str`: 1D :class:`~numpy.ndarray`>
2020 The cut coordinates in a dictionary. The keys are the directions
2021 ('x', 'y', 'z'), and the values are arrays holding the cut
2022 coordinates.
2024 axes : :obj:`dict` of :class:`~matplotlib.axes.Axes`
2025 The 3 axes used to plot multiple views.
2027 frame_axes : :class:`~matplotlib.axes.Axes`
2028 The axes framing the whole set of views.
2030 See Also
2031 --------
2032 nilearn.plotting.displays.TiledSlicer : Three cuts are performed \
2033 in orthogonal directions.
2034 nilearn.plotting.displays.OrthoSlicer : Three cuts are performed \
2035 and arranged in a 2x2 grid.
2037 """
2039 _cut_displayed: ClassVar[str] = "yxz"
2040 _axes_class: ClassVar[CutAxes] = CutAxes # type: ignore[assignment, misc]
2041 _default_figsize: ClassVar[list[float]] = [4.0, 5.0]
2043 @classmethod
2044 def find_cut_coords(
2045 cls,
2046 img=None,
2047 threshold=None, # noqa: ARG003
2048 cut_coords=None,
2049 ):
2050 """Instantiate the slicer and find cut coordinates for mosaic plotting.
2052 Parameters
2053 ----------
2054 img : 3D :class:`~nibabel.nifti1.Nifti1Image`, optional
2055 The brain image.
2057 threshold : :obj:`float`, optional
2058 The lower threshold to the positive activation. If ``None``,
2059 the activation threshold is computed using the 80% percentile of
2060 the absolute value of the map.
2062 cut_coords : :obj:`list` / :obj:`tuple` of 3 :obj:`float`,\
2063 :obj:`int`, optional
2064 xyz world coordinates of cuts. If ``cut_coords``
2065 are not provided, 7 coordinates of cuts are automatically
2066 calculated.
2068 Returns
2069 -------
2070 cut_coords : :obj:`dict`
2071 xyz world coordinates of cuts in a direction.
2072 Each key denotes the direction.
2073 """
2074 if cut_coords is None:
2075 cut_coords = 7
2077 if not isinstance(cut_coords, collections.abc.Sequence) and isinstance(
2078 cut_coords, numbers.Number
2079 ):
2080 cut_coords = [cut_coords] * 3
2081 elif len(cut_coords) == len(cls._cut_displayed):
2082 cut_coords = [
2083 cut_coords["xyz".find(c)] for c in sorted(cls._cut_displayed)
2084 ]
2085 else:
2086 raise ValueError(
2087 "The number cut_coords passed does not"
2088 " match the display_mode. Mosaic plotting "
2089 "expects tuple of length 3."
2090 )
2091 cut_coords = cls._find_cut_coords(img, cut_coords, cls._cut_displayed)
2092 return cut_coords
2094 @staticmethod
2095 def _find_cut_coords(img, cut_coords, cut_displayed):
2096 """Find slicing positions along a given axis.
2098 Help to :func:`~nilearn.plotting.find_cut_coords`.
2100 Parameters
2101 ----------
2102 img : 3D :class:`~nibabel.nifti1.Nifti1Image`
2103 The brain image.
2105 cut_coords : :obj:`list` / :obj:`tuple` of 3 :obj:`float`,\
2106 :obj:`int`, optional
2107 xyz world coordinates of cuts.
2109 cut_displayed : :obj:`str`
2110 Sectional directions 'yxz'.
2112 Returns
2113 -------
2114 cut_coords : 1D :class:`~numpy.ndarray` of length specified\
2115 in ``n_cuts``
2116 The computed ``cut_coords``.
2117 """
2118 coords = {}
2119 if img is None or img is False:
2120 bounds = ((-40, 40), (-30, 30), (-30, 75))
2121 for direction, n_cuts in zip(sorted(cut_displayed), cut_coords):
2122 lower, upper = bounds["xyz".index(direction)]
2123 coords[direction] = np.linspace(lower, upper, n_cuts).tolist()
2124 else:
2125 for direction, n_cuts in zip(sorted(cut_displayed), cut_coords):
2126 coords[direction] = find_cut_slices(
2127 img, direction=direction, n_cuts=n_cuts
2128 )
2129 return coords
2131 def _init_axes(self, **kwargs):
2132 """Initialize and place axes for display of 'xyz' multiple cuts.
2134 Also adapts the width of the color bar relative to the axes.
2136 Parameters
2137 ----------
2138 kwargs : :obj:`dict`
2139 Additional arguments to pass to ``self._axes_class``.
2140 """
2141 if not isinstance(self.cut_coords, dict):
2142 self.cut_coords = self.find_cut_coords(cut_coords=self.cut_coords)
2144 if len(self.cut_coords) != len(self._cut_displayed):
2145 raise ValueError(
2146 "The number cut_coords passed does not match the mosaic mode"
2147 )
2148 x0, y0, x1, y1 = self.rect
2150 # Create our axes:
2151 self.axes = {}
2152 # portions for main axes
2153 fraction = y1 / len(self.cut_coords)
2154 height = fraction
2155 for index, direction in enumerate(self._cut_displayed):
2156 coords = self.cut_coords[direction]
2157 # portions allotment for each of 'x', 'y', 'z' coordinate
2158 fraction_c = 1.0 / len(coords)
2159 fh = self.frame_axes.get_figure()
2160 indices = [
2161 x0,
2162 fraction * index * (y1 - y0) + y0,
2163 x1,
2164 fraction * (y1 - y0),
2165 ]
2166 ax = fh.add_axes(indices)
2167 ax.axis("off")
2168 this_x0, this_y0, this_x1, _ = indices
2169 for index_c, coord in enumerate(coords):
2170 coord = float(coord)
2171 fh_c = ax.get_figure()
2172 # indices for each sub axes within main axes
2173 indices = [
2174 fraction_c * index_c * (this_x1 - this_x0) + this_x0,
2175 this_y0,
2176 fraction_c * (this_x1 - this_x0),
2177 height,
2178 ]
2179 ax = fh_c.add_axes(indices)
2180 ax.axis("off")
2181 display_ax = self._axes_class(ax, direction, coord, **kwargs)
2182 self.axes[(direction, coord)] = display_ax
2183 ax.set_axes_locator(self._locator)
2185 # increase color bar width to adapt to the number of cuts
2186 # see issue https://github.com/nilearn/nilearn/pull/4284
2187 self._colorbar_width *= len(coords) ** 1.1
2189 def _locator(
2190 self,
2191 axes,
2192 renderer, # noqa: ARG002
2193 ):
2194 """Adjust the size of the axes.
2196 Locator function used by matplotlib to position axes.
2198 Here we put the logic used to adjust the size of the axes.
2200 ``renderer`` is required to match the matplotlib API.
2201 """
2202 x0, y0, x1, y1 = self.rect
2203 display_ax_dict = self.axes
2205 if self._colorbar:
2206 adjusted_width = self._colorbar_width / len(self.axes)
2207 right_margin = self._colorbar_margin["right"] / len(self.axes)
2208 ticks_margin = self._colorbar_margin["left"] / len(self.axes)
2209 x1 = x1 - (adjusted_width + right_margin + ticks_margin)
2211 # capture widths for each axes for anchoring Bbox
2212 width_dict = {}
2213 for direction in self._cut_displayed:
2214 this_width = {}
2215 for display_ax in display_ax_dict.values():
2216 if direction == display_ax.direction:
2217 bounds = display_ax.get_object_bounds()
2218 if not bounds:
2219 # This happens if the call to _map_show was not
2220 # successful. As it happens asynchronously (during a
2221 # refresh of the figure) we capture the problem and
2222 # ignore it: it only adds a non informative traceback
2223 bounds = [0, 1, 0, 1]
2224 xmin, xmax, _, _ = bounds
2225 this_width[display_ax.ax] = xmax - xmin
2226 total_width = float(sum(this_width.values()))
2227 for ax, w in this_width.items():
2228 width_dict[ax] = w / total_width * (x1 - x0)
2230 left_dict = {}
2231 # bottom positions in Bbox according to cuts
2232 bottom_dict = {}
2233 # fraction is divided by the cut directions 'y', 'x', 'z'
2234 fraction = y1 / len(self._cut_displayed)
2235 height_dict = {}
2236 for index, direction in enumerate(self._cut_displayed):
2237 left = float(x0)
2238 this_height = fraction + fraction * index
2239 for display_ax in display_ax_dict.values():
2240 if direction == display_ax.direction:
2241 left_dict[display_ax.ax] = left
2242 this_width = width_dict[display_ax.ax]
2243 left += this_width
2244 bottom_dict[display_ax.ax] = fraction * index * (y1 - y0)
2245 height_dict[display_ax.ax] = this_height
2246 return Bbox(
2247 [
2248 [left_dict[axes], bottom_dict[axes]],
2249 [left_dict[axes] + width_dict[axes], height_dict[axes]],
2250 ]
2251 )
2253 def draw_cross(self, cut_coords=None, **kwargs):
2254 """Draw a crossbar on the plot to show where the cut is performed.
2256 Parameters
2257 ----------
2258 cut_coords : 3-:obj:`tuple` of :obj:`float`, optional
2259 The position of the cross to draw. If ``None`` is passed, the
2260 ``OrthoSlicer``'s cut coordinates are used.
2262 kwargs : :obj:`dict`
2263 Extra keyword arguments are passed to function
2264 :func:`matplotlib.pyplot.axhline`.
2265 """
2266 pass
2269SLICERS = {
2270 "ortho": OrthoSlicer,
2271 "tiled": TiledSlicer,
2272 "mosaic": MosaicSlicer,
2273 "xz": XZSlicer,
2274 "yz": YZSlicer,
2275 "yx": YXSlicer,
2276 "x": XSlicer,
2277 "y": YSlicer,
2278 "z": ZSlicer,
2279}
2282def get_slicer(display_mode):
2283 """Retrieve a slicer from a given display mode.
2285 Parameters
2286 ----------
2287 display_mode : :obj:`str`
2288 The desired display mode.
2289 Possible options are:
2291 - "ortho": Three cuts are performed in orthogonal directions.
2292 - "tiled": Three cuts are performed and arranged in a 2x2 grid.
2293 - "mosaic": Three cuts are performed along multiple rows and columns.
2294 - "x": Sagittal
2295 - "y": Coronal
2296 - "z": Axial
2297 - "xz": Sagittal + Axial
2298 - "yz": Coronal + Axial
2299 - "yx": Coronal + Sagittal
2301 Returns
2302 -------
2303 slicer : An instance of one of the subclasses of\
2304 :class:`~nilearn.plotting.displays.BaseSlicer`
2306 The slicer corresponding to the requested display mode:
2308 - "ortho": Returns an
2309 :class:`~nilearn.plotting.displays.OrthoSlicer`.
2310 - "tiled": Returns a
2311 :class:`~nilearn.plotting.displays.TiledSlicer`.
2312 - "mosaic": Returns a
2313 :class:`~nilearn.plotting.displays.MosaicSlicer`.
2314 - "xz": Returns a
2315 :class:`~nilearn.plotting.displays.XZSlicer`.
2316 - "yz": Returns a
2317 :class:`~nilearn.plotting.displays.YZSlicer`.
2318 - "yx": Returns a
2319 :class:`~nilearn.plotting.displays.YZSlicer`.
2320 - "x": Returns a
2321 :class:`~nilearn.plotting.displays.XSlicer`.
2322 - "y": Returns a
2323 :class:`~nilearn.plotting.displays.YSlicer`.
2324 - "z": Returns a
2325 :class:`~nilearn.plotting.displays.ZSlicer`.
2327 """
2328 return get_create_display_fun(display_mode, SLICERS)