Coverage for nilearn/decoding/tests/test_searchlight.py: 0%

133 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-16 12:32 +0200

1"""Test the searchlight module.""" 

2 

3import numpy as np 

4import pytest 

5from nibabel import Nifti1Image 

6from sklearn.model_selection import ( 

7 KFold, 

8 LeaveOneGroupOut, 

9) 

10from sklearn.utils.estimator_checks import parametrize_with_checks 

11 

12from nilearn._utils.estimator_checks import ( 

13 check_estimator, 

14 nilearn_check_estimator, 

15 return_expected_failed_checks, 

16) 

17from nilearn._utils.tags import SKLEARN_LT_1_6 

18from nilearn.conftest import _rng 

19from nilearn.decoding import searchlight 

20 

21ESTIMATOR_TO_CHECK = [searchlight.SearchLight()] 

22 

23if SKLEARN_LT_1_6: 

24 

25 @pytest.mark.parametrize( 

26 "estimator, check, name", 

27 check_estimator(estimators=ESTIMATOR_TO_CHECK), 

28 ) 

29 def test_check_estimator_sklearn_valid(estimator, check, name): # noqa: ARG001 

30 """Check compliance with sklearn estimators.""" 

31 check(estimator) 

32 

33 @pytest.mark.xfail(reason="invalid checks should fail") 

34 @pytest.mark.parametrize( 

35 "estimator, check, name", 

36 check_estimator(estimators=ESTIMATOR_TO_CHECK, valid=False), 

37 ) 

38 def test_check_estimator_sklearn_invalid(estimator, check, name): # noqa: ARG001 

39 """Check compliance with sklearn estimators.""" 

40 check(estimator) 

41 

42else: 

43 

44 @parametrize_with_checks( 

45 estimators=ESTIMATOR_TO_CHECK, 

46 expected_failed_checks=return_expected_failed_checks, 

47 ) 

48 def test_check_estimator_sklearn(estimator, check): 

49 """Check compliance with sklearn estimators.""" 

50 check(estimator) 

51 

52 

53@pytest.mark.parametrize( 

54 "estimator, check, name", 

55 nilearn_check_estimator( 

56 estimators=[ 

57 searchlight.SearchLight( 

58 mask_img=Nifti1Image( 

59 np.ones((5, 5, 5), dtype=bool).astype("uint8"), np.eye(4) 

60 ) 

61 ) 

62 ] 

63 ), 

64) 

65def test_check_estimator_nilearn(estimator, check, name): # noqa: ARG001 

66 """Check compliance with nilearn estimators rules.""" 

67 check(estimator) 

68 

69 

70def _make_searchlight_test_data(frames): 

71 data = _rng().random((5, 5, 5, frames)) 

72 mask = np.ones((5, 5, 5), dtype=bool) 

73 mask_img = Nifti1Image(mask.astype("uint8"), np.eye(4)) 

74 # Create a condition array, with balanced classes 

