Coverage for nilearn/plotting/surface/_matplotlib_backend.py: 0%

312 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-20 10:58 +0200

1"""Functions specific to "matplotlib" backend for surface visualization 

2functions in :obj:`~nilearn.plotting.surface.surf_plotting`. 

3 

4Any imports from "matplotlib" package, or "matplotlib" engine specific utility 

5functions in :obj:`~nilearn.plotting.surface` should be in this file. 

6""" 

7 

8import itertools 

9from warnings import warn 

10 

11import numpy as np 

12 

13from nilearn import DEFAULT_DIVERGING_CMAP 

14from nilearn._utils import compare_version 

15from nilearn._utils.logger import find_stack_level 

16from nilearn.image import get_data 

17from nilearn.plotting import cm 

18from nilearn.plotting._utils import ( 

19 get_cbar_ticks, 

20 get_colorbar_and_data_ranges, 

21 save_figure_if_needed, 

22) 

23from nilearn.plotting.cm import mix_colormaps 

24from nilearn.plotting.js_plotting_utils import to_color_strings 

25from nilearn.plotting.surface._utils import ( 

26 DEFAULT_HEMI, 

27 check_engine_params, 

28 check_surf_map, 

29 check_surface_plotting_inputs, 

30 get_faces_on_edge, 

31 sanitize_hemi_view, 

32) 

33from nilearn.surface import load_surf_data, load_surf_mesh 

34 

35try: 

36 import matplotlib.pyplot as plt 

37 from matplotlib import __version__ as mpl_version 

38 from matplotlib.cm import ScalarMappable 

39 from matplotlib.colorbar import make_axes 

40 from matplotlib.colors import LinearSegmentedColormap, Normalize, to_rgba 

41 from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec 

42 from matplotlib.patches import Patch 

43 from mpl_toolkits.mplot3d.art3d import Poly3DCollection 

44except ImportError: 

45 from nilearn.plotting._utils import engine_warning 

46 

47 engine_warning("matplotlib") 

48 

49MATPLOTLIB_VIEWS = { 

50 "left": { 

51 "lateral": (0, 180), 

52 "medial": (0, 0), 

53 "dorsal": (90, 0), 

54 "ventral": (270, 0), 

55 "anterior": (0, 90), 

56 "posterior": (0, 270), 

57 }, 

58 "right": { 

59 "lateral": (0, 0), 

60 "medial": (0, 180), 

61 "dorsal": (90, 0), 

62 "ventral": (270, 0), 

63 "anterior": (0, 90), 

64 "posterior": (0, 270), 

65 }, 

66 "both": { 

67 "right": (0, 0), 

68 "left": (0, 180), 

69 "dorsal": (90, 0), 

70 "ventral": (270, 0), 

71 "anterior": (0, 90), 

72 "posterior": (0, 270), 

73 }, 

74} 

75 

76 

77def _adjust_colorbar_and_data_ranges( 

78 stat_map, vmin=None, vmax=None, symmetric_cbar=None 

79): 

80 """Adjust colorbar and data ranges for 'matplotlib' engine. 

81 

82 Parameters 

83 ---------- 

84 stat_map : :obj:`str` or :class:`numpy.ndarray` or None, default=None 

85 

86 %(vmin)s 

87 

88 %(vmax)s 

89 

90 %(symmetric_cbar)s 

91 

92 Returns 

93 ------- 

94 cbar_vmin, cbar_vmax, vmin, vmax 

95 """ 

96 return get_colorbar_and_data_ranges( 

97 stat_map, 

98 vmin=vmin, 

99 vmax=vmax, 

100 symmetric_cbar=symmetric_cbar, 

101 ) 

102 

103 

104def _adjust_plot_roi_params(params): 

105 """Adjust avg_method and cbar_tick_format values for 'matplotlib' engine. 

106 

107 Sets the values in params dict. 

108 

109 Parameters 

110 ---------- 

111 params : dict 

112 dictionary to set the adjusted parameters 

113 """ 

114 avg_method = params.get("avg_method", None) 

