Coverage for nilearn/plotting/matrix/_matplotlib_backend.py: 0%

92 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-16 12:32 +0200

1import matplotlib.pyplot as plt 

2import numpy as np 

3 

4from nilearn._utils import constrained_layout_kwargs 

5 

6VALID_TRI_VALUES = ("full", "lower", "diag") 

7 

8 

9def _configure_axis( 

10 axes, labels, label_size, x_label_rotation, y_label_rotation 

11): 

12 """Help for plot_matrix.""" 

13 if not labels: 

14 axes.xaxis.set_major_formatter(plt.NullFormatter()) 

15 axes.yaxis.set_major_formatter(plt.NullFormatter()) 

16 else: 

17 axes.set_xticks(np.arange(len(labels))) 

18 axes.set_xticklabels(labels, size=label_size) 

19 for label in axes.get_xticklabels(): 

20 label.set_ha("right") 

21 label.set_rotation(x_label_rotation) 

22 axes.set_yticks(np.arange(len(labels))) 

23 axes.set_yticklabels(labels, size=label_size) 

24 for label in axes.get_yticklabels(): 

25 label.set_ha("right") 

26 label.set_va("top") 

27 label.set_rotation(y_label_rotation) 

28 

29 

30def _configure_grid(axes, tri, size): 

31 """Help for plot_matrix.""" 

32 # Different grids for different layouts 

33 if tri == "lower": 

34 for i in range(size): 

35 # Correct for weird mis-sizing 

36 i = 1.001 * i 

37 axes.plot([i + 0.5, i + 0.5], [size - 0.5, i + 0.5], color="gray") 

38 axes.plot([i + 0.5, -0.5], [i + 0.5, i + 0.5], color="gray") 

39 elif tri == "diag": 

40 for i in range(size): 

41 # Correct for weird mis-sizing 

42 i = 1.001 * i 

43 axes.plot([i + 0.5, i + 0.5], [size - 0.5, i - 0.5], color="gray") 

44 axes.plot([i + 0.5, -0.5], [i - 0.5, i - 0.5], color="gray") 

45 else: 

46 for i in range(size): 

47 # Correct for weird mis-sizing 

48 i = 1.001 * i 

49 axes.plot([i + 0.5, i + 0.5], [size - 0.5, -0.5], color="gray") 

50 axes.plot([size - 0.5, -0.5], [i + 0.5, i + 0.5], color="gray") 

51 

52 

53def _fit_axes(axes): 

54 """Help for plot_matrix. 

55 

56 This function redimensions the given axes to have 

57 labels fitting. 

58 """ 

59 fig = axes.get_figure() 

60 renderer = fig.canvas.get_renderer() 

61 ylabel_width = ( 

62 axes.yaxis.get_tightbbox(renderer) 

63 .transformed(axes.figure.transFigure.inverted()) 

64 .width 

65 ) 

66 if axes.get_position().xmin < 1.1 * ylabel_width: 

67 # we need to move it over 

68 new_position = axes.get_position() 

69 new_position.x0 = 1.1 * ylabel_width # pad a little 

70 axes.set_position(new_position) 

71 

72 xlabel_height = ( 

73 axes.xaxis.get_tightbbox(renderer) 

74 .transformed(axes.figure.transFigure.inverted()) 

75 .height 

76 ) 

77 if axes.get_position().ymin < 1.1 * xlabel_height: 

78 # we need to move it over 

79 new_position = axes.get_position() 

80 new_position.y0 = 1.1 * xlabel_height # pad a little 

81 axes.set_position(new_position) 

82 

83 

84def _sanitize_figure_and_axes(figure, axes): 

85 """Help for plot_matrix.""" 

86 if axes is not None and figure is not None: 

87 raise ValueError( 

88 "Parameters figure and axes cannot be specified together. " 

89 f"You gave 'figure={figure}, axes={axes}'." 

90 ) 

91 if figure is not None: 

92 if isinstance(figure, plt.Figure): 

93 fig = figure 

94 if hasattr(fig, "set_layout_engine"): # can be removed w/mpl 3.5 

95 fig.set_layout_engine("constrained") 

96 else: 

97 fig = plt.figure(figsize=figure, **constrained_layout_kwargs()) 

98 axes = plt.gca() 

99 own_fig = True 

100 elif axes is None: 

101 fig, axes = plt.subplots( 

102 1, 

103 1, 

104 figsize=(7, 5), 

105 **constrained_layout_kwargs(), 

106 ) 

107 own_fig = True 

108 else: 

109 fig = axes.figure 

110 own_fig = False 

111 return fig, axes, own_fig 

112 

113 

114def _sanitize_labels(mat_shape, labels): 

115 """Help for plot_matrix.""" 

116 # we need a list so an empty one will be cast to False 

117 if isinstance(labels, np.ndarray): 

118 labels = labels.tolist() 

119 if labels and len(labels) != mat_shape[0]: 

120 raise ValueError( 

121 f"Length of labels ({len(labels)}) " 

122 f"unequal to length of matrix ({mat_shape[0]})." 

123 ) 

124 return labels 

125 

126 

127def _sanitize_inputs_plot_matrix( 

128 mat_shape, tri, labels, reorder, figure, axes 

129): 

130 """Help for plot_matrix. 

131 

132 This function makes sure the inputs to plot_matrix are valid. 

133 """ 

134 _sanitize_tri(tri) 

135 labels = _sanitize_labels(mat_shape, labels) 

136 reorder = _sanitize_reorder(reorder) 

137 fig, axes, own_fig = _sanitize_figure_and_axes(figure, axes) 

138 return labels, reorder, fig, axes, own_fig 

139 

140 

141def _sanitize_reorder(reorder): 

142 """Help for plot_matrix.""" 

143 VALID_REORDER_ARGS = (True, False, "single", "complete", "average") 

144 if reorder not in VALID_REORDER_ARGS: 

145 param_to_print = [] 

146 for item in VALID_REORDER_ARGS: 

147 if isinstance(item, str): 

148 param_to_print.append(f'"{item}"') 

149 else: 

150 param_to_print.append(str(item)) 

151 raise ValueError( 

152 "Parameter reorder needs to be one of:" 

153 f"\n{', '.join(param_to_print)}." 

154 ) 

155 reorder = "average" if reorder is True else reorder 

156 return reorder 

157 

158 

159def _sanitize_tri(tri, allowed_values=None): 

160 """Help for plot_matrix.""" 

161 if allowed_values is None: 

162 allowed_values = VALID_TRI_VALUES 

163 if tri not in allowed_values: 

164 raise ValueError( 

165 f"Parameter tri needs to be one of: {', '.join(allowed_values)}." 

166 )