Coverage for nilearn/plotting/_utils.py: 0%

96 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-16 12:32 +0200

1from numbers import Number 

2from pathlib import Path 

3from warnings import warn 

4 

5import matplotlib.pyplot as plt 

6import numpy as np 

7from matplotlib.colors import LinearSegmentedColormap 

8 

9from nilearn._utils.logger import find_stack_level 

10 

11DEFAULT_ENGINE = "matplotlib" 

12 

13 

14def engine_warning(engine): 

15 message = ( 

16 f"'{engine}' is not installed. To be able to use '{engine}' as " 

17 "plotting engine for 'nilearn.plotting' package:\n" 

18 " pip install 'nilearn[plotting]'" 

19 ) 

20 warn(message, stacklevel=find_stack_level()) 

21 

22 

23def save_figure_if_needed(fig, output_file): 

24 """Save figure if an output file value is given. 

25 

26 Create output path if required. 

27 

28 Parameters 

29 ---------- 

30 fig: figure, axes, or display instance 

31 

32 output_file: str, Path or None 

33 

34 Returns 

35 ------- 

36 None if ``output_file`` is None, ``fig`` otherwise. 

37 """ 

38 # avoid circular import 

39 from nilearn.plotting.displays import BaseSlicer 

40 

41 if output_file is None: 

42 return fig 

43 

44 output_file = Path(output_file) 

45 output_file.parent.mkdir(exist_ok=True, parents=True) 

46 

47 if not isinstance(fig, (plt.Figure, BaseSlicer)): 

48 fig = fig.figure 

49 

50 fig.savefig(output_file) 

51 if isinstance(fig, plt.Figure): 

52 plt.close(fig) 

53 else: 

54 fig.close() 

55 

56 return None 

57 

58 

59def get_cbar_ticks(vmin, vmax, offset, n_ticks=5): 

60 """Help for BaseSlicer.""" 

61 # edge case where the data has a single value yields 

62 # a cryptic matplotlib error message when trying to plot the color bar 

63 if vmin == vmax: 

64 return np.linspace(vmin, vmax, 1) 

65 

66 # edge case where the data has all negative values but vmax is exactly 0 

67 if vmax == 0: 

68 vmax += np.finfo(np.float32).eps 

69 

70 # If a threshold is specified, we want two of the tick 

71 # to correspond to -threshold and +threshold on the colorbar. 

72 # If the threshold is very small compared to vmax, 

73 # we use a simple linspace as the result would be very difficult to see. 

74 ticks = np.linspace(vmin, vmax, n_ticks) 

75 if offset is not None and offset / vmax > 0.12: 

76 diff = [abs(abs(tick) - offset) for tick in ticks] 

77 # Edge case where the thresholds are exactly 

78 # at the same distance to 4 ticks 

79 if diff.count(min(diff)) == 4: 

80 idx_closest = np.sort(np.argpartition(diff, 4)[:4]) 

81 idx_closest = np.isin(ticks, np.sort(ticks[idx_closest])[1:3]) 

82 else: 

83 # Find the closest 2 ticks 

84 idx_closest = np.sort(np.argpartition(diff, 2)[:2]) 

85 if 0 in ticks[idx_closest]: 

86 idx_closest = np.sort(np.argpartition(diff, 3)[:3]) 

87 idx_closest = idx_closest[[0, 2]] 

88 ticks[idx_closest] = [-offset, offset] 

89 if len(ticks) > 0 and ticks[0] < vmin: 

90 ticks[0] = vmin 

91 

92 return ticks 

93 

94 

95def get_colorbar_and_data_ranges( 

96 stat_map_data, 

97 vmin=None, 

98 vmax=None, 

99 symmetric_cbar=True, 

100 force_min_stat_map_value=None, 

101): 

102 """Set colormap and colorbar limits. 

103 

104 Used by plot_stat_map, plot_glass_brain and plot_img_on_surf. 

105 

106 The limits for the colorbar depend on the symmetric_cbar argument. Please 

107 refer to docstring of plot_stat_map. 

108 """ 

