Coverage for nilearn/glm/_base.py: 13%

207 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-20 10:58 +0200

1import warnings 

2from collections import OrderedDict 

3from copy import deepcopy 

4from pathlib import Path 

5 

6import numpy as np 

7from nibabel.onetime import auto_attr 

8from sklearn.base import BaseEstimator 

9from sklearn.utils.estimator_checks import check_is_fitted 

10 

11from nilearn._utils import CacheMixin 

12from nilearn._utils.glm import coerce_to_dict 

13from nilearn._utils.logger import find_stack_level 

14from nilearn._utils.tags import SKLEARN_LT_1_6 

15from nilearn.externals import tempita 

16from nilearn.interfaces.bids.utils import bids_entities, create_bids_filename 

17from nilearn.maskers import SurfaceMasker 

18from nilearn.surface import SurfaceImage 

19 

20FIGURE_FORMAT = "png" 

21 

22 

23class BaseGLM(CacheMixin, BaseEstimator): 

24 """Implement a base class \ 

25 for the :term:`General Linear Model<GLM>`. 

26 """ 

27 

28 def _is_volume_glm(self): 

29 """Return if model is run on volume data or not.""" 

30 return not ( 

31 ( 

32 hasattr(self, "mask_img") 

33 and isinstance(self.mask_img, (SurfaceMasker, SurfaceImage)) 

34 ) 

35 or ( 

36 self.__sklearn_is_fitted__() 

37 and hasattr(self, "masker_") 

38 and isinstance(self.masker_, SurfaceMasker) 

39 ) 

40 ) 

41 

42 def _attributes_to_dict(self): 

43 """Return dict with pertinent model attributes & information. 

44 

45 Returns 

46 ------- 

47 dict 

48 """ 

49 selected_attributes = [ 

50 "subject_label", 

51 "drift_model", 

52 "hrf_model", 

53 "standardize", 

54 "noise_model", 

55 "t_r", 

56 "signal_scaling", 

57 "scaling_axis", 

58 "smoothing_fwhm", 

59 "slice_time_ref", 

60 ] 

61 if self._is_volume_glm(): 

62 selected_attributes.extend(["target_shape", "target_affine"]) 

63 if self.__str__() == "First Level Model": 

64 if self.hrf_model == "fir": 

65 selected_attributes.append("fir_delays") 

66 

67 if self.drift_model == "cosine": 

68 selected_attributes.append("high_pass") 

69 elif self.drift_model == "polynomial": 

70 selected_attributes.append("drift_order") 

71 

72 selected_attributes.sort() 

73 

74 model_param = OrderedDict( 

75 (attr_name, getattr(self, attr_name)) 

76 for attr_name in selected_attributes 

77 if getattr(self, attr_name, None) is not None 

78 ) 

79 

80 for k, v in model_param.items(): 

81 if isinstance(v, np.ndarray): 

82 model_param[k] = v.tolist() 

83 

84 return model_param 

85 

86 def _more_tags(self): 

87 """Return estimator tags. 

88 

89 TODO remove when bumping sklearn_version > 1.5 

90 """ 

91 return self.__sklearn_tags__() 

92 

93 def __sklearn_tags__(self): 

94 """Return estimator tags. 

95 

96 See the sklearn documentation for more details on tags 

97 https://scikit-learn.org/1.6/developers/develop.html#estimator-tags 

98 """ 

99 # TODO 

100 # get rid of if block 

101 if SKLEARN_LT_1_6: 

102 from nilearn._utils.tags import tags 

103 

104 return tags(surf_img=True, niimg_like=True, glm=True) 

105 

106 from nilearn._utils.tags import InputTags 

107 

108 tags = super().__sklearn_tags__() 

109 tags.input_tags = InputTags(surf_img=True, niimg_like=True, glm=True) 

110 return tags 

111 

112 # @auto_attr store the value as an object attribute after initial call 

113 # better performance than @property 

114 @auto_attr 

115 def residuals(self): 

116 """Transform voxelwise residuals to the same shape \ 

117 as the input Nifti1Image(s). 

118 

119 Returns 

120 ------- 

121 output : list 

122 A list of Nifti1Image(s). 

123 

124 """ 

