Coverage for nilearn/maskers/tests/test_surface_masker.py: 0%
64 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
1import numpy as np
2import pytest
3from sklearn.utils.estimator_checks import parametrize_with_checks
5from nilearn._utils.estimator_checks import (
6 check_estimator,
7 nilearn_check_estimator,
8 return_expected_failed_checks,
9)
10from nilearn._utils.tags import SKLEARN_LT_1_6
11from nilearn.maskers import SurfaceMasker
12from nilearn.surface import SurfaceImage
13from nilearn.surface.utils import (
14 assert_polydata_equal,
15 assert_surface_image_equal,
16)
18ESTIMATORS_TO_CHECK = [SurfaceMasker()]
20if SKLEARN_LT_1_6:
22 @pytest.mark.parametrize(
23 "estimator, check, name",
24 check_estimator(estimators=ESTIMATORS_TO_CHECK),
25 )
26 def test_check_estimator_sklearn_valid(estimator, check, name): # noqa: ARG001
27 """Check compliance with sklearn estimators."""
28 check(estimator)
30 @pytest.mark.xfail(reason="invalid checks should fail")
31 @pytest.mark.parametrize(
32 "estimator, check, name",
33 check_estimator(estimators=ESTIMATORS_TO_CHECK, valid=False),
34 )
35 def test_check_estimator_sklearn_invalid(estimator, check, name): # noqa: ARG001
36 """Check compliance with sklearn estimators."""
37 check(estimator)
39else:
41 @parametrize_with_checks(
42 estimators=ESTIMATORS_TO_CHECK,
43 expected_failed_checks=return_expected_failed_checks,
44 )
45 def test_check_estimator_sklearn(estimator, check):
46 """Check compliance with sklearn estimators."""
47 check(estimator)
50@pytest.mark.parametrize(
51 "estimator, check, name",
52 nilearn_check_estimator(estimators=ESTIMATORS_TO_CHECK),
53)
54def test_check_estimator_nilearn(estimator, check, name): # noqa: ARG001
55 """Check compliance with sklearn estimators."""
56 check(estimator)
59def test_fit_list_surf_images(surf_img_2d):
60 """Test fit on list of surface images.
62 - resulting mask should have a single 'timepoint'
63 - all vertices should be included in the mask, because no mask is provided
65 """
66 masker = SurfaceMasker()
67 masker.fit([surf_img_2d(3), surf_img_2d(5)])
68 assert masker.mask_img_.shape == (surf_img_2d(1).shape[0],)
69 assert masker.mask_img_.shape == (masker.n_elements_,)
72def test_fit_list_surf_images_with_mask(surf_mask_1d, surf_img_2d):
73 """Test fit on list of surface images when masker has a mask."""
74 masker = SurfaceMasker(mask_img=surf_mask_1d)
75 masker.fit([surf_img_2d(3), surf_img_2d(5)])
76 assert masker.mask_img_.shape == (surf_img_2d(1).shape[0],)
79@pytest.mark.parametrize("n_timepoints", [3])
80def test_transform_inverse_transform_no_mask(surf_mesh, n_timepoints):
81 """Check output of inverse transform when not using a mask."""
82 # make a sample image with data on the first timepoint/sample 1-4 on
83 # left part and 10-50 on right part
84 img_data = {}
85 for i, (key, val) in enumerate(surf_mesh.parts.items()):
86 data_shape = (val.n_vertices, n_timepoints)
87 data_part = (
88 np.arange(np.prod(data_shape)).reshape(data_shape[::-1]) + 1.0
89 ) * 10**i
90 img_data[key] = data_part.T
92 img = SurfaceImage(surf_mesh, img_data)
93 masker = SurfaceMasker().fit(img)
94 signals = masker.transform(img)
96 # make sure none of the data has been removed
97 assert np.array_equal(signals[0], [1, 2, 3, 4, 10, 20, 30, 40, 50])
98 unmasked_img = masker.inverse_transform(signals)
99 assert_polydata_equal(img.data, unmasked_img.data)
102@pytest.mark.parametrize("n_timepoints", [1, 3])
103def test_transform_inverse_transform_with_mask(surf_mesh, n_timepoints):
104 """Check output of inverse transform when using a mask."""
105 # make a sample image with data on the first timepoint/sample 1-4 on
106 # left part and 10-50 on right part-
107 img_data = {}
108 for i, (key, val) in enumerate(surf_mesh.parts.items()):
109 data_shape = (val.n_vertices, n_timepoints)
110 data_part = (
111 np.arange(np.prod(data_shape)).reshape(data_shape[::-1]) + 1.0
112 ) * 10**i
113 img_data[key] = data_part.T
114 img = SurfaceImage(surf_mesh, img_data)
116 # make a mask that removes first vertex of each part
117 # total 2 removed
118 mask_data = {
119 "left": np.asarray([False, True, True, True]),
120 "right": np.asarray([False, True, True, True, True]),
121 }
122 mask = SurfaceImage(surf_mesh, mask_data)
124 masker = SurfaceMasker(mask).fit(img)
125 signals = masker.transform(img)
127 # check the data for first seven vertices is as expected
128 assert np.array_equal(signals.ravel()[:7], [2, 3, 4, 20, 30, 40, 50])
130 # check whether inverse transform does not change the img
131 unmasked_img = masker.inverse_transform(signals)
132 # recreate data that we expect after unmasking
133 expected_data = {k: v.copy() for (k, v) in img.data.parts.items()}
134 for v in expected_data.values():
135 v[0] = 0.0
136 expected_img = SurfaceImage(img.mesh, expected_data)
137 assert_surface_image_equal(unmasked_img, expected_img)