Coverage for nilearn/plotting/displays/_figures.py: 0%

187 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-16 12:32 +0200

1import warnings 

2 

3import numpy as np 

4from scipy import linalg 

5from scipy.spatial import distance_matrix 

6 

7from nilearn._utils.helpers import is_kaleido_installed, is_plotly_installed 

8from nilearn._utils.logger import find_stack_level 

9from nilearn.plotting.surface._utils import DEFAULT_HEMI, get_faces_on_edge 

10from nilearn.surface import SurfaceImage 

11from nilearn.surface.surface import get_data, load_surf_data 

12 

13if is_plotly_installed(): 

14 import plotly.graph_objects as go 

15 

16 

17class SurfaceFigure: 

18 """Abstract class for surface figures. 

19 

20 Parameters 

21 ---------- 

22 figure : Figure instance or ``None``, optional 

23 Figure to be wrapped. 

24 

25 output_file : :obj:`str` or ``None``, optional 

26 Path to output file. 

27 """ 

28 

29 def __init__(self, figure=None, output_file=None, hemi=DEFAULT_HEMI): 

30 self.figure = figure 

31 self.output_file = output_file 

32 self.hemi = hemi 

33 

34 def show(self): 

35 """Show the figure.""" 

36 raise NotImplementedError 

37 

38 def _check_output_file(self, output_file=None): 

39 """If an output file is provided, \ 

40 set it as the new default output file. 

41 

42 Parameters 

43 ---------- 

44 output_file : :obj:`str` or ``None``, optional 

45 Path to output file. 

46 """ 

47 if output_file is None: 

48 if self.output_file is None: 

49 raise ValueError( 

50 "You must provide an output file name to save the figure." 

51 ) 

52 else: 

53 self.output_file = output_file 

54 

55 def add_contours(self): 

56 """Draw boundaries around roi.""" 

57 raise NotImplementedError 

58 

59 

60class PlotlySurfaceFigure(SurfaceFigure): 

61 """Implementation of a surface figure obtained with `plotly` engine. 

62 

63 Parameters 

64 ---------- 

65 figure : Plotly figure instance or ``None``, optional 

66 Plotly figure instance to be used. 

67 

68 output_file : :obj:`str` or ``None``, optional 

69 Output file path. 

70 

71 Attributes 

72 ---------- 

73 figure : Plotly figure instance 

74 Plotly figure. Use this attribute to access the underlying 

75 plotly figure for further customization and use plotly 

76 functionality. 

77 

78 output_file : :obj:`str` 

79 Output file path. 

80 

81 """ 

82 

83 @property 

84 def _faces(self): 

85 return np.vstack( 

86 [self.figure._data[0].get(d) for d in ["i", "j", "k"]] 

87 ).T 

88 

89 @property 

90 def _coords(self): 

91 return np.vstack( 

92 [self.figure._data[0].get(d) for d in ["x", "y", "z"]] 

93 ).T 

94 

95 def __init__(self, figure=None, output_file=None, hemi=DEFAULT_HEMI): 

96 if not is_plotly_installed(): 

97 raise ImportError( 

98 "Plotly is required to use `PlotlySurfaceFigure`." 

99 ) 

100 

101 if figure is not None and not isinstance(figure, go.Figure): 

102 raise TypeError( 

103 "`PlotlySurfaceFigure` accepts only plotly Figure objects." 

104 ) 

105 super().__init__(figure=figure, output_file=output_file, hemi=hemi) 

106 

107 def show(self, renderer="browser"): 

108 """Show the figure. 

109 

110 Parameters 

111 ---------- 

112 renderer : :obj:`str`, default='browser' 

113 Plotly renderer to be used. 

114 

115 """ 

116 if self.figure is not None: 

117 self.figure.show(renderer=renderer) 

118 return self.figure 

119 

120 def savefig(self, output_file=None): 

