Coverage for nilearn/_utils/tests/test_data_gen.py: 0%

296 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-20 10:58 +0200

1"""Tests for the data generation utilities.""" 

2 

3from __future__ import annotations 

4 

5import json 

6 

7import numpy as np 

8import pandas as pd 

9import pytest 

10from nibabel import load 

11from numpy.testing import assert_almost_equal 

12from pandas.api.types import is_numeric_dtype, is_object_dtype 

13from pandas.testing import assert_frame_equal 

14 

15from nilearn._utils.data_gen import ( 

16 add_metadata_to_bids_dataset, 

17 basic_paradigm, 

18 create_fake_bids_dataset, 

19 generate_fake_fmri, 

20 generate_fake_fmri_data_and_design, 

21 generate_group_sparse_gaussian_graphs, 

22 generate_labeled_regions, 

23 generate_maps, 

24 generate_mni_space_img, 

25 generate_random_img, 

26 generate_regions_ts, 

27 generate_timeseries, 

28 write_fake_bold_img, 

29 write_fake_fmri_data_and_design, 

30) 

31from nilearn.image import get_data 

32 

33 

34def test_add_metadata_to_bids_derivatives_default_path(tmp_path): 

35 """Check the filename created is the default value \ 

36 of add_metadata_to_bids_dataset. 

37 """ 

38 target_dir = tmp_path / "derivatives" / "sub-01" / "ses-01" / "func" 

39 target_dir.mkdir(parents=True) 

40 json_file = add_metadata_to_bids_dataset( 

41 bids_path=tmp_path, metadata={"foo": "bar"} 

42 ) 

43 assert json_file.exists() 

44 assert ( 

45 json_file.name 

46 == "sub-01_ses-01_task-main_run-01_space-MNI_desc-preproc_bold.json" 

47 ) 

48 with json_file.open() as f: 

49 metadata = json.load(f) 

50 assert metadata == {"foo": "bar"} 

51 

52 

53def test_add_metadata_to_bids_derivatives_with_json_path(tmp_path): 

54 # bare bone smoke test 

55 target_dir = tmp_path / "derivatives" / "sub-02" 

56 target_dir.mkdir(parents=True) 

57 json_file = "derivatives/sub-02/sub-02_task-main_bold.json" 

58 json_file = add_metadata_to_bids_dataset( 

59 bids_path=tmp_path, metadata={"foo": "bar"}, json_file=json_file 

60 ) 

61 assert json_file.exists() 

62 assert json_file.name == "sub-02_task-main_bold.json" 

63 with json_file.open() as f: 

64 metadata = json.load(f) 

65 assert metadata == {"foo": "bar"} 

66 

67 

68@pytest.mark.parametrize("have_spaces", [False, True]) 

69def test_basic_paradigm(have_spaces): 

70 events = basic_paradigm(condition_names_have_spaces=have_spaces) 

71 

72 assert events.columns.equals(pd.Index(["trial_type", "onset", "duration"])) 

73 assert is_object_dtype(events["trial_type"]) 

74 assert is_numeric_dtype(events["onset"]) 

75 assert is_numeric_dtype(events["duration"]) 

76 assert events["trial_type"].str.contains(" ").any() == have_spaces 

77 

78 

79@pytest.mark.parametrize("shape", [(3, 4, 5), (2, 3, 5, 7)]) 

80@pytest.mark.parametrize("affine", [None, np.diag([0.5, 0.3, 1, 1])]) 

81def test_write_fake_bold_img(tmp_path, shape, affine, rng): 

82 img_file = write_fake_bold_img( 

83 file_path=tmp_path / "fake_bold.nii", 

84 shape=shape, 

85 affine=affine, 

86 random_state=rng, 

87 ) 

88 img = load(img_file) 

89 

90 assert img.get_fdata().shape == shape 

91 if affine is not None: 

92 assert_almost_equal(img.affine, affine) 

93 

94 

95def _bids_path_template( 

96 task, 

97 suffix, 

98 n_runs=None, 

99 space=None, 

100 desc=None, 

101 extra_entity=None, 

102): 

