Coverage for nilearn/maskers/tests/test_surface_masker.py: 0%

64 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-16 12:32 +0200

1import numpy as np 

2import pytest 

3from sklearn.utils.estimator_checks import parametrize_with_checks 

4 

5from nilearn._utils.estimator_checks import ( 

6 check_estimator, 

7 nilearn_check_estimator, 

8 return_expected_failed_checks, 

9) 

10from nilearn._utils.tags import SKLEARN_LT_1_6 

11from nilearn.maskers import SurfaceMasker 

12from nilearn.surface import SurfaceImage 

13from nilearn.surface.utils import ( 

14 assert_polydata_equal, 

15 assert_surface_image_equal, 

16) 

17 

18ESTIMATORS_TO_CHECK = [SurfaceMasker()] 

19 

20if SKLEARN_LT_1_6: 

21 

22 @pytest.mark.parametrize( 

23 "estimator, check, name", 

24 check_estimator(estimators=ESTIMATORS_TO_CHECK), 

25 ) 

26 def test_check_estimator_sklearn_valid(estimator, check, name): # noqa: ARG001 

27 """Check compliance with sklearn estimators.""" 

28 check(estimator) 

29 

30 @pytest.mark.xfail(reason="invalid checks should fail") 

31 @pytest.mark.parametrize( 

32 "estimator, check, name", 

33 check_estimator(estimators=ESTIMATORS_TO_CHECK, valid=False), 

34 ) 

35 def test_check_estimator_sklearn_invalid(estimator, check, name): # noqa: ARG001 

36 """Check compliance with sklearn estimators.""" 

37 check(estimator) 

38 

39else: 

40 

41 @parametrize_with_checks( 

42 estimators=ESTIMATORS_TO_CHECK, 

43 expected_failed_checks=return_expected_failed_checks, 

44 ) 

45 def test_check_estimator_sklearn(estimator, check): 

46 """Check compliance with sklearn estimators.""" 

47 check(estimator) 

48 

49 

50@pytest.mark.parametrize( 

51 "estimator, check, name", 

52 nilearn_check_estimator(estimators=ESTIMATORS_TO_CHECK), 

53) 

54def test_check_estimator_nilearn(estimator, check, name): # noqa: ARG001 

55 """Check compliance with sklearn estimators.""" 

56 check(estimator) 

57 

58 

59def test_fit_list_surf_images(surf_img_2d): 

60 """Test fit on list of surface images. 

61 

62 - resulting mask should have a single 'timepoint' 

63 - all vertices should be included in the mask, because no mask is provided 

64 

65 """ 

66 masker = SurfaceMasker() 

67 masker.fit([surf_img_2d(3), surf_img_2d(5)]) 

68 assert masker.mask_img_.shape == (surf_img_2d(1).shape[0],) 

69 assert masker.mask_img_.shape == (masker.n_elements_,) 

70 

71 

72def test_fit_list_surf_images_with_mask(surf_mask_1d, surf_img_2d): 

73 """Test fit on list of surface images when masker has a mask.""" 

74 masker = SurfaceMasker(mask_img=surf_mask_1d) 

75 masker.fit([surf_img_2d(3), surf_img_2d(5)]) 

76 assert masker.mask_img_.shape == (surf_img_2d(1).shape[0],) 

77 

78 

79@pytest.mark.parametrize("n_timepoints", [3]) 

80def test_transform_inverse_transform_no_mask(surf_mesh, n_timepoints): 

81 """Check output of inverse transform when not using a mask.""" 

82 # make a sample image with data on the first timepoint/sample 1-4 on 

83 # left part and 10-50 on right part 

84 img_data = {} 

85 for i, (key, val) in enumerate(surf_mesh.parts.items()): 

86 data_shape = (val.n_vertices, n_timepoints) 

87 data_part = ( 

88 np.arange(np.prod(data_shape)).reshape(data_shape[::-1]) + 1.0 

89 ) * 10**i 

90 img_data[key] = data_part.T 

91 

92 img = SurfaceImage(surf_mesh, img_data) 

93 masker = SurfaceMasker().fit(img) 

94 signals = masker.transform(img) 

95 

96 # make sure none of the data has been removed 

97 assert np.array_equal(signals[0], [1, 2, 3, 4, 10, 20, 30, 40, 50]) 

98 unmasked_img = masker.inverse_transform(signals) 

99 assert_polydata_equal(img.data, unmasked_img.data) 

100 

101 

102@pytest.mark.parametrize("n_timepoints", [1, 3]) 

103def test_transform_inverse_transform_with_mask(surf_mesh, n_timepoints): 

104 """Check output of inverse transform when using a mask.""" 

105 # make a sample image with data on the first timepoint/sample 1-4 on 

106 # left part and 10-50 on right part- 

107 img_data = {} 

108 for i, (key, val) in enumerate(surf_mesh.parts.items()): 

109 data_shape = (val.n_vertices, n_timepoints) 

110 data_part = ( 

111 np.arange(np.prod(data_shape)).reshape(data_shape[::-1]) + 1.0 

112 ) * 10**i 

113 img_data[key] = data_part.T 

114 img = SurfaceImage(surf_mesh, img_data) 

115 

116 # make a mask that removes first vertex of each part 

117 # total 2 removed 

118 mask_data = { 

119 "left": np.asarray([False, True, True, True]), 

120 "right": np.asarray([False, True, True, True, True]), 

121 } 

122 mask = SurfaceImage(surf_mesh, mask_data) 

123 

124 masker = SurfaceMasker(mask).fit(img) 

125 signals = masker.transform(img) 

126 

127 # check the data for first seven vertices is as expected 

128 assert np.array_equal(signals.ravel()[:7], [2, 3, 4, 20, 30, 40, 50]) 

129 

130 # check whether inverse transform does not change the img 

131 unmasked_img = masker.inverse_transform(signals) 

132 # recreate data that we expect after unmasking 

133 expected_data = {k: v.copy() for (k, v) in img.data.parts.items()} 

134 for v in expected_data.values(): 

135 v[0] = 0.0 

136 expected_img = SurfaceImage(img.mesh, expected_data) 

137 assert_surface_image_equal(unmasked_img, expected_img)