Coverage for nilearn/plotting/tests/test_html_stat_map.py: 0%

193 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-16 12:32 +0200

1import base64 

2import warnings 

3from io import BytesIO 

4 

5import numpy as np 

6import pytest 

7from matplotlib import pyplot as plt 

8from nibabel import Nifti1Image 

9 

10from nilearn import datasets, image 

11from nilearn.image import get_data, new_img_like 

12from nilearn.plotting.html_stat_map import ( 

13 StatMapView, 

14 _bytes_io_to_base64, 

15 _data_to_sprite, 

16 _get_bg_mask_and_cmap, 

17 _get_cut_slices, 

18 _json_view_data, 

19 _json_view_params, 

20 _json_view_size, 

21 _json_view_to_html, 

22 _load_bg_img, 

23 _mask_stat_map, 

24 _resample_stat_map, 

25 _save_cm, 

26 _save_sprite, 

27 _threshold_data, 

28 view_img, 

29) 

30from nilearn.plotting.js_plotting_utils import colorscale 

31 

32 

33def _check_html(html_view, title=None): 

34 """Check the presence of some expected code in the html viewer.""" 

35 assert isinstance(html_view, StatMapView) 

36 assert "var brain =" in str(html_view) 

37 assert "overlayImg" in str(html_view) 

38 if title is not None: 

39 assert f"<title>{title}</title>" in str(html_view) 

40 

41 

42def _simulate_img(affine=None): 

43 """Simulate data with one "spot". 

44 

45 Returns 

46 ------- 

47 img 

48 

49 data 

50 """ 

51 if affine is None: 

52 affine = np.eye(4) 

53 data = np.zeros([8, 8, 8]) 

54 data[4, 4, 4] = 1 

55 img = Nifti1Image(data, affine) 

56 return img, data 

57 

58 

59def _check_affine(affine): 

60 """Check positive, isotropic, near-diagonal affine.""" 

61 assert affine[0, 0] == affine[1, 1] 

62 assert affine[2, 2] == affine[1, 1] 

63 assert affine[0, 0] > 0 

64 

65 A, b = image.resampling.to_matrix_vector(affine) 

66 assert np.all((np.abs(A) > 0.001).sum(axis=0) == 1), ( 

67 "the affine transform was not near-diagonal" 

68 ) 

69 

70 

71def test_data_to_sprite(): 

72 # Simulate data and turn into sprite 

73 data = np.zeros([8, 8, 8]) 

74 data[2:6, 2:6, 2:6] = 1 

75 sprite = _data_to_sprite(data) 

76 

77 # Generate ground truth for the sprite 

78 Z = np.zeros([8, 8]) 

79 Zr = np.zeros([2, 8]) 

80 Cr = np.tile(np.array([[0, 0, 1, 1, 1, 1, 0, 0]]), [4, 1]) 

81 C = np.concatenate((Zr, Cr, Zr), axis=0) 

82 gtruth = np.concatenate( 

83 ( 

84 np.concatenate((Z, Z, C), axis=1), 

85 np.concatenate((C, C, C), axis=1), 

86 np.concatenate((Z, Z, Z), axis=1), 

87 ), 

88 axis=0, 

89 ) 

90 

91 assert sprite.shape == gtruth.shape, "shape of sprite not as expected" 

92 assert (sprite == gtruth).all(), "simulated sprite not as expected" 

93 

94 

95def test_threshold_data(): 

96 data = np.arange(-3, 4) 

97 

98 # Check that an 'auto' threshold leaves at least one element 

99 data_t, mask, thresh = _threshold_data(data, threshold="auto") 

100 gtruth_m = np.array([False, True, True, True, True, True, False]) 

101 gtruth_d = np.array([-3, 0, 0, 0, 0, 0, 3]) 

102 assert (mask == gtruth_m).all() 

103 assert (data_t == gtruth_d).all() 

104 

105 # Check that threshold=None keeps everything 

106 data_t, mask, thresh = _threshold_data(data, threshold=None) 

107 assert np.all(np.logical_not(mask)) 

108 assert np.all(data_t == data) 

109 

110 # Check positive threshold works 

111 data_t, mask, thresh = _threshold_data(data, threshold=1) 

112 gtruth = np.array([False, False, True, True, True, False, False]) 

113 assert (mask == gtruth).all() 

114 

115 # Check 0 threshold works 

116 data_t, mask, thresh = _threshold_data(data, threshold=0) 

117 gtruth = np.array([False, False, False, True, False, False, False]) 