115 if avg_method is None: 

116 params["avg_method"] = "median" 

117 

118 cbar_tick_format = params.get("cbar_tick_format", "auto") 

119 if cbar_tick_format == "auto": 

120 params["cbar_tick_format"] = "%i" 

121 

122 

123def _get_vertexcolor( 

124 surf_map, 

125 cmap, 

126 norm, 

127 absolute_threshold=None, 

128 bg_map=None, 

129 bg_on_data=None, 

130 darkness=None, 

131): 

132 """Get the color of the vertices.""" 

133 if bg_map is None: 

134 bg_data = np.ones(len(surf_map)) * 0.5 

135 bg_vmin, bg_vmax = 0, 1 

136 else: 

137 bg_data = np.copy(load_surf_data(bg_map)) 

138 

139 # scale background map if need be 

140 bg_vmin, bg_vmax = np.min(bg_data), np.max(bg_data) 

141 if bg_vmin < 0 or bg_vmax > 1: 

142 bg_norm = Normalize(vmin=bg_vmin, vmax=bg_vmax) 

143 bg_data = bg_norm(bg_data) 

144 

145 if darkness is not None: 

146 bg_data *= darkness 

147 warn( 

148 ( 

149 "The `darkness` parameter will be deprecated in release 0.13. " 

150 "We recommend setting `darkness` to None" 

151 ), 

152 DeprecationWarning, 

153 stacklevel=find_stack_level(), 

154 ) 

155 

156 bg_colors = plt.get_cmap("Greys")(bg_data) 

157 

158 # select vertices which are filtered out by the threshold 

159 if absolute_threshold is None: 

160 under_threshold = np.zeros_like(surf_map, dtype=bool) 

161 else: 

162 under_threshold = np.abs(surf_map) < absolute_threshold 

163 

164 surf_colors = cmap(norm(surf_map).data) 

165 # set transparency of voxels under threshold to 0 

166 surf_colors[under_threshold, 3] = 0 

167 if bg_on_data: 

168 # if need be, set transparency of voxels above threshold to 0.7 

169 # so that background map becomes visible 

170 surf_colors[~under_threshold, 3] = 0.7 

171 

172 vertex_colors = cm.mix_colormaps(surf_colors, bg_colors) 

173 

174 return to_color_strings(vertex_colors) 

175 

176 

177def _colorbar_from_array( 

178 array, 

179 vmin, 

180 vmax, 

181 threshold, 

182 symmetric_cbar=True, 

183 cmap=DEFAULT_DIVERGING_CMAP, 

184): 

185 """Generate a custom colorbar for an array. 

186 

187 Internal function used by plot_img_on_surf 

188 

189 array : :class:`np.ndarray` 

190 Any 3D array. 

191 

192 vmin : :obj:`float` 

193 lower bound for plotting of stat_map values. 

194 

195 vmax : :obj:`float` 

196 upper bound for plotting of stat_map values. 

197 

198 threshold : :obj:`float` 

199 If None is given, the colorbar is not thresholded. 

200 If a number is given, it is used to threshold the colorbar. 

201 Absolute values lower than threshold are shown in gray. 

202 

203 kwargs : :obj:`dict` 

204 Extra arguments passed to get_colorbar_and_data_ranges. 

205 

206 cmap : :obj:`str`, default='cold_hot' 

207 The name of a matplotlib or nilearn colormap. 

208 

209 """ 

210 _, _, vmin, vmax = get_colorbar_and_data_ranges( 

211 array, 

212 vmin=vmin, 

213 vmax=vmax, 

214 symmetric_cbar=symmetric_cbar, 

215 ) 

216 norm = Normalize(vmin=vmin, vmax=vmax) 

217 cmaplist = [cmap(i) for i in range(cmap.N)] 

218 

219 if threshold is None: 

220 threshold = 0.0 

221 

222 # set colors to gray for absolute values < threshold 

223 istart = int(norm(-threshold, clip=True) * (cmap.N - 1)) 

224 istop = int(norm(threshold, clip=True) * (cmap.N - 1)) 

