Coverage for nilearn/_utils/param_validation.py: 11%

118 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-20 10:58 +0200

1"""Utilities to check for valid parameters.""" 

2 

3import numbers 

4import sys 

5import warnings 

6 

7import numpy as np 

8from sklearn.feature_selection import SelectPercentile, f_classif, f_regression 

9 

10import nilearn.typing as nilearn_typing 

11from nilearn._utils import logger 

12from nilearn._utils.logger import find_stack_level 

13from nilearn._utils.niimg import _get_data 

14 

15# Volume of a standard (MNI152) brain mask in mm^3 

16MNI152_BRAIN_VOLUME = 1827243.0 

17 

18 

19def check_threshold( 

20 threshold, data, percentile_func, name="threshold", two_sided=True 

21): 

22 """Check if the given threshold is in correct format and within the limit. 

23 

24 If threshold is string, this function returns score of the data calculated 

25 based upon the given specific percentile function. 

26 

27 Parameters 

28 ---------- 

29 threshold : :obj:`float` or :obj:`str` 

30 Threshold that is used to set certain data values to zero. 

31 If threshold is float, it should be within the range of minimum and the 

32 maximum intensity of the data. 

33 If `two_sided` is True, threshold cannot be negative. 

34 If threshold is str, the given string should be within the range of 

35 "0%" to "100%". 

36 

37 data : ndarray 

38 An array of the input masked data. 

39 

40 percentile_func : function {scoreatpercentile, fast_abs_percentile} 

41 Percentile function for example scipy.stats.scoreatpercentile 

42 to calculate the score on the data. 

43 

44 name : :obj:`str`, default='threshold' 

45 A string just used for representing the name of the threshold for a 

46 precise error message. 

47 

48 two_sided : :obj:`bool`, default=True 

49 Whether the thresholding should yield both positive and negative 

50 part of the maps. 

51 

52 .. versionadded:: 0.11.2dev 

53 

54 Returns 

55 ------- 

56 threshold : :obj:`float` 

57 Returns the score of the percentile on the data or returns threshold as 

58 it is if given threshold is not a string percentile. 

59 

60 Raises 

61 ------ 

62 ValueError 

63 If threshold is of type str but is not a non-negative number followed 

64 by the percent sign. 

65 If threshold is a negative float and `two_sided` is True. 

66 TypeError 

67 If threshold is neither float nor a string in correct percentile 

68 format. 

69 """ 

70 percentile = False 

71 if isinstance(threshold, str): 

72 message = ( 

73 f'If "{name}" is given as string it ' 

74 "should be a number followed by the percent " 

75 'sign, e.g. "25.3%"' 

76 ) 

77 if not threshold.endswith("%"): 

78 raise ValueError(message) 

79 try: 

80 threshold = float(threshold[:-1]) 

81 percentile = True 

82 except ValueError as exc: 

83 exc.args += (message,) 

84 raise 

85 elif not isinstance(threshold, numbers.Real): 

86 raise TypeError( 

87 f"{name} should be either a number " 

88 "or a string finishing with a percent sign" 

89 ) 

90 

91 if threshold >= 0: 

92 data = abs(data) if two_sided else np.extract(data >= 0, data) 

93 

94 if percentile: 

95 threshold = percentile_func(data, threshold) 

96 else: 

97 value_check = data.max() 

98 if threshold > value_check: 

99 warnings.warn( 

100 f"The given float value must not exceed {value_check}. " 

101 f"But, you have given threshold={threshold}.", 

102 category=UserWarning, 

103 stacklevel=find_stack_level(), 

104 ) 

105 else: 

106 if two_sided: 

107 raise ValueError( 

108 f'"{name}" should not be a negative value when two_sided=True.' 

109 ) 

110 if percentile: 

111 raise ValueError( 

112 f'"{name}" should not be a negative percentile value.' 

113 ) 

114 data = np.extract(data <= 0, data) 

115 value_check = data.min() 

116 if threshold < value_check: 

117 warnings.warn( 

118 f"The given float value must not be less than " 

119 f"{value_check}. But, you have given " 

120 f"threshold={threshold}.", 

121 category=UserWarning, 

122 stacklevel=find_stack_level(), 

123 ) 

124 

125 return threshold 

126 

127 

128def _get_mask_extent(mask_img): 

129 """Compute the extent of the provided brain mask. 

130 The extent is the volume of the mask in mm^3 if mask_img is a Nifti1Image 

131 or the number of vertices if mask_img is a SurfaceImage. 

132 

133 Parameters 

134 ---------- 

135 mask_img : Nifti1Image or SurfaceImage 

136 The Nifti1Image whose voxel dimensions or the SurfaceImage whose 

137 number of vertices are to be computed. 

138 

139 Returns 

140 ------- 

141 mask_extent : float 

142 The computed volume in mm^3 (if mask_img is a Nifti1Image) or the 

143 number of vertices (if mask_img is a SurfaceImage). 

144 

145 """ 