118 assert (mask == gtruth).all() 

119 

120 # Check that overly lenient threshold returns array 

121 data = np.arange(3, 10) 

122 data_t, mask, thresh = _threshold_data(data, threshold=2) 

123 gtruth = np.full(7, False) 

124 assert (mask == gtruth).all() 

125 

126 

127def test_save_sprite(rng): 

128 """Test covers _save_sprite as well as _bytes_io_to_base64.""" 

129 # Generate a simulated volume with a square inside 

130 data = rng.uniform(size=140).reshape(7, 5, 4) 

131 mask = np.zeros((7, 5, 4), dtype=int) 

132 mask[1:-1, 1:-1, 1:-1] = 1 

133 # Save the sprite using BytesIO 

134 sprite_io = BytesIO() 

135 _save_sprite(data, sprite_io, vmin=0, vmax=1, mask=mask, format="png") 

136 

137 # Load the sprite back in base64 

138 sprite_base64 = _bytes_io_to_base64(sprite_io) 

139 

140 decoded_io = BytesIO() 

141 decoded_io.write(base64.b64decode(sprite_base64)) 

142 decoded_io.seek(0) 

143 img = plt.imread(decoded_io, format="png") 

144 correct_img = np.ma.array( 

145 _data_to_sprite(data), 

146 mask=_data_to_sprite(mask), 

147 ) 

148 correct_img = plt.Normalize(0, 1)(correct_img) 

149 cmapped = plt.get_cmap("Greys")(correct_img) 

150 assert np.allclose(img, cmapped, atol=0.1) 

151 

152 

153@pytest.mark.parametrize("cmap", ["tab10", "cold_hot"]) 

154@pytest.mark.parametrize("n_colors", [7, 20]) 

155def test_save_cmap(cmap, n_colors): 

156 """Test covers _save_cmap as well as _bytes_io_to_base64.""" 

157 # Save the cmap using BytesIO 

158 cmap_io = BytesIO() 

159 _save_cm(cmap_io, cmap, format="png", n_colors=n_colors) 

160 

161 # Load the colormap back in base64 

162 cmap_base64 = _bytes_io_to_base64(cmap_io) 

163 

164 decoded_io = BytesIO() 

165 decoded_io.write(base64.b64decode(cmap_base64)) 

166 decoded_io.seek(0) 

167 img = plt.imread(decoded_io, format="png") 

168 expected = plt.get_cmap(cmap)(np.linspace(0, 1, n_colors)) 

169 assert np.allclose(img, expected, atol=0.1) 

170 

171 

172def test_mask_stat_map(): 

173 # Generate simple simulated data with one "spot" 

174 img, data = _simulate_img() 

175 

176 # Try not to threshold anything 

177 mask_img, img, data_t, thresh = _mask_stat_map(img, threshold=None) 

178 assert np.max(get_data(mask_img)) == 0 

179 

180 # Now threshold at zero 

181 mask_img, img, data_t, thresh = _mask_stat_map(img, threshold=0) 

182 assert np.min((data == 0) == get_data(mask_img)) 

183 

184 

185def test_load_bg_img(affine_eye): 

186 # Generate simple simulated data with non-diagonal affine 

187 affine = affine_eye 

188 affine[0, 0] = -1 

189 affine[0, 1] = 0.1 

190 img, _ = _simulate_img(affine) 

191 

192 # use empty bg_img 

193 bg_img, _, _, _ = _load_bg_img(img, bg_img=None) 

194 # Check positive isotropic, near-diagonal affine 

195 _check_affine(bg_img.affine) 

196 

197 # Try to load the default background 

198 bg_img, _, _, _ = _load_bg_img(img) 

199 

200 # Check positive isotropic, near-diagonal affine 

201 _check_affine(bg_img.affine) 

202 

203 

204def test_get_bg_mask_and_cmap(): 

205 # non-regression test for issue #3120 (bg image was masked with mni 

206 # template mask) 

207 img, _ = _simulate_img() 

208 mask, cmap = _get_bg_mask_and_cmap(img, False) 

209 assert (mask == np.zeros(img.shape, dtype=bool)).all() 

210 

211 

212def test_resample_stat_map(affine_eye): 

213 # Start with simple simulated data 

214 bg_img, data = _simulate_img() 

215 

216 # Now double the voxel size and mess with the affine 

217 affine = 2 * affine_eye 

218 affine[3, 3] = 1 

219 affine[0, 1] = 0.1 

