Coverage for nilearn/plotting/displays/_axes.py: 0%
202 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
1import numbers
2import warnings
4import matplotlib.pyplot as plt
5import numpy as np
6from matplotlib.colors import Normalize
7from matplotlib.font_manager import FontProperties
8from matplotlib.lines import Line2D
9from matplotlib.patches import FancyArrow
10from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
12from nilearn._utils import fill_doc
13from nilearn._utils.logger import find_stack_level
14from nilearn.image import coord_transform
15from nilearn.plotting.displays._utils import coords_3d_to_2d
16from nilearn.plotting.glass_brain import plot_brain_schematics
19@fill_doc
20class BaseAxes:
21 """An MPL axis-like object that displays a 2D view of 3D volumes.
23 Parameters
24 ----------
25 %(ax)s
26 direction : {'x', 'y', 'z'}
27 The directions of the view.
29 coord : :obj:`float`
30 The coordinate along the direction of the cut.
31 %(radiological)s
32 """
34 def __init__(self, ax, direction, coord, radiological=False):
35 self.ax = ax
36 self.direction = direction
37 self.coord = coord
38 self._object_bounds = []
39 self.shape = None
40 self.radiological = radiological
42 def transform_to_2d(self, data, affine):
43 """Transform to a 2D."""
44 raise NotImplementedError(
45 "'transform_to_2d' needs to be implemented in derived classes'"
46 )
48 def add_object_bounds(self, bounds):
49 """Ensure that axes get rescaled when adding object bounds."""
50 old_object_bounds = self.get_object_bounds()
51 self._object_bounds.append(bounds)
52 new_object_bounds = self.get_object_bounds()
54 if new_object_bounds != old_object_bounds:
55 self.ax.axis(self.get_object_bounds())
57 def draw_2d(
58 self,
59 data_2d,
60 data_bounds,
61 bounding_box,
62 type="imshow",
63 transparency=None,
64 **kwargs,
65 ):
66 """Draw 2D."""
67 kwargs["origin"] = "upper"
69 if "alpha" in kwargs:
70 warnings.warn(
71 f"{kwargs['alpha']=} detected in parameters.\n"
72 f"Overriding with {transparency=}.\n"
73 "To suppress this warning pass "
74 "your 'alpha' value "
75 "via the 'transparency' parameter.",
76 stacklevel=find_stack_level(),
77 )
78 kwargs["alpha"] = transparency
80 if self.direction == "y":
81 (xmin, xmax), (_, _), (zmin, zmax) = data_bounds
82 (xmin_, xmax_), (_, _), (zmin_, zmax_) = bounding_box
83 elif self.direction in "xlr":
84 (_, _), (xmin, xmax), (zmin, zmax) = data_bounds
85 (_, _), (xmin_, xmax_), (zmin_, zmax_) = bounding_box
86 elif self.direction == "z":
87 (xmin, xmax), (zmin, zmax), (_, _) = data_bounds
88 (xmin_, xmax_), (zmin_, zmax_), (_, _) = bounding_box
89 else:
90 raise ValueError(f"Invalid value for direction {self.direction}")
91 ax = self.ax
92 # Here we need to do a copy to avoid having the image changing as
93 # we change the data
94 im = getattr(ax, type)(
95 data_2d.copy(), extent=(xmin, xmax, zmin, zmax), **kwargs
96 )
98 self.add_object_bounds((xmin_, xmax_, zmin_, zmax_))
99 self.shape = data_2d.T.shape
100 # The bounds of the object do not take into account a possible
101 # inversion of the axis. As such, we check that the axis is properly
102 # inverted when direction is left
103 if self.direction == "l" and not (ax.get_xlim()[0] > ax.get_xlim()[1]):
104 ax.invert_xaxis()
105 return im
107 def get_object_bounds(self):
108 """Return the bounds of the objects on this axes."""
109 if len(self._object_bounds) == 0:
110 # Nothing plotted yet
111 return -0.01, 0.01, -0.01, 0.01
112 xmins, xmaxs, ymins, ymaxs = np.array(self._object_bounds).T
113 xmax = max(xmaxs.max(), xmins.max())
114 xmin = min(xmins.min(), xmaxs.min())
115 ymax = max(ymaxs.max(), ymins.max())
116 ymin = min(ymins.min(), ymaxs.min())
118 return xmin, xmax, ymin, ymax
120 def draw_left_right(self, size, bg_color, **kwargs):
121 """Draw the annotation "L" for left, and "R" for right.
123 Parameters
124 ----------
125 size : :obj:`float`, optional
126 Size of the text areas.
128 bg_color : matplotlib color: :obj:`str` or (r, g, b) value
129 The background color for both text areas.
131 """
132 if self.direction in "xlr":
133 return
134 ax = self.ax
135 annotation_on_left = "L"
136 annotation_on_right = "R"
137 if self.radiological:
138 annotation_on_left = "R"
139 annotation_on_right = "L"
140 ax.text(
141 0.1,
142 0.95,
143 annotation_on_left,
144 transform=ax.transAxes,
145 horizontalalignment="left",
146 verticalalignment="top",
147 size=size,
148 bbox={
149 "boxstyle": "square,pad=0",
150 "ec": bg_color,
151 "fc": bg_color,
152 "alpha": 1,
153 },
154 **kwargs,
155 )
157 ax.text(
158 0.9,
159 0.95,
160 annotation_on_right,
161 transform=ax.transAxes,
162 horizontalalignment="right",
163 verticalalignment="top",
164 size=size,
165 bbox={"boxstyle": "square,pad=0", "ec": bg_color, "fc": bg_color},
166 **kwargs,
167 )
169 def draw_scale_bar(
170 self,
171 bg_color,
172 size=5.0,
173 units="cm",
174 fontproperties=None,
175 frameon=False,
176 loc=4,
177 pad=0.1,
178 borderpad=0.5,
179 sep=5,
180 size_vertical=0,
181 label_top=False,
182 color="black",
183 fontsize=None,
184 **kwargs,
185 ):
186 """Add a scale bar annotation to the display.
188 Parameters
189 ----------
190 bg_color : matplotlib color: :obj:`str` or (r, g, b) value
191 The background color of the scale bar annotation.
193 size : :obj:`float`, default=5.0
194 Horizontal length of the scale bar, given in `units`.
197 units : :obj:`str`, default='cm'
198 Physical units of the scale bar (`'cm'` or `'mm'`).
201 fontproperties : :class:`~matplotlib.font_manager.FontProperties`\
202 or :obj:`dict`, optional
203 Font properties for the label text.
205 frameon : :obj:`bool`, default=False
206 Whether the scale bar is plotted with a border.
208 loc : :obj:`int`, default=4
209 Location of this scale bar.
210 Valid location codes are documented in
211 :class:`~mpl_toolkits.axes_grid1.anchored_artists.AnchoredSizeBar`
213 pad : :obj:`int` or :obj:`float`, default=0.1
214 Padding around the label and scale bar, in fraction of the font
215 size.
217 borderpad : :obj:`int` or :obj:`float`, default=0.5
218 Border padding, in fraction of the font size.
220 sep : :obj:`int` or :obj:`float`, default=5
221 Separation between the label and the scale bar, in points.
224 size_vertical : :obj:`int` or :obj:`float`, default=0
225 Vertical length of the size bar, given in `units`.
228 label_top : :obj:`bool`, default=False
229 If ``True``, the label will be over the scale bar.
232 color : :obj:`str`, default='black'
233 Color for the scale bar and label.
235 fontsize : :obj:`int`, optional
236 Label font size (overwrites the size passed in through the
237 ``fontproperties`` argument).
239 **kwargs :
240 Keyworded arguments to pass to
241 :class:`~matplotlib.offsetbox.AnchoredOffsetbox`.
243 """
244 axis = self.ax
245 fontproperties = fontproperties or FontProperties()
246 if fontsize:
247 fontproperties.set_size(fontsize)
248 width_mm = size
249 if units == "cm":
250 width_mm *= 10
252 anchor_size_bar = AnchoredSizeBar(
253 axis.transData,
254 width_mm,
255 f"{size:g}{units}",
256 fontproperties=fontproperties,
257 frameon=frameon,
258 loc=loc,
259 pad=pad,
260 borderpad=borderpad,
261 sep=sep,
262 size_vertical=size_vertical,
263 label_top=label_top,
264 color=color,
265 **kwargs,
266 )
268 if frameon:
269 anchor_size_bar.patch.set_facecolor(bg_color)
270 anchor_size_bar.patch.set_edgecolor("none")
271 axis.add_artist(anchor_size_bar)
273 def draw_position(self, size, bg_color, **kwargs):
274 """``draw_position`` is not implemented in base class and \
275 should be implemented in derived classes.
276 """
277 raise NotImplementedError(
278 "'draw_position' should be implemented in derived classes"
279 )
282@fill_doc
283class CutAxes(BaseAxes):
284 """An MPL axis-like object that displays a cut of 3D volumes.
286 Parameters
287 ----------
288 %(ax)s
289 direction : {'x', 'y', 'z'}
290 The directions of the view.
292 coord : :obj:`float`
293 The coordinate along the direction of the cut.
294 """
296 def transform_to_2d(self, data, affine):
297 """Cut the 3D volume into a 2D slice.
299 Parameters
300 ----------
301 data : 3D :class:`~numpy.ndarray`
302 The 3D volume to cut.
304 affine : 4x4 :class:`~numpy.ndarray`
305 The affine of the volume.
307 """
308 coords = [0, 0, 0]
309 if self.direction not in ["x", "y", "z"]:
310 raise ValueError(f"Invalid value for direction {self.direction}")
311 coords["xyz".index(self.direction)] = self.coord
312 x_map, y_map, z_map = (
313 int(np.round(c))
314 for c in coord_transform(
315 coords[0], coords[1], coords[2], np.linalg.inv(affine)
316 )
317 )
318 if self.direction == "y":
319 cut = np.rot90(data[:, y_map, :])
320 elif self.direction == "x":
321 cut = np.rot90(data[x_map, :, :])
322 elif self.direction == "z":
323 cut = np.rot90(data[:, :, z_map])
324 return cut
326 def draw_position(self, size, bg_color, decimals=False, **kwargs):
327 """Draw coordinates.
329 Parameters
330 ----------
331 size : :obj:`float`, optional
332 Size of the text area.
334 bg_color : matplotlib color: :obj:`str` or (r, g, b) value
335 The background color for text area.
337 decimals : :obj:`bool` or :obj:`str`, default=False
338 Formatting string for the coordinates.
339 If set to ``False``, integer formatting will be used.
342 """
343 if decimals:
344 text = f"%s=%.{decimals}f"
345 coord = float(self.coord)
346 else:
347 text = "%s=%i"
348 coord = self.coord
349 ax = self.ax
350 ax.text(
351 0,
352 0,
353 text % (self.direction, coord),
354 transform=ax.transAxes,
355 horizontalalignment="left",
356 verticalalignment="bottom",
357 size=size,
358 bbox={
359 "boxstyle": "square,pad=0",
360 "ec": bg_color,
361 "fc": bg_color,
362 "alpha": 1,
363 },
364 **kwargs,
365 )
368@fill_doc
369class GlassBrainAxes(BaseAxes):
370 """An MPL axis-like object that displays a 2D projection of 3D \
371 volumes with a schematic view of the brain.
373 Parameters
374 ----------
375 %(ax)s
376 direction : {'x', 'y', 'z'}
377 The directions of the view.
379 coord : :obj:`float`
380 The coordinate along the direction of the cut.
382 plot_abs : :obj:`bool`, default=True
383 If set to ``True`` the absolute value of the data will be considered.
385 """
387 def __init__(
388 self, ax, direction, coord, plot_abs=True, radiological=False, **kwargs
389 ):
390 super().__init__(ax, direction, coord, radiological=radiological)
391 self._plot_abs = plot_abs
392 if ax is not None:
393 object_bounds = plot_brain_schematics(ax, direction, **kwargs)
394 self.add_object_bounds(object_bounds)
396 def transform_to_2d(self, data, affine):
397 """Return the maximum of the absolute value of the 3D volume \
398 along an axis.
400 Parameters
401 ----------
402 data : 3D :class:`numpy.ndarray`
403 The 3D volume.
405 affine : 4x4 :class:`numpy.ndarray`
406 The affine of the volume.
408 """
409 max_axis = (
410 0 if self.direction in "xlr" else ".yz".index(self.direction)
411 )
412 # set unselected brain hemisphere activations to 0
413 if self.direction == "l":
414 x_center, _, _, _ = np.dot(
415 np.linalg.inv(affine), np.array([0, 0, 0, 1])
416 )
417 data_selection = data[: int(x_center), :, :]
418 elif self.direction == "r":
419 x_center, _, _, _ = np.dot(
420 np.linalg.inv(affine), np.array([0, 0, 0, 1])
421 )
422 data_selection = data[int(x_center) :, :, :]
423 else:
424 data_selection = data
426 # We need to make sure data_selection is not empty in the x axis
427 # This should be the case since we expect images in MNI space
428 if data_selection.shape[0] == 0:
429 data_selection = data
431 if not self._plot_abs:
432 # get the shape of the array we are projecting to
433 new_shape = list(data.shape)
434 del new_shape[max_axis]
436 # generate a 3D indexing array that points to max abs value in the
437 # current projection
438 a1, a2 = np.indices(new_shape)
439 inds = [a1, a2]
440 inds.insert(max_axis, np.abs(data_selection).argmax(axis=max_axis))
442 # take the values where the absolute value of the projection
443 # is the highest
444 maximum_intensity_data = data_selection[tuple(inds)]
445 else:
446 maximum_intensity_data = np.abs(data_selection).max(axis=max_axis)
448 # This work around can be removed bumping matplotlib > 2.1.0. See #1815
449 # in nilearn for the invention of this work around
450 if (
451 self.direction == "l"
452 and data_selection.min() is np.ma.masked
453 and not (self.ax.get_xlim()[0] > self.ax.get_xlim()[1])
454 ):
455 self.ax.invert_xaxis()
457 return np.rot90(maximum_intensity_data)
459 def draw_position(self, size, bg_color, **kwargs):
460 """Not implemented as it does not make sense to draw crosses for \
461 the position of the cuts \
462 since we are taking the max along one axis.
463 """
464 pass
466 def _add_markers(self, marker_coords, marker_color, marker_size, **kwargs):
467 """Plot markers.
469 In the case of 'l' and 'r' directions (for hemispheric projections),
470 markers in the coordinate x == 0 are included in both hemispheres.
471 """
472 marker_coords_2d = coords_3d_to_2d(marker_coords, self.direction)
473 xdata, ydata = marker_coords_2d.T
475 # Allow markers only in their respective hemisphere when appropriate
476 if self.direction in "lr":
477 if not isinstance(marker_color, str) and not isinstance(
478 marker_color, np.ndarray
479 ):
480 marker_color = np.asarray(marker_color)
481 relevant_coords = []
482 xcoords, _, _ = marker_coords.T
483 relevant_coords.extend(
484 cidx
485 for cidx, xc in enumerate(xcoords)
486 if (self.direction == "r" and xc >= 0)
487 or (self.direction == "l" and xc <= 0)
488 )
489 xdata = xdata[relevant_coords]
490 ydata = ydata[relevant_coords]
491 # if marker_color is string for example 'red' or 'blue', then
492 # we pass marker_color as it is to matplotlib scatter without
493 # making any selection in 'l' or 'r' color.
494 # More likely that user wants to display all nodes to be in
495 # same color.
496 if not isinstance(marker_color, str) and len(marker_color) != 1:
497 marker_color = marker_color[relevant_coords]
499 if not isinstance(marker_size, numbers.Number):
500 marker_size = np.asarray(marker_size)[relevant_coords]
502 defaults = {"marker": "o", "zorder": 1000}
503 for k, v in defaults.items():
504 kwargs.setdefault(k, v)
506 self.ax.scatter(xdata, ydata, s=marker_size, c=marker_color, **kwargs)
508 def _add_lines(
509 self,
510 line_coords,
511 line_values,
512 cmap,
513 vmin=None,
514 vmax=None,
515 directed=False,
516 **kwargs,
517 ):
518 """Plot lines.
520 Parameters
521 ----------
522 line_coords : :obj:`list` of :class:`numpy.ndarray` of shape (2, 3)
523 3D coordinates of lines start points and end points.
525 line_values : array_like
526 Values of the lines.
528 %(cmap)s
529 Colormap used to map ``line_values`` to a color.
531 vmin, vmax : :obj:`float`, optional
532 If not ``None``, either or both of these values will be used to
533 as the minimum and maximum values to color lines. If ``None`` are
534 supplied the maximum absolute value within the given threshold
535 will be used as minimum (multiplied by -1) and maximum
536 coloring levels.
538 directed : :obj:`bool`, default=False
539 Add arrows instead of lines if set to ``True``.
540 Use this when plotting directed graphs for example.
543 kwargs : :obj:`dict`
544 Additional arguments to pass to :class:`~matplotlib.lines.Line2D`.
546 """
547 # colormap for colorbar
548 self.cmap = cmap
549 if vmin is None and vmax is None:
550 abs_line_values_max = np.abs(line_values).max()
551 vmin = -abs_line_values_max
552 vmax = abs_line_values_max
553 elif vmin is None:
554 if vmax > 0:
555 vmin = -vmax
556 else:
557 raise ValueError(
558 "If vmax is set to a non-positive number "
559 "then vmin needs to be specified"
560 )
561 elif vmax is None:
562 if vmin < 0:
563 vmax = -vmin
564 else:
565 raise ValueError(
566 "If vmin is set to a non-negative number "
567 "then vmax needs to be specified"
568 )
569 norm = Normalize(vmin=vmin, vmax=vmax)
570 # normalization useful for colorbar
571 self.norm = norm
572 abs_norm = Normalize(vmin=0, vmax=max(abs(vmax), abs(vmin)))
573 value_to_color = plt.cm.ScalarMappable(norm=norm, cmap=cmap).to_rgba
575 # Allow lines only in their respective hemisphere when appropriate
576 if self.direction in "lr":
577 relevant_lines = [
578 lidx
579 for lidx, line in enumerate(line_coords)
580 if (
581 self.direction == "r"
582 and line[0, 0] >= 0
583 and line[1, 0] >= 0
584 )
585 or (
586 self.direction == "l" and line[0, 0] < 0 and line[1, 0] < 0
587 )
588 ]
589 line_coords = np.array(line_coords)[relevant_lines]
590 line_values = line_values[relevant_lines]
592 for start_end_point_3d, line_value in zip(line_coords, line_values):
593 start_end_point_2d = coords_3d_to_2d(
594 start_end_point_3d, self.direction
595 )
597 color = value_to_color(line_value)
598 abs_line_value = abs(line_value)
599 linewidth = 1 + 2 * abs_norm(abs_line_value)
600 # Hacky way to put the strongest connections on top of the weakest
601 # note sign does not matter hence using 'abs'
602 zorder = 10 + 10 * abs_norm(abs_line_value)
603 this_kwargs = {
604 "color": color,
605 "linewidth": linewidth,
606 "zorder": zorder,
607 }
608 # kwargs should have priority over this_kwargs so that the
609 # user can override the default logic
610 this_kwargs.update(kwargs)
611 xdata, ydata = start_end_point_2d.T
612 # If directed is True, add an arrow
613 if directed:
614 dx = xdata[1] - xdata[0]
615 dy = ydata[1] - ydata[0]
616 # Hack to avoid empty arrows to crash with
617 # matplotlib versions older than 3.1
618 # This can be removed once support for
619 # matplotlib pre 3.1 has been dropped.
620 if dx == dy == 0:
621 arrow = FancyArrow(xdata[0], ydata[0], dx, dy)
622 else:
623 arrow = FancyArrow(
624 xdata[0],
625 ydata[0],
626 dx,
627 dy,
628 length_includes_head=True,
629 width=linewidth,
630 head_width=3 * linewidth,
631 **this_kwargs,
632 )
633 self.ax.add_patch(arrow)
634 # Otherwise a line
635 else:
636 line = Line2D(xdata, ydata, **this_kwargs)
637 self.ax.add_line(line)