docs for muutils v0.8.12
View Source on GitHub

muutils.parallel

parallel processing utilities, chiefly run_maybe_parallel


  1"parallel processing utilities, chiefly `run_maybe_parallel`"
  2
  3from __future__ import annotations
  4
  5import multiprocessing
  6import functools
  7from typing import (
  8    Any,
  9    Callable,
 10    Iterable,
 11    Literal,
 12    Optional,
 13    Tuple,
 14    TypeVar,
 15    Dict,
 16    List,
 17    Union,
 18    Protocol,
 19)
 20
 21# for no tqdm fallback
 22from muutils.spinner import SpinnerContext
 23from muutils.validate_type import get_fn_allowed_kwargs
 24
 25
 26InputType = TypeVar("InputType")
 27OutputType = TypeVar("OutputType")
 28# typevars for our iterable and map
 29
 30
 31class ProgressBarFunction(Protocol):
 32    "a protocol for a progress bar function"
 33
 34    def __call__(self, iterable: Iterable, **kwargs: Any) -> Iterable: ...
 35
 36
 37ProgressBarOption = Literal["tqdm", "spinner", "none", None]
 38# type for the progress bar option
 39
 40
 41DEFAULT_PBAR_FN: ProgressBarOption
 42# default progress bar function
 43
 44try:
 45    # use tqdm if it's available
 46    import tqdm  # type: ignore[import-untyped]
 47
 48    DEFAULT_PBAR_FN = "tqdm"
 49
 50except ImportError:
 51    # use progress bar as fallback
 52    DEFAULT_PBAR_FN = "spinner"
 53
 54
 55def spinner_fn_wrap(x: Iterable, **kwargs) -> List:
 56    "spinner wrapper"
 57    spinnercontext_allowed_kwargs: set[str] = get_fn_allowed_kwargs(
 58        SpinnerContext.__init__
 59    )
 60    mapped_kwargs: dict = {
 61        k: v for k, v in kwargs.items() if k in spinnercontext_allowed_kwargs
 62    }
 63    if "desc" in kwargs and "message" not in mapped_kwargs:
 64        mapped_kwargs["message"] = kwargs["desc"]
 65
 66    if "message" not in mapped_kwargs and "total" in kwargs:
 67        mapped_kwargs["message"] = f"Processing {kwargs['total']} items"
 68
 69    with SpinnerContext(**mapped_kwargs):
 70        output = list(x)
 71
 72    return output
 73
 74
 75def map_kwargs_for_tqdm(kwargs: dict) -> dict:
 76    "map kwargs for tqdm, cant wrap because the pbar dissapears?"
 77    tqdm_allowed_kwargs: set[str] = get_fn_allowed_kwargs(tqdm.tqdm.__init__)
 78    mapped_kwargs: dict = {k: v for k, v in kwargs.items() if k in tqdm_allowed_kwargs}
 79
 80    if "desc" not in kwargs:
 81        if "message" in kwargs:
 82            mapped_kwargs["desc"] = kwargs["message"]
 83
 84        elif "total" in kwargs:
 85            mapped_kwargs["desc"] = f"Processing {kwargs.get('total')} items"
 86    return mapped_kwargs
 87
 88
 89def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable:
 90    "fallback to no progress bar"
 91    return x
 92
 93
 94def set_up_progress_bar_fn(
 95    pbar: Union[ProgressBarFunction, ProgressBarOption],
 96    pbar_kwargs: Optional[Dict[str, Any]] = None,
 97    **extra_kwargs,
 98) -> Tuple[ProgressBarFunction, dict]:
 99    """set up the progress bar function and its kwargs
100
101    # Parameters:
102     - `pbar : Union[ProgressBarFunction, ProgressBarOption]`
103       progress bar function or option. if a function, we return as-is. if a string, we figure out which progress bar to use
104     - `pbar_kwargs : Optional[Dict[str, Any]]`
105       kwargs passed to the progress bar function (default to `None`)
106       (defaults to `None`)
107
108    # Returns:
109     - `Tuple[ProgressBarFunction, dict]`
110         a tuple of the progress bar function and its kwargs
111
112    # Raises:
113     - `ValueError` : if `pbar` is not one of the valid options
114    """
115    pbar_fn: ProgressBarFunction
116
117    if pbar_kwargs is None:
118        pbar_kwargs = dict()
119
120    pbar_kwargs = {**extra_kwargs, **pbar_kwargs}
121
122    # dont use a progress bar if `pbar` is None or "none", or if `disable` is set to True in `pbar_kwargs`
123    if (pbar is None) or (pbar == "none") or pbar_kwargs.get("disable", False):
124        pbar_fn = no_progress_fn_wrap  # type: ignore[assignment]
125
126    # if `pbar` is a different string, figure out which progress bar to use
127    elif isinstance(pbar, str):
128        if pbar == "tqdm":
129            pbar_fn = tqdm.tqdm
130            pbar_kwargs = map_kwargs_for_tqdm(pbar_kwargs)
131        elif pbar == "spinner":
132            pbar_fn = functools.partial(spinner_fn_wrap, **pbar_kwargs)
133            pbar_kwargs = dict()
134        else:
135            raise ValueError(
136                f"`pbar` must be either 'tqdm' or 'spinner' if `str`, or a valid callable, got {type(pbar) = } {pbar = }"
137            )
138    else:
139        # the default value is a callable which will resolve to tqdm if available or spinner as a fallback. we pass kwargs to this
140        pbar_fn = pbar
141
142    return pbar_fn, pbar_kwargs
143
144
145# TODO: if `parallel` is a negative int, use `multiprocessing.cpu_count() + parallel` to determine the number of processes
146def run_maybe_parallel(
147    func: Callable[[InputType], OutputType],
148    iterable: Iterable[InputType],
149    parallel: Union[bool, int],
150    pbar_kwargs: Optional[Dict[str, Any]] = None,
151    chunksize: Optional[int] = None,
152    keep_ordered: bool = True,
153    use_multiprocess: bool = False,
154    pbar: Union[ProgressBarFunction, ProgressBarOption] = DEFAULT_PBAR_FN,
155) -> List[OutputType]:
156    """a function to make it easier to sometimes parallelize an operation
157
158    - if `parallel` is `False`, then the function will run in serial, running `map(func, iterable)`
159    - if `parallel` is `True`, then the function will run in parallel, running in parallel with the maximum number of processes
160    - if `parallel` is an `int`, it must be greater than 1, and the function will run in parallel with the number of processes specified by `parallel`
161
162    the maximum number of processes is given by the `min(len(iterable), multiprocessing.cpu_count())`
163
164    # Parameters:
165     - `func : Callable[[InputType], OutputType]`
166       function passed to either `map` or `Pool.imap`
167     - `iterable : Iterable[InputType]`
168       iterable passed to either `map` or `Pool.imap`
169     - `parallel : bool | int`
170       whether to run in parallel, and how many processes to use
171     - `pbar_kwargs : Dict[str, Any]`
172       kwargs passed to the progress bar function
173
174    # Returns:
175     - `List[OutputType]`
176       a list of the output of `func` for each element in `iterable`
177
178    # Raises:
179     - `ValueError` : if `parallel` is not a boolean or an integer greater than 1
180     - `ValueError` : if `use_multiprocess=True` and `parallel=False`
181     - `ImportError` : if `use_multiprocess=True` and `multiprocess` is not available
182    """
183
184    # number of inputs in iterable
185    n_inputs: int = len(iterable)  # type: ignore[arg-type]
186    if n_inputs == 0:
187        # Return immediately if there is no input
188        return list()
189
190    # which progress bar to use
191    pbar_fn: ProgressBarFunction
192    pbar_kwargs_processed: dict
193    pbar_fn, pbar_kwargs_processed = set_up_progress_bar_fn(
194        pbar=pbar,
195        pbar_kwargs=pbar_kwargs,
196        # extra kwargs
197        total=n_inputs,
198    )
199
200    # number of processes
201    num_processes: int
202    if isinstance(parallel, bool):
203        num_processes = multiprocessing.cpu_count() if parallel else 1
204    elif isinstance(parallel, int):
205        if parallel < 2:
206            raise ValueError(
207                f"`parallel` must be a boolean, or be an integer greater than 1, got {type(parallel) = } {parallel = }"
208            )
209        num_processes = parallel
210    else:
211        raise ValueError(
212            f"The 'parallel' parameter must be a boolean or an integer, got {type(parallel) = } {parallel = }"
213        )
214
215    # make sure we don't have more processes than iterable, and don't bother with parallel if there's only one process
216    num_processes = min(num_processes, n_inputs)
217    mp = multiprocessing
218    if num_processes == 1:
219        parallel = False
220
221    if use_multiprocess:
222        if not parallel:
223            raise ValueError("`use_multiprocess=True` requires `parallel=True`")
224
225        try:
226            import multiprocess  # type: ignore[import-untyped]
227        except ImportError as e:
228            raise ImportError(
229                "`use_multiprocess=True` requires the `multiprocess` package -- this is mostly useful when you need to pickle a lambda. install muutils with `pip install muutils[multiprocess]` or just do `pip install multiprocess`"
230            ) from e
231
232        mp = multiprocess
233
234    # set up the map function -- maybe its parallel, maybe it's just `map`
235    do_map: Callable[
236        [Callable[[InputType], OutputType], Iterable[InputType]],
237        Iterable[OutputType],
238    ]
239    if parallel:
240        # use `mp.Pool` since we might want to use `multiprocess` instead of `multiprocessing`
241        pool = mp.Pool(num_processes)
242
243        # use `imap` if we want to keep the order, otherwise use `imap_unordered`
244        if keep_ordered:
245            do_map = pool.imap
246        else:
247            do_map = pool.imap_unordered
248
249        # figure out a smart chunksize if one is not given
250        chunksize_int: int
251        if chunksize is None:
252            chunksize_int = max(1, n_inputs // num_processes)
253        else:
254            chunksize_int = chunksize
255
256        # set the chunksize
257        do_map = functools.partial(do_map, chunksize=chunksize_int)  # type: ignore
258
259    else:
260        do_map = map
261
262    # run the map function with a progress bar
263    output: List[OutputType] = list(
264        pbar_fn(
265            do_map(
266                func,
267                iterable,
268            ),
269            **pbar_kwargs_processed,
270        )
271    )
272
273    # close the pool if we used one
274    if parallel:
275        pool.close()
276        pool.join()
277
278    # return the output as a list
279    return output

class ProgressBarFunction(typing.Protocol):
32class ProgressBarFunction(Protocol):
33    "a protocol for a progress bar function"
34
35    def __call__(self, iterable: Iterable, **kwargs: Any) -> Iterable: ...

a protocol for a progress bar function

ProgressBarFunction(*args, **kwargs)
1945def _no_init_or_replace_init(self, *args, **kwargs):
1946    cls = type(self)
1947
1948    if cls._is_protocol:
1949        raise TypeError('Protocols cannot be instantiated')
1950
1951    # Already using a custom `__init__`. No need to calculate correct
1952    # `__init__` to call. This can lead to RecursionError. See bpo-45121.
1953    if cls.__init__ is not _no_init_or_replace_init:
1954        return
1955
1956    # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`.
1957    # The first instantiation of the subclass will call `_no_init_or_replace_init` which
1958    # searches for a proper new `__init__` in the MRO. The new `__init__`
1959    # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent
1960    # instantiation of the protocol subclass will thus use the new
1961    # `__init__` and no longer call `_no_init_or_replace_init`.
1962    for base in cls.__mro__:
1963        init = base.__dict__.get('__init__', _no_init_or_replace_init)
1964        if init is not _no_init_or_replace_init:
1965            cls.__init__ = init
1966            break
1967    else:
1968        # should not happen
1969        cls.__init__ = object.__init__
1970
1971    cls.__init__(self, *args, **kwargs)
ProgressBarOption = typing.Literal['tqdm', 'spinner', 'none', None]
DEFAULT_PBAR_FN: Literal['tqdm', 'spinner', 'none', None] = 'tqdm'
def spinner_fn_wrap(x: Iterable, **kwargs) -> List:
56def spinner_fn_wrap(x: Iterable, **kwargs) -> List:
57    "spinner wrapper"
58    spinnercontext_allowed_kwargs: set[str] = get_fn_allowed_kwargs(
59        SpinnerContext.__init__
60    )
61    mapped_kwargs: dict = {
62        k: v for k, v in kwargs.items() if k in spinnercontext_allowed_kwargs
63    }
64    if "desc" in kwargs and "message" not in mapped_kwargs:
65        mapped_kwargs["message"] = kwargs["desc"]
66
67    if "message" not in mapped_kwargs and "total" in kwargs:
68        mapped_kwargs["message"] = f"Processing {kwargs['total']} items"
69
70    with SpinnerContext(**mapped_kwargs):
71        output = list(x)
72
73    return output

spinner wrapper

def map_kwargs_for_tqdm(kwargs: dict) -> dict:
76def map_kwargs_for_tqdm(kwargs: dict) -> dict:
77    "map kwargs for tqdm, cant wrap because the pbar dissapears?"
78    tqdm_allowed_kwargs: set[str] = get_fn_allowed_kwargs(tqdm.tqdm.__init__)
79    mapped_kwargs: dict = {k: v for k, v in kwargs.items() if k in tqdm_allowed_kwargs}
80
81    if "desc" not in kwargs:
82        if "message" in kwargs:
83            mapped_kwargs["desc"] = kwargs["message"]
84
85        elif "total" in kwargs:
86            mapped_kwargs["desc"] = f"Processing {kwargs.get('total')} items"
87    return mapped_kwargs

map kwargs for tqdm, cant wrap because the pbar dissapears?

def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable:
90def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable:
91    "fallback to no progress bar"
92    return x

fallback to no progress bar

def set_up_progress_bar_fn( pbar: Union[ProgressBarFunction, Literal['tqdm', 'spinner', 'none', None]], pbar_kwargs: Optional[Dict[str, Any]] = None, **extra_kwargs) -> Tuple[ProgressBarFunction, dict]:
 95def set_up_progress_bar_fn(
 96    pbar: Union[ProgressBarFunction, ProgressBarOption],
 97    pbar_kwargs: Optional[Dict[str, Any]] = None,
 98    **extra_kwargs,
 99) -> Tuple[ProgressBarFunction, dict]:
100    """set up the progress bar function and its kwargs
101
102    # Parameters:
103     - `pbar : Union[ProgressBarFunction, ProgressBarOption]`
104       progress bar function or option. if a function, we return as-is. if a string, we figure out which progress bar to use
105     - `pbar_kwargs : Optional[Dict[str, Any]]`
106       kwargs passed to the progress bar function (default to `None`)
107       (defaults to `None`)
108
109    # Returns:
110     - `Tuple[ProgressBarFunction, dict]`
111         a tuple of the progress bar function and its kwargs
112
113    # Raises:
114     - `ValueError` : if `pbar` is not one of the valid options
115    """
116    pbar_fn: ProgressBarFunction
117
118    if pbar_kwargs is None:
119        pbar_kwargs = dict()
120
121    pbar_kwargs = {**extra_kwargs, **pbar_kwargs}
122
123    # dont use a progress bar if `pbar` is None or "none", or if `disable` is set to True in `pbar_kwargs`
124    if (pbar is None) or (pbar == "none") or pbar_kwargs.get("disable", False):
125        pbar_fn = no_progress_fn_wrap  # type: ignore[assignment]
126
127    # if `pbar` is a different string, figure out which progress bar to use
128    elif isinstance(pbar, str):
129        if pbar == "tqdm":
130            pbar_fn = tqdm.tqdm
131            pbar_kwargs = map_kwargs_for_tqdm(pbar_kwargs)
132        elif pbar == "spinner":
133            pbar_fn = functools.partial(spinner_fn_wrap, **pbar_kwargs)
134            pbar_kwargs = dict()
135        else:
136            raise ValueError(
137                f"`pbar` must be either 'tqdm' or 'spinner' if `str`, or a valid callable, got {type(pbar) = } {pbar = }"
138            )
139    else:
140        # the default value is a callable which will resolve to tqdm if available or spinner as a fallback. we pass kwargs to this
141        pbar_fn = pbar
142
143    return pbar_fn, pbar_kwargs

set up the progress bar function and its kwargs

Parameters:

  • pbar : Union[ProgressBarFunction, ProgressBarOption] progress bar function or option. if a function, we return as-is. if a string, we figure out which progress bar to use
  • pbar_kwargs : Optional[Dict[str, Any]] kwargs passed to the progress bar function (default to None) (defaults to None)

Returns:

  • Tuple[ProgressBarFunction, dict] a tuple of the progress bar function and its kwargs

Raises:

  • ValueError : if pbar is not one of the valid options
def run_maybe_parallel( func: Callable[[~InputType], ~OutputType], iterable: Iterable[~InputType], parallel: Union[bool, int], pbar_kwargs: Optional[Dict[str, Any]] = None, chunksize: Optional[int] = None, keep_ordered: bool = True, use_multiprocess: bool = False, pbar: Union[ProgressBarFunction, Literal['tqdm', 'spinner', 'none', None]] = 'tqdm') -> List[~OutputType]:
147def run_maybe_parallel(
148    func: Callable[[InputType], OutputType],
149    iterable: Iterable[InputType],
150    parallel: Union[bool, int],
151    pbar_kwargs: Optional[Dict[str, Any]] = None,
152    chunksize: Optional[int] = None,
153    keep_ordered: bool = True,
154    use_multiprocess: bool = False,
155    pbar: Union[ProgressBarFunction, ProgressBarOption] = DEFAULT_PBAR_FN,
156) -> List[OutputType]:
157    """a function to make it easier to sometimes parallelize an operation
158
159    - if `parallel` is `False`, then the function will run in serial, running `map(func, iterable)`
160    - if `parallel` is `True`, then the function will run in parallel, running in parallel with the maximum number of processes
161    - if `parallel` is an `int`, it must be greater than 1, and the function will run in parallel with the number of processes specified by `parallel`
162
163    the maximum number of processes is given by the `min(len(iterable), multiprocessing.cpu_count())`
164
165    # Parameters:
166     - `func : Callable[[InputType], OutputType]`
167       function passed to either `map` or `Pool.imap`
168     - `iterable : Iterable[InputType]`
169       iterable passed to either `map` or `Pool.imap`
170     - `parallel : bool | int`
171       whether to run in parallel, and how many processes to use
172     - `pbar_kwargs : Dict[str, Any]`
173       kwargs passed to the progress bar function
174
175    # Returns:
176     - `List[OutputType]`
177       a list of the output of `func` for each element in `iterable`
178
179    # Raises:
180     - `ValueError` : if `parallel` is not a boolean or an integer greater than 1
181     - `ValueError` : if `use_multiprocess=True` and `parallel=False`
182     - `ImportError` : if `use_multiprocess=True` and `multiprocess` is not available
183    """
184
185    # number of inputs in iterable
186    n_inputs: int = len(iterable)  # type: ignore[arg-type]
187    if n_inputs == 0:
188        # Return immediately if there is no input
189        return list()
190
191    # which progress bar to use
192    pbar_fn: ProgressBarFunction
193    pbar_kwargs_processed: dict
194    pbar_fn, pbar_kwargs_processed = set_up_progress_bar_fn(
195        pbar=pbar,
196        pbar_kwargs=pbar_kwargs,
197        # extra kwargs
198        total=n_inputs,
199    )
200
201    # number of processes
202    num_processes: int
203    if isinstance(parallel, bool):
204        num_processes = multiprocessing.cpu_count() if parallel else 1
205    elif isinstance(parallel, int):
206        if parallel < 2:
207            raise ValueError(
208                f"`parallel` must be a boolean, or be an integer greater than 1, got {type(parallel) = } {parallel = }"
209            )
210        num_processes = parallel
211    else:
212        raise ValueError(
213            f"The 'parallel' parameter must be a boolean or an integer, got {type(parallel) = } {parallel = }"
214        )
215
216    # make sure we don't have more processes than iterable, and don't bother with parallel if there's only one process
217    num_processes = min(num_processes, n_inputs)
218    mp = multiprocessing
219    if num_processes == 1:
220        parallel = False
221
222    if use_multiprocess:
223        if not parallel:
224            raise ValueError("`use_multiprocess=True` requires `parallel=True`")
225
226        try:
227            import multiprocess  # type: ignore[import-untyped]
228        except ImportError as e:
229            raise ImportError(
230                "`use_multiprocess=True` requires the `multiprocess` package -- this is mostly useful when you need to pickle a lambda. install muutils with `pip install muutils[multiprocess]` or just do `pip install multiprocess`"
231            ) from e
232
233        mp = multiprocess
234
235    # set up the map function -- maybe its parallel, maybe it's just `map`
236    do_map: Callable[
237        [Callable[[InputType], OutputType], Iterable[InputType]],
238        Iterable[OutputType],
239    ]
240    if parallel:
241        # use `mp.Pool` since we might want to use `multiprocess` instead of `multiprocessing`
242        pool = mp.Pool(num_processes)
243
244        # use `imap` if we want to keep the order, otherwise use `imap_unordered`
245        if keep_ordered:
246            do_map = pool.imap
247        else:
248            do_map = pool.imap_unordered
249
250        # figure out a smart chunksize if one is not given
251        chunksize_int: int
252        if chunksize is None:
253            chunksize_int = max(1, n_inputs // num_processes)
254        else:
255            chunksize_int = chunksize
256
257        # set the chunksize
258        do_map = functools.partial(do_map, chunksize=chunksize_int)  # type: ignore
259
260    else:
261        do_map = map
262
263    # run the map function with a progress bar
264    output: List[OutputType] = list(
265        pbar_fn(
266            do_map(
267                func,
268                iterable,
269            ),
270            **pbar_kwargs_processed,
271        )
272    )
273
274    # close the pool if we used one
275    if parallel:
276        pool.close()
277        pool.join()
278
279    # return the output as a list
280    return output

a function to make it easier to sometimes parallelize an operation

  • if parallel is False, then the function will run in serial, running map(func, iterable)
  • if parallel is True, then the function will run in parallel, running in parallel with the maximum number of processes
  • if parallel is an int, it must be greater than 1, and the function will run in parallel with the number of processes specified by parallel

the maximum number of processes is given by the min(len(iterable), multiprocessing.cpu_count())

Parameters:

  • func : Callable[[InputType], OutputType] function passed to either map or Pool.imap
  • iterable : Iterable[InputType] iterable passed to either map or Pool.imap
  • parallel : bool | int whether to run in parallel, and how many processes to use
  • pbar_kwargs : Dict[str, Any] kwargs passed to the progress bar function

Returns:

  • List[OutputType] a list of the output of func for each element in iterable

Raises:

  • ValueError : if parallel is not a boolean or an integer greater than 1
  • ValueError : if use_multiprocess=True and parallel=False
  • ImportError : if use_multiprocess=True and multiprocess is not available