Coverage for nilearn/maskers/tests/test_surface_maps_masker.py: 0%

59 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-16 12:32 +0200

1import numpy as np 

2import pytest 

3from sklearn.utils.estimator_checks import parametrize_with_checks 

4 

5from nilearn._utils.estimator_checks import ( 

6 check_estimator, 

7 nilearn_check_estimator, 

8 return_expected_failed_checks, 

9) 

10from nilearn._utils.tags import SKLEARN_LT_1_6 

11from nilearn.conftest import _surf_maps_img 

12from nilearn.maskers import SurfaceMapsMasker 

13from nilearn.surface import SurfaceImage 

14 

15ESTIMATORS_TO_CHECK = [SurfaceMapsMasker(_surf_maps_img())] 

16 

17if SKLEARN_LT_1_6: 

18 

19 @pytest.mark.parametrize( 

20 "estimator, check, name", 

21 check_estimator(estimators=ESTIMATORS_TO_CHECK), 

22 ) 

23 def test_check_estimator_sklearn_valid(estimator, check, name): # noqa: ARG001 

24 """Check compliance with sklearn estimators.""" 

25 check(estimator) 

26 

27 @pytest.mark.xfail(reason="invalid checks should fail") 

28 @pytest.mark.parametrize( 

29 "estimator, check, name", 

30 check_estimator(estimators=ESTIMATORS_TO_CHECK, valid=False), 

31 ) 

32 def test_check_estimator_sklearn_invalid(estimator, check, name): # noqa: ARG001 

33 """Check compliance with sklearn estimators.""" 

34 check(estimator) 

35 

36else: 

37 

38 @parametrize_with_checks( 

39 estimators=ESTIMATORS_TO_CHECK, 

40 expected_failed_checks=return_expected_failed_checks, 

41 ) 

42 def test_check_estimator_sklearn(estimator, check): 

43 """Check compliance with sklearn estimators.""" 

44 check(estimator) 

45 

46 

47@pytest.mark.parametrize( 

48 "estimator, check, name", 

49 nilearn_check_estimator(estimators=ESTIMATORS_TO_CHECK), 

50) 

51def test_check_estimator_nilearn(estimator, check, name): # noqa: ARG001 

52 """Check compliance with sklearn estimators.""" 

53 check(estimator) 

54 

55 

56def test_surface_maps_masker_fit_transform_mask_vs_no_mask( 

57 surf_maps_img, surf_img_2d, surf_mask_1d 

58): 

59 """Test that fit_transform returns the different results when a mask is 

60 used vs. when no mask is used. 

61 """ 

62 masker_with_mask = SurfaceMapsMasker(surf_maps_img, surf_mask_1d).fit() 

63 region_signals_with_mask = masker_with_mask.transform(surf_img_2d(50)) 

64 

65 masker_no_mask = SurfaceMapsMasker(surf_maps_img).fit() 

66 region_signals_no_mask = masker_no_mask.transform(surf_img_2d(50)) 

67 

68 assert not (region_signals_with_mask == region_signals_no_mask).all() 

69 

70 

71def test_surface_maps_masker_fit_transform_actual_output(surf_mesh, rng): 

72 """Test that fit_transform returns the expected output. 

73 Meaning that the SurfaceMapsMasker gives the solution to equation Ax = B, 

74 where A is the maps_img, x is the region_signals, and B is the img. 

75 """ 

76 # create a maps_img with 9 vertices and 2 regions 

77 A = rng.random((9, 2)) 

78 maps_data = {"left": A[:4, :], "right": A[4:, :]} 

79 surf_maps_img = SurfaceImage(surf_mesh, maps_data) 

80 

81 # random region signals x 

82 expected_region_signals = rng.random((50, 2)) 

83 

84 # create an img with 9 vertices and 50 timepoints as B = A @ x 

85 B = np.dot(A, expected_region_signals.T) 

86 img_data = {"left": B[:4, :], "right": B[4:, :]} 

87 surf_img = SurfaceImage(surf_mesh, img_data) 

88 

89 # get the region signals x using the SurfaceMapsMasker 

90 region_signals = SurfaceMapsMasker(surf_maps_img).fit_transform(surf_img) 

91 

92 assert region_signals.shape == expected_region_signals.shape 

93 assert np.allclose(region_signals, expected_region_signals) 

94 

95 

96def test_surface_maps_masker_inverse_transform_actual_output(surf_mesh, rng): 

97 """Test that inverse_transform returns the expected output.""" 

98 # create a maps_img with 9 vertices and 2 regions 

99 A = rng.random((9, 2)) 

100 maps_data = {"left": A[:4, :], "right": A[4:, :]} 

101 surf_maps_img = SurfaceImage(surf_mesh, maps_data) 

102 

103 # random region signals x 

104 expected_region_signals = rng.random((50, 2)) 

105 

106 # create an img with 9 vertices and 50 timepoints as B = A @ x 

107 B = np.dot(A, expected_region_signals.T) 

108 img_data = {"left": B[:4, :], "right": B[4:, :]} 

109 surf_img = SurfaceImage(surf_mesh, img_data) 

110 

111 # get the region signals x using the SurfaceMapsMasker 

112 masker = SurfaceMapsMasker(surf_maps_img).fit() 

113 region_signals = masker.fit_transform(surf_img) 

114 X_inverse_transformed = masker.inverse_transform(region_signals) 

115 

116 assert np.allclose( 

117 X_inverse_transformed.data.parts["left"], img_data["left"] 

118 ) 

119 assert np.allclose( 

120 X_inverse_transformed.data.parts["right"], img_data["right"] 

121 ) 

122 

123 

124def test_surface_maps_masker_1d_maps_img(surf_img_1d): 

125 """Test that an error is raised when maps_img has 1D data.""" 

126 with pytest.raises( 

127 ValueError, 

128 match="maps_img should be 2D", 

129 ): 

130 SurfaceMapsMasker(maps_img=surf_img_1d).fit() 

131 

132 

133def test_surface_maps_masker_labels_img_none(): 

134 """Test that an error is raised when maps_img is None.""" 

135 with pytest.raises( 

136 ValueError, 

137 match="provide a maps_img during initialization", 

138 ): 

139 SurfaceMapsMasker(maps_img=None).fit()