Coverage for nilearn/datasets/tests/test_utils.py: 0%

302 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-16 12:32 +0200

1"""Test the datasets module.""" 

2 

3import gzip 

4import os 

5import re 

6import shutil 

7import tarfile 

8import urllib 

9from pathlib import Path 

10from unittest.mock import MagicMock 

11from zipfile import ZipFile 

12 

13import numpy as np 

14import pytest 

15import requests 

16 

17from nilearn.datasets import _utils, utils 

18from nilearn.datasets.tests._testing import Response 

19 

20datadir = _utils.PACKAGE_DIRECTORY / "data" 

21 

22DATASET_NAMES = { 

23 "ABIDE_pcp", 

24 "adhd", 

25 "bids_langloc", 

26 "brainomics_localizer", 

27 "development_fmri", 

28 "dosenbach_2010", 

29 "fiac", 

30 "fsaverage3", 

31 "fsaverage4", 

32 "fsaverage5", 

33 "fsaverage6", 

34 "fsaverage", 

35 "haxby2001", 

36 "icbm152_2009", 

37 "language_localizer_demo", 

38 "localizer_first_level", 

39 "Megatrawls", 

40 "mixed_gambles", 

41 "miyawaki2008", 

42 "neurovault", 

43 "nki_enhanced_surface", 

44 "oasis1", 

45 "power_2011", 

46 "spm_auditory", 

47 "spm_multimodal", 

48} 

49 

50 

51@pytest.mark.parametrize("name", DATASET_NAMES) 

52def test_get_dataset_descr(name): 

53 """Test function ``get_dataset_descr()``. 

54 

55 Not needed for atlas datasets as this is checked in 

56 nilearn/datasets/tests/test_atlas.py 

57 """ 

58 descr = _utils.get_dataset_descr(name) 

59 

60 assert isinstance(descr, str) 

61 assert len(descr) > 0 

62 

63 

64def test_get_dataset_descr_warning(): 

65 """Tests that function ``get_dataset_descr()`` gives a warning \ 

66 when no description is available. 

67 """ 

68 with pytest.warns( 

69 UserWarning, match="Could not find dataset description." 

70 ): 

71 descr = _utils.get_dataset_descr("") 

72 

73 assert descr == "" 

74 

75 

76def test_get_dataset_dir(tmp_path): 

77 """Test folder creation under different environments. 

78 

79 Enforcing a custom clean install. 

80 """ 

81 os.environ.pop("NILEARN_DATA", None) 

82 os.environ.pop("NILEARN_SHARED_DATA", None) 

83 

84 expected_base_dir = Path("~/nilearn_data").expanduser() 

85 data_dir = _utils.get_dataset_dir("test", verbose=0) 

86 

87 assert data_dir == expected_base_dir / "test" 

88 assert data_dir.exists() 

89 

90 shutil.rmtree(data_dir) 

91 

92 expected_base_dir = tmp_path / "test_nilearn_data" 

93 os.environ["NILEARN_DATA"] = str(expected_base_dir) 

94 data_dir = _utils.get_dataset_dir("test", verbose=0) 

95 

96 assert data_dir == expected_base_dir / "test" 

97 assert data_dir.exists() 

98 

99 shutil.rmtree(data_dir) 

100 

101 expected_base_dir = tmp_path / "nilearn_shared_data" 

102 os.environ["NILEARN_SHARED_DATA"] = str(expected_base_dir) 

103 data_dir = _utils.get_dataset_dir("test", verbose=0) 

104 

105 assert data_dir == expected_base_dir / "test" 

106 assert data_dir.exists() 

107 

108 shutil.rmtree(data_dir) 

109 

110 # Verify exception for a path which exists and is a file 

111 test_file = tmp_path / "some_file" 

112 test_file.write_text("abcfeg") 

113 

114 with pytest.raises( 

115 OSError, 

116 match="Nilearn tried to store the dataset in the following " 

117 "directories, but", 

118 ): 

