Coverage for nilearn/plotting/html_stat_map.py: 0%

171 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-16 12:32 +0200

1"""Visualizing 3D stat maps in a Brainsprite viewer.""" 

2 

3import copy 

4import json 

5import warnings 

6from base64 import b64encode 

7from io import BytesIO 

8from pathlib import Path 

9 

10import matplotlib 

11import numpy as np 

12from matplotlib.image import imsave 

13from nibabel.affines import apply_affine 

14 

15from nilearn import DEFAULT_DIVERGING_CMAP 

16from nilearn._utils import check_niimg_3d, fill_doc 

17from nilearn._utils.extmath import fast_abs_percentile 

18from nilearn._utils.html_document import HTMLDocument 

19from nilearn._utils.logger import find_stack_level 

20from nilearn._utils.niimg import safe_get_data 

21from nilearn._utils.param_validation import check_threshold 

22from nilearn.datasets import load_mni152_template 

23from nilearn.image import get_data, new_img_like, reorder_img, resample_to_img 

24from nilearn.plotting.find_cuts import find_xyz_cut_coords 

25from nilearn.plotting.img_plotting import load_anat 

26from nilearn.plotting.js_plotting_utils import colorscale, get_html_template 

27 

28 

29def _data_to_sprite(data, radiological=False): 

30 """Convert a 3D array into a sprite of sagittal slices. 

31 

32 Parameters 

33 ---------- 

34 data : :class:`numpy.ndarray` 

35 Input data to convert to sprite. 

36 

37 Returns 

38 ------- 

39 sprite : 2D :class:`numpy.ndarray` 

40 If each sagittal slice is nz (height) x ny (width) pixels, the sprite 

41 size is (M x nz) x (N x ny), where M and N are computed to be roughly 

42 equal. All slices are pasted together row by row, from top left to 

43 bottom right. The last row is completed with empty slices. 

44 

45 """ 

46 nx, ny, nz = data.shape 

47 nrows = int(np.ceil(np.sqrt(nx))) 

48 ncolumns = int(np.ceil(nx / float(nrows))) 

49 

50 sprite = np.zeros((nrows * nz, ncolumns * ny)) 

51 indrow, indcol = np.where(np.ones((nrows, ncolumns))) 

52 

53 if radiological: 

54 for xx in range(nx): 

55 sprite[ 

56 (indrow[xx] * nz) : ((indrow[xx] + 1) * nz), 

57 (indcol[xx] * ny) : ((indcol[xx] + 1) * ny), 

58 ] = data[nx - xx - 1, :, ::-1].transpose() 

59 

60 else: 

61 for xx in range(nx): 

62 sprite[ 

63 (indrow[xx] * nz) : ((indrow[xx] + 1) * nz), 

64 (indcol[xx] * ny) : ((indcol[xx] + 1) * ny), 

65 ] = data[xx, :, ::-1].transpose() 

66 

67 return sprite 

68 

69 

70def _threshold_data(data, threshold=None): 

71 """Threshold a data array. 

72 

73 Parameters 

74 ---------- 

75 data : :class:`numpy.ndarray` 

76 Data to apply threshold on. 

77 

78 threshold : :obj:`float`, optional 

79 Threshold to apply to data. 

80 

81 Returns 

82 ------- 

83 data : :class:`numpy.ndarray` 

84 Thresholded data. 

85 

86 mask : :class:`numpy.ndarray` of :obj:`bool` 

87 Boolean mask. 

88 

89 threshold : :obj:`float` 

90 Updated threshold value. 

91 

92 """ 

93 # If threshold is None, do nothing 

94 if threshold is None: 

95 mask = np.full(data.shape, False) 

96 return data, mask, threshold 

97 

98 # Deal with automatic settings of plot parameters 

99 if threshold == "auto": 

100 # Threshold epsilon below a percentile value, to be sure that some 

101 # voxels pass the threshold 

102 threshold = fast_abs_percentile(data) - 1e-5 

103 

104 # Threshold 

105 threshold = check_threshold( 

106 threshold, data, percentile_func=fast_abs_percentile, name="threshold" 

107 ) 

