Coverage for nilearn/plotting/tests/test_displays.py: 0%

210 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-16 12:32 +0200

1# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- 

2# vi: set ft=python sts=4 ts=4 sw=4 et: 

3 

4import matplotlib 

5import matplotlib.pyplot as plt 

6import numpy as np 

7import pytest 

8from nibabel import Nifti1Image 

9 

10from nilearn.datasets import load_mni152_template 

11from nilearn.plotting.displays import ( 

12 BaseAxes, 

13 LProjector, 

14 LRProjector, 

15 LYRProjector, 

16 LYRZProjector, 

17 LZRProjector, 

18 LZRYProjector, 

19 MosaicSlicer, 

20 OrthoProjector, 

21 OrthoSlicer, 

22 RProjector, 

23 TiledSlicer, 

24 XProjector, 

25 XSlicer, 

26 XZProjector, 

27 XZSlicer, 

28 YProjector, 

29 YSlicer, 

30 YXProjector, 

31 YXSlicer, 

32 YZProjector, 

33 YZSlicer, 

34 ZProjector, 

35 ZSlicer, 

36) 

37 

38SLICER_KEYS = ["ortho", "tiled", "x", "y", "z", "yx", "yz", "mosaic", "xz"] 

39SLICERS = [ 

40 OrthoSlicer, 

41 TiledSlicer, 

42 XSlicer, 

43 YSlicer, 

44 ZSlicer, 

45 YXSlicer, 

46 YZSlicer, 

47 MosaicSlicer, 

48 XZSlicer, 

49] 

50PROJECTOR_KEYS = [ 

51 "ortho", 

52 "xz", 

53 "yz", 

54 "yx", 

55 "lyrz", 

56 "lyr", 

57 "lzr", 

58 "lr", 

59 "l", 

60 "r", 

61] 

62PROJECTORS = [ 

63 OrthoProjector, 

64 XZProjector, 

65 YZProjector, 

66 YXProjector, 

67 XProjector, 

68 YProjector, 

69 ZProjector, 

70 LZRYProjector, 

71 LYRZProjector, 

72 LYRProjector, 

73 LZRProjector, 

74 LRProjector, 

75 LProjector, 

76 RProjector, 

77] 

78 

79 

80def test_base_axes_exceptions(): 

81 """Tests for exceptions raised by class ``BaseAxes``.""" 

82 axes = BaseAxes(None, "foo", 3) 

83 # Constructor doesn't raise for invalid direction 

84 assert axes.direction == "foo" 

85 assert axes.coord == 3 

86 with pytest.raises( 

87 NotImplementedError, match="'transform_to_2d' needs to be" 

88 ): 

89 axes.transform_to_2d(None, None) 

90 with pytest.raises(NotImplementedError, match="'draw_position' should be"): 

91 axes.draw_position(None, None) 

92 with pytest.raises(ValueError, match="Invalid value for direction"): 

93 axes.draw_2d(None, None, None) 

94 

95 

96def test_cut_axes_exception(affine_eye): 

97 """Tests for exceptions raised by class ``CutAxes``.""" 

98 from nilearn.plotting.displays import CutAxes 

99 

100 axes = CutAxes(None, "foo", 2) 

101 assert axes.direction == "foo" 

102 assert axes.coord == 2 

103 with pytest.raises(ValueError, match="Invalid value for direction"): 

104 axes.transform_to_2d(None, affine_eye) 

105 

106 

107def test_glass_brain_axes(): 

108 """Tests for class ``GlassBrainAxes``.""" 

109 from nilearn.plotting.displays import GlassBrainAxes 

110 

111 ax = plt.subplot(111) 

112 axes = GlassBrainAxes(ax, "r", 2) 

113 axes._add_markers(np.array([[0, 0, 0]]), "g", [10]) 

114 line_coords = [np.array([[0, 0, 0], [1, 1, 1]])] 

115 line_values = np.array([1, 0, 6]) 

116 with pytest.raises( 

117 ValueError, match="If vmax is set to a non-positive number " 

118 ): 

119 axes._add_lines(line_coords, line_values, None, vmin=None, vmax=-10) 

