Coverage for nilearn/plotting/tests/test_displays.py: 0%
210 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
1# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
2# vi: set ft=python sts=4 ts=4 sw=4 et:
4import matplotlib
5import matplotlib.pyplot as plt
6import numpy as np
7import pytest
8from nibabel import Nifti1Image
10from nilearn.datasets import load_mni152_template
11from nilearn.plotting.displays import (
12 BaseAxes,
13 LProjector,
14 LRProjector,
15 LYRProjector,
16 LYRZProjector,
17 LZRProjector,
18 LZRYProjector,
19 MosaicSlicer,
20 OrthoProjector,
21 OrthoSlicer,
22 RProjector,
23 TiledSlicer,
24 XProjector,
25 XSlicer,
26 XZProjector,
27 XZSlicer,
28 YProjector,
29 YSlicer,
30 YXProjector,
31 YXSlicer,
32 YZProjector,
33 YZSlicer,
34 ZProjector,
35 ZSlicer,
36)
38SLICER_KEYS = ["ortho", "tiled", "x", "y", "z", "yx", "yz", "mosaic", "xz"]
39SLICERS = [
40 OrthoSlicer,
41 TiledSlicer,
42 XSlicer,
43 YSlicer,
44 ZSlicer,
45 YXSlicer,
46 YZSlicer,
47 MosaicSlicer,
48 XZSlicer,
49]
50PROJECTOR_KEYS = [
51 "ortho",
52 "xz",
53 "yz",
54 "yx",
55 "lyrz",
56 "lyr",
57 "lzr",
58 "lr",
59 "l",
60 "r",
61]
62PROJECTORS = [
63 OrthoProjector,
64 XZProjector,
65 YZProjector,
66 YXProjector,
67 XProjector,
68 YProjector,
69 ZProjector,
70 LZRYProjector,
71 LYRZProjector,
72 LYRProjector,
73 LZRProjector,
74 LRProjector,
75 LProjector,
76 RProjector,
77]
80def test_base_axes_exceptions():
81 """Tests for exceptions raised by class ``BaseAxes``."""
82 axes = BaseAxes(None, "foo", 3)
83 # Constructor doesn't raise for invalid direction
84 assert axes.direction == "foo"
85 assert axes.coord == 3
86 with pytest.raises(
87 NotImplementedError, match="'transform_to_2d' needs to be"
88 ):
89 axes.transform_to_2d(None, None)
90 with pytest.raises(NotImplementedError, match="'draw_position' should be"):
91 axes.draw_position(None, None)
92 with pytest.raises(ValueError, match="Invalid value for direction"):
93 axes.draw_2d(None, None, None)
96def test_cut_axes_exception(affine_eye):
97 """Tests for exceptions raised by class ``CutAxes``."""
98 from nilearn.plotting.displays import CutAxes
100 axes = CutAxes(None, "foo", 2)
101 assert axes.direction == "foo"
102 assert axes.coord == 2
103 with pytest.raises(ValueError, match="Invalid value for direction"):
104 axes.transform_to_2d(None, affine_eye)
107def test_glass_brain_axes():
108 """Tests for class ``GlassBrainAxes``."""
109 from nilearn.plotting.displays import GlassBrainAxes
111 ax = plt.subplot(111)
112 axes = GlassBrainAxes(ax, "r", 2)
113 axes._add_markers(np.array([[0, 0, 0]]), "g", [10])
114 line_coords = [np.array([[0, 0, 0], [1, 1, 1]])]
115 line_values = np.array([1, 0, 6])
116 with pytest.raises(
117 ValueError, match="If vmax is set to a non-positive number "
118 ):
119 axes._add_lines(line_coords, line_values, None, vmin=None, vmax=-10)
120 axes._add_lines(line_coords, line_values, None, vmin=None, vmax=10)
121 with pytest.raises(
122 ValueError, match="If vmin is set to a non-negative number "
123 ):
124 axes._add_lines(line_coords, line_values, None, vmin=10, vmax=None)
125 axes._add_lines(line_coords, line_values, None, vmin=-10, vmax=None)
126 axes._add_lines(line_coords, line_values, None, vmin=-10, vmax=-5)
129def test_get_index_from_direction_exception():
130 """Tests that a ValueError is raised when an invalid direction \
131 is given to function ``_get_index_from_direction``.
132 """
133 from nilearn.plotting.displays._utils import _get_index_from_direction
135 with pytest.raises(ValueError, match="foo is not a valid direction."):
136 _get_index_from_direction("foo")
139@pytest.fixture
140def img():
141 """Image used for testing."""
142 return load_mni152_template(resolution=2)
145@pytest.fixture
146def cut_coords(name):
147 """Select appropriate cut coords."""
148 if name == "mosaic":
149 return 3
150 if name in ["yx", "yz", "xz"]:
151 return (0,) * 2
152 if name in ["lyrz", "lyr", "lzr"]:
153 return (0,)
154 return (0,) * 4 if name in ["lr", "l"] else (0,) * 3
157@pytest.mark.parametrize("display,name", zip(SLICERS, SLICER_KEYS))
158def test_display_basics_slicers(display, name, img, cut_coords):
159 """Basic smoke tests for all displays (slicers).
161 Each object is instantiated, ``add_overlay``, ``title``,
162 and ``close`` are then called.
163 """
164 display = display(cut_coords=cut_coords)
165 display.add_overlay(img, cmap="gray")
166 display.title(f"display mode is {name}")
167 if name != "mosaic":
168 assert display.cut_coords == cut_coords
169 assert isinstance(display.frame_axes, matplotlib.axes.Axes)
170 display.close()
173@pytest.mark.parametrize("display,name", zip(PROJECTORS, PROJECTOR_KEYS))
174def test_display_basics_projectors(display, name, img, cut_coords):
175 """Basic smoke tests for all displays (projectors).
177 Each object is instantiated, ``add_overlay``, ``title``,
178 and ``close`` are then called.
179 """
180 display = display(cut_coords=cut_coords)
181 display.add_overlay(img, cmap="gray")
182 display.title(f"display mode is {name}")
183 if name != "mosaic":
184 assert display.cut_coords == cut_coords
185 assert isinstance(display.frame_axes, matplotlib.axes.Axes)
186 display.close()
189@pytest.mark.parametrize(
190 "slicer", [XSlicer, YSlicer, ZSlicer, YXSlicer, YZSlicer, XZSlicer]
191)
192def test_stacked_slicer(slicer, img, tmp_path):
193 """Tests for saving to file with stacked slicers."""
194 cut_coords = 3 if slicer in [XSlicer, YSlicer, ZSlicer] else (3, 3)
195 slicer = slicer.init_with_figure(img=img, cut_coords=cut_coords)
196 slicer.add_overlay(img, cmap="gray")
197 # Forcing a layout here, to test the locator code
198 slicer.savefig(tmp_path / "out.png")
199 slicer.close()
202@pytest.mark.parametrize("slicer", [OrthoSlicer, TiledSlicer, MosaicSlicer])
203def test_slicer_save_to_file(slicer, img, tmp_path):
204 """Tests for saving to file with Ortho/Tiled/Mosaic slicers."""
205 cut_coords = None if slicer == MosaicSlicer else (0, 0, 0)
206 slicer = slicer.init_with_figure(
207 img=img, cut_coords=cut_coords, colorbar=True
208 )
209 slicer.add_overlay(img, cmap="gray", colorbar=True)
210 assert slicer.brain_color == (0.5, 0.5, 0.5)
211 assert not slicer.black_bg
212 # Forcing a layout here, to test the locator code
213 slicer.savefig(tmp_path / "out.png")
214 slicer.close()
217@pytest.mark.parametrize("cut_coords", [2, 4])
218def test_mosaic_slicer_integer_cut_coords(cut_coords, img):
219 """Tests for MosaicSlicer with cut_coords provided as an integer."""
220 slicer = MosaicSlicer.init_with_figure(img=img, cut_coords=cut_coords)
221 slicer.add_overlay(img, cmap="gray", colorbar=True)
222 slicer.title("mosaic mode")
223 for d in ["x", "y", "z"]:
224 assert d in slicer.cut_coords
225 assert len(slicer.cut_coords[d]) == cut_coords
226 slicer.close()
229@pytest.mark.parametrize("cut_coords", [(4, 5, 2), (1, 1, 1)])
230def test_mosaic_slicer_tuple_cut_coords(cut_coords, img):
231 """Tests for MosaicSlicer with cut_coords provided as a tuple."""
232 slicer = MosaicSlicer.init_with_figure(img=img, cut_coords=cut_coords)
233 slicer.add_overlay(img, cmap="gray", colorbar=True)
234 slicer.title("Showing mosaic mode")
235 for i, d in enumerate(["x", "y", "z"]):
236 assert len(slicer.cut_coords[d]) == cut_coords[i]
237 slicer.close()
240@pytest.mark.parametrize("cut_coords", [None, 5, (1, 1, 1)])
241def test_mosaic_slicer_img_none_false(cut_coords, img):
242 """Tests for MosaicSlicer when img is ``None`` or ``False`` \
243 while initializing the figure.
244 """
245 slicer = MosaicSlicer.init_with_figure(img=None, cut_coords=cut_coords)
246 slicer.add_overlay(img, cmap="gray", colorbar=True)
247 slicer.close()
250@pytest.mark.parametrize("cut_coords", [(5, 4), (1, 2, 3, 4)])
251def test_mosaic_slicer_wrong_inputs(cut_coords):
252 """Tests that providing wrong inputs raises a ``ValueError``."""
253 with pytest.raises(
254 ValueError,
255 match=(
256 "The number cut_coords passed does not "
257 "match the display_mode. Mosaic plotting "
258 "expects tuple of length 3."
259 ),
260 ):
261 MosaicSlicer.init_with_figure(img=None, cut_coords=cut_coords)
262 MosaicSlicer(img=None, cut_coords=cut_coords)
265@pytest.fixture
266def expected_cuts(cut_coords):
267 """Return expected cut with test_demo_mosaic_slicer."""
268 if cut_coords == (1, 1, 1):
269 return {"x": [-40.0], "y": [-30.0], "z": [-30.0]}
270 if cut_coords == 5:
271 return {
272 "x": [-40.0, -20.0, 0.0, 20.0, 40.0],
273 "y": [-30.0, -15.0, 0.0, 15.0, 30.0],
274 "z": [-30.0, -3.75, 22.5, 48.75, 75.0],
275 }
276 return {"x": [10, 20], "y": [30, 40], "z": [15, 16]}
279@pytest.mark.parametrize(
280 "cut_coords", [(1, 1, 1), 5, {"x": [10, 20], "y": [30, 40], "z": [15, 16]}]
281)
282def test_demo_mosaic_slicer(cut_coords, img, expected_cuts):
283 """Tests for MosaicSlicer with different cut_coords in constructor."""
284 slicer = MosaicSlicer(cut_coords=cut_coords)
285 slicer.add_overlay(img, cmap="gray")
286 assert slicer.cut_coords == expected_cuts
287 slicer.close()
290@pytest.mark.parametrize("projector", PROJECTORS)
291def test_projectors_basic(projector, img, tmp_path):
292 """Basic tests for projectors."""
293 projector = projector.init_with_figure(img=img)
294 projector.add_overlay(img, cmap="gray")
295 projector.savefig(tmp_path / "out.png")
296 projector.close()
299def test_contour_fillings_levels_in_add_contours(img):
300 """Tests for method ``add_contours`` of ``OrthoSlicer``."""
301 oslicer = OrthoSlicer(cut_coords=(0, 0, 0))
302 # levels should be at least 2
303 # If single levels are passed then we force upper level to be inf
304 oslicer.add_contours(img, filled=True, colors="r", alpha=0.2, levels=[0.0])
305 # If two levels are passed, it should be increasing from zero index
306 # In this case, we simply omit appending inf
307 oslicer.add_contours(
308 img, filled=True, colors="b", alpha=0.1, levels=[0.0, 0.2]
309 )
310 # without passing colors and alpha. In this case, default values are
311 # chosen from matplotlib
312 oslicer.add_contours(img, filled=True, levels=[0.0, 0.2])
314 # levels with only one value
315 # vmin argument is not needed but added because of matplotlib 3.8.0rc1 bug
316 # see https://github.com/matplotlib/matplotlib/issues/26531
317 oslicer.add_contours(img, filled=True, levels=[0.0], vmin=0.0)
319 # without passing levels, should work with default levels from
320 # matplotlib
321 oslicer.add_contours(img, filled=True)
322 oslicer.close()
325def test_user_given_cmap_with_colorbar(img):
326 """Test cmap provided as a string with ``OrthoSlicer``."""
327 oslicer = OrthoSlicer(cut_coords=(0, 0, 0))
328 oslicer.add_overlay(img, cmap="Paired", colorbar=True)
329 oslicer.close()
332@pytest.mark.parametrize("display", [OrthoSlicer, LYRZProjector])
333def test_data_complete_mask(affine_eye, display):
334 """Test for a special case due to matplotlib 2.1.0.
336 When the data is completely masked, then we have plotting issues
337 See similar issue #9280 reported in matplotlib. This function
338 tests the patch added for this particular issue.
339 """
340 # data is completely masked
341 data = np.zeros((10, 20, 30))
342 img = Nifti1Image(data, affine_eye)
343 n_cuts = 3 if display == OrthoSlicer else 4
344 display = display(cut_coords=(0,) * n_cuts)
345 display.add_overlay(img)
346 display.close()
349def test_add_markers_cut_coords_is_none():
350 """Tests a special case for ``add_markers`` when ``cut_coords`` are None.
352 This case is used when coords are placed on glass brain.
353 """
354 orthoslicer = OrthoSlicer(cut_coords=(None, None, None))
355 orthoslicer.add_markers([(0, 0, 2)])
356 orthoslicer.close()
359def test_annotations():
360 """Tests for ``display.annotate()``.
362 In particular, exercise some of the keyword arguments for scale bars.
363 """
364 orthoslicer = OrthoSlicer(cut_coords=(None, None, None))
365 orthoslicer.annotate(size=10, left_right=True, positions=False)
366 orthoslicer.annotate(
367 size=12,
368 left_right=False,
369 positions=False,
370 scalebar=True,
371 scale_size=2.5,
372 scale_units="cm",
373 scale_loc=3,
374 frameon=True,
375 )
376 orthoslicer.close()
379def test_position_annotation_with_decimals():
380 """Test of decimals position annotation with precision of 2."""
381 orthoslicer = OrthoSlicer(cut_coords=(0, 0, 0))
382 orthoslicer.annotate(positions=True, decimals=2)
383 orthoslicer.close()
386@pytest.mark.parametrize("node_color", ["red", ["red", "blue"]])
387def test_add_graph_with_node_color_as_string(node_color):
388 """Tests for ``display.add_graph()``."""
389 lzry_projector = LZRYProjector(cut_coords=(0, 0, 0, 0))
390 matrix = np.array([[0, 3], [3, 0]])
391 node_coords = [[-53.60, -62.80, 36.64], [23.87, 0.31, 69.42]]
392 lzry_projector.add_graph(matrix, node_coords, node_color=node_color)
393 lzry_projector.close()
396@pytest.mark.parametrize(
397 "threshold,vmin,vmax,expected_results",
398 [
399 (None, None, None, [[-2, -1, 0], [0, 1, 2]]),
400 (0.5, None, None, [[-2, -1, np.nan], [np.nan, 1, 2]]),
401 (1, 0, None, [[np.nan, np.nan, np.nan], [np.nan, np.nan, 2]]),
402 (1, None, 1, [[-2, np.nan, np.nan], [np.nan, np.nan, np.nan]]),
403 (0, 0, 0, [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan]]),
404 ],
405)
406def test_threshold(threshold, vmin, vmax, expected_results):
407 """Tests for ``OrthoSlicer._threshold``."""
408 data = np.array([[-2, -1, 0], [0, 1, 2]], dtype=float)
409 assert np.ma.allequal(
410 OrthoSlicer._threshold(data, threshold, vmin, vmax),
411 np.ma.masked_invalid(expected_results),
412 )
415@pytest.mark.parametrize("transparency", [None, 0, 0.5, 1])
416@pytest.mark.parametrize("display,name", zip(SLICERS, SLICER_KEYS))
417def test_display_slicers_transparency(
418 display, img, name, cut_coords, transparency
419):
420 """Test several valid transparency values.
422 Also make sure warning is thrown that alpha value is overridden.
423 """
424 display = display(cut_coords=cut_coords)
425 with pytest.warns(UserWarning, match="Overriding with"):
426 display.add_overlay(
427 img, cmap=plt.cm.gray, transparency=transparency, alpha=0.5
428 )
429 display.title(f"display mode is {name}")
432@pytest.mark.parametrize("transparency", [-2, 10])
433@pytest.mark.parametrize("display,name", zip(SLICERS, SLICER_KEYS))
434def test_display_slicers_transparency_warning(
435 display, img, name, cut_coords, transparency
436):
437 """Test several invalid transparency values throw warnings."""
438 display = display(cut_coords=cut_coords)
439 with pytest.warns(UserWarning, match="Setting it to"):
440 display.add_overlay(img, cmap=plt.cm.gray, transparency=transparency)
441 display.title(f"display mode is {name}")
444@pytest.mark.parametrize("transparency", [None, 0, 0.5, 1])
445@pytest.mark.parametrize("display,name", zip(PROJECTORS, PROJECTOR_KEYS))
446def test_display_projectors_transparency(
447 display, img, name, cut_coords, transparency
448):
449 """Test several valid transparency values.
451 Also make sure warning is thrown that alpha value is overridden.
452 """
453 display = display(cut_coords=cut_coords)
454 with pytest.warns(UserWarning, match="Overriding with"):
455 display.add_overlay(
456 img, cmap=plt.cm.gray, transparency=transparency, alpha=0.5
457 )
458 display.title(f"display mode is {name}")
461@pytest.mark.parametrize("transparency", [-2, 10])
462@pytest.mark.parametrize("display,name", zip(PROJECTORS, PROJECTOR_KEYS))
463def test_display_projectors_transparency_warning(
464 display, img, name, cut_coords, transparency
465):
466 """Test several invalid transparency values throw warnings."""
467 display = display(cut_coords=cut_coords)
468 with pytest.warns(UserWarning, match="Setting it to"):
469 display.add_overlay(img, cmap=plt.cm.gray, transparency=transparency)
470 display.title(f"display mode is {name}")