Coverage for nilearn/plotting/displays/_axes.py: 0%

202 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-16 12:32 +0200

1import numbers 

2import warnings 

3 

4import matplotlib.pyplot as plt 

5import numpy as np 

6from matplotlib.colors import Normalize 

7from matplotlib.font_manager import FontProperties 

8from matplotlib.lines import Line2D 

9from matplotlib.patches import FancyArrow 

10from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar 

11 

12from nilearn._utils import fill_doc 

13from nilearn._utils.logger import find_stack_level 

14from nilearn.image import coord_transform 

15from nilearn.plotting.displays._utils import coords_3d_to_2d 

16from nilearn.plotting.glass_brain import plot_brain_schematics 

17 

18 

19@fill_doc 

20class BaseAxes: 

21 """An MPL axis-like object that displays a 2D view of 3D volumes. 

22 

23 Parameters 

24 ---------- 

25 %(ax)s 

26 direction : {'x', 'y', 'z'} 

27 The directions of the view. 

28 

29 coord : :obj:`float` 

30 The coordinate along the direction of the cut. 

31 %(radiological)s 

32 """ 

33 

34 def __init__(self, ax, direction, coord, radiological=False): 

35 self.ax = ax 

36 self.direction = direction 

37 self.coord = coord 

38 self._object_bounds = [] 

39 self.shape = None 

40 self.radiological = radiological 

41 

42 def transform_to_2d(self, data, affine): 

43 """Transform to a 2D.""" 

44 raise NotImplementedError( 

45 "'transform_to_2d' needs to be implemented in derived classes'" 

46 ) 

47 

48 def add_object_bounds(self, bounds): 

49 """Ensure that axes get rescaled when adding object bounds.""" 

50 old_object_bounds = self.get_object_bounds() 

51 self._object_bounds.append(bounds) 

52 new_object_bounds = self.get_object_bounds() 

53 

54 if new_object_bounds != old_object_bounds: 

55 self.ax.axis(self.get_object_bounds()) 

56 

57 def draw_2d( 

58 self, 

59 data_2d, 

60 data_bounds, 

61 bounding_box, 

62 type="imshow", 

63 transparency=None, 

64 **kwargs, 

65 ): 

66 """Draw 2D.""" 

67 kwargs["origin"] = "upper" 

68 

69 if "alpha" in kwargs: 

70 warnings.warn( 

71 f"{kwargs['alpha']=} detected in parameters.\n" 

72 f"Overriding with {transparency=}.\n" 

73 "To suppress this warning pass " 

74 "your 'alpha' value " 

75 "via the 'transparency' parameter.", 

76 stacklevel=find_stack_level(), 

77 ) 

78 kwargs["alpha"] = transparency 

79 

80 if self.direction == "y": 

81 (xmin, xmax), (_, _), (zmin, zmax) = data_bounds 

82 (xmin_, xmax_), (_, _), (zmin_, zmax_) = bounding_box 

83 elif self.direction in "xlr": 

84 (_, _), (xmin, xmax), (zmin, zmax) = data_bounds 

85 (_, _), (xmin_, xmax_), (zmin_, zmax_) = bounding_box 

86 elif self.direction == "z": 

87 (xmin, xmax), (zmin, zmax), (_, _) = data_bounds 

88 (xmin_, xmax_), (zmin_, zmax_), (_, _) = bounding_box 

89 else: 

90 raise ValueError(f"Invalid value for direction {self.direction}") 

91 ax = self.ax 

92 # Here we need to do a copy to avoid having the image changing as 

93 # we change the data 

94 im = getattr(ax, type)( 

95 data_2d.copy(), extent=(xmin, xmax, zmin, zmax), **kwargs 

96 ) 

97 

98 self.add_object_bounds((xmin_, xmax_, zmin_, zmax_)) 

99 self.shape = data_2d.T.shape 

100 # The bounds of the object do not take into account a possible 

101 # inversion of the axis. As such, we check that the axis is properly 

102 # inverted when direction is left 

103 if self.direction == "l" and not (ax.get_xlim()[0] > ax.get_xlim()[1]): 

104 ax.invert_xaxis() 

105 return im 

106 

107 def get_object_bounds(self): 

108 """Return the bounds of the objects on this axes.""" 

109 if len(self._object_bounds) == 0: 

110 # Nothing plotted yet 

111 return -0.01, 0.01, -0.01, 0.01 

112 xmins, xmaxs, ymins, ymaxs = np.array(self._object_bounds).T 

113 xmax = max(xmaxs.max(), xmins.max()) 

114 xmin = min(xmins.min(), xmaxs.min()) 

