Coverage for nilearn/glm/_base.py: 13%
207 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
1import warnings
2from collections import OrderedDict
3from copy import deepcopy
4from pathlib import Path
6import numpy as np
7from nibabel.onetime import auto_attr
8from sklearn.base import BaseEstimator
9from sklearn.utils.estimator_checks import check_is_fitted
11from nilearn._utils import CacheMixin
12from nilearn._utils.glm import coerce_to_dict
13from nilearn._utils.logger import find_stack_level
14from nilearn._utils.tags import SKLEARN_LT_1_6
15from nilearn.externals import tempita
16from nilearn.interfaces.bids.utils import bids_entities, create_bids_filename
17from nilearn.maskers import SurfaceMasker
18from nilearn.surface import SurfaceImage
20FIGURE_FORMAT = "png"
23class BaseGLM(CacheMixin, BaseEstimator):
24 """Implement a base class \
25 for the :term:`General Linear Model<GLM>`.
26 """
28 def _is_volume_glm(self):
29 """Return if model is run on volume data or not."""
30 return not (
31 (
32 hasattr(self, "mask_img")
33 and isinstance(self.mask_img, (SurfaceMasker, SurfaceImage))
34 )
35 or (
36 self.__sklearn_is_fitted__()
37 and hasattr(self, "masker_")
38 and isinstance(self.masker_, SurfaceMasker)
39 )
40 )
42 def _attributes_to_dict(self):
43 """Return dict with pertinent model attributes & information.
45 Returns
46 -------
47 dict
48 """
49 selected_attributes = [
50 "subject_label",
51 "drift_model",
52 "hrf_model",
53 "standardize",
54 "noise_model",
55 "t_r",
56 "signal_scaling",
57 "scaling_axis",
58 "smoothing_fwhm",
59 "slice_time_ref",
60 ]
61 if self._is_volume_glm():
62 selected_attributes.extend(["target_shape", "target_affine"])
63 if self.__str__() == "First Level Model":
64 if self.hrf_model == "fir":
65 selected_attributes.append("fir_delays")
67 if self.drift_model == "cosine":
68 selected_attributes.append("high_pass")
69 elif self.drift_model == "polynomial":
70 selected_attributes.append("drift_order")
72 selected_attributes.sort()
74 model_param = OrderedDict(
75 (attr_name, getattr(self, attr_name))
76 for attr_name in selected_attributes
77 if getattr(self, attr_name, None) is not None
78 )
80 for k, v in model_param.items():
81 if isinstance(v, np.ndarray):
82 model_param[k] = v.tolist()
84 return model_param
86 def _more_tags(self):
87 """Return estimator tags.
89 TODO remove when bumping sklearn_version > 1.5
90 """
91 return self.__sklearn_tags__()
93 def __sklearn_tags__(self):
94 """Return estimator tags.
96 See the sklearn documentation for more details on tags
97 https://scikit-learn.org/1.6/developers/develop.html#estimator-tags
98 """
99 # TODO
100 # get rid of if block
101 if SKLEARN_LT_1_6:
102 from nilearn._utils.tags import tags
104 return tags(surf_img=True, niimg_like=True, glm=True)
106 from nilearn._utils.tags import InputTags
108 tags = super().__sklearn_tags__()
109 tags.input_tags = InputTags(surf_img=True, niimg_like=True, glm=True)
110 return tags
112 # @auto_attr store the value as an object attribute after initial call
113 # better performance than @property
114 @auto_attr
115 def residuals(self):
116 """Transform voxelwise residuals to the same shape \
117 as the input Nifti1Image(s).
119 Returns
120 -------
121 output : list
122 A list of Nifti1Image(s).
124 """
125 return self._get_element_wise_model_attribute(
126 "residuals", result_as_time_series=True
127 )
129 @auto_attr
130 def predicted(self):
131 """Transform voxelwise predicted values to the same shape \
132 as the input Nifti1Image(s).
134 Returns
135 -------
136 output : list
137 A list of Nifti1Image(s).
139 """
140 return self._get_element_wise_model_attribute(
141 "predicted", result_as_time_series=True
142 )
144 @auto_attr
145 def r_square(self):
146 """Transform voxelwise r-squared values to the same shape \
147 as the input Nifti1Image(s).
149 Returns
150 -------
151 output : list
152 A list of Nifti1Image(s).
154 """
155 return self._get_element_wise_model_attribute(
156 "r_square", result_as_time_series=False
157 )
159 def _generate_filenames_output(
160 self, prefix, contrasts, contrast_types, out_dir, entities_to_drop=None
161 ):
162 """Generate output filenames for a series of contrasts.
164 This function constructs and stores the expected output filenames
165 for contrast-related statistical maps and design matrices within
166 the model.
168 Output files try to follow the BIDS convention where applicable.
169 For first level models,
170 if no prefix is passed,
171 and str or Path were used as input files to the GLM
172 the output filenames will be based on the input files.
174 See nilearn.interfaces.bids.save_glm_to_bids for more details.
176 Parameters
177 ----------
178 prefix : :obj:`str`
179 String to prepend to generated filenames.
180 If a string is provided, '_' will be added to the end.
182 contrasts : :obj:`str` or array of shape (n_col) or :obj:`list` \
183 of (:obj:`str` or array of shape (n_col)) or :obj:`dict`
184 Contrast definitions.
186 contrast_types ::obj:`dict` of :obj:`str`
187 An optional dictionary mapping some
188 or all of the :term:`contrast` names to
189 specific contrast types ('t' or 'F').
191 out_dir : :obj:`str` or :obj:`pathlib.Path`
192 Output directory for files.
194 entities_to_drop : :obj:`list` of :obj:`str` or None, default=None
195 name of BIDS entities to drop
196 from input filenames
197 when generating output filenames.
198 If None is passed this will default to:
199 ["part", "echo", "hemi", "desc"]
201 Notes
202 -----
203 - The function ensures that contrast names are valid strings.
204 - It constructs filenames for effect sizes, statistical maps,
205 and design matrices in a structured manner.
206 - The output directory structure may include a subject-level
207 or group-level subdirectory based on the model type.
208 """
209 check_is_fitted(self)
211 generate_bids_name = _use_input_files_for_filenaming(self, prefix)
213 contrasts = coerce_to_dict(contrasts)
214 for k, v in contrasts.items():
215 if not isinstance(k, str):
216 raise ValueError(
217 f"contrast names must be strings, not {type(k)}"
218 )
220 if not isinstance(v, (str, np.ndarray, list)):
221 raise ValueError(
222 "contrast definitions must be strings or array_likes, "
223 f"not {type(v)}"
224 )
226 entities = {
227 "sub": None,
228 "ses": None,
229 "task": None,
230 "space": None,
231 }
233 if generate_bids_name:
234 # try to figure out filename entities from input files
235 # only keep entity label if unique across runs
236 for k in entities:
237 label = [
238 x["entities"].get(k)
239 for x in self._reporting_data["run_imgs"].values()
240 if x["entities"].get(k) is not None
241 ]
243 label = set(label)
244 if len(label) != 1:
245 continue
246 label = next(iter(label))
247 entities[k] = label
249 elif not isinstance(prefix, str):
250 prefix = ""
252 if self.__str__() == "Second Level Model":
253 sub = "group"
254 elif entities["sub"]:
255 sub = f"sub-{entities['sub']}"
256 else:
257 sub = prefix.split("_")[0] if prefix.startswith("sub-") else ""
259 if self.__str__() == "Second Level Model":
260 design_matrices = [self.design_matrix_]
261 else:
262 design_matrices = self.design_matrices_
264 # dropping some entities to avoid polluting output names
265 all_entities = [
266 *bids_entities()["raw"],
267 *bids_entities()["derivatives"],
268 ]
269 if entities_to_drop is None:
270 entities_to_drop = ["part", "echo", "hemi", "desc"]
271 assert all(isinstance(x, str) for x in entities_to_drop)
272 entities_to_include = [
273 x for x in all_entities if x not in entities_to_drop
274 ]
275 if not generate_bids_name:
276 entities_to_include = ["run"]
277 entities_to_include.extend(["contrast", "stat"])
279 mask = _generate_mask(
280 prefix, generate_bids_name, entities, entities_to_include
281 )
283 statistical_maps = _generate_statistical_maps(
284 prefix,
285 contrasts,
286 contrast_types,
287 generate_bids_name,
288 entities,
289 entities_to_include,
290 )
292 model_level_mapping = _generate_model_level_mapping(
293 self,
294 prefix,
295 design_matrices,
296 generate_bids_name,
297 entities,
298 entities_to_include,
299 )
301 design_matrices_dict = _generate_design_matrices_dict(
302 self,
303 prefix,
304 design_matrices,
305 generate_bids_name,
306 entities_to_include,
307 )
309 contrasts_dict = _generate_contrasts_dict(
310 self,
311 prefix,
312 contrasts,
313 design_matrices,
314 generate_bids_name,
315 entities,
316 entities_to_include,
317 )
319 out_dir = Path(out_dir) / sub
321 # consider using a class or data class
322 # to better standardize naming
323 self._reporting_data["filenames"] = {
324 "dir": out_dir,
325 "mask": mask,
326 "design_matrices_dict": design_matrices_dict,
327 "contrasts_dict": contrasts_dict,
328 "statistical_maps": statistical_maps,
329 "model_level_mapping": model_level_mapping,
330 }
333def _generate_mask(
334 prefix: str,
335 generate_bids_name: bool,
336 entities,
337 entities_to_include: list[str],
338):
339 """Return filename for GLM mask."""
340 fields = {
341 "prefix": prefix,
342 "suffix": "mask",
343 "extension": "nii.gz",
344 "entities": deepcopy(entities),
345 }
346 fields["entities"].pop("run", None)
347 fields["entities"].pop("ses", None)
349 if generate_bids_name:
350 fields["prefix"] = None
352 return create_bids_filename(fields, entities_to_include)
355def _generate_statistical_maps(
356 prefix: str,
357 contrasts,
358 contrast_types,
359 generate_bids_name: bool,
360 entities,
361 entities_to_include: list[str],
362):
363 """Return dictionary containing statmap filenames for each contrast.
365 statistical_maps[contrast_name][statmap_label] = filename
366 """
367 if not isinstance(contrast_types, dict):
368 contrast_types = {}
370 statistical_maps: dict[str, dict[str, str]] = {}
372 for contrast_name in contrasts:
373 # Extract stat_type
374 contrast_matrix = contrasts[contrast_name]
375 # Strings and 1D arrays are assumed to be t-contrasts
376 if isinstance(contrast_matrix, str) or (contrast_matrix.ndim == 1):
377 stat_type = "t"
378 else:
379 stat_type = "F"
380 # Override automatic detection with explicit type if provided
381 stat_type = contrast_types.get(contrast_name, stat_type)
383 fields = {
384 "prefix": prefix,
385 "suffix": "statmap",
386 "extension": "nii.gz",
387 "entities": deepcopy(entities),
388 }
390 if generate_bids_name:
391 fields["prefix"] = None
393 fields["entities"]["contrast"] = _clean_contrast_name(contrast_name)
395 tmp = {}
396 for key, stat_label in zip(
397 [
398 "effect_size",
399 "stat",
400 "effect_variance",
401 "z_score",
402 "p_value",
403 ],
404 ["effect", stat_type, "variance", "z", "p"],
405 ):
406 fields["entities"]["stat"] = stat_label
407 tmp[key] = create_bids_filename(fields, entities_to_include)
409 fields["entities"]["stat"] = None
410 fields["suffix"] = "clusters"
411 fields["extension"] = "tsv"
412 tmp["clusters_tsv"] = create_bids_filename(fields, entities_to_include)
414 fields["extension"] = "json"
415 tmp["metadata"] = create_bids_filename(fields, entities_to_include)
417 statistical_maps[contrast_name] = tempita.bunch(**tmp)
419 return statistical_maps
422def _generate_model_level_mapping(
423 model,
424 prefix: str,
425 design_matrices,
426 generate_bids_name: bool,
427 entities,
428 entities_to_include: list[str],
429):
430 """Return dictionary of filenames for nifti of runwise error & residuals.
432 model_level_mapping[i_run][statmap_label] = filename
433 """
434 fields = {
435 "prefix": prefix,
436 "suffix": "statmap",
437 "extension": "nii.gz",
438 "entities": deepcopy(entities),
439 }
441 if generate_bids_name:
442 fields["prefix"] = None
444 model_level_mapping = {}
446 for i_run, _ in enumerate(design_matrices):
447 if _is_flm_with_single_run(model):
448 fields["entities"]["run"] = i_run + 1
449 if generate_bids_name:
450 fields["entities"] = deepcopy(
451 model._reporting_data["run_imgs"][i_run]["entities"]
452 )
454 tmp = {}
455 for key, stat_label in zip(
456 ["residuals", "r_square"],
457 ["errorts", "rsquared"],
458 ):
459 fields["entities"]["stat"] = stat_label
460 tmp[key] = create_bids_filename(fields, entities_to_include)
462 model_level_mapping[i_run] = tempita.bunch(**tmp)
464 return model_level_mapping
467def _generate_design_matrices_dict(
468 model,
469 prefix: str,
470 design_matrices,
471 generate_bids_name: bool,
472 entities_to_include: list[str],
473) -> dict[int, dict[str, str]]:
474 """Return dictionary with filenames for design_matrices figures / tables.
476 design_matrices_dict[i_run][key] = filename
477 """
478 fields = {"prefix": prefix, "extension": FIGURE_FORMAT, "entities": {}}
479 if generate_bids_name:
480 fields["prefix"] = None # type: ignore[assignment]
482 design_matrices_dict = tempita.bunch()
484 for i_run, _ in enumerate(design_matrices):
485 if _is_flm_with_single_run(model):
486 fields["entities"] = {"run": i_run + 1} # type: ignore[assignment]
487 if generate_bids_name:
488 fields["entities"] = deepcopy(
489 model._reporting_data["run_imgs"][i_run]["entities"]
490 )
492 tmp = {}
493 for extension in [FIGURE_FORMAT, "tsv"]:
494 for key, suffix in zip(
495 ["design_matrix", "correlation_matrix"],
496 ["design", "corrdesign"],
497 ):
498 fields["extension"] = extension
499 fields["suffix"] = suffix
500 tmp[f"{key}_{extension}"] = create_bids_filename(
501 fields, entities_to_include
502 )
504 design_matrices_dict[i_run] = tempita.bunch(**tmp)
506 return design_matrices_dict
509def _generate_contrasts_dict(
510 model,
511 prefix: str,
512 contrasts,
513 design_matrices,
514 generate_bids_name: bool,
515 entities,
516 entities_to_include: list[str],
517) -> dict[int, dict[str, str]]:
518 """Return dictionary with filenames for contrast matrices figures.
520 contrasts_dict[i_run][contrast_name] = filename
521 """
522 fields = {
523 "prefix": prefix,
524 "extension": FIGURE_FORMAT,
525 "entities": deepcopy(entities),
526 "suffix": "design",
527 }
528 if generate_bids_name:
529 fields["prefix"] = None
531 contrasts_dict = tempita.bunch()
533 for i_run, _ in enumerate(design_matrices):
534 if _is_flm_with_single_run(model):
535 fields["entities"]["run"] = i_run + 1
536 if generate_bids_name:
537 fields["entities"] = deepcopy(
538 model._reporting_data["run_imgs"][i_run]["entities"]
539 )
541 tmp = {}
542 for contrast_name in contrasts:
543 fields["entities"]["contrast"] = _clean_contrast_name(
544 contrast_name
545 )
546 tmp[contrast_name] = create_bids_filename(
547 fields, entities_to_include
548 )
550 contrasts_dict[i_run] = tempita.bunch(**tmp)
552 return contrasts_dict
555def _use_input_files_for_filenaming(self, prefix) -> bool:
556 """Determine if we should try to use input files to generate \
557 output filenames.
558 """
559 if self.__str__() == "Second Level Model" or prefix is not None:
560 return False
562 input_files = self._reporting_data["run_imgs"]
564 files_used_as_input = all(len(x) > 0 for x in input_files.values())
565 tmp = {x.get("sub") for x in input_files.values()}
566 all_files_have_same_sub = len(tmp) == 1 and tmp is not None
568 return files_used_as_input and all_files_have_same_sub
571def _is_flm_with_single_run(model) -> bool:
572 return (
573 model.__str__() == "First Level Model"
574 and len(model._reporting_data["run_imgs"]) > 1
575 )
578def _clean_contrast_name(contrast_name):
579 """Remove prohibited characters from name and convert to camelCase.
581 .. versionadded:: 0.9.2
583 BIDS filenames, in which the contrast name will appear as a
584 contrast-<name> key/value pair, must be alphanumeric strings.
586 Parameters
587 ----------
588 contrast_name : :obj:`str`
589 Contrast name to clean.
591 Returns
592 -------
593 new_name : :obj:`str`
594 Contrast name converted to alphanumeric-only camelCase.
595 """
596 new_name = contrast_name[:]
598 # Some characters translate to words
599 new_name = new_name.replace("-", " Minus ")
600 new_name = new_name.replace("+", " Plus ")
601 new_name = new_name.replace(">", " Gt ")
602 new_name = new_name.replace("<", " Lt ")
604 # Others translate to spaces
605 new_name = new_name.replace("_", " ")
607 # Convert to camelCase
608 new_name = new_name.split(" ")
609 new_name[0] = new_name[0].lower()
610 new_name[1:] = [c.title() for c in new_name[1:]]
611 new_name = " ".join(new_name)
613 # Remove non-alphanumeric characters
614 new_name = "".join(ch for ch in new_name if ch.isalnum())
616 # Let users know if the name was changed
617 if new_name != contrast_name:
618 warnings.warn(
619 f'Contrast name "{contrast_name}" changed to "{new_name}"',
620 stacklevel=find_stack_level(),
621 )
622 return new_name