225 for i in range(istart, istop): 

226 cmaplist[i] = (0.5, 0.5, 0.5, 1.0) 

227 our_cmap = LinearSegmentedColormap.from_list( 

228 "Custom cmap", cmaplist, cmap.N 

229 ) 

230 sm = plt.cm.ScalarMappable(cmap=our_cmap, norm=norm) 

231 

232 # fake up the array of the scalar mappable. 

233 sm._A = [] 

234 

235 return sm 

236 

237 

238def _compute_facecolors(bg_map, faces, n_vertices, darkness, alpha): 

239 """Help for plot_surf with matplotlib engine. 

240 

241 This function computes the facecolors. 

242 """ 

243 if bg_map is None: 

244 bg_data = np.ones(n_vertices) * 0.5 

245 else: 

246 bg_data = np.copy(load_surf_data(bg_map)) 

247 if bg_data.shape[0] != n_vertices: 

248 raise ValueError( 

249 "The bg_map does not have the same number " 

250 "of vertices as the mesh." 

251 ) 

252 

253 bg_faces = np.mean(bg_data[faces], axis=1) 

254 # scale background map if need be 

255 bg_vmin, bg_vmax = np.min(bg_faces), np.max(bg_faces) 

256 if bg_vmin < 0 or bg_vmax > 1: 

257 bg_norm = Normalize(vmin=bg_vmin, vmax=bg_vmax) 

258 bg_faces = bg_norm(bg_faces) 

259 

260 if darkness is not None: 

261 bg_faces *= darkness 

262 warn( 

263 ( 

264 "The `darkness` parameter will be deprecated in release 0.13. " 

265 "We recommend setting `darkness` to None" 

266 ), 

267 DeprecationWarning, 

268 stacklevel=find_stack_level(), 

269 ) 

270 

271 face_colors = plt.cm.gray_r(bg_faces) 

272 

273 # set alpha if in auto mode 

274 if alpha == "auto": 

275 alpha = 0.5 if bg_map is None else 1 

276 # modify alpha values of background 

277 face_colors[:, 3] = alpha * face_colors[:, 3] 

278 

279 return face_colors 

280 

281 

282def _compute_surf_map_faces( 

283 surf_map, faces, avg_method, n_vertices, face_colors_size 

284): 

285 """Help for plot_surf. 

286 

287 This function computes the surf map faces using the 

288 provided averaging method. 

289 

290 .. note:: 

291 This method is called exclusively when using matplotlib, 

292 since it only supports plotting face-colour maps and not 

293 vertex-colour maps. 

294 

295 """ 

296 surf_map_data = check_surf_map(surf_map, n_vertices) 

297 

298 # create face values from vertex values by selected avg methods 

299 error_message = ( 

300 "avg_method should be either " 

301 "['mean', 'median', 'max', 'min'] " 

302 "or a custom function" 

303 ) 

304 if isinstance(avg_method, str): 

305 try: 

306 avg_method = getattr(np, avg_method) 

307 except AttributeError: 

308 raise ValueError(error_message) 

309 surf_map_faces = avg_method(surf_map_data[faces], axis=1) 

310 elif callable(avg_method): 

311 surf_map_faces = np.apply_along_axis( 

312 avg_method, 1, surf_map_data[faces] 

313 ) 

314 

315 # check that surf_map_faces has the same length as face_colors 

316 if surf_map_faces.shape != (face_colors_size,): 

317 raise ValueError( 

318 "Array computed with the custom function " 

319 "from avg_method does not have the correct shape: " 

320 f"{surf_map_faces[0]} != {face_colors_size}" 

321 ) 

322 

323 # check that dtype is either int or float 

324 if not ( 

325 "int" in str(surf_map_faces.dtype) 

326 or "float" in str(surf_map_faces.dtype) 

327 ): 

328 raise ValueError( 

329 "Array computed with the custom function " 

330 "from avg_method should be an array of numbers " 

331 "(int or float)" 

332 ) 

333 else: 

334 raise ValueError(error_message) 

335 return surf_map_faces 

336 

337 

