Coverage for nilearn/regions/tests/test_region_extractor.py: 0%
242 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
1"""Test Region Extractor and its functions."""
3import numpy as np
4import pytest
5from nibabel import Nifti1Image
6from scipy.ndimage import label
8from nilearn._utils.data_gen import generate_labeled_regions, generate_maps
9from nilearn._utils.estimator_checks import (
10 check_estimator,
11 nilearn_check_estimator,
12)
13from nilearn._utils.exceptions import DimensionError
14from nilearn._utils.tags import SKLEARN_LT_1_6
15from nilearn.conftest import _affine_eye, _img_4d_zeros, _shape_3d_large
16from nilearn.image import get_data
17from nilearn.regions import (
18 RegionExtractor,
19 connected_label_regions,
20 connected_regions,
21)
22from nilearn.regions.region_extractor import (
23 _remove_small_regions,
24 _threshold_maps_ratio,
25)
28@pytest.fixture
29def negative_regions():
30 return False
33@pytest.fixture
34def dummy_map(shape_3d_default, n_regions):
35 """Generate a small dummy map.
37 Use for error testing
38 """
39 return generate_maps(shape=shape_3d_default, n_regions=n_regions)[0]
42@pytest.fixture
43def map_img_3d(rng, affine_eye, shape_3d_default):
44 map_img = np.zeros(shape_3d_default) + 0.1 * rng.standard_normal(
45 size=shape_3d_default
46 )
47 return Nifti1Image(map_img, affine=affine_eye)
50N_REGIONS = 3
53@pytest.fixture
54def maps(negative_regions, n_regions, shape_3d_large):
55 return generate_maps(
56 shape=shape_3d_large,
57 n_regions=n_regions,
58 random_state=42,
59 negative_regions=negative_regions,
60 )[0]
63@pytest.fixture
64def maps_and_mask(n_regions, shape_3d_large):
65 return generate_maps(
66 shape=shape_3d_large, n_regions=n_regions, random_state=42
67 )
70ESTIMATORS_TO_CHECK = [RegionExtractor()]
72if SKLEARN_LT_1_6:
74 @pytest.mark.parametrize(
75 "estimator, check, name",
76 check_estimator(estimators=ESTIMATORS_TO_CHECK),
77 )
78 def test_check_estimator_sklearn_valid(estimator, check, name): # noqa: ARG001
79 """Check compliance with sklearn estimators."""
80 check(estimator)
82 @pytest.mark.xfail(reason="invalid checks should fail")
83 @pytest.mark.parametrize(
84 "estimator, check, name",
85 check_estimator(estimators=ESTIMATORS_TO_CHECK, valid=False),
86 )
87 def test_check_estimator_sklearn_invalid(estimator, check, name): # noqa: ARG001
88 """Check compliance with sklearn estimators."""
89 check(estimator)
91else:
92 from sklearn.utils.estimator_checks import parametrize_with_checks
94 from nilearn._utils.estimator_checks import (
95 return_expected_failed_checks,
96 )
98 @parametrize_with_checks(
99 estimators=ESTIMATORS_TO_CHECK,
100 expected_failed_checks=return_expected_failed_checks,
101 )
102 def test_check_estimator_sklearn(estimator, check):
103 """Check compliance with sklearn estimators."""
104 check(estimator)
107@pytest.mark.timeout(0)
108@pytest.mark.parametrize(
109 "estimator, check, name",
110 nilearn_check_estimator(
111 estimators=[
112 RegionExtractor(
113 maps_img=generate_maps(
114 shape=_shape_3d_large(),
115 n_regions=2,
116 random_state=42,
117 affine=_affine_eye(),
118 )[0]
119 )
120 ]
121 ),
122)
123def test_check_estimator_nilearn(estimator, check, name): # noqa: ARG001
124 """Check compliance with nilearn estimators rules."""
125 check(estimator)
128@pytest.mark.parametrize("invalid_threshold", ["80%", "auto", -1.0])
129def test_invalid_thresholds_in_threshold_maps_ratio(
130 dummy_map, invalid_threshold
131):
132 with pytest.raises(
133 ValueError,
134 match="threshold given as ratio to the number of voxels must "
135 "be Real number and should be positive "
136 "and between 0 and total number of maps "
137 f"i.e. n_maps={dummy_map.shape[-1]}. "
138 f"You provided {invalid_threshold}",
139 ):
140 _threshold_maps_ratio(maps_img=dummy_map, threshold=invalid_threshold)
143def test_nans_threshold_maps_ratio(maps, affine_eye):
144 data = get_data(maps)
145 data[:, :, 0] = np.nan
147 maps_img = Nifti1Image(data, affine_eye)
148 _threshold_maps_ratio(maps_img, threshold=0.8)
151def test_threshold_maps_ratio(maps):
152 """Check _threshold_maps_ratio with randomly generated maps."""
153 # test that there is no side effect
154 get_data(maps)[:3] = 100
155 maps_data = get_data(maps).copy()
156 thr_maps = _threshold_maps_ratio(maps, threshold=1.0)
157 np.testing.assert_array_equal(get_data(maps), maps_data)
159 # make sure that n_regions (4th dimension) are kept same even
160 # in thresholded image
161 assert thr_maps.shape[-1] == maps.shape[-1]
164def test_threshold_maps_ratio_3d(map_img_3d):
165 """Check size is the same for 3D image before and after thresholding."""
166 thr_maps_3d = _threshold_maps_ratio(map_img_3d, threshold=0.5)
167 assert map_img_3d.shape == thr_maps_3d.shape
170@pytest.mark.parametrize("invalid_extract_type", ["spam", 1])
171def test_invalids_extract_types_in_connected_regions(
172 dummy_map, invalid_extract_type
173):
174 valid_names = ["connected_components", "local_regions"]
175 message = f"'extract_type' should be {valid_names}"
176 with pytest.raises(ValueError, match=message):
177 connected_regions(dummy_map, extract_type=invalid_extract_type)
180@pytest.mark.parametrize(
181 "extract_type", ["connected_components", "local_regions"]
182)
183def test_connected_regions_4d(maps, extract_type):
184 """Regions extracted should be equal or more than already present."""
185 connected_extraction_img, index = connected_regions(
186 maps, min_region_size=10, extract_type=extract_type
187 )
188 assert connected_extraction_img.shape[-1] >= N_REGIONS
189 assert index, np.ndarray
192@pytest.mark.parametrize(
193 "extract_type", ["connected_components", "local_regions"]
194)
195def test_connected_regions_3d(map_img_3d, extract_type):
196 """For 3D images regions extracted should be more than equal to 1."""
197 connected_extraction_3d_img, _ = connected_regions(
198 maps_img=map_img_3d, min_region_size=10, extract_type=extract_type
199 )
200 assert connected_extraction_3d_img.shape[-1] >= 1
203def test_connected_regions_different_results_with_different_mask_images(
204 maps_and_mask,
205):
206 maps, mask_img = maps_and_mask
207 # Test input mask_img
208 mask = get_data(mask_img)
209 mask[1, 1, 1] = 0
211 extraction_with_mask_img, _ = connected_regions(maps, mask_img=mask_img)
213 assert extraction_with_mask_img.shape[-1] >= 1
215 extraction_without_mask_img, _ = connected_regions(maps)
217 assert np.all(get_data(extraction_with_mask_img)[mask == 0] == 0.0)
218 assert not np.all(get_data(extraction_without_mask_img)[mask == 0] == 0.0)
220 # mask_img with different shape
221 mask = np.zeros(shape=(10, 11, 12), dtype="uint8")
222 mask[1:-1, 1:-1, 1:-1] = 1
223 affine = np.array(
224 [
225 [2.0, 0.0, 0.0, 0.0],
226 [0.0, 2.0, 0.0, 0.0],
227 [0.0, 0.0, 2.0, 0.0],
228 [0.0, 0.0, 0.0, 2.0],
229 ]
230 )
231 mask_img = Nifti1Image(mask, affine=affine)
232 extraction_not_same_fov_mask, _ = connected_regions(
233 maps, mask_img=mask_img
234 )
236 assert maps.shape[:3] == extraction_not_same_fov_mask.shape[:3]
237 assert mask_img.shape != extraction_not_same_fov_mask.shape[:3]
239 extraction_not_same_fov, _ = connected_regions(maps)
241 assert np.sum(get_data(extraction_not_same_fov) == 0) > np.sum(
242 get_data(extraction_not_same_fov_mask) == 0
243 )
246def test_invalid_threshold_strategies(dummy_map):
247 extract_strategy_check = RegionExtractor(
248 dummy_map, thresholding_strategy="n_"
249 )
251 with pytest.raises(
252 ValueError,
253 match="'thresholding_strategy' should be ",
254 ):
255 extract_strategy_check.fit()
258@pytest.mark.parametrize("threshold", [None, "30%"])
259def test_threshold_as_none_and_string_cases(dummy_map, threshold):
260 to_check = RegionExtractor(dummy_map, threshold=threshold)
262 with pytest.raises(
263 ValueError, match="The given input to threshold is not valid."
264 ):
265 to_check.fit()
268def test_region_extractor_fit_and_transform(maps_and_mask):
269 maps, mask_img = maps_and_mask
271 # Test maps are zero in the mask
272 mask_data = get_data(mask_img)
273 mask_data[1, 1, 1] = 0
274 extractor_without_mask = RegionExtractor(maps)
275 extractor_without_mask.fit()
276 extractor_with_mask = RegionExtractor(maps, mask_img=mask_img)
277 extractor_with_mask.fit()
279 assert not np.all(
280 get_data(extractor_without_mask.regions_img_)[mask_data == 0] == 0.0
281 )
282 assert np.all(
283 get_data(extractor_with_mask.regions_img_)[mask_data == 0] == 0.0
284 )
287def test_region_extractor_strategy_ratio_n_voxels(maps):
288 extract_ratio = RegionExtractor(
289 maps, threshold=0.2, thresholding_strategy="ratio_n_voxels"
290 )
291 extract_ratio.fit()
293 assert extract_ratio.regions_img_ != ""
294 assert extract_ratio.regions_img_.shape[-1] >= N_REGIONS
297@pytest.mark.parametrize("negative_regions", [True])
298def test_region_extractor_two_sided(maps):
299 threshold = 0.4
300 thresholding_strategy = "img_value"
301 min_region_size = 5
303 extract_ratio1 = RegionExtractor(
304 maps,
305 threshold=threshold,
306 thresholding_strategy=thresholding_strategy,
307 two_sided=False,
308 min_region_size=min_region_size,
309 extractor="connected_components",
310 )
311 extract_ratio1.fit()
313 extract_ratio2 = RegionExtractor(
314 maps,
315 threshold=threshold,
316 thresholding_strategy=thresholding_strategy,
317 two_sided=True,
318 min_region_size=min_region_size,
319 extractor="connected_components",
320 )
322 extract_ratio2.fit()
324 assert not np.array_equal(
325 np.unique(extract_ratio1.regions_img_.get_fdata()),
326 np.unique(extract_ratio2.regions_img_.get_fdata()),
327 )
330def test_region_extractor_strategy_percentile(maps_and_mask):
331 maps, mask_img = maps_and_mask
333 extractor = RegionExtractor(
334 maps,
335 threshold=30,
336 thresholding_strategy="percentile",
337 mask_img=mask_img,
338 two_sided=True,
339 )
340 extractor.fit()
342 assert extractor.index_, np.ndarray
343 assert extractor.regions_img_ != ""
344 assert extractor.regions_img_.shape[-1] >= N_REGIONS
346 n_regions_extracted = extractor.regions_img_.shape[-1]
347 shape = (91, 109, 91, 7)
348 expected_signal_shape = (7, n_regions_extracted)
349 n_subjects = 3
350 for _ in range(n_subjects):
351 # smoke test NiftiMapsMasker transform inherited in Region Extractor
352 signal = extractor.transform(_img_4d_zeros(shape=shape))
354 assert expected_signal_shape == signal.shape
357def test_region_extractor_high_resolution_image(
358 affine_eye, n_regions, shape_3d_large
359):
360 maps, _ = generate_maps(
361 shape=shape_3d_large, n_regions=n_regions, affine=0.2 * affine_eye
362 )
364 extract_ratio = RegionExtractor(
365 maps,
366 thresholding_strategy="ratio_n_voxels",
367 smoothing_fwhm=0.6,
368 min_region_size=0.4,
369 )
370 extract_ratio.fit()
372 assert extract_ratio.regions_img_ != ""
373 assert extract_ratio.regions_img_.shape[-1] >= n_regions
376def test_region_extractor_zeros_affine_diagonal(affine_eye, n_regions):
377 affine = affine_eye
378 affine[[0, 1]] = affine[[1, 0]] # permutes first and second lines
379 maps, _ = generate_maps(
380 shape=[40, 40, 40], n_regions=n_regions, affine=affine, random_state=42
381 )
383 extract_ratio = RegionExtractor(
384 maps, threshold=0.2, thresholding_strategy="ratio_n_voxels"
385 )
386 extract_ratio.fit()
388 assert extract_ratio.regions_img_ != ""
389 assert extract_ratio.regions_img_.shape[-1] >= n_regions
392def test_error_messages_connected_label_regions(img_labels):
393 with pytest.raises(
394 ValueError, match="Expected 'min_size' to be specified as integer."
395 ):
396 connected_label_regions(labels_img=img_labels, min_size="a")
397 with pytest.raises(
398 ValueError, match="'connect_diag' must be specified as True or False."
399 ):
400 connected_label_regions(labels_img=img_labels, connect_diag=None)
403def test_remove_small_regions(affine_eye):
404 data = np.array(
405 [
406 [[0.0, 1.0, 0.0], [0.0, 1.0, 1.0], [0.0, 0.0, 0.0]],
407 [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
408 [[0.0, 0.0, 1.0], [1.0, 0.0, 0.0], [0.0, 1.0, 1.0]],
409 ]
410 )
411 # To remove small regions, data should be labeled
412 label_map, n_labels = label(data)
413 sum_label_data = np.sum(label_map)
415 min_size = 10
416 # data can be act as mask_data to identify regions in label_map because
417 # features in label_map are built upon non-zeros in data
418 removed_data = _remove_small_regions(label_map, affine_eye, min_size)
419 sum_removed_data = np.sum(removed_data)
421 assert sum_removed_data < sum_label_data
424def test_connected_label_regions(img_labels):
425 labels_data = get_data(img_labels)
426 n_labels_without_region_extraction = len(np.unique(labels_data))
428 # extract region without specifying min_size
429 extracted_regions_on_labels_img = connected_label_regions(img_labels)
430 extracted_regions_labels_data = get_data(extracted_regions_on_labels_img)
431 n_labels_without_min = len(np.unique(extracted_regions_labels_data))
433 assert n_labels_without_region_extraction < n_labels_without_min
435 # with specifying min_size
436 extracted_regions_with_min = connected_label_regions(
437 img_labels, min_size=100
438 )
439 extracted_regions_with_min_data = get_data(extracted_regions_with_min)
440 n_labels_with_min = len(np.unique(extracted_regions_with_min_data))
442 assert n_labels_without_min > n_labels_with_min
445def test_connected_label_regions_connect_diag_false(img_labels):
446 labels_data = get_data(img_labels)
447 n_labels_without_region_extraction = len(np.unique(labels_data))
449 ext_reg_without_connect_diag = connected_label_regions(
450 img_labels, connect_diag=False
451 )
453 data_wo_connect_diag = get_data(ext_reg_without_connect_diag)
454 n_labels_wo_connect_diag = len(np.unique(data_wo_connect_diag))
455 assert n_labels_wo_connect_diag > n_labels_without_region_extraction
458def test_connected_label_regions_return_empty_for_large_min_size(img_labels):
459 """If min_size is large and if all the regions are removed \
460 then empty image will be returned.
461 """
462 extract_reg_min_size_large = connected_label_regions(
463 img_labels, min_size=500
464 )
466 assert np.unique(get_data(extract_reg_min_size_large)) == 0
469def test_connected_label_regions_check_labels(img_labels):
470 """Test the names of the brain regions given in labels."""
471 # Test labels for 9 regions in n_regions
472 labels = [f"region_{x}" for x in "abcdefghi"]
474 # If labels are provided, first return will contain extracted labels image
475 # and second return will contain list of new names generated based on same
476 # name with assigned on both hemispheres for example.
477 _, new_labels = connected_label_regions(
478 img_labels, min_size=100, labels=labels
479 )
480 # The length of new_labels returned can differ depending upon min_size. If
481 # min_size given is more small regions can be removed therefore newly
482 # generated labels can be less than original size of labels. Or if min_size
483 # is less then newly generated labels can be more.
485 # We test here whether labels returned are empty or not.
486 assert new_labels != ""
487 assert len(new_labels) <= len(labels)
490def test_connected_label_regions_check_labels_as_numpy_array(img_labels):
491 """Test the names of the brain regions given in labels."""
492 # labels given in numpy array
493 # Test labels for 9 regions in n_regions
494 labels = [f"region_{x}" for x in "abcdefghi"]
495 labels = np.asarray(labels)
496 _, new_labels2 = connected_label_regions(img_labels, labels=labels)
498 assert new_labels2 != ""
499 # By default min_size is less, so newly generated labels can be more.
500 assert len(new_labels2) >= len(labels)
502 # If number of labels provided are wrong (which means less than number of
503 # unique labels in img_labels), then we raise an error
505 # Test whether error raises
506 unique_labels = set(np.unique(np.asarray(get_data(img_labels))))
507 unique_labels.remove(0)
509 # labels given are less than n_regions=9
510 provided_labels = [f"region_{x}" for x in "acfghi"]
512 assert len(provided_labels) < len(unique_labels)
514 with pytest.raises(ValueError):
515 connected_label_regions(img_labels, labels=provided_labels)
518def test_connected_label_regions_unknonw_labels(
519 img_labels, affine_eye, shape_3d_default
520):
521 """If unknown/negative integers are provided as labels in img_labels, \
522 we raise an error and test the same whether error is raised.
524 Introduce data type of float
526 See issue: https://github.com/nilearn/nilearn/issues/2580
527 """
528 labels_data = get_data(img_labels)
530 labels_data = np.zeros(shape_3d_default, dtype=np.float32)
531 h0, h1, h2 = (x // 2 for x in shape_3d_default)
532 labels_data[:h0, :h1, :h2] = 1
533 labels_data[:h0, :h1, h2:] = 2
534 labels_data[:h0, h1:, :h2] = 3
535 labels_data[:h0, h1:, h2:] = -4
536 labels_data[h0:, :h1, :h2] = 5
537 labels_data[h0:, :h1, h2:] = 6
538 labels_data[h0:, h1:, :h2] = np.nan
539 labels_data[h0:, h1:, h2:] = np.inf
541 neg_labels_img = Nifti1Image(labels_data, affine_eye)
543 with pytest.raises(ValueError):
544 connected_label_regions(labels_img=neg_labels_img)
546 # If labels_img provided is 4D Nifti image, then test whether error is
547 # raised or not. Since this function accepts only 3D image.
548 labels_4d_data = np.zeros((*shape_3d_default, 2))
549 labels_data[h0:, h1:, :h2] = 0
550 labels_data[h0:, h1:, h2:] = 0
551 labels_4d_data[..., 0] = labels_data
552 labels_4d_data[..., 1] = labels_data
553 labels_img_4d = Nifti1Image(labels_4d_data, affine_eye)
555 with pytest.raises(DimensionError):
556 connected_label_regions(labels_img=labels_img_4d)
559def test_connected_label_regions_check_labels_string_without_list(
560 img_labels, affine_eye, shape_3d_default
561):
562 """If labels (or names to regions) given is a string without a list \
563 we expect it to be split to regions extracted and returned as list.
564 """
565 labels_in_str = "region_a"
566 labels_img_in_str = generate_labeled_regions(
567 shape=shape_3d_default, affine=affine_eye, n_regions=1
568 )
569 _, new_labels = connected_label_regions(
570 labels_img_in_str, labels=labels_in_str
571 )
573 assert isinstance(new_labels, list)
575 # If user has provided combination of labels, then function passes without
576 # breaking and new labels are returned based upon given labels and should
577 # be equal or more based on regions extracted
578 combined_labels = [
579 "region_a",
580 "1",
581 "region_b",
582 "2",
583 "region_c",
584 "3",
585 "region_d",
586 "4",
587 "region_e",
588 ]
589 _, new_labels = connected_label_regions(img_labels, labels=combined_labels)
591 assert len(new_labels) >= len(combined_labels)