120 axes._add_lines(line_coords, line_values, None, vmin=None, vmax=10) 

121 with pytest.raises( 

122 ValueError, match="If vmin is set to a non-negative number " 

123 ): 

124 axes._add_lines(line_coords, line_values, None, vmin=10, vmax=None) 

125 axes._add_lines(line_coords, line_values, None, vmin=-10, vmax=None) 

126 axes._add_lines(line_coords, line_values, None, vmin=-10, vmax=-5) 

127 

128 

129def test_get_index_from_direction_exception(): 

130 """Tests that a ValueError is raised when an invalid direction \ 

131 is given to function ``_get_index_from_direction``. 

132 """ 

133 from nilearn.plotting.displays._utils import _get_index_from_direction 

134 

135 with pytest.raises(ValueError, match="foo is not a valid direction."): 

136 _get_index_from_direction("foo") 

137 

138 

139@pytest.fixture 

140def img(): 

141 """Image used for testing.""" 

142 return load_mni152_template(resolution=2) 

143 

144 

145@pytest.fixture 

146def cut_coords(name): 

147 """Select appropriate cut coords.""" 

148 if name == "mosaic": 

149 return 3 

150 if name in ["yx", "yz", "xz"]: 

151 return (0,) * 2 

152 if name in ["lyrz", "lyr", "lzr"]: 

153 return (0,) 

154 return (0,) * 4 if name in ["lr", "l"] else (0,) * 3 

155 

156 

157@pytest.mark.parametrize("display,name", zip(SLICERS, SLICER_KEYS)) 

158def test_display_basics_slicers(display, name, img, cut_coords): 

159 """Basic smoke tests for all displays (slicers). 

160 

161 Each object is instantiated, ``add_overlay``, ``title``, 

162 and ``close`` are then called. 

163 """ 

164 display = display(cut_coords=cut_coords) 

165 display.add_overlay(img, cmap="gray") 

166 display.title(f"display mode is {name}") 

167 if name != "mosaic": 

168 assert display.cut_coords == cut_coords 

169 assert isinstance(display.frame_axes, matplotlib.axes.Axes) 

170 display.close() 

171 

172 

173@pytest.mark.parametrize("display,name", zip(PROJECTORS, PROJECTOR_KEYS)) 

174def test_display_basics_projectors(display, name, img, cut_coords): 

175 """Basic smoke tests for all displays (projectors). 

176 

177 Each object is instantiated, ``add_overlay``, ``title``, 

178 and ``close`` are then called. 

179 """ 

180 display = display(cut_coords=cut_coords) 

181 display.add_overlay(img, cmap="gray") 

182 display.title(f"display mode is {name}") 

183 if name != "mosaic": 

184 assert display.cut_coords == cut_coords 

185 assert isinstance(display.frame_axes, matplotlib.axes.Axes) 

186 display.close() 

187 

188 

189@pytest.mark.parametrize( 

190 "slicer", [XSlicer, YSlicer, ZSlicer, YXSlicer, YZSlicer, XZSlicer] 

191) 

192def test_stacked_slicer(slicer, img, tmp_path): 

193 """Tests for saving to file with stacked slicers.""" 

194 cut_coords = 3 if slicer in [XSlicer, YSlicer, ZSlicer] else (3, 3) 

195 slicer = slicer.init_with_figure(img=img, cut_coords=cut_coords) 

196 slicer.add_overlay(img, cmap="gray") 

197 # Forcing a layout here, to test the locator code 

198 slicer.savefig(tmp_path / "out.png") 

199 slicer.close() 

200 

201 

202@pytest.mark.parametrize("slicer", [OrthoSlicer, TiledSlicer, MosaicSlicer]) 

203def test_slicer_save_to_file(slicer, img, tmp_path): 

204 """Tests for saving to file with Ortho/Tiled/Mosaic slicers.""" 

205 cut_coords = None if slicer == MosaicSlicer else (0, 0, 0) 

206 slicer = slicer.init_with_figure( 

207 img=img, cut_coords=cut_coords, colorbar=True 

208 ) 

209 slicer.add_overlay(img, cmap="gray", colorbar=True) 