338def _get_bounds(data, vmin=None, vmax=None): 

339 """Help returning the data bounds.""" 

340 vmin = np.nanmin(data) if vmin is None else vmin 

341 vmax = np.nanmax(data) if vmax is None else vmax 

342 

343 if vmin == vmax == 0: 

344 # try to avoid divide by 0 warnings / errors downstream 

345 vmax = 1 

346 vmin = -1 

347 

348 return vmin, vmax 

349 

350 

351def _get_cmap(cmap, vmin, vmax, cbar_tick_format, threshold=None): 

352 """Help for plot_surf with matplotlib engine. 

353 

354 This function returns the colormap. 

355 """ 

356 our_cmap = plt.get_cmap(cmap) 

357 norm = Normalize(vmin=vmin, vmax=vmax) 

358 cmaplist = [our_cmap(i) for i in range(our_cmap.N)] 

359 if threshold is not None: 

360 if cbar_tick_format == "%i" and int(threshold) != threshold: 

361 warn( 

362 "You provided a non integer threshold " 

363 "but configured the colorbar to use integer formatting.", 

364 stacklevel=find_stack_level(), 

365 ) 

366 # set colors to gray for absolute values < threshold 

367 istart = int(norm(-threshold, clip=True) * (our_cmap.N - 1)) 

368 istop = int(norm(threshold, clip=True) * (our_cmap.N - 1)) 

369 for i in range(istart, istop): 

370 cmaplist[i] = (0.5, 0.5, 0.5, 1.0) 

371 our_cmap = LinearSegmentedColormap.from_list( 

372 "Custom cmap", cmaplist, our_cmap.N 

373 ) 

374 return our_cmap, norm 

375 

376 

377def _get_ticks(vmin, vmax, cbar_tick_format, threshold): 

378 """Help for plot_surf with matplotlib engine. 

379 

380 This function computes the tick values for the colorbar. 

381 """ 

382 # Default number of ticks is 5... 

383 n_ticks = 5 

384 # ...unless we are dealing with integers with a small range 

385 # in this case, we reduce the number of ticks 

386 if cbar_tick_format == "%i" and vmax - vmin < n_ticks - 1: 

387 return np.arange(vmin, vmax + 1) 

388 else: 

389 return get_cbar_ticks(vmin, vmax, threshold, n_ticks) 

390 

391 

392def _rescale(data, vmin=None, vmax=None): 

393 """Rescales the data.""" 

394 data_copy = np.copy(data) 

395 # if no vmin/vmax are passed figure them out from data 

396 vmin, vmax = _get_bounds(data_copy, vmin, vmax) 

397 data_copy -= vmin 

398 data_copy /= vmax - vmin 

399 return data_copy, vmin, vmax 

400 

401 

402def _threshold(data, threshold, vmin, vmax): 

403 """Thresholds the data.""" 

404 # If no thresholding and nans, filter them out 

405 if threshold is None: 

406 mask = np.logical_not(np.isnan(data)) 

407 else: 

408 mask = np.abs(data) >= threshold 

409 if vmin > -threshold: 

410 mask = np.logical_and(mask, data >= vmin) 

411 if vmax < threshold: 

412 mask = np.logical_and(mask, data <= vmax) 

413 return mask 

414 

415 

416def _threshold_and_rescale(data, threshold, vmin, vmax): 

417 """Help for plot_surf. 

418 

419 This function thresholds and rescales the provided data. 

420 """ 

421 data_copy, vmin, vmax = _rescale(data, vmin, vmax) 

422 return data_copy, _threshold(data, threshold, vmin, vmax), vmin, vmax 

423 

424 

425def _check_figure_axes_inputs(figure, axes): 

426 """Check if the specified figure and axes are matplotlib objects.""" 

427 if figure is not None and not isinstance(figure, plt.Figure): 

428 raise ValueError( 

429 "figure argument should be None or a 'matplotlib.pyplot.Figure'." 

430 ) 

431 if axes is not None and not isinstance(axes, plt.Axes): 