103 """Create a BIDS filepath from a template. 

104 

105 File path is relative to the BIDS root folder. 

106 

107 File path contains a session level folder. 

108 

109 """ 

110 task = f"task-{task}_*" 

111 run = "run-*_*" if n_runs is not None else "*" 

112 space = f"space-{space}_*" if space is not None else "*" 

113 desc = f"desc-{desc}_*" if desc is not None else "*" 

114 

115 # only using with resolution and acquisition entities (for now) 

116 acq = "*" 

117 res = "*" 

118 if extra_entity is not None: 

119 if "acq" in extra_entity: 

120 acq = f"acq-{extra_entity['acq']}_*" 

121 elif "res" in extra_entity: 

122 res = f"res-{extra_entity['res']}_*" 

123 

124 path = "sub-*/ses-*/func/sub-*_ses-*_*" 

125 path += f"{task}{acq}{run}{space}{res}{desc}{suffix}" 

126 # TODO use regex 

127 path = path.replace("***", "*") 

128 path = path.replace("**", "*") 

129 return path 

130 

131 

132@pytest.mark.parametrize("n_sub", [1, 2]) 

133@pytest.mark.parametrize("n_ses", [1, 2]) 

134@pytest.mark.parametrize( 

135 "tasks,n_runs", 

136 [(["main"], [1]), (["main"], [2]), (["main", "localizer"], [2, 1])], 

137) 

138def test_fake_bids_raw_with_session_and_runs( 

139 tmp_path, n_sub, n_ses, tasks, n_runs 

140): 

141 """Check number of each file 'type' created in raw.""" 

142 bids_path = create_fake_bids_dataset( 

143 base_dir=tmp_path, n_sub=n_sub, n_ses=n_ses, tasks=tasks, n_runs=n_runs 

144 ) 

145 

146 # raw 

147 file_pattern = "sub-*/ses-*/anat/sub-*ses-*T1w.nii.gz" 

148 raw_anat_files = list(bids_path.glob(file_pattern)) 

149 assert len(raw_anat_files) == n_sub 

150 

151 for i, task in enumerate(tasks): 

152 for suffix in ["bold.nii.gz", "bold.json", "events.tsv"]: 

153 file_pattern = _bids_path_template( 

154 task=task, suffix=suffix, n_runs=n_runs[i] 

155 ) 

156 files = list(bids_path.glob(file_pattern)) 

157 assert len(files) == n_sub * n_ses * n_runs[i] 

158 

159 all_files = list(bids_path.glob("sub-*/ses-*/*/*")) 

160 # per subject: 1 anat + (1 event + 1 json + 1 bold) per run per session 

161 n_raw_files_expected = n_sub * (1 + 3 * sum(n_runs) * n_ses) 

162 assert len(all_files) == n_raw_files_expected 

163 

164 

165def _check_n_files_derivatives_for_task( 

166 bids_path, 

167 n_sub, 

168 n_ses, 

169 task, 

170 n_run, 

171 extra_entity=None, 

172): 

173 """Check number of each file 'type' in derivatives for a given task.""" 

174 for suffix in ["timeseries.tsv"]: 

175 # 1 confound per raw file 

176 # so we do not use the extra entity for derivatives entities like res 

177 if extra_entity is None or "res" in extra_entity: 

178 file_pattern = _bids_path_template( 

179 task=task, 

180 suffix=suffix, 

181 n_runs=n_run, 

182 extra_entity=None, 

183 ) 

184 elif "acq" in extra_entity: 

185 file_pattern = _bids_path_template( 

186 task=task, 

187 suffix=suffix, 

188 n_runs=n_run, 

189 extra_entity=extra_entity, 

190 ) 

191 

192 files = list(bids_path.glob(f"derivatives/{file_pattern}")) 

193 assert len(files) == n_sub * n_ses * n_run 

194 

195 for space in ["MNI", "T1w"]: 

196 file_pattern = _bids_path_template( 

197 task=task, 

198 suffix="bold.nii.gz", 

199 n_runs=n_run, 

200 space=space, 

201 desc="preproc", 

202 extra_entity=extra_entity, 

203 ) 

