Coverage for nilearn/maskers/tests/test_multi_nifti_labels_masker.py: 0%

155 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-16 12:32 +0200

1"""Test the multi_nifti_labels_masker module.""" 

2 

3import numpy as np 

4import pytest 

5from nibabel import Nifti1Image 

6from numpy.testing import assert_almost_equal, assert_array_equal 

7from sklearn.utils.estimator_checks import parametrize_with_checks 

8 

9from nilearn._utils.data_gen import ( 

10 generate_fake_fmri, 

11 generate_labeled_regions, 

12) 

13from nilearn._utils.estimator_checks import ( 

14 check_estimator, 

15 nilearn_check_estimator, 

16 return_expected_failed_checks, 

17) 

18from nilearn._utils.tags import SKLEARN_LT_1_6 

19from nilearn.conftest import _img_labels 

20from nilearn.image import get_data 

21from nilearn.maskers import MultiNiftiLabelsMasker 

22 

23ESTIMATORS_TO_CHECK = [MultiNiftiLabelsMasker()] 

24 

25if SKLEARN_LT_1_6: 

26 

27 @pytest.mark.parametrize( 

28 "estimator, check, name", 

29 check_estimator(estimators=ESTIMATORS_TO_CHECK), 

30 ) 

31 def test_check_estimator_sklearn_valid(estimator, check, name): # noqa: ARG001 

32 """Check compliance with sklearn estimators.""" 

33 check(estimator) 

34 

35 @pytest.mark.xfail(reason="invalid checks should fail") 

36 @pytest.mark.parametrize( 

37 "estimator, check, name", 

38 check_estimator(estimators=ESTIMATORS_TO_CHECK, valid=False), 

39 ) 

40 def test_check_estimator_sklearn_invalid(estimator, check, name): # noqa: ARG001 

41 """Check compliance with sklearn estimators.""" 

42 check(estimator) 

43 

44else: 

45 

46 @parametrize_with_checks( 

47 estimators=ESTIMATORS_TO_CHECK, 

48 expected_failed_checks=return_expected_failed_checks, 

49 ) 

50 def test_check_estimator_sklearn(estimator, check): 

51 """Check compliance with sklearn estimators.""" 

52 check(estimator) 

53 

54 

55@pytest.mark.timeout(0) 

56@pytest.mark.parametrize( 

57 "estimator, check, name", 

58 nilearn_check_estimator( 

59 estimators=[MultiNiftiLabelsMasker(labels_img=_img_labels())] 

60 ), 

61) 

62def test_check_estimator_nilearn(estimator, check, name): # noqa: ARG001 

63 """Check compliance with nilearn estimators rules.""" 

64 check(estimator) 

65 

66 

67def test_multi_nifti_labels_masker( 

68 affine_eye, n_regions, shape_3d_default, length, img_labels 

69): 

70 """Check working of shape/affine checks.""" 

71 fmri11_img, mask11_img = generate_fake_fmri( 

72 shape_3d_default, affine=affine_eye, length=length 

73 ) 

74 

75 masker11 = MultiNiftiLabelsMasker(img_labels, resampling_target=None) 

76 

77 # No exception raised here 

78 signals11 = masker11.fit_transform(fmri11_img) 

79 

80 assert signals11.shape == (length, n_regions) 

81 

82 # No exception should be raised either 

83 masker11 = MultiNiftiLabelsMasker(img_labels, resampling_target=None) 

84 masker11.fit() 

85 masker11.inverse_transform(signals11) 

86 

87 masker11 = MultiNiftiLabelsMasker( 

88 img_labels, mask_img=mask11_img, resampling_target=None 

89 ) 

90 signals11 = masker11.fit_transform(fmri11_img) 

91 

92 assert signals11.shape == (length, n_regions) 

93 

94 # Should work with 4D + 1D input too (also test fit_transform) 

95 signals_input = [fmri11_img, fmri11_img] 

96 signals11_list = masker11.fit_transform(signals_input) 

97 

98 for signals in signals11_list: 

99 assert signals.shape == (length, n_regions) 

100 

101 masker11 = MultiNiftiLabelsMasker(img_labels, resampling_target=None) 

102 signals11_list = masker11.fit_transform(signals_input) 

103 

104 for signals in signals11_list: 

105 assert signals.shape == (length, n_regions) 

106 

107 # Call inverse transform (smoke test) 

108 for signals in signals11_list: 

109 fmri11_img_r = masker11.inverse_transform(signals) 

110 

111 assert fmri11_img_r.shape == fmri11_img.shape 

112 assert_almost_equal(fmri11_img_r.affine, fmri11_img.affine) 

113 

114 

115def test_multi_nifti_labels_masker_errors( 

116 affine_eye, shape_3d_default, length, img_labels 

117): 

118 """Test errors in MultiNiftiLabelsMasker.""" 

119 shape2 = (12, 10, 14) 

120 affine2 = np.diag((1, 2, 3, 1)) 

121 

122 fmri12_img, mask12_img = generate_fake_fmri( 

123 shape_3d_default, affine=affine2, length=length 

124 ) 

