Coverage for nilearn/plotting/tests/test_img_comparisons.py: 0%
95 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
1"""Tests for nilearn.plotting.img_comparison."""
3import matplotlib.pyplot as plt
4import numpy as np
5import pytest
6from nibabel import Nifti1Image
8from nilearn._utils.data_gen import generate_fake_fmri
9from nilearn.conftest import _affine_mni, _img_mask_mni, _make_surface_mask
10from nilearn.image import iter_img
11from nilearn.maskers import NiftiMasker, SurfaceMasker
12from nilearn.plotting import plot_bland_altman, plot_img_comparison
14# ruff: noqa: ARG001
17def _mask():
18 affine = _affine_mni()
19 data_positive = np.zeros((7, 7, 3))
20 data_positive[1:-1, 2:-1, 1:] = 1
21 return Nifti1Image(data_positive, affine)
24def test_deprecation_function_moved(matplotlib_pyplot, img_3d_mni):
25 from nilearn.plotting.img_plotting import plot_img_comparison as old_fn
27 with pytest.warns(DeprecationWarning, match="moved"):
28 old_fn(
29 img_3d_mni,
30 img_3d_mni,
31 plot_hist=False,
32 )
35@pytest.mark.parametrize(
36 "masker",
37 [
38 None,
39 _mask(),
40 NiftiMasker(mask_img=_img_mask_mni()),
41 NiftiMasker(mask_img=_img_mask_mni()).fit(),
42 ],
43)
44def test_plot_img_comparison_masker(matplotlib_pyplot, img_3d_mni, masker):
45 """Tests for plot_img_comparison with masker or mask image."""
46 plot_img_comparison(
47 img_3d_mni,
48 img_3d_mni,
49 masker=masker,
50 plot_hist=False,
51 )
54@pytest.mark.parametrize(
55 "masker",
56 [
57 None,
58 _make_surface_mask(),
59 SurfaceMasker(mask_img=_make_surface_mask()),
60 SurfaceMasker(mask_img=_make_surface_mask()).fit(),
61 ],
62)
63def test_plot_img_comparison_surface(matplotlib_pyplot, surf_img_1d, masker):
64 """Test plot_img_comparison with 2 surface images."""
65 plot_img_comparison(
66 surf_img_1d, [surf_img_1d, surf_img_1d], masker=masker, plot_hist=False
67 )
70def test_plot_img_comparison_error(surf_img_1d, img_3d_mni):
71 """Err if something else than image or list of image is passed."""
72 with pytest.raises(TypeError, match="must both be list of 3D"):
73 plot_img_comparison(surf_img_1d, {surf_img_1d})
75 with pytest.raises(TypeError, match="must both be list of only"):
76 plot_img_comparison(surf_img_1d, img_3d_mni)
79@pytest.mark.timeout(0)
80def test_plot_img_comparison(matplotlib_pyplot, rng, tmp_path):
81 """Tests for plot_img_comparison."""
82 _, axes = plt.subplots(2, 1)
83 axes = axes.ravel()
85 length = 2
87 query_images, mask_img = generate_fake_fmri(
88 random_state=rng, shape=(2, 3, 4), length=length
89 )
90 # plot_img_comparison doesn't handle 4d images ATM
91 query_images = list(iter_img(query_images))
93 target_images, _ = generate_fake_fmri(
94 random_state=rng, shape=(4, 5, 6), length=length
95 )
96 target_images = list(iter_img(target_images))
97 target_images[0] = query_images[0]
99 masker = NiftiMasker(mask_img).fit()
101 correlations = plot_img_comparison(
102 target_images,
103 query_images,
104 masker,
105 axes=axes,
106 src_label="query",
107 output_dir=tmp_path,
108 colorbar=False,
109 )
111 assert len(correlations) == len(query_images)
112 assert correlations[0] == pytest.approx(1.0)
114 # 5 scatterplots
115 ax_0, ax_1 = axes
116 assert len(ax_0.collections) == length
117 assert len(
118 ax_0.collections[0].get_edgecolors()
119 == masker.transform(target_images[0]).ravel().shape[0]
120 )
121 assert ax_0.get_ylabel() == "query"
122 assert ax_0.get_xlabel() == "image set 1"
124 # 5 regression lines
125 assert len(ax_0.lines) == length
126 assert ax_0.lines[0].get_linestyle() == "--"
127 assert ax_1.get_title() == "Histogram of imgs values"
128 gridsize = 100
129 assert len(ax_1.patches) == length * 2 * gridsize
132@pytest.mark.timeout(0)
133def test_plot_img_comparison_without_plot(matplotlib_pyplot, rng):
134 """Tests for plot_img_comparison no plot should return same result."""
135 _, axes = plt.subplots(2, 1)
136 axes = axes.ravel()
138 query_images, mask_img = generate_fake_fmri(
139 random_state=rng, shape=(2, 3, 4), length=2
140 )
141 # plot_img_comparison doesn't handle 4d images ATM
142 query_images = list(iter_img(query_images))
144 target_images, _ = generate_fake_fmri(
145 random_state=rng, shape=(2, 3, 4), length=2
146 )
147 target_images = list(iter_img(target_images))
148 target_images[0] = query_images[0]
150 masker = NiftiMasker(mask_img).fit()
152 correlations = plot_img_comparison(
153 target_images, query_images, masker, plot_hist=True, colorbar=False
154 )
156 correlations_1 = plot_img_comparison(
157 target_images, query_images, masker, plot_hist=False
158 )
160 assert np.allclose(correlations, correlations_1)
163@pytest.mark.parametrize(
164 "masker",
165 [
166 None,
167 _mask(),
168 NiftiMasker(mask_img=_img_mask_mni()),
169 NiftiMasker(mask_img=_img_mask_mni()).fit(),
170 ],
171)
172def test_plot_bland_altman(
173 matplotlib_pyplot, tmp_path, img_3d_mni, img_3d_mni_as_file, masker
174):
175 """Test Bland-Altman plot with different masker values.
177 Also check non default values for
178 labels,
179 title
180 grid size,
181 and output_file.
183 Also checks that input images can be nifti image or path.
184 """
185 plot_bland_altman(
186 img_3d_mni,
187 img_3d_mni_as_file,
188 masker=masker,
189 ref_label="image set 1",
190 src_label="image set 2",
191 title="cheese shop",
192 gridsize=10,
193 output_file=tmp_path / "spam.jpg",
194 lims=[-1, 5, -2, 3],
195 colorbar=False,
196 )
198 assert (tmp_path / "spam.jpg").is_file()
201@pytest.mark.parametrize(
202 "masker",
203 [
204 None,
205 _make_surface_mask(),
206 SurfaceMasker(mask_img=_make_surface_mask()),
207 SurfaceMasker(mask_img=_make_surface_mask()).fit(),
208 ],
209)
210def test_plot_bland_altman_surface(matplotlib_pyplot, surf_img_1d, masker):
211 """Test Bland-Altman plot with 2 surface images.
213 Also checks tuple value for gridsize.
214 """
215 plot_bland_altman(
216 surf_img_1d, surf_img_1d, masker=masker, gridsize=(10, 80)
217 )
220@pytest.mark.timeout(0)
221def test_plot_bland_altman_errors(
222 surf_img_1d, surf_mask_1d, img_3d_rand_eye, img_3d_ones_eye
223):
224 """Check common errors for bland altman plots.
226 - both inputs must be niimg like or surface
227 - valid masker type for volume or surface data
228 """
229 error_msg = "'ref_img' and 'src_img' must both be"
230 with pytest.raises(TypeError, match=error_msg):
231 plot_bland_altman(1, "foo")
233 with pytest.raises(TypeError, match=error_msg):
234 plot_bland_altman(surf_img_1d, img_3d_rand_eye)
236 with pytest.raises(TypeError, match="Mask should be of type:"):
237 plot_bland_altman(img_3d_rand_eye, img_3d_rand_eye, masker=1)
239 error_msg = "Mask and images to fit must be of compatible types."
240 # invalid masker for that image type
241 with pytest.raises(TypeError, match=error_msg):
242 plot_bland_altman(
243 img_3d_rand_eye, img_3d_rand_eye, masker=SurfaceMasker()
244 )
245 with pytest.raises(TypeError, match=error_msg):
246 plot_bland_altman(
247 img_3d_rand_eye, img_3d_rand_eye, masker=surf_mask_1d
248 )
249 with pytest.raises(TypeError, match=error_msg):
250 plot_bland_altman(surf_img_1d, surf_img_1d, masker=NiftiMasker())
252 with pytest.raises(TypeError, match=error_msg):
253 plot_bland_altman(surf_img_1d, surf_img_1d, masker=img_3d_ones_eye)
255 with pytest.raises(
256 TypeError, match="'lims' must be a list or tuple of length == 4"
257 ):
258 plot_bland_altman(img_3d_rand_eye, img_3d_rand_eye, lims=[-1])
260 with pytest.raises(TypeError, match="with all values different from 0."):
261 plot_bland_altman(img_3d_rand_eye, img_3d_rand_eye, lims=[0, 1, -2, 0])