204 files = list(bids_path.glob(f"derivatives/{file_pattern}")) 

205 assert len(files) == n_sub * n_ses * n_run 

206 

207 # only T1w have desc-fmriprep_bold 

208 file_pattern = _bids_path_template( 

209 task=task, 

210 suffix="bold.nii.gz", 

211 n_runs=n_run, 

212 space="T1w", 

213 desc="fmriprep", 

214 extra_entity=extra_entity, 

215 ) 

216 files = list(bids_path.glob(f"derivatives/{file_pattern}")) 

217 assert len(files) == n_sub * n_ses * n_run 

218 

219 file_pattern = _bids_path_template( 

220 task=task, 

221 suffix="bold.nii.gz", 

222 n_runs=n_run, 

223 space="MNI", 

224 desc="fmriprep", 

225 extra_entity=extra_entity, 

226 ) 

227 files = list(bids_path.glob(f"derivatives/{file_pattern}")) 

228 assert not files 

229 

230 

231@pytest.mark.parametrize("n_sub", [1, 2]) 

232@pytest.mark.parametrize("n_ses", [1, 2]) 

233@pytest.mark.parametrize( 

234 "tasks,n_runs", 

235 [(["main"], [1]), (["main"], [2]), (["main", "localizer"], [2, 1])], 

236) 

237def test_fake_bids_derivatives_with_session_and_runs( 

238 tmp_path, n_sub, n_ses, tasks, n_runs 

239): 

240 """Check number of each file 'type' created in derivatives.""" 

241 bids_path = create_fake_bids_dataset( 

242 base_dir=tmp_path, n_sub=n_sub, n_ses=n_ses, tasks=tasks, n_runs=n_runs 

243 ) 

244 

245 # derivatives 

246 for task, n_run in zip(tasks, n_runs): 

247 _check_n_files_derivatives_for_task( 

248 bids_path=bids_path, 

249 n_sub=n_sub, 

250 n_ses=n_ses, 

251 task=task, 

252 n_run=n_run, 

253 ) 

254 

255 all_files = list(bids_path.glob("derivatives/sub-*/ses-*/*/*")) 

256 # per subject: (2 confound + 3 bold + 2 gifti) per run per session 

257 n_derivatives_files_expected = n_sub * (7 * sum(n_runs) * n_ses) 

258 assert len(all_files) == n_derivatives_files_expected 

259 

260 

261def test_bids_dataset_no_run_entity(tmp_path): 

262 """n_runs = 0 produces files without the run entity.""" 

263 bids_path = create_fake_bids_dataset( 

264 base_dir=tmp_path, 

265 n_sub=1, 

266 n_ses=1, 

267 tasks=["main"], 

268 n_runs=[0], 

269 with_derivatives=True, 

270 ) 

271 

272 files = list(bids_path.glob("**/*run-*")) 

273 assert not files 

274 

275 # nifti: 1 anat + 1 raw bold + 3 derivatives bold 

276 files = list(bids_path.glob("**/*.nii.gz")) 

277 assert len(files) == 5 

278 

279 # events or json or confounds: 1 

280 for suffix in ["events.tsv", "timeseries.tsv", "bold.json"]: 

281 files = list(bids_path.glob(f"**/*{suffix}")) 

282 assert len(files) == 1 

283 

284 

285def test_bids_dataset_no_session(tmp_path): 

286 """n_ses = 0 prevent creation of a session folder.""" 

287 bids_path = create_fake_bids_dataset( 

288 base_dir=tmp_path, 

289 n_sub=1, 

290 n_ses=0, 

291 tasks=["main"], 

292 n_runs=[1], 

293 with_derivatives=True, 

294 ) 

295 

296 files = list(bids_path.glob("**/*ses-*")) 

297 assert not files 

298 

299 # nifti: 1 anat + 1 raw bold + 3 derivatives bold 

300 files = list(bids_path.glob("**/*.nii.gz")) 

301 assert len(files) == 5 

302 

303 # events or json or confounds: 1 

304 for suffix in ["events.tsv", "timeseries.tsv", "bold.json"]: 

305 files = list(bids_path.glob(f"**/*{suffix}")) 

