Coverage for nilearn/connectome/tests/test_group_sparse_cov.py: 0%
64 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
1import numpy as np
2import pytest
3from sklearn.utils.estimator_checks import parametrize_with_checks
5from nilearn._utils.data_gen import generate_group_sparse_gaussian_graphs
6from nilearn._utils.estimator_checks import (
7 check_estimator,
8 nilearn_check_estimator,
9 return_expected_failed_checks,
10)
11from nilearn._utils.tags import SKLEARN_LT_1_6
12from nilearn.connectome import GroupSparseCovariance, GroupSparseCovarianceCV
13from nilearn.connectome.group_sparse_cov import (
14 group_sparse_covariance,
15 group_sparse_scores,
16)
18ESTIMATORS_TO_CHECK = [GroupSparseCovarianceCV(), GroupSparseCovariance()]
20if SKLEARN_LT_1_6:
22 @pytest.mark.parametrize(
23 "estimator, check, name",
24 (check_estimator(estimators=ESTIMATORS_TO_CHECK)),
25 )
26 def test_check_estimator_group_sparse_covariance(estimator, check, name): # noqa: ARG001
27 """Check compliance with sklearn estimators."""
28 check(estimator)
30 @pytest.mark.xfail(reason="invalid checks should fail")
31 @pytest.mark.parametrize(
32 "estimator, check, name",
33 check_estimator(estimators=ESTIMATORS_TO_CHECK, valid=False),
34 )
35 def test_check_estimator_invalid_group_sparse_covariance(
36 estimator,
37 check,
38 name, # noqa: ARG001
39 ):
40 """Check compliance with sklearn estimators."""
41 check(estimator)
43else:
45 @parametrize_with_checks(
46 estimators=ESTIMATORS_TO_CHECK,
47 expected_failed_checks=return_expected_failed_checks,
48 )
49 def test_check_estimator_sklearn(estimator, check):
50 """Check compliance with sklearn estimators."""
51 check(estimator)
54@pytest.mark.parametrize(
55 "estimator, check, name",
56 nilearn_check_estimator(estimators=ESTIMATORS_TO_CHECK),
57)
58def test_check_estimator_nilearn(estimator, check, name): # noqa: ARG001
59 """Check compliance with nilearn estimators rules."""
60 check(estimator)
63def test_group_sparse_covariance(rng):
64 # run in debug mode. Should not fail
65 # without debug mode: cost must decrease.
67 signals, _, _ = generate_group_sparse_gaussian_graphs(
68 density=0.1,
69 n_subjects=5,
70 n_features=10,
71 min_n_samples=100,
72 max_n_samples=151,
73 random_state=rng,
74 )
76 alpha = 0.1
78 # These executions must hit the tolerance limit
79 _, omega = group_sparse_covariance(
80 signals, alpha, max_iter=20, tol=1e-2, debug=True, verbose=1
81 )
82 _, omega2 = group_sparse_covariance(
83 signals, alpha, max_iter=20, tol=1e-2, debug=True, verbose=0
84 )
86 np.testing.assert_almost_equal(omega, omega2, decimal=4)
89@pytest.mark.parametrize("duality_gap", [True, False])
90def test_group_sparse_covariance_with_probe_function(rng, duality_gap):
91 signals, _, _ = generate_group_sparse_gaussian_graphs(
92 density=0.1,
93 n_subjects=5,
94 n_features=10,
95 min_n_samples=100,
96 max_n_samples=151,
97 random_state=rng,
98 )
100 alpha = 0.1
102 class Probe:
103 def __init__(self):
104 self.objective = []
106 def __call__(
107 self,
108 emp_covs,
109 n_samples,
110 alpha,
111 max_iter, # noqa: ARG002
112 tol, # noqa: ARG002
113 n,
114 omega,
115 omega_diff, # noqa: ARG002
116 ):
117 if n >= 0:
118 if duality_gap:
119 _, objective, _ = group_sparse_scores(
120 omega,
121 n_samples,
122 emp_covs,
123 alpha,
124 duality_gap=duality_gap,
125 )
126 else:
127 _, objective = group_sparse_scores(
128 omega,
129 n_samples,
130 emp_covs,
131 alpha,
132 duality_gap=duality_gap,
133 )
134 self.objective.append(objective)
136 # Use a probe to test for number of iterations and decreasing objective.
137 probe = Probe()
138 _, omega = group_sparse_covariance(
139 signals, alpha, max_iter=4, tol=None, verbose=0, probe_function=probe
140 )
141 objective = probe.objective
142 # check number of iterations
143 assert len(objective) == 4
145 # np.testing.assert_array_less is a strict comparison.
146 # Zeros can occur in np.diff(objective).
147 assert np.all(np.diff(objective) <= 0)
148 assert omega.shape == (10, 10, 5)
151def test_group_sparse_covariance_check_consistency_between_classes(rng):
152 signals, _, _ = generate_group_sparse_gaussian_graphs(
153 density=0.1,
154 n_subjects=5,
155 n_features=10,
156 min_n_samples=100,
157 max_n_samples=151,
158 random_state=rng,
159 )
161 # Check consistency between classes
162 gsc1 = GroupSparseCovarianceCV(
163 alphas=4, tol=1e-1, max_iter=20, verbose=0, early_stopping=True
164 )
165 gsc1.fit(signals)
167 gsc2 = GroupSparseCovariance(
168 alpha=gsc1.alpha_, tol=1e-1, max_iter=20, verbose=0
169 )
170 gsc2.fit(signals)
172 np.testing.assert_almost_equal(
173 gsc1.precisions_, gsc2.precisions_, decimal=4
174 )
177def test_group_sparse_covariance_errors(rng):
178 signals, _, _ = generate_group_sparse_gaussian_graphs(
179 density=0.1,
180 n_subjects=5,
181 n_features=10,
182 min_n_samples=100,
183 max_n_samples=151,
184 random_state=rng,
185 )
187 alpha = 0.1
189 # Test input argument checking
190 with pytest.raises(ValueError, match="must be a positive number"):
191 group_sparse_covariance(signals, "")
192 with pytest.raises(ValueError, match="subjects' .* must be .* iterable"):
193 group_sparse_covariance(1, alpha)
194 with pytest.raises(
195 ValueError, match="All subjects must have the same number of features."
196 ):
197 group_sparse_covariance([np.ones((2, 2)), np.ones((2, 3))], alpha)