119 _utils.get_dataset_dir("test", test_file, verbose=0) 

120 

121 

122def test_add_readme_to_default_data_locations(tmp_path): 

123 """Make sure get_dataset_dir creates a README.""" 

124 assert not (tmp_path / "README.md").exists() 

125 

126 _utils.get_dataset_dir(dataset_name="test", verbose=0, data_dir=tmp_path) 

127 

128 assert (tmp_path / "README.md").exists() 

129 

130 

131def test_get_dataset_dir_path_as_str(tmp_path): 

132 """Make sure get_dataset_dir can handle string.""" 

133 expected_base_dir = tmp_path / "env_data" 

134 expected_dataset_dir = expected_base_dir / "test" 

135 data_dir = _utils.get_dataset_dir( 

136 "test", default_paths=[expected_dataset_dir], verbose=0 

137 ) 

138 

139 assert data_dir == expected_dataset_dir 

140 assert data_dir.exists() 

141 

142 shutil.rmtree(data_dir) 

143 

144 

145def test_get_dataset_dir_write_access(tmp_path): 

146 """Check get_dataset_dir can deal with folders with special permissions.""" 

147 os.environ.pop("NILEARN_SHARED_DATA", None) 

148 

149 no_write = tmp_path / "no_write" 

150 no_write.mkdir(parents=True) 

151 no_write.chmod(0o400) 

152 

153 expected_base_dir = tmp_path / "nilearn_shared_data" 

154 os.environ["NILEARN_SHARED_DATA"] = str(expected_base_dir) 

155 data_dir = _utils.get_dataset_dir( 

156 "test", default_paths=[no_write], verbose=0 

157 ) 

158 

159 # Non writable dir is returned because dataset may be in there. 

160 assert data_dir == no_write 

161 assert data_dir.exists() 

162 

163 no_write.chmod(0o600) 

164 shutil.rmtree(data_dir) 

165 

166 

167def test_get_dataset_dir_symlink(tmp_path): 

168 """Make sure get_dataset_dir can handle simlink.""" 

169 expected_linked_dir = tmp_path / "linked" 

170 expected_linked_dir.mkdir(parents=True) 

171 expected_base_dir = tmp_path / "env_data" 

172 expected_base_dir.mkdir() 

173 symlink_dir = expected_base_dir / "test" 

174 symlink_dir.symlink_to(expected_linked_dir) 

175 

176 assert symlink_dir.exists() 

177 

178 data_dir = _utils.get_dataset_dir( 

179 "test", default_paths=[symlink_dir], verbose=0 

180 ) 

181 

182 assert data_dir == expected_linked_dir 

183 assert data_dir.exists() 

184 

185 

186def test_md5_sum_file(tmp_path): 

187 """Tests nilearn.dataset._utils._md5_sum_file.""" 

188 # Create dummy temporary file 

189 f = tmp_path / "test" 

190 f.write_bytes(b"abcfeg") 

191 

192 assert _utils._md5_sum_file(f) == "18f32295c556b2a1a3a8e68fe1ad40f7" 

193 

194 

195def test_read_md5_sum_file(tmp_path): 

196 """Tests nilearn.dataset._utils.read_md5_sum_file.""" 

197 # Create dummy temporary file 

198 f = tmp_path / "test" 

199 f.write_bytes( 

200 b"20861c8c3fe177da19a7e9539a5dbac /tmp/test\n" 

201 b"70886dcabe7bf5c5a1c24ca24e4cbd94 test/some_image.nii", 

202 ) 

203 

204 h = _utils.read_md5_sum_file(f) 

205 

206 assert "/tmp/test" in h 

207 assert "/etc/test" not in h 

208 assert h["test/some_image.nii"] == "70886dcabe7bf5c5a1c24ca24e4cbd94" 

209 assert h["/tmp/test"] == "20861c8c3fe177da19a7e9539a5dbac" 

210 

211 

212def test_tree(tmp_path): 

213 """Tests nilearn.dataset._utils.tree.""" 