220 stat_map_img = Nifti1Image(data, affine) 

221 

222 # Make a mask for the stat image 

223 mask_img = new_img_like(stat_map_img, data > 0, stat_map_img.affine) 

224 

225 # Now run the resampling 

226 stat_map_img, mask_img = _resample_stat_map( 

227 stat_map_img, bg_img, mask_img, resampling_interpolation="nearest" 

228 ) 

229 

230 # Check positive isotropic, near-diagonal affine 

231 _check_affine(stat_map_img.affine) 

232 _check_affine(mask_img.affine) 

233 

234 # Check voxel size matches bg_img 

235 assert stat_map_img.affine[0, 0] == bg_img.affine[0, 0], ( 

236 "stat_map_img was not resampled at the resolution of background" 

237 ) 

238 assert mask_img.affine[0, 0] == bg_img.affine[0, 0], ( 

239 "mask_img was not resampled at the resolution of background" 

240 ) 

241 

242 

243def test_json_view_params(affine_eye): 

244 # Try to generate some sprite parameters 

245 params = _json_view_params( 

246 shape=[4, 4, 4], 

247 affine=affine_eye, 

248 vmin=0, 

249 vmax=1, 

250 cut_slices=[1, 1, 1], 

251 black_bg=True, 

252 opacity=0.5, 

253 draw_cross=False, 

254 annotate=True, 

255 title="A test", 

256 colorbar=True, 

257 value=True, 

258 ) 

259 

260 # Just check that a structure was generated, 

261 # and test a single parameter 

262 assert params["overlay"]["opacity"] == 0.5 

263 

264 

265def test_json_view_size(): 

266 # Build some minimal sprite Parameters 

267 sprite_params = {"nbSlice": {"X": 4, "Y": 4, "Z": 4}} 

268 width, height = _json_view_size(sprite_params) 

269 

270 # This is a simple case: height is 4 pixels, width 3 x 4 = 12 pixels 

271 # with an additional 120% height factor for annotations and margins 

272 ratio = 1.2 * 4 / 12 

273 

274 # check we received the expected width and height 

275 width_exp = 600 

276 height_exp = np.ceil(ratio * 600) 

277 assert width == width_exp, "html viewer does not have expected width" 

278 assert height == height_exp, "html viewer does not have expected height" 

279 

280 

281def _get_data_and_json_view(black_bg, cbar, radiological): 

282 # simple simulated data for stat_img and background 

283 bg_img, data = _simulate_img() 

284 stat_map_img, data = _simulate_img() 

285 

286 # make a mask 

287 mask_img = new_img_like(stat_map_img, data > 0, stat_map_img.affine) 

288 

289 # Get color bar and data ranges 

290 colors = colorscale( 

291 "cold_hot", data.ravel(), threshold=0, symmetric_cmap=True, vmax=1 

292 ) 

293 

294 # Build a sprite 

295 json_view = _json_view_data( 

296 bg_img, 

297 stat_map_img, 

298 mask_img, 

299 bg_min=0, 

300 bg_max=1, 

301 black_bg=black_bg, 

302 colors=colors, 

303 cmap="cold_hot", 

304 colorbar=cbar, 

305 radiological=radiological, 

306 ) 

307 return data, json_view 

308 

309 

310@pytest.mark.parametrize("black_bg", [True, False]) 

311@pytest.mark.parametrize("cbar", [True, False]) 

312@pytest.mark.parametrize("radiological", [True, False]) 

313def test_json_view_data(black_bg, cbar, radiological): 

314 _, json_view = _get_data_and_json_view(black_bg, cbar, radiological) 

315 # Check the presence of critical fields 

316 assert isinstance(json_view["bg_base64"], str) 

317 assert isinstance(json_view["stat_map_base64"], str) 

318 assert isinstance(json_view["cm_base64"], str) 

319 

320 

321@pytest.mark.parametrize("black_bg", [True, False]) 

322@pytest.mark.parametrize("cbar", [True, False]) 

323@pytest.mark.parametrize("radiological", [True, False]) 

324def test_json_view_to_html(affine_eye, black_bg, cbar, radiological): 

325 data, json_view = _get_data_and_json_view(black_bg, cbar, radiological) 

326 json_view["params"] = _json_view_params( 

327 data.shape, 

328 affine_eye, 

329 vmin=0, 

330 vmax=1, 

331 cut_slices=[1, 1, 1], 

332 black_bg=True, 

333 opacity=1, 

334 draw_cross=True, 

335 annotate=False, 

336 title="test", 

337 colorbar=True, 

338 radiological=radiological, 

339 ) 