125 return self._get_element_wise_model_attribute( 

126 "residuals", result_as_time_series=True 

127 ) 

128 

129 @auto_attr 

130 def predicted(self): 

131 """Transform voxelwise predicted values to the same shape \ 

132 as the input Nifti1Image(s). 

133 

134 Returns 

135 ------- 

136 output : list 

137 A list of Nifti1Image(s). 

138 

139 """ 

140 return self._get_element_wise_model_attribute( 

141 "predicted", result_as_time_series=True 

142 ) 

143 

144 @auto_attr 

145 def r_square(self): 

146 """Transform voxelwise r-squared values to the same shape \ 

147 as the input Nifti1Image(s). 

148 

149 Returns 

150 ------- 

151 output : list 

152 A list of Nifti1Image(s). 

153 

154 """ 

155 return self._get_element_wise_model_attribute( 

156 "r_square", result_as_time_series=False 

157 ) 

158 

159 def _generate_filenames_output( 

160 self, prefix, contrasts, contrast_types, out_dir, entities_to_drop=None 

161 ): 

162 """Generate output filenames for a series of contrasts. 

163 

164 This function constructs and stores the expected output filenames 

165 for contrast-related statistical maps and design matrices within 

166 the model. 

167 

168 Output files try to follow the BIDS convention where applicable. 

169 For first level models, 

170 if no prefix is passed, 

171 and str or Path were used as input files to the GLM 

172 the output filenames will be based on the input files. 

173 

174 See nilearn.interfaces.bids.save_glm_to_bids for more details. 

175 

176 Parameters 

177 ---------- 

178 prefix : :obj:`str` 

179 String to prepend to generated filenames. 

180 If a string is provided, '_' will be added to the end. 

181 

182 contrasts : :obj:`str` or array of shape (n_col) or :obj:`list` \ 

183 of (:obj:`str` or array of shape (n_col)) or :obj:`dict` 

184 Contrast definitions. 

185 

186 contrast_types ::obj:`dict` of :obj:`str` 

187 An optional dictionary mapping some 

188 or all of the :term:`contrast` names to 

189 specific contrast types ('t' or 'F'). 

190 

191 out_dir : :obj:`str` or :obj:`pathlib.Path` 

192 Output directory for files. 

193 

194 entities_to_drop : :obj:`list` of :obj:`str` or None, default=None 

195 name of BIDS entities to drop 

196 from input filenames 

197 when generating output filenames. 

198 If None is passed this will default to: 

199 ["part", "echo", "hemi", "desc"] 

200 

201 Notes 

202 ----- 

203 - The function ensures that contrast names are valid strings. 

204 - It constructs filenames for effect sizes, statistical maps, 

205 and design matrices in a structured manner. 

206 - The output directory structure may include a subject-level 

207 or group-level subdirectory based on the model type. 

208 """ 

209 check_is_fitted(self) 

210 

211 generate_bids_name = _use_input_files_for_filenaming(self, prefix) 

212 

213 contrasts = coerce_to_dict(contrasts) 

214 for k, v in contrasts.items(): 

215 if not isinstance(k, str): 

216 raise ValueError( 

217 f"contrast names must be strings, not {type(k)}" 

218 ) 

219 

220 if not isinstance(v, (str, np.ndarray, list)): 

221 raise ValueError( 

222 "contrast definitions must be strings or array_likes, " 

223 f"not {type(v)}" 

224 ) 

225 

226 entities = { 

227 "sub": None, 

228 "ses": None, 

229 "task": None, 

230 "space": None, 

231 } 

232 

233 if generate_bids_name: 

234 # try to figure out filename entities from input files 

235 # only keep entity label if unique across runs 

236 for k in entities: 

237 label = [ 

238 x["entities"].get(k) 

239 for x in self._reporting_data["run_imgs"].values() 

240 if x["entities"].get(k) is not None 

241 ] 

242 

243 label = set(label) 

244 if len(label) != 1: 

245 continue 

246 label = next(iter(label)) 

247 entities[k] = label 

248 

249 elif not isinstance(prefix, str): 

250 prefix = "" 

251 