121 """Save the figure to file. 

122 

123 Parameters 

124 ---------- 

125 output_file : :obj:`str` or ``None``, optional 

126 Path to output file. 

127 """ 

128 if not is_kaleido_installed(): 

129 raise ImportError( 

130 "`kaleido` is required to save plotly figures to disk." 

131 ) 

132 self._check_output_file(output_file=output_file) 

133 if self.figure is not None: 

134 self.figure.write_image(self.output_file) 

135 

136 def add_contours( 

137 self, 

138 roi_map, 

139 levels=None, 

140 labels=None, 

141 lines=None, 

142 elevation=0.1, 

143 ): 

144 """Draw boundaries around roi. 

145 

146 Parameters 

147 ---------- 

148 roi_map : :obj:`str` or :class:`numpy.ndarray` or :obj:`list` of \ 

149 :class:`numpy.ndarray` or\ 

150 :obj:`~nilearn.surface.SurfaceImage` 

151 ROI map to be displayed on the surface 

152 mesh, can be a file (valid formats are .gii, .mgz, .nii, 

153 .nii.gz, or FreeSurfer specific files such as .annot or .label), 

154 or a Numpy array with a value for each vertex of the surf_mesh. 

155 The value at each vertex is one inside the ROI and zero outside 

156 the ROI, or an :obj:`int` giving the label number for atlases. 

157 

158 levels : :obj:`list` of :obj:`int`, or :obj:`None`, default=None 

159 A :obj:`list` of indices of the regions that are to be outlined. 

160 Every index needs to correspond to one index in roi_map. 

161 If :obj:`None`, all regions in roi_map are used. 

162 

163 labels : :obj:`list` of :obj:`str` or :obj:`None`, default=None 

164 A :obj:`list` of labels for the individual regions of interest. 

165 Provide :obj:`None` as list entry to skip showing the label of 

166 that region. If :obj:`None`, no labels are used. 

167 

168 lines : :obj:`list` of :obj:`dict` giving the properties of the \ 

169 contours, or :obj:`None`, default=None 

170 For valid keys, see :attr:`plotly.graph_objects.Scatter3d.line`. 

171 If length 1, the properties defined in that element will be used 

172 to draw all requested contours. 

173 

174 elevations : :obj:`float`, default=0.1 

175 Controls how high above the face each boundary should be placed. 

176 0.0 implies directly on boundary, and higher values are farther 

177 above the face. This is useful for avoiding overlap of surface 

178 and boundary. 

179 

180 Warnings 

181 -------- 

182 Warns when a vertex is isolated; it will not be included in the 

183 roi contour. 

184 

185 Notes 

186 ----- 

187 Regions are traced by connecting the centroids of non-isolated 

188 faces (triangles). 

189 """ 

190 if isinstance(roi_map, SurfaceImage): 

191 assert len(roi_map.shape) == 1 or roi_map.shape[1] == 1 

192 if self.hemi in ["left", "right"]: 

193 roi_map = roi_map.data.parts[self.hemi] 

194 elif self.hemi == "both": 

195 roi_map = get_data(roi_map) 

196 

197 if levels is None: 

198 levels = np.unique(roi_map) 

199 if labels is None: 

200 labels = [f"Region {i}" for i, _ in enumerate(levels)] 

201 if lines is None: 

202 lines = [None] * len(levels) 

203 elif len(lines) == 1 and len(levels) > 1: 

204 lines *= len(levels) 

205 if len(levels) != len(labels): 

206 raise ValueError( 

207 "levels and labels need to be either the same length or None." 

208 ) 

209 if len(levels) != len(lines): 

210 raise ValueError( 

211 "levels and lines need to be either the same length or None." 

212 ) 

213 

214 roi = load_surf_data(roi_map) 

215 

216 traces = [] 

217 for level, label, line in zip(levels, labels, lines): 

218 parc_idx = np.where(roi == level)[0] 

219 

220 # warn when the edge faces exclude vertices in parcellation 

