Coverage for nilearn/connectome/tests/test_group_sparse_cov.py: 0%

64 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-16 12:32 +0200

1import numpy as np 

2import pytest 

3from sklearn.utils.estimator_checks import parametrize_with_checks 

4 

5from nilearn._utils.data_gen import generate_group_sparse_gaussian_graphs 

6from nilearn._utils.estimator_checks import ( 

7 check_estimator, 

8 nilearn_check_estimator, 

9 return_expected_failed_checks, 

10) 

11from nilearn._utils.tags import SKLEARN_LT_1_6 

12from nilearn.connectome import GroupSparseCovariance, GroupSparseCovarianceCV 

13from nilearn.connectome.group_sparse_cov import ( 

14 group_sparse_covariance, 

15 group_sparse_scores, 

16) 

17 

18ESTIMATORS_TO_CHECK = [GroupSparseCovarianceCV(), GroupSparseCovariance()] 

19 

20if SKLEARN_LT_1_6: 

21 

22 @pytest.mark.parametrize( 

23 "estimator, check, name", 

24 (check_estimator(estimators=ESTIMATORS_TO_CHECK)), 

25 ) 

26 def test_check_estimator_group_sparse_covariance(estimator, check, name): # noqa: ARG001 

27 """Check compliance with sklearn estimators.""" 

28 check(estimator) 

29 

30 @pytest.mark.xfail(reason="invalid checks should fail") 

31 @pytest.mark.parametrize( 

32 "estimator, check, name", 

33 check_estimator(estimators=ESTIMATORS_TO_CHECK, valid=False), 

34 ) 

35 def test_check_estimator_invalid_group_sparse_covariance( 

36 estimator, 

37 check, 

38 name, # noqa: ARG001 

39 ): 

40 """Check compliance with sklearn estimators.""" 

41 check(estimator) 

42 

43else: 

44 

45 @parametrize_with_checks( 

46 estimators=ESTIMATORS_TO_CHECK, 

47 expected_failed_checks=return_expected_failed_checks, 

48 ) 

49 def test_check_estimator_sklearn(estimator, check): 

50 """Check compliance with sklearn estimators.""" 

51 check(estimator) 

52 

53 

54@pytest.mark.parametrize( 

55 "estimator, check, name", 

56 nilearn_check_estimator(estimators=ESTIMATORS_TO_CHECK), 

57) 

58def test_check_estimator_nilearn(estimator, check, name): # noqa: ARG001 

59 """Check compliance with nilearn estimators rules.""" 

60 check(estimator) 

61 

62 

63def test_group_sparse_covariance(rng): 

64 # run in debug mode. Should not fail 

65 # without debug mode: cost must decrease. 

66 

67 signals, _, _ = generate_group_sparse_gaussian_graphs( 

68 density=0.1, 

69 n_subjects=5, 

70 n_features=10, 

71 min_n_samples=100, 

72 max_n_samples=151, 

73 random_state=rng, 

74 ) 

75 

76 alpha = 0.1 

77 

78 # These executions must hit the tolerance limit 

79 _, omega = group_sparse_covariance( 

80 signals, alpha, max_iter=20, tol=1e-2, debug=True, verbose=1 

81 ) 

82 _, omega2 = group_sparse_covariance( 

83 signals, alpha, max_iter=20, tol=1e-2, debug=True, verbose=0 

84 ) 

85 

86 np.testing.assert_almost_equal(omega, omega2, decimal=4) 

87 

88 

89@pytest.mark.parametrize("duality_gap", [True, False]) 

90def test_group_sparse_covariance_with_probe_function(rng, duality_gap): 

91 signals, _, _ = generate_group_sparse_gaussian_graphs( 

92 density=0.1, 

93 n_subjects=5, 

94 n_features=10, 

95 min_n_samples=100, 

96 max_n_samples=151, 

97 random_state=rng, 

98 ) 

99 

100 alpha = 0.1 

101 

102 class Probe: 

103 def __init__(self): 

104 self.objective = [] 

105 

106 def __call__( 

107 self, 

108 emp_covs, 

109 n_samples, 

110 alpha, 

111 max_iter, # noqa: ARG002 

112 tol, # noqa: ARG002 

113 n, 

114 omega, 

115 omega_diff, # noqa: ARG002 

116 ): 

117 if n >= 0: 

118 if duality_gap: 

119 _, objective, _ = group_sparse_scores( 

120 omega, 

121 n_samples, 

122 emp_covs, 

123 alpha, 

124 duality_gap=duality_gap, 

125 ) 

126 else: 

127 _, objective = group_sparse_scores( 

128 omega, 

129 n_samples, 

130 emp_covs, 

131 alpha, 

132 duality_gap=duality_gap, 

133 ) 

134 self.objective.append(objective) 

135 

136 # Use a probe to test for number of iterations and decreasing objective. 

137 probe = Probe() 

138 _, omega = group_sparse_covariance( 

139 signals, alpha, max_iter=4, tol=None, verbose=0, probe_function=probe 

140 ) 

141 objective = probe.objective 

142 # check number of iterations 

143 assert len(objective) == 4 

144 

145 # np.testing.assert_array_less is a strict comparison. 

146 # Zeros can occur in np.diff(objective). 

147 assert np.all(np.diff(objective) <= 0) 

148 assert omega.shape == (10, 10, 5) 

149 

150 

151def test_group_sparse_covariance_check_consistency_between_classes(rng): 

152 signals, _, _ = generate_group_sparse_gaussian_graphs( 

153 density=0.1, 

154 n_subjects=5, 

155 n_features=10, 

156 min_n_samples=100, 

157 max_n_samples=151, 

158 random_state=rng, 

159 ) 

160 

161 # Check consistency between classes 

162 gsc1 = GroupSparseCovarianceCV( 

163 alphas=4, tol=1e-1, max_iter=20, verbose=0, early_stopping=True 

164 ) 

165 gsc1.fit(signals) 

166 

167 gsc2 = GroupSparseCovariance( 

168 alpha=gsc1.alpha_, tol=1e-1, max_iter=20, verbose=0 

169 ) 

170 gsc2.fit(signals) 

171 

172 np.testing.assert_almost_equal( 

173 gsc1.precisions_, gsc2.precisions_, decimal=4 

174 ) 

175 

176 

177def test_group_sparse_covariance_errors(rng): 

178 signals, _, _ = generate_group_sparse_gaussian_graphs( 

179 density=0.1, 

180 n_subjects=5, 

181 n_features=10, 

182 min_n_samples=100, 

183 max_n_samples=151, 

184 random_state=rng, 

185 ) 

186 

187 alpha = 0.1 

188 

189 # Test input argument checking 

190 with pytest.raises(ValueError, match="must be a positive number"): 

191 group_sparse_covariance(signals, "") 

192 with pytest.raises(ValueError, match="subjects' .* must be .* iterable"): 

193 group_sparse_covariance(1, alpha) 

194 with pytest.raises( 

195 ValueError, match="All subjects must have the same number of features." 

196 ): 

197 group_sparse_covariance([np.ones((2, 2)), np.ones((2, 3))], alpha)