125 fmri21_img, mask21_img = generate_fake_fmri( 

126 shape2, affine=affine_eye, length=length 

127 ) 

128 

129 # Test all kinds of mismatch between shapes and between affines 

130 masker11 = MultiNiftiLabelsMasker(img_labels, resampling_target=None) 

131 masker11.fit() 

132 

133 with pytest.raises( 

134 ValueError, match="Images have different affine matrices." 

135 ): 

136 masker11.transform(fmri12_img) 

137 

138 with pytest.raises(ValueError, match="Images have incompatible shapes."): 

139 masker11.transform(fmri21_img) 

140 

141 masker11 = MultiNiftiLabelsMasker( 

142 img_labels, mask_img=mask12_img, resampling_target=None 

143 ) 

144 

145 with pytest.raises( 

146 ValueError, match="Following field of view errors were detected" 

147 ): 

148 masker11.fit() 

149 

150 masker11 = MultiNiftiLabelsMasker( 

151 img_labels, mask_img=mask21_img, resampling_target=None 

152 ) 

153 

154 with pytest.raises( 

155 ValueError, match="Following field of view errors were detected" 

156 ): 

157 masker11.fit() 

158 

159 

160def test_multi_nifti_labels_masker_errors_strategy(img_labels): 

161 """Test strategy errors.""" 

162 masker = MultiNiftiLabelsMasker(img_labels, strategy="TESTRAISE") 

163 with pytest.raises(ValueError, match="Invalid strategy 'TESTRAISE'"): 

164 masker.fit() 

165 

166 

167@pytest.mark.parametrize("resampling_target", ["mask", "invalid"]) 

168def test_multi_nifti_labels_masker_errors_resampling( 

169 img_labels, resampling_target 

170): 

171 """Test error checking resampling_target.""" 

172 masker = MultiNiftiLabelsMasker( 

173 img_labels, 

174 resampling_target=resampling_target, 

175 ) 

176 with pytest.raises( 

177 ValueError, match="invalid value for 'resampling_target' parameter" 

178 ): 

179 masker.fit() 

180 

181 

182@pytest.mark.timeout(0) 

183def test_multi_nifti_labels_masker_reduction_strategies(affine_eye): 

184 """Tests strategies of MultiNiftiLabelsMasker. 

185 

186 1. whether the usage of different reduction strategies work 

187 2. whether unrecognized strategies raise a ValueError 

188 3. whether the default option is backwards compatible (calls "mean") 

189 """ 

190 test_values = [-2.0, -1.0, 0.0, 1.0, 2] 

191 

192 img_data = np.array([[test_values, test_values]]) 

193 

194 labels_data = np.array([[[0, 0, 0, 0, 0], [1, 1, 1, 1, 1]]], dtype=np.int8) 

195 

196 img = Nifti1Image(img_data, affine_eye) 

197 labels = Nifti1Image(labels_data, affine_eye) 

198 

199 # What MultiNiftiLabelsMasker should return for each reduction strategy? 

200 expected_results = { 

201 "mean": np.mean(test_values), 

202 "median": np.median(test_values), 

203 "sum": np.sum(test_values), 

204 "minimum": np.min(test_values), 

205 "maximum": np.max(test_values), 

206 "standard_deviation": np.std(test_values), 

207 "variance": np.var(test_values), 

208 } 

209 

210 for strategy, expected_result in expected_results.items(): 

211 masker = MultiNiftiLabelsMasker(labels, strategy=strategy) 

212 # Here passing [img, img] within a list because it is multiple subjects 

213 # with a 3D object. 

214 results = masker.fit_transform([img, img]) 

215 for result in results: 

216 assert result.squeeze() == expected_result 

217 

218 default_masker = MultiNiftiLabelsMasker(labels) 

219 assert default_masker.strategy == "mean" 

220 

221 

222def test_multi_nifti_labels_masker_resampling( 

223 affine_eye, n_regions, length, img_labels 

224): 

225 """Test resampling in MultiNiftiLabelsMasker.""" 

226 shape1 = (10, 11, 12) 

227 

228 # mask 

229 shape2 = (16, 17, 18) 

230 

231 # With data of the same affine 

232 fmri11_img, _ = generate_fake_fmri( 

233 shape1, affine=affine_eye, length=length 

234 ) 

235 _, mask22_img = generate_fake_fmri( 

236 shape2, affine=affine_eye, length=length 

237 ) 

238 

239 # Target: labels 

240 masker = MultiNiftiLabelsMasker( 

241 img_labels, mask_img=mask22_img, resampling_target="labels" 

242 ) 

243 

244 fmri11_img = [fmri11_img, fmri11_img] 

245 

246 signals = masker.fit_transform(fmri11_img) 

247 

248 assert_almost_equal(masker.labels_img_.affine, img_labels.affine) 

249 assert masker.labels_img_.shape == img_labels.shape 

250 

251 assert_almost_equal(masker.mask_img_.affine, masker.labels_img_.affine) 