115 ymax = max(ymaxs.max(), ymins.max()) 

116 ymin = min(ymins.min(), ymaxs.min()) 

117 

118 return xmin, xmax, ymin, ymax 

119 

120 def draw_left_right(self, size, bg_color, **kwargs): 

121 """Draw the annotation "L" for left, and "R" for right. 

122 

123 Parameters 

124 ---------- 

125 size : :obj:`float`, optional 

126 Size of the text areas. 

127 

128 bg_color : matplotlib color: :obj:`str` or (r, g, b) value 

129 The background color for both text areas. 

130 

131 """ 

132 if self.direction in "xlr": 

133 return 

134 ax = self.ax 

135 annotation_on_left = "L" 

136 annotation_on_right = "R" 

137 if self.radiological: 

138 annotation_on_left = "R" 

139 annotation_on_right = "L" 

140 ax.text( 

141 0.1, 

142 0.95, 

143 annotation_on_left, 

144 transform=ax.transAxes, 

145 horizontalalignment="left", 

146 verticalalignment="top", 

147 size=size, 

148 bbox={ 

149 "boxstyle": "square,pad=0", 

150 "ec": bg_color, 

151 "fc": bg_color, 

152 "alpha": 1, 

153 }, 

154 **kwargs, 

155 ) 

156 

157 ax.text( 

158 0.9, 

159 0.95, 

160 annotation_on_right, 

161 transform=ax.transAxes, 

162 horizontalalignment="right", 

163 verticalalignment="top", 

164 size=size, 

165 bbox={"boxstyle": "square,pad=0", "ec": bg_color, "fc": bg_color}, 

166 **kwargs, 

167 ) 

168 

169 def draw_scale_bar( 

170 self, 

171 bg_color, 

172 size=5.0, 

173 units="cm", 

174 fontproperties=None, 

175 frameon=False, 

176 loc=4, 

177 pad=0.1, 

178 borderpad=0.5, 

179 sep=5, 

180 size_vertical=0, 

181 label_top=False, 

182 color="black", 

183 fontsize=None, 

184 **kwargs, 

185 ): 

186 """Add a scale bar annotation to the display. 

187 

188 Parameters 

189 ---------- 

190 bg_color : matplotlib color: :obj:`str` or (r, g, b) value 

191 The background color of the scale bar annotation. 

192 

193 size : :obj:`float`, default=5.0 

194 Horizontal length of the scale bar, given in `units`. 

195 

196 

197 units : :obj:`str`, default='cm' 

198 Physical units of the scale bar (`'cm'` or `'mm'`). 

199 

200 

201 fontproperties : :class:`~matplotlib.font_manager.FontProperties`\ 

202 or :obj:`dict`, optional 

203 Font properties for the label text. 

204 

205 frameon : :obj:`bool`, default=False 

206 Whether the scale bar is plotted with a border. 

207 

208 loc : :obj:`int`, default=4 

209 Location of this scale bar. 

210 Valid location codes are documented in 

211 :class:`~mpl_toolkits.axes_grid1.anchored_artists.AnchoredSizeBar` 

212 

213 pad : :obj:`int` or :obj:`float`, default=0.1 

214 Padding around the label and scale bar, in fraction of the font 

215 size. 

216 

217 borderpad : :obj:`int` or :obj:`float`, default=0.5 

218 Border padding, in fraction of the font size. 

219 

220 sep : :obj:`int` or :obj:`float`, default=5 

221 Separation between the label and the scale bar, in points. 

222 

223 

224 size_vertical : :obj:`int` or :obj:`float`, default=0 

225 Vertical length of the size bar, given in `units`. 

226 

227 

228 label_top : :obj:`bool`, default=False 

229 If ``True``, the label will be over the scale bar. 

230 

231 

232 color : :obj:`str`, default='black' 

233 Color for the scale bar and label. 

234 

235 fontsize : :obj:`int`, optional 

236 Label font size (overwrites the size passed in through the 

237 ``fontproperties`` argument). 

238 

239 **kwargs : 

240 Keyworded arguments to pass to 

241 :class:`~matplotlib.offsetbox.AnchoredOffsetbox`. 

242 

243 """ 

244 axis = self.ax 

245 fontproperties = fontproperties or FontProperties() 

246 if fontsize: 

247 fontproperties.set_size(fontsize) 

248 width_mm = size 

249 if units == "cm": 

250 width_mm *= 10 

251 

