Coverage for nilearn/_utils/param_validation.py: 11%
118 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
1"""Utilities to check for valid parameters."""
3import numbers
4import sys
5import warnings
7import numpy as np
8from sklearn.feature_selection import SelectPercentile, f_classif, f_regression
10import nilearn.typing as nilearn_typing
11from nilearn._utils import logger
12from nilearn._utils.logger import find_stack_level
13from nilearn._utils.niimg import _get_data
15# Volume of a standard (MNI152) brain mask in mm^3
16MNI152_BRAIN_VOLUME = 1827243.0
19def check_threshold(
20 threshold, data, percentile_func, name="threshold", two_sided=True
21):
22 """Check if the given threshold is in correct format and within the limit.
24 If threshold is string, this function returns score of the data calculated
25 based upon the given specific percentile function.
27 Parameters
28 ----------
29 threshold : :obj:`float` or :obj:`str`
30 Threshold that is used to set certain data values to zero.
31 If threshold is float, it should be within the range of minimum and the
32 maximum intensity of the data.
33 If `two_sided` is True, threshold cannot be negative.
34 If threshold is str, the given string should be within the range of
35 "0%" to "100%".
37 data : ndarray
38 An array of the input masked data.
40 percentile_func : function {scoreatpercentile, fast_abs_percentile}
41 Percentile function for example scipy.stats.scoreatpercentile
42 to calculate the score on the data.
44 name : :obj:`str`, default='threshold'
45 A string just used for representing the name of the threshold for a
46 precise error message.
48 two_sided : :obj:`bool`, default=True
49 Whether the thresholding should yield both positive and negative
50 part of the maps.
52 .. versionadded:: 0.11.2dev
54 Returns
55 -------
56 threshold : :obj:`float`
57 Returns the score of the percentile on the data or returns threshold as
58 it is if given threshold is not a string percentile.
60 Raises
61 ------
62 ValueError
63 If threshold is of type str but is not a non-negative number followed
64 by the percent sign.
65 If threshold is a negative float and `two_sided` is True.
66 TypeError
67 If threshold is neither float nor a string in correct percentile
68 format.
69 """
70 percentile = False
71 if isinstance(threshold, str):
72 message = (
73 f'If "{name}" is given as string it '
74 "should be a number followed by the percent "
75 'sign, e.g. "25.3%"'
76 )
77 if not threshold.endswith("%"):
78 raise ValueError(message)
79 try:
80 threshold = float(threshold[:-1])
81 percentile = True
82 except ValueError as exc:
83 exc.args += (message,)
84 raise
85 elif not isinstance(threshold, numbers.Real):
86 raise TypeError(
87 f"{name} should be either a number "
88 "or a string finishing with a percent sign"
89 )
91 if threshold >= 0:
92 data = abs(data) if two_sided else np.extract(data >= 0, data)
94 if percentile:
95 threshold = percentile_func(data, threshold)
96 else:
97 value_check = data.max()
98 if threshold > value_check:
99 warnings.warn(
100 f"The given float value must not exceed {value_check}. "
101 f"But, you have given threshold={threshold}.",
102 category=UserWarning,
103 stacklevel=find_stack_level(),
104 )
105 else:
106 if two_sided:
107 raise ValueError(
108 f'"{name}" should not be a negative value when two_sided=True.'
109 )
110 if percentile:
111 raise ValueError(
112 f'"{name}" should not be a negative percentile value.'
113 )
114 data = np.extract(data <= 0, data)
115 value_check = data.min()
116 if threshold < value_check:
117 warnings.warn(
118 f"The given float value must not be less than "
119 f"{value_check}. But, you have given "
120 f"threshold={threshold}.",
121 category=UserWarning,
122 stacklevel=find_stack_level(),
123 )
125 return threshold
128def _get_mask_extent(mask_img):
129 """Compute the extent of the provided brain mask.
130 The extent is the volume of the mask in mm^3 if mask_img is a Nifti1Image
131 or the number of vertices if mask_img is a SurfaceImage.
133 Parameters
134 ----------
135 mask_img : Nifti1Image or SurfaceImage
136 The Nifti1Image whose voxel dimensions or the SurfaceImage whose
137 number of vertices are to be computed.
139 Returns
140 -------
141 mask_extent : float
142 The computed volume in mm^3 (if mask_img is a Nifti1Image) or the
143 number of vertices (if mask_img is a SurfaceImage).
145 """
146 if not hasattr(mask_img, "affine"):
147 # sum number of True values in both hemispheres
148 return (
149 mask_img.data.parts["left"].sum()
150 + mask_img.data.parts["right"].sum()
151 )
152 affine = mask_img.affine
153 prod_vox_dims = 1.0 * np.abs(np.linalg.det(affine[:3, :3]))
154 return prod_vox_dims * _get_data(mask_img).astype(bool).sum()
157def adjust_screening_percentile(
158 screening_percentile,
159 mask_img,
160 verbose=0,
161 mesh_n_vertices=None,
162):
163 """Adjust the screening percentile according to the MNI152 template or
164 the number of vertices of the provided standard brain mesh.
166 Parameters
167 ----------
168 screening_percentile : float in the interval [0, 100]
169 Percentile value for ANOVA univariate feature selection. A value of
170 100 means 'keep all features'. This percentile is expressed
171 w.r.t the volume of either a standard (MNI152) brain (if mask_img is a
172 3D volume) or a the number of vertices in the standard brain mesh
173 (if mask_img is a SurfaceImage). This means that the
174 `screening_percentile` is corrected at runtime by premultiplying it
175 with the ratio of the volume of the mask of the data and volume of the
176 standard brain.
178 mask_img : Nifti1Image or SurfaceImage
179 The Nifti1Image whose voxel dimensions or the SurfaceImage whose
180 number of vertices are to be computed.
182 %(verbose0)s
184 mesh_n_vertices : int, default=None
185 Number of vertices of the reference brain mesh, eg., fsaverage5
186 or fsaverage7 etc.. If provided, the screening percentile will be
187 adjusted according to the number of vertices.
189 Returns
190 -------
191 screening_percentile : float in the interval [0, 100]
192 Percentile value for ANOVA univariate feature selection.
194 """
195 original_screening_percentile = screening_percentile
196 # correct screening_percentile according to the volume of the data mask
197 # or the number of vertices of the reference mesh
198 mask_extent = _get_mask_extent(mask_img)
199 # if mask_img is a surface mesh, reference is the number of vertices
200 # in the standard mesh otherwise it is the volume of the MNI152 brain
201 # template
202 reference_extent = (
203 MNI152_BRAIN_VOLUME if mesh_n_vertices is None else mesh_n_vertices
204 )
205 if mask_extent > 1.1 * reference_extent:
206 warnings.warn(
207 "Brain mask is bigger than the standard "
208 "human brain. This object is probably not tuned to "
209 "be used on such data.",
210 stacklevel=find_stack_level(),
211 )
212 elif mask_extent < 0.005 * reference_extent:
213 warnings.warn(
214 "Brain mask is smaller than .5% of the size of the standard "
215 "human brain. This object is probably not tuned to "
216 "be used on such data.",
217 stacklevel=find_stack_level(),
218 )
220 if screening_percentile < 100.0:
221 screening_percentile = screening_percentile * (
222 reference_extent / mask_extent
223 )
224 screening_percentile = min(screening_percentile, 100.0)
225 # if screening_percentile is 100, we don't do anything
227 if hasattr(mask_img, "mesh"):
228 log_mask = f"Mask n_vertices = {mask_extent:g}"
229 else:
230 log_mask = (
231 f"Mask volume = {mask_extent:g}mm^3 = {mask_extent / 1000.0:g}cm^3"
232 )
233 logger.log(
234 log_mask,
235 verbose=verbose,
236 msg_level=1,
237 )
238 if hasattr(mask_img, "mesh"):
239 log_ref = f"Reference mesh n_vertices = {reference_extent:g}"
240 else:
241 log_ref = f"Standard brain volume = {MNI152_BRAIN_VOLUME:g}mm^3"
242 logger.log(
243 log_ref,
244 verbose=verbose,
245 msg_level=1,
246 )
247 logger.log(
248 f"Original screening-percentile: {original_screening_percentile:g}",
249 verbose=verbose,
250 msg_level=1,
251 )
252 logger.log(
253 f"Corrected screening-percentile: {screening_percentile:g}",
254 verbose=verbose,
255 msg_level=1,
256 )
257 return screening_percentile
260def check_feature_screening(
261 screening_percentile,
262 mask_img,
263 is_classification,
264 verbose=0,
265 mesh_n_vertices=None,
266):
267 """Check feature screening method.
269 Turns floats between 1 and 100 into SelectPercentile objects.
271 Parameters
272 ----------
273 screening_percentile : float in the interval [0, 100]
274 Percentile value for :term:`ANOVA` univariate feature selection.
275 A value of 100 means 'keep all features'.
276 This percentile is expressed
277 w.r.t the volume of a standard (MNI152) brain, and so is corrected
278 at runtime by premultiplying it with the ratio of the volume of the
279 mask of the data and volume of a standard brain.
281 mask_img : nibabel image object
282 Input image whose :term:`voxel` dimensions are to be computed.
284 is_classification : bool
285 If is_classification is True, it indicates that a classification task
286 is performed. Otherwise, a regression task is performed.
288 %(verbose0)s
290 mesh_n_vertices : int, default=None
291 Number of vertices of the reference mesh, eg., fsaverage5 or
292 fsaverage7 etc.. If provided, the screening percentile will be adjusted
293 according to the number of vertices.
295 Returns
296 -------
297 selector : SelectPercentile instance
298 Used to perform the :term:`ANOVA` univariate feature selection.
300 """
301 f_test = f_classif if is_classification else f_regression
303 if screening_percentile == 100 or screening_percentile is None:
304 return None
305 elif not (0.0 <= screening_percentile <= 100.0):
306 raise ValueError(
307 "screening_percentile should be in the interval"
308 f" [0, 100], got {screening_percentile:g}"
309 )
310 else:
311 # correct screening_percentile according to the volume or the number of
312 # vertices in the data mask
313 screening_percentile_ = adjust_screening_percentile(
314 screening_percentile,
315 mask_img,
316 verbose=verbose,
317 mesh_n_vertices=mesh_n_vertices,
318 )
320 return SelectPercentile(f_test, percentile=int(screening_percentile_))
323def check_run_sample_masks(n_runs, sample_masks):
324 """Check that number of sample_mask matches number of runs."""
325 if not isinstance(sample_masks, (list, tuple, np.ndarray)):
326 raise TypeError(
327 f"sample_mask has an unhandled type: {sample_masks.__class__}"
328 )
330 if isinstance(sample_masks, np.ndarray):
331 sample_masks = (sample_masks,)
333 checked_sample_masks = [_convert_bool2index(sm) for sm in sample_masks]
334 checked_sample_masks = [_cast_to_int32(sm) for sm in checked_sample_masks]
336 if len(checked_sample_masks) != n_runs:
337 raise ValueError(
338 f"Number of sample_mask ({len(checked_sample_masks)}) not "
339 f"matching number of runs ({n_runs})."
340 )
341 return checked_sample_masks
344def _convert_bool2index(sample_mask):
345 """Convert boolean to index."""
346 check_boolean = [
347 type(i) is bool or type(i) is np.bool_ for i in sample_mask
348 ]
349 if all(check_boolean):
350 sample_mask = np.where(sample_mask)[0]
351 return sample_mask
354def _cast_to_int32(sample_mask):
355 """Ensure the sample mask dtype is signed."""
356 new_dtype = np.int32
357 if np.min(sample_mask) < 0:
358 msg = "sample_mask should not contain negative values."
359 raise ValueError(msg)
361 if highest := np.max(sample_mask) > np.iinfo(new_dtype).max:
362 msg = f"Max value in sample mask is larger than \
363 what can be represented by int32: {highest}."
364 raise ValueError(msg)
365 return np.asarray(sample_mask, new_dtype)
368# dictionary that matches a given parameter / attribute name to a type
369TYPE_MAPS = {
370 "annotate": nilearn_typing.Annotate,
371 "border_size": nilearn_typing.BorderSize,
372 "bg_on_data": nilearn_typing.BgOnData,
373 "colorbar": nilearn_typing.ColorBar,
374 "connected": nilearn_typing.Connected,
375 "data_dir": nilearn_typing.DataDir,
376 "draw_cross": nilearn_typing.DrawCross,
377 "detrend": nilearn_typing.Detrend,
378 "high_pass": nilearn_typing.HighPass,
379 "hrf_model": nilearn_typing.HrfModel,
380 "keep_masked_labels": nilearn_typing.KeepMaskedLabels,
381 "keep_masked_maps": nilearn_typing.KeepMaskedMaps,
382 "low_pass": nilearn_typing.LowPass,
383 "lower_cutoff": nilearn_typing.LowerCutoff,
384 "memory": nilearn_typing.MemoryLike,
385 "memory_level": nilearn_typing.MemoryLevel,
386 "n_jobs": nilearn_typing.NJobs,
387 "n_perm": nilearn_typing.NPerm,
388 "opening": nilearn_typing.Opening,
389 "radiological": nilearn_typing.Radiological,
390 "random_state": nilearn_typing.RandomState,
391 "resolution": nilearn_typing.Resolution,
392 "resume": nilearn_typing.Resume,
393 "smoothing_fwhm": nilearn_typing.SmoothingFwhm,
394 "standardize_confounds": nilearn_typing.StandardizeConfounds,
395 "t_r": nilearn_typing.Tr,
396 "tfce": nilearn_typing.Tfce,
397 "threshold": nilearn_typing.Threshold,
398 "title": nilearn_typing.Title,
399 "two_sided_test": nilearn_typing.TwoSidedTest,
400 "target_affine": nilearn_typing.TargetAffine,
401 "target_shape": nilearn_typing.TargetShape,
402 "transparency": nilearn_typing.Transparency,
403 "transparency_range": nilearn_typing.TransparencyRange,
404 "url": nilearn_typing.Url,
405 "upper_cutoff": nilearn_typing.UpperCutoff,
406 "verbose": nilearn_typing.Verbose,
407 "vmax": nilearn_typing.Vmax,
408 "vmin": nilearn_typing.Vmin,
409}
412def check_params(fn_dict):
413 """Check types of inputs passed to a function / method / class.
415 This function checks the types of function / method parameters or type_map
416 the attributes of the class.
418 This function is made to check the types of the parameters
419 described in ``nilearn._utils.docs``
420 that are shared by many functions / methods / class
421 and thus ensure a generic way to do input validation
422 in several important points in the code base.
424 In most cases this means that this function can be used
425 on functions / classes that have the ``@fill_doc`` decorator,
426 or whose doc string uses parameter templates
427 (for example ``%(data_dir)s``).
429 If the function cannot (yet) check any of the parameters / attributes,
430 it will throw an error to say that its use is not needed.
432 Typical usage:
434 .. code-block:: python
436 def some_function(param_1, param_2="a"):
437 check_params(locals())
438 ...
440 Class MyClass:
441 def __init__(param_1, param_2="a")
442 ...
444 def fit(X):
445 # check attributes of the class instance
446 check_params(self.__dict__)
447 # check parameters passed to the method
448 check_params(locals())
450 """
451 keys_to_check = set(TYPE_MAPS.keys()).intersection(set(fn_dict.keys()))
452 # Send a message to dev if they are using this function needlessly.
453 if not keys_to_check:
454 raise ValueError(
455 "No known parameter to check.\n"
456 "You probably do not need to use 'check_params' here."
457 )
459 for k in keys_to_check:
460 type_to_check = TYPE_MAPS[k]
461 value = fn_dict[k]
463 # TODO update when dropping python 3.9
464 error_msg = (
465 f"'{k}' should be of type '{type_to_check}'.\nGot: '{type(value)}'"
466 )
467 if sys.version_info[1] > 9:
468 if not isinstance(value, type_to_check):
469 raise TypeError(error_msg)
470 elif value is not None and not isinstance(value, type_to_check):
471 raise TypeError(error_msg)
474def check_reduction_strategy(strategy: str):
475 """Check that the provided strategy is supported.
477 Parameters
478 ----------
479 %(strategy)s
480 """
481 available_reduction_strategies = {
482 "mean",
483 "median",
484 "sum",
485 "minimum",
486 "maximum",
487 "standard_deviation",
488 "variance",
489 }
491 if strategy not in available_reduction_strategies:
492 raise ValueError(
493 f"Invalid strategy '{strategy}'. "
494 f"Valid strategies are {available_reduction_strategies}."
495 )