221 # a vertex is isolated when it is a vertex of 6 faces that 

222 # each have only on vertex in parcellation 

223 verts_per_face = np.isin(self._faces, parc_idx).sum(axis=1) 

224 faces_w_one_v = np.flatnonzero(verts_per_face == 1) 

225 unique_v, unique_v_counts = np.unique( 

226 self._faces[faces_w_one_v], return_counts=True 

227 ) 

228 isolated_v = unique_v[unique_v_counts == 6] 

229 if any(isolated_v): 

230 warnings.warn( 

231 f"""{label=} contains isolated vertices: 

232 {isolated_v.tolist()}. These will not be included in ROI 

233 boundary line.""", 

234 stacklevel=find_stack_level(), 

235 ) 

236 

237 sorted_vertices = self._get_sorted_edge_centroids( 

238 parc_idx=parc_idx, elevation=elevation 

239 ) 

240 

241 traces.append( 

242 go.Scatter3d( 

243 x=sorted_vertices[:, 0], 

244 y=sorted_vertices[:, 1], 

245 z=sorted_vertices[:, 2], 

246 mode="lines", 

247 line=line, 

248 name=label, 

249 ) 

250 ) 

251 self.figure.add_traces(data=traces) 

252 

253 def _get_sorted_edge_centroids(self, parc_idx, elevation=0.1): 

254 """Identify which vertices lie on the outer edge of a parcellation. 

255 

256 Parameters 

257 ---------- 

258 parc_idx : :class:`numpy.ndarray` 

259 Indices of the vertices of the region to be plotted. 

260 

261 elevation : :obj:`float` 

262 Controls how high above the face each centroid should be placed. 

263 0.0 implies directly on boundary, and higher values are farther 

264 above the face. This is useful for avoiding overlap of surface 

265 and boundary. 

266 

267 Returns 

268 ------- 

269 sorted_vertices : :class:`numpy.ndarray` 

270 (n_vertices, s) x,y,z coordinates of vertices that trace region 

271 of interest. 

272 

273 Notes 

274 ----- 

275 For each face on the edge of a region 

276 1. Get a centroid for each face (parallel to the triangle plane) 

277 2. Find the xyz coordinate that is normal to the triangle face 

278 (at distance `elevation`) 

279 3. Arrange the centroids such in a good order for plotting 

280 """ 

281 # Mask indicating faces whose centroids will compose the boundary. 

282 edge_faces = get_faces_on_edge(faces=self._faces, parc_idx=parc_idx) 

283 

284 # gather the centroids of each face 

285 centroids = [] 

286 segments = [] 

287 vs = [] 

288 idxs = [] 

289 for e, face in zip(edge_faces, self._faces): 

290 if e: 

291 t0 = self._coords[face[0]] 

292 t1 = self._coords[face[1]] 

293 t2 = self._coords[face[2]] 

294 

295 # the xyz coordinate is weighted toward the roi boundary (2:1) 

296 w0 = 2 if face[0] in parc_idx else 1 

297 w1 = 2 if face[1] in parc_idx else 1 

298 w2 = 2 if face[2] in parc_idx else 1 

299 x = np.average((t0[0], t1[0], t2[0]), weights=(w0, w1, w2)) 

300 y = np.average((t0[1], t1[1], t2[1]), weights=(w0, w1, w2)) 

301 z = np.average((t0[2], t1[2], t2[2]), weights=(w0, w1, w2)) 

302 centroids.append( 

303 self._project_above_face( 

304 np.array((x, y, z)), t0, t1, t2, elevation=elevation 

305 ) 

306 ) 

307 segs = [None] * 3 

308 if face[0] in parc_idx and face[1] in parc_idx: 

309 segs[0] = self._transform_coord_to_plane( 

310 t0, t0, t1, t2 

311 ) + self._transform_coord_to_plane(t1, t0, t1, t2) 

312 if face[0] in parc_idx and face[2] in parc_idx: 