146 if not hasattr(mask_img, "affine"): 

147 # sum number of True values in both hemispheres 

148 return ( 

149 mask_img.data.parts["left"].sum() 

150 + mask_img.data.parts["right"].sum() 

151 ) 

152 affine = mask_img.affine 

153 prod_vox_dims = 1.0 * np.abs(np.linalg.det(affine[:3, :3])) 

154 return prod_vox_dims * _get_data(mask_img).astype(bool).sum() 

155 

156 

157def adjust_screening_percentile( 

158 screening_percentile, 

159 mask_img, 

160 verbose=0, 

161 mesh_n_vertices=None, 

162): 

163 """Adjust the screening percentile according to the MNI152 template or 

164 the number of vertices of the provided standard brain mesh. 

165 

166 Parameters 

167 ---------- 

168 screening_percentile : float in the interval [0, 100] 

169 Percentile value for ANOVA univariate feature selection. A value of 

170 100 means 'keep all features'. This percentile is expressed 

171 w.r.t the volume of either a standard (MNI152) brain (if mask_img is a 

172 3D volume) or a the number of vertices in the standard brain mesh 

173 (if mask_img is a SurfaceImage). This means that the 

174 `screening_percentile` is corrected at runtime by premultiplying it 

175 with the ratio of the volume of the mask of the data and volume of the 

176 standard brain. 

177 

178 mask_img : Nifti1Image or SurfaceImage 

179 The Nifti1Image whose voxel dimensions or the SurfaceImage whose 

180 number of vertices are to be computed. 

181 

182 %(verbose0)s 

183 

184 mesh_n_vertices : int, default=None 

185 Number of vertices of the reference brain mesh, eg., fsaverage5 

186 or fsaverage7 etc.. If provided, the screening percentile will be 

187 adjusted according to the number of vertices. 

188 

189 Returns 

190 ------- 

191 screening_percentile : float in the interval [0, 100] 

192 Percentile value for ANOVA univariate feature selection. 

193 

194 """ 

195 original_screening_percentile = screening_percentile 

196 # correct screening_percentile according to the volume of the data mask 

197 # or the number of vertices of the reference mesh 

198 mask_extent = _get_mask_extent(mask_img) 

199 # if mask_img is a surface mesh, reference is the number of vertices 

200 # in the standard mesh otherwise it is the volume of the MNI152 brain 

201 # template 

202 reference_extent = ( 

203 MNI152_BRAIN_VOLUME if mesh_n_vertices is None else mesh_n_vertices 

204 ) 

205 if mask_extent > 1.1 * reference_extent: 

206 warnings.warn( 

207 "Brain mask is bigger than the standard " 

208 "human brain. This object is probably not tuned to " 

209 "be used on such data.", 

210 stacklevel=find_stack_level(), 

211 ) 

212 elif mask_extent < 0.005 * reference_extent: 

213 warnings.warn( 

214 "Brain mask is smaller than .5% of the size of the standard " 

215 "human brain. This object is probably not tuned to " 

216 "be used on such data.", 

217 stacklevel=find_stack_level(), 

218 ) 

219 

220 if screening_percentile < 100.0: 

221 screening_percentile = screening_percentile * ( 

222 reference_extent / mask_extent 

223 ) 

224 screening_percentile = min(screening_percentile, 100.0) 

225 # if screening_percentile is 100, we don't do anything 

226 

227 if hasattr(mask_img, "mesh"): 

228 log_mask = f"Mask n_vertices = {mask_extent:g}" 

229 else: 

230 log_mask = ( 

231 f"Mask volume = {mask_extent:g}mm^3 = {mask_extent / 1000.0:g}cm^3" 

232 ) 

233 logger.log( 

234 log_mask, 

235 verbose=verbose, 

236 msg_level=1, 

237 ) 

238 if hasattr(mask_img, "mesh"): 

239 log_ref = f"Reference mesh n_vertices = {reference_extent:g}" 

240 else: 

241 log_ref = f"Standard brain volume = {MNI152_BRAIN_VOLUME:g}mm^3" 

242 logger.log( 

243 log_ref, 

244 verbose=verbose, 

245 msg_level=1, 

246 ) 

247 logger.log( 

248 f"Original screening-percentile: {original_screening_percentile:g}", 

249 verbose=verbose, 

250 msg_level=1, 

251 ) 

252 logger.log( 

253 f"Corrected screening-percentile: {screening_percentile:g}", 

254 verbose=verbose, 

255 msg_level=1, 

256 ) 