214 dir1 = tmp_path / "dir1" 

215 dir11 = dir1 / "dir11" 

216 dir12 = dir1 / "dir12" 

217 dir2 = tmp_path / "dir2" 

218 

219 dir1.mkdir() 

220 dir11.mkdir() 

221 dir12.mkdir() 

222 dir2.mkdir() 

223 

224 (tmp_path / "file1").touch() 

225 (tmp_path / "file2").touch() 

226 (dir1 / "file11").touch() 

227 (dir1 / "file12").touch() 

228 (dir11 / "file111").touch() 

229 (dir2 / "file21").touch() 

230 

231 # test for list return value 

232 tree_ = _utils.tree(tmp_path) 

233 

234 # Check the tree 

235 assert type(tree_[0]) is tuple 

236 assert type(tree_[0][1]) is list 

237 assert type(tree_[0][1][0]) is tuple 

238 assert type(tree_[1]) is tuple 

239 assert type(tree_[1][1]) is list 

240 assert tree_[0][1][0][1][0] == str(dir11 / "file111") 

241 assert len(tree_[0][1][1][1]) == 0 

242 assert tree_[0][1][2] == str(dir1 / "file11") 

243 assert tree_[0][1][3] == str(dir1 / "file12") 

244 assert tree_[1][1][0] == str(dir2 / "file21") 

245 assert tree_[2] == str(tmp_path / "file1") 

246 assert tree_[3] == str(tmp_path / "file2") 

247 

248 # test for dictionary return value 

249 tree_ = _utils.tree(tmp_path, dictionary=True) 

250 

251 # Check the tree 

252 assert type(tree_[dir1.name]) is dict 

253 assert type(tree_[dir1.name][dir11.name]) is list 

254 assert len(tree_[dir1.name][dir12.name]) == 0 

255 assert type(tree_[dir2.name]) is list 

256 assert type(tree_["."]) is list 

257 assert tree_[dir1.name][dir11.name][0] == str(dir11 / "file111") 

258 assert tree_[dir1.name]["."][0] == str(dir1 / "file11") 

259 assert tree_[dir1.name]["."][1] == str(dir1 / "file12") 

260 assert tree_[dir2.name][0] == str(dir2 / "file21") 

261 assert tree_["."] == [str(tmp_path / "file1"), str(tmp_path / "file2")] 

262 

263 

264def test_movetree(tmp_path): 

265 """Tests nilearn.dataset._utils.movetree.""" 

266 dir1 = tmp_path / "dir1" 

267 dir111 = dir1 / "dir11" 

268 dir112 = dir1 / "dir12" 

269 dir2 = tmp_path / "dir2" 

270 dir212 = dir2 / "dir12" 

271 

272 dir1.mkdir() 

273 dir111.mkdir() 

274 dir112.mkdir() 

275 dir2.mkdir() 

276 dir212.mkdir() 

277 

278 (dir1 / "file11").touch() 

279 (dir1 / "file12").touch() 

280 (dir111 / "file1111").touch() 

281 (dir112 / "file1121").touch() 

282 (dir2 / "file21").touch() 

283 

284 _utils.movetree(dir1, dir2) 

285 

286 assert not dir111.exists() 

287 assert not dir112.exists() 

288 assert not (dir1 / "file11").exists() 

289 assert not (dir1 / "file12").exists() 

290 assert not (dir111 / "file1111").exists() 

291 assert not (dir112 / "file1121").exists() 

292 

293 dir211 = dir2 / "dir11" 

294 dir212 = dir2 / "dir12" 

295 

296 assert dir211.exists() 

297 assert dir212.exists() 

298 assert (dir2 / "file21").exists() 

299 assert (dir2 / "file11").exists() 

300 assert (dir2 / "file12").exists() 

301 assert (dir211 / "file1111").exists() 

302 assert (dir212 / "file1121").exists() 

303 

304 

305def test_filter_columns(): 

306 """Test filter_columns.""" 

