Coverage for nilearn/_utils/tests/test_tags.py: 0%

21 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-20 10:58 +0200

1"""Check Nilearn estimators tags.""" 

2 

3from sklearn.base import BaseEstimator 

4 

5from nilearn._utils.tags import SKLEARN_LT_1_6 

6 

7 

8class NilearnEstimator(BaseEstimator): 

9 """Dummy estimator that takes surface image but not nifti as inputs.""" 

10 

11 def __sklearn_tags__(self): 

12 # TODO 

13 # get rid of if block 

14 # bumping sklearn_version > 1.5 

15 if SKLEARN_LT_1_6: 

16 from nilearn._utils.tags import tags 

17 

18 return tags(surf_img=True, niimg_like=False) 

19 

20 from nilearn._utils.tags import InputTags 

21 

22 tags = super().__sklearn_tags__() 

23 tags.input_tags = InputTags(surf_img=True, niimg_like=False) 

24 return tags 

25 

26 

27def test_nilearn_tags(): 

28 """Check that adding tags to Nilearn estimators work as expected. 

29 

30 Especially with different sklearn versions. 

31 """ 

32 est = NilearnEstimator() 

33 

34 tags = est.__sklearn_tags__() 

35 if SKLEARN_LT_1_6: 

36 assert "niimg_like" not in tags["X_types"] 

37 assert "surf_img" in tags["X_types"] 

38 # making sure 2darray still here 

39 # as it allows to run some sklearn checks 

40 assert "2darray" in tags["X_types"] 

41 else: 

42 assert not tags.input_tags.niimg_like 

43 assert tags.input_tags.surf_img 

44 # making sure 2darray still here 

45 # as it allows to run some sklearn checks 

46 assert tags.input_tags.two_d_array