Coverage for nilearn/_utils/tests/test_estimator_checks.py: 0%

40 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-20 10:58 +0200

1""" 

2Test the class_inspect module. 

3 

4This test file is in nilearn/tests because Nosetest, 

5which we historically used, 

6ignores modules whose name starts with an underscore. 

7""" 

8 

9import pytest 

10 

11from nilearn._utils.estimator_checks import ( 

12 check_estimator_has_sklearn_is_fitted, 

13 check_masker_dict_unchanged, 

14) 

15from nilearn.maskers.base_masker import BaseMasker 

16 

17 

18def test_check_estimator_has_sklearn_is_fitted(): 

19 """Check errors are thrown for unfitted estimator. 

20 

21 Check that before fitting 

22 - estimator has a __sklearn_is_fitted__ method that returns false 

23 - running sklearn check_is_fitted on masker throws an error 

24 """ 

25 

26 class DummyEstimator: 

27 def __init__(self): 

28 pass 

29 

30 with pytest.raises( 

31 TypeError, match="must have __sklearn_is_fitted__ method" 

32 ): 

33 check_estimator_has_sklearn_is_fitted(DummyEstimator()) 

34 

35 class DummyEstimator: 

36 def __init__(self): 

37 pass 

38 

39 def __sklearn_is_fitted__(self): 

40 return True 

41 

42 with pytest.raises(ValueError, match="must return False before fit"): 

43 check_estimator_has_sklearn_is_fitted(DummyEstimator()) 

44 

45 

46def test_check_masker_dict_unchanged(): 

47 class DummyEstimator(BaseMasker): 

48 """Estimator with a transform method that adds a new attribute.""" 

49 

50 def __init__(self, mask_img=None): 

51 self.mask_img = mask_img 

52 

53 def fit(self, imgs): 

54 self.imgs = imgs 

55 return self 

56 

57 def transform(self, imgs): 

58 self._imgs = imgs 

59 

60 estimator = DummyEstimator() 

61 

62 with pytest.raises( 

63 ValueError, match="Estimator changes '__dict__' keys during transform." 

64 ): 

65 check_masker_dict_unchanged(estimator) 

66 

67 class DummyEstimator(BaseMasker): 

68 """Estimator with a transform method that modifies an attribute.""" 

69 

70 def __init__(self, mask_img=None): 

71 self.mask_img = mask_img 

72 

73 def fit(self, imgs): 

74 self.imgs = imgs 

75 return self 

76 

77 def transform(self, imgs): 

78 del imgs 

79 self.imgs = 1 

80 

81 estimator = DummyEstimator() 

82 

83 with pytest.raises( 

84 ValueError, match="Estimator changes the following '__dict__' keys" 

85 ): 

86 check_masker_dict_unchanged(estimator)