Coverage for nilearn/datasets/tests/test_utils.py: 0%
302 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
1"""Test the datasets module."""
3import gzip
4import os
5import re
6import shutil
7import tarfile
8import urllib
9from pathlib import Path
10from unittest.mock import MagicMock
11from zipfile import ZipFile
13import numpy as np
14import pytest
15import requests
17from nilearn.datasets import _utils, utils
18from nilearn.datasets.tests._testing import Response
20datadir = _utils.PACKAGE_DIRECTORY / "data"
22DATASET_NAMES = {
23 "ABIDE_pcp",
24 "adhd",
25 "bids_langloc",
26 "brainomics_localizer",
27 "development_fmri",
28 "dosenbach_2010",
29 "fiac",
30 "fsaverage3",
31 "fsaverage4",
32 "fsaverage5",
33 "fsaverage6",
34 "fsaverage",
35 "haxby2001",
36 "icbm152_2009",
37 "language_localizer_demo",
38 "localizer_first_level",
39 "Megatrawls",
40 "mixed_gambles",
41 "miyawaki2008",
42 "neurovault",
43 "nki_enhanced_surface",
44 "oasis1",
45 "power_2011",
46 "spm_auditory",
47 "spm_multimodal",
48}
51@pytest.mark.parametrize("name", DATASET_NAMES)
52def test_get_dataset_descr(name):
53 """Test function ``get_dataset_descr()``.
55 Not needed for atlas datasets as this is checked in
56 nilearn/datasets/tests/test_atlas.py
57 """
58 descr = _utils.get_dataset_descr(name)
60 assert isinstance(descr, str)
61 assert len(descr) > 0
64def test_get_dataset_descr_warning():
65 """Tests that function ``get_dataset_descr()`` gives a warning \
66 when no description is available.
67 """
68 with pytest.warns(
69 UserWarning, match="Could not find dataset description."
70 ):
71 descr = _utils.get_dataset_descr("")
73 assert descr == ""
76def test_get_dataset_dir(tmp_path):
77 """Test folder creation under different environments.
79 Enforcing a custom clean install.
80 """
81 os.environ.pop("NILEARN_DATA", None)
82 os.environ.pop("NILEARN_SHARED_DATA", None)
84 expected_base_dir = Path("~/nilearn_data").expanduser()
85 data_dir = _utils.get_dataset_dir("test", verbose=0)
87 assert data_dir == expected_base_dir / "test"
88 assert data_dir.exists()
90 shutil.rmtree(data_dir)
92 expected_base_dir = tmp_path / "test_nilearn_data"
93 os.environ["NILEARN_DATA"] = str(expected_base_dir)
94 data_dir = _utils.get_dataset_dir("test", verbose=0)
96 assert data_dir == expected_base_dir / "test"
97 assert data_dir.exists()
99 shutil.rmtree(data_dir)
101 expected_base_dir = tmp_path / "nilearn_shared_data"
102 os.environ["NILEARN_SHARED_DATA"] = str(expected_base_dir)
103 data_dir = _utils.get_dataset_dir("test", verbose=0)
105 assert data_dir == expected_base_dir / "test"
106 assert data_dir.exists()
108 shutil.rmtree(data_dir)
110 # Verify exception for a path which exists and is a file
111 test_file = tmp_path / "some_file"
112 test_file.write_text("abcfeg")
114 with pytest.raises(
115 OSError,
116 match="Nilearn tried to store the dataset in the following "
117 "directories, but",
118 ):
119 _utils.get_dataset_dir("test", test_file, verbose=0)
122def test_add_readme_to_default_data_locations(tmp_path):
123 """Make sure get_dataset_dir creates a README."""
124 assert not (tmp_path / "README.md").exists()
126 _utils.get_dataset_dir(dataset_name="test", verbose=0, data_dir=tmp_path)
128 assert (tmp_path / "README.md").exists()
131def test_get_dataset_dir_path_as_str(tmp_path):
132 """Make sure get_dataset_dir can handle string."""
133 expected_base_dir = tmp_path / "env_data"
134 expected_dataset_dir = expected_base_dir / "test"
135 data_dir = _utils.get_dataset_dir(
136 "test", default_paths=[expected_dataset_dir], verbose=0
137 )
139 assert data_dir == expected_dataset_dir
140 assert data_dir.exists()
142 shutil.rmtree(data_dir)
145def test_get_dataset_dir_write_access(tmp_path):
146 """Check get_dataset_dir can deal with folders with special permissions."""
147 os.environ.pop("NILEARN_SHARED_DATA", None)
149 no_write = tmp_path / "no_write"
150 no_write.mkdir(parents=True)
151 no_write.chmod(0o400)
153 expected_base_dir = tmp_path / "nilearn_shared_data"
154 os.environ["NILEARN_SHARED_DATA"] = str(expected_base_dir)
155 data_dir = _utils.get_dataset_dir(
156 "test", default_paths=[no_write], verbose=0
157 )
159 # Non writable dir is returned because dataset may be in there.
160 assert data_dir == no_write
161 assert data_dir.exists()
163 no_write.chmod(0o600)
164 shutil.rmtree(data_dir)
167def test_get_dataset_dir_symlink(tmp_path):
168 """Make sure get_dataset_dir can handle simlink."""
169 expected_linked_dir = tmp_path / "linked"
170 expected_linked_dir.mkdir(parents=True)
171 expected_base_dir = tmp_path / "env_data"
172 expected_base_dir.mkdir()
173 symlink_dir = expected_base_dir / "test"
174 symlink_dir.symlink_to(expected_linked_dir)
176 assert symlink_dir.exists()
178 data_dir = _utils.get_dataset_dir(
179 "test", default_paths=[symlink_dir], verbose=0
180 )
182 assert data_dir == expected_linked_dir
183 assert data_dir.exists()
186def test_md5_sum_file(tmp_path):
187 """Tests nilearn.dataset._utils._md5_sum_file."""
188 # Create dummy temporary file
189 f = tmp_path / "test"
190 f.write_bytes(b"abcfeg")
192 assert _utils._md5_sum_file(f) == "18f32295c556b2a1a3a8e68fe1ad40f7"
195def test_read_md5_sum_file(tmp_path):
196 """Tests nilearn.dataset._utils.read_md5_sum_file."""
197 # Create dummy temporary file
198 f = tmp_path / "test"
199 f.write_bytes(
200 b"20861c8c3fe177da19a7e9539a5dbac /tmp/test\n"
201 b"70886dcabe7bf5c5a1c24ca24e4cbd94 test/some_image.nii",
202 )
204 h = _utils.read_md5_sum_file(f)
206 assert "/tmp/test" in h
207 assert "/etc/test" not in h
208 assert h["test/some_image.nii"] == "70886dcabe7bf5c5a1c24ca24e4cbd94"
209 assert h["/tmp/test"] == "20861c8c3fe177da19a7e9539a5dbac"
212def test_tree(tmp_path):
213 """Tests nilearn.dataset._utils.tree."""
214 dir1 = tmp_path / "dir1"
215 dir11 = dir1 / "dir11"
216 dir12 = dir1 / "dir12"
217 dir2 = tmp_path / "dir2"
219 dir1.mkdir()
220 dir11.mkdir()
221 dir12.mkdir()
222 dir2.mkdir()
224 (tmp_path / "file1").touch()
225 (tmp_path / "file2").touch()
226 (dir1 / "file11").touch()
227 (dir1 / "file12").touch()
228 (dir11 / "file111").touch()
229 (dir2 / "file21").touch()
231 # test for list return value
232 tree_ = _utils.tree(tmp_path)
234 # Check the tree
235 assert type(tree_[0]) is tuple
236 assert type(tree_[0][1]) is list
237 assert type(tree_[0][1][0]) is tuple
238 assert type(tree_[1]) is tuple
239 assert type(tree_[1][1]) is list
240 assert tree_[0][1][0][1][0] == str(dir11 / "file111")
241 assert len(tree_[0][1][1][1]) == 0
242 assert tree_[0][1][2] == str(dir1 / "file11")
243 assert tree_[0][1][3] == str(dir1 / "file12")
244 assert tree_[1][1][0] == str(dir2 / "file21")
245 assert tree_[2] == str(tmp_path / "file1")
246 assert tree_[3] == str(tmp_path / "file2")
248 # test for dictionary return value
249 tree_ = _utils.tree(tmp_path, dictionary=True)
251 # Check the tree
252 assert type(tree_[dir1.name]) is dict
253 assert type(tree_[dir1.name][dir11.name]) is list
254 assert len(tree_[dir1.name][dir12.name]) == 0
255 assert type(tree_[dir2.name]) is list
256 assert type(tree_["."]) is list
257 assert tree_[dir1.name][dir11.name][0] == str(dir11 / "file111")
258 assert tree_[dir1.name]["."][0] == str(dir1 / "file11")
259 assert tree_[dir1.name]["."][1] == str(dir1 / "file12")
260 assert tree_[dir2.name][0] == str(dir2 / "file21")
261 assert tree_["."] == [str(tmp_path / "file1"), str(tmp_path / "file2")]
264def test_movetree(tmp_path):
265 """Tests nilearn.dataset._utils.movetree."""
266 dir1 = tmp_path / "dir1"
267 dir111 = dir1 / "dir11"
268 dir112 = dir1 / "dir12"
269 dir2 = tmp_path / "dir2"
270 dir212 = dir2 / "dir12"
272 dir1.mkdir()
273 dir111.mkdir()
274 dir112.mkdir()
275 dir2.mkdir()
276 dir212.mkdir()
278 (dir1 / "file11").touch()
279 (dir1 / "file12").touch()
280 (dir111 / "file1111").touch()
281 (dir112 / "file1121").touch()
282 (dir2 / "file21").touch()
284 _utils.movetree(dir1, dir2)
286 assert not dir111.exists()
287 assert not dir112.exists()
288 assert not (dir1 / "file11").exists()
289 assert not (dir1 / "file12").exists()
290 assert not (dir111 / "file1111").exists()
291 assert not (dir112 / "file1121").exists()
293 dir211 = dir2 / "dir11"
294 dir212 = dir2 / "dir12"
296 assert dir211.exists()
297 assert dir212.exists()
298 assert (dir2 / "file21").exists()
299 assert (dir2 / "file11").exists()
300 assert (dir2 / "file12").exists()
301 assert (dir211 / "file1111").exists()
302 assert (dir212 / "file1121").exists()
305def test_filter_columns():
306 """Test filter_columns."""
307 # Create fake recarray
308 value1 = np.arange(500)
309 strings = np.asarray(["a", "b", "c"])
310 value2 = strings[value1 % 3]
312 values = np.asarray(
313 list(zip(value1, value2)), dtype=[("INT", int), ("STR", "S1")]
314 )
316 f = _utils.filter_columns(values, {"INT": (23, 46)})
318 assert np.sum(f) == 24
320 f = _utils.filter_columns(values, {"INT": [0, 9, (12, 24)]})
322 assert np.sum(f) == 15
324 value1 = value1 % 2
325 values = np.asarray(
326 list(zip(value1, value2)), dtype=[("INT", int), ("STR", b"S1")]
327 )
329 # No filter
330 f = _utils.filter_columns(values, [])
332 assert np.sum(f) == 500
334 f = _utils.filter_columns(values, {"STR": b"b"})
336 assert np.sum(f) == 167
338 f = _utils.filter_columns(values, {"STR": "b"})
340 assert np.sum(f) == 167
342 f = _utils.filter_columns(values, {"INT": 1, "STR": b"b"})
344 assert np.sum(f) == 84
346 f = _utils.filter_columns(
347 values, {"INT": 1, "STR": b"b"}, combination="or"
348 )
350 assert np.sum(f) == 333
353@pytest.mark.parametrize(
354 "ext, mode", [("tar", "w"), ("tar.gz", "w:gz"), ("tgz", "w:gz")]
355)
356def test_uncompress_tar(tmp_path, ext, mode):
357 """Tests nilearn.dataset._utils.uncompress_file for tar files.
359 For each kind of compression, we create:
360 - a compressed object (ztemp)
361 - a temporary file-like object to compress into ztemp
362 we then uncompress the ztemp object into dtemp under the name ftemp
363 and check if ftemp exists
364 """
365 ztemp = tmp_path / f"test.{ext}"
366 ftemp = "test"
367 with tarfile.open(ztemp, mode) as testtar:
368 temp = tmp_path / ftemp
369 temp.write_text(ftemp)
370 testtar.add(temp)
372 _utils.uncompress_file(ztemp, verbose=0)
374 assert (tmp_path / ftemp).exists()
377def test_uncompress_zip(tmp_path):
378 """Tests nilearn.dataset._utils.uncompress_file for zip files.
380 For each kind of compression, we create:
381 - a compressed object (ztemp)
382 - a temporary file-like object to compress into ztemp
383 we then uncompress the ztemp object into dtemp under the name ftemp
384 and check if ftemp exists
385 """
386 ztemp = tmp_path / "test.zip"
387 ftemp = "test"
388 with ZipFile(ztemp, "w") as testzip:
389 testzip.writestr(ftemp, " ")
391 _utils.uncompress_file(ztemp, verbose=0)
393 assert (tmp_path / ftemp).exists()
396@pytest.mark.parametrize("ext", [".gz", ""])
397def test_uncompress_gzip(tmp_path, ext):
398 """Tests nilearn.dataset._utils.uncompress_file for gzip files.
400 For each kind of compression, we create:
401 - a compressed object (ztemp)
402 - a temporary file-like object to compress into ztemp
403 we then uncompress the ztemp object into dtemp under the name ftemp
404 and check if ftemp exists
405 """
406 ztemp = tmp_path / f"test{ext}"
407 ftemp = "test"
409 with gzip.open(ztemp, "wt") as testgzip:
410 testgzip.write(ftemp)
412 _utils.uncompress_file(ztemp, verbose=0)
414 assert (tmp_path / ftemp).exists()
417def test_safe_extract(tmp_path):
418 """Test vulnerability patch by mimicking path traversal."""
419 ztemp = tmp_path / "test.tar"
420 in_archive_file = tmp_path / "something.txt"
421 in_archive_file.write_text("hello")
422 with tarfile.open(ztemp, "w") as tar:
423 arcname = "../test.tar"
424 tar.add(in_archive_file, arcname=arcname)
426 with pytest.raises(
427 Exception, match="Attempted Path Traversal in Tar File"
428 ):
429 _utils.uncompress_file(ztemp, verbose=0)
432def test_fetch_single_file_part(tmp_path, capsys, request_mocker):
433 """Check that fetch_single_file can fetch part of file."""
435 def get_response(match, request):
436 """Create mock Response object with correct content range header."""
437 req_range = request.headers.get("Range")
438 resp = Response(b"dummy content", match)
440 # set up Response object to return partial content
441 # and update header accordingly
442 if req_range is not None:
443 resp.iter_start = int(re.match(r"bytes=(\d+)-", req_range)[1])
444 resp.headers["Content-Range"] = (
445 f"bytes {resp.iter_start}-{len(resp.content) - 1}"
446 f"/{len(resp.content)}"
447 )
449 return resp
451 url = "http://foo/temp.txt"
452 file_full = tmp_path / "temp.txt"
453 file_part = tmp_path / "temp.txt.part"
454 file_part.write_text("D") # should not be overwritten
456 request_mocker.url_mapping[url] = get_response
458 _utils.fetch_single_file(
459 url=url, data_dir=tmp_path, verbose=1, resume=True
460 )
462 assert file_full.exists()
463 assert file_full.read_text() == "Dummy content" # not overwritten
464 assert "Resuming failed" not in capsys.readouterr().out
466 file_full.unlink()
467 assert not file_full.exists()
468 assert not file_part.exists()
470 # test for overwrite
471 file_part.write_text("D") # should be overwritten
473 _utils.fetch_single_file(
474 url=url, data_dir=tmp_path, resume=True, overwrite=True
475 )
477 assert file_full.exists()
478 assert file_full.read_text() == "dummy content" # overwritten
481def test_fetch_single_file_part_error(tmp_path, capsys, request_mocker):
482 """Check error fetch_single_file."""
483 url = "http://foo/temp.txt"
484 file_part = tmp_path / "temp.txt.part"
485 file_part.touch() # should not be overwritten
487 # the default Response from the mocker does not handle Range requests
488 request_mocker.url_mapping[url] = "dummy content"
490 _utils.fetch_single_file(
491 url=url, data_dir=tmp_path, verbose=1, resume=True
492 )
494 assert (
495 "Resuming failed, try to download the whole file."
496 in capsys.readouterr().out
497 )
500def test_fetch_single_file_overwrite(tmp_path, request_mocker):
501 """Check that fetch_single_file can overwrite files."""
502 # overwrite non-exiting file.
503 fil = _utils.fetch_single_file(
504 url="http://foo/", data_dir=tmp_path, verbose=0, overwrite=True
505 )
507 assert request_mocker.url_count == 1
508 assert fil.exists()
509 assert fil.read_text() == ""
511 # Modify content
512 fil.write_text("some content")
514 # Don't overwrite existing file.
515 fil = _utils.fetch_single_file(
516 url="http://foo/", data_dir=tmp_path, verbose=0, overwrite=False
517 )
519 assert request_mocker.url_count == 1
520 assert fil.exists()
521 assert fil.read_text() == "some content"
523 # Overwrite existing file.
524 fil = _utils.fetch_single_file(
525 url="http://foo/", data_dir=tmp_path, verbose=0, overwrite=True
526 )
528 assert request_mocker.url_count == 2
529 assert fil.exists()
530 assert fil.read_text() == ""
533@pytest.mark.parametrize("should_cast_path_to_string", [False, True])
534def test_fetch_files_use_session(
535 should_cast_path_to_string,
536 tmp_path,
537 request_mocker, # noqa: ARG001
538):
539 """Use session parameter of fetch_files."""
540 if should_cast_path_to_string:
541 tmp_path = str(tmp_path)
543 # regression test for https://github.com/nilearn/nilearn/issues/2863
544 session = MagicMock()
545 _utils.fetch_files(
546 files=[
547 ("example1", "https://example.org/example1", {"overwrite": True}),
548 ("example2", "https://example.org/example2", {"overwrite": True}),
549 ],
550 data_dir=tmp_path,
551 session=session,
552 )
554 assert session.send.call_count == 2
557@pytest.mark.parametrize("should_cast_path_to_string", [False, True])
558def test_fetch_files_overwrite(
559 should_cast_path_to_string, tmp_path, request_mocker
560):
561 """Check that fetch_files can overwrite files."""
562 if should_cast_path_to_string:
563 tmp_path = str(tmp_path)
565 # overwrite non-exiting file.
566 files = ("1.txt", "http://foo/1.txt")
567 fil = Path(
568 _utils.fetch_files(
569 data_dir=tmp_path,
570 verbose=0,
571 files=[(*files, {"overwrite": True})],
572 )[0]
573 )
575 assert request_mocker.url_count == 1
576 assert fil.exists()
577 assert not fil.read_text()
579 # Modify content
580 fil.write_text("some content")
582 # Don't overwrite existing file.
583 fil = Path(
584 _utils.fetch_files(
585 data_dir=tmp_path,
586 verbose=0,
587 files=[(*files, {"overwrite": False})],
588 )[0]
589 )
591 assert request_mocker.url_count == 1
592 assert fil.exists()
593 assert fil.read_text() == "some content"
595 # Overwrite existing file.
596 fil = Path(
597 _utils.fetch_files(
598 data_dir=tmp_path,
599 verbose=0,
600 files=[(*files, {"overwrite": True})],
601 )[0]
602 )
604 assert request_mocker.url_count == 2
605 assert fil.exists()
606 assert not fil.read_text()
609def test_naive_ftp_adapter():
610 """Test _NaiveFTPAdapter error."""
611 sender = _utils._NaiveFTPAdapter()
612 resp = sender.send(requests.Request("GET", "ftp://example.com").prepare())
613 resp.close()
614 resp.raw.close.assert_called_with()
615 urllib.request.OpenerDirector.open.side_effect = urllib.error.URLError(
616 "timeout"
617 )
618 with pytest.raises(requests.RequestException, match="timeout"):
619 resp = sender.send(
620 requests.Request("GET", "ftp://example.com").prepare()
621 )
624def test_load_sample_motor_activation_image():
625 """Test deprecation utils.load_sample_motor_activation_image.
627 Remove when when version >= 0.13.
628 """
629 with pytest.warns(
630 DeprecationWarning,
631 match="Please import this function from 'nilearn.datasets.func'",
632 ):
633 utils.load_sample_motor_activation_image()