Coverage for nilearn/plotting/tests/test_baseline_comparisons.py: 0%
180 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
1"""
2Test if public plotting functions' output has changed.
4See the maintenance page of our documentation for more information
5https://nilearn.github.io/dev/maintenance.html#generating-new-baseline-figures-for-plotting-tests
6"""
8import numpy as np
9import pandas as pd
10import pytest
11from matplotlib import pyplot as plt
13from nilearn.datasets import (
14 load_fsaverage_data,
15 load_mni152_template,
16 load_sample_motor_activation_image,
17)
18from nilearn.glm.first_level.design_matrix import (
19 make_first_level_design_matrix,
20)
21from nilearn.glm.tests._testing import modulated_event_paradigm
22from nilearn.plotting import (
23 plot_anat,
24 plot_bland_altman,
25 plot_carpet,
26 plot_connectome,
27 plot_contrast_matrix,
28 plot_design_matrix,
29 plot_design_matrix_correlation,
30 plot_epi,
31 plot_event,
32 plot_glass_brain,
33 plot_img,
34 plot_img_comparison,
35 plot_matrix,
36 plot_prob_atlas,
37 plot_roi,
38 plot_stat_map,
39 plot_surf,
40 plot_surf_roi,
41 plot_surf_stat_map,
42)
43from nilearn.plotting.img_plotting import MNI152TEMPLATE
45PLOTTING_FUNCS_3D = {
46 plot_img,
47 plot_anat,
48 plot_stat_map,
49 plot_roi,
50 plot_epi,
51 plot_glass_brain,
52}
54PLOTTING_FUNCS_4D = {plot_prob_atlas, plot_carpet}
56SURFACE_FUNCS = {
57 plot_surf,
58 plot_surf_stat_map,
59 plot_surf_roi,
60}
63@pytest.mark.mpl_image_compare
64@pytest.mark.parametrize("plot_func", PLOTTING_FUNCS_3D)
65def test_plot_functions_black_bg(plot_func, img_3d_mni):
66 """Test parameter for black background.
68 black_bg=False being the default it should be covered by other tests.
69 """
70 return plot_func(img_3d_mni, black_bg=True)
73@pytest.mark.mpl_image_compare
74@pytest.mark.parametrize("plot_func", PLOTTING_FUNCS_3D)
75def test_plot_functions_title(plot_func, img_3d_mni):
76 """Test parameter for title.
78 title=None being the default it should be covered by other tests.
79 """
80 return plot_func(img_3d_mni, title="foo")
83@pytest.mark.mpl_image_compare
84@pytest.mark.parametrize("plot_func", PLOTTING_FUNCS_3D)
85def test_plot_functions_annotate(plot_func, img_3d_mni):
86 """Test parameter for annotate=False.
88 annotate=True being the default it should be covered by other tests.
89 """
90 return plot_func(img_3d_mni, annotate=False)
93@pytest.mark.mpl_image_compare
94@pytest.mark.parametrize(
95 "display_mode", ["x", "y", "z", "yx", "xz", "yz", "ortho"]
96)
97def test_plot_stat_map_display_mode(display_mode):
98 """Test parameter for display_mode.
100 Only test one function to speed up testing.
101 """
102 return plot_stat_map(
103 load_sample_motor_activation_image(), display_mode=display_mode
104 )
107@pytest.mark.mpl_image_compare
108@pytest.mark.parametrize("plot_func", PLOTTING_FUNCS_3D)
109def test_plot_functions_no_colorbar(plot_func, img_3d_mni):
110 """Test no colorbar.
112 colorbar=True being the default it should be covered by other tests.
113 """
114 return plot_func(
115 img_3d_mni,
116 colorbar=False,
117 )
120@pytest.mark.mpl_image_compare
121@pytest.mark.parametrize("plot_func", PLOTTING_FUNCS_3D)
122def test_plot_functions_colorbar_ticks(plot_func, img_3d_mni):
123 """Test parameter for colorbar."""
124 return plot_func(
125 img_3d_mni,
126 cbar_tick_format="%f",
127 )
130@pytest.mark.mpl_image_compare(tolerance=5)
131@pytest.mark.parametrize("plot_func", PLOTTING_FUNCS_3D)
132@pytest.mark.parametrize("vmin", [-1, 1])
133def test_plot_functions_vmin(plot_func, vmin):
134 """Test 3D plotting functions with vmin."""
135 return plot_func(load_sample_motor_activation_image(), vmin=vmin)
138@pytest.mark.mpl_image_compare(tolerance=5)
139@pytest.mark.parametrize("plot_func", PLOTTING_FUNCS_3D)
140@pytest.mark.parametrize("vmax", [2, 3])
141def test_plot_functions_vmax(plot_func, vmax):
142 """Test 3D plotting functions with vmax."""
143 return plot_func(load_sample_motor_activation_image(), vmax=vmax)
146@pytest.mark.mpl_image_compare(tolerance=5)
147@pytest.mark.parametrize("plotting_func", PLOTTING_FUNCS_3D)
148def test_plotting_functions_radiological_view(plotting_func):
149 """Test for radiological view.
151 radiological=False being the default it should be covered by other tests.
152 """
153 radiological = True
154 result = plotting_func(
155 load_sample_motor_activation_image(), radiological=radiological
156 )
157 assert result.axes.get("y").radiological is radiological
158 return result
161@pytest.mark.mpl_image_compare
162def test_plot_carpet_default_params(img_4d_mni, img_3d_ones_mni):
163 """Smoke-test for 4D plot_carpet with default arguments."""
164 return plot_carpet(img_4d_mni, mask_img=img_3d_ones_mni)
167@pytest.mark.timeout(0)
168@pytest.mark.mpl_image_compare
169def test_plot_prob_atlas_default_params(img_3d_mni, img_4d_mni):
170 """Smoke-test for plot_prob_atlas with default arguments."""
171 return plot_prob_atlas(img_4d_mni, bg_img=img_3d_mni)
174@pytest.mark.mpl_image_compare
175@pytest.mark.parametrize("anat_img", [False, MNI152TEMPLATE])
176def test_plot_anat_mni(anat_img):
177 """Tests for plot_anat with MNI template."""
178 return plot_anat(anat_img=anat_img)
181@pytest.mark.mpl_image_compare
182@pytest.mark.parametrize("colorbar", [True, False])
183def test_plot_connectome_colorbar(colorbar, adjacency, node_coords):
184 """Smoke test for plot_connectome with default parameters \
185 and with and without the colorbar.
186 """
187 return plot_connectome(adjacency, node_coords, colorbar=colorbar)
190@pytest.mark.mpl_image_compare
191@pytest.mark.parametrize(
192 "node_color",
193 [["green", "blue", "k", "cyan"], np.array(["red"]), ["red"], "green"],
194)
195def test_plot_connectome_node_colors(
196 node_color, node_coords, adjacency, params_plot_connectome
197):
198 """Smoke test for plot_connectome with different values for node_color."""
199 return plot_connectome(
200 adjacency,
201 node_coords,
202 node_color=node_color,
203 **params_plot_connectome,
204 )
207@pytest.mark.mpl_image_compare
208@pytest.mark.parametrize("alpha", [0.0, 0.7, 1.0])
209def test_plot_connectome_alpha(alpha, adjacency, node_coords):
210 """Smoke test for plot_connectome with various alpha values."""
211 return plot_connectome(adjacency, node_coords, alpha=alpha)
214@pytest.mark.mpl_image_compare
215@pytest.mark.parametrize(
216 "display_mode",
217 [
218 "ortho",
219 "x",
220 "y",
221 "z",
222 "xz",
223 "yx",
224 "yz",
225 "l",
226 "r",
227 "lr",
228 "lzr",
229 "lyr",
230 "lzry",
231 "lyrz",
232 ],
233)
234def test_plot_connectome_display_mode(
235 display_mode, node_coords, adjacency, params_plot_connectome
236):
237 """Smoke test for plot_connectome with different values \
238 for display_mode.
239 """
240 return plot_connectome(
241 adjacency,
242 node_coords,
243 display_mode=display_mode,
244 **params_plot_connectome,
245 )
248@pytest.mark.mpl_image_compare
249def test_plot_connectome_node_and_edge_kwargs(adjacency, node_coords):
250 """Smoke test for plot_connectome with node_kwargs, edge_kwargs, \
251 and edge_cmap arguments.
252 """
253 return plot_connectome(
254 adjacency,
255 node_coords,
256 edge_threshold="70%",
257 node_size=[10, 20, 30, 40],
258 node_color=np.zeros((4, 3)),
259 edge_cmap="RdBu",
260 colorbar=True,
261 node_kwargs={"marker": "v"},
262 edge_kwargs={"linewidth": 4},
263 )
266# ---------------------- surface plotting -------------------------------
269@pytest.mark.mpl_image_compare(tolerance=5)
270@pytest.mark.parametrize("plot_func", SURFACE_FUNCS)
271@pytest.mark.parametrize(
272 "view",
273 [
274 "anterior",
275 "posterior",
276 "dorsal",
277 "ventral",
278 ],
279)
280@pytest.mark.parametrize("hemi", ["left", "right", "both"])
281def test_plot_surf_surface(plot_func, view, hemi):
282 """Test surface plotting functions with views and hemispheres."""
283 surf_img = load_fsaverage_data()
284 return plot_func(
285 surf_img.mesh,
286 surf_img,
287 engine="matplotlib",
288 view=view,
289 hemi=hemi,
290 title=f"{view=}, {hemi=}",
291 )
294@pytest.mark.mpl_image_compare(tolerance=5)
295@pytest.mark.parametrize("plot_func", SURFACE_FUNCS)
296@pytest.mark.parametrize("colorbar", [True, False])
297@pytest.mark.parametrize("cbar_tick_format", ["auto", "%f"])
298def test_plot_surf_surface_colorbar(plot_func, colorbar, cbar_tick_format):
299 """Test surface plotting functions with colorbars."""
300 surf_img = load_fsaverage_data()
301 return plot_func(
302 surf_img.mesh,
303 surf_img,
304 engine="matplotlib",
305 colorbar=colorbar,
306 cbar_tick_format=cbar_tick_format,
307 )
310# ---------------------- design matrix plotting -------------------------------
313@pytest.mark.mpl_image_compare
314def test_plot_event_duration_0():
315 """Test plot event with events of duration 0."""
316 return plot_event(modulated_event_paradigm())
319@pytest.mark.mpl_image_compare
320def test_plot_event_x_lim(rng):
321 """Test that x_lim is set after end of last event.
323 Regression test for https://github.com/nilearn/nilearn/issues/4907
324 """
325 trial_types = ["foo", "bar", "baz"]
327 n_runs = 3
329 events = [
330 pd.DataFrame(
331 {
332 "trial_type": trial_types,
333 "onset": rng.random((3,)) * 5,
334 "duration": rng.uniform(size=(3,)) * 2 + 1,
335 }
336 )
337 for _ in range(n_runs)
338 ]
340 return plot_event(events)
343@pytest.fixture
344def matrix_to_plot(rng):
345 return rng.random((50, 50)) * 10 - 5
348@pytest.mark.mpl_image_compare
349@pytest.mark.parametrize("colorbar", [True, False])
350def test_plot_matrix_colorbar(matrix_to_plot, colorbar):
351 """Test plotting matrix with or without colorbar."""
352 ax = plot_matrix(matrix_to_plot, colorbar=colorbar)
354 return ax.get_figure()
357@pytest.mark.mpl_image_compare
358@pytest.mark.parametrize(
359 "labels", [[], np.array([str(i) for i in range(50)]), None]
360)
361def test_plot_matrix_labels(matrix_to_plot, labels):
362 """Test plotting labels on matrix."""
363 ax = plot_matrix(matrix_to_plot, labels=labels)
365 return ax.get_figure()
368@pytest.mark.mpl_image_compare
369@pytest.mark.parametrize("tri", ["full", "lower", "diag"])
370def test_plot_matrix_grid(matrix_to_plot, tri):
371 """Test plotting full matrix or upper / lower half of it."""
372 ax = plot_matrix(matrix_to_plot, tri=tri)
374 return ax.get_figure()
377@pytest.mark.mpl_image_compare
378@pytest.mark.parametrize("tri", ["full", "diag"])
379def test_plot_design_matrix_correlation(tri):
380 """Test plotting full matrix or lower half of it."""
381 frame_times = np.linspace(0, 127 * 1.0, 128)
382 dmtx = make_first_level_design_matrix(
383 frame_times, events=modulated_event_paradigm()
384 )
386 ax = plot_design_matrix_correlation(
387 dmtx,
388 tri=tri,
389 )
391 return ax.get_figure()
394@pytest.mark.mpl_image_compare
395@pytest.mark.parametrize("colorbar", [True, False])
396def test_plot_design_matrix_correlation_colorbar(colorbar):
397 """Test plot_design_matrix_correlation with / without colorbar."""
398 frame_times = np.linspace(0, 127 * 1.0, 128)
399 dmtx = make_first_level_design_matrix(
400 frame_times, events=modulated_event_paradigm()
401 )
403 ax = plot_design_matrix_correlation(dmtx, colorbar=colorbar)
405 return ax.get_figure()
408@pytest.mark.mpl_image_compare
409def test_plot_design_matrix():
410 """Test plot_design_matrix."""
411 frame_times = np.linspace(0, 127 * 1.0, 128)
412 dmtx = make_first_level_design_matrix(
413 frame_times, drift_model="polynomial", drift_order=3
414 )
416 ax = plot_design_matrix(dmtx)
418 return ax.get_figure()
421@pytest.mark.mpl_image_compare
422@pytest.mark.parametrize(
423 "contrast",
424 [np.array([[1, 0, 0, 1], [0, -2, 1, 0]]), np.array([1, 0, 0, -1])],
425)
426def test_plot_contrast_matrix(contrast):
427 """Test plot_contrast_matrix with T and F contrast."""
428 frame_times = np.linspace(0, 127 * 1.0, 128)
429 dmtx = make_first_level_design_matrix(
430 frame_times, drift_model="polynomial", drift_order=3
431 )
433 ax = plot_contrast_matrix(contrast, dmtx)
435 return ax.get_figure()
438@pytest.mark.mpl_image_compare
439@pytest.mark.parametrize("colorbar", [True, False])
440def test_plot_contrast_matrix_colorbar(colorbar):
441 """Test plot_contrast_matrix colorbar."""
442 frame_times = np.linspace(0, 127 * 1.0, 128)
443 dmtx = make_first_level_design_matrix(
444 frame_times, drift_model="polynomial", drift_order=3
445 )
446 contrast = np.array([[1, 0, 0, 1], [0, -2, 1, 0]])
448 ax = plot_contrast_matrix(contrast, dmtx, colorbar=colorbar)
450 return ax.get_figure()
453@pytest.mark.mpl_image_compare
454@pytest.mark.parametrize("fn", [plot_stat_map, plot_img, plot_glass_brain])
455def test_plot_with_transparency(fn):
456 """Test transparency parameter to determine alpha layer."""
457 return fn(
458 load_sample_motor_activation_image(), transparency=0.5, cmap="cold_hot"
459 )
462@pytest.mark.mpl_image_compare
463@pytest.mark.parametrize("fn", [plot_stat_map, plot_img, plot_glass_brain])
464@pytest.mark.parametrize("transparency_range", [None, [0, 2], [2, 4]])
465def test_plot_with_transparency_range(fn, transparency_range):
466 """Test transparency range parameter to determine alpha layer."""
467 return fn(
468 load_sample_motor_activation_image(),
469 transparency=load_sample_motor_activation_image(),
470 transparency_range=transparency_range,
471 cmap="cold_hot",
472 )
475IMG_COMPARISON_FUNCS = {plot_img_comparison, plot_bland_altman}
478@pytest.mark.mpl_image_compare
479@pytest.mark.parametrize("plot_func", IMG_COMPARISON_FUNCS)
480def test_img_comparison_default(
481 plot_func,
482):
483 """Test img comparing plotting functions with defaults."""
484 plot_func(load_mni152_template(), load_sample_motor_activation_image())
485 # need to use gcf as plot_img_comparison does not return a figure
486 return plt.gcf()
489@pytest.mark.mpl_image_compare
490@pytest.mark.parametrize("plot_func", IMG_COMPARISON_FUNCS)
491@pytest.mark.parametrize("colorbar", [True, False])
492def test_img_comparison_colorbar(
493 plot_func,
494 colorbar,
495):
496 """Test img comparing plotting functions with colorbar."""
497 plot_func(
498 load_mni152_template(),
499 load_sample_motor_activation_image(),
500 colorbar=colorbar,
501 )
502 # need to use gcf as plot_img_comparison does not return a figure
503 return plt.gcf()