Coverage for nilearn/_utils/tests/test_data_gen.py: 0%
296 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
1"""Tests for the data generation utilities."""
3from __future__ import annotations
5import json
7import numpy as np
8import pandas as pd
9import pytest
10from nibabel import load
11from numpy.testing import assert_almost_equal
12from pandas.api.types import is_numeric_dtype, is_object_dtype
13from pandas.testing import assert_frame_equal
15from nilearn._utils.data_gen import (
16 add_metadata_to_bids_dataset,
17 basic_paradigm,
18 create_fake_bids_dataset,
19 generate_fake_fmri,
20 generate_fake_fmri_data_and_design,
21 generate_group_sparse_gaussian_graphs,
22 generate_labeled_regions,
23 generate_maps,
24 generate_mni_space_img,
25 generate_random_img,
26 generate_regions_ts,
27 generate_timeseries,
28 write_fake_bold_img,
29 write_fake_fmri_data_and_design,
30)
31from nilearn.image import get_data
34def test_add_metadata_to_bids_derivatives_default_path(tmp_path):
35 """Check the filename created is the default value \
36 of add_metadata_to_bids_dataset.
37 """
38 target_dir = tmp_path / "derivatives" / "sub-01" / "ses-01" / "func"
39 target_dir.mkdir(parents=True)
40 json_file = add_metadata_to_bids_dataset(
41 bids_path=tmp_path, metadata={"foo": "bar"}
42 )
43 assert json_file.exists()
44 assert (
45 json_file.name
46 == "sub-01_ses-01_task-main_run-01_space-MNI_desc-preproc_bold.json"
47 )
48 with json_file.open() as f:
49 metadata = json.load(f)
50 assert metadata == {"foo": "bar"}
53def test_add_metadata_to_bids_derivatives_with_json_path(tmp_path):
54 # bare bone smoke test
55 target_dir = tmp_path / "derivatives" / "sub-02"
56 target_dir.mkdir(parents=True)
57 json_file = "derivatives/sub-02/sub-02_task-main_bold.json"
58 json_file = add_metadata_to_bids_dataset(
59 bids_path=tmp_path, metadata={"foo": "bar"}, json_file=json_file
60 )
61 assert json_file.exists()
62 assert json_file.name == "sub-02_task-main_bold.json"
63 with json_file.open() as f:
64 metadata = json.load(f)
65 assert metadata == {"foo": "bar"}
68@pytest.mark.parametrize("have_spaces", [False, True])
69def test_basic_paradigm(have_spaces):
70 events = basic_paradigm(condition_names_have_spaces=have_spaces)
72 assert events.columns.equals(pd.Index(["trial_type", "onset", "duration"]))
73 assert is_object_dtype(events["trial_type"])
74 assert is_numeric_dtype(events["onset"])
75 assert is_numeric_dtype(events["duration"])
76 assert events["trial_type"].str.contains(" ").any() == have_spaces
79@pytest.mark.parametrize("shape", [(3, 4, 5), (2, 3, 5, 7)])
80@pytest.mark.parametrize("affine", [None, np.diag([0.5, 0.3, 1, 1])])
81def test_write_fake_bold_img(tmp_path, shape, affine, rng):
82 img_file = write_fake_bold_img(
83 file_path=tmp_path / "fake_bold.nii",
84 shape=shape,
85 affine=affine,
86 random_state=rng,
87 )
88 img = load(img_file)
90 assert img.get_fdata().shape == shape
91 if affine is not None:
92 assert_almost_equal(img.affine, affine)
95def _bids_path_template(
96 task,
97 suffix,
98 n_runs=None,
99 space=None,
100 desc=None,
101 extra_entity=None,
102):
103 """Create a BIDS filepath from a template.
105 File path is relative to the BIDS root folder.
107 File path contains a session level folder.
109 """
110 task = f"task-{task}_*"
111 run = "run-*_*" if n_runs is not None else "*"
112 space = f"space-{space}_*" if space is not None else "*"
113 desc = f"desc-{desc}_*" if desc is not None else "*"
115 # only using with resolution and acquisition entities (for now)
116 acq = "*"
117 res = "*"
118 if extra_entity is not None:
119 if "acq" in extra_entity:
120 acq = f"acq-{extra_entity['acq']}_*"
121 elif "res" in extra_entity:
122 res = f"res-{extra_entity['res']}_*"
124 path = "sub-*/ses-*/func/sub-*_ses-*_*"
125 path += f"{task}{acq}{run}{space}{res}{desc}{suffix}"
126 # TODO use regex
127 path = path.replace("***", "*")
128 path = path.replace("**", "*")
129 return path
132@pytest.mark.parametrize("n_sub", [1, 2])
133@pytest.mark.parametrize("n_ses", [1, 2])
134@pytest.mark.parametrize(
135 "tasks,n_runs",
136 [(["main"], [1]), (["main"], [2]), (["main", "localizer"], [2, 1])],
137)
138def test_fake_bids_raw_with_session_and_runs(
139 tmp_path, n_sub, n_ses, tasks, n_runs
140):
141 """Check number of each file 'type' created in raw."""
142 bids_path = create_fake_bids_dataset(
143 base_dir=tmp_path, n_sub=n_sub, n_ses=n_ses, tasks=tasks, n_runs=n_runs
144 )
146 # raw
147 file_pattern = "sub-*/ses-*/anat/sub-*ses-*T1w.nii.gz"
148 raw_anat_files = list(bids_path.glob(file_pattern))
149 assert len(raw_anat_files) == n_sub
151 for i, task in enumerate(tasks):
152 for suffix in ["bold.nii.gz", "bold.json", "events.tsv"]:
153 file_pattern = _bids_path_template(
154 task=task, suffix=suffix, n_runs=n_runs[i]
155 )
156 files = list(bids_path.glob(file_pattern))
157 assert len(files) == n_sub * n_ses * n_runs[i]
159 all_files = list(bids_path.glob("sub-*/ses-*/*/*"))
160 # per subject: 1 anat + (1 event + 1 json + 1 bold) per run per session
161 n_raw_files_expected = n_sub * (1 + 3 * sum(n_runs) * n_ses)
162 assert len(all_files) == n_raw_files_expected
165def _check_n_files_derivatives_for_task(
166 bids_path,
167 n_sub,
168 n_ses,
169 task,
170 n_run,
171 extra_entity=None,
172):
173 """Check number of each file 'type' in derivatives for a given task."""
174 for suffix in ["timeseries.tsv"]:
175 # 1 confound per raw file
176 # so we do not use the extra entity for derivatives entities like res
177 if extra_entity is None or "res" in extra_entity:
178 file_pattern = _bids_path_template(
179 task=task,
180 suffix=suffix,
181 n_runs=n_run,
182 extra_entity=None,
183 )
184 elif "acq" in extra_entity:
185 file_pattern = _bids_path_template(
186 task=task,
187 suffix=suffix,
188 n_runs=n_run,
189 extra_entity=extra_entity,
190 )
192 files = list(bids_path.glob(f"derivatives/{file_pattern}"))
193 assert len(files) == n_sub * n_ses * n_run
195 for space in ["MNI", "T1w"]:
196 file_pattern = _bids_path_template(
197 task=task,
198 suffix="bold.nii.gz",
199 n_runs=n_run,
200 space=space,
201 desc="preproc",
202 extra_entity=extra_entity,
203 )
204 files = list(bids_path.glob(f"derivatives/{file_pattern}"))
205 assert len(files) == n_sub * n_ses * n_run
207 # only T1w have desc-fmriprep_bold
208 file_pattern = _bids_path_template(
209 task=task,
210 suffix="bold.nii.gz",
211 n_runs=n_run,
212 space="T1w",
213 desc="fmriprep",
214 extra_entity=extra_entity,
215 )
216 files = list(bids_path.glob(f"derivatives/{file_pattern}"))
217 assert len(files) == n_sub * n_ses * n_run
219 file_pattern = _bids_path_template(
220 task=task,
221 suffix="bold.nii.gz",
222 n_runs=n_run,
223 space="MNI",
224 desc="fmriprep",
225 extra_entity=extra_entity,
226 )
227 files = list(bids_path.glob(f"derivatives/{file_pattern}"))
228 assert not files
231@pytest.mark.parametrize("n_sub", [1, 2])
232@pytest.mark.parametrize("n_ses", [1, 2])
233@pytest.mark.parametrize(
234 "tasks,n_runs",
235 [(["main"], [1]), (["main"], [2]), (["main", "localizer"], [2, 1])],
236)
237def test_fake_bids_derivatives_with_session_and_runs(
238 tmp_path, n_sub, n_ses, tasks, n_runs
239):
240 """Check number of each file 'type' created in derivatives."""
241 bids_path = create_fake_bids_dataset(
242 base_dir=tmp_path, n_sub=n_sub, n_ses=n_ses, tasks=tasks, n_runs=n_runs
243 )
245 # derivatives
246 for task, n_run in zip(tasks, n_runs):
247 _check_n_files_derivatives_for_task(
248 bids_path=bids_path,
249 n_sub=n_sub,
250 n_ses=n_ses,
251 task=task,
252 n_run=n_run,
253 )
255 all_files = list(bids_path.glob("derivatives/sub-*/ses-*/*/*"))
256 # per subject: (2 confound + 3 bold + 2 gifti) per run per session
257 n_derivatives_files_expected = n_sub * (7 * sum(n_runs) * n_ses)
258 assert len(all_files) == n_derivatives_files_expected
261def test_bids_dataset_no_run_entity(tmp_path):
262 """n_runs = 0 produces files without the run entity."""
263 bids_path = create_fake_bids_dataset(
264 base_dir=tmp_path,
265 n_sub=1,
266 n_ses=1,
267 tasks=["main"],
268 n_runs=[0],
269 with_derivatives=True,
270 )
272 files = list(bids_path.glob("**/*run-*"))
273 assert not files
275 # nifti: 1 anat + 1 raw bold + 3 derivatives bold
276 files = list(bids_path.glob("**/*.nii.gz"))
277 assert len(files) == 5
279 # events or json or confounds: 1
280 for suffix in ["events.tsv", "timeseries.tsv", "bold.json"]:
281 files = list(bids_path.glob(f"**/*{suffix}"))
282 assert len(files) == 1
285def test_bids_dataset_no_session(tmp_path):
286 """n_ses = 0 prevent creation of a session folder."""
287 bids_path = create_fake_bids_dataset(
288 base_dir=tmp_path,
289 n_sub=1,
290 n_ses=0,
291 tasks=["main"],
292 n_runs=[1],
293 with_derivatives=True,
294 )
296 files = list(bids_path.glob("**/*ses-*"))
297 assert not files
299 # nifti: 1 anat + 1 raw bold + 3 derivatives bold
300 files = list(bids_path.glob("**/*.nii.gz"))
301 assert len(files) == 5
303 # events or json or confounds: 1
304 for suffix in ["events.tsv", "timeseries.tsv", "bold.json"]:
305 files = list(bids_path.glob(f"**/*{suffix}"))
306 assert len(files) == 1
309def test_create_fake_bids_dataset_no_derivatives(tmp_path):
310 """Check no file is created in derivatives."""
311 bids_path = create_fake_bids_dataset(
312 base_dir=tmp_path,
313 n_sub=1,
314 n_ses=1,
315 tasks=["main"],
316 n_runs=[2],
317 with_derivatives=False,
318 )
319 files = list(bids_path.glob("derivatives/**"))
320 assert not files
323@pytest.mark.parametrize(
324 "confounds_tag,with_confounds", [(None, True), ("_timeseries", False)]
325)
326def test_create_fake_bids_dataset_no_confounds(
327 tmp_path, confounds_tag, with_confounds
328):
329 """Check that files are created in the derivatives but no confounds."""
330 bids_path = create_fake_bids_dataset(
331 base_dir=tmp_path,
332 n_sub=1,
333 n_ses=1,
334 tasks=["main"],
335 n_runs=[2],
336 with_confounds=with_confounds,
337 confounds_tag=confounds_tag,
338 )
339 assert list(bids_path.glob("derivatives/*"))
340 files = list(bids_path.glob("derivatives/*/*/func/*timeseries.tsv"))
341 assert not files
344def test_fake_bids_errors(tmp_path):
345 with pytest.raises(ValueError, match="labels.*alphanumeric"):
346 create_fake_bids_dataset(
347 base_dir=tmp_path, n_sub=1, n_ses=1, tasks=["foo_bar"], n_runs=[1]
348 )
350 with pytest.raises(ValueError, match="labels.*alphanumeric"):
351 create_fake_bids_dataset(
352 base_dir=tmp_path,
353 n_sub=1,
354 n_ses=1,
355 tasks=["main"],
356 n_runs=[1],
357 entities={"acq": "foo_bar"},
358 )
360 with pytest.raises(ValueError, match="number.*tasks.*runs.*same"):
361 create_fake_bids_dataset(
362 base_dir=tmp_path,
363 n_sub=1,
364 n_ses=1,
365 tasks=["main"],
366 n_runs=[1, 2],
367 )
370def test_fake_bids_extra_raw_entity(tmp_path):
371 """Check files with extra entity are created appropriately."""
372 n_sub = 2
373 n_ses = 2
374 tasks = ["main"]
375 n_runs = [2]
376 entities = {"acq": ["foo", "bar"]}
377 bids_path = create_fake_bids_dataset(
378 base_dir=tmp_path,
379 n_sub=n_sub,
380 n_ses=n_ses,
381 tasks=tasks,
382 n_runs=n_runs,
383 entities=entities,
384 )
386 # raw
387 for i, task in enumerate(tasks):
388 for suffix in ["bold.nii.gz", "bold.json", "events.tsv"]:
389 for label in entities["acq"]:
390 file_pattern = _bids_path_template(
391 task=task,
392 suffix=suffix,
393 n_runs=n_runs[i],
394 extra_entity={"acq": label},
395 )
396 files = list(bids_path.glob(file_pattern))
397 assert len(files) == n_sub * n_ses * n_runs[i]
399 all_files = list(bids_path.glob("sub-*/ses-*/*/*"))
400 # per subject:
401 # 1 anat + (1 event + 1 json + 1 bold) per entity per run per session
402 n_raw_files_expected = n_sub * (
403 1 + 3 * sum(n_runs) * n_ses * len(entities["acq"])
404 )
405 assert len(all_files) == n_raw_files_expected
407 # derivatives
408 for label in entities["acq"]:
409 for task, n_run in zip(tasks, n_runs):
410 _check_n_files_derivatives_for_task(
411 bids_path=bids_path,
412 n_sub=n_sub,
413 n_ses=n_ses,
414 task=task,
415 n_run=n_run,
416 extra_entity={"acq": label},
417 )
419 all_files = list(bids_path.glob("derivatives/sub-*/ses-*/*/*"))
420 # per subject: (2 confound + 3 bold + 2 gifti)
421 # per run per session per entity
422 n_derivatives_files_expected = (
423 n_sub * (7 * sum(n_runs) * n_ses) * len(entities["acq"])
424 )
425 assert len(all_files) == n_derivatives_files_expected
428def test_fake_bids_extra_derivative_entity(tmp_path):
429 """Check files with extra entity are created appropriately."""
430 n_sub = 2
431 n_ses = 2
432 tasks = ["main"]
433 n_runs = [2]
434 entities = {"res": ["foo", "bar"]}
435 bids_path = create_fake_bids_dataset(
436 base_dir=tmp_path,
437 n_sub=n_sub,
438 n_ses=n_ses,
439 tasks=tasks,
440 n_runs=n_runs,
441 entities=entities,
442 )
444 # raw
445 all_files = list(bids_path.glob("sub-*/ses-*/*/*res*"))
446 assert not all_files
448 # derivatives
449 for label in entities["res"]:
450 for task, n_run in zip(tasks, n_runs):
451 _check_n_files_derivatives_for_task(
452 bids_path=bids_path,
453 n_sub=n_sub,
454 n_ses=n_ses,
455 task=task,
456 n_run=n_run,
457 extra_entity={"res": label},
458 )
460 all_files = list(bids_path.glob("derivatives/sub-*/ses-*/*/*"))
461 # per subject:
462 # 1 confound per run per session
463 # + (3 bold + 2 gifti) per run per session per entity
464 n_derivatives_files_expected = n_sub * (
465 2 * sum(n_runs) * n_ses
466 + 5 * sum(n_runs) * n_ses * len(entities["res"])
467 )
468 assert len(all_files) == n_derivatives_files_expected
471def test_fake_bids_extra_entity_not_bids_entity(tmp_path):
472 """Check files with extra entity are created appropriately."""
473 with pytest.raises(ValueError, match="Invalid entity"):
474 create_fake_bids_dataset(
475 base_dir=tmp_path,
476 entities={"egg": ["spam"]},
477 )
480@pytest.mark.parametrize("window", ["boxcar", "hamming"])
481def test_generate_regions_ts_no_overlap(window):
482 n_voxels = 50
483 n_regions = 10
485 regions = generate_regions_ts(
486 n_voxels, n_regions, overlap=0, window=window
487 )
489 assert regions.shape == (n_regions, n_voxels)
490 # check no overlap
491 np.testing.assert_array_less(
492 (regions > 0).sum(axis=0) - 0.1, np.ones(regions.shape[1])
493 )
494 # check: a region everywhere
495 np.testing.assert_array_less(
496 np.zeros(regions.shape[1]), (regions > 0).sum(axis=0)
497 )
500@pytest.mark.parametrize("window", ["boxcar", "hamming"])
501def test_generate_regions_ts_with_overlap(window):
502 n_voxels = 50
503 n_regions = 10
505 regions = generate_regions_ts(
506 n_voxels, n_regions, overlap=1, window=window
507 )
509 assert regions.shape == (n_regions, n_voxels)
510 # check overlap
511 assert np.any((regions > 0).sum(axis=-1) > 1.9)
512 # check: a region everywhere
513 np.testing.assert_array_less(
514 np.zeros(regions.shape[1]), (regions > 0).sum(axis=0)
515 )
518def test_generate_labeled_regions():
519 """Minimal testing of generate_labeled_regions."""
520 shape = (3, 4, 5)
521 n_regions = 10
522 regions = generate_labeled_regions(shape, n_regions)
523 assert regions.shape == shape
524 assert len(np.unique(get_data(regions))) == n_regions + 1
527def test_generate_maps():
528 # Basic testing of generate_maps()
529 shape = (10, 11, 12)
530 n_regions = 9
531 maps_img, _ = generate_maps(shape, n_regions, border=1)
532 maps = get_data(maps_img)
533 assert maps.shape == (*shape, n_regions)
534 # no empty map
535 assert np.all(abs(maps).sum(axis=0).sum(axis=0).sum(axis=0) > 0)
536 # check border
537 assert np.all(maps[0, ...] == 0)
538 assert np.all(maps[:, 0, ...] == 0)
539 assert np.all(maps[:, :, 0, :] == 0)
542@pytest.mark.parametrize("shape", [(10, 11, 12), (6, 6, 7)])
543@pytest.mark.parametrize("length", [16, 20])
544@pytest.mark.parametrize("kind", ["noise", "step"])
545@pytest.mark.parametrize(
546 "n_block,block_size,block_type",
547 [
548 (None, None, None),
549 (1, 1, "classification"),
550 (4, 3, "classification"),
551 (4, 4, "regression"),
552 ],
553)
554def test_generate_fake_fmri(
555 shape, length, kind, n_block, block_size, block_type, rng
556):
557 fake_fmri = generate_fake_fmri(
558 shape=shape,
559 length=length,
560 kind=kind,
561 n_blocks=n_block,
562 block_size=block_size,
563 block_type=block_type,
564 random_state=rng,
565 )
567 assert fake_fmri[0].shape[:-1] == shape
568 assert fake_fmri[0].shape[-1] == length
569 if n_block is not None:
570 assert fake_fmri[2].size == length
573def test_generate_fake_fmri_error(rng):
574 with pytest.raises(ValueError, match="10 is too small"):
575 generate_fake_fmri(
576 length=10,
577 n_blocks=10,
578 block_size=3,
579 random_state=rng,
580 )
583@pytest.mark.parametrize(
584 "shapes", [[(2, 3, 5, 7)], [(5, 5, 5, 3), (5, 5, 5, 5)]]
585)
586@pytest.mark.parametrize("rank", [1, 3, 5])
587@pytest.mark.parametrize("affine", [None, np.diag([0.5, 0.3, 1, 1])])
588def test_fake_fmri_data_and_design_generate(shapes, rank, affine):
589 # test generate
590 mask, fmri_data, design_matrices = generate_fake_fmri_data_and_design(
591 shapes, rk=rank, affine=affine, random_state=42
592 )
594 for fmri, shape in zip(fmri_data, shapes):
595 assert mask.shape == shape[:3]
596 assert fmri.shape == shape
597 if affine is not None:
598 assert_almost_equal(fmri.affine, affine)
600 for design, shape in zip(design_matrices, shapes):
601 assert design.shape == (shape[3], rank)
604@pytest.mark.parametrize(
605 "shapes", [[(2, 3, 5, 7)], [(5, 5, 5, 3), (5, 5, 5, 5)]]
606)
607@pytest.mark.parametrize("rank", [1, 3, 5])
608@pytest.mark.parametrize("affine", [None, np.diag([0.5, 0.3, 1, 1])])
609def test_fake_fmri_data_and_design_write(tmp_path, shapes, rank, affine):
610 mask, fmri_data, design_matrices = generate_fake_fmri_data_and_design(
611 shapes, rk=rank, affine=affine, random_state=42
612 )
613 mask_file, fmri_files, design_files = write_fake_fmri_data_and_design(
614 shapes, rk=rank, affine=affine, random_state=42, file_path=tmp_path
615 )
617 mask_img = load(mask_file)
618 assert_almost_equal(mask_img.get_fdata(), mask.get_fdata())
619 assert_almost_equal(mask_img.affine, mask.affine)
621 for fmri_file, fmri in zip(fmri_files, fmri_data):
622 fmri_img = load(fmri_file)
623 assert_almost_equal(fmri_img.get_fdata(), fmri.get_fdata())
624 assert_almost_equal(fmri_img.affine, fmri.affine)
626 for design_file, design in zip(design_files, design_matrices):
627 assert_frame_equal(
628 pd.read_csv(design_file, sep="\t"), design, check_exact=False
629 )
632@pytest.mark.parametrize("shape", [(3, 4, 5), (2, 3, 5, 7)])
633@pytest.mark.parametrize("affine", [None, np.diag([0.5, 0.3, 1, 1])])
634def test_generate_random_img(shape, affine, rng):
635 img, mask = generate_random_img(
636 shape=shape, affine=affine, random_state=rng
637 )
639 assert img.shape == shape
640 assert mask.shape == shape[:3]
641 if affine is not None:
642 assert_almost_equal(img.affine, affine)
643 assert_almost_equal(mask.affine, affine)
646@pytest.mark.parametrize("n_subjects", [5, 9])
647@pytest.mark.parametrize("n_features", [30, 9])
648@pytest.mark.parametrize("n_samples_range", [(30, 50), (9, 9)])
649@pytest.mark.parametrize("density", [0.1, 1])
650def test_generate_group_sparse_gaussian_graphs(
651 n_subjects, n_features, n_samples_range, density, rng
652):
653 signals, precisions, topology = generate_group_sparse_gaussian_graphs(
654 n_subjects=n_subjects,
655 n_features=n_features,
656 min_n_samples=n_samples_range[0],
657 max_n_samples=n_samples_range[1],
658 density=density,
659 random_state=rng,
660 )
662 assert len(signals) == n_subjects
663 assert len(precisions) == n_subjects
665 signal_shapes = np.array([s.shape for s in signals])
666 precision_shapes = np.array([p.shape for p in precisions])
667 assert np.all(
668 (signal_shapes[:, 0] >= n_samples_range[0])
669 & (signal_shapes[:, 0] <= n_samples_range[1])
670 )
671 assert np.all(signal_shapes[:, 1] == n_features)
672 assert np.all(precision_shapes == (n_features, n_features))
673 assert topology.shape == (n_features, n_features)
675 eigenvalues = np.array([np.linalg.eigvalsh(p) for p in precisions])
676 assert np.all(eigenvalues >= 0)
679@pytest.mark.parametrize("n_timepoints", [1, 9])
680@pytest.mark.parametrize("n_features", [1, 9])
681def test_generate_timeseries(n_timepoints, n_features, rng):
682 timeseries = generate_timeseries(n_timepoints, n_features, rng)
683 assert timeseries.shape == (n_timepoints, n_features)
686@pytest.mark.parametrize("n_scans", [1, 5])
687@pytest.mark.parametrize("res", [1, 30])
688@pytest.mark.parametrize("mask_dilation", [1, 2])
689def test_generate_mni_space_img(n_scans, res, mask_dilation, rng):
690 inverse_img, mask_img = generate_mni_space_img(
691 n_scans=n_scans, res=res, mask_dilation=mask_dilation, random_state=rng
692 )
694 def resample_dim(orig, res):
695 return (orig - 2) // res + 2
697 expected_shape = (
698 resample_dim(197, res),
699 resample_dim(233, res),
700 resample_dim(189, res),
701 )
702 assert inverse_img.shape[:3] == expected_shape
703 assert inverse_img.shape[3] == n_scans
704 assert mask_img.shape == expected_shape
705 assert_almost_equal(inverse_img.affine, mask_img.affine)