307 # Create fake recarray 

308 value1 = np.arange(500) 

309 strings = np.asarray(["a", "b", "c"]) 

310 value2 = strings[value1 % 3] 

311 

312 values = np.asarray( 

313 list(zip(value1, value2)), dtype=[("INT", int), ("STR", "S1")] 

314 ) 

315 

316 f = _utils.filter_columns(values, {"INT": (23, 46)}) 

317 

318 assert np.sum(f) == 24 

319 

320 f = _utils.filter_columns(values, {"INT": [0, 9, (12, 24)]}) 

321 

322 assert np.sum(f) == 15 

323 

324 value1 = value1 % 2 

325 values = np.asarray( 

326 list(zip(value1, value2)), dtype=[("INT", int), ("STR", b"S1")] 

327 ) 

328 

329 # No filter 

330 f = _utils.filter_columns(values, []) 

331 

332 assert np.sum(f) == 500 

333 

334 f = _utils.filter_columns(values, {"STR": b"b"}) 

335 

336 assert np.sum(f) == 167 

337 

338 f = _utils.filter_columns(values, {"STR": "b"}) 

339 

340 assert np.sum(f) == 167 

341 

342 f = _utils.filter_columns(values, {"INT": 1, "STR": b"b"}) 

343 

344 assert np.sum(f) == 84 

345 

346 f = _utils.filter_columns( 

347 values, {"INT": 1, "STR": b"b"}, combination="or" 

348 ) 

349 

350 assert np.sum(f) == 333 

351 

352 

353@pytest.mark.parametrize( 

354 "ext, mode", [("tar", "w"), ("tar.gz", "w:gz"), ("tgz", "w:gz")] 

355) 

356def test_uncompress_tar(tmp_path, ext, mode): 

357 """Tests nilearn.dataset._utils.uncompress_file for tar files. 

358 

359 For each kind of compression, we create: 

360 - a compressed object (ztemp) 

361 - a temporary file-like object to compress into ztemp 

362 we then uncompress the ztemp object into dtemp under the name ftemp 

363 and check if ftemp exists 

364 """ 

365 ztemp = tmp_path / f"test.{ext}" 

366 ftemp = "test" 

367 with tarfile.open(ztemp, mode) as testtar: 

368 temp = tmp_path / ftemp 

369 temp.write_text(ftemp) 

370 testtar.add(temp) 

371 

372 _utils.uncompress_file(ztemp, verbose=0) 

373 

374 assert (tmp_path / ftemp).exists() 

375 

376 

377def test_uncompress_zip(tmp_path): 

378 """Tests nilearn.dataset._utils.uncompress_file for zip files. 

379 

380 For each kind of compression, we create: 

381 - a compressed object (ztemp) 

382 - a temporary file-like object to compress into ztemp 

383 we then uncompress the ztemp object into dtemp under the name ftemp 

384 and check if ftemp exists 

385 """ 

386 ztemp = tmp_path / "test.zip" 

387 ftemp = "test" 

388 with ZipFile(ztemp, "w") as testzip: 

389 testzip.writestr(ftemp, " ") 

390 

391 _utils.uncompress_file(ztemp, verbose=0) 

392 

393 assert (tmp_path / ftemp).exists() 

394 

395 

396@pytest.mark.parametrize("ext", [".gz", ""]) 

397def test_uncompress_gzip(tmp_path, ext): 

398 """Tests nilearn.dataset._utils.uncompress_file for gzip files. 

399 

400 For each kind of compression, we create: 

401 - a compressed object (ztemp) 

402 - a temporary file-like object to compress into ztemp 

403 we then uncompress the ztemp object into dtemp under the name ftemp 

404 and check if ftemp exists 

405 """ 

406 ztemp = tmp_path / f"test{ext}" 

407 ftemp = "test" 

408 

409 with gzip.open(ztemp, "wt") as testgzip: 

410 testgzip.write(ftemp) 

411 