257 return screening_percentile 

258 

259 

260def check_feature_screening( 

261 screening_percentile, 

262 mask_img, 

263 is_classification, 

264 verbose=0, 

265 mesh_n_vertices=None, 

266): 

267 """Check feature screening method. 

268 

269 Turns floats between 1 and 100 into SelectPercentile objects. 

270 

271 Parameters 

272 ---------- 

273 screening_percentile : float in the interval [0, 100] 

274 Percentile value for :term:`ANOVA` univariate feature selection. 

275 A value of 100 means 'keep all features'. 

276 This percentile is expressed 

277 w.r.t the volume of a standard (MNI152) brain, and so is corrected 

278 at runtime by premultiplying it with the ratio of the volume of the 

279 mask of the data and volume of a standard brain. 

280 

281 mask_img : nibabel image object 

282 Input image whose :term:`voxel` dimensions are to be computed. 

283 

284 is_classification : bool 

285 If is_classification is True, it indicates that a classification task 

286 is performed. Otherwise, a regression task is performed. 

287 

288 %(verbose0)s 

289 

290 mesh_n_vertices : int, default=None 

291 Number of vertices of the reference mesh, eg., fsaverage5 or 

292 fsaverage7 etc.. If provided, the screening percentile will be adjusted 

293 according to the number of vertices. 

294 

295 Returns 

296 ------- 

297 selector : SelectPercentile instance 

298 Used to perform the :term:`ANOVA` univariate feature selection. 

299 

300 """ 

301 f_test = f_classif if is_classification else f_regression 

302 

303 if screening_percentile == 100 or screening_percentile is None: 

304 return None 

305 elif not (0.0 <= screening_percentile <= 100.0): 

306 raise ValueError( 

307 "screening_percentile should be in the interval" 

308 f" [0, 100], got {screening_percentile:g}" 

309 ) 

310 else: 

311 # correct screening_percentile according to the volume or the number of 

312 # vertices in the data mask 

313 screening_percentile_ = adjust_screening_percentile( 

314 screening_percentile, 

315 mask_img, 

316 verbose=verbose, 

317 mesh_n_vertices=mesh_n_vertices, 

318 ) 

319 

320 return SelectPercentile(f_test, percentile=int(screening_percentile_)) 

321 

322 

323def check_run_sample_masks(n_runs, sample_masks): 

324 """Check that number of sample_mask matches number of runs.""" 

325 if not isinstance(sample_masks, (list, tuple, np.ndarray)): 

326 raise TypeError( 

327 f"sample_mask has an unhandled type: {sample_masks.__class__}" 

328 ) 

329 

330 if isinstance(sample_masks, np.ndarray): 

331 sample_masks = (sample_masks,) 

332 

333 checked_sample_masks = [_convert_bool2index(sm) for sm in sample_masks] 

334 checked_sample_masks = [_cast_to_int32(sm) for sm in checked_sample_masks] 

335 

336 if len(checked_sample_masks) != n_runs: 

337 raise ValueError( 

338 f"Number of sample_mask ({len(checked_sample_masks)}) not " 

339 f"matching number of runs ({n_runs})." 

340 ) 

341 return checked_sample_masks 

342 

343 

344def _convert_bool2index(sample_mask): 

345 """Convert boolean to index.""" 

346 check_boolean = [ 

347 type(i) is bool or type(i) is np.bool_ for i in sample_mask 

348 ] 

349 if all(check_boolean): 

350 sample_mask = np.where(sample_mask)[0] 

351 return sample_mask 

352 

353 

354def _cast_to_int32(sample_mask): 

355 """Ensure the sample mask dtype is signed.""" 

356 new_dtype = np.int32 

357 if np.min(sample_mask) < 0: 

358 msg = "sample_mask should not contain negative values." 

359 raise ValueError(msg) 

360 

361 if highest := np.max(sample_mask) > np.iinfo(new_dtype).max: 

362 msg = f"Max value in sample mask is larger than \ 

363 what can be represented by int32: {highest}." 

364 raise ValueError(msg) 

365 return np.asarray(sample_mask, new_dtype) 

366 

367 

368# dictionary that matches a given parameter / attribute name to a type 