108 

109 if threshold == 0: 

110 mask = data == 0 

111 else: 

112 mask = (data >= -threshold) & (data <= threshold) 

113 data = data * np.logical_not(mask) 

114 if not np.any(mask): 

115 warnings.warn( 

116 f"Threshold given was {threshold}, " 

117 f"but the data has no values below {data.min()}. ", 

118 stacklevel=find_stack_level(), 

119 ) 

120 return data, mask, threshold 

121 

122 

123def _save_sprite( 

124 data, 

125 output_sprite, 

126 vmax, 

127 vmin, 

128 mask=None, 

129 cmap="Greys", 

130 format="png", 

131 radiological=False, 

132): 

133 """Generate a sprite from a 3D Niimg-like object. 

134 

135 Parameters 

136 ---------- 

137 data : :class:`numpy.ndarray` 

138 Input data. 

139 

140 output_sprite : :class:`numpy.ndarray` 

141 Output sprite. 

142 

143 vmax, vmin : :obj:`float` 

144 ??? 

145 

146 mask : :class:`numpy.ndarray`, optional 

147 Mask to use. 

148 

149 %(cmap)s 

150 default='Greys' 

151 

152 

153 format : :obj:`str`, default='png' 

154 Format to use for output image. 

155 

156 Returns 

157 ------- 

158 sprite : :class:`numpy.ndarray` 

159 Returned sprite. 

160 

161 """ 

162 # Create sprite 

163 sprite = _data_to_sprite(data, radiological) 

164 

165 # Mask the sprite 

166 if mask is not None: 

167 mask = _data_to_sprite(mask, radiological) 

168 sprite = np.ma.array(sprite, mask=mask) 

169 

170 # Save the sprite 

171 imsave( 

172 output_sprite, sprite, vmin=vmin, vmax=vmax, cmap=cmap, format=format 

173 ) 

174 

175 return sprite 

176 

177 

178def _bytes_io_to_base64(handle_io): 

179 """Encode the content of a bytesIO virtual file as base64. 

180 

181 Also closes the file. 

182 

183 Returns 

184 ------- 

185 data 

186 """ 

187 handle_io.seek(0) 

188 data = b64encode(handle_io.read()).decode("utf-8") 

189 handle_io.close() 

190 return data 

191 

192 

193def _save_cm(output_cmap, cmap, format="png", n_colors=256): 

194 """Save the colormap of an image as an image file.""" 

195 # save the colormap 

196 data = np.arange(0.0, n_colors) / (n_colors - 1.0) 

197 data = data.reshape([1, n_colors]) 

198 imsave(output_cmap, data, cmap=cmap, format=format) 

199 

200 

201class StatMapView(HTMLDocument): # noqa: D101 

202 pass 

203 

204 

205def _mask_stat_map(stat_map_img, threshold=None): 

206 """Load a stat map and apply a threshold. 

207 

208 Returns 

209 ------- 

210 mask_img 

211 

212 stat_map_img 

213 

214 data 

215 

216 threshold 

217 """ 

218 # Load stat map 

219 stat_map_img = check_niimg_3d(stat_map_img, dtype="auto") 

220 data = safe_get_data(stat_map_img, ensure_finite=True) 

221 

222 # threshold the stat_map 

223 if threshold is not None: 

224 data, mask, threshold = _threshold_data(data, threshold) 

225 mask_img = new_img_like(stat_map_img, mask, stat_map_img.affine) 

226 else: 

227 mask_img = new_img_like( 

228 stat_map_img, np.zeros(data.shape), stat_map_img.affine 

229 ) 

230 return mask_img, stat_map_img, data, threshold 

231 

232 

233def _load_bg_img(stat_map_img, bg_img="MNI152", black_bg="auto", dim="auto"): 

234 """Load and resample bg_img in an isotropic resolution, \ 

235 with a positive diagonal affine matrix. 

236 

237 Returns 

238 ------- 

239 bg_img 

240 

241 bg_min 

242 

243 bg_max 

244 

245 black_bg 

246 """ 

