Coverage for nilearn/plotting/matrix/tests/test_matrix_plotting.py: 0%

156 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-20 10:58 +0200

1import matplotlib as mpl 

2import matplotlib.pyplot as plt 

3import numpy as np 

4import pandas as pd 

5import pytest 

6 

7from nilearn._utils import constrained_layout_kwargs 

8from nilearn.glm.first_level.design_matrix import ( 

9 make_first_level_design_matrix, 

10) 

11from nilearn.glm.tests._testing import block_paradigm, modulated_event_paradigm 

12from nilearn.plotting.matrix._utils import VALID_TRI_VALUES 

13from nilearn.plotting.matrix.matrix_plotting import ( 

14 _sanitize_figure_and_axes, 

15 plot_contrast_matrix, 

16 plot_design_matrix, 

17 plot_design_matrix_correlation, 

18 plot_event, 

19 plot_matrix, 

20) 

21 

22 

23@pytest.fixture 

24def mat(): 

25 return np.zeros((10, 10)) 

26 

27 

28@pytest.fixture 

29def labels(): 

30 return [str(i) for i in range(10)] 

31 

32 

33############################################################################## 

34# Some smoke testing for graphics-related code 

35 

36 

37@pytest.mark.parametrize( 

38 "fig,axes", [("foo", "bar"), (1, 2), plt.subplots(1, 1, figsize=(7, 5))] 

39) 

40def test_sanitize_figure_and_axes_error(fig, axes): 

41 with pytest.raises( 

42 ValueError, 

43 match=("Parameters figure and axes cannot be specified together."), 

44 ): 

45 _sanitize_figure_and_axes(fig, axes) 

46 

47 

48@pytest.mark.parametrize( 

49 "fig,axes,expected", 

50 [ 

51 ((6, 4), None, True), 

52 (plt.figure(figsize=(3, 2)), None, True), 

53 (None, None, True), 

54 (None, plt.subplots(1, 1)[1], False), 

55 ], 

56) 

57def test_sanitize_figure_and_axes(fig, axes, expected): 

58 fig2, axes2, own_fig = _sanitize_figure_and_axes(fig, axes) 

59 assert isinstance(fig2, plt.Figure) 

60 assert isinstance(axes2, plt.Axes) 

61 assert own_fig == expected 

62 

63 

64@pytest.mark.parametrize( 

65 "matrix, labels, reorder", 

66 [ 

67 (np.zeros((10, 10)), [0, 1, 2], False), 

68 (np.zeros((10, 10)), None, True), 

69 (np.zeros((10, 10)), [str(i) for i in range(10)], " "), 

70 ], 

71) 

72def test_matrix_plotting_errors(matrix, labels, reorder): 

73 """Test invalid input values for plot_matrix.""" 

74 with pytest.raises(ValueError): 

75 plot_matrix(matrix, labels=labels, reorder=reorder) 

76 

77 

78@pytest.mark.parametrize("tri", VALID_TRI_VALUES) 

79def test_matrix_plotting_with_labels_and_different_tri(mat, labels, tri): 

80 """Test plot_matrix with labels on only part of the matrix.""" 

81 ax = plot_matrix(mat, labels=labels, tri=tri) 

82 

83 assert isinstance(ax, mpl.image.AxesImage) 

84 ax.axes.set_title("Title") 

85 assert ax._axes.get_title() == "Title" 

86 for axis in [ax._axes.xaxis, ax._axes.yaxis]: 

87 assert len(axis.majorTicks) == len(labels) 

88 for tick, label in zip(axis.majorTicks, labels): 

89 assert tick.label1.get_text() == label 

90 

91 

92@pytest.mark.parametrize("title", ["foo", "foo bar", " ", None]) 

93def test_matrix_plotting_set_title(mat, labels, title): 

94 """Test setting title with plot_matrix.""" 

95 ax = plot_matrix(mat, labels=labels, title=title) 

96 

97 n_txt = 0 if title is None else len(title) 

98 

99 assert len(ax._axes.title.get_text()) == n_txt 

100 if title is not None: 

101 assert ax._axes.title.get_text() == title 

102 

103 

104def test_matrix_plotting_reorder(mat, labels): 

105 from itertools import permutations 

106 

107 # test if reordering with default linkage works 

108 idx = [2, 3, 5] 

109 # make symmetric matrix of similarities so we can get a block 

110 for perm in permutations(idx, 2): 

111 mat[perm] = 1 

112 

113 ax = plot_matrix(mat, labels=labels, reorder=True) 