252 if self.__str__() == "Second Level Model": 

253 sub = "group" 

254 elif entities["sub"]: 

255 sub = f"sub-{entities['sub']}" 

256 else: 

257 sub = prefix.split("_")[0] if prefix.startswith("sub-") else "" 

258 

259 if self.__str__() == "Second Level Model": 

260 design_matrices = [self.design_matrix_] 

261 else: 

262 design_matrices = self.design_matrices_ 

263 

264 # dropping some entities to avoid polluting output names 

265 all_entities = [ 

266 *bids_entities()["raw"], 

267 *bids_entities()["derivatives"], 

268 ] 

269 if entities_to_drop is None: 

270 entities_to_drop = ["part", "echo", "hemi", "desc"] 

271 assert all(isinstance(x, str) for x in entities_to_drop) 

272 entities_to_include = [ 

273 x for x in all_entities if x not in entities_to_drop 

274 ] 

275 if not generate_bids_name: 

276 entities_to_include = ["run"] 

277 entities_to_include.extend(["contrast", "stat"]) 

278 

279 mask = _generate_mask( 

280 prefix, generate_bids_name, entities, entities_to_include 

281 ) 

282 

283 statistical_maps = _generate_statistical_maps( 

284 prefix, 

285 contrasts, 

286 contrast_types, 

287 generate_bids_name, 

288 entities, 

289 entities_to_include, 

290 ) 

291 

292 model_level_mapping = _generate_model_level_mapping( 

293 self, 

294 prefix, 

295 design_matrices, 

296 generate_bids_name, 

297 entities, 

298 entities_to_include, 

299 ) 

300 

301 design_matrices_dict = _generate_design_matrices_dict( 

302 self, 

303 prefix, 

304 design_matrices, 

305 generate_bids_name, 

306 entities_to_include, 

307 ) 

308 

309 contrasts_dict = _generate_contrasts_dict( 

310 self, 

311 prefix, 

312 contrasts, 

313 design_matrices, 

314 generate_bids_name, 

315 entities, 

316 entities_to_include, 

317 ) 

318 

319 out_dir = Path(out_dir) / sub 

320 

321 # consider using a class or data class 

322 # to better standardize naming 

323 self._reporting_data["filenames"] = { 

324 "dir": out_dir, 

325 "mask": mask, 

326 "design_matrices_dict": design_matrices_dict, 

327 "contrasts_dict": contrasts_dict, 

328 "statistical_maps": statistical_maps, 

329 "model_level_mapping": model_level_mapping, 

330 } 

331 

332 

333def _generate_mask( 

334 prefix: str, 

335 generate_bids_name: bool, 

336 entities, 

337 entities_to_include: list[str], 

338): 

339 """Return filename for GLM mask.""" 

340 fields = { 

341 "prefix": prefix, 

342 "suffix": "mask", 

343 "extension": "nii.gz", 

344 "entities": deepcopy(entities), 

345 } 

346 fields["entities"].pop("run", None) 

347 fields["entities"].pop("ses", None) 

348 

349 if generate_bids_name: 

350 fields["prefix"] = None 

351 

352 return create_bids_filename(fields, entities_to_include) 

353 

354 

355def _generate_statistical_maps( 

356 prefix: str, 

357 contrasts, 

358 contrast_types, 

359 generate_bids_name: bool, 

360 entities, 

361 entities_to_include: list[str], 

362): 

363 """Return dictionary containing statmap filenames for each contrast. 

364 

365 statistical_maps[contrast_name][statmap_label] = filename 

366 """ 

367 if not isinstance(contrast_types, dict): 

368 contrast_types = {} 

369 

370 statistical_maps: dict[str, dict[str, str]] = {} 

371 

372 for contrast_name in contrasts: 

373 # Extract stat_type 

374 contrast_matrix = contrasts[contrast_name] 

375 # Strings and 1D arrays are assumed to be t-contrasts 

376 if isinstance(contrast_matrix, str) or (contrast_matrix.ndim == 1): 

377 stat_type = "t" 

378 else: 

379 stat_type = "F" 

380 # Override automatic detection with explicit type if provided 

381 stat_type = contrast_types.get(contrast_name, stat_type) 