369TYPE_MAPS = { 

370 "annotate": nilearn_typing.Annotate, 

371 "border_size": nilearn_typing.BorderSize, 

372 "bg_on_data": nilearn_typing.BgOnData, 

373 "colorbar": nilearn_typing.ColorBar, 

374 "connected": nilearn_typing.Connected, 

375 "data_dir": nilearn_typing.DataDir, 

376 "draw_cross": nilearn_typing.DrawCross, 

377 "detrend": nilearn_typing.Detrend, 

378 "high_pass": nilearn_typing.HighPass, 

379 "hrf_model": nilearn_typing.HrfModel, 

380 "keep_masked_labels": nilearn_typing.KeepMaskedLabels, 

381 "keep_masked_maps": nilearn_typing.KeepMaskedMaps, 

382 "low_pass": nilearn_typing.LowPass, 

383 "lower_cutoff": nilearn_typing.LowerCutoff, 

384 "memory": nilearn_typing.MemoryLike, 

385 "memory_level": nilearn_typing.MemoryLevel, 

386 "n_jobs": nilearn_typing.NJobs, 

387 "n_perm": nilearn_typing.NPerm, 

388 "opening": nilearn_typing.Opening, 

389 "radiological": nilearn_typing.Radiological, 

390 "random_state": nilearn_typing.RandomState, 

391 "resolution": nilearn_typing.Resolution, 

392 "resume": nilearn_typing.Resume, 

393 "smoothing_fwhm": nilearn_typing.SmoothingFwhm, 

394 "standardize_confounds": nilearn_typing.StandardizeConfounds, 

395 "t_r": nilearn_typing.Tr, 

396 "tfce": nilearn_typing.Tfce, 

397 "threshold": nilearn_typing.Threshold, 

398 "title": nilearn_typing.Title, 

399 "two_sided_test": nilearn_typing.TwoSidedTest, 

400 "target_affine": nilearn_typing.TargetAffine, 

401 "target_shape": nilearn_typing.TargetShape, 

402 "transparency": nilearn_typing.Transparency, 

403 "transparency_range": nilearn_typing.TransparencyRange, 

404 "url": nilearn_typing.Url, 

405 "upper_cutoff": nilearn_typing.UpperCutoff, 

406 "verbose": nilearn_typing.Verbose, 

407 "vmax": nilearn_typing.Vmax, 

408 "vmin": nilearn_typing.Vmin, 

409} 

410 

411 

412def check_params(fn_dict): 

413 """Check types of inputs passed to a function / method / class. 

414 

415 This function checks the types of function / method parameters or type_map 

416 the attributes of the class. 

417 

418 This function is made to check the types of the parameters 

419 described in ``nilearn._utils.docs`` 

420 that are shared by many functions / methods / class 

421 and thus ensure a generic way to do input validation 

422 in several important points in the code base. 

423 

424 In most cases this means that this function can be used 

425 on functions / classes that have the ``@fill_doc`` decorator, 

426 or whose doc string uses parameter templates 

427 (for example ``%(data_dir)s``). 

428 

429 If the function cannot (yet) check any of the parameters / attributes, 

430 it will throw an error to say that its use is not needed. 

431 

432 Typical usage: 

433 

434 .. code-block:: python 

435 

436 def some_function(param_1, param_2="a"): 

437 check_params(locals()) 

438 ... 

439 

440 Class MyClass: 

441 def __init__(param_1, param_2="a") 

442 ... 

443 

444 def fit(X): 

445 # check attributes of the class instance 

446 check_params(self.__dict__) 

447 # check parameters passed to the method 

448 check_params(locals()) 

449 

450 """ 

451 keys_to_check = set(TYPE_MAPS.keys()).intersection(set(fn_dict.keys())) 

452 # Send a message to dev if they are using this function needlessly. 

453 if not keys_to_check: 

454 raise ValueError( 

455 "No known parameter to check.\n" 

456 "You probably do not need to use 'check_params' here." 

457 ) 

458 

459 for k in keys_to_check: 

460 type_to_check = TYPE_MAPS[k] 

461 value = fn_dict[k] 

462 

463 # TODO update when dropping python 3.9 

464 error_msg = ( 

465 f"'{k}' should be of type '{type_to_check}'.\nGot: '{type(value)}'" 

466 ) 

467 if sys.version_info[1] > 9: 

468 if not isinstance(value, type_to_check): 

469 raise TypeError(error_msg) 

470 elif value is not None and not isinstance(value, type_to_check): 

471 raise TypeError(error_msg) 

472 

473 

474def check_reduction_strategy(strategy: str): 

475 """Check that the provided strategy is supported. 

476 

477 Parameters 

478 ---------- 

479 %(strategy)s 

480 """ 

481 available_reduction_strategies = { 

482 "mean", 

483 "median", 

484 "sum", 

485 "minimum", 

486 "maximum", 

487 "standard_deviation", 

488 "variance", 

489 } 

490 

491 if strategy not in available_reduction_strategies: 

492 raise ValueError( 

493 f"Invalid strategy '{strategy}'. " 

494 f"Valid strategies are {available_reduction_strategies}." 

495 )