Coverage for nilearn/_utils/tests/test_param_validation.py: 0%
102 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
1"""Test the _utils.param_validation module."""
3import warnings
4from pathlib import Path
6import numpy as np
7import pytest
8from nibabel import Nifti1Image, load
9from scipy.stats import scoreatpercentile
10from sklearn.base import BaseEstimator
12from nilearn._utils.extmath import fast_abs_percentile
13from nilearn._utils.param_validation import (
14 MNI152_BRAIN_VOLUME,
15 _cast_to_int32,
16 _get_mask_extent,
17 check_feature_screening,
18 check_params,
19 check_threshold,
20)
22mni152_brain_mask = (
23 "/usr/share/fsl/data/standard/MNI152_T1_1mm_brain_mask.nii.gz"
24)
27@pytest.fixture
28def matrix():
29 return np.array(
30 [[-3.0, 2.0, -1.0, 0.0, -4.0], [4.0, -6.0, 5.0, 1.0, -3.0]]
31 )
34def test_check_threshold_positive_and_zero_ts_true(matrix):
35 """Tests nilearn._utils.param_validation.check_threshold when
36 two_sided=True, threshold is specified as a number and threshold >=0.
37 """
38 # Test threshold=0 should return as it is since it is not string
39 assert check_threshold(0, matrix, scoreatpercentile, two_sided=True) == 0
41 # Test threshold=6 should return as it is since it is not string
42 assert check_threshold(6, matrix, scoreatpercentile, two_sided=True) == 6
44 # test with numpy scalar as argument
45 threshold = 2.0
46 threshold_numpy_scalar = np.float64(threshold)
47 assert check_threshold(
48 threshold, matrix, scoreatpercentile, two_sided=True
49 ) == check_threshold(
50 threshold_numpy_scalar, matrix, scoreatpercentile, two_sided=True
51 )
53 # check whether raises a warning if given threshold is higher than expected
54 with pytest.warns(UserWarning):
55 check_threshold(6.5, matrix, scoreatpercentile, two_sided=True)
58def test_check_threshold_positive_and_zero_ts_false(matrix):
59 """Tests nilearn._utils.param_validation.check_threshold when
60 two_sided=False, threshold is specified as a number and threshold >=0.
61 """
62 # Test threshold=4 should return as it is since it is not string
63 assert check_threshold(5, matrix, scoreatpercentile, two_sided=False) == 5
65 # check whether raises a warning if given threshold is higher than expected
66 # 6 will raise warning as negative values are not considered
67 with pytest.warns(UserWarning):
68 check_threshold(6, matrix, scoreatpercentile, two_sided=False)
71def test_check_threshold_percentile_positive_and_zero_ts_true(matrix):
72 """Tests nilearn._utils.param_validation.check_threshold when
73 two_sided=True, threshold is specified as percentile (str ending with a %)
74 and threshold >=0.
75 """
76 # Test for threshold provided as a percentile of the data
77 # ()
78 threshold = check_threshold(
79 "10%", matrix, scoreatpercentile, two_sided=True
80 )
81 assert 0 < threshold < 1.0
83 threshold = check_threshold(
84 "40%", matrix, scoreatpercentile, two_sided=True
85 )
86 assert 2.0 < threshold < 3.0
88 threshold = check_threshold(
89 "90%", matrix, scoreatpercentile, two_sided=True
90 )
91 assert 5.0 < threshold < 6.0
94def test_check_threshold_percentile_positive_and_zero_ts_false(matrix):
95 """Tests nilearn._utils.param_validation.check_threshold when
96 two_sided=False, threshold is specified as percentile (str ending with a %)
97 and threshold >=0.
98 """
99 threshold = check_threshold(
100 "10%", matrix, scoreatpercentile, two_sided=False
101 )
102 assert 0 < threshold < 1.0
104 threshold = check_threshold(
105 "40%", matrix, scoreatpercentile, two_sided=False
106 )
107 assert 1.0 < threshold < 2.0
109 threshold = check_threshold(
110 "90%", matrix, scoreatpercentile, two_sided=False
111 )
112 assert 4.0 < threshold < 5.0
115def test_check_threshold_negative_ts_false(matrix):
116 """Tests nilearn._utils.param_validation.check_threshold when
117 two_sided=False, threshold is specified as a number and threshold <=0.
118 """
119 # Test threshold=0 should return as it is since it is not string
120 assert check_threshold(0, matrix, scoreatpercentile, two_sided=False) == 0
122 # Test threshold=4 should return as it is since it is not string
123 assert (
124 check_threshold(-6, matrix, scoreatpercentile, two_sided=False) == -6
125 )
127 # check whether raises a warning if given threshold is higher than expected
128 # -7 will raise warning as negative values are not considered
129 with pytest.warns(UserWarning):
130 check_threshold(-7, matrix, scoreatpercentile, two_sided=False)
133def test_check_threshold_for_error(matrix):
134 """Tests nilearn._utils.param_validation.check_threshold for errors."""
135 name = "threshold"
136 # few not correctly formatted strings for 'threshold'
137 wrong_thresholds = ["0.1", "10", "10.2.3%", "asdf%"]
138 for wrong_threshold in wrong_thresholds:
139 for two_sided in [True, False]:
140 with pytest.raises(
141 ValueError,
142 match=f"{name}.+should be a number followed",
143 ):
144 check_threshold(
145 wrong_threshold,
146 matrix,
147 fast_abs_percentile,
148 name,
149 two_sided,
150 )
152 threshold = object()
153 for two_sided in [True, False]:
154 with pytest.raises(
155 TypeError, match=f"{name}.+should be either a number or a string"
156 ):
157 check_threshold(
158 threshold, matrix, fast_abs_percentile, name, two_sided
159 )
161 two_sided = True
162 # invalid threshold values when two_sided=True
163 thresholds = [-10, "-10%"]
164 for wrong_threshold in thresholds:
165 with pytest.raises(
166 ValueError, match=f"{name}.+should not be a negative"
167 ):
168 check_threshold(
169 wrong_threshold, matrix, fast_abs_percentile, name, two_sided
170 )
171 with pytest.raises(ValueError, match=f"{name}.+should not be a negative"):
172 check_threshold(
173 "-10%", matrix, fast_abs_percentile, name, two_sided=False
174 )
177def test_get_mask_extent():
178 # Test that hard-coded standard mask volume can be corrected computed
179 if Path(mni152_brain_mask).is_file():
180 assert _get_mask_extent(load(mni152_brain_mask)) == MNI152_BRAIN_VOLUME
181 else:
182 warnings.warn(
183 f"Couldn't find {mni152_brain_mask} (for testing)", stacklevel=2
184 )
187def test_feature_screening(affine_eye):
188 # dummy
189 mask_img_data = np.zeros((182, 218, 182))
190 mask_img_data[30:-30, 30:-30, 30:-30] = 1
191 mask_img = Nifti1Image(mask_img_data, affine=affine_eye)
193 for is_classif in [True, False]:
194 for screening_percentile in [100, None, 20, 101, -1, 10]:
195 if screening_percentile == 100 or screening_percentile is None:
196 assert (
197 check_feature_screening(
198 screening_percentile, mask_img, is_classif
199 )
200 is None
201 )
202 elif screening_percentile in {-1, 101}:
203 with pytest.raises(ValueError):
204 check_feature_screening(
205 screening_percentile,
206 mask_img,
207 is_classif,
208 )
209 elif screening_percentile == 20:
210 assert isinstance(
211 check_feature_screening(
212 screening_percentile, mask_img, is_classif
213 ),
214 BaseEstimator,
215 )
218@pytest.mark.parametrize("dtype", (np.uint8, np.uint16, np.uint32, np.int8))
219def test_sample_mask_signed(dtype):
220 """Check unsigned sample_mask is converted to signed."""
221 sample_mask = np.arange(2, dtype=dtype)
222 assert _cast_to_int32(sample_mask).dtype.kind == "i"
225def test_sample_mask_raises_on_negative():
226 """Check for error when sample_mask has negative."""
227 with pytest.raises(
228 ValueError, match="sample_mask should not contain negative values"
229 ):
230 _cast_to_int32(np.array([-1, -2, 1]))
233def test_sample_mask_raises_on_high_index():
234 """Check for error when sample_mask has a very high index."""
235 with pytest.raises(
236 ValueError, match="Max value in sample mask is larger than"
237 ):
238 _cast_to_int32(np.array(2**66))
241def test_check_params():
242 """Check that passing incorrect type to a function raises TypeError."""
244 def f_with_param_to_check(data_dir):
245 check_params(locals())
246 return data_dir
248 f_with_param_to_check(data_dir="foo")
250 with pytest.raises(TypeError, match="'data_dir' should be of type"):
251 f_with_param_to_check(data_dir=1)
254def test_check_params_not_necessary():
255 """Check an error is raised when function is used when not needed."""
257 def f_with_unknown_param(foo):
258 check_params(locals())
259 return foo
261 with pytest.raises(ValueError, match="No known parameter to check."):
262 f_with_unknown_param(foo=1)