432 raise ValueError( 

433 "axes argument should be None or a 'matplotlib.pyplot.Axes'." 

434 ) 

435 

436 

437def _get_view_plot_surf(hemi, view): 

438 """Check ``hemi`` and ``view``, and return `elev` and `azim` for 

439 matplotlib engine. 

440 """ 

441 view = sanitize_hemi_view(hemi, view) 

442 if isinstance(view, str): 

443 if hemi == "both" and view in ["lateral", "medial"]: 

444 raise ValueError( 

445 "Invalid view definition: when hemi is 'both', " 

446 "view cannot be 'lateral' or 'medial'.\n" 

447 "Maybe you meant 'left' or 'right'?" 

448 ) 

449 return MATPLOTLIB_VIEWS[hemi][view] 

450 return view 

451 

452 

453def _plot_surf( 

454 surf_mesh, 

455 surf_map=None, 

456 bg_map=None, 

457 hemi=DEFAULT_HEMI, 

458 view=None, 

459 cmap=None, 

460 symmetric_cmap=None, 

461 colorbar=True, 

462 avg_method=None, 

463 threshold=None, 

464 alpha=None, 

465 bg_on_data=False, 

466 darkness=0.7, 

467 vmin=None, 

468 vmax=None, 

469 cbar_vmin=None, 

470 cbar_vmax=None, 

471 cbar_tick_format="auto", 

472 title=None, 

473 title_font_size=None, 

474 output_file=None, 

475 axes=None, 

476 figure=None, 

477): 

478 """Implement 'matplotlib' backend code for 

479 `~nilearn.plotting.surface.surf_plotting.plot_surf` function. 

480 """ 

481 parameters_not_implemented_in_matplotlib = { 

482 "symmetric_cmap": symmetric_cmap, 

483 "title_font_size": title_font_size, 

484 } 

485 

486 check_engine_params(parameters_not_implemented_in_matplotlib, "matplotlib") 

487 

488 # adjust values 

489 avg_method = "mean" if avg_method is None else avg_method 

490 alpha = "auto" if alpha is None else alpha 

491 cbar_tick_format = ( 

492 "%.2g" if cbar_tick_format == "auto" else cbar_tick_format 

493 ) 

494 # Leave space for colorbar 

495 figsize = [4.7, 5] if colorbar else [4, 5] 

496 

497 coords, faces = load_surf_mesh(surf_mesh) 

498 

499 limits = [coords.min(), coords.max()] 

500 

501 # Get elevation and azimut from view 

502 elev, azim = _get_view_plot_surf(hemi, view) 

503 

504 # if no cmap is given, set to matplotlib default 

505 if cmap is None: 

506 cmap = plt.get_cmap(plt.rcParamsDefault["image.cmap"]) 

507 # if cmap is given as string, translate to matplotlib cmap 

508 elif isinstance(cmap, str): 

509 cmap = plt.get_cmap(cmap) 

510 

511 # initiate figure and 3d axes 

512 if axes is None: 

513 if figure is None: 

514 figure = plt.figure(figsize=figsize) 

515 axes = figure.add_axes((0, 0, 1, 1), projection="3d") 

516 elif figure is None: 

517 figure = axes.get_figure() 

518 axes.set_xlim(*limits) 

519 axes.set_ylim(*limits) 

520 

521 try: 

522 axes.view_init(elev=elev, azim=azim) 

523 except AttributeError: 

524 raise AttributeError( 

525 "'Axes' object has no attribute 'view_init'.\n" 

526 "Remember that the projection must be '3d'.\n" 

527 "For example:\n" 

528 "\t plt.subplots(subplot_kw={'projection': '3d'})" 

529 ) 

530 except Exception as e: # pragma: no cover 

531 raise e 

532 

533 axes.set_axis_off() 

534 

535 # plot mesh without data 

536 p3dcollec = axes.plot_trisurf( 

537 coords[:, 0], 

538 coords[:, 1], 

539 coords[:, 2], 

540 triangles=faces, 

541 linewidth=0.1, 

542 antialiased=False, 

543 color="white", 

544 ) 

545 