412 _utils.uncompress_file(ztemp, verbose=0) 

413 

414 assert (tmp_path / ftemp).exists() 

415 

416 

417def test_safe_extract(tmp_path): 

418 """Test vulnerability patch by mimicking path traversal.""" 

419 ztemp = tmp_path / "test.tar" 

420 in_archive_file = tmp_path / "something.txt" 

421 in_archive_file.write_text("hello") 

422 with tarfile.open(ztemp, "w") as tar: 

423 arcname = "../test.tar" 

424 tar.add(in_archive_file, arcname=arcname) 

425 

426 with pytest.raises( 

427 Exception, match="Attempted Path Traversal in Tar File" 

428 ): 

429 _utils.uncompress_file(ztemp, verbose=0) 

430 

431 

432def test_fetch_single_file_part(tmp_path, capsys, request_mocker): 

433 """Check that fetch_single_file can fetch part of file.""" 

434 

435 def get_response(match, request): 

436 """Create mock Response object with correct content range header.""" 

437 req_range = request.headers.get("Range") 

438 resp = Response(b"dummy content", match) 

439 

440 # set up Response object to return partial content 

441 # and update header accordingly 

442 if req_range is not None: 

443 resp.iter_start = int(re.match(r"bytes=(\d+)-", req_range)[1]) 

444 resp.headers["Content-Range"] = ( 

445 f"bytes {resp.iter_start}-{len(resp.content) - 1}" 

446 f"/{len(resp.content)}" 

447 ) 

448 

449 return resp 

450 

451 url = "http://foo/temp.txt" 

452 file_full = tmp_path / "temp.txt" 

453 file_part = tmp_path / "temp.txt.part" 

454 file_part.write_text("D") # should not be overwritten 

455 

456 request_mocker.url_mapping[url] = get_response 

457 

458 _utils.fetch_single_file( 

459 url=url, data_dir=tmp_path, verbose=1, resume=True 

460 ) 

461 

462 assert file_full.exists() 

463 assert file_full.read_text() == "Dummy content" # not overwritten 

464 assert "Resuming failed" not in capsys.readouterr().out 

465 

466 file_full.unlink() 

467 assert not file_full.exists() 

468 assert not file_part.exists() 

469 

470 # test for overwrite 

471 file_part.write_text("D") # should be overwritten 

472 

473 _utils.fetch_single_file( 

474 url=url, data_dir=tmp_path, resume=True, overwrite=True 

475 ) 

476 

477 assert file_full.exists() 

478 assert file_full.read_text() == "dummy content" # overwritten 

479 

480 

481def test_fetch_single_file_part_error(tmp_path, capsys, request_mocker): 

482 """Check error fetch_single_file.""" 

483 url = "http://foo/temp.txt" 

484 file_part = tmp_path / "temp.txt.part" 

485 file_part.touch() # should not be overwritten 

486 

487 # the default Response from the mocker does not handle Range requests 

488 request_mocker.url_mapping[url] = "dummy content" 

489 

490 _utils.fetch_single_file( 

491 url=url, data_dir=tmp_path, verbose=1, resume=True 

492 ) 

493 

494 assert ( 

495 "Resuming failed, try to download the whole file." 

496 in capsys.readouterr().out 

497 ) 

498 

499 

500def test_fetch_single_file_overwrite(tmp_path, request_mocker): 

501 """Check that fetch_single_file can overwrite files.""" 

502 # overwrite non-exiting file. 

503 fil = _utils.fetch_single_file( 

504 url="http://foo/", data_dir=tmp_path, verbose=0, overwrite=True 

505 ) 

506 

507 assert request_mocker.url_count == 1 

508 assert fil.exists() 

509 assert fil.read_text() == "" 

510 

511 # Modify content 

512 fil.write_text("some content") 

513 

514 # Don't overwrite existing file. 

515 fil = _utils.fetch_single_file( 

516 url="http://foo/", data_dir=tmp_path, verbose=0, overwrite=False 

517 ) 