247 if bg_img is None or bg_img is False: 

248 if black_bg == "auto": 

249 black_bg = False 

250 bg_img = new_img_like( 

251 stat_map_img, np.ma.masked_all(stat_map_img.shape) 

252 ) 

253 bg_min, bg_max = 0, 0 

254 else: 

255 if isinstance(bg_img, str) and bg_img == "MNI152": 

256 bg_img = load_mni152_template(resolution=2) 

257 else: 

258 bg_img = check_niimg_3d(bg_img) 

259 masked_data = np.ma.masked_inside( 

260 safe_get_data(bg_img, ensure_finite=True), -1e-6, 1e-6, copy=False 

261 ) 

262 bg_img = new_img_like(bg_img, masked_data) 

263 bg_img, black_bg, bg_min, bg_max = load_anat( 

264 bg_img, dim=dim, black_bg=black_bg 

265 ) 

266 bg_img = reorder_img(bg_img, resample="nearest", copy_header=True) 

267 return bg_img, bg_min, bg_max, black_bg 

268 

269 

270def _resample_stat_map( 

271 stat_map_img, bg_img, mask_img, resampling_interpolation="continuous" 

272): 

273 """Resample the stat map and mask to the background. 

274 

275 Returns 

276 ------- 

277 stat_map_img 

278 

279 mask_img 

280 """ 

281 stat_map_img = resample_to_img( 

282 stat_map_img, 

283 bg_img, 

284 interpolation=resampling_interpolation, 

285 copy_header=True, 

286 force_resample=False, # TODO set to True in 0.13.0 

287 ) 

288 mask_img = resample_to_img( 

289 mask_img, 

290 bg_img, 

291 fill_value=1, 

292 interpolation="nearest", 

293 copy_header=True, 

294 force_resample=False, # TODO set to True in 0.13.0 

295 ) 

296 

297 return stat_map_img, mask_img 

298 

299 

300def _json_view_params( 

301 shape, 

302 affine, 

303 vmin, 

304 vmax, 

305 cut_slices, 

306 black_bg=False, 

307 opacity=1, 

308 draw_cross=True, 

309 annotate=True, 

310 title=None, 

311 colorbar=True, 

312 value=True, 

313 radiological=False, 

314 show_lr=True, 

315): 

316 """Create a dictionary with all the brainsprite parameters. 

317 

318 Returns 

319 ------- 

320 params 

321 """ 

322 # Set color parameters 

323 if black_bg: 

324 cfont = "#FFFFFF" 

325 cbg = "#000000" 

326 else: 

327 cfont = "#000000" 

328 cbg = "#FFFFFF" 

329 

330 # Deal with limitations of json dump regarding types 

331 if type(vmin).__module__ == "numpy": 

332 vmin = vmin.tolist() # json does not deal with numpy array 

333 if type(vmax).__module__ == "numpy": 

334 vmax = vmax.tolist() # json does not deal with numpy array 

335 

336 params = { 

337 "canvas": "3Dviewer", 

338 "sprite": "spriteImg", 

339 "nbSlice": {"X": shape[0], "Y": shape[1], "Z": shape[2]}, 

340 "overlay": { 

341 "sprite": "overlayImg", 

342 "nbSlice": {"X": shape[0], "Y": shape[1], "Z": shape[2]}, 

343 "opacity": opacity, 

344 }, 

345 "colorBackground": cbg, 

346 "colorFont": cfont, 

347 "crosshair": draw_cross, 

348 "affine": affine.tolist(), 

349 "flagCoordinates": annotate, 

350 "title": title, 

351 "flagValue": value, 

352 "numSlice": { 

353 "X": cut_slices[0] - 1, 

354 "Y": cut_slices[1] - 1, 

355 "Z": cut_slices[2] - 1, 

356 }, 

357 "radiological": radiological, 

358 "showLR": show_lr, 

359 } 

360 

361 if colorbar: 

362 params["colorMap"] = {"img": "colorMap", "min": vmin, "max": vmax} 

363 return params 

364 

365 