306 assert len(files) == 1 

307 

308 

309def test_create_fake_bids_dataset_no_derivatives(tmp_path): 

310 """Check no file is created in derivatives.""" 

311 bids_path = create_fake_bids_dataset( 

312 base_dir=tmp_path, 

313 n_sub=1, 

314 n_ses=1, 

315 tasks=["main"], 

316 n_runs=[2], 

317 with_derivatives=False, 

318 ) 

319 files = list(bids_path.glob("derivatives/**")) 

320 assert not files 

321 

322 

323@pytest.mark.parametrize( 

324 "confounds_tag,with_confounds", [(None, True), ("_timeseries", False)] 

325) 

326def test_create_fake_bids_dataset_no_confounds( 

327 tmp_path, confounds_tag, with_confounds 

328): 

329 """Check that files are created in the derivatives but no confounds.""" 

330 bids_path = create_fake_bids_dataset( 

331 base_dir=tmp_path, 

332 n_sub=1, 

333 n_ses=1, 

334 tasks=["main"], 

335 n_runs=[2], 

336 with_confounds=with_confounds, 

337 confounds_tag=confounds_tag, 

338 ) 

339 assert list(bids_path.glob("derivatives/*")) 

340 files = list(bids_path.glob("derivatives/*/*/func/*timeseries.tsv")) 

341 assert not files 

342 

343 

344def test_fake_bids_errors(tmp_path): 

345 with pytest.raises(ValueError, match="labels.*alphanumeric"): 

346 create_fake_bids_dataset( 

347 base_dir=tmp_path, n_sub=1, n_ses=1, tasks=["foo_bar"], n_runs=[1] 

348 ) 

349 

350 with pytest.raises(ValueError, match="labels.*alphanumeric"): 

351 create_fake_bids_dataset( 

352 base_dir=tmp_path, 

353 n_sub=1, 

354 n_ses=1, 

355 tasks=["main"], 

356 n_runs=[1], 

357 entities={"acq": "foo_bar"}, 

358 ) 

359 

360 with pytest.raises(ValueError, match="number.*tasks.*runs.*same"): 

361 create_fake_bids_dataset( 

362 base_dir=tmp_path, 

363 n_sub=1, 

364 n_ses=1, 

365 tasks=["main"], 

366 n_runs=[1, 2], 

367 ) 

368 

369 

370def test_fake_bids_extra_raw_entity(tmp_path): 

371 """Check files with extra entity are created appropriately.""" 

372 n_sub = 2 

373 n_ses = 2 

374 tasks = ["main"] 

375 n_runs = [2] 

376 entities = {"acq": ["foo", "bar"]} 

377 bids_path = create_fake_bids_dataset( 

378 base_dir=tmp_path, 

379 n_sub=n_sub, 

380 n_ses=n_ses, 

381 tasks=tasks, 

382 n_runs=n_runs, 

383 entities=entities, 

384 ) 

385 

386 # raw 

387 for i, task in enumerate(tasks): 

388 for suffix in ["bold.nii.gz", "bold.json", "events.tsv"]: 

389 for label in entities["acq"]: 

390 file_pattern = _bids_path_template( 

391 task=task, 

392 suffix=suffix, 

393 n_runs=n_runs[i], 

394 extra_entity={"acq": label}, 

395 ) 

396 files = list(bids_path.glob(file_pattern)) 

397 assert len(files) == n_sub * n_ses * n_runs[i] 

398 

399 all_files = list(bids_path.glob("sub-*/ses-*/*/*")) 

400 # per subject: 

401 # 1 anat + (1 event + 1 json + 1 bold) per entity per run per session 

402 n_raw_files_expected = n_sub * ( 

403 1 + 3 * sum(n_runs) * n_ses * len(entities["acq"]) 

404 ) 

405 assert len(all_files) == n_raw_files_expected 

406 

407 # derivatives 

408 for label in entities["acq"]: 

409 for task, n_run in zip(tasks, n_runs): 

410 _check_n_files_derivatives_for_task( 

411 bids_path=bids_path, 

412 n_sub=n_sub, 

413 n_ses=n_ses, 

414 task=task, 

415 n_run=n_run, 

416 extra_entity={"acq": label}, 

417 ) 

