Coverage for nilearn/plotting/tests/test_img_comparisons.py: 0%

95 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-16 12:32 +0200

1"""Tests for nilearn.plotting.img_comparison.""" 

2 

3import matplotlib.pyplot as plt 

4import numpy as np 

5import pytest 

6from nibabel import Nifti1Image 

7 

8from nilearn._utils.data_gen import generate_fake_fmri 

9from nilearn.conftest import _affine_mni, _img_mask_mni, _make_surface_mask 

10from nilearn.image import iter_img 

11from nilearn.maskers import NiftiMasker, SurfaceMasker 

12from nilearn.plotting import plot_bland_altman, plot_img_comparison 

13 

14# ruff: noqa: ARG001 

15 

16 

17def _mask(): 

18 affine = _affine_mni() 

19 data_positive = np.zeros((7, 7, 3)) 

20 data_positive[1:-1, 2:-1, 1:] = 1 

21 return Nifti1Image(data_positive, affine) 

22 

23 

24def test_deprecation_function_moved(matplotlib_pyplot, img_3d_mni): 

25 from nilearn.plotting.img_plotting import plot_img_comparison as old_fn 

26 

27 with pytest.warns(DeprecationWarning, match="moved"): 

28 old_fn( 

29 img_3d_mni, 

30 img_3d_mni, 

31 plot_hist=False, 

32 ) 

33 

34 

35@pytest.mark.parametrize( 

36 "masker", 

37 [ 

38 None, 

39 _mask(), 

40 NiftiMasker(mask_img=_img_mask_mni()), 

41 NiftiMasker(mask_img=_img_mask_mni()).fit(), 

42 ], 

43) 

44def test_plot_img_comparison_masker(matplotlib_pyplot, img_3d_mni, masker): 

45 """Tests for plot_img_comparison with masker or mask image.""" 

46 plot_img_comparison( 

47 img_3d_mni, 

48 img_3d_mni, 

49 masker=masker, 

50 plot_hist=False, 

51 ) 

52 

53 

54@pytest.mark.parametrize( 

55 "masker", 

56 [ 

57 None, 

58 _make_surface_mask(), 

59 SurfaceMasker(mask_img=_make_surface_mask()), 

60 SurfaceMasker(mask_img=_make_surface_mask()).fit(), 

61 ], 

62) 

63def test_plot_img_comparison_surface(matplotlib_pyplot, surf_img_1d, masker): 

64 """Test plot_img_comparison with 2 surface images.""" 

65 plot_img_comparison( 

66 surf_img_1d, [surf_img_1d, surf_img_1d], masker=masker, plot_hist=False 

67 ) 

68 

69 

70def test_plot_img_comparison_error(surf_img_1d, img_3d_mni): 

71 """Err if something else than image or list of image is passed.""" 

72 with pytest.raises(TypeError, match="must both be list of 3D"): 

73 plot_img_comparison(surf_img_1d, {surf_img_1d}) 

74 

75 with pytest.raises(TypeError, match="must both be list of only"): 

76 plot_img_comparison(surf_img_1d, img_3d_mni) 

77 

78 

79@pytest.mark.timeout(0) 

80def test_plot_img_comparison(matplotlib_pyplot, rng, tmp_path): 

81 """Tests for plot_img_comparison.""" 

82 _, axes = plt.subplots(2, 1) 

83 axes = axes.ravel() 

84 

85 length = 2 

86 

87 query_images, mask_img = generate_fake_fmri( 

88 random_state=rng, shape=(2, 3, 4), length=length 

89 ) 

90 # plot_img_comparison doesn't handle 4d images ATM 

91 query_images = list(iter_img(query_images)) 

92 

93 target_images, _ = generate_fake_fmri( 

94 random_state=rng, shape=(4, 5, 6), length=length 

95 ) 

96 target_images = list(iter_img(target_images)) 

97 target_images[0] = query_images[0] 

98 

99 masker = NiftiMasker(mask_img).fit() 

100 

101 correlations = plot_img_comparison( 

102 target_images, 

103 query_images, 

104 masker, 

105 axes=axes, 

106 src_label="query", 

107 output_dir=tmp_path, 

108 colorbar=False, 

109 ) 

110 

111 assert len(correlations) == len(query_images) 

112 assert correlations[0] == pytest.approx(1.0) 

113 

114 # 5 scatterplots 

115 ax_0, ax_1 = axes 

116 assert len(ax_0.collections) == length 

117 assert len( 

118 ax_0.collections[0].get_edgecolors() 

119 == masker.transform(target_images[0]).ravel().shape[0] 

120 ) 

121 assert ax_0.get_ylabel() == "query" 

122 assert ax_0.get_xlabel() == "image set 1" 

123 

124 # 5 regression lines 

125 assert len(ax_0.lines) == length 

126 assert ax_0.lines[0].get_linestyle() == "--" 

127 assert ax_1.get_title() == "Histogram of imgs values" 

128 gridsize = 100 

129 assert len(ax_1.patches) == length * 2 * gridsize 

130 

131 

132@pytest.mark.timeout(0) 