366def _json_view_size(params, width_view=600): 

367 """Define the size of the viewer. 

368 

369 Returns 

370 ------- 

371 width_view 

372 

373 height_view 

374 """ 

375 # slices_width = sagittal_width (y) + coronal_width (x) + axial_width (x) 

376 slices_width = params["nbSlice"]["Y"] + 2 * params["nbSlice"]["X"] 

377 

378 # slices_height = max of sagittal_height (z), coronal_height (z), and 

379 # axial_height (y). 

380 # Also add 20% extra height for annotation and margin 

381 slices_height = np.max([params["nbSlice"]["Y"], params["nbSlice"]["Z"]]) 

382 slices_height = 1.20 * slices_height 

383 

384 # Get the final size of the viewer 

385 ratio = slices_height / slices_width 

386 height_view = np.ceil(ratio * width_view) 

387 

388 return width_view, height_view 

389 

390 

391def _get_bg_mask_and_cmap(bg_img, black_bg): 

392 """Get background data for _json_view_data.""" 

393 bg_mask = np.ma.getmaskarray(get_data(bg_img)) 

394 bg_cmap = copy.copy(matplotlib.pyplot.get_cmap("gray")) 

395 if black_bg: 

396 bg_cmap.set_bad("black") 

397 else: 

398 bg_cmap.set_bad("white") 

399 return bg_mask, bg_cmap 

400 

401 

402def _json_view_data( 

403 bg_img, 

404 stat_map_img, 

405 mask_img, 

406 bg_min, 

407 bg_max, 

408 black_bg, 

409 colors, 

410 cmap, 

411 colorbar, 

412 radiological, 

413): 

414 """Create a json-like viewer object, and populate with base64 data. 

415 

416 Returns 

417 ------- 

418 json_view 

419 """ 

420 # Initialize brainsprite data structure 

421 json_view = dict.fromkeys( 

422 [ 

423 "bg_base64", 

424 "stat_map_base64", 

425 "cm_base64", 

426 "params", 

427 "js_jquery", 

428 "js_brainsprite", 

429 ] 

430 ) 

431 

432 # Create a base64 sprite for the background 

433 bg_sprite = BytesIO() 

434 bg_data = safe_get_data(bg_img, ensure_finite=True).astype(float) 

435 bg_mask, bg_cmap = _get_bg_mask_and_cmap(bg_img, black_bg) 

436 _save_sprite( 

437 bg_data, 

438 bg_sprite, 

439 bg_max, 

440 bg_min, 

441 bg_mask, 

442 bg_cmap, 

443 "png", 

444 radiological, 

445 ) 

446 json_view["bg_base64"] = _bytes_io_to_base64(bg_sprite) 

447 

448 # Create a base64 sprite for the stat map 

449 stat_map_sprite = BytesIO() 

450 data = safe_get_data(stat_map_img, ensure_finite=True) 

451 mask = safe_get_data(mask_img, ensure_finite=True) 

452 _save_sprite( 

453 data, 

454 stat_map_sprite, 

455 colors["vmax"], 

456 colors["vmin"], 

457 mask, 

458 cmap, 

459 "png", 

460 radiological, 

461 ) 

462 json_view["stat_map_base64"] = _bytes_io_to_base64(stat_map_sprite) 

463 

464 # Create a base64 colormap 

465 if colorbar: 

466 stat_map_cm = BytesIO() 

467 _save_cm(stat_map_cm, colors["cmap"], "png") 

468 json_view["cm_base64"] = _bytes_io_to_base64(stat_map_cm) 

469 else: 

470 json_view["cm_base64"] = "" 

471 

472 return json_view 

473 

474 

475def _json_view_to_html(json_view, width_view=600): 

476 """Fill a brainsprite html template with relevant parameters and data. 

477 

478 Returns 

479 ------- 

480 html_view 

481 """ 

482 # Fix the size of the viewer 

483 width, height = _json_view_size(json_view["params"], width_view) 

484 

485 # Populate all missing keys with html-ready data 