252 anchor_size_bar = AnchoredSizeBar( 

253 axis.transData, 

254 width_mm, 

255 f"{size:g}{units}", 

256 fontproperties=fontproperties, 

257 frameon=frameon, 

258 loc=loc, 

259 pad=pad, 

260 borderpad=borderpad, 

261 sep=sep, 

262 size_vertical=size_vertical, 

263 label_top=label_top, 

264 color=color, 

265 **kwargs, 

266 ) 

267 

268 if frameon: 

269 anchor_size_bar.patch.set_facecolor(bg_color) 

270 anchor_size_bar.patch.set_edgecolor("none") 

271 axis.add_artist(anchor_size_bar) 

272 

273 def draw_position(self, size, bg_color, **kwargs): 

274 """``draw_position`` is not implemented in base class and \ 

275 should be implemented in derived classes. 

276 """ 

277 raise NotImplementedError( 

278 "'draw_position' should be implemented in derived classes" 

279 ) 

280 

281 

282@fill_doc 

283class CutAxes(BaseAxes): 

284 """An MPL axis-like object that displays a cut of 3D volumes. 

285 

286 Parameters 

287 ---------- 

288 %(ax)s 

289 direction : {'x', 'y', 'z'} 

290 The directions of the view. 

291 

292 coord : :obj:`float` 

293 The coordinate along the direction of the cut. 

294 """ 

295 

296 def transform_to_2d(self, data, affine): 

297 """Cut the 3D volume into a 2D slice. 

298 

299 Parameters 

300 ---------- 

301 data : 3D :class:`~numpy.ndarray` 

302 The 3D volume to cut. 

303 

304 affine : 4x4 :class:`~numpy.ndarray` 

305 The affine of the volume. 

306 

307 """ 

308 coords = [0, 0, 0] 

309 if self.direction not in ["x", "y", "z"]: 

310 raise ValueError(f"Invalid value for direction {self.direction}") 

311 coords["xyz".index(self.direction)] = self.coord 

312 x_map, y_map, z_map = ( 

313 int(np.round(c)) 

314 for c in coord_transform( 

315 coords[0], coords[1], coords[2], np.linalg.inv(affine) 

316 ) 

317 ) 

318 if self.direction == "y": 

319 cut = np.rot90(data[:, y_map, :]) 

320 elif self.direction == "x": 

321 cut = np.rot90(data[x_map, :, :]) 

322 elif self.direction == "z": 

323 cut = np.rot90(data[:, :, z_map]) 

324 return cut 

325 

326 def draw_position(self, size, bg_color, decimals=False, **kwargs): 

327 """Draw coordinates. 

328 

329 Parameters 

330 ---------- 

331 size : :obj:`float`, optional 

332 Size of the text area. 

333 

334 bg_color : matplotlib color: :obj:`str` or (r, g, b) value 

335 The background color for text area. 

336 

337 decimals : :obj:`bool` or :obj:`str`, default=False 

338 Formatting string for the coordinates. 

339 If set to ``False``, integer formatting will be used. 

340 

341 

342 """ 

343 if decimals: 

344 text = f"%s=%.{decimals}f" 

345 coord = float(self.coord) 

346 else: 

347 text = "%s=%i" 

348 coord = self.coord 

349 ax = self.ax 

350 ax.text( 

351 0, 

352 0, 

353 text % (self.direction, coord), 

354 transform=ax.transAxes, 

355 horizontalalignment="left", 

356 verticalalignment="bottom", 

357 size=size, 

358 bbox={ 

359 "boxstyle": "square,pad=0", 

360 "ec": bg_color, 

361 "fc": bg_color, 

362 "alpha": 1, 

363 }, 

364 **kwargs, 

365 ) 

366 

367 

368@fill_doc 

369class GlassBrainAxes(BaseAxes): 

370 """An MPL axis-like object that displays a 2D projection of 3D \ 

371 volumes with a schematic view of the brain. 

372 

373 Parameters 

374 ---------- 

375 %(ax)s 

376 direction : {'x', 'y', 'z'} 

377 The directions of the view. 

378 

379 coord : :obj:`float` 

380 The coordinate along the direction of the cut. 

381 

382 plot_abs : :obj:`bool`, default=True 

383 If set to ``True`` the absolute value of the data will be considered. 

384 

385 """ 

386 

387 def __init__( 

388 self, ax, direction, coord, plot_abs=True, radiological=False, **kwargs 

389 ): 

390 super().__init__(ax, direction, coord, radiological=radiological) 

391 self._plot_abs = plot_abs 

392 if ax is not None: 

393 object_bounds = plot_brain_schematics(ax, direction, **kwargs) 

394 self.add_object_bounds(object_bounds) 

395 

396 def transform_to_2d(self, data, affine): 