133def test_plot_img_comparison_without_plot(matplotlib_pyplot, rng): 

134 """Tests for plot_img_comparison no plot should return same result.""" 

135 _, axes = plt.subplots(2, 1) 

136 axes = axes.ravel() 

137 

138 query_images, mask_img = generate_fake_fmri( 

139 random_state=rng, shape=(2, 3, 4), length=2 

140 ) 

141 # plot_img_comparison doesn't handle 4d images ATM 

142 query_images = list(iter_img(query_images)) 

143 

144 target_images, _ = generate_fake_fmri( 

145 random_state=rng, shape=(2, 3, 4), length=2 

146 ) 

147 target_images = list(iter_img(target_images)) 

148 target_images[0] = query_images[0] 

149 

150 masker = NiftiMasker(mask_img).fit() 

151 

152 correlations = plot_img_comparison( 

153 target_images, query_images, masker, plot_hist=True, colorbar=False 

154 ) 

155 

156 correlations_1 = plot_img_comparison( 

157 target_images, query_images, masker, plot_hist=False 

158 ) 

159 

160 assert np.allclose(correlations, correlations_1) 

161 

162 

163@pytest.mark.parametrize( 

164 "masker", 

165 [ 

166 None, 

167 _mask(), 

168 NiftiMasker(mask_img=_img_mask_mni()), 

169 NiftiMasker(mask_img=_img_mask_mni()).fit(), 

170 ], 

171) 

172def test_plot_bland_altman( 

173 matplotlib_pyplot, tmp_path, img_3d_mni, img_3d_mni_as_file, masker 

174): 

175 """Test Bland-Altman plot with different masker values. 

176 

177 Also check non default values for 

178 labels, 

179 title 

180 grid size, 

181 and output_file. 

182 

183 Also checks that input images can be nifti image or path. 

184 """ 

185 plot_bland_altman( 

186 img_3d_mni, 

187 img_3d_mni_as_file, 

188 masker=masker, 

189 ref_label="image set 1", 

190 src_label="image set 2", 

191 title="cheese shop", 

192 gridsize=10, 

193 output_file=tmp_path / "spam.jpg", 

194 lims=[-1, 5, -2, 3], 

195 colorbar=False, 

196 ) 

197 

198 assert (tmp_path / "spam.jpg").is_file() 

199 

200 

201@pytest.mark.parametrize( 

202 "masker", 

203 [ 

204 None, 

205 _make_surface_mask(), 

206 SurfaceMasker(mask_img=_make_surface_mask()), 

207 SurfaceMasker(mask_img=_make_surface_mask()).fit(), 

208 ], 

209) 

210def test_plot_bland_altman_surface(matplotlib_pyplot, surf_img_1d, masker): 

211 """Test Bland-Altman plot with 2 surface images. 

212 

213 Also checks tuple value for gridsize. 

214 """ 

215 plot_bland_altman( 

216 surf_img_1d, surf_img_1d, masker=masker, gridsize=(10, 80) 

217 ) 

218 

219 

220@pytest.mark.timeout(0) 

221def test_plot_bland_altman_errors( 

222 surf_img_1d, surf_mask_1d, img_3d_rand_eye, img_3d_ones_eye 

223): 

224 """Check common errors for bland altman plots. 

225 

226 - both inputs must be niimg like or surface 

227 - valid masker type for volume or surface data 

228 """ 

229 error_msg = "'ref_img' and 'src_img' must both be" 

230 with pytest.raises(TypeError, match=error_msg): 

231 plot_bland_altman(1, "foo") 

232 

233 with pytest.raises(TypeError, match=error_msg): 

234 plot_bland_altman(surf_img_1d, img_3d_rand_eye) 

235 

236 with pytest.raises(TypeError, match="Mask should be of type:"): 

237 plot_bland_altman(img_3d_rand_eye, img_3d_rand_eye, masker=1) 

238 

239 error_msg = "Mask and images to fit must be of compatible types." 

240 # invalid masker for that image type 

241 with pytest.raises(TypeError, match=error_msg): 

242 plot_bland_altman( 

243 img_3d_rand_eye, img_3d_rand_eye, masker=SurfaceMasker() 

244 ) 

245 with pytest.raises(TypeError, match=error_msg): 

246 plot_bland_altman( 

247 img_3d_rand_eye, img_3d_rand_eye, masker=surf_mask_1d 

248 ) 

249 with pytest.raises(TypeError, match=error_msg): 

250 plot_bland_altman(surf_img_1d, surf_img_1d, masker=NiftiMasker()) 

251 

252 with pytest.raises(TypeError, match=error_msg): 

253 plot_bland_altman(surf_img_1d, surf_img_1d, masker=img_3d_ones_eye) 

254 

255 with pytest.raises( 

256 TypeError, match="'lims' must be a list or tuple of length == 4" 

257 ): 

258 plot_bland_altman(img_3d_rand_eye, img_3d_rand_eye, lims=[-1]) 

259 

260 with pytest.raises(TypeError, match="with all values different from 0."): 

261 plot_bland_altman(img_3d_rand_eye, img_3d_rand_eye, lims=[0, 1, -2, 0])