546 # reduce viewing distance to remove space around mesh 

547 axes.set_box_aspect(None, zoom=1.3) 

548 

549 bg_face_colors = _compute_facecolors( 

550 bg_map, faces, coords.shape[0], darkness, alpha 

551 ) 

552 if surf_map is not None: 

553 surf_map_faces = _compute_surf_map_faces( 

554 surf_map, 

555 faces, 

556 avg_method, 

557 coords.shape[0], 

558 bg_face_colors.shape[0], 

559 ) 

560 surf_map_faces, kept_indices, vmin, vmax = _threshold_and_rescale( 

561 surf_map_faces, threshold, vmin, vmax 

562 ) 

563 

564 surf_map_face_colors = cmap(surf_map_faces) 

565 # set transparency of voxels under threshold to 0 

566 surf_map_face_colors[~kept_indices, 3] = 0 

567 if bg_on_data: 

568 # if need be, set transparency of voxels above threshold to 0.7 

569 # so that background map becomes visible 

570 surf_map_face_colors[kept_indices, 3] = 0.7 

571 

572 face_colors = mix_colormaps(surf_map_face_colors, bg_face_colors) 

573 

574 if colorbar: 

575 cbar_vmin = cbar_vmin if cbar_vmin is not None else vmin 

576 cbar_vmax = cbar_vmax if cbar_vmax is not None else vmax 

577 ticks = _get_ticks( 

578 cbar_vmin, cbar_vmax, cbar_tick_format, threshold 

579 ) 

580 our_cmap, norm = _get_cmap( 

581 cmap, vmin, vmax, cbar_tick_format, threshold 

582 ) 

583 bounds = np.linspace(cbar_vmin, cbar_vmax, our_cmap.N) 

584 

585 # we need to create a proxy mappable 

586 proxy_mappable = ScalarMappable(cmap=our_cmap, norm=norm) 

587 proxy_mappable.set_array(surf_map_faces) 

588 figure._colorbar_ax, _ = make_axes( 

589 axes, 

590 location="right", 

591 fraction=0.15, 

592 shrink=0.5, 

593 pad=0.0, 

594 aspect=10.0, 

595 ) 

596 figure._cbar = figure.colorbar( 

597 proxy_mappable, 

598 cax=figure._colorbar_ax, 

599 ticks=ticks, 

600 boundaries=bounds, 

601 spacing="proportional", 

602 format=cbar_tick_format, 

603 orientation="vertical", 

604 ) 

605 

606 # fix floating point bug causing highest to sometimes surpass 1 

607 # (for example 1.0000000000000002) 

608 face_colors[face_colors > 1] = 1 

609 

610 p3dcollec.set_facecolors(face_colors) 

611 p3dcollec.set_edgecolors(face_colors) 

612 

613 if title is not None: 

614 axes.set_title(title) 

615 

616 return save_figure_if_needed(figure, output_file) 

617 

618 

619def _plot_surf_contours( 

620 surf_mesh=None, 

621 roi_map=None, 

622 hemi=DEFAULT_HEMI, 

623 levels=None, 

624 labels=None, 

625 colors=None, 

626 legend=False, 

627 cmap="tab20", 

628 title=None, 

629 output_file=None, 

630 axes=None, 

631 figure=None, 

632 **kwargs, 

633): 

634 """Implement 'matplotlib' backend code for 

635 `~nilearn.plotting.surface.surf_plotting.plot_surf_contours` function. 

636 """ 

637 _check_figure_axes_inputs(figure, axes) 

638 

639 if figure is None and axes is None: 

640 figure = _plot_surf(surf_mesh, hemi=hemi, **kwargs) 

641 axes = figure.axes[0] 

642 elif figure is None: 

643 figure = axes.get_figure() 

644 elif axes is None: 

645 axes = figure.axes[0] 

646 

647 if axes.name != "3d": 

648 raise ValueError("Axes must be 3D.") 

649 

650 # test if axes contains Poly3DCollection, if not initialize surface 

651 if not axes.collections or not isinstance( 

652 axes.collections[0], Poly3DCollection 

653 ): 