313 segs[1] = self._transform_coord_to_plane( 

314 t0, t0, t1, t2 

315 ) + self._transform_coord_to_plane(t2, t0, t1, t2) 

316 if face[1] in parc_idx and face[2] in parc_idx: 

317 segs[2] = self._transform_coord_to_plane( 

318 t2, t0, t1, t2 

319 ) + self._transform_coord_to_plane(t1, t0, t1, t2) 

320 segments.append(tuple(segs)) 

321 vs.append((t0, t1, t2)) 

322 idxs.append([f for f in face if f in parc_idx]) 

323 

324 centroids = np.array(centroids) 

325 

326 # Next, sort centroids along boundary 

327 # Start with the first vertex 

328 current_vertex = 0 

329 visited_vertices = {current_vertex} 

330 last_distance = np.inf 

331 prev_first = 0 

332 

333 sorted_vertices = [centroids[0]] 

334 

335 # Loop over the remaining vertices in order of distance from the 

336 # current vertex 

337 for _ in range(1, len(centroids)): 

338 remaining_vertices = np.array( 

339 [ 

340 vertex 

341 for vertex in range(len(centroids)) 

342 if vertex not in visited_vertices 

343 ] 

344 ) 

345 remaining_distances = distance_matrix( 

346 centroids[current_vertex].reshape(1, -1), 

347 centroids[remaining_vertices], 

348 ) 

349 # Occasionally, the next closest centroid is one that would 

350 # cause a loop. This is common when a vertex is a neighbor 

351 # of only one other vertex in the roi (the loop encircles 

352 # this corner vertex). So, the next added centroid is one 

353 # that may be slightly farther away -- if the one that is 

354 # farther away has fewer vertices within the roi. 

355 

356 # from the current vertex, there are only at most 5 options 

357 # that will be good jumps 

358 n_jumps_remaining = min(5, max(0, len(remaining_vertices) - 1)) 

359 smallest_idx = np.argpartition( 

360 remaining_distances.squeeze(), n_jumps_remaining 

361 )[:n_jumps_remaining] 

362 xy1 = self._transform_coord_to_plane( 

363 centroids[current_vertex], *vs[current_vertex] 

364 ) 

365 next_index = -1 

366 for attempt in np.argsort(remaining_distances[0, smallest_idx]): 

367 fail = False 

368 shortest_idx = smallest_idx[attempt] 

369 xy2 = self._transform_coord_to_plane( 

370 centroids[remaining_vertices[shortest_idx]], 

371 *vs[current_vertex], 

372 ) 

373 if all( 

374 v not in idxs[current_vertex] 

375 for v in idxs[remaining_vertices[shortest_idx]] 

376 ): 

377 # this does not share vertex, so try again 

378 fail |= True 

379 if fail: 

380 continue 

381 # also need to test for whether an edge is shared 

382 shared = 0 

383 for v in vs[remaining_vertices[shortest_idx]]: 

384 for v2 in vs[current_vertex]: 

385 shared += np.all(np.isclose(v, v2)) 

386 if shared < 2: 

387 # this does not share and edge, so try again 

388 continue 

389 for e in segments[current_vertex]: 

390 if e is not None and self._do_segs_intersect( 

391 *xy1, *xy2, *e 

392 ): 

393 # this one crosses boundary, so try again 

394 fail |= True 

395 if fail: 

396 continue 

397 next_index = shortest_idx 

398 

399 # if none of those five worked, then just pick the next nearest 

400 if next_index == -1: 

401 next_index = np.argmin(remaining_distances) 

402 

403 closest_vertex = remaining_vertices[next_index] 

404 

405 # some regions have multiple, non-isolated vertices 

406 # this block detects that by checking whether the next 

407 # vertex is very far away 

408 if remaining_distances[0, next_index] > last_distance * 3: 

409 # close the current contour 

410 # add triple of None, which is parsed by plotly 

411 # as a signal to start a new closed contour 