418 

419 all_files = list(bids_path.glob("derivatives/sub-*/ses-*/*/*")) 

420 # per subject: (2 confound + 3 bold + 2 gifti) 

421 # per run per session per entity 

422 n_derivatives_files_expected = ( 

423 n_sub * (7 * sum(n_runs) * n_ses) * len(entities["acq"]) 

424 ) 

425 assert len(all_files) == n_derivatives_files_expected 

426 

427 

428def test_fake_bids_extra_derivative_entity(tmp_path): 

429 """Check files with extra entity are created appropriately.""" 

430 n_sub = 2 

431 n_ses = 2 

432 tasks = ["main"] 

433 n_runs = [2] 

434 entities = {"res": ["foo", "bar"]} 

435 bids_path = create_fake_bids_dataset( 

436 base_dir=tmp_path, 

437 n_sub=n_sub, 

438 n_ses=n_ses, 

439 tasks=tasks, 

440 n_runs=n_runs, 

441 entities=entities, 

442 ) 

443 

444 # raw 

445 all_files = list(bids_path.glob("sub-*/ses-*/*/*res*")) 

446 assert not all_files 

447 

448 # derivatives 

449 for label in entities["res"]: 

450 for task, n_run in zip(tasks, n_runs): 

451 _check_n_files_derivatives_for_task( 

452 bids_path=bids_path, 

453 n_sub=n_sub, 

454 n_ses=n_ses, 

455 task=task, 

456 n_run=n_run, 

457 extra_entity={"res": label}, 

458 ) 

459 

460 all_files = list(bids_path.glob("derivatives/sub-*/ses-*/*/*")) 

461 # per subject: 

462 # 1 confound per run per session 

463 # + (3 bold + 2 gifti) per run per session per entity 

464 n_derivatives_files_expected = n_sub * ( 

465 2 * sum(n_runs) * n_ses 

466 + 5 * sum(n_runs) * n_ses * len(entities["res"]) 

467 ) 

468 assert len(all_files) == n_derivatives_files_expected 

469 

470 

471def test_fake_bids_extra_entity_not_bids_entity(tmp_path): 

472 """Check files with extra entity are created appropriately.""" 

473 with pytest.raises(ValueError, match="Invalid entity"): 

474 create_fake_bids_dataset( 

475 base_dir=tmp_path, 

476 entities={"egg": ["spam"]}, 

477 ) 

478 

479 

480@pytest.mark.parametrize("window", ["boxcar", "hamming"]) 

481def test_generate_regions_ts_no_overlap(window): 

482 n_voxels = 50 

483 n_regions = 10 

484 

485 regions = generate_regions_ts( 

486 n_voxels, n_regions, overlap=0, window=window 

487 ) 

488 

489 assert regions.shape == (n_regions, n_voxels) 

490 # check no overlap 

491 np.testing.assert_array_less( 

492 (regions > 0).sum(axis=0) - 0.1, np.ones(regions.shape[1]) 

493 ) 

494 # check: a region everywhere 

495 np.testing.assert_array_less( 

496 np.zeros(regions.shape[1]), (regions > 0).sum(axis=0) 

497 ) 

498 

499 

500@pytest.mark.parametrize("window", ["boxcar", "hamming"]) 

501def test_generate_regions_ts_with_overlap(window): 

502 n_voxels = 50 

503 n_regions = 10 

504 

505 regions = generate_regions_ts( 

506 n_voxels, n_regions, overlap=1, window=window 

507 ) 

508 

509 assert regions.shape == (n_regions, n_voxels) 

510 # check overlap 

511 assert np.any((regions > 0).sum(axis=-1) > 1.9) 

512 # check: a region everywhere 

513 np.testing.assert_array_less( 

514 np.zeros(regions.shape[1]), (regions > 0).sum(axis=0) 

515 ) 

516 

517 

518def test_generate_labeled_regions(): 

519 """Minimal testing of generate_labeled_regions.""" 

520 shape = (3, 4, 5) 

521 n_regions = 10 