397 """Return the maximum of the absolute value of the 3D volume \ 

398 along an axis. 

399 

400 Parameters 

401 ---------- 

402 data : 3D :class:`numpy.ndarray` 

403 The 3D volume. 

404 

405 affine : 4x4 :class:`numpy.ndarray` 

406 The affine of the volume. 

407 

408 """ 

409 max_axis = ( 

410 0 if self.direction in "xlr" else ".yz".index(self.direction) 

411 ) 

412 # set unselected brain hemisphere activations to 0 

413 if self.direction == "l": 

414 x_center, _, _, _ = np.dot( 

415 np.linalg.inv(affine), np.array([0, 0, 0, 1]) 

416 ) 

417 data_selection = data[: int(x_center), :, :] 

418 elif self.direction == "r": 

419 x_center, _, _, _ = np.dot( 

420 np.linalg.inv(affine), np.array([0, 0, 0, 1]) 

421 ) 

422 data_selection = data[int(x_center) :, :, :] 

423 else: 

424 data_selection = data 

425 

426 # We need to make sure data_selection is not empty in the x axis 

427 # This should be the case since we expect images in MNI space 

428 if data_selection.shape[0] == 0: 

429 data_selection = data 

430 

431 if not self._plot_abs: 

432 # get the shape of the array we are projecting to 

433 new_shape = list(data.shape) 

434 del new_shape[max_axis] 

435 

436 # generate a 3D indexing array that points to max abs value in the 

437 # current projection 

438 a1, a2 = np.indices(new_shape) 

439 inds = [a1, a2] 

440 inds.insert(max_axis, np.abs(data_selection).argmax(axis=max_axis)) 

441 

442 # take the values where the absolute value of the projection 

443 # is the highest 

444 maximum_intensity_data = data_selection[tuple(inds)] 

445 else: 

446 maximum_intensity_data = np.abs(data_selection).max(axis=max_axis) 

447 

448 # This work around can be removed bumping matplotlib > 2.1.0. See #1815 

449 # in nilearn for the invention of this work around 

450 if ( 

451 self.direction == "l" 

452 and data_selection.min() is np.ma.masked 

453 and not (self.ax.get_xlim()[0] > self.ax.get_xlim()[1]) 

454 ): 

455 self.ax.invert_xaxis() 

456 

457 return np.rot90(maximum_intensity_data) 

458 

459 def draw_position(self, size, bg_color, **kwargs): 

460 """Not implemented as it does not make sense to draw crosses for \ 

461 the position of the cuts \ 

462 since we are taking the max along one axis. 

463 """ 

464 pass 

465 

466 def _add_markers(self, marker_coords, marker_color, marker_size, **kwargs): 

467 """Plot markers. 

468 

469 In the case of 'l' and 'r' directions (for hemispheric projections), 

470 markers in the coordinate x == 0 are included in both hemispheres. 

471 """ 

472 marker_coords_2d = coords_3d_to_2d(marker_coords, self.direction) 

473 xdata, ydata = marker_coords_2d.T 

474 

475 # Allow markers only in their respective hemisphere when appropriate 

476 if self.direction in "lr": 

477 if not isinstance(marker_color, str) and not isinstance( 

478 marker_color, np.ndarray 

479 ): 

480 marker_color = np.asarray(marker_color) 

481 relevant_coords = [] 

482 xcoords, _, _ = marker_coords.T 

483 relevant_coords.extend( 

484 cidx 

485 for cidx, xc in enumerate(xcoords) 

486 if (self.direction == "r" and xc >= 0) 

487 or (self.direction == "l" and xc <= 0) 

488 ) 

489 xdata = xdata[relevant_coords] 

490 ydata = ydata[relevant_coords] 

491 # if marker_color is string for example 'red' or 'blue', then 

492 # we pass marker_color as it is to matplotlib scatter without 

493 # making any selection in 'l' or 'r' color. 

494 # More likely that user wants to display all nodes to be in 

495 # same color. 

496 if not isinstance(marker_color, str) and len(marker_color) != 1: 

497 marker_color = marker_color[relevant_coords] 

498 

499 if not isinstance(marker_size, numbers.Number): 

500 marker_size = np.asarray(marker_size)[relevant_coords] 

501 

502 defaults = {"marker": "o", "zorder": 1000} 

503 for k, v in defaults.items(): 

504 kwargs.setdefault(k, v) 

505 

506 self.ax.scatter(xdata, ydata, s=marker_size, c=marker_color, **kwargs) 

507 

508 def _add_lines( 

509 self, 

510 line_coords, 

511 line_values, 

512 cmap, 

513 vmin=None, 

514 vmax=None, 

515 directed=False, 

516 **kwargs, 

517 ): 

