Coverage for nilearn/datasets/_utils.py: 9%
377 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
1"""Private utility functions to the datasets module."""
3import collections.abc
4import contextlib
5import fnmatch
6import hashlib
7import os
8import pickle
9import shutil
10import tarfile
11import time
12import urllib
13import warnings
14import zipfile
15from pathlib import Path
17import numpy as np
18import requests
20from nilearn._utils import fill_doc, logger
21from nilearn._utils.logger import find_stack_level
23from .utils import get_data_dirs
25_REQUESTS_TIMEOUT = (15.1, 61)
26PACKAGE_DIRECTORY = Path(__file__).absolute().parent
29ALLOWED_DATA_TYPES = (
30 "curvature",
31 "sulcal",
32 "thickness",
33)
35ALLOWED_MESH_TYPES = {
36 "pial",
37 "white_matter",
38 "inflated",
39 "sphere",
40 "flat",
41}
44def md5_hash(string):
45 """Calculate the MD5 hash of a string."""
46 m = hashlib.md5()
47 m.update(string.encode("utf-8"))
48 return m.hexdigest()
51def _format_time(t):
52 return f"{t / 60.0:4.1f}min" if t > 60 else f" {t:5.1f}s"
55def _md5_sum_file(path):
56 """Calculate the MD5 sum of a file."""
57 with Path(path).open("rb") as f:
58 m = hashlib.md5()
59 while True:
60 data = f.read(8192)
61 if data:
62 m.update(data)
63 else:
64 break
65 return m.hexdigest()
68def read_md5_sum_file(path):
69 """Read a MD5 checksum file and returns hashes as a dictionary."""
70 with Path(path).open() as f:
71 hashes = {}
72 while True:
73 line = f.readline()
74 if not line:
75 break
76 h, name = line.rstrip().split(" ", 1)
77 hashes[name] = h
78 return hashes
81def _chunk_report_(bytes_so_far, total_size, initial_size, t0):
82 """Show downloading percentage.
84 Parameters
85 ----------
86 bytes_so_far : int
87 Number of downloaded bytes.
89 total_size : int
90 Total size of the file (may be 0/None, depending on download method).
92 t0 : int
93 The time in seconds (as returned by time.time()) at which the
94 download was resumed / started.
96 initial_size : int
97 If resuming, indicate the initial size of the file.
98 If not resuming, set to zero.
100 """
101 if not total_size:
102 logger.log(f"\rDownloaded {int(bytes_so_far)} of ? bytes.")
104 else:
105 # Estimate remaining download time
106 total_percent = float(bytes_so_far) / total_size
108 current_download_size = bytes_so_far - initial_size
109 bytes_remaining = total_size - bytes_so_far
110 dt = time.time() - t0
111 download_rate = current_download_size / max(1e-8, float(dt))
112 # Minimum rate of 0.01 bytes/s, to avoid dividing by zero.
113 time_remaining = bytes_remaining / max(0.01, download_rate)
115 # Trailing whitespace is to erase extra char when message length varies
116 logger.log(
117 f"\rDownloaded {bytes_so_far} of {total_size} bytes "
118 f"({total_percent * 100:.1f}%%, "
119 f"{_format_time(time_remaining)} remaining)",
120 )
123@fill_doc
124def _chunk_read_(
125 response,
126 local_file,
127 chunk_size=8192,
128 report_hook=None,
129 initial_size=0,
130 total_size=None,
131 verbose=1,
132):
133 """Download a file chunk by chunk and show advancement.
135 Parameters
136 ----------
137 response : urllib.response.addinfourl
138 Response to the download request in order to get file size.
140 local_file : file
141 Hard disk file where data should be written.
143 chunk_size : int, default=8192
144 Size of downloaded chunks.
146 report_hook : bool, optional
147 Whether or not to show downloading advancement. Default: None
149 initial_size : int, default=0
150 If resuming, indicate the initial size of the file.
152 total_size : int, optional
153 Expected final size of download (None means it is unknown).
154 %(verbose)s
156 Returns
157 -------
158 data : string
159 The downloaded file.
161 """
162 try:
163 if total_size is None:
164 total_size = response.headers.get("Content-Length").strip()
165 total_size = int(total_size) + initial_size
166 except Exception as e:
167 logger.log(
168 "Warning: total size could not be determined.",
169 verbose=verbose,
170 msg_level=2,
171 )
172 logger.log(
173 f"Full stack trace: {e}",
174 verbose=verbose,
175 msg_level=3,
176 )
177 total_size = None
178 bytes_so_far = initial_size
180 t0 = time_last_display = time.time()
181 for chunk in response.iter_content(chunk_size):
182 bytes_so_far += len(chunk)
183 time_last_read = time.time()
184 if (
185 report_hook
186 and
187 # Refresh report every second or when download is
188 # finished.
189 (time_last_read > time_last_display + 1.0 or not chunk)
190 ):
191 _chunk_report_(bytes_so_far, total_size, initial_size, t0)
192 time_last_display = time_last_read
193 if chunk:
194 local_file.write(chunk)
195 else:
196 break
199@fill_doc
200def get_dataset_dir(
201 dataset_name, data_dir=None, default_paths=None, verbose=1
202) -> Path:
203 """Create if necessary and return data directory of given dataset.
205 Parameters
206 ----------
207 dataset_name : string
208 The unique name of the dataset.
209 %(data_dir)s
210 default_paths : list of string, optional
211 Default system paths in which the dataset may already have been
212 installed by a third party software. They will be checked first.
213 %(verbose)s
215 Returns
216 -------
217 data_dir : pathlib.Path
218 Path of the given dataset directory.
220 Notes
221 -----
222 This function retrieves the datasets directory (or data directory) using
223 the following priority :
225 1. defaults system paths
226 2. the keyword argument data_dir
227 3. the global environment variable NILEARN_SHARED_DATA
228 4. the user environment variable NILEARN_DATA
229 5. nilearn_data in the user home folder
231 """
232 paths = []
233 # Search possible data-specific system paths
234 if default_paths is not None:
235 for default_path in default_paths:
236 paths.extend(
237 [(Path(d), True) for d in str(default_path).split(os.pathsep)]
238 )
240 paths.extend([(Path(d), False) for d in get_data_dirs(data_dir=data_dir)])
242 logger.log(f"Dataset search paths: {paths}", verbose=verbose, msg_level=2)
244 # Check if the dataset exists somewhere
245 for path, is_pre_dir in paths:
246 if not is_pre_dir:
247 path = path / dataset_name
248 if path.is_symlink():
249 # Resolve path
250 path = path.resolve()
251 if path.exists() and path.is_dir():
252 logger.log(
253 f"Dataset found in {path}", verbose=verbose, msg_level=1
254 )
255 return path
257 # If not, create a folder in the first writable directory
258 errors = []
259 for path, is_pre_dir in paths:
260 if not is_pre_dir:
261 path = path / dataset_name
262 if not path.exists():
263 try:
264 path.mkdir(parents=True)
265 _add_readme_to_default_data_locations(
266 data_dir=data_dir,
267 verbose=verbose,
268 )
270 logger.log(f"Dataset created in {path}", verbose)
272 return path
273 except Exception as exc:
274 short_error_message = getattr(exc, "strerror", str(exc))
275 errors.append(f"\n -{path} ({short_error_message})")
277 raise OSError(
278 "Nilearn tried to store the dataset in the following "
279 f"directories, but: {''.join(errors)}"
280 )
283def _add_readme_to_default_data_locations(data_dir=None, verbose=1):
284 for d in get_data_dirs(data_dir=data_dir):
285 file = Path(d) / "README.md"
286 if file.parent.exists() and not file.exists():
287 with file.open("w") as f:
288 f.write(
289 """# Nilearn data folder
291This directory is used by Nilearn to store datasets
292and atlases downloaded from the internet.
293It can be safely deleted.
294If you delete it, previously downloaded data will be downloaded again."""
295 )
297 logger.log(f"Added README.md to {d}", verbose=verbose)
300# The functions _is_within_directory and _safe_extract were implemented in
301# https://github.com/nilearn/nilearn/pull/3391 to address a directory
302# traversal vulnerability https://github.com/advisories/GHSA-gw9q-c7gh-j9vm
303def _is_within_directory(directory, target):
304 abs_directory = Path(directory).resolve().absolute()
305 abs_target = Path(target).resolve().absolute()
307 prefix = os.path.commonprefix([abs_directory, abs_target])
309 return prefix == str(abs_directory)
312def _safe_extract(tar, path=".", members=None, *, numeric_owner=False):
313 path = Path(path)
314 for member in tar.getmembers():
315 member_path = path / member.name
316 if not _is_within_directory(path, member_path):
317 raise Exception("Attempted Path Traversal in Tar File")
319 tar.extractall(path, members, numeric_owner=numeric_owner)
322@fill_doc
323def uncompress_file(file_, delete_archive=True, verbose=1):
324 """Uncompress files contained in a data_set.
326 Parameters
327 ----------
328 file_ : string
329 Path of file to be uncompressed.
331 delete_archive : bool, default=True
332 Whether or not to delete archive once it is uncompressed.
333 %(verbose)s
335 Notes
336 -----
337 This handles zip, tar, gzip and bzip files only.
339 """
340 logger.log(f"Extracting data from {file_}...", verbose=verbose)
342 file_ = Path(file_)
343 data_dir = file_.parent
345 # We first try to see if it is a zip file
346 try:
347 filename = data_dir / file_.stem
348 with file_.open("rb") as fd:
349 header = fd.read(4)
350 processed = False
351 if zipfile.is_zipfile(file_):
352 z = zipfile.ZipFile(file_)
353 z.extractall(path=data_dir)
354 z.close()
355 if delete_archive:
356 file_.unlink()
357 processed = True
358 elif file_.suffix == ".gz" or header.startswith(b"\x1f\x8b"):
359 import gzip
361 if file_.suffix == ".tgz":
362 filename = filename.with_suffix(".tar")
363 elif not file_.suffix:
364 # We rely on the assumption that gzip files have an extension
365 shutil.move(file_, f"{file_}.gz")
366 file_ = file_.with_suffix(".gz")
367 with gzip.open(file_) as gz, filename.open("wb") as out:
368 shutil.copyfileobj(gz, out, 8192)
369 # If file is .tar.gz, this will be handled in the next case
370 if delete_archive:
371 file_.unlink()
372 file_ = filename
373 processed = True
374 if file_.is_file() and tarfile.is_tarfile(file_):
375 with contextlib.closing(tarfile.open(file_, "r")) as tar:
376 _safe_extract(tar, path=data_dir)
377 if delete_archive:
378 file_.unlink()
379 processed = True
380 if not processed:
381 raise OSError(f"[Uncompress] unknown archive file format: {file_}")
383 logger.log(".. done.\n", verbose=verbose)
385 except Exception as e:
386 logger.log(f"Error uncompressing file: {e}", verbose=verbose)
387 raise
390def _filter_column(array, col, criteria):
391 """Return index array matching criteria.
393 Parameters
394 ----------
395 array : numpy array with columns
396 Array in which data will be filtered.
398 col : string
399 Name of the column.
401 criteria : integer (or float), pair of integers, string or list of these
402 if integer, select elements in column matching integer
403 if a tuple, select elements between the limits given by the tuple
404 if a string, select elements that match the string
406 """
407 # Raise an error if the column does not exist. This is the only way to
408 # test it across all possible types (pandas, recarray...)
409 try:
410 array[col]
411 except Exception:
412 raise KeyError(f"Filtering criterion {col} does not exist")
414 if (
415 not isinstance(criteria, str)
416 and not isinstance(criteria, bytes)
417 and not isinstance(criteria, tuple)
418 and isinstance(criteria, collections.abc.Iterable)
419 ):
420 filter = np.zeros(array.shape[0], dtype=bool)
421 for criterion in criteria:
422 filter = np.logical_or(
423 filter, _filter_column(array, col, criterion)
424 )
425 return filter
427 if isinstance(criteria, tuple):
428 if len(criteria) != 2:
429 raise ValueError("An interval must have 2 values")
430 if criteria[0] is None:
431 return array[col] <= criteria[1]
432 if criteria[1] is None:
433 return array[col] >= criteria[0]
434 filter = array[col] <= criteria[1]
435 return np.logical_and(filter, array[col] >= criteria[0])
437 # Handle strings with different encodings
438 if isinstance(criteria, (str, bytes)):
439 criteria = np.array(criteria).astype(array[col].dtype)
441 return array[col] == criteria
444def filter_columns(array, filters, combination="and"):
445 """Return indices of recarray entries that match criteria.
447 Parameters
448 ----------
449 array : numpy array with columns
450 Array in which data will be filtered.
452 filters : list of criteria
453 See _filter_column.
455 combination : string {'and', 'or'}, default='and'
456 String describing the combination operator. Possible values are "and"
457 and "or".
459 """
460 if combination == "and":
461 fcomb = np.logical_and
462 mask = np.ones(array.shape[0], dtype=bool)
463 elif combination == "or":
464 fcomb = np.logical_or
465 mask = np.zeros(array.shape[0], dtype=bool)
466 else:
467 raise ValueError(f"Combination mode not known: {combination}")
469 for column in filters:
470 mask = fcomb(mask, _filter_column(array, column, filters[column]))
471 return mask
474class _NaiveFTPAdapter(requests.adapters.BaseAdapter):
475 def send(
476 self,
477 request,
478 timeout=None,
479 **kwargs, # noqa: ARG002
480 ):
481 with contextlib.suppress(Exception):
482 timeout, _ = timeout
483 try:
484 data = urllib.request.urlopen(request.url, timeout=timeout)
485 except Exception as e:
486 raise requests.RequestException(e.reason)
487 data.release_conn = data.close
488 resp = requests.Response()
489 resp.url = data.geturl()
490 resp.status_code = data.getcode() or 200
491 resp.raw = data
492 resp.headers = dict(data.info().items())
493 return resp
495 def close(self):
496 pass
499@fill_doc
500def fetch_single_file(
501 url,
502 data_dir,
503 resume=True,
504 overwrite=False,
505 md5sum=None,
506 username=None,
507 password=None,
508 verbose=1,
509 session=None,
510):
511 """Load requested file, downloading it if needed or requested.
513 Parameters
514 ----------
515 %(url)s
516 %(data_dir)s
517 %(resume)s
518 overwrite : bool, default=False
519 If true and file already exists, delete it.
521 md5sum : string, optional
522 MD5 sum of the file. Checked if download of the file is required.
524 username : string, optional
525 Username used for basic HTTP authentication.
527 password : string, optional
528 Password used for basic HTTP authentication.
529 %(verbose)s
530 session : requests.Session, optional
531 Session to use to send requests.
533 Returns
534 -------
535 files : pahtlib.Path
536 Absolute path of downloaded file.
538 Notes
539 -----
540 If, for any reason, the download procedure fails, all downloaded files are
541 removed.
543 """
544 if session is None:
545 with requests.Session() as sess:
546 sess.mount("ftp:", _NaiveFTPAdapter())
547 return fetch_single_file(
548 url,
549 data_dir,
550 resume=resume,
551 overwrite=overwrite,
552 md5sum=md5sum,
553 username=username,
554 password=password,
555 verbose=verbose,
556 session=sess,
557 )
559 # Determine data path
560 data_dir.mkdir(parents=True, exist_ok=True)
562 # Determine filename using URL
563 parse = urllib.parse.urlparse(url)
564 file_name = Path(parse.path).name
565 if file_name == "":
566 file_name = md5_hash(parse.path)
568 temp_file_name = f"{file_name}.part"
569 full_name = data_dir / file_name
570 temp_full_name = data_dir / temp_file_name
571 if full_name.exists():
572 if overwrite:
573 full_name.unlink()
574 else:
575 return full_name
576 if temp_full_name.exists() and overwrite:
577 temp_full_name.unlink()
578 t0 = time.time()
579 initial_size = 0
581 try:
582 # Download data
583 headers = {}
584 auth = None
585 if username is not None and password is not None:
586 if not url.startswith("https"):
587 raise ValueError(
588 "Authentication was requested "
589 f"on a non secured URL ({url})."
590 "Request has been blocked for security reasons."
591 )
592 auth = (username, password)
594 displayed_url = url.split("?")[0] if verbose == 1 else url
595 logger.log(f"Downloading data from {displayed_url} ...", verbose)
597 if resume and temp_full_name.exists():
598 # Download has been interrupted, we try to resume it.
599 local_file_size = temp_full_name.stat().st_size
600 # If the file exists, then only download the remainder
601 headers["Range"] = f"bytes={local_file_size}-"
602 try:
603 req = requests.Request(
604 method="GET", url=url, headers=headers, auth=auth
605 )
606 prepped = session.prepare_request(req)
607 with session.send(
608 prepped, stream=True, timeout=_REQUESTS_TIMEOUT
609 ) as resp:
610 resp.raise_for_status()
611 content_range = resp.headers.get("Content-Range")
612 if content_range is None or not content_range.startswith(
613 f"bytes {local_file_size}-"
614 ):
615 raise OSError("Server does not support resuming")
616 initial_size = local_file_size
617 with temp_full_name.open("ab") as fh:
618 _chunk_read_(
619 resp,
620 fh,
621 report_hook=(verbose > 0),
622 initial_size=initial_size,
623 verbose=verbose,
624 )
625 except OSError:
626 logger.log(
627 "Resuming failed, try to download the whole file.", verbose
628 )
629 return fetch_single_file(
630 url,
631 data_dir,
632 resume=False,
633 overwrite=overwrite,
634 md5sum=md5sum,
635 username=username,
636 password=password,
637 verbose=verbose,
638 session=session,
639 )
640 else:
641 req = requests.Request(
642 method="GET", url=url, headers=headers, auth=auth
643 )
644 prepped = session.prepare_request(req)
645 with session.send(
646 prepped, stream=True, timeout=_REQUESTS_TIMEOUT
647 ) as resp:
648 resp.raise_for_status()
649 with temp_full_name.open("wb") as fh:
650 _chunk_read_(
651 resp,
652 fh,
653 report_hook=(verbose > 0),
654 initial_size=initial_size,
655 verbose=verbose,
656 )
657 shutil.move(temp_full_name, full_name)
658 dt = time.time() - t0
660 # Complete the reporting hook
661 logger.log(
662 f" ...done. ({dt:.0f} seconds, {dt // 60:.0f} min)\n",
663 verbose=verbose,
664 )
665 except requests.RequestException:
666 logger.log(
667 f"Error while fetching file {file_name}; dataset fetching aborted."
668 )
669 raise
670 if md5sum is not None and _md5_sum_file(full_name) != md5sum:
671 raise ValueError(
672 f"File {full_name} checksum verification has failed."
673 " Dataset fetching aborted."
674 )
675 return full_name
678def get_dataset_descr(ds_name):
679 """Return the description of a dataset."""
680 try:
681 with (PACKAGE_DIRECTORY / "description" / f"{ds_name}.rst").open(
682 "rb"
683 ) as rst_file:
684 descr = rst_file.read()
685 except OSError:
686 descr = ""
688 if not descr:
689 warnings.warn(
690 "Could not find dataset description.",
691 stacklevel=find_stack_level(),
692 )
694 if isinstance(descr, bytes):
695 descr = descr.decode("utf-8")
697 return descr
700def movetree(src, dst):
701 """Move entire tree under `src` inside `dst`.
703 Creates `dst` if it does not already exist.
705 Any existing file is overwritten.
707 The difference with `shutil.mv` is that `shutil.mv` moves `src` under `dst`
708 if `dst` already exists.
709 """
710 src = Path(src)
712 # Create destination dir if it does not exist
713 dst = Path(dst)
714 dst.mkdir(parents=True, exist_ok=True)
716 errors = []
718 for srcfile in src.iterdir():
719 dstfile = dst / srcfile.name
720 try:
721 if srcfile.is_dir() and dstfile.is_dir():
722 movetree(srcfile, dstfile)
723 srcfile.rmdir()
724 else:
725 shutil.move(srcfile, dstfile)
726 except OSError as why:
727 errors.append((srcfile, dstfile, str(why)))
728 # catch the Error from the recursive movetree so that we can
729 # continue with other files
730 except Exception as err:
731 errors.extend(err.args[0])
732 if errors:
733 raise Exception(errors)
736@fill_doc
737def fetch_files(data_dir, files, resume=True, verbose=1, session=None):
738 """Load requested dataset, downloading it if needed or requested.
740 This function retrieves files from the hard drive or download them from
741 the given urls. Note to developers: All the files will be first
742 downloaded in a sandbox and, if everything goes well, they will be moved
743 into the folder of the dataset. This prevents corrupting previously
744 downloaded data. In case of a big dataset, do not hesitate to make several
745 calls if needed.
747 Parameters
748 ----------
749 %(data_dir)s
750 files : list of (string, string, dict)
751 List of files and their corresponding url with dictionary that contains
752 options regarding the files. Eg. (file_path, url, opt). If a file_path
753 is not found in data_dir, as in data_dir/file_path the download will
754 be immediately canceled and any downloaded files will be deleted.
755 Options supported are:
756 * 'move' if renaming the file or moving it to a subfolder is needed
757 * 'uncompress' to indicate that the file is an archive
758 * 'md5sum' to check the md5 sum of the file
759 * 'overwrite' if the file should be re-downloaded even if it exists
760 %(resume)s
761 %(verbose)s
762 session : `requests.Session`, optional
763 Session to use to send requests.
765 Returns
766 -------
767 files : list of string
768 Absolute paths of downloaded files on disk.
770 """
771 if session is None:
772 with requests.Session() as sess:
773 sess.mount("ftp:", _NaiveFTPAdapter())
774 return fetch_files(
775 data_dir,
776 files,
777 resume=resume,
778 verbose=verbose,
779 session=sess,
780 )
781 # There are two working directories here:
782 # - data_dir is the destination directory of the dataset
783 # - temp_dir is a temporary directory dedicated to this fetching call. All
784 # files that must be downloaded will be in this directory. If a corrupted
785 # file is found, or a file is missing, this working directory will be
786 # deleted.
787 files = list(files)
788 files_pickle = pickle.dumps([(file_, url) for file_, url, _ in files])
789 files_md5 = hashlib.md5(files_pickle).hexdigest()
790 data_dir = Path(data_dir)
791 temp_dir = data_dir / files_md5
793 # Create destination dir
794 data_dir.mkdir(parents=True, exist_ok=True)
796 # Abortion flag, in case of error
797 abort = None
799 files_ = []
800 for file_, url, opts in files:
801 # 3 possibilities:
802 # - the file exists in data_dir, nothing to do.
803 # - the file does not exists: we download it in temp_dir
804 # - the file exists in temp_dir: this can happen if an archive has been
805 # downloaded. There is nothing to do
807 # Target file in the data_dir
808 target_file = data_dir / file_
809 # Target file in temp dir
810 temp_target_file = temp_dir / file_
811 # Whether to keep existing files
812 overwrite = opts.get("overwrite", False)
813 if abort is None and (
814 overwrite
815 or (not target_file.exists() and not temp_target_file.exists())
816 ):
817 # We may be in a global read-only repository. If so, we cannot
818 # download files.
819 if not os.access(data_dir, os.W_OK):
820 raise ValueError(
821 "Dataset files are missing but dataset"
822 " repository is read-only. Contact your data"
823 " administrator to solve the problem"
824 )
826 temp_dir.mkdir(parents=True, exist_ok=True)
827 md5sum = opts.get("md5sum", None)
829 dl_file = fetch_single_file(
830 url,
831 temp_dir,
832 resume=resume,
833 verbose=verbose,
834 md5sum=md5sum,
835 username=opts.get("username", None),
836 password=opts.get("password", None),
837 session=session,
838 overwrite=overwrite,
839 )
840 if "move" in opts:
841 # XXX: here, move is supposed to be a dir, it can be a name
842 move = temp_dir / opts["move"]
843 move_dir = move.parent
844 move_dir.mkdir(parents=True, exist_ok=True)
845 shutil.move(dl_file, move)
846 dl_file = move
847 if "uncompress" in opts:
848 try:
849 uncompress_file(dl_file, verbose=verbose)
850 except Exception as e:
851 abort = str(e)
853 if (
854 abort is None
855 and not target_file.exists()
856 and not temp_target_file.exists()
857 ):
858 warnings.warn(
859 f"An error occurred while fetching {file_}",
860 stacklevel=find_stack_level(),
861 )
862 abort = (
863 "Dataset has been downloaded but requested file was "
864 f"not provided:\nURL: {url}\n"
865 f"Target file: {target_file}\nDownloaded: {dl_file}"
866 )
867 if abort is not None:
868 if temp_dir.exists():
869 shutil.rmtree(temp_dir)
870 raise OSError(f"Fetching aborted: {abort}")
871 files_.append(str(target_file))
872 # If needed, move files from temps directory to final directory.
873 if temp_dir.exists():
874 # XXX We could only moved the files requested
875 # XXX Movetree can go wrong
876 movetree(temp_dir, data_dir)
877 shutil.rmtree(temp_dir)
878 return files_
881def tree(path, pattern=None, dictionary=False):
882 """Return a directory tree under the form of a dictionary or list.
884 Parameters
885 ----------
886 path : string or pathlib.Path
887 Path browsed.
889 pattern : string, optional
890 Pattern used to filter files (see fnmatch).
892 dictionary : boolean, default=False
893 If True, the function will return a dict instead of a list.
895 """
896 path = Path(path)
897 files = []
898 dirs = {} if dictionary else []
900 for file_path in path.iterdir():
901 if file_path.is_dir():
902 if dictionary:
903 dirs[file_path.name] = tree(file_path, pattern, dictionary)
904 else:
905 dirs.append(
906 (file_path.name, tree(file_path, pattern, dictionary))
907 )
908 elif pattern is None or fnmatch.fnmatch(file_path.name, pattern):
909 files.append(str(file_path))
910 files = sorted(files)
911 if not dictionary:
912 return sorted(dirs) + files
913 if len(dirs) == 0:
914 return files
915 if len(files) > 0:
916 dirs["."] = files
917 return dirs