522 regions = generate_labeled_regions(shape, n_regions) 

523 assert regions.shape == shape 

524 assert len(np.unique(get_data(regions))) == n_regions + 1 

525 

526 

527def test_generate_maps(): 

528 # Basic testing of generate_maps() 

529 shape = (10, 11, 12) 

530 n_regions = 9 

531 maps_img, _ = generate_maps(shape, n_regions, border=1) 

532 maps = get_data(maps_img) 

533 assert maps.shape == (*shape, n_regions) 

534 # no empty map 

535 assert np.all(abs(maps).sum(axis=0).sum(axis=0).sum(axis=0) > 0) 

536 # check border 

537 assert np.all(maps[0, ...] == 0) 

538 assert np.all(maps[:, 0, ...] == 0) 

539 assert np.all(maps[:, :, 0, :] == 0) 

540 

541 

542@pytest.mark.parametrize("shape", [(10, 11, 12), (6, 6, 7)]) 

543@pytest.mark.parametrize("length", [16, 20]) 

544@pytest.mark.parametrize("kind", ["noise", "step"]) 

545@pytest.mark.parametrize( 

546 "n_block,block_size,block_type", 

547 [ 

548 (None, None, None), 

549 (1, 1, "classification"), 

550 (4, 3, "classification"), 

551 (4, 4, "regression"), 

552 ], 

553) 

554def test_generate_fake_fmri( 

555 shape, length, kind, n_block, block_size, block_type, rng 

556): 

557 fake_fmri = generate_fake_fmri( 

558 shape=shape, 

559 length=length, 

560 kind=kind, 

561 n_blocks=n_block, 

562 block_size=block_size, 

563 block_type=block_type, 

564 random_state=rng, 

565 ) 

566 

567 assert fake_fmri[0].shape[:-1] == shape 

568 assert fake_fmri[0].shape[-1] == length 

569 if n_block is not None: 

570 assert fake_fmri[2].size == length 

571 

572 

573def test_generate_fake_fmri_error(rng): 

574 with pytest.raises(ValueError, match="10 is too small"): 

575 generate_fake_fmri( 

576 length=10, 

577 n_blocks=10, 

578 block_size=3, 

579 random_state=rng, 

580 ) 

581 

582 

583@pytest.mark.parametrize( 

584 "shapes", [[(2, 3, 5, 7)], [(5, 5, 5, 3), (5, 5, 5, 5)]] 

585) 

586@pytest.mark.parametrize("rank", [1, 3, 5]) 

587@pytest.mark.parametrize("affine", [None, np.diag([0.5, 0.3, 1, 1])]) 

588def test_fake_fmri_data_and_design_generate(shapes, rank, affine): 

589 # test generate 

590 mask, fmri_data, design_matrices = generate_fake_fmri_data_and_design( 

591 shapes, rk=rank, affine=affine, random_state=42 

592 ) 

593 

594 for fmri, shape in zip(fmri_data, shapes): 

595 assert mask.shape == shape[:3] 

596 assert fmri.shape == shape 

597 if affine is not None: 

598 assert_almost_equal(fmri.affine, affine) 

599 

600 for design, shape in zip(design_matrices, shapes): 

601 assert design.shape == (shape[3], rank) 

602 

603 

604@pytest.mark.parametrize( 

605 "shapes", [[(2, 3, 5, 7)], [(5, 5, 5, 3), (5, 5, 5, 5)]] 

606) 

607@pytest.mark.parametrize("rank", [1, 3, 5]) 

608@pytest.mark.parametrize("affine", [None, np.diag([0.5, 0.3, 1, 1])]) 

609def test_fake_fmri_data_and_design_write(tmp_path, shapes, rank, affine): 

610 mask, fmri_data, design_matrices = generate_fake_fmri_data_and_design( 

611 shapes, rk=rank, affine=affine, random_state=42 

612 ) 

613 mask_file, fmri_files, design_files = write_fake_fmri_data_and_design( 

614 shapes, rk=rank, affine=affine, random_state=42, file_path=tmp_path 

615 ) 

616 

617 mask_img = load(mask_file) 