114 

115 assert len(labels) == len(ax.axes.get_xticklabels()) 

116 reordered_labels = [ 

117 int(lbl.get_text()) for lbl in ax.axes.get_xticklabels() 

118 ] 

119 # block order does not matter 

120 assert reordered_labels[:3] == idx or reordered_labels[-3:] == idx, ( 

121 "Clustering does not find block structure." 

122 ) 

123 

124 plt.close() 

125 

126 # test if reordering with specific linkage works 

127 ax = plot_matrix(mat, labels=labels, reorder="complete") 

128 

129 

130def test_show_design_matrix(tmp_path): 

131 """Test plot_design_matrix saving to file.""" 

132 frame_times = np.linspace(0, 127 * 1.0, 128) 

133 dmtx = make_first_level_design_matrix( 

134 frame_times, drift_model="polynomial", drift_order=3 

135 ) 

136 

137 ax = plot_design_matrix(dmtx, output_file=tmp_path / "dmtx.png") 

138 

139 assert (tmp_path / "dmtx.png").exists() 

140 assert ax is None 

141 

142 plot_design_matrix(dmtx, output_file=tmp_path / "dmtx.pdf") 

143 

144 assert (tmp_path / "dmtx.pdf").exists() 

145 

146 

147@pytest.mark.parametrize("suffix, sep", [(".csv", ","), (".tsv", "\t")]) 

148def test_plot_design_matrix_path_str(tmp_path, suffix, sep): 

149 """Test plot_design_matrix directly from file.""" 

150 frame_times = np.linspace(0, 127 * 1.0, 128) 

151 dmtx = make_first_level_design_matrix( 

152 frame_times, drift_model="polynomial", drift_order=3 

153 ) 

154 filename = (tmp_path / "tmp").with_suffix(suffix) 

155 dmtx.to_csv(filename, sep=sep, index=False) 

156 

157 ax = plot_design_matrix(filename) 

158 

159 assert ax is not None 

160 

161 ax = plot_design_matrix(str(filename)) 

162 

163 assert ax is not None 

164 

165 

166def test_show_event_plot(tmp_path): 

167 """Test plot_event.""" 

168 onset = np.linspace(0, 19.0, 20) 

169 duration = np.full(20, 0.5) 

170 trial_idx = np.arange(20) 

171 

172 trial_idx[10:] -= 10 

173 condition_ids = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"] 

174 

175 # add some modulation 

176 modulation = np.full(20, 1) 

177 modulation[[1, 5, 15]] = 0.5 

178 

179 trial_type = np.array([condition_ids[i] for i in trial_idx]) 

180 

181 model_event = pd.DataFrame( 

182 { 

183 "onset": onset, 

184 "duration": duration, 

185 "trial_type": trial_type, 

186 "modulation": modulation, 

187 } 

188 ) 

189 # Test Dataframe 

190 fig = plot_event(model_event) 

191 

192 assert fig is not None 

193 

194 # Test List 

195 fig = plot_event([model_event, model_event]) 

196 

197 assert fig is not None 

198 

199 # Test save 

200 fig = plot_event(model_event, output_file=tmp_path / "event.png") 

201 

202 assert (tmp_path / "event.png").exists() 

203 assert fig is None 

204 

205 plot_event(model_event, output_file=tmp_path / "event.pdf") 

206 

207 assert (tmp_path / "event.pdf").exists() 

208 

209 

210def test_plot_event_error(): 

211 """Test plot_event error with cmap.""" 

212 onset = np.linspace(0, 19.0, 20) 

213 duration = np.full(20, 0.5) 

214 trial_idx = np.arange(20) 

215 # This makes 11 events in order to test cmap error 

216 trial_idx[11:] -= 10 

217 condition_ids = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k"] 

218 

219 # add some modulation 

220 modulation = np.full(20, 1) 

221 modulation[[1, 5, 15]] = 0.5 

222 

223 trial_type = np.array([condition_ids[i] for i in trial_idx]) 

224 

225 model_event = pd.DataFrame( 

226 { 

227 "onset": onset, 

228 "duration": duration, 

229 "trial_type": trial_type, 

230 "modulation": modulation, 

231 } 

232 ) 

233 

234 with pytest.raises( 

235 ValueError, 

236 match="The number of event types is greater than colors in colormap", 

237 ): 

238 plot_event(model_event, cmap="tab10") 

239 

240 

241@pytest.mark.parametrize("suffix, sep", [(".csv", ","), (".tsv", "\t")]) 