518 

519 assert request_mocker.url_count == 1 

520 assert fil.exists() 

521 assert fil.read_text() == "some content" 

522 

523 # Overwrite existing file. 

524 fil = _utils.fetch_single_file( 

525 url="http://foo/", data_dir=tmp_path, verbose=0, overwrite=True 

526 ) 

527 

528 assert request_mocker.url_count == 2 

529 assert fil.exists() 

530 assert fil.read_text() == "" 

531 

532 

533@pytest.mark.parametrize("should_cast_path_to_string", [False, True]) 

534def test_fetch_files_use_session( 

535 should_cast_path_to_string, 

536 tmp_path, 

537 request_mocker, # noqa: ARG001 

538): 

539 """Use session parameter of fetch_files.""" 

540 if should_cast_path_to_string: 

541 tmp_path = str(tmp_path) 

542 

543 # regression test for https://github.com/nilearn/nilearn/issues/2863 

544 session = MagicMock() 

545 _utils.fetch_files( 

546 files=[ 

547 ("example1", "https://example.org/example1", {"overwrite": True}), 

548 ("example2", "https://example.org/example2", {"overwrite": True}), 

549 ], 

550 data_dir=tmp_path, 

551 session=session, 

552 ) 

553 

554 assert session.send.call_count == 2 

555 

556 

557@pytest.mark.parametrize("should_cast_path_to_string", [False, True]) 

558def test_fetch_files_overwrite( 

559 should_cast_path_to_string, tmp_path, request_mocker 

560): 

561 """Check that fetch_files can overwrite files.""" 

562 if should_cast_path_to_string: 

563 tmp_path = str(tmp_path) 

564 

565 # overwrite non-exiting file. 

566 files = ("1.txt", "http://foo/1.txt") 

567 fil = Path( 

568 _utils.fetch_files( 

569 data_dir=tmp_path, 

570 verbose=0, 

571 files=[(*files, {"overwrite": True})], 

572 )[0] 

573 ) 

574 

575 assert request_mocker.url_count == 1 

576 assert fil.exists() 

577 assert not fil.read_text() 

578 

579 # Modify content 

580 fil.write_text("some content") 

581 

582 # Don't overwrite existing file. 

583 fil = Path( 

584 _utils.fetch_files( 

585 data_dir=tmp_path, 

586 verbose=0, 

587 files=[(*files, {"overwrite": False})], 

588 )[0] 

589 ) 

590 

591 assert request_mocker.url_count == 1 

592 assert fil.exists() 

593 assert fil.read_text() == "some content" 

594 

595 # Overwrite existing file. 

596 fil = Path( 

597 _utils.fetch_files( 

598 data_dir=tmp_path, 

599 verbose=0, 

600 files=[(*files, {"overwrite": True})], 

601 )[0] 

602 ) 

603 

604 assert request_mocker.url_count == 2 

605 assert fil.exists() 

606 assert not fil.read_text() 

607 

608 

609def test_naive_ftp_adapter(): 

610 """Test _NaiveFTPAdapter error.""" 

611 sender = _utils._NaiveFTPAdapter() 

612 resp = sender.send(requests.Request("GET", "ftp://example.com").prepare()) 

613 resp.close() 

614 resp.raw.close.assert_called_with() 

615 urllib.request.OpenerDirector.open.side_effect = urllib.error.URLError( 

616 "timeout" 

617 ) 

618 with pytest.raises(requests.RequestException, match="timeout"): 

619 resp = sender.send( 

620 requests.Request("GET", "ftp://example.com").prepare() 

621 ) 

622 

623 

624def test_load_sample_motor_activation_image(): 

625 """Test deprecation utils.load_sample_motor_activation_image. 

626 

627 Remove when when version >= 0.13. 

628 """ 

629 with pytest.warns( 

630 DeprecationWarning, 

631 match="Please import this function from 'nilearn.datasets.func'", 

632 ): 

633 utils.load_sample_motor_activation_image()