486 json_view["INSERT_PAGE_TITLE_HERE"] = ( 

487 json_view["params"]["title"] or "Slice viewer" 

488 ) 

489 json_view["params"] = json.dumps(json_view["params"]) 

490 js_dir = Path(__file__).parent / "data" / "js" 

491 with (js_dir / "jquery.min.js").open() as f: 

492 json_view["js_jquery"] = f.read() 

493 with (js_dir / "brainsprite.min.js").open() as f: 

494 json_view["js_brainsprite"] = f.read() 

495 

496 # Load the html template, and plug in all the data 

497 html_view = get_html_template("stat_map_template.html") 

498 html_view = html_view.safe_substitute(json_view) 

499 

500 return StatMapView(html_view, width=width, height=height) 

501 

502 

503def _get_cut_slices(stat_map_img, cut_coords=None, threshold=None): 

504 """For internal use. 

505 

506 Find slice numbers for the cut. 

507 Based on find_xyz_cut_coords 

508 """ 

509 # Select coordinates for the cut 

510 if cut_coords is None: 

511 cut_coords = find_xyz_cut_coords( 

512 stat_map_img, activation_threshold=threshold 

513 ) 

514 

515 # Convert cut coordinates into cut slices 

516 try: 

517 cut_slices = apply_affine( 

518 np.linalg.inv(stat_map_img.affine), cut_coords 

519 ) 

520 except ValueError: 

521 raise ValueError( 

522 "The input given for display_mode='ortho' " 

523 "needs to be a list of 3d world coordinates in (x, y, z). " 

524 f"You provided cut_coords={cut_coords}" 

525 ) 

526 except IndexError: 

527 raise ValueError( 

528 "The input given for display_mode='ortho' " 

529 "needs to be a list of 3d world coordinates in (x, y, z). " 

530 f"You provided single cut, cut_coords={cut_coords}" 

531 ) 

532 

533 return cut_slices 

534 

535 

536@fill_doc 

537def view_img( 

538 stat_map_img, 

539 bg_img="MNI152", 

540 cut_coords=None, 

541 colorbar=True, 

542 title=None, 

543 threshold=1e-6, 

544 annotate=True, 

545 draw_cross=True, 

546 black_bg="auto", 

547 cmap=DEFAULT_DIVERGING_CMAP, 

548 symmetric_cmap=True, 

549 dim="auto", 

550 vmax=None, 

551 vmin=None, 

552 resampling_interpolation="continuous", 

553 width_view=600, 

554 opacity=1, 

555 radiological=False, 

556 show_lr=True, 

557): 