518 """Plot lines. 

519 

520 Parameters 

521 ---------- 

522 line_coords : :obj:`list` of :class:`numpy.ndarray` of shape (2, 3) 

523 3D coordinates of lines start points and end points. 

524 

525 line_values : array_like 

526 Values of the lines. 

527 

528 %(cmap)s 

529 Colormap used to map ``line_values`` to a color. 

530 

531 vmin, vmax : :obj:`float`, optional 

532 If not ``None``, either or both of these values will be used to 

533 as the minimum and maximum values to color lines. If ``None`` are 

534 supplied the maximum absolute value within the given threshold 

535 will be used as minimum (multiplied by -1) and maximum 

536 coloring levels. 

537 

538 directed : :obj:`bool`, default=False 

539 Add arrows instead of lines if set to ``True``. 

540 Use this when plotting directed graphs for example. 

541 

542 

543 kwargs : :obj:`dict` 

544 Additional arguments to pass to :class:`~matplotlib.lines.Line2D`. 

545 

546 """ 

547 # colormap for colorbar 

548 self.cmap = cmap 

549 if vmin is None and vmax is None: 

550 abs_line_values_max = np.abs(line_values).max() 

551 vmin = -abs_line_values_max 

552 vmax = abs_line_values_max 

553 elif vmin is None: 

554 if vmax > 0: 

555 vmin = -vmax 

556 else: 

557 raise ValueError( 

558 "If vmax is set to a non-positive number " 

559 "then vmin needs to be specified" 

560 ) 

561 elif vmax is None: 

562 if vmin < 0: 

563 vmax = -vmin 

564 else: 

565 raise ValueError( 

566 "If vmin is set to a non-negative number " 

567 "then vmax needs to be specified" 

568 ) 

569 norm = Normalize(vmin=vmin, vmax=vmax) 

570 # normalization useful for colorbar 

571 self.norm = norm 

572 abs_norm = Normalize(vmin=0, vmax=max(abs(vmax), abs(vmin))) 

573 value_to_color = plt.cm.ScalarMappable(norm=norm, cmap=cmap).to_rgba 

574 

575 # Allow lines only in their respective hemisphere when appropriate 

576 if self.direction in "lr": 

577 relevant_lines = [ 

578 lidx 

579 for lidx, line in enumerate(line_coords) 

580 if ( 

581 self.direction == "r" 

582 and line[0, 0] >= 0 

583 and line[1, 0] >= 0 

584 ) 

585 or ( 

586 self.direction == "l" and line[0, 0] < 0 and line[1, 0] < 0 

587 ) 

588 ] 

589 line_coords = np.array(line_coords)[relevant_lines] 

590 line_values = line_values[relevant_lines] 

591 

592 for start_end_point_3d, line_value in zip(line_coords, line_values): 

593 start_end_point_2d = coords_3d_to_2d( 

594 start_end_point_3d, self.direction 

595 ) 

596 

597 color = value_to_color(line_value) 

598 abs_line_value = abs(line_value) 

599 linewidth = 1 + 2 * abs_norm(abs_line_value) 

600 # Hacky way to put the strongest connections on top of the weakest 

601 # note sign does not matter hence using 'abs' 

602 zorder = 10 + 10 * abs_norm(abs_line_value) 

603 this_kwargs = { 

604 "color": color, 

605 "linewidth": linewidth, 

606 "zorder": zorder, 

607 } 

608 # kwargs should have priority over this_kwargs so that the 

609 # user can override the default logic 

610 this_kwargs.update(kwargs) 

611 xdata, ydata = start_end_point_2d.T 

612 # If directed is True, add an arrow 

613 if directed: 

614 dx = xdata[1] - xdata[0] 

615 dy = ydata[1] - ydata[0] 

616 # Hack to avoid empty arrows to crash with 

617 # matplotlib versions older than 3.1 

618 # This can be removed once support for 

619 # matplotlib pre 3.1 has been dropped. 

620 if dx == dy == 0: 

621 arrow = FancyArrow(xdata[0], ydata[0], dx, dy) 

622 else: 

623 arrow = FancyArrow( 

624 xdata[0], 

625 ydata[0], 

626 dx, 

627 dy, 

628 length_includes_head=True, 

629 width=linewidth, 

630 head_width=3 * linewidth, 

631 **this_kwargs, 

632 ) 

633 self.ax.add_patch(arrow) 

634 # Otherwise a line 

635 else: 

636 line = Line2D(xdata, ydata, **this_kwargs) 

637 self.ax.add_line(line)