Coverage for nilearn/_utils/tests/test_estimator_checks.py: 0%
40 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
1"""
2Test the class_inspect module.
4This test file is in nilearn/tests because Nosetest,
5which we historically used,
6ignores modules whose name starts with an underscore.
7"""
9import pytest
11from nilearn._utils.estimator_checks import (
12 check_estimator_has_sklearn_is_fitted,
13 check_masker_dict_unchanged,
14)
15from nilearn.maskers.base_masker import BaseMasker
18def test_check_estimator_has_sklearn_is_fitted():
19 """Check errors are thrown for unfitted estimator.
21 Check that before fitting
22 - estimator has a __sklearn_is_fitted__ method that returns false
23 - running sklearn check_is_fitted on masker throws an error
24 """
26 class DummyEstimator:
27 def __init__(self):
28 pass
30 with pytest.raises(
31 TypeError, match="must have __sklearn_is_fitted__ method"
32 ):
33 check_estimator_has_sklearn_is_fitted(DummyEstimator())
35 class DummyEstimator:
36 def __init__(self):
37 pass
39 def __sklearn_is_fitted__(self):
40 return True
42 with pytest.raises(ValueError, match="must return False before fit"):
43 check_estimator_has_sklearn_is_fitted(DummyEstimator())
46def test_check_masker_dict_unchanged():
47 class DummyEstimator(BaseMasker):
48 """Estimator with a transform method that adds a new attribute."""
50 def __init__(self, mask_img=None):
51 self.mask_img = mask_img
53 def fit(self, imgs):
54 self.imgs = imgs
55 return self
57 def transform(self, imgs):
58 self._imgs = imgs
60 estimator = DummyEstimator()
62 with pytest.raises(
63 ValueError, match="Estimator changes '__dict__' keys during transform."
64 ):
65 check_masker_dict_unchanged(estimator)
67 class DummyEstimator(BaseMasker):
68 """Estimator with a transform method that modifies an attribute."""
70 def __init__(self, mask_img=None):
71 self.mask_img = mask_img
73 def fit(self, imgs):
74 self.imgs = imgs
75 return self
77 def transform(self, imgs):
78 del imgs
79 self.imgs = 1
81 estimator = DummyEstimator()
83 with pytest.raises(
84 ValueError, match="Estimator changes the following '__dict__' keys"
85 ):
86 check_masker_dict_unchanged(estimator)