Coverage for nilearn/_utils/tags.py: 42%
38 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
1"""Nilearn tags for estimators.
3These tags override or extends some of the sklearn tags.
5With those tags we can specify if one of Nilearn's 'estimator'
6(those include our maskers)
7has certain characteristics or expected behavior.
8For example if the estimator can accept nifti and / or surface images
9during fitting.
11This is mostly used internally to run some checks on our API
12and its behavior.
14See the sklearn documentation for more details on tags
15https://scikit-learn.org/1.6/developers/develop.html#estimator-tags
16"""
18from dataclasses import dataclass
20from packaging.version import parse
21from sklearn import __version__ as sklearn_version
23SKLEARN_LT_1_6 = parse(sklearn_version).release[1] < 6
25if SKLEARN_LT_1_6: 25 ↛ 58line 25 didn't jump to line 58 because the condition on line 25 was always true
27 def tags(
28 niimg_like=True,
29 surf_img=False,
30 masker=False,
31 multi_masker=False,
32 glm=False,
33 **kwargs,
34 ):
35 """Add nilearn tags to estimator.
37 See also: InputTags
39 TODO remove when dropping sklearn 1.5
40 """
41 X_types = kwargs.get("X_types", [])
42 X_types.append("2darray")
43 if niimg_like: 43 ↛ 44line 43 didn't jump to line 44 because the condition on line 43 was never true
44 X_types.append("niimg_like")
45 if surf_img: 45 ↛ 46line 45 didn't jump to line 46 because the condition on line 45 was never true
46 X_types.append("surf_img")
47 if masker: 47 ↛ 48line 47 didn't jump to line 48 because the condition on line 47 was never true
48 X_types.append("masker")
49 if multi_masker: 49 ↛ 50line 49 didn't jump to line 50 because the condition on line 49 was never true
50 X_types.append("multi_masker")
51 if glm: 51 ↛ 52line 51 didn't jump to line 52 because the condition on line 51 was never true
52 X_types.append("glm")
53 X_types = list(set(X_types))
55 return dict(X_types=X_types, **kwargs)
57else:
58 from sklearn.utils import InputTags as SkInputTags
60 @dataclass
61 class InputTags(SkInputTags):
62 """Tags for the input data.
64 Nilearn version of sklearn.utils.InputTags
65 https://scikit-learn.org/1.6/modules/generated/sklearn.utils.InputTags.html#sklearn.utils.InputTags
66 """
68 # same as base input tags of
69 # sklearn.utils.InputTags
70 one_d_array: bool = False
71 two_d_array: bool = True
72 three_d_array: bool = False
73 sparse: bool = False
74 categorical: bool = False
75 string: bool = False
76 dict: bool = False
77 positive_only: bool = False
78 allow_nan: bool = False
79 pairwise: bool = False
81 # nilearn specific things
83 # estimator accepts for str, Path to .nii[.gz] file
84 # or NiftiImage object
85 niimg_like: bool = True
86 # estimator accepts SurfaceImage object
87 surf_img: bool = False
89 # estimator that are maskers
90 # TODO: implement a masker_tags attribute
91 masker: bool = False
92 multi_masker: bool = False
94 # glm
95 glm: bool = False