210 assert slicer.brain_color == (0.5, 0.5, 0.5) 

211 assert not slicer.black_bg 

212 # Forcing a layout here, to test the locator code 

213 slicer.savefig(tmp_path / "out.png") 

214 slicer.close() 

215 

216 

217@pytest.mark.parametrize("cut_coords", [2, 4]) 

218def test_mosaic_slicer_integer_cut_coords(cut_coords, img): 

219 """Tests for MosaicSlicer with cut_coords provided as an integer.""" 

220 slicer = MosaicSlicer.init_with_figure(img=img, cut_coords=cut_coords) 

221 slicer.add_overlay(img, cmap="gray", colorbar=True) 

222 slicer.title("mosaic mode") 

223 for d in ["x", "y", "z"]: 

224 assert d in slicer.cut_coords 

225 assert len(slicer.cut_coords[d]) == cut_coords 

226 slicer.close() 

227 

228 

229@pytest.mark.parametrize("cut_coords", [(4, 5, 2), (1, 1, 1)]) 

230def test_mosaic_slicer_tuple_cut_coords(cut_coords, img): 

231 """Tests for MosaicSlicer with cut_coords provided as a tuple.""" 

232 slicer = MosaicSlicer.init_with_figure(img=img, cut_coords=cut_coords) 

233 slicer.add_overlay(img, cmap="gray", colorbar=True) 

234 slicer.title("Showing mosaic mode") 

235 for i, d in enumerate(["x", "y", "z"]): 

236 assert len(slicer.cut_coords[d]) == cut_coords[i] 

237 slicer.close() 

238 

239 

240@pytest.mark.parametrize("cut_coords", [None, 5, (1, 1, 1)]) 

241def test_mosaic_slicer_img_none_false(cut_coords, img): 

242 """Tests for MosaicSlicer when img is ``None`` or ``False`` \ 

243 while initializing the figure. 

244 """ 

245 slicer = MosaicSlicer.init_with_figure(img=None, cut_coords=cut_coords) 

246 slicer.add_overlay(img, cmap="gray", colorbar=True) 

247 slicer.close() 

248 

249 

250@pytest.mark.parametrize("cut_coords", [(5, 4), (1, 2, 3, 4)]) 

251def test_mosaic_slicer_wrong_inputs(cut_coords): 

252 """Tests that providing wrong inputs raises a ``ValueError``.""" 

253 with pytest.raises( 

254 ValueError, 

255 match=( 

256 "The number cut_coords passed does not " 

257 "match the display_mode. Mosaic plotting " 

258 "expects tuple of length 3." 

259 ), 

260 ): 

261 MosaicSlicer.init_with_figure(img=None, cut_coords=cut_coords) 

262 MosaicSlicer(img=None, cut_coords=cut_coords) 

263 

264 

265@pytest.fixture 

266def expected_cuts(cut_coords): 

267 """Return expected cut with test_demo_mosaic_slicer.""" 

268 if cut_coords == (1, 1, 1): 

269 return {"x": [-40.0], "y": [-30.0], "z": [-30.0]} 

270 if cut_coords == 5: 

271 return { 

272 "x": [-40.0, -20.0, 0.0, 20.0, 40.0], 

273 "y": [-30.0, -15.0, 0.0, 15.0, 30.0], 

274 "z": [-30.0, -3.75, 22.5, 48.75, 75.0], 

275 } 

276 return {"x": [10, 20], "y": [30, 40], "z": [15, 16]} 

277 

278 

279@pytest.mark.parametrize( 

280 "cut_coords", [(1, 1, 1), 5, {"x": [10, 20], "y": [30, 40], "z": [15, 16]}] 

281) 

282def test_demo_mosaic_slicer(cut_coords, img, expected_cuts): 

283 """Tests for MosaicSlicer with different cut_coords in constructor.""" 

284 slicer = MosaicSlicer(cut_coords=cut_coords) 

285 slicer.add_overlay(img, cmap="gray") 

286 assert slicer.cut_coords == expected_cuts 

287 slicer.close() 

288 

289 

290@pytest.mark.parametrize("projector", PROJECTORS) 

291def test_projectors_basic(projector, img, tmp_path): 