558 """Interactive html viewer of a statistical map, with optional background. 

559 

560 Parameters 

561 ---------- 

562 stat_map_img : Niimg-like object 

563 See :ref:`extracting_data`. 

564 The statistical map image. Can be either a 3D volume or a 4D volume 

565 with exactly one time point. 

566 %(bg_img)s 

567 If nothing is specified, the MNI152 template will be used. 

568 To turn off background image, just pass "bg_img=False". 

569 Default='MNI152'. 

570 

571 cut_coords : None, or a :obj:`tuple` of :obj:`float`, default=None 

572 The :term:`MNI` coordinates of the point where the cut is performed 

573 as a 3-tuple: (x, y, z). If None is given, the cuts are calculated 

574 automatically. 

575 

576 colorbar : :obj:`bool`, default=True 

577 If True, display a colorbar on top of the plots. 

578 %(title)s 

579 threshold : :obj:`str`, number or None, default=1e-06 

580 If None is given, the image is not thresholded. 

581 If a string of the form "90%%" is given, use the 90-th percentile of 

582 the absolute value in the image. 

583 If a number is given, it is used to threshold the image: 

584 values below the threshold (in absolute value) are plotted 

585 as transparent. If auto is given, the threshold is determined 

586 automatically. 

587 

588 annotate : :obj:`bool`, default=True 

589 If annotate is True, current cuts are added to the viewer. 

590 %(draw_cross)s 

591 black_bg : :obj:`bool` or 'auto', default='auto' 

592 If True, the background of the image is set to be black. 

593 Otherwise, a white background is used. 

594 If set to auto, an educated guess is made to find if the background 

595 is white or black. 

596 %(cmap)s 

597 default="RdBu_r" 

598 symmetric_cmap : :obj:`bool`, default=True 

599 True: make colormap symmetric (ranging from -vmax to vmax). 

600 False: the colormap will go from the minimum of the volume to vmax. 

601 Set it to False if you are plotting a positive volume, e.g. an atlas 

602 or an anatomical image. 

603 %(dim)s 

604 Default='auto'. 

605 vmax : :obj:`float`, or None, default=None 

606 max value for mapping colors. 

607 If vmax is None and symmetric_cmap is True, vmax is the max 

608 absolute value of the volume. 

609 If vmax is None and symmetric_cmap is False, vmax is the max 

610 value of the volume. 

611 

612 vmin : :obj:`float`, or None, default=None 

613 min value for mapping colors. 

614 If `symmetric_cmap` is `True`, `vmin` is always equal to `-vmax` and 

615 cannot be chosen. 

616 If `symmetric_cmap` is `False`, `vmin` is equal to the min of the 

617 image, or 0 when a threshold is used. 

618 %(resampling_interpolation)s 

619 Default='continuous'. 

620 

621 width_view : :obj:`int`, default=600 

622 Width of the viewer in pixels. 

623 

624 opacity : :obj:`float` in [0,1], default=1 

625 The level of opacity of the overlay (0: transparent, 1: opaque). 

626 

627 Returns 

628 ------- 

629 html_view : the html viewer object. 

630 It can be saved as an html page `html_view.save_as_html('test.html')`, 

631 or opened in a browser `html_view.open_in_browser()`. 

632 If the output is not requested and the current environment is a Jupyter 

633 notebook, the viewer will be inserted in the notebook. 

634 

635 See Also 

636 -------- 

637 nilearn.plotting.plot_stat_map: 

638 static plot of brain volume, on a single or multiple planes. 

639 nilearn.plotting.view_connectome: 

640 interactive 3d view of a connectome. 

641 nilearn.plotting.view_markers: 

642 interactive plot of colored markers. 

643 nilearn.plotting.view_surf, nilearn.plotting.view_img_on_surf: 

644 interactive view of statistical maps or surface atlases on the cortical 

645 surface. 

646 

647 """ 

648 # Prepare the color map and thresholding 

649 mask_img, stat_map_img, data, threshold = _mask_stat_map( 

650 stat_map_img, threshold 

651 ) 

652 colors = colorscale( 

653 cmap, 

654 data.ravel(), 

655 threshold=threshold, 

656 symmetric_cmap=symmetric_cmap, 

657 vmax=vmax, 

658 vmin=vmin, 

659 ) 

660 

661 # Prepare the data for the cuts 

662 bg_img, bg_min, bg_max, black_bg = _load_bg_img( 

663 stat_map_img, bg_img, black_bg, dim 

664 ) 

665 stat_map_img, mask_img = _resample_stat_map( 

666 stat_map_img, bg_img, mask_img, resampling_interpolation 

667 ) 

668 cut_slices = _get_cut_slices(stat_map_img, cut_coords, threshold) 

669 

670 # Now create a json-like object for the viewer, and converts in html 

671 json_view = _json_view_data( 

672 bg_img, 

673 stat_map_img, 

674 mask_img, 

675 bg_min, 

676 bg_max, 

677 black_bg, 

678 colors, 

679 cmap, 

680 colorbar, 

681 radiological, 

682 ) 

683 

684 json_view["params"] = _json_view_params( 

685 stat_map_img.shape, 

686 stat_map_img.affine, 

687 colors["vmin"], 

688 colors["vmax"], 

689 cut_slices, 

690 black_bg, 

691 opacity, 

692 draw_cross, 

693 annotate, 

694 title, 

695 colorbar, 

696 value=False, 

697 radiological=radiological, 

698 show_lr=show_lr, 

699 ) 

700 

701 html_view = _json_view_to_html(json_view, width_view) 

702 

703 return html_view