242def test_plot_event_path_tsv_csv(tmp_path, suffix, sep): 

243 """Test plot_events directly from file.""" 

244 model_event = block_paradigm() 

245 filename = (tmp_path / "tmp").with_suffix(suffix) 

246 model_event.to_csv(filename, sep=sep, index=False) 

247 

248 plot_event(filename) 

249 plot_event([filename, str(filename)]) 

250 

251 

252def test_show_contrast_matrix(tmp_path): 

253 """Test that the show code indeed (formally) runs.""" 

254 frame_times = np.linspace(0, 127 * 1.0, 128) 

255 dmtx = make_first_level_design_matrix( 

256 frame_times, drift_model="polynomial", drift_order=3 

257 ) 

258 contrast = np.ones(4) 

259 

260 ax = plot_contrast_matrix( 

261 contrast, dmtx, output_file=tmp_path / "contrast.png" 

262 ) 

263 assert (tmp_path / "contrast.png").exists() 

264 

265 assert ax is None 

266 

267 plot_contrast_matrix(contrast, dmtx, output_file=tmp_path / "contrast.pdf") 

268 

269 assert (tmp_path / "contrast.pdf").exists() 

270 

271 

272def test_show_contrast_matrix_axes(): 

273 """Test poassing axes to plot_contrast_matrix.""" 

274 frame_times = np.linspace(0, 127 * 1.0, 128) 

275 dmtx = make_first_level_design_matrix( 

276 frame_times, drift_model="polynomial", drift_order=3 

277 ) 

278 contrast = np.ones(4) 

279 fig, ax = plt.subplots(**constrained_layout_kwargs()) 

280 

281 plot_contrast_matrix(contrast, dmtx, axes=ax) 

282 

283 # to actually check we need get_layout_engine, but even without it the 

284 # above allows us to test the kwargs are at least okay 

285 pytest.importorskip("matplotlib", minversion="3.5.0") 

286 assert "constrained" in fig.get_layout_engine().__class__.__name__.lower() 

287 

288 

289@pytest.mark.parametrize("cmap", ["RdBu_r", "bwr", "seismic_r"]) 

290def test_plot_design_matrix_correlation(cmap, tmp_path): 

291 """Smoke test for valid cmaps and output file.""" 

292 frame_times = np.linspace(0, 127 * 1.0, 128) 

293 dmtx = make_first_level_design_matrix( 

294 frame_times, events=modulated_event_paradigm() 

295 ) 

296 

297 plot_design_matrix_correlation( 

298 dmtx, cmap=cmap, output_file=tmp_path / "corr_mat.png" 

299 ) 

300 

301 assert (tmp_path / "corr_mat.png").exists() 

302 

303 

304def test_plot_design_matrix_correlation_smoke_path(tmp_path): 

305 """Check that plot_design_matrix_correlation works with paths.""" 

306 frame_times = np.linspace(0, 127 * 1.0, 128) 

307 dmtx = make_first_level_design_matrix( 

308 frame_times, events=modulated_event_paradigm() 

309 ) 

310 

311 dmtx.to_csv(tmp_path / "tmp.tsv", sep="\t", index=False) 

312 

313 plot_design_matrix_correlation(tmp_path / "tmp.tsv") 

314 plot_design_matrix_correlation(str(tmp_path / "tmp.tsv")) 

315 

316 

317def test_plot_design_matrix_correlation_errors(mat): 

318 """Test plot_design_matrix_correlation errors.""" 

319 with pytest.raises( 

320 ValueError, match="Tables to load can only be TSV or CSV." 

321 ): 

322 plot_design_matrix_correlation("foo") 

323 

324 with pytest.raises(ValueError, match="dataframe cannot be empty."): 

325 plot_design_matrix_correlation(pd.DataFrame()) 

326 

327 with pytest.raises(ValueError, match="cmap must be one of"): 

328 plot_design_matrix_correlation(pd.DataFrame(mat), cmap="foo") 

329 

330 dmtx = pd.DataFrame( 

331 {"event_1": [0, 1], "constant": [1, 1], "drift_1": [0, 1]} 

332 ) 

333 with pytest.raises(ValueError, match="tri needs to be one of"): 

334 plot_design_matrix_correlation(dmtx, tri="lower") 

335 

336 dmtx = pd.DataFrame({"constant": [1, 1], "drift_1": [0, 1]}) 

337 with pytest.raises(ValueError, match="Nothing left to plot after "): 

338 plot_design_matrix_correlation(dmtx)