654 _ = _plot_surf(surf_mesh, hemi=hemi, axes=axes, **kwargs) 

655 

656 if levels is None: 

657 levels = np.unique(roi_map) 

658 

659 if labels is None: 

660 labels = [None] * len(levels) 

661 

662 if colors is None: 

663 n_levels = len(levels) 

664 vmax = n_levels 

665 cmap = plt.get_cmap(cmap) 

666 norm = Normalize(vmin=0, vmax=vmax) 

667 colors = [cmap(norm(color_i)) for color_i in range(vmax)] 

668 else: 

669 try: 

670 colors = [to_rgba(color, alpha=1.0) for color in colors] 

671 except ValueError: 

672 raise ValueError( 

673 "All elements of colors need to be either a" 

674 " matplotlib color string or RGBA values." 

675 ) 

676 

677 if not (len(levels) == len(labels) == len(colors)): 

678 raise ValueError( 

679 "Levels, labels, and colors " 

680 "argument need to be either the same length or None." 

681 ) 

682 

683 _, faces = load_surf_mesh(surf_mesh) 

684 roi = load_surf_data(roi_map) 

685 

686 patch_list = [] 

687 for level, color, label in zip(levels, colors, labels): 

688 roi_indices = np.where(roi == level)[0] 

689 faces_outside = get_faces_on_edge(faces, roi_indices) 

690 # Fix: Matplotlib version 3.3.2 to 3.3.3 

691 # Attribute _facecolors3d changed to _facecolor3d in 

692 # matplotlib version 3.3.3 

693 if compare_version(mpl_version, "<", "3.3.3"): 

694 axes.collections[0]._facecolors3d[faces_outside] = color 

695 if axes.collections[0]._edgecolors3d.size == 0: 

696 axes.collections[0].set_edgecolor( 

697 axes.collections[0]._facecolors3d 

698 ) 

699 axes.collections[0]._edgecolors3d[faces_outside] = color 

700 else: 

701 axes.collections[0]._facecolor3d[faces_outside] = color 

702 if axes.collections[0]._edgecolor3d.size == 0: 

703 axes.collections[0].set_edgecolor( 

704 axes.collections[0]._facecolor3d 

705 ) 

706 axes.collections[0]._edgecolor3d[faces_outside] = color 

707 if label and legend: 

708 patch_list.append(Patch(color=color, label=label)) 

709 # plot legend only if indicated and labels provided 

710 if legend and np.any([lbl is not None for lbl in labels]): 

711 figure.legend(handles=patch_list) 

712 # if legends, then move title to the left 

713 if title is None and hasattr(figure._suptitle, "_text"): 

714 title = figure._suptitle._text 

715 if title: 

716 axes.set_title(title) 

717 

718 return save_figure_if_needed(figure, output_file) 

719 

720 

721def _plot_img_on_surf( 

722 surf, 

723 surf_mesh, 

724 stat_map, 

725 texture, 

726 hemis, 

727 modes, 

728 bg_on_data=False, 

729 inflate=False, 

730 output_file=None, 

731 title=None, 

732 colorbar=True, 

733 vmin=None, 

734 vmax=None, 

735 threshold=None, 

736 symmetric_cbar=None, 

737 cmap=DEFAULT_DIVERGING_CMAP, 

738 cbar_tick_format=None, 

739 **kwargs, 

740): 

741 """Implement 'matplotlib' backend code for 

742 `~nilearn.plotting.surface.surf_plotting.plot_img_on_surf` function. 

743 """ 

744 if symmetric_cbar is None: 

745 symmetric_cbar = "auto" 

746 if cbar_tick_format is None: 

747 cbar_tick_format = "%i" 

748 symmetric_cmap = kwargs.pop("symmetric_cmap", True) 

749 

750 cbar_h = 0.25 

751 title_h = 0.25 * (title is not None) 

752 w, h = plt.figaspect((len(modes) + cbar_h + title_h) / len(hemis)) 

753 fig = plt.figure(figsize=(w, h), constrained_layout=False) 

