Coverage for nilearn/_utils/tests/test_tags.py: 0%
21 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
1"""Check Nilearn estimators tags."""
3from sklearn.base import BaseEstimator
5from nilearn._utils.tags import SKLEARN_LT_1_6
8class NilearnEstimator(BaseEstimator):
9 """Dummy estimator that takes surface image but not nifti as inputs."""
11 def __sklearn_tags__(self):
12 # TODO
13 # get rid of if block
14 # bumping sklearn_version > 1.5
15 if SKLEARN_LT_1_6:
16 from nilearn._utils.tags import tags
18 return tags(surf_img=True, niimg_like=False)
20 from nilearn._utils.tags import InputTags
22 tags = super().__sklearn_tags__()
23 tags.input_tags = InputTags(surf_img=True, niimg_like=False)
24 return tags
27def test_nilearn_tags():
28 """Check that adding tags to Nilearn estimators work as expected.
30 Especially with different sklearn versions.
31 """
32 est = NilearnEstimator()
34 tags = est.__sklearn_tags__()
35 if SKLEARN_LT_1_6:
36 assert "niimg_like" not in tags["X_types"]
37 assert "surf_img" in tags["X_types"]
38 # making sure 2darray still here
39 # as it allows to run some sklearn checks
40 assert "2darray" in tags["X_types"]
41 else:
42 assert not tags.input_tags.niimg_like
43 assert tags.input_tags.surf_img
44 # making sure 2darray still here
45 # as it allows to run some sklearn checks
46 assert tags.input_tags.two_d_array