Coverage for nilearn/_utils/tests/test_masker_validation.py: 0%
76 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
1import numpy as np
2import pytest
3from joblib import Memory
4from nibabel import Nifti1Image
5from sklearn.base import BaseEstimator
7from nilearn._utils.masker_validation import check_embedded_masker
8from nilearn.maskers import MultiNiftiMasker, NiftiMasker, SurfaceMasker
11class OwningClass(BaseEstimator):
12 def __init__(
13 self,
14 mask=None,
15 smoothing_fwhm=None,
16 standardize=False,
17 detrend=False,
18 low_pass=None,
19 high_pass=None,
20 t_r=None,
21 target_affine=None,
22 target_shape=None,
23 mask_strategy="background",
24 mask_args=None,
25 memory=None,
26 memory_level=0,
27 n_jobs=1,
28 verbose=0,
29 dummy=None,
30 ):
31 if memory is None:
32 memory = Memory(location=None)
33 self.mask = mask
35 self.smoothing_fwhm = smoothing_fwhm
36 self.standardize = standardize
37 self.detrend = detrend
38 self.low_pass = low_pass
39 self.high_pass = high_pass
40 self.t_r = t_r
41 self.target_affine = target_affine
42 self.target_shape = target_shape
43 self.mask_strategy = mask_strategy
44 self.mask_args = mask_args
45 self.memory = memory
46 self.memory_level = memory_level
47 self.n_jobs = n_jobs
48 self.verbose = verbose
49 self.dummy = dummy
52class DummyEstimator:
53 def __init__(self, **kwargs):
54 for k, v in kwargs.items():
55 setattr(self, k, v)
57 def fit(self, *args, **kwargs): # noqa: ARG002
58 self.masker = check_embedded_masker(self)
61def test_check_embedded_masker_defaults():
62 dummy = DummyEstimator(memory=None, memory_level=1)
63 with pytest.warns(
64 Warning, match="Provided estimator has no verbose attribute set."
65 ):
66 dummy.fit()
67 assert dummy.masker.memory_level == 0
68 assert dummy.masker.verbose == 0
69 dummy = DummyEstimator(verbose=1)
70 with pytest.warns(
71 Warning, match="Provided estimator has no memory attribute set."
72 ):
73 dummy.fit()
74 assert isinstance(dummy.masker.memory, Memory)
75 assert dummy.masker.memory.location is None
76 assert dummy.masker.memory_level == 0
77 assert dummy.masker.verbose == 1
80def test_check_embedded_masker():
81 owner = OwningClass()
82 masker = check_embedded_masker(owner)
83 assert type(masker) is MultiNiftiMasker
85 for mask, masker_type in (
86 (MultiNiftiMasker(), "multi_nii"),
87 (NiftiMasker(), "nii"),
88 (SurfaceMasker(), "surface"),
89 ):
90 owner = OwningClass(mask=mask)
91 masker = check_embedded_masker(owner, masker_type=masker_type)
92 assert isinstance(masker, type(mask))
93 for param_key in masker.get_params():
94 if param_key not in [
95 "memory",
96 "memory_level",
97 "n_jobs",
98 "verbose",
99 ]:
100 assert getattr(masker, param_key) == getattr(mask, param_key)
101 else:
102 assert getattr(masker, param_key) == getattr(owner, param_key)
104 # Check use of mask as mask_img
105 shape = (6, 8, 10, 5)
106 affine = np.eye(4)
107 mask = Nifti1Image(np.ones(shape[:3], dtype=np.int8), affine)
108 owner = OwningClass(mask=mask)
109 masker = check_embedded_masker(owner)
110 assert masker.mask_img is mask
112 # Check attribute forwarding
113 data = np.zeros((9, 9, 9))
114 data[2:-2, 2:-2, 2:-2] = 10
115 imgs = Nifti1Image(data, np.eye(4))
116 mask = MultiNiftiMasker()
117 mask.fit([[imgs]])
118 owner = OwningClass(mask=mask)
119 masker = check_embedded_masker(owner)
120 assert masker.mask_img is mask.mask_img_
122 # Check conflict warning
123 mask = NiftiMasker(mask_strategy="epi")
124 owner = OwningClass(mask=mask)
125 with pytest.warns(UserWarning):
126 check_embedded_masker(owner)