754 height_ratios = [title_h] + [1.0] * len(modes) + [cbar_h] 

755 grid = GridSpec( 

756 len(modes) + 2, 

757 len(hemis), 

758 left=0.0, 

759 right=1.0, 

760 bottom=0.0, 

761 top=1.0, 

762 height_ratios=height_ratios, 

763 hspace=0.0, 

764 wspace=0.0, 

765 ) 

766 axes = [] 

767 

768 for i, (mode, hemi) in enumerate(itertools.product(modes, hemis)): 

769 bg_map = None 

770 # By default, add curv sign background map if mesh is inflated, 

771 # sulc depth background map otherwise 

772 if inflate: 

773 curv_map = load_surf_data(surf_mesh[f"curv_{hemi}"]) 

774 curv_sign_map = (np.sign(curv_map) + 1) / 4 + 0.25 

775 bg_map = curv_sign_map 

776 else: 

777 sulc_map = surf_mesh[f"sulc_{hemi}"] 

778 bg_map = sulc_map 

779 

780 ax = fig.add_subplot(grid[i + len(hemis)], projection="3d") 

781 axes.append(ax) 

782 

783 # Starting from this line until _plot_surf included is actually 

784 # plot_surf_stat_map, but to avoid cyclic import 

785 # the code is duplicated here 

786 stat_map_iter = texture[hemi] 

787 surf_mesh_iter = surf[hemi] 

788 

789 stat_map_iter, surf_mesh_iter, bg_map = check_surface_plotting_inputs( 

790 stat_map_iter, 

791 surf_mesh_iter, 

792 hemi, 

793 bg_map, 

794 map_var_name="img_on_surf", 

795 ) 

796 loaded_stat_map = load_surf_data(stat_map_iter) 

797 

798 # derive symmetric vmin, vmax and colorbar limits depending on 

799 # symmetric_cbar settings 

800 cbar_vmin, cbar_vmax, vmin, vmax = _adjust_colorbar_and_data_ranges( 

801 loaded_stat_map, 

802 vmin=vmin, 

803 vmax=vmax, 

804 symmetric_cbar=symmetric_cbar, 

805 ) 

806 _plot_surf( 

807 surf_mesh=surf_mesh_iter, 

808 surf_map=loaded_stat_map, 

809 bg_map=bg_map, 

810 hemi=hemi, 

811 view=mode, 

812 cmap=cmap, 

813 symmetric_cmap=symmetric_cmap, 

814 colorbar=False, # Colorbar created externally. 

815 threshold=threshold, 

816 bg_on_data=bg_on_data, 

817 vmin=vmin, 

818 vmax=vmax, 

819 axes=ax, 

820 **kwargs, 

821 ) 

822 

823 # We increase this value to better position the camera of the 

824 # 3D projection plot. The default value makes meshes look too 

825 # small. 

826 ax.set_box_aspect(None, zoom=1.3) 

827 

828 if colorbar: 

829 sm = _colorbar_from_array( 

830 get_data(stat_map), 

831 vmin, 

832 vmax, 

833 threshold, 

834 symmetric_cbar=symmetric_cbar, 

835 cmap=plt.get_cmap(cmap), 

836 ) 

837 

838 cbar_grid = GridSpecFromSubplotSpec(3, 3, grid[-1, :]) 

839 cbar_ax = fig.add_subplot(cbar_grid[1]) 

840 axes.append(cbar_ax) 

841 # Get custom ticks to set in colorbar 

842 ticks = _get_ticks(vmin, vmax, cbar_tick_format, threshold) 

843 fig.colorbar( 

844 sm, 

845 cax=cbar_ax, 

846 orientation="horizontal", 

847 ticks=ticks, 

848 format=cbar_tick_format, 

849 ) 

850 

851 if title is not None: 

852 fig.suptitle(title, y=1.0 - title_h / sum(height_ratios), va="bottom") 

853 

854 if output_file is None: 

855 return fig, axes 

856 fig.savefig(output_file, bbox_inches="tight") 

857 plt.close(fig)