109 # handle invalid vmin/vmax inputs 

110 if (not isinstance(vmin, Number)) or (not np.isfinite(vmin)): 

111 vmin = None 

112 if (not isinstance(vmax, Number)) or (not np.isfinite(vmax)): 

113 vmax = None 

114 

115 # avoid dealing with masked_array: 

116 if hasattr(stat_map_data, "_mask"): 

117 stat_map_data = np.asarray( 

118 stat_map_data[np.logical_not(stat_map_data._mask)] 

119 ) 

120 

121 if force_min_stat_map_value is None: 

122 stat_map_min = np.nanmin(stat_map_data) 

123 else: 

124 stat_map_min = force_min_stat_map_value 

125 stat_map_max = np.nanmax(stat_map_data) 

126 

127 if symmetric_cbar == "auto": 

128 if vmin is None or vmax is None: 

129 min_value = ( 

130 stat_map_min if vmin is None else max(vmin, stat_map_min) 

131 ) 

132 max_value = ( 

133 stat_map_max if vmax is None else min(stat_map_max, vmax) 

134 ) 

135 symmetric_cbar = min_value < 0 < max_value 

136 else: 

137 symmetric_cbar = np.isclose(vmin, -vmax) 

138 

139 # check compatibility between vmin, vmax and symmetric_cbar 

140 if symmetric_cbar: 

141 if vmin is None and vmax is None: 

142 vmax = max(-stat_map_min, stat_map_max) 

143 vmin = -vmax 

144 elif vmin is None: 

145 vmin = -vmax 

146 elif vmax is None: 

147 vmax = -vmin 

148 elif not np.isclose(vmin, -vmax): 

149 raise ValueError( 

150 "vmin must be equal to -vmax unless symmetric_cbar is False." 

151 ) 

152 cbar_vmin = vmin 

153 cbar_vmax = vmax 

154 # set colorbar limits 

155 else: 

156 negative_range = stat_map_max <= 0 

157 positive_range = stat_map_min >= 0 

158 if positive_range: 

159 cbar_vmin = 0 if vmin is None else vmin 

160 cbar_vmax = vmax 

161 elif negative_range: 

162 cbar_vmax = 0 if vmax is None else vmax 

163 cbar_vmin = vmin 

164 else: 

165 # limit colorbar to plotted values 

166 cbar_vmin = vmin 

167 cbar_vmax = vmax 

168 

169 # set vmin/vmax based on data if they are not already set 

170 if vmin is None: 

171 vmin = stat_map_min 

172 if vmax is None: 

173 vmax = stat_map_max 

174 

175 return cbar_vmin, cbar_vmax, float(vmin), float(vmax) 

176 

177 

178def check_threshold_not_negative(threshold): 

179 """Make sure threshold is non negative number. 

180 

181 If threshold == "auto", it may be set to very small value. 

182 So we allow for that. 

183 """ 

184 if isinstance(threshold, (int, float)) and threshold < -1e-5: 

185 raise ValueError("Threshold should be a non-negative number!") 

186 

187 

188def create_colormap_from_lut(cmap, default_cmap="gist_ncar"): 

189 """ 

190 Create a Matplotlib colormap from a DataFrame containing color mappings. 

191 

192 Parameters 

193 ---------- 

194 cmap : pd.DataFrame 

195 DataFrame with columns 'index', 'name', and 'color' (hex values) 

196 

197 Returns 

198 ------- 

199 colormap (LinearSegmentedColormap): A Matplotlib colormap 

200 """ 

201 if "color" not in cmap.columns: 

202 warn( 

203 "No 'color' column found in the look-up table. " 

204 "Will use the default colormap instead.", 

205 stacklevel=find_stack_level(), 

206 ) 

207 return default_cmap 

208 

209 # Ensure colors are properly extracted from DataFrame 

210 colors = cmap.sort_values(by="index")["color"].tolist() 

211 

212 # Create a colormap from the list of colors 

213 return LinearSegmentedColormap.from_list( 

214 "custom_colormap", colors, N=len(colors) 

215 )