75 cond = np.arange(frames, dtype=int) >= (frames // 2) 

76 

77 data[2, 2, 2, :] = 0 

78 data[2, 2, 2, cond] = 2 

79 data_img = Nifti1Image(data, np.eye(4)) 

80 

81 return data_img, cond, mask_img 

82 

83 

84def define_cross_validation(): 

85 # Define cross validation 

86 cv = KFold(n_splits=4) 

87 n_jobs = 1 

88 return cv, n_jobs 

89 

90 

91def test_searchlight_no_mask(): 

92 """Check validation type mask.""" 

93 sl = searchlight.SearchLight(mask_img=1) 

94 

95 frames = 30 

96 data_img, cond, _ = _make_searchlight_test_data(frames) 

97 with pytest.raises( 

98 TypeError, 

99 match="input should be a NiftiLike object", 

100 ): 

101 sl.fit(data_img, y=cond) 

102 

103 

104def test_searchlight_small_radius(): 

105 frames = 30 

106 data_img, cond, mask_img = _make_searchlight_test_data(frames) 

107 cv, n_jobs = define_cross_validation() 

108 

109 # Small radius : only one pixel is selected 

110 sl = searchlight.SearchLight( 

111 mask_img, 

112 process_mask_img=mask_img, 

113 radius=0.5, 

114 n_jobs=n_jobs, 

115 scoring="accuracy", 

116 cv=cv, 

117 verbose=1, 

118 ) 

119 sl.fit(data_img, y=cond) 

120 

121 assert np.where(sl.scores_ == 1)[0].size == 1 

122 assert sl.scores_[2, 2, 2] == 1.0 

123 

124 

125def test_searchlight_mask_far_from_signal(affine_eye): 

126 frames = 30 

127 data_img, cond, mask_img = _make_searchlight_test_data(frames) 

128 cv, n_jobs = define_cross_validation() 

129 

130 process_mask = np.zeros((5, 5, 5), dtype=bool) 

131 process_mask[0, 0, 0] = True 

132 process_mask_img = Nifti1Image(process_mask.astype("uint8"), affine_eye) 

133 sl = searchlight.SearchLight( 

134 mask_img, 

135 process_mask_img=process_mask_img, 

136 radius=0.5, 

137 n_jobs=n_jobs, 

138 scoring="accuracy", 

139 cv=cv, 

140 ) 

141 sl.fit(data_img, y=cond) 

142 

143 assert np.where(sl.scores_ == 1)[0].size == 0 

144 

145 

146def test_searchlight_medium_radius(): 

147 frames = 30 

148 data_img, cond, mask_img = _make_searchlight_test_data(frames) 

149 cv, n_jobs = define_cross_validation() 

150 

151 sl = searchlight.SearchLight( 

152 mask_img, 

153 process_mask_img=mask_img, 

154 radius=1, 

155 n_jobs=n_jobs, 

156 scoring="accuracy", 

157 cv=cv, 

158 ) 

159 sl.fit(data_img, cond) 

160 

161 assert np.where(sl.scores_ == 1)[0].size == 7 

162 assert sl.scores_[2, 2, 2] == 1.0 

163 assert sl.scores_[1, 2, 2] == 1.0 

164 assert sl.scores_[2, 1, 2] == 1.0 

165 assert sl.scores_[2, 2, 1] == 1.0 

166 assert sl.scores_[3, 2, 2] == 1.0 

167 assert sl.scores_[2, 3, 2] == 1.0 

168 assert sl.scores_[2, 2, 3] == 1.0 

169 

170 

171def test_searchlight_large_radius(): 

172 frames = 30 

173 data_img, cond, mask_img = _make_searchlight_test_data(frames) 

174 cv, n_jobs = define_cross_validation() 

175 

176 sl = searchlight.SearchLight( 

177 mask_img, 

178 process_mask_img=mask_img, 

179 radius=2, 

180 n_jobs=n_jobs, 

181 scoring="accuracy", 

182 cv=cv, 

183 ) 

184 sl.fit(data_img, cond) 

185 

186 assert np.where(sl.scores_ == 1)[0].size == 33 

187 assert sl.scores_[2, 2, 2] == 1.0 

188 

189 

190def test_searchlight_group_cross_validation(rng): 

191 frames = 30 

192 data_img, cond, mask_img = _make_searchlight_test_data(frames) 

193 _, n_jobs = define_cross_validation() 

194 

195 groups = rng.permutation(np.arange(frames, dtype=int) > (frames // 2)) 

196 

197 sl = searchlight.SearchLight( 

198 mask_img, 

199 process_mask_img=mask_img, 

200 radius=1, 

201 n_jobs=n_jobs, 

202 scoring="accuracy", 

203 cv=LeaveOneGroupOut(), 

204 ) 

205 sl.fit(data_img, y=cond, groups=groups) 

206 

207 assert np.where(sl.scores_ == 1)[0].size == 7 

208 assert sl.scores_[2, 2, 2] == 1.0 

209 

210 

211def test_searchlight_group_cross_validation_with_extra_group_variable( 

212 rng, 

213 affine_eye, 

214): 

215 frames = 30 

216 data_img, cond, mask_img = _make_searchlight_test_data(frames) 

217 cv, n_jobs = define_cross_validation() 

218 

219 groups = rng.permutation(np.arange(frames, dtype=int) > (frames // 2)) 

220 

221 sl = searchlight.SearchLight( 

222 mask_img, 

223 process_mask_img=mask_img, 

224 radius=1, 

225 n_jobs=n_jobs, 

226 scoring="accuracy", 

227 cv=cv, 

228 ) 

229 sl.fit(data_img, y=cond, groups=groups) 

230 

231 assert np.where(sl.scores_ == 1)[0].size == 7 

232 assert sl.scores_[2, 2, 2] == 1.0 

233 

234 # Check whether searchlight works on list of 3D images 

235 data = rng.random((5, 5, 5)) 

236 data_img = Nifti1Image(data, affine=affine_eye) 

237 imgs = [data_img] * 12 

238 

239 # labels 

240 y = [0, 1] * 6 

241 

242 # run searchlight on list of 3D images 

243 sl = searchlight.SearchLight(mask_img) 

244 sl.fit(imgs, y) 

245 

246 

247def test_mask_img_dimension_mismatch(): 

248 """Test if SearchLight handles mismatched mask and 

249 image dimensions gracefully. 

250 """ 

251 data_img, cond, _ = _make_searchlight_test_data(frames=20) 

252 

253 # Create a mask with smaller dimensions (4x4x4 vs 5x5x5 in data_img) 

254 invalid_mask_img = Nifti1Image( 

255 np.ones((4, 4, 4), dtype="uint8"), np.eye(4) 

256 ) 

257 

258 # Instantiate SearchLight with mismatched mask 

259 sl = searchlight.SearchLight(invalid_mask_img, radius=1.0) 

260 

261 # Fit should complete without raising an error 

262 sl.fit(data_img, y=cond) 

263 

264 # Ensure scores_ exists and is the correct shape 

265 assert sl.scores_ is not None 

266 assert sl.scores_.shape == invalid_mask_img.shape 

267 

268 

269def test_transform_applies_mask_correctly(): 

270 """Test if `transform()` applies the mask correctly.""" 

271 frames = 20 

272 data_img, cond, mask_img = _make_searchlight_test_data(frames) 

273 

274 sl = searchlight.SearchLight(mask_img, radius=1.0) 

275 sl.fit(data_img, y=cond) 

276 

277 # Ensure model is fitted correctly 

278 assert sl.scores_ is not None 

279 assert sl.process_mask_ is not None 

280 

281 # Perform transform on the same data 

282 transformed_scores = sl.transform(data_img) 

283 

284 assert transformed_scores is not None 

285 assert transformed_scores.shape == (5, 5, 5) 

286 assert transformed_scores.size > 0 

287 

288 

289def test_process_mask_shape_mismatch(): 

290 """Test SearchLight with mismatched process mask and image dimensions.""" 

291 frames = 20 

292 data_img, cond, mask_img = _make_searchlight_test_data(frames) 

293 

294 # Create a process mask with smaller dimensions 

295 # (4x4x4 vs 5x5x5 in data_img) 

296 process_mask_img = Nifti1Image( 

297 np.ones((4, 4, 4), dtype="uint8"), np.eye(4) 

298 ) 

299 

300 # Instantiate SearchLight with mismatched process mask 

301 sl = searchlight.SearchLight( 

302 mask_img=mask_img, process_mask_img=process_mask_img, radius=1.0 

303 ) 

304 

305 # Fit should complete without error, but scores may be partially populated 

306 sl.fit(data_img, y=cond) 

307 

308 # Ensure scores_ exists and is the correct shape 

309 assert sl.scores_ is not None 

310 assert sl.scores_.shape == process_mask_img.shape