412 sorted_vertices.extend( 

413 (centroids[prev_first], np.array([None] * 3)) 

414 ) 

415 # start the new contour 

416 prev_first = closest_vertex 

417 

418 visited_vertices.add(closest_vertex) 

419 sorted_vertices.append(centroids[closest_vertex]) 

420 

421 # Move to the closest vertex and repeat the process 

422 current_vertex = closest_vertex 

423 last_distance = remaining_distances[0, next_index] 

424 

425 # append the first one again to close the outline 

426 sorted_vertices.append(centroids[prev_first]) 

427 

428 return np.asarray(sorted_vertices) 

429 

430 @staticmethod 

431 def _project_above_face(point, t0, t1, t2, elevation=0.1): 

432 """Given 3d coordinates `point`, report coordinates that define \ 

433 the closest point that is `elevation` above (normal to) \ 

434 the plane defined by vertices `t0`, `t1`, and `t2`. 

435 """ 

436 u = t1 - t0 

437 v = t2 - t0 

438 # vector normal to plane 

439 n = np.cross(u, v) 

440 n /= np.linalg.norm(n) 

441 p_ = point - t0 

442 

443 p_normal = np.dot(p_, n) * n 

444 p_tangent = p_ - p_normal 

445 

446 closest_point = p_tangent + t0 

447 return closest_point + elevation * n 

448 

449 @staticmethod 

450 def _do_segs_intersect(x1, y1, x2, y2, x3, y3, x4, y4): 

451 """Check whether line segments intersect. 

452 

453 Parameters 

454 ---------- 

455 x1 : :obj:`float` or :obj:`int` 

456 First coordinate of first segment beginning. 

457 y1 : :obj:`float` or :obj:`int` 

458 Second coordinate of first segment beginning. 

459 x2 : :obj:`float` or :obj:`int` 

460 First coordinate of first segment end. 

461 y2 : :obj:`float` or :obj:`int` 

462 Second coordinate of first segment end. 

463 x3 : :obj:`float` or :obj:`int` 

464 First coordinate of second segment beginning. 

465 y3 : :obj:`float` or :obj:`int` 

466 Second coordinate of second segment beginning. 

467 x4 : :obj:`float` or :obj:`int` 

468 First coordinate of second segment end. 

469 y4 : :obj:`float` or :obj:`int` 

470 Second coordinate of second segment end. 

471 

472 Returns 

473 ------- 

474 check: :obj:`bool` 

475 True if segments intersect, otherwise False 

476 

477 Notes 

478 ----- 

479 Implements an algorithm described here 

480 https://en.wikipedia.org/w/index.php?title=Intersection_(geometry)&oldid=1215046212#Two_line_segments 

481 """ 

482 a1 = x1 - x2 

483 b1 = x3 - x4 

484 c1 = x1 - x3 

485 a2 = y1 - y2 

486 b2 = y3 - y4 

487 c2 = y1 - y3 

488 d = a1 * b2 - a2 * b1 

489 if np.isclose(d, 0): 

490 return False 

491 t = (c1 * b2 - c2 * b1) / d 

492 u = (c1 * a2 - c2 * a1) / d 

493 return 0 <= t <= 1 and 0 <= u <= 1 

494 

495 @staticmethod 

496 def _transform_coord_to_plane(v, t0, t1, t2): 

497 """Given 3d point `v`, find closest point on plane defined \ 

498 by vertices `t0`, `t1`, and `t2`. 

499 """ 

500 A = linalg.orth(np.column_stack((t1 - t0, t2 - t0))) 

501 normal = np.cross(A[:, 0], A[:, 1]) 

502 normal /= np.linalg.norm(normal) 

503 B = np.column_stack((A, normal)) 

504 Bp = np.linalg.inv(B) 

505 P = B @ np.diag((1, 1, 0)) @ Bp 

506 return tuple(Bp[:2, :] @ (t0 + P @ (v - t0)))