382 

383 fields = { 

384 "prefix": prefix, 

385 "suffix": "statmap", 

386 "extension": "nii.gz", 

387 "entities": deepcopy(entities), 

388 } 

389 

390 if generate_bids_name: 

391 fields["prefix"] = None 

392 

393 fields["entities"]["contrast"] = _clean_contrast_name(contrast_name) 

394 

395 tmp = {} 

396 for key, stat_label in zip( 

397 [ 

398 "effect_size", 

399 "stat", 

400 "effect_variance", 

401 "z_score", 

402 "p_value", 

403 ], 

404 ["effect", stat_type, "variance", "z", "p"], 

405 ): 

406 fields["entities"]["stat"] = stat_label 

407 tmp[key] = create_bids_filename(fields, entities_to_include) 

408 

409 fields["entities"]["stat"] = None 

410 fields["suffix"] = "clusters" 

411 fields["extension"] = "tsv" 

412 tmp["clusters_tsv"] = create_bids_filename(fields, entities_to_include) 

413 

414 fields["extension"] = "json" 

415 tmp["metadata"] = create_bids_filename(fields, entities_to_include) 

416 

417 statistical_maps[contrast_name] = tempita.bunch(**tmp) 

418 

419 return statistical_maps 

420 

421 

422def _generate_model_level_mapping( 

423 model, 

424 prefix: str, 

425 design_matrices, 

426 generate_bids_name: bool, 

427 entities, 

428 entities_to_include: list[str], 

429): 

430 """Return dictionary of filenames for nifti of runwise error & residuals. 

431 

432 model_level_mapping[i_run][statmap_label] = filename 

433 """ 

434 fields = { 

435 "prefix": prefix, 

436 "suffix": "statmap", 

437 "extension": "nii.gz", 

438 "entities": deepcopy(entities), 

439 } 

440 

441 if generate_bids_name: 

442 fields["prefix"] = None 

443 

444 model_level_mapping = {} 

445 

446 for i_run, _ in enumerate(design_matrices): 

447 if _is_flm_with_single_run(model): 

448 fields["entities"]["run"] = i_run + 1 

449 if generate_bids_name: 

450 fields["entities"] = deepcopy( 

451 model._reporting_data["run_imgs"][i_run]["entities"] 

452 ) 

453 

454 tmp = {} 

455 for key, stat_label in zip( 

456 ["residuals", "r_square"], 

457 ["errorts", "rsquared"], 

458 ): 

459 fields["entities"]["stat"] = stat_label 

460 tmp[key] = create_bids_filename(fields, entities_to_include) 

461 

462 model_level_mapping[i_run] = tempita.bunch(**tmp) 

463 

464 return model_level_mapping 

465 

466 

467def _generate_design_matrices_dict( 

468 model, 

469 prefix: str, 

470 design_matrices, 

471 generate_bids_name: bool, 

472 entities_to_include: list[str], 

473) -> dict[int, dict[str, str]]: 

474 """Return dictionary with filenames for design_matrices figures / tables. 

475 

476 design_matrices_dict[i_run][key] = filename 

477 """ 

478 fields = {"prefix": prefix, "extension": FIGURE_FORMAT, "entities": {}} 

479 if generate_bids_name: 

480 fields["prefix"] = None # type: ignore[assignment] 

481 

482 design_matrices_dict = tempita.bunch() 

483 

484 for i_run, _ in enumerate(design_matrices): 

485 if _is_flm_with_single_run(model): 

486 fields["entities"] = {"run": i_run + 1} # type: ignore[assignment] 

487 if generate_bids_name: 

488 fields["entities"] = deepcopy( 

489 model._reporting_data["run_imgs"][i_run]["entities"] 

490 ) 

491 

492 tmp = {} 

493 for extension in [FIGURE_FORMAT, "tsv"]: 

494 for key, suffix in zip( 

495 ["design_matrix", "correlation_matrix"], 

496 ["design", "corrdesign"], 

497 ): 

498 fields["extension"] = extension 

499 fields["suffix"] = suffix 

500 tmp[f"{key}_{extension}"] = create_bids_filename( 

501 fields, entities_to_include 

502 ) 

