Coverage for nilearn/_utils/tests/test_masker_validation.py: 0%

76 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-20 10:58 +0200

1import numpy as np 

2import pytest 

3from joblib import Memory 

4from nibabel import Nifti1Image 

5from sklearn.base import BaseEstimator 

6 

7from nilearn._utils.masker_validation import check_embedded_masker 

8from nilearn.maskers import MultiNiftiMasker, NiftiMasker, SurfaceMasker 

9 

10 

11class OwningClass(BaseEstimator): 

12 def __init__( 

13 self, 

14 mask=None, 

15 smoothing_fwhm=None, 

16 standardize=False, 

17 detrend=False, 

18 low_pass=None, 

19 high_pass=None, 

20 t_r=None, 

21 target_affine=None, 

22 target_shape=None, 

23 mask_strategy="background", 

24 mask_args=None, 

25 memory=None, 

26 memory_level=0, 

27 n_jobs=1, 

28 verbose=0, 

29 dummy=None, 

30 ): 

31 if memory is None: 

32 memory = Memory(location=None) 

33 self.mask = mask 

34 

35 self.smoothing_fwhm = smoothing_fwhm 

36 self.standardize = standardize 

37 self.detrend = detrend 

38 self.low_pass = low_pass 

39 self.high_pass = high_pass 

40 self.t_r = t_r 

41 self.target_affine = target_affine 

42 self.target_shape = target_shape 

43 self.mask_strategy = mask_strategy 

44 self.mask_args = mask_args 

45 self.memory = memory 

46 self.memory_level = memory_level 

47 self.n_jobs = n_jobs 

48 self.verbose = verbose 

49 self.dummy = dummy 

50 

51 

52class DummyEstimator: 

53 def __init__(self, **kwargs): 

54 for k, v in kwargs.items(): 

55 setattr(self, k, v) 

56 

57 def fit(self, *args, **kwargs): # noqa: ARG002 

58 self.masker = check_embedded_masker(self) 

59 

60 

61def test_check_embedded_masker_defaults(): 

62 dummy = DummyEstimator(memory=None, memory_level=1) 

63 with pytest.warns( 

64 Warning, match="Provided estimator has no verbose attribute set." 

65 ): 

66 dummy.fit() 

67 assert dummy.masker.memory_level == 0 

68 assert dummy.masker.verbose == 0 

69 dummy = DummyEstimator(verbose=1) 

70 with pytest.warns( 

71 Warning, match="Provided estimator has no memory attribute set." 

72 ): 

73 dummy.fit() 

74 assert isinstance(dummy.masker.memory, Memory) 

75 assert dummy.masker.memory.location is None 

76 assert dummy.masker.memory_level == 0 

77 assert dummy.masker.verbose == 1 

78 

79 

80def test_check_embedded_masker(): 

81 owner = OwningClass() 

82 masker = check_embedded_masker(owner) 

83 assert type(masker) is MultiNiftiMasker 

84 

85 for mask, masker_type in ( 

86 (MultiNiftiMasker(), "multi_nii"), 

87 (NiftiMasker(), "nii"), 

88 (SurfaceMasker(), "surface"), 

89 ): 

90 owner = OwningClass(mask=mask) 

91 masker = check_embedded_masker(owner, masker_type=masker_type) 

92 assert isinstance(masker, type(mask)) 

93 for param_key in masker.get_params(): 

94 if param_key not in [ 

95 "memory", 

96 "memory_level", 

97 "n_jobs", 

98 "verbose", 

99 ]: 

100 assert getattr(masker, param_key) == getattr(mask, param_key) 

101 else: 

102 assert getattr(masker, param_key) == getattr(owner, param_key) 

103 

104 # Check use of mask as mask_img 

105 shape = (6, 8, 10, 5) 

106 affine = np.eye(4) 

107 mask = Nifti1Image(np.ones(shape[:3], dtype=np.int8), affine) 

108 owner = OwningClass(mask=mask) 

109 masker = check_embedded_masker(owner) 

110 assert masker.mask_img is mask 

111 

112 # Check attribute forwarding 

113 data = np.zeros((9, 9, 9)) 

114 data[2:-2, 2:-2, 2:-2] = 10 

115 imgs = Nifti1Image(data, np.eye(4)) 

116 mask = MultiNiftiMasker() 

117 mask.fit([[imgs]]) 

118 owner = OwningClass(mask=mask) 

119 masker = check_embedded_masker(owner) 

120 assert masker.mask_img is mask.mask_img_ 

121 

122 # Check conflict warning 

123 mask = NiftiMasker(mask_strategy="epi") 

124 owner = OwningClass(mask=mask) 

125 with pytest.warns(UserWarning): 

126 check_embedded_masker(owner)