292 """Basic tests for projectors.""" 

293 projector = projector.init_with_figure(img=img) 

294 projector.add_overlay(img, cmap="gray") 

295 projector.savefig(tmp_path / "out.png") 

296 projector.close() 

297 

298 

299def test_contour_fillings_levels_in_add_contours(img): 

300 """Tests for method ``add_contours`` of ``OrthoSlicer``.""" 

301 oslicer = OrthoSlicer(cut_coords=(0, 0, 0)) 

302 # levels should be at least 2 

303 # If single levels are passed then we force upper level to be inf 

304 oslicer.add_contours(img, filled=True, colors="r", alpha=0.2, levels=[0.0]) 

305 # If two levels are passed, it should be increasing from zero index 

306 # In this case, we simply omit appending inf 

307 oslicer.add_contours( 

308 img, filled=True, colors="b", alpha=0.1, levels=[0.0, 0.2] 

309 ) 

310 # without passing colors and alpha. In this case, default values are 

311 # chosen from matplotlib 

312 oslicer.add_contours(img, filled=True, levels=[0.0, 0.2]) 

313 

314 # levels with only one value 

315 # vmin argument is not needed but added because of matplotlib 3.8.0rc1 bug 

316 # see https://github.com/matplotlib/matplotlib/issues/26531 

317 oslicer.add_contours(img, filled=True, levels=[0.0], vmin=0.0) 

318 

319 # without passing levels, should work with default levels from 

320 # matplotlib 

321 oslicer.add_contours(img, filled=True) 

322 oslicer.close() 

323 

324 

325def test_user_given_cmap_with_colorbar(img): 

326 """Test cmap provided as a string with ``OrthoSlicer``.""" 

327 oslicer = OrthoSlicer(cut_coords=(0, 0, 0)) 

328 oslicer.add_overlay(img, cmap="Paired", colorbar=True) 

329 oslicer.close() 

330 

331 

332@pytest.mark.parametrize("display", [OrthoSlicer, LYRZProjector]) 

333def test_data_complete_mask(affine_eye, display): 

334 """Test for a special case due to matplotlib 2.1.0. 

335 

336 When the data is completely masked, then we have plotting issues 

337 See similar issue #9280 reported in matplotlib. This function 

338 tests the patch added for this particular issue. 

339 """ 

340 # data is completely masked 

341 data = np.zeros((10, 20, 30)) 

342 img = Nifti1Image(data, affine_eye) 

343 n_cuts = 3 if display == OrthoSlicer else 4 

344 display = display(cut_coords=(0,) * n_cuts) 

345 display.add_overlay(img) 

346 display.close() 

347 

348 

349def test_add_markers_cut_coords_is_none(): 

350 """Tests a special case for ``add_markers`` when ``cut_coords`` are None. 

351 

352 This case is used when coords are placed on glass brain. 

353 """ 

354 orthoslicer = OrthoSlicer(cut_coords=(None, None, None)) 

355 orthoslicer.add_markers([(0, 0, 2)]) 

356 orthoslicer.close() 

357 

358 

359def test_annotations(): 

360 """Tests for ``display.annotate()``. 

361 

362 In particular, exercise some of the keyword arguments for scale bars. 

363 """ 

364 orthoslicer = OrthoSlicer(cut_coords=(None, None, None)) 

365 orthoslicer.annotate(size=10, left_right=True, positions=False) 

366 orthoslicer.annotate( 

367 size=12, 

368 left_right=False, 

369 positions=False, 

370 scalebar=True, 

371 scale_size=2.5, 

372 scale_units="cm", 

373 scale_loc=3, 

374 frameon=True, 

375 ) 

376 orthoslicer.close() 

377 

378 

379def test_position_annotation_with_decimals(): 

380 """Test of decimals position annotation with precision of 2.""" 

381 orthoslicer = OrthoSlicer(cut_coords=(0, 0, 0)) 

382 orthoslicer.annotate(positions=True, decimals=2) 

383 orthoslicer.close() 

384 

385 

386@pytest.mark.parametrize("node_color", ["red", ["red", "blue"]]) 

387def test_add_graph_with_node_color_as_string(node_color): 

