Coverage for nilearn/plotting/_utils.py: 0%
96 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
1from numbers import Number
2from pathlib import Path
3from warnings import warn
5import matplotlib.pyplot as plt
6import numpy as np
7from matplotlib.colors import LinearSegmentedColormap
9from nilearn._utils.logger import find_stack_level
11DEFAULT_ENGINE = "matplotlib"
14def engine_warning(engine):
15 message = (
16 f"'{engine}' is not installed. To be able to use '{engine}' as "
17 "plotting engine for 'nilearn.plotting' package:\n"
18 " pip install 'nilearn[plotting]'"
19 )
20 warn(message, stacklevel=find_stack_level())
23def save_figure_if_needed(fig, output_file):
24 """Save figure if an output file value is given.
26 Create output path if required.
28 Parameters
29 ----------
30 fig: figure, axes, or display instance
32 output_file: str, Path or None
34 Returns
35 -------
36 None if ``output_file`` is None, ``fig`` otherwise.
37 """
38 # avoid circular import
39 from nilearn.plotting.displays import BaseSlicer
41 if output_file is None:
42 return fig
44 output_file = Path(output_file)
45 output_file.parent.mkdir(exist_ok=True, parents=True)
47 if not isinstance(fig, (plt.Figure, BaseSlicer)):
48 fig = fig.figure
50 fig.savefig(output_file)
51 if isinstance(fig, plt.Figure):
52 plt.close(fig)
53 else:
54 fig.close()
56 return None
59def get_cbar_ticks(vmin, vmax, offset, n_ticks=5):
60 """Help for BaseSlicer."""
61 # edge case where the data has a single value yields
62 # a cryptic matplotlib error message when trying to plot the color bar
63 if vmin == vmax:
64 return np.linspace(vmin, vmax, 1)
66 # edge case where the data has all negative values but vmax is exactly 0
67 if vmax == 0:
68 vmax += np.finfo(np.float32).eps
70 # If a threshold is specified, we want two of the tick
71 # to correspond to -threshold and +threshold on the colorbar.
72 # If the threshold is very small compared to vmax,
73 # we use a simple linspace as the result would be very difficult to see.
74 ticks = np.linspace(vmin, vmax, n_ticks)
75 if offset is not None and offset / vmax > 0.12:
76 diff = [abs(abs(tick) - offset) for tick in ticks]
77 # Edge case where the thresholds are exactly
78 # at the same distance to 4 ticks
79 if diff.count(min(diff)) == 4:
80 idx_closest = np.sort(np.argpartition(diff, 4)[:4])
81 idx_closest = np.isin(ticks, np.sort(ticks[idx_closest])[1:3])
82 else:
83 # Find the closest 2 ticks
84 idx_closest = np.sort(np.argpartition(diff, 2)[:2])
85 if 0 in ticks[idx_closest]:
86 idx_closest = np.sort(np.argpartition(diff, 3)[:3])
87 idx_closest = idx_closest[[0, 2]]
88 ticks[idx_closest] = [-offset, offset]
89 if len(ticks) > 0 and ticks[0] < vmin:
90 ticks[0] = vmin
92 return ticks
95def get_colorbar_and_data_ranges(
96 stat_map_data,
97 vmin=None,
98 vmax=None,
99 symmetric_cbar=True,
100 force_min_stat_map_value=None,
101):
102 """Set colormap and colorbar limits.
104 Used by plot_stat_map, plot_glass_brain and plot_img_on_surf.
106 The limits for the colorbar depend on the symmetric_cbar argument. Please
107 refer to docstring of plot_stat_map.
108 """
109 # handle invalid vmin/vmax inputs
110 if (not isinstance(vmin, Number)) or (not np.isfinite(vmin)):
111 vmin = None
112 if (not isinstance(vmax, Number)) or (not np.isfinite(vmax)):
113 vmax = None
115 # avoid dealing with masked_array:
116 if hasattr(stat_map_data, "_mask"):
117 stat_map_data = np.asarray(
118 stat_map_data[np.logical_not(stat_map_data._mask)]
119 )
121 if force_min_stat_map_value is None:
122 stat_map_min = np.nanmin(stat_map_data)
123 else:
124 stat_map_min = force_min_stat_map_value
125 stat_map_max = np.nanmax(stat_map_data)
127 if symmetric_cbar == "auto":
128 if vmin is None or vmax is None:
129 min_value = (
130 stat_map_min if vmin is None else max(vmin, stat_map_min)
131 )
132 max_value = (
133 stat_map_max if vmax is None else min(stat_map_max, vmax)
134 )
135 symmetric_cbar = min_value < 0 < max_value
136 else:
137 symmetric_cbar = np.isclose(vmin, -vmax)
139 # check compatibility between vmin, vmax and symmetric_cbar
140 if symmetric_cbar:
141 if vmin is None and vmax is None:
142 vmax = max(-stat_map_min, stat_map_max)
143 vmin = -vmax
144 elif vmin is None:
145 vmin = -vmax
146 elif vmax is None:
147 vmax = -vmin
148 elif not np.isclose(vmin, -vmax):
149 raise ValueError(
150 "vmin must be equal to -vmax unless symmetric_cbar is False."
151 )
152 cbar_vmin = vmin
153 cbar_vmax = vmax
154 # set colorbar limits
155 else:
156 negative_range = stat_map_max <= 0
157 positive_range = stat_map_min >= 0
158 if positive_range:
159 cbar_vmin = 0 if vmin is None else vmin
160 cbar_vmax = vmax
161 elif negative_range:
162 cbar_vmax = 0 if vmax is None else vmax
163 cbar_vmin = vmin
164 else:
165 # limit colorbar to plotted values
166 cbar_vmin = vmin
167 cbar_vmax = vmax
169 # set vmin/vmax based on data if they are not already set
170 if vmin is None:
171 vmin = stat_map_min
172 if vmax is None:
173 vmax = stat_map_max
175 return cbar_vmin, cbar_vmax, float(vmin), float(vmax)
178def check_threshold_not_negative(threshold):
179 """Make sure threshold is non negative number.
181 If threshold == "auto", it may be set to very small value.
182 So we allow for that.
183 """
184 if isinstance(threshold, (int, float)) and threshold < -1e-5:
185 raise ValueError("Threshold should be a non-negative number!")
188def create_colormap_from_lut(cmap, default_cmap="gist_ncar"):
189 """
190 Create a Matplotlib colormap from a DataFrame containing color mappings.
192 Parameters
193 ----------
194 cmap : pd.DataFrame
195 DataFrame with columns 'index', 'name', and 'color' (hex values)
197 Returns
198 -------
199 colormap (LinearSegmentedColormap): A Matplotlib colormap
200 """
201 if "color" not in cmap.columns:
202 warn(
203 "No 'color' column found in the look-up table. "
204 "Will use the default colormap instead.",
205 stacklevel=find_stack_level(),
206 )
207 return default_cmap
209 # Ensure colors are properly extracted from DataFrame
210 colors = cmap.sort_values(by="index")["color"].tolist()
212 # Create a colormap from the list of colors
213 return LinearSegmentedColormap.from_list(
214 "custom_colormap", colors, N=len(colors)
215 )