muutils.parallel
parallel processing utilities, chiefly run_maybe_parallel
1"parallel processing utilities, chiefly `run_maybe_parallel`" 2 3from __future__ import annotations 4 5import multiprocessing 6import functools 7from typing import ( 8 Any, 9 Callable, 10 Iterable, 11 Literal, 12 Optional, 13 Tuple, 14 TypeVar, 15 Dict, 16 List, 17 Union, 18 Protocol, 19) 20 21# for no tqdm fallback 22from muutils.spinner import SpinnerContext 23from muutils.validate_type import get_fn_allowed_kwargs 24 25 26InputType = TypeVar("InputType") 27OutputType = TypeVar("OutputType") 28# typevars for our iterable and map 29 30 31class ProgressBarFunction(Protocol): 32 "a protocol for a progress bar function" 33 34 def __call__(self, iterable: Iterable, **kwargs: Any) -> Iterable: ... 35 36 37ProgressBarOption = Literal["tqdm", "spinner", "none", None] 38# type for the progress bar option 39 40 41DEFAULT_PBAR_FN: ProgressBarOption 42# default progress bar function 43 44try: 45 # use tqdm if it's available 46 import tqdm # type: ignore[import-untyped] 47 48 DEFAULT_PBAR_FN = "tqdm" 49 50except ImportError: 51 # use progress bar as fallback 52 DEFAULT_PBAR_FN = "spinner" 53 54 55def spinner_fn_wrap(x: Iterable, **kwargs) -> List: 56 "spinner wrapper" 57 spinnercontext_allowed_kwargs: set[str] = get_fn_allowed_kwargs( 58 SpinnerContext.__init__ 59 ) 60 mapped_kwargs: dict = { 61 k: v for k, v in kwargs.items() if k in spinnercontext_allowed_kwargs 62 } 63 if "desc" in kwargs and "message" not in mapped_kwargs: 64 mapped_kwargs["message"] = kwargs["desc"] 65 66 if "message" not in mapped_kwargs and "total" in kwargs: 67 mapped_kwargs["message"] = f"Processing {kwargs['total']} items" 68 69 with SpinnerContext(**mapped_kwargs): 70 output = list(x) 71 72 return output 73 74 75def map_kwargs_for_tqdm(kwargs: dict) -> dict: 76 "map kwargs for tqdm, cant wrap because the pbar dissapears?" 77 tqdm_allowed_kwargs: set[str] = get_fn_allowed_kwargs(tqdm.tqdm.__init__) 78 mapped_kwargs: dict = {k: v for k, v in kwargs.items() if k in tqdm_allowed_kwargs} 79 80 if "desc" not in kwargs: 81 if "message" in kwargs: 82 mapped_kwargs["desc"] = kwargs["message"] 83 84 elif "total" in kwargs: 85 mapped_kwargs["desc"] = f"Processing {kwargs.get('total')} items" 86 return mapped_kwargs 87 88 89def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable: 90 "fallback to no progress bar" 91 return x 92 93 94def set_up_progress_bar_fn( 95 pbar: Union[ProgressBarFunction, ProgressBarOption], 96 pbar_kwargs: Optional[Dict[str, Any]] = None, 97 **extra_kwargs, 98) -> Tuple[ProgressBarFunction, dict]: 99 """set up the progress bar function and its kwargs 100 101 # Parameters: 102 - `pbar : Union[ProgressBarFunction, ProgressBarOption]` 103 progress bar function or option. if a function, we return as-is. if a string, we figure out which progress bar to use 104 - `pbar_kwargs : Optional[Dict[str, Any]]` 105 kwargs passed to the progress bar function (default to `None`) 106 (defaults to `None`) 107 108 # Returns: 109 - `Tuple[ProgressBarFunction, dict]` 110 a tuple of the progress bar function and its kwargs 111 112 # Raises: 113 - `ValueError` : if `pbar` is not one of the valid options 114 """ 115 pbar_fn: ProgressBarFunction 116 117 if pbar_kwargs is None: 118 pbar_kwargs = dict() 119 120 pbar_kwargs = {**extra_kwargs, **pbar_kwargs} 121 122 # dont use a progress bar if `pbar` is None or "none", or if `disable` is set to True in `pbar_kwargs` 123 if (pbar is None) or (pbar == "none") or pbar_kwargs.get("disable", False): 124 pbar_fn = no_progress_fn_wrap # type: ignore[assignment] 125 126 # if `pbar` is a different string, figure out which progress bar to use 127 elif isinstance(pbar, str): 128 if pbar == "tqdm": 129 pbar_fn = tqdm.tqdm 130 pbar_kwargs = map_kwargs_for_tqdm(pbar_kwargs) 131 elif pbar == "spinner": 132 pbar_fn = functools.partial(spinner_fn_wrap, **pbar_kwargs) 133 pbar_kwargs = dict() 134 else: 135 raise ValueError( 136 f"`pbar` must be either 'tqdm' or 'spinner' if `str`, or a valid callable, got {type(pbar) = } {pbar = }" 137 ) 138 else: 139 # the default value is a callable which will resolve to tqdm if available or spinner as a fallback. we pass kwargs to this 140 pbar_fn = pbar 141 142 return pbar_fn, pbar_kwargs 143 144 145# TODO: if `parallel` is a negative int, use `multiprocessing.cpu_count() + parallel` to determine the number of processes 146def run_maybe_parallel( 147 func: Callable[[InputType], OutputType], 148 iterable: Iterable[InputType], 149 parallel: Union[bool, int], 150 pbar_kwargs: Optional[Dict[str, Any]] = None, 151 chunksize: Optional[int] = None, 152 keep_ordered: bool = True, 153 use_multiprocess: bool = False, 154 pbar: Union[ProgressBarFunction, ProgressBarOption] = DEFAULT_PBAR_FN, 155) -> List[OutputType]: 156 """a function to make it easier to sometimes parallelize an operation 157 158 - if `parallel` is `False`, then the function will run in serial, running `map(func, iterable)` 159 - if `parallel` is `True`, then the function will run in parallel, running in parallel with the maximum number of processes 160 - if `parallel` is an `int`, it must be greater than 1, and the function will run in parallel with the number of processes specified by `parallel` 161 162 the maximum number of processes is given by the `min(len(iterable), multiprocessing.cpu_count())` 163 164 # Parameters: 165 - `func : Callable[[InputType], OutputType]` 166 function passed to either `map` or `Pool.imap` 167 - `iterable : Iterable[InputType]` 168 iterable passed to either `map` or `Pool.imap` 169 - `parallel : bool | int` 170 whether to run in parallel, and how many processes to use 171 - `pbar_kwargs : Dict[str, Any]` 172 kwargs passed to the progress bar function 173 174 # Returns: 175 - `List[OutputType]` 176 a list of the output of `func` for each element in `iterable` 177 178 # Raises: 179 - `ValueError` : if `parallel` is not a boolean or an integer greater than 1 180 - `ValueError` : if `use_multiprocess=True` and `parallel=False` 181 - `ImportError` : if `use_multiprocess=True` and `multiprocess` is not available 182 """ 183 184 # number of inputs in iterable 185 n_inputs: int = len(iterable) # type: ignore[arg-type] 186 if n_inputs == 0: 187 # Return immediately if there is no input 188 return list() 189 190 # which progress bar to use 191 pbar_fn: ProgressBarFunction 192 pbar_kwargs_processed: dict 193 pbar_fn, pbar_kwargs_processed = set_up_progress_bar_fn( 194 pbar=pbar, 195 pbar_kwargs=pbar_kwargs, 196 # extra kwargs 197 total=n_inputs, 198 ) 199 200 # number of processes 201 num_processes: int 202 if isinstance(parallel, bool): 203 num_processes = multiprocessing.cpu_count() if parallel else 1 204 elif isinstance(parallel, int): 205 if parallel < 2: 206 raise ValueError( 207 f"`parallel` must be a boolean, or be an integer greater than 1, got {type(parallel) = } {parallel = }" 208 ) 209 num_processes = parallel 210 else: 211 raise ValueError( 212 f"The 'parallel' parameter must be a boolean or an integer, got {type(parallel) = } {parallel = }" 213 ) 214 215 # make sure we don't have more processes than iterable, and don't bother with parallel if there's only one process 216 num_processes = min(num_processes, n_inputs) 217 mp = multiprocessing 218 if num_processes == 1: 219 parallel = False 220 221 if use_multiprocess: 222 if not parallel: 223 raise ValueError("`use_multiprocess=True` requires `parallel=True`") 224 225 try: 226 import multiprocess # type: ignore[import-untyped] 227 except ImportError as e: 228 raise ImportError( 229 "`use_multiprocess=True` requires the `multiprocess` package -- this is mostly useful when you need to pickle a lambda. install muutils with `pip install muutils[multiprocess]` or just do `pip install multiprocess`" 230 ) from e 231 232 mp = multiprocess 233 234 # set up the map function -- maybe its parallel, maybe it's just `map` 235 do_map: Callable[ 236 [Callable[[InputType], OutputType], Iterable[InputType]], 237 Iterable[OutputType], 238 ] 239 if parallel: 240 # use `mp.Pool` since we might want to use `multiprocess` instead of `multiprocessing` 241 pool = mp.Pool(num_processes) 242 243 # use `imap` if we want to keep the order, otherwise use `imap_unordered` 244 if keep_ordered: 245 do_map = pool.imap 246 else: 247 do_map = pool.imap_unordered 248 249 # figure out a smart chunksize if one is not given 250 chunksize_int: int 251 if chunksize is None: 252 chunksize_int = max(1, n_inputs // num_processes) 253 else: 254 chunksize_int = chunksize 255 256 # set the chunksize 257 do_map = functools.partial(do_map, chunksize=chunksize_int) # type: ignore 258 259 else: 260 do_map = map 261 262 # run the map function with a progress bar 263 output: List[OutputType] = list( 264 pbar_fn( 265 do_map( 266 func, 267 iterable, 268 ), 269 **pbar_kwargs_processed, 270 ) 271 ) 272 273 # close the pool if we used one 274 if parallel: 275 pool.close() 276 pool.join() 277 278 # return the output as a list 279 return output
class
ProgressBarFunction(typing.Protocol):
32class ProgressBarFunction(Protocol): 33 "a protocol for a progress bar function" 34 35 def __call__(self, iterable: Iterable, **kwargs: Any) -> Iterable: ...
a protocol for a progress bar function
ProgressBarFunction(*args, **kwargs)
1945def _no_init_or_replace_init(self, *args, **kwargs): 1946 cls = type(self) 1947 1948 if cls._is_protocol: 1949 raise TypeError('Protocols cannot be instantiated') 1950 1951 # Already using a custom `__init__`. No need to calculate correct 1952 # `__init__` to call. This can lead to RecursionError. See bpo-45121. 1953 if cls.__init__ is not _no_init_or_replace_init: 1954 return 1955 1956 # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`. 1957 # The first instantiation of the subclass will call `_no_init_or_replace_init` which 1958 # searches for a proper new `__init__` in the MRO. The new `__init__` 1959 # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent 1960 # instantiation of the protocol subclass will thus use the new 1961 # `__init__` and no longer call `_no_init_or_replace_init`. 1962 for base in cls.__mro__: 1963 init = base.__dict__.get('__init__', _no_init_or_replace_init) 1964 if init is not _no_init_or_replace_init: 1965 cls.__init__ = init 1966 break 1967 else: 1968 # should not happen 1969 cls.__init__ = object.__init__ 1970 1971 cls.__init__(self, *args, **kwargs)
ProgressBarOption =
typing.Literal['tqdm', 'spinner', 'none', None]
DEFAULT_PBAR_FN: Literal['tqdm', 'spinner', 'none', None] =
'tqdm'
def
spinner_fn_wrap(x: Iterable, **kwargs) -> List:
56def spinner_fn_wrap(x: Iterable, **kwargs) -> List: 57 "spinner wrapper" 58 spinnercontext_allowed_kwargs: set[str] = get_fn_allowed_kwargs( 59 SpinnerContext.__init__ 60 ) 61 mapped_kwargs: dict = { 62 k: v for k, v in kwargs.items() if k in spinnercontext_allowed_kwargs 63 } 64 if "desc" in kwargs and "message" not in mapped_kwargs: 65 mapped_kwargs["message"] = kwargs["desc"] 66 67 if "message" not in mapped_kwargs and "total" in kwargs: 68 mapped_kwargs["message"] = f"Processing {kwargs['total']} items" 69 70 with SpinnerContext(**mapped_kwargs): 71 output = list(x) 72 73 return output
spinner wrapper
def
map_kwargs_for_tqdm(kwargs: dict) -> dict:
76def map_kwargs_for_tqdm(kwargs: dict) -> dict: 77 "map kwargs for tqdm, cant wrap because the pbar dissapears?" 78 tqdm_allowed_kwargs: set[str] = get_fn_allowed_kwargs(tqdm.tqdm.__init__) 79 mapped_kwargs: dict = {k: v for k, v in kwargs.items() if k in tqdm_allowed_kwargs} 80 81 if "desc" not in kwargs: 82 if "message" in kwargs: 83 mapped_kwargs["desc"] = kwargs["message"] 84 85 elif "total" in kwargs: 86 mapped_kwargs["desc"] = f"Processing {kwargs.get('total')} items" 87 return mapped_kwargs
map kwargs for tqdm, cant wrap because the pbar dissapears?
def
no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable:
90def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable: 91 "fallback to no progress bar" 92 return x
fallback to no progress bar
def
set_up_progress_bar_fn( pbar: Union[ProgressBarFunction, Literal['tqdm', 'spinner', 'none', None]], pbar_kwargs: Optional[Dict[str, Any]] = None, **extra_kwargs) -> Tuple[ProgressBarFunction, dict]:
95def set_up_progress_bar_fn( 96 pbar: Union[ProgressBarFunction, ProgressBarOption], 97 pbar_kwargs: Optional[Dict[str, Any]] = None, 98 **extra_kwargs, 99) -> Tuple[ProgressBarFunction, dict]: 100 """set up the progress bar function and its kwargs 101 102 # Parameters: 103 - `pbar : Union[ProgressBarFunction, ProgressBarOption]` 104 progress bar function or option. if a function, we return as-is. if a string, we figure out which progress bar to use 105 - `pbar_kwargs : Optional[Dict[str, Any]]` 106 kwargs passed to the progress bar function (default to `None`) 107 (defaults to `None`) 108 109 # Returns: 110 - `Tuple[ProgressBarFunction, dict]` 111 a tuple of the progress bar function and its kwargs 112 113 # Raises: 114 - `ValueError` : if `pbar` is not one of the valid options 115 """ 116 pbar_fn: ProgressBarFunction 117 118 if pbar_kwargs is None: 119 pbar_kwargs = dict() 120 121 pbar_kwargs = {**extra_kwargs, **pbar_kwargs} 122 123 # dont use a progress bar if `pbar` is None or "none", or if `disable` is set to True in `pbar_kwargs` 124 if (pbar is None) or (pbar == "none") or pbar_kwargs.get("disable", False): 125 pbar_fn = no_progress_fn_wrap # type: ignore[assignment] 126 127 # if `pbar` is a different string, figure out which progress bar to use 128 elif isinstance(pbar, str): 129 if pbar == "tqdm": 130 pbar_fn = tqdm.tqdm 131 pbar_kwargs = map_kwargs_for_tqdm(pbar_kwargs) 132 elif pbar == "spinner": 133 pbar_fn = functools.partial(spinner_fn_wrap, **pbar_kwargs) 134 pbar_kwargs = dict() 135 else: 136 raise ValueError( 137 f"`pbar` must be either 'tqdm' or 'spinner' if `str`, or a valid callable, got {type(pbar) = } {pbar = }" 138 ) 139 else: 140 # the default value is a callable which will resolve to tqdm if available or spinner as a fallback. we pass kwargs to this 141 pbar_fn = pbar 142 143 return pbar_fn, pbar_kwargs
set up the progress bar function and its kwargs
Parameters:
pbar : Union[ProgressBarFunction, ProgressBarOption]progress bar function or option. if a function, we return as-is. if a string, we figure out which progress bar to usepbar_kwargs : Optional[Dict[str, Any]]kwargs passed to the progress bar function (default toNone) (defaults toNone)
Returns:
Tuple[ProgressBarFunction, dict]a tuple of the progress bar function and its kwargs
Raises:
ValueError: ifpbaris not one of the valid options
def
run_maybe_parallel( func: Callable[[~InputType], ~OutputType], iterable: Iterable[~InputType], parallel: Union[bool, int], pbar_kwargs: Optional[Dict[str, Any]] = None, chunksize: Optional[int] = None, keep_ordered: bool = True, use_multiprocess: bool = False, pbar: Union[ProgressBarFunction, Literal['tqdm', 'spinner', 'none', None]] = 'tqdm') -> List[~OutputType]:
147def run_maybe_parallel( 148 func: Callable[[InputType], OutputType], 149 iterable: Iterable[InputType], 150 parallel: Union[bool, int], 151 pbar_kwargs: Optional[Dict[str, Any]] = None, 152 chunksize: Optional[int] = None, 153 keep_ordered: bool = True, 154 use_multiprocess: bool = False, 155 pbar: Union[ProgressBarFunction, ProgressBarOption] = DEFAULT_PBAR_FN, 156) -> List[OutputType]: 157 """a function to make it easier to sometimes parallelize an operation 158 159 - if `parallel` is `False`, then the function will run in serial, running `map(func, iterable)` 160 - if `parallel` is `True`, then the function will run in parallel, running in parallel with the maximum number of processes 161 - if `parallel` is an `int`, it must be greater than 1, and the function will run in parallel with the number of processes specified by `parallel` 162 163 the maximum number of processes is given by the `min(len(iterable), multiprocessing.cpu_count())` 164 165 # Parameters: 166 - `func : Callable[[InputType], OutputType]` 167 function passed to either `map` or `Pool.imap` 168 - `iterable : Iterable[InputType]` 169 iterable passed to either `map` or `Pool.imap` 170 - `parallel : bool | int` 171 whether to run in parallel, and how many processes to use 172 - `pbar_kwargs : Dict[str, Any]` 173 kwargs passed to the progress bar function 174 175 # Returns: 176 - `List[OutputType]` 177 a list of the output of `func` for each element in `iterable` 178 179 # Raises: 180 - `ValueError` : if `parallel` is not a boolean or an integer greater than 1 181 - `ValueError` : if `use_multiprocess=True` and `parallel=False` 182 - `ImportError` : if `use_multiprocess=True` and `multiprocess` is not available 183 """ 184 185 # number of inputs in iterable 186 n_inputs: int = len(iterable) # type: ignore[arg-type] 187 if n_inputs == 0: 188 # Return immediately if there is no input 189 return list() 190 191 # which progress bar to use 192 pbar_fn: ProgressBarFunction 193 pbar_kwargs_processed: dict 194 pbar_fn, pbar_kwargs_processed = set_up_progress_bar_fn( 195 pbar=pbar, 196 pbar_kwargs=pbar_kwargs, 197 # extra kwargs 198 total=n_inputs, 199 ) 200 201 # number of processes 202 num_processes: int 203 if isinstance(parallel, bool): 204 num_processes = multiprocessing.cpu_count() if parallel else 1 205 elif isinstance(parallel, int): 206 if parallel < 2: 207 raise ValueError( 208 f"`parallel` must be a boolean, or be an integer greater than 1, got {type(parallel) = } {parallel = }" 209 ) 210 num_processes = parallel 211 else: 212 raise ValueError( 213 f"The 'parallel' parameter must be a boolean or an integer, got {type(parallel) = } {parallel = }" 214 ) 215 216 # make sure we don't have more processes than iterable, and don't bother with parallel if there's only one process 217 num_processes = min(num_processes, n_inputs) 218 mp = multiprocessing 219 if num_processes == 1: 220 parallel = False 221 222 if use_multiprocess: 223 if not parallel: 224 raise ValueError("`use_multiprocess=True` requires `parallel=True`") 225 226 try: 227 import multiprocess # type: ignore[import-untyped] 228 except ImportError as e: 229 raise ImportError( 230 "`use_multiprocess=True` requires the `multiprocess` package -- this is mostly useful when you need to pickle a lambda. install muutils with `pip install muutils[multiprocess]` or just do `pip install multiprocess`" 231 ) from e 232 233 mp = multiprocess 234 235 # set up the map function -- maybe its parallel, maybe it's just `map` 236 do_map: Callable[ 237 [Callable[[InputType], OutputType], Iterable[InputType]], 238 Iterable[OutputType], 239 ] 240 if parallel: 241 # use `mp.Pool` since we might want to use `multiprocess` instead of `multiprocessing` 242 pool = mp.Pool(num_processes) 243 244 # use `imap` if we want to keep the order, otherwise use `imap_unordered` 245 if keep_ordered: 246 do_map = pool.imap 247 else: 248 do_map = pool.imap_unordered 249 250 # figure out a smart chunksize if one is not given 251 chunksize_int: int 252 if chunksize is None: 253 chunksize_int = max(1, n_inputs // num_processes) 254 else: 255 chunksize_int = chunksize 256 257 # set the chunksize 258 do_map = functools.partial(do_map, chunksize=chunksize_int) # type: ignore 259 260 else: 261 do_map = map 262 263 # run the map function with a progress bar 264 output: List[OutputType] = list( 265 pbar_fn( 266 do_map( 267 func, 268 iterable, 269 ), 270 **pbar_kwargs_processed, 271 ) 272 ) 273 274 # close the pool if we used one 275 if parallel: 276 pool.close() 277 pool.join() 278 279 # return the output as a list 280 return output
a function to make it easier to sometimes parallelize an operation
- if
parallelisFalse, then the function will run in serial, runningmap(func, iterable) - if
parallelisTrue, then the function will run in parallel, running in parallel with the maximum number of processes - if
parallelis anint, it must be greater than 1, and the function will run in parallel with the number of processes specified byparallel
the maximum number of processes is given by the min(len(iterable), multiprocessing.cpu_count())
Parameters:
func : Callable[[InputType], OutputType]function passed to eithermaporPool.imapiterable : Iterable[InputType]iterable passed to eithermaporPool.imapparallel : bool | intwhether to run in parallel, and how many processes to usepbar_kwargs : Dict[str, Any]kwargs passed to the progress bar function
Returns:
List[OutputType]a list of the output offuncfor each element initerable
Raises:
ValueError: ifparallelis not a boolean or an integer greater than 1ValueError: ifuse_multiprocess=Trueandparallel=FalseImportError: ifuse_multiprocess=Trueandmultiprocessis not available