388 """Tests for ``display.add_graph()``.""" 

389 lzry_projector = LZRYProjector(cut_coords=(0, 0, 0, 0)) 

390 matrix = np.array([[0, 3], [3, 0]]) 

391 node_coords = [[-53.60, -62.80, 36.64], [23.87, 0.31, 69.42]] 

392 lzry_projector.add_graph(matrix, node_coords, node_color=node_color) 

393 lzry_projector.close() 

394 

395 

396@pytest.mark.parametrize( 

397 "threshold,vmin,vmax,expected_results", 

398 [ 

399 (None, None, None, [[-2, -1, 0], [0, 1, 2]]), 

400 (0.5, None, None, [[-2, -1, np.nan], [np.nan, 1, 2]]), 

401 (1, 0, None, [[np.nan, np.nan, np.nan], [np.nan, np.nan, 2]]), 

402 (1, None, 1, [[-2, np.nan, np.nan], [np.nan, np.nan, np.nan]]), 

403 (0, 0, 0, [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan]]), 

404 ], 

405) 

406def test_threshold(threshold, vmin, vmax, expected_results): 

407 """Tests for ``OrthoSlicer._threshold``.""" 

408 data = np.array([[-2, -1, 0], [0, 1, 2]], dtype=float) 

409 assert np.ma.allequal( 

410 OrthoSlicer._threshold(data, threshold, vmin, vmax), 

411 np.ma.masked_invalid(expected_results), 

412 ) 

413 

414 

415@pytest.mark.parametrize("transparency", [None, 0, 0.5, 1]) 

416@pytest.mark.parametrize("display,name", zip(SLICERS, SLICER_KEYS)) 

417def test_display_slicers_transparency( 

418 display, img, name, cut_coords, transparency 

419): 

420 """Test several valid transparency values. 

421 

422 Also make sure warning is thrown that alpha value is overridden. 

423 """ 

424 display = display(cut_coords=cut_coords) 

425 with pytest.warns(UserWarning, match="Overriding with"): 

426 display.add_overlay( 

427 img, cmap=plt.cm.gray, transparency=transparency, alpha=0.5 

428 ) 

429 display.title(f"display mode is {name}") 

430 

431 

432@pytest.mark.parametrize("transparency", [-2, 10]) 

433@pytest.mark.parametrize("display,name", zip(SLICERS, SLICER_KEYS)) 

434def test_display_slicers_transparency_warning( 

435 display, img, name, cut_coords, transparency 

436): 

437 """Test several invalid transparency values throw warnings.""" 

438 display = display(cut_coords=cut_coords) 

439 with pytest.warns(UserWarning, match="Setting it to"): 

440 display.add_overlay(img, cmap=plt.cm.gray, transparency=transparency) 

441 display.title(f"display mode is {name}") 

442 

443 

444@pytest.mark.parametrize("transparency", [None, 0, 0.5, 1]) 

445@pytest.mark.parametrize("display,name", zip(PROJECTORS, PROJECTOR_KEYS)) 

446def test_display_projectors_transparency( 

447 display, img, name, cut_coords, transparency 

448): 

449 """Test several valid transparency values. 

450 

451 Also make sure warning is thrown that alpha value is overridden. 

452 """ 

453 display = display(cut_coords=cut_coords) 

454 with pytest.warns(UserWarning, match="Overriding with"): 

455 display.add_overlay( 

456 img, cmap=plt.cm.gray, transparency=transparency, alpha=0.5 

457 ) 

458 display.title(f"display mode is {name}") 

459 

460 

461@pytest.mark.parametrize("transparency", [-2, 10]) 

462@pytest.mark.parametrize("display,name", zip(PROJECTORS, PROJECTOR_KEYS)) 

463def test_display_projectors_transparency_warning( 

464 display, img, name, cut_coords, transparency 

465): 

466 """Test several invalid transparency values throw warnings.""" 

467 display = display(cut_coords=cut_coords) 

468 with pytest.warns(UserWarning, match="Setting it to"): 

469 display.add_overlay(img, cmap=plt.cm.gray, transparency=transparency) 

470 display.title(f"display mode is {name}")