Coverage for nilearn/plotting/matrix/_matplotlib_backend.py: 0%
92 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
1import matplotlib.pyplot as plt
2import numpy as np
4from nilearn._utils import constrained_layout_kwargs
6VALID_TRI_VALUES = ("full", "lower", "diag")
9def _configure_axis(
10 axes, labels, label_size, x_label_rotation, y_label_rotation
11):
12 """Help for plot_matrix."""
13 if not labels:
14 axes.xaxis.set_major_formatter(plt.NullFormatter())
15 axes.yaxis.set_major_formatter(plt.NullFormatter())
16 else:
17 axes.set_xticks(np.arange(len(labels)))
18 axes.set_xticklabels(labels, size=label_size)
19 for label in axes.get_xticklabels():
20 label.set_ha("right")
21 label.set_rotation(x_label_rotation)
22 axes.set_yticks(np.arange(len(labels)))
23 axes.set_yticklabels(labels, size=label_size)
24 for label in axes.get_yticklabels():
25 label.set_ha("right")
26 label.set_va("top")
27 label.set_rotation(y_label_rotation)
30def _configure_grid(axes, tri, size):
31 """Help for plot_matrix."""
32 # Different grids for different layouts
33 if tri == "lower":
34 for i in range(size):
35 # Correct for weird mis-sizing
36 i = 1.001 * i
37 axes.plot([i + 0.5, i + 0.5], [size - 0.5, i + 0.5], color="gray")
38 axes.plot([i + 0.5, -0.5], [i + 0.5, i + 0.5], color="gray")
39 elif tri == "diag":
40 for i in range(size):
41 # Correct for weird mis-sizing
42 i = 1.001 * i
43 axes.plot([i + 0.5, i + 0.5], [size - 0.5, i - 0.5], color="gray")
44 axes.plot([i + 0.5, -0.5], [i - 0.5, i - 0.5], color="gray")
45 else:
46 for i in range(size):
47 # Correct for weird mis-sizing
48 i = 1.001 * i
49 axes.plot([i + 0.5, i + 0.5], [size - 0.5, -0.5], color="gray")
50 axes.plot([size - 0.5, -0.5], [i + 0.5, i + 0.5], color="gray")
53def _fit_axes(axes):
54 """Help for plot_matrix.
56 This function redimensions the given axes to have
57 labels fitting.
58 """
59 fig = axes.get_figure()
60 renderer = fig.canvas.get_renderer()
61 ylabel_width = (
62 axes.yaxis.get_tightbbox(renderer)
63 .transformed(axes.figure.transFigure.inverted())
64 .width
65 )
66 if axes.get_position().xmin < 1.1 * ylabel_width:
67 # we need to move it over
68 new_position = axes.get_position()
69 new_position.x0 = 1.1 * ylabel_width # pad a little
70 axes.set_position(new_position)
72 xlabel_height = (
73 axes.xaxis.get_tightbbox(renderer)
74 .transformed(axes.figure.transFigure.inverted())
75 .height
76 )
77 if axes.get_position().ymin < 1.1 * xlabel_height:
78 # we need to move it over
79 new_position = axes.get_position()
80 new_position.y0 = 1.1 * xlabel_height # pad a little
81 axes.set_position(new_position)
84def _sanitize_figure_and_axes(figure, axes):
85 """Help for plot_matrix."""
86 if axes is not None and figure is not None:
87 raise ValueError(
88 "Parameters figure and axes cannot be specified together. "
89 f"You gave 'figure={figure}, axes={axes}'."
90 )
91 if figure is not None:
92 if isinstance(figure, plt.Figure):
93 fig = figure
94 if hasattr(fig, "set_layout_engine"): # can be removed w/mpl 3.5
95 fig.set_layout_engine("constrained")
96 else:
97 fig = plt.figure(figsize=figure, **constrained_layout_kwargs())
98 axes = plt.gca()
99 own_fig = True
100 elif axes is None:
101 fig, axes = plt.subplots(
102 1,
103 1,
104 figsize=(7, 5),
105 **constrained_layout_kwargs(),
106 )
107 own_fig = True
108 else:
109 fig = axes.figure
110 own_fig = False
111 return fig, axes, own_fig
114def _sanitize_labels(mat_shape, labels):
115 """Help for plot_matrix."""
116 # we need a list so an empty one will be cast to False
117 if isinstance(labels, np.ndarray):
118 labels = labels.tolist()
119 if labels and len(labels) != mat_shape[0]:
120 raise ValueError(
121 f"Length of labels ({len(labels)}) "
122 f"unequal to length of matrix ({mat_shape[0]})."
123 )
124 return labels
127def _sanitize_inputs_plot_matrix(
128 mat_shape, tri, labels, reorder, figure, axes
129):
130 """Help for plot_matrix.
132 This function makes sure the inputs to plot_matrix are valid.
133 """
134 _sanitize_tri(tri)
135 labels = _sanitize_labels(mat_shape, labels)
136 reorder = _sanitize_reorder(reorder)
137 fig, axes, own_fig = _sanitize_figure_and_axes(figure, axes)
138 return labels, reorder, fig, axes, own_fig
141def _sanitize_reorder(reorder):
142 """Help for plot_matrix."""
143 VALID_REORDER_ARGS = (True, False, "single", "complete", "average")
144 if reorder not in VALID_REORDER_ARGS:
145 param_to_print = []
146 for item in VALID_REORDER_ARGS:
147 if isinstance(item, str):
148 param_to_print.append(f'"{item}"')
149 else:
150 param_to_print.append(str(item))
151 raise ValueError(
152 "Parameter reorder needs to be one of:"
153 f"\n{', '.join(param_to_print)}."
154 )
155 reorder = "average" if reorder is True else reorder
156 return reorder
159def _sanitize_tri(tri, allowed_values=None):
160 """Help for plot_matrix."""
161 if allowed_values is None:
162 allowed_values = VALID_TRI_VALUES
163 if tri not in allowed_values:
164 raise ValueError(
165 f"Parameter tri needs to be one of: {', '.join(allowed_values)}."
166 )