Coverage for nilearn/decomposition/tests/test_dict_learning.py: 0%
55 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
1import numpy as np
2import pytest
4from nilearn.decomposition.dict_learning import DictLearning
5from nilearn.decomposition.tests.conftest import (
6 RANDOM_STATE,
7 check_decomposition_estimator,
8)
9from nilearn.image import get_data, iter_img
10from nilearn.maskers import NiftiMasker, SurfaceMasker
11from nilearn.surface.surface import get_data as get_surface_data
14@pytest.mark.parametrize("data_type", ["nifti", "surface"])
15@pytest.mark.parametrize("n_epochs", [1, 2, 10])
16def test_check_values_epoch_argument_smoke(
17 decomposition_mask_img, n_epochs, canica_components, canica_data, data_type
18):
19 """Smoke test to check different values of the epoch argument."""
20 if data_type == "nifti":
21 masker = NiftiMasker(mask_img=decomposition_mask_img).fit()
22 mask = get_data(decomposition_mask_img) != 0
23 else:
24 masker = SurfaceMasker(mask_img=decomposition_mask_img).fit()
25 mask = get_surface_data(decomposition_mask_img) != 0
27 flat_mask = mask.ravel()
28 dict_init = masker.inverse_transform(canica_components[:, flat_mask])
30 dict_learning = DictLearning(
31 n_components=4,
32 random_state=RANDOM_STATE,
33 dict_init=dict_init,
34 mask=decomposition_mask_img,
35 n_epochs=n_epochs,
36 smoothing_fwhm=None,
37 alpha=1,
38 )
39 dict_learning.fit(canica_data)
41 check_decomposition_estimator(dict_learning, data_type)
44@pytest.mark.parametrize("data_type", ["nifti"])
45def test_dict_learning(
46 decomposition_mask_img, canica_components, canica_data, data_type
47):
48 """Check content of components_img_."""
49 masker = NiftiMasker(mask_img=decomposition_mask_img).fit()
50 mask = get_data(decomposition_mask_img) != 0
51 flat_mask = mask.ravel()
52 masked_components = canica_components[:, flat_mask]
53 dict_init = masker.inverse_transform(masked_components)
55 # Note that
56 # adding smoothing will make this test break
57 smoothing_fwhm = None
59 dict_learning = DictLearning(
60 n_components=4,
61 random_state=RANDOM_STATE,
62 dict_init=dict_init,
63 mask=decomposition_mask_img,
64 smoothing_fwhm=smoothing_fwhm,
65 alpha=1,
66 )
68 dict_learning_auto_init = DictLearning(
69 n_components=4,
70 random_state=RANDOM_STATE,
71 mask=decomposition_mask_img,
72 n_epochs=10,
73 smoothing_fwhm=smoothing_fwhm,
74 alpha=1,
75 )
76 maps = {}
77 for estimator in [dict_learning, dict_learning_auto_init]:
78 estimator.fit(canica_data)
80 check_decomposition_estimator(dict_learning, data_type)
82 maps[estimator] = get_data(estimator.components_img_)
83 maps[estimator] = np.reshape(
84 np.rollaxis(maps[estimator], 3, 0)[:, mask], (4, flat_mask.sum())
85 )
87 for this_dict_learning in [dict_learning]:
88 these_maps = maps[this_dict_learning]
89 S = np.sqrt(np.sum(masked_components**2, axis=1))
90 S[S == 0] = 1
91 masked_components /= S[:, np.newaxis]
93 S = np.sqrt(np.sum(these_maps**2, axis=1))
94 S[S == 0] = 1
95 these_maps /= S[:, np.newaxis]
97 K = np.abs(masked_components.dot(these_maps.T))
98 recovered_maps = np.sum(K > 0.9)
100 assert recovered_maps >= 2
103@pytest.mark.parametrize("data_type", ["nifti", "surface"])
104def test_component_sign(
105 decomposition_mask_img, canica_data, data_type
106) -> None:
107 """Check sign of extracted components.
109 Regression test:
110 We should have a heuristic that flips the sign of components in
111 DictLearning to have more positive values than negative values, for
112 instance by making sure that the largest value is positive.
113 """
114 dict_learning = DictLearning(
115 n_components=4,
116 random_state=RANDOM_STATE,
117 mask=decomposition_mask_img,
118 smoothing_fwhm=None,
119 alpha=1,
120 )
121 dict_learning.fit(canica_data)
123 check_decomposition_estimator(dict_learning, data_type)
125 for mp in iter_img(dict_learning.components_img_):
126 mp = get_data(mp) if data_type == "nifti" else get_surface_data(mp)
127 assert np.sum(mp[mp <= 0]) <= np.sum(mp[mp > 0])