340 

341 # Create a viewer 

342 html_view = _json_view_to_html(json_view) 

343 _check_html(html_view) 

344 

345 

346def test_get_cut_slices(affine_eye): 

347 # Generate simple simulated data with one "spot" 

348 img, data = _simulate_img() 

349 

350 # Use automatic selection of coordinates 

351 cut_slices = _get_cut_slices(img, cut_coords=None, threshold=None) 

352 assert (cut_slices == [4, 4, 4]).all() 

353 

354 # Check that using a single number for cut_coords raises an error 

355 with pytest.raises(ValueError): 

356 _get_cut_slices(img, cut_coords=4, threshold=None) 

357 

358 # Check that it is possible to manually specify coordinates 

359 cut_slices = _get_cut_slices(img, cut_coords=[2, 2, 2], threshold=None) 

360 assert (cut_slices == [2, 2, 2]).all() 

361 

362 # Check that the affine does not change where the cut is done 

363 affine = 2 * affine_eye 

364 img = Nifti1Image(data, affine) 

365 cut_slices = _get_cut_slices(img, cut_coords=None, threshold=None) 

366 assert (cut_slices == [4, 4, 4]).all() 

367 

368 

369@pytest.mark.parametrize( 

370 "params, warning_msg", 

371 [ 

372 ( 

373 {"threshold": 2.0, "vmax": 4.0}, 

374 "The given float value must not exceed .*", 

375 ), 

376 ( 

377 {"symmetric_cmap": False}, 

378 "'partition' will ignore the 'mask' of the MaskedArray *", 

379 ), 

380 ], 

381) 

382def test_view_img_3d_warnings(params, warning_msg): 

383 """Test warning when viewing 3D images.""" 

384 mni = datasets.load_mni152_template(resolution=2) 

385 

386 # Create a fake functional image by resample the template 

387 img = image.resample_img( 

388 mni, 

389 target_affine=3 * np.eye(3), 

390 copy_header=True, 

391 force_resample=True, 

392 ) 

393 

394 # Should not raise warnings 

395 with warnings.catch_warnings(record=True) as w: 

396 html_view = view_img(img, bg_img=None) 

397 assert len(w) == 0 

398 

399 with pytest.warns(UserWarning, match=warning_msg): 

400 html_view = view_img(img, **params) 

401 

402 _check_html(html_view) 

403 

404 

405def test_view_img_3d_warnings_more(): 

406 """Test warning when viewing 3D images. 

407 

408 Has more precise checks on the output. 

409 """ 

410 mni = datasets.load_mni152_template(resolution=2) 

411 

412 # Create a fake functional image by resample the template 

413 img = image.resample_img( 

414 mni, 

415 target_affine=3 * np.eye(3), 

416 copy_header=True, 

417 force_resample=True, 

418 ) 

419 

420 with pytest.warns( 

421 UserWarning, 

422 match="'partition' will ignore the 'mask' of the MaskedArray", 

423 ): 

424 html_view = view_img(img) 

425 

426 _check_html(html_view, title="Slice viewer") 

427 

428 with pytest.warns( 

429 UserWarning, 

430 match="'partition' will ignore the 'mask' of the MaskedArray", 

431 ): 

432 html_view = view_img(img, threshold="95%", title="SOME_TITLE") 

433 

434 _check_html(html_view, title="SOME_TITLE") 

435 

436 

437@pytest.mark.parametrize( 

438 "params", 

439 [ 

440 {"threshold": 2.0, "vmax": 4.0}, 

441 {"threshold": 1e6}, 

442 {"width_view": 1000}, 

443 ], 

444) 

445def test_view_img_4d_warnings(params): 

446 """Test warning when viewing 4D images.""" 

447 mni = datasets.load_mni152_template(resolution=2) 

448 

449 # Create a fake functional image by resample the template 

450 img = image.resample_img( 

451 mni, 

452 target_affine=3 * np.eye(3), 

453 copy_header=True, 

454 force_resample=True, 

455 ) 

456 img_4d = image.new_img_like(img, get_data(img)[:, :, :, np.newaxis]) 

457 assert len(img_4d.shape) == 4 

458 

459 with pytest.warns( 

460 UserWarning, 

461 match="'partition' will ignore the 'mask' of the MaskedArray", 

462 ): 

463 html_view = view_img(img_4d, **params) 

464 

465 _check_html(html_view)