Coverage for nilearn/plotting/matrix/tests/test_matrix_plotting.py: 0%
156 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
1import matplotlib as mpl
2import matplotlib.pyplot as plt
3import numpy as np
4import pandas as pd
5import pytest
7from nilearn._utils import constrained_layout_kwargs
8from nilearn.glm.first_level.design_matrix import (
9 make_first_level_design_matrix,
10)
11from nilearn.glm.tests._testing import block_paradigm, modulated_event_paradigm
12from nilearn.plotting.matrix._utils import VALID_TRI_VALUES
13from nilearn.plotting.matrix.matrix_plotting import (
14 _sanitize_figure_and_axes,
15 plot_contrast_matrix,
16 plot_design_matrix,
17 plot_design_matrix_correlation,
18 plot_event,
19 plot_matrix,
20)
23@pytest.fixture
24def mat():
25 return np.zeros((10, 10))
28@pytest.fixture
29def labels():
30 return [str(i) for i in range(10)]
33##############################################################################
34# Some smoke testing for graphics-related code
37@pytest.mark.parametrize(
38 "fig,axes", [("foo", "bar"), (1, 2), plt.subplots(1, 1, figsize=(7, 5))]
39)
40def test_sanitize_figure_and_axes_error(fig, axes):
41 with pytest.raises(
42 ValueError,
43 match=("Parameters figure and axes cannot be specified together."),
44 ):
45 _sanitize_figure_and_axes(fig, axes)
48@pytest.mark.parametrize(
49 "fig,axes,expected",
50 [
51 ((6, 4), None, True),
52 (plt.figure(figsize=(3, 2)), None, True),
53 (None, None, True),
54 (None, plt.subplots(1, 1)[1], False),
55 ],
56)
57def test_sanitize_figure_and_axes(fig, axes, expected):
58 fig2, axes2, own_fig = _sanitize_figure_and_axes(fig, axes)
59 assert isinstance(fig2, plt.Figure)
60 assert isinstance(axes2, plt.Axes)
61 assert own_fig == expected
64@pytest.mark.parametrize(
65 "matrix, labels, reorder",
66 [
67 (np.zeros((10, 10)), [0, 1, 2], False),
68 (np.zeros((10, 10)), None, True),
69 (np.zeros((10, 10)), [str(i) for i in range(10)], " "),
70 ],
71)
72def test_matrix_plotting_errors(matrix, labels, reorder):
73 """Test invalid input values for plot_matrix."""
74 with pytest.raises(ValueError):
75 plot_matrix(matrix, labels=labels, reorder=reorder)
78@pytest.mark.parametrize("tri", VALID_TRI_VALUES)
79def test_matrix_plotting_with_labels_and_different_tri(mat, labels, tri):
80 """Test plot_matrix with labels on only part of the matrix."""
81 ax = plot_matrix(mat, labels=labels, tri=tri)
83 assert isinstance(ax, mpl.image.AxesImage)
84 ax.axes.set_title("Title")
85 assert ax._axes.get_title() == "Title"
86 for axis in [ax._axes.xaxis, ax._axes.yaxis]:
87 assert len(axis.majorTicks) == len(labels)
88 for tick, label in zip(axis.majorTicks, labels):
89 assert tick.label1.get_text() == label
92@pytest.mark.parametrize("title", ["foo", "foo bar", " ", None])
93def test_matrix_plotting_set_title(mat, labels, title):
94 """Test setting title with plot_matrix."""
95 ax = plot_matrix(mat, labels=labels, title=title)
97 n_txt = 0 if title is None else len(title)
99 assert len(ax._axes.title.get_text()) == n_txt
100 if title is not None:
101 assert ax._axes.title.get_text() == title
104def test_matrix_plotting_reorder(mat, labels):
105 from itertools import permutations
107 # test if reordering with default linkage works
108 idx = [2, 3, 5]
109 # make symmetric matrix of similarities so we can get a block
110 for perm in permutations(idx, 2):
111 mat[perm] = 1
113 ax = plot_matrix(mat, labels=labels, reorder=True)
115 assert len(labels) == len(ax.axes.get_xticklabels())
116 reordered_labels = [
117 int(lbl.get_text()) for lbl in ax.axes.get_xticklabels()
118 ]
119 # block order does not matter
120 assert reordered_labels[:3] == idx or reordered_labels[-3:] == idx, (
121 "Clustering does not find block structure."
122 )
124 plt.close()
126 # test if reordering with specific linkage works
127 ax = plot_matrix(mat, labels=labels, reorder="complete")
130def test_show_design_matrix(tmp_path):
131 """Test plot_design_matrix saving to file."""
132 frame_times = np.linspace(0, 127 * 1.0, 128)
133 dmtx = make_first_level_design_matrix(
134 frame_times, drift_model="polynomial", drift_order=3
135 )
137 ax = plot_design_matrix(dmtx, output_file=tmp_path / "dmtx.png")
139 assert (tmp_path / "dmtx.png").exists()
140 assert ax is None
142 plot_design_matrix(dmtx, output_file=tmp_path / "dmtx.pdf")
144 assert (tmp_path / "dmtx.pdf").exists()
147@pytest.mark.parametrize("suffix, sep", [(".csv", ","), (".tsv", "\t")])
148def test_plot_design_matrix_path_str(tmp_path, suffix, sep):
149 """Test plot_design_matrix directly from file."""
150 frame_times = np.linspace(0, 127 * 1.0, 128)
151 dmtx = make_first_level_design_matrix(
152 frame_times, drift_model="polynomial", drift_order=3
153 )
154 filename = (tmp_path / "tmp").with_suffix(suffix)
155 dmtx.to_csv(filename, sep=sep, index=False)
157 ax = plot_design_matrix(filename)
159 assert ax is not None
161 ax = plot_design_matrix(str(filename))
163 assert ax is not None
166def test_show_event_plot(tmp_path):
167 """Test plot_event."""
168 onset = np.linspace(0, 19.0, 20)
169 duration = np.full(20, 0.5)
170 trial_idx = np.arange(20)
172 trial_idx[10:] -= 10
173 condition_ids = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"]
175 # add some modulation
176 modulation = np.full(20, 1)
177 modulation[[1, 5, 15]] = 0.5
179 trial_type = np.array([condition_ids[i] for i in trial_idx])
181 model_event = pd.DataFrame(
182 {
183 "onset": onset,
184 "duration": duration,
185 "trial_type": trial_type,
186 "modulation": modulation,
187 }
188 )
189 # Test Dataframe
190 fig = plot_event(model_event)
192 assert fig is not None
194 # Test List
195 fig = plot_event([model_event, model_event])
197 assert fig is not None
199 # Test save
200 fig = plot_event(model_event, output_file=tmp_path / "event.png")
202 assert (tmp_path / "event.png").exists()
203 assert fig is None
205 plot_event(model_event, output_file=tmp_path / "event.pdf")
207 assert (tmp_path / "event.pdf").exists()
210def test_plot_event_error():
211 """Test plot_event error with cmap."""
212 onset = np.linspace(0, 19.0, 20)
213 duration = np.full(20, 0.5)
214 trial_idx = np.arange(20)
215 # This makes 11 events in order to test cmap error
216 trial_idx[11:] -= 10
217 condition_ids = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k"]
219 # add some modulation
220 modulation = np.full(20, 1)
221 modulation[[1, 5, 15]] = 0.5
223 trial_type = np.array([condition_ids[i] for i in trial_idx])
225 model_event = pd.DataFrame(
226 {
227 "onset": onset,
228 "duration": duration,
229 "trial_type": trial_type,
230 "modulation": modulation,
231 }
232 )
234 with pytest.raises(
235 ValueError,
236 match="The number of event types is greater than colors in colormap",
237 ):
238 plot_event(model_event, cmap="tab10")
241@pytest.mark.parametrize("suffix, sep", [(".csv", ","), (".tsv", "\t")])
242def test_plot_event_path_tsv_csv(tmp_path, suffix, sep):
243 """Test plot_events directly from file."""
244 model_event = block_paradigm()
245 filename = (tmp_path / "tmp").with_suffix(suffix)
246 model_event.to_csv(filename, sep=sep, index=False)
248 plot_event(filename)
249 plot_event([filename, str(filename)])
252def test_show_contrast_matrix(tmp_path):
253 """Test that the show code indeed (formally) runs."""
254 frame_times = np.linspace(0, 127 * 1.0, 128)
255 dmtx = make_first_level_design_matrix(
256 frame_times, drift_model="polynomial", drift_order=3
257 )
258 contrast = np.ones(4)
260 ax = plot_contrast_matrix(
261 contrast, dmtx, output_file=tmp_path / "contrast.png"
262 )
263 assert (tmp_path / "contrast.png").exists()
265 assert ax is None
267 plot_contrast_matrix(contrast, dmtx, output_file=tmp_path / "contrast.pdf")
269 assert (tmp_path / "contrast.pdf").exists()
272def test_show_contrast_matrix_axes():
273 """Test poassing axes to plot_contrast_matrix."""
274 frame_times = np.linspace(0, 127 * 1.0, 128)
275 dmtx = make_first_level_design_matrix(
276 frame_times, drift_model="polynomial", drift_order=3
277 )
278 contrast = np.ones(4)
279 fig, ax = plt.subplots(**constrained_layout_kwargs())
281 plot_contrast_matrix(contrast, dmtx, axes=ax)
283 # to actually check we need get_layout_engine, but even without it the
284 # above allows us to test the kwargs are at least okay
285 pytest.importorskip("matplotlib", minversion="3.5.0")
286 assert "constrained" in fig.get_layout_engine().__class__.__name__.lower()
289@pytest.mark.parametrize("cmap", ["RdBu_r", "bwr", "seismic_r"])
290def test_plot_design_matrix_correlation(cmap, tmp_path):
291 """Smoke test for valid cmaps and output file."""
292 frame_times = np.linspace(0, 127 * 1.0, 128)
293 dmtx = make_first_level_design_matrix(
294 frame_times, events=modulated_event_paradigm()
295 )
297 plot_design_matrix_correlation(
298 dmtx, cmap=cmap, output_file=tmp_path / "corr_mat.png"
299 )
301 assert (tmp_path / "corr_mat.png").exists()
304def test_plot_design_matrix_correlation_smoke_path(tmp_path):
305 """Check that plot_design_matrix_correlation works with paths."""
306 frame_times = np.linspace(0, 127 * 1.0, 128)
307 dmtx = make_first_level_design_matrix(
308 frame_times, events=modulated_event_paradigm()
309 )
311 dmtx.to_csv(tmp_path / "tmp.tsv", sep="\t", index=False)
313 plot_design_matrix_correlation(tmp_path / "tmp.tsv")
314 plot_design_matrix_correlation(str(tmp_path / "tmp.tsv"))
317def test_plot_design_matrix_correlation_errors(mat):
318 """Test plot_design_matrix_correlation errors."""
319 with pytest.raises(
320 ValueError, match="Tables to load can only be TSV or CSV."
321 ):
322 plot_design_matrix_correlation("foo")
324 with pytest.raises(ValueError, match="dataframe cannot be empty."):
325 plot_design_matrix_correlation(pd.DataFrame())
327 with pytest.raises(ValueError, match="cmap must be one of"):
328 plot_design_matrix_correlation(pd.DataFrame(mat), cmap="foo")
330 dmtx = pd.DataFrame(
331 {"event_1": [0, 1], "constant": [1, 1], "drift_1": [0, 1]}
332 )
333 with pytest.raises(ValueError, match="tri needs to be one of"):
334 plot_design_matrix_correlation(dmtx, tri="lower")
336 dmtx = pd.DataFrame({"constant": [1, 1], "drift_1": [0, 1]})
337 with pytest.raises(ValueError, match="Nothing left to plot after "):
338 plot_design_matrix_correlation(dmtx)