252 assert masker.mask_img_.shape == masker.labels_img_.shape[:3] 

253 

254 for t in signals: 

255 assert t.shape == (length, n_regions) 

256 

257 fmri11_img_r = masker.inverse_transform(t) 

258 assert_almost_equal(fmri11_img_r.affine, masker.labels_img_.affine) 

259 assert fmri11_img_r.shape == (masker.labels_img_.shape[:3] + (length,)) 

260 

261 

262def test_multi_nifti_labels_masker_resampling_clipped_labels( 

263 affine_eye, n_regions, length, img_labels, img_fmri 

264): 

265 """Test with clipped labels. 

266 

267 Mask does not contain all labels. 

268 Shapes do matter in that case, 

269 because there is some resampling taking place. 

270 """ 

271 shape2 = (8, 9, 10) # mask 

272 

273 _, mask22_img = generate_fake_fmri( 

274 shape2, affine=affine_eye, length=length 

275 ) 

276 

277 # Multi-subject example 

278 fmri11_img = [img_fmri, img_fmri] 

279 

280 masker = MultiNiftiLabelsMasker( 

281 img_labels, mask_img=mask22_img, resampling_target="labels" 

282 ) 

283 

284 signals = masker.fit_transform(fmri11_img) 

285 

286 assert_almost_equal(masker.labels_img_.affine, img_labels.affine) 

287 assert masker.labels_img_.shape == img_labels.shape 

288 assert_almost_equal(masker.mask_img_.affine, masker.labels_img_.affine) 

289 assert masker.mask_img_.shape == masker.labels_img_.shape[:3] 

290 uniq_labels = np.unique(get_data(masker.labels_img_)) 

291 assert uniq_labels[0] == 0 

292 assert len(uniq_labels) - 1 == n_regions 

293 

294 for t in signals: 

295 assert t.shape == (length, n_regions) 

296 # Some regions have been clipped. Resulting signal must be zero 

297 assert (t.var(axis=0) == 0).sum() < n_regions 

298 

299 fmri11_img_r = masker.inverse_transform(t) 

300 

301 assert_almost_equal(fmri11_img_r.affine, masker.labels_img_.affine) 

302 assert fmri11_img_r.shape == (masker.labels_img_.shape[:3] + (length,)) 

303 

304 

305def test_multi_nifti_labels_masker_atlas_data_different_fov( 

306 affine_eye, img_labels, length 

307): 

308 """Test with data and atlas of different shape. 

309 

310 The atlas should be resampled to the data. 

311 """ 

312 shape2 = (8, 9, 10) # mask 

313 shape22 = (5, 5, 6) 

314 affine2 = 2 * np.eye(4) 

315 affine2[-1, -1] = 1 

316 

317 _, mask22_img = generate_fake_fmri( 

318 shape2, affine=affine_eye, length=length 

319 ) 

320 

321 fmri22_img, _ = generate_fake_fmri(shape22, affine=affine2, length=length) 

322 masker = MultiNiftiLabelsMasker(img_labels, mask_img=mask22_img) 

323 

324 masker.fit_transform(fmri22_img) 

325 

326 assert_array_equal(masker.labels_img_.affine, affine2) 

327 

328 

329def test_multi_nifti_labels_masker_resampling_target(): 

330 """Test labels masker with resampling target in 'data', 'labels'. 

331 

332 Must return resampled labels having number of labels 

333 equal with transformed shape of 2nd dimension. 

334 

335 This tests are added based on issue #1673 in Nilearn. 

336 """ 

337 shape = (13, 11, 12) 

338 affine = np.eye(4) * 2 

339 

340 fmri_img, _ = generate_fake_fmri(shape, affine=affine, length=21) 

341 labels_img = generate_labeled_regions( 

342 (9, 8, 6), affine=np.eye(4), n_regions=10 

343 ) 

344 for resampling_target in ["data", "labels"]: 

345 masker = MultiNiftiLabelsMasker( 

346 labels_img=labels_img, resampling_target=resampling_target 

347 ) 

348 if resampling_target == "data": 

349 with pytest.warns( 

350 UserWarning, 

351 match=( 

352 "After resampling the label image " 

353 "to the data image, the following " 

354 "labels were removed" 

355 ), 

356 ): 

357 signals = masker.fit_transform(fmri_img) 

358 else: 

359 signals = masker.fit_transform(fmri_img) 

360 

361 resampled_labels_img = masker.labels_img_ 

362 n_resampled_labels = len(np.unique(get_data(resampled_labels_img))) 

363 assert n_resampled_labels - 1 == signals.shape[1] 

364 

365 # inverse transform 

366 compressed_img = masker.inverse_transform(signals) 

367 

368 # Test that compressing the image a second time should yield an image 

369 # with the same data as compressed_img. 

370 signals2 = masker.fit_transform(fmri_img) 

371 

372 # inverse transform again 

373 compressed_img2 = masker.inverse_transform(signals2) 

374 

375 assert_array_equal(get_data(compressed_img), get_data(compressed_img2))