618 assert_almost_equal(mask_img.get_fdata(), mask.get_fdata()) 

619 assert_almost_equal(mask_img.affine, mask.affine) 

620 

621 for fmri_file, fmri in zip(fmri_files, fmri_data): 

622 fmri_img = load(fmri_file) 

623 assert_almost_equal(fmri_img.get_fdata(), fmri.get_fdata()) 

624 assert_almost_equal(fmri_img.affine, fmri.affine) 

625 

626 for design_file, design in zip(design_files, design_matrices): 

627 assert_frame_equal( 

628 pd.read_csv(design_file, sep="\t"), design, check_exact=False 

629 ) 

630 

631 

632@pytest.mark.parametrize("shape", [(3, 4, 5), (2, 3, 5, 7)]) 

633@pytest.mark.parametrize("affine", [None, np.diag([0.5, 0.3, 1, 1])]) 

634def test_generate_random_img(shape, affine, rng): 

635 img, mask = generate_random_img( 

636 shape=shape, affine=affine, random_state=rng 

637 ) 

638 

639 assert img.shape == shape 

640 assert mask.shape == shape[:3] 

641 if affine is not None: 

642 assert_almost_equal(img.affine, affine) 

643 assert_almost_equal(mask.affine, affine) 

644 

645 

646@pytest.mark.parametrize("n_subjects", [5, 9]) 

647@pytest.mark.parametrize("n_features", [30, 9]) 

648@pytest.mark.parametrize("n_samples_range", [(30, 50), (9, 9)]) 

649@pytest.mark.parametrize("density", [0.1, 1]) 

650def test_generate_group_sparse_gaussian_graphs( 

651 n_subjects, n_features, n_samples_range, density, rng 

652): 

653 signals, precisions, topology = generate_group_sparse_gaussian_graphs( 

654 n_subjects=n_subjects, 

655 n_features=n_features, 

656 min_n_samples=n_samples_range[0], 

657 max_n_samples=n_samples_range[1], 

658 density=density, 

659 random_state=rng, 

660 ) 

661 

662 assert len(signals) == n_subjects 

663 assert len(precisions) == n_subjects 

664 

665 signal_shapes = np.array([s.shape for s in signals]) 

666 precision_shapes = np.array([p.shape for p in precisions]) 

667 assert np.all( 

668 (signal_shapes[:, 0] >= n_samples_range[0]) 

669 & (signal_shapes[:, 0] <= n_samples_range[1]) 

670 ) 

671 assert np.all(signal_shapes[:, 1] == n_features) 

672 assert np.all(precision_shapes == (n_features, n_features)) 

673 assert topology.shape == (n_features, n_features) 

674 

675 eigenvalues = np.array([np.linalg.eigvalsh(p) for p in precisions]) 

676 assert np.all(eigenvalues >= 0) 

677 

678 

679@pytest.mark.parametrize("n_timepoints", [1, 9]) 

680@pytest.mark.parametrize("n_features", [1, 9]) 

681def test_generate_timeseries(n_timepoints, n_features, rng): 

682 timeseries = generate_timeseries(n_timepoints, n_features, rng) 

683 assert timeseries.shape == (n_timepoints, n_features) 

684 

685 

686@pytest.mark.parametrize("n_scans", [1, 5]) 

687@pytest.mark.parametrize("res", [1, 30]) 

688@pytest.mark.parametrize("mask_dilation", [1, 2]) 

689def test_generate_mni_space_img(n_scans, res, mask_dilation, rng): 

690 inverse_img, mask_img = generate_mni_space_img( 

691 n_scans=n_scans, res=res, mask_dilation=mask_dilation, random_state=rng 

692 ) 

693 

694 def resample_dim(orig, res): 

695 return (orig - 2) // res + 2 

696 

697 expected_shape = ( 

698 resample_dim(197, res), 

699 resample_dim(233, res), 

700 resample_dim(189, res), 

701 ) 

702 assert inverse_img.shape[:3] == expected_shape 

703 assert inverse_img.shape[3] == n_scans 

704 assert mask_img.shape == expected_shape 

705 assert_almost_equal(inverse_img.affine, mask_img.affine)