Coverage for nilearn/_utils/tags.py: 42%

38 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-20 10:58 +0200

1"""Nilearn tags for estimators. 

2 

3These tags override or extends some of the sklearn tags. 

4 

5With those tags we can specify if one of Nilearn's 'estimator' 

6(those include our maskers) 

7has certain characteristics or expected behavior. 

8For example if the estimator can accept nifti and / or surface images 

9during fitting. 

10 

11This is mostly used internally to run some checks on our API 

12and its behavior. 

13 

14See the sklearn documentation for more details on tags 

15https://scikit-learn.org/1.6/developers/develop.html#estimator-tags 

16""" 

17 

18from dataclasses import dataclass 

19 

20from packaging.version import parse 

21from sklearn import __version__ as sklearn_version 

22 

23SKLEARN_LT_1_6 = parse(sklearn_version).release[1] < 6 

24 

25if SKLEARN_LT_1_6: 25 ↛ 58line 25 didn't jump to line 58 because the condition on line 25 was always true

26 

27 def tags( 

28 niimg_like=True, 

29 surf_img=False, 

30 masker=False, 

31 multi_masker=False, 

32 glm=False, 

33 **kwargs, 

34 ): 

35 """Add nilearn tags to estimator. 

36 

37 See also: InputTags 

38 

39 TODO remove when dropping sklearn 1.5 

40 """ 

41 X_types = kwargs.get("X_types", []) 

42 X_types.append("2darray") 

43 if niimg_like: 43 ↛ 44line 43 didn't jump to line 44 because the condition on line 43 was never true

44 X_types.append("niimg_like") 

45 if surf_img: 45 ↛ 46line 45 didn't jump to line 46 because the condition on line 45 was never true

46 X_types.append("surf_img") 

47 if masker: 47 ↛ 48line 47 didn't jump to line 48 because the condition on line 47 was never true

48 X_types.append("masker") 

49 if multi_masker: 49 ↛ 50line 49 didn't jump to line 50 because the condition on line 49 was never true

50 X_types.append("multi_masker") 

51 if glm: 51 ↛ 52line 51 didn't jump to line 52 because the condition on line 51 was never true

52 X_types.append("glm") 

53 X_types = list(set(X_types)) 

54 

55 return dict(X_types=X_types, **kwargs) 

56 

57else: 

58 from sklearn.utils import InputTags as SkInputTags 

59 

60 @dataclass 

61 class InputTags(SkInputTags): 

62 """Tags for the input data. 

63 

64 Nilearn version of sklearn.utils.InputTags 

65 https://scikit-learn.org/1.6/modules/generated/sklearn.utils.InputTags.html#sklearn.utils.InputTags 

66 """ 

67 

68 # same as base input tags of 

69 # sklearn.utils.InputTags 

70 one_d_array: bool = False 

71 two_d_array: bool = True 

72 three_d_array: bool = False 

73 sparse: bool = False 

74 categorical: bool = False 

75 string: bool = False 

76 dict: bool = False 

77 positive_only: bool = False 

78 allow_nan: bool = False 

79 pairwise: bool = False 

80 

81 # nilearn specific things 

82 

83 # estimator accepts for str, Path to .nii[.gz] file 

84 # or NiftiImage object 

85 niimg_like: bool = True 

86 # estimator accepts SurfaceImage object 

87 surf_img: bool = False 

88 

89 # estimator that are maskers 

90 # TODO: implement a masker_tags attribute 

91 masker: bool = False 

92 multi_masker: bool = False 

93 

94 # glm 

95 glm: bool = False