503 

504 design_matrices_dict[i_run] = tempita.bunch(**tmp) 

505 

506 return design_matrices_dict 

507 

508 

509def _generate_contrasts_dict( 

510 model, 

511 prefix: str, 

512 contrasts, 

513 design_matrices, 

514 generate_bids_name: bool, 

515 entities, 

516 entities_to_include: list[str], 

517) -> dict[int, dict[str, str]]: 

518 """Return dictionary with filenames for contrast matrices figures. 

519 

520 contrasts_dict[i_run][contrast_name] = filename 

521 """ 

522 fields = { 

523 "prefix": prefix, 

524 "extension": FIGURE_FORMAT, 

525 "entities": deepcopy(entities), 

526 "suffix": "design", 

527 } 

528 if generate_bids_name: 

529 fields["prefix"] = None 

530 

531 contrasts_dict = tempita.bunch() 

532 

533 for i_run, _ in enumerate(design_matrices): 

534 if _is_flm_with_single_run(model): 

535 fields["entities"]["run"] = i_run + 1 

536 if generate_bids_name: 

537 fields["entities"] = deepcopy( 

538 model._reporting_data["run_imgs"][i_run]["entities"] 

539 ) 

540 

541 tmp = {} 

542 for contrast_name in contrasts: 

543 fields["entities"]["contrast"] = _clean_contrast_name( 

544 contrast_name 

545 ) 

546 tmp[contrast_name] = create_bids_filename( 

547 fields, entities_to_include 

548 ) 

549 

550 contrasts_dict[i_run] = tempita.bunch(**tmp) 

551 

552 return contrasts_dict 

553 

554 

555def _use_input_files_for_filenaming(self, prefix) -> bool: 

556 """Determine if we should try to use input files to generate \ 

557 output filenames. 

558 """ 

559 if self.__str__() == "Second Level Model" or prefix is not None: 

560 return False 

561 

562 input_files = self._reporting_data["run_imgs"] 

563 

564 files_used_as_input = all(len(x) > 0 for x in input_files.values()) 

565 tmp = {x.get("sub") for x in input_files.values()} 

566 all_files_have_same_sub = len(tmp) == 1 and tmp is not None 

567 

568 return files_used_as_input and all_files_have_same_sub 

569 

570 

571def _is_flm_with_single_run(model) -> bool: 

572 return ( 

573 model.__str__() == "First Level Model" 

574 and len(model._reporting_data["run_imgs"]) > 1 

575 ) 

576 

577 

578def _clean_contrast_name(contrast_name): 

579 """Remove prohibited characters from name and convert to camelCase. 

580 

581 .. versionadded:: 0.9.2 

582 

583 BIDS filenames, in which the contrast name will appear as a 

584 contrast-<name> key/value pair, must be alphanumeric strings. 

585 

586 Parameters 

587 ---------- 

588 contrast_name : :obj:`str` 

589 Contrast name to clean. 

590 

591 Returns 

592 ------- 

593 new_name : :obj:`str` 

594 Contrast name converted to alphanumeric-only camelCase. 

595 """ 

596 new_name = contrast_name[:] 

597 

598 # Some characters translate to words 

599 new_name = new_name.replace("-", " Minus ") 

600 new_name = new_name.replace("+", " Plus ") 

601 new_name = new_name.replace(">", " Gt ") 

602 new_name = new_name.replace("<", " Lt ") 

603 

604 # Others translate to spaces 

605 new_name = new_name.replace("_", " ") 

606 

607 # Convert to camelCase 

608 new_name = new_name.split(" ") 

609 new_name[0] = new_name[0].lower() 

610 new_name[1:] = [c.title() for c in new_name[1:]] 

611 new_name = " ".join(new_name) 

612 

613 # Remove non-alphanumeric characters 

614 new_name = "".join(ch for ch in new_name if ch.isalnum()) 

615 

616 # Let users know if the name was changed 

617 if new_name != contrast_name: 

618 warnings.warn( 

619 f'Contrast name "{contrast_name}" changed to "{new_name}"', 

620 stacklevel=find_stack_level(), 

621 ) 

622 return new_name