Coverage for nilearn/plotting/displays/_figures.py: 0%
187 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
1import warnings
3import numpy as np
4from scipy import linalg
5from scipy.spatial import distance_matrix
7from nilearn._utils.helpers import is_kaleido_installed, is_plotly_installed
8from nilearn._utils.logger import find_stack_level
9from nilearn.plotting.surface._utils import DEFAULT_HEMI, get_faces_on_edge
10from nilearn.surface import SurfaceImage
11from nilearn.surface.surface import get_data, load_surf_data
13if is_plotly_installed():
14 import plotly.graph_objects as go
17class SurfaceFigure:
18 """Abstract class for surface figures.
20 Parameters
21 ----------
22 figure : Figure instance or ``None``, optional
23 Figure to be wrapped.
25 output_file : :obj:`str` or ``None``, optional
26 Path to output file.
27 """
29 def __init__(self, figure=None, output_file=None, hemi=DEFAULT_HEMI):
30 self.figure = figure
31 self.output_file = output_file
32 self.hemi = hemi
34 def show(self):
35 """Show the figure."""
36 raise NotImplementedError
38 def _check_output_file(self, output_file=None):
39 """If an output file is provided, \
40 set it as the new default output file.
42 Parameters
43 ----------
44 output_file : :obj:`str` or ``None``, optional
45 Path to output file.
46 """
47 if output_file is None:
48 if self.output_file is None:
49 raise ValueError(
50 "You must provide an output file name to save the figure."
51 )
52 else:
53 self.output_file = output_file
55 def add_contours(self):
56 """Draw boundaries around roi."""
57 raise NotImplementedError
60class PlotlySurfaceFigure(SurfaceFigure):
61 """Implementation of a surface figure obtained with `plotly` engine.
63 Parameters
64 ----------
65 figure : Plotly figure instance or ``None``, optional
66 Plotly figure instance to be used.
68 output_file : :obj:`str` or ``None``, optional
69 Output file path.
71 Attributes
72 ----------
73 figure : Plotly figure instance
74 Plotly figure. Use this attribute to access the underlying
75 plotly figure for further customization and use plotly
76 functionality.
78 output_file : :obj:`str`
79 Output file path.
81 """
83 @property
84 def _faces(self):
85 return np.vstack(
86 [self.figure._data[0].get(d) for d in ["i", "j", "k"]]
87 ).T
89 @property
90 def _coords(self):
91 return np.vstack(
92 [self.figure._data[0].get(d) for d in ["x", "y", "z"]]
93 ).T
95 def __init__(self, figure=None, output_file=None, hemi=DEFAULT_HEMI):
96 if not is_plotly_installed():
97 raise ImportError(
98 "Plotly is required to use `PlotlySurfaceFigure`."
99 )
101 if figure is not None and not isinstance(figure, go.Figure):
102 raise TypeError(
103 "`PlotlySurfaceFigure` accepts only plotly Figure objects."
104 )
105 super().__init__(figure=figure, output_file=output_file, hemi=hemi)
107 def show(self, renderer="browser"):
108 """Show the figure.
110 Parameters
111 ----------
112 renderer : :obj:`str`, default='browser'
113 Plotly renderer to be used.
115 """
116 if self.figure is not None:
117 self.figure.show(renderer=renderer)
118 return self.figure
120 def savefig(self, output_file=None):
121 """Save the figure to file.
123 Parameters
124 ----------
125 output_file : :obj:`str` or ``None``, optional
126 Path to output file.
127 """
128 if not is_kaleido_installed():
129 raise ImportError(
130 "`kaleido` is required to save plotly figures to disk."
131 )
132 self._check_output_file(output_file=output_file)
133 if self.figure is not None:
134 self.figure.write_image(self.output_file)
136 def add_contours(
137 self,
138 roi_map,
139 levels=None,
140 labels=None,
141 lines=None,
142 elevation=0.1,
143 ):
144 """Draw boundaries around roi.
146 Parameters
147 ----------
148 roi_map : :obj:`str` or :class:`numpy.ndarray` or :obj:`list` of \
149 :class:`numpy.ndarray` or\
150 :obj:`~nilearn.surface.SurfaceImage`
151 ROI map to be displayed on the surface
152 mesh, can be a file (valid formats are .gii, .mgz, .nii,
153 .nii.gz, or FreeSurfer specific files such as .annot or .label),
154 or a Numpy array with a value for each vertex of the surf_mesh.
155 The value at each vertex is one inside the ROI and zero outside
156 the ROI, or an :obj:`int` giving the label number for atlases.
158 levels : :obj:`list` of :obj:`int`, or :obj:`None`, default=None
159 A :obj:`list` of indices of the regions that are to be outlined.
160 Every index needs to correspond to one index in roi_map.
161 If :obj:`None`, all regions in roi_map are used.
163 labels : :obj:`list` of :obj:`str` or :obj:`None`, default=None
164 A :obj:`list` of labels for the individual regions of interest.
165 Provide :obj:`None` as list entry to skip showing the label of
166 that region. If :obj:`None`, no labels are used.
168 lines : :obj:`list` of :obj:`dict` giving the properties of the \
169 contours, or :obj:`None`, default=None
170 For valid keys, see :attr:`plotly.graph_objects.Scatter3d.line`.
171 If length 1, the properties defined in that element will be used
172 to draw all requested contours.
174 elevations : :obj:`float`, default=0.1
175 Controls how high above the face each boundary should be placed.
176 0.0 implies directly on boundary, and higher values are farther
177 above the face. This is useful for avoiding overlap of surface
178 and boundary.
180 Warnings
181 --------
182 Warns when a vertex is isolated; it will not be included in the
183 roi contour.
185 Notes
186 -----
187 Regions are traced by connecting the centroids of non-isolated
188 faces (triangles).
189 """
190 if isinstance(roi_map, SurfaceImage):
191 assert len(roi_map.shape) == 1 or roi_map.shape[1] == 1
192 if self.hemi in ["left", "right"]:
193 roi_map = roi_map.data.parts[self.hemi]
194 elif self.hemi == "both":
195 roi_map = get_data(roi_map)
197 if levels is None:
198 levels = np.unique(roi_map)
199 if labels is None:
200 labels = [f"Region {i}" for i, _ in enumerate(levels)]
201 if lines is None:
202 lines = [None] * len(levels)
203 elif len(lines) == 1 and len(levels) > 1:
204 lines *= len(levels)
205 if len(levels) != len(labels):
206 raise ValueError(
207 "levels and labels need to be either the same length or None."
208 )
209 if len(levels) != len(lines):
210 raise ValueError(
211 "levels and lines need to be either the same length or None."
212 )
214 roi = load_surf_data(roi_map)
216 traces = []
217 for level, label, line in zip(levels, labels, lines):
218 parc_idx = np.where(roi == level)[0]
220 # warn when the edge faces exclude vertices in parcellation
221 # a vertex is isolated when it is a vertex of 6 faces that
222 # each have only on vertex in parcellation
223 verts_per_face = np.isin(self._faces, parc_idx).sum(axis=1)
224 faces_w_one_v = np.flatnonzero(verts_per_face == 1)
225 unique_v, unique_v_counts = np.unique(
226 self._faces[faces_w_one_v], return_counts=True
227 )
228 isolated_v = unique_v[unique_v_counts == 6]
229 if any(isolated_v):
230 warnings.warn(
231 f"""{label=} contains isolated vertices:
232 {isolated_v.tolist()}. These will not be included in ROI
233 boundary line.""",
234 stacklevel=find_stack_level(),
235 )
237 sorted_vertices = self._get_sorted_edge_centroids(
238 parc_idx=parc_idx, elevation=elevation
239 )
241 traces.append(
242 go.Scatter3d(
243 x=sorted_vertices[:, 0],
244 y=sorted_vertices[:, 1],
245 z=sorted_vertices[:, 2],
246 mode="lines",
247 line=line,
248 name=label,
249 )
250 )
251 self.figure.add_traces(data=traces)
253 def _get_sorted_edge_centroids(self, parc_idx, elevation=0.1):
254 """Identify which vertices lie on the outer edge of a parcellation.
256 Parameters
257 ----------
258 parc_idx : :class:`numpy.ndarray`
259 Indices of the vertices of the region to be plotted.
261 elevation : :obj:`float`
262 Controls how high above the face each centroid should be placed.
263 0.0 implies directly on boundary, and higher values are farther
264 above the face. This is useful for avoiding overlap of surface
265 and boundary.
267 Returns
268 -------
269 sorted_vertices : :class:`numpy.ndarray`
270 (n_vertices, s) x,y,z coordinates of vertices that trace region
271 of interest.
273 Notes
274 -----
275 For each face on the edge of a region
276 1. Get a centroid for each face (parallel to the triangle plane)
277 2. Find the xyz coordinate that is normal to the triangle face
278 (at distance `elevation`)
279 3. Arrange the centroids such in a good order for plotting
280 """
281 # Mask indicating faces whose centroids will compose the boundary.
282 edge_faces = get_faces_on_edge(faces=self._faces, parc_idx=parc_idx)
284 # gather the centroids of each face
285 centroids = []
286 segments = []
287 vs = []
288 idxs = []
289 for e, face in zip(edge_faces, self._faces):
290 if e:
291 t0 = self._coords[face[0]]
292 t1 = self._coords[face[1]]
293 t2 = self._coords[face[2]]
295 # the xyz coordinate is weighted toward the roi boundary (2:1)
296 w0 = 2 if face[0] in parc_idx else 1
297 w1 = 2 if face[1] in parc_idx else 1
298 w2 = 2 if face[2] in parc_idx else 1
299 x = np.average((t0[0], t1[0], t2[0]), weights=(w0, w1, w2))
300 y = np.average((t0[1], t1[1], t2[1]), weights=(w0, w1, w2))
301 z = np.average((t0[2], t1[2], t2[2]), weights=(w0, w1, w2))
302 centroids.append(
303 self._project_above_face(
304 np.array((x, y, z)), t0, t1, t2, elevation=elevation
305 )
306 )
307 segs = [None] * 3
308 if face[0] in parc_idx and face[1] in parc_idx:
309 segs[0] = self._transform_coord_to_plane(
310 t0, t0, t1, t2
311 ) + self._transform_coord_to_plane(t1, t0, t1, t2)
312 if face[0] in parc_idx and face[2] in parc_idx:
313 segs[1] = self._transform_coord_to_plane(
314 t0, t0, t1, t2
315 ) + self._transform_coord_to_plane(t2, t0, t1, t2)
316 if face[1] in parc_idx and face[2] in parc_idx:
317 segs[2] = self._transform_coord_to_plane(
318 t2, t0, t1, t2
319 ) + self._transform_coord_to_plane(t1, t0, t1, t2)
320 segments.append(tuple(segs))
321 vs.append((t0, t1, t2))
322 idxs.append([f for f in face if f in parc_idx])
324 centroids = np.array(centroids)
326 # Next, sort centroids along boundary
327 # Start with the first vertex
328 current_vertex = 0
329 visited_vertices = {current_vertex}
330 last_distance = np.inf
331 prev_first = 0
333 sorted_vertices = [centroids[0]]
335 # Loop over the remaining vertices in order of distance from the
336 # current vertex
337 for _ in range(1, len(centroids)):
338 remaining_vertices = np.array(
339 [
340 vertex
341 for vertex in range(len(centroids))
342 if vertex not in visited_vertices
343 ]
344 )
345 remaining_distances = distance_matrix(
346 centroids[current_vertex].reshape(1, -1),
347 centroids[remaining_vertices],
348 )
349 # Occasionally, the next closest centroid is one that would
350 # cause a loop. This is common when a vertex is a neighbor
351 # of only one other vertex in the roi (the loop encircles
352 # this corner vertex). So, the next added centroid is one
353 # that may be slightly farther away -- if the one that is
354 # farther away has fewer vertices within the roi.
356 # from the current vertex, there are only at most 5 options
357 # that will be good jumps
358 n_jumps_remaining = min(5, max(0, len(remaining_vertices) - 1))
359 smallest_idx = np.argpartition(
360 remaining_distances.squeeze(), n_jumps_remaining
361 )[:n_jumps_remaining]
362 xy1 = self._transform_coord_to_plane(
363 centroids[current_vertex], *vs[current_vertex]
364 )
365 next_index = -1
366 for attempt in np.argsort(remaining_distances[0, smallest_idx]):
367 fail = False
368 shortest_idx = smallest_idx[attempt]
369 xy2 = self._transform_coord_to_plane(
370 centroids[remaining_vertices[shortest_idx]],
371 *vs[current_vertex],
372 )
373 if all(
374 v not in idxs[current_vertex]
375 for v in idxs[remaining_vertices[shortest_idx]]
376 ):
377 # this does not share vertex, so try again
378 fail |= True
379 if fail:
380 continue
381 # also need to test for whether an edge is shared
382 shared = 0
383 for v in vs[remaining_vertices[shortest_idx]]:
384 for v2 in vs[current_vertex]:
385 shared += np.all(np.isclose(v, v2))
386 if shared < 2:
387 # this does not share and edge, so try again
388 continue
389 for e in segments[current_vertex]:
390 if e is not None and self._do_segs_intersect(
391 *xy1, *xy2, *e
392 ):
393 # this one crosses boundary, so try again
394 fail |= True
395 if fail:
396 continue
397 next_index = shortest_idx
399 # if none of those five worked, then just pick the next nearest
400 if next_index == -1:
401 next_index = np.argmin(remaining_distances)
403 closest_vertex = remaining_vertices[next_index]
405 # some regions have multiple, non-isolated vertices
406 # this block detects that by checking whether the next
407 # vertex is very far away
408 if remaining_distances[0, next_index] > last_distance * 3:
409 # close the current contour
410 # add triple of None, which is parsed by plotly
411 # as a signal to start a new closed contour
412 sorted_vertices.extend(
413 (centroids[prev_first], np.array([None] * 3))
414 )
415 # start the new contour
416 prev_first = closest_vertex
418 visited_vertices.add(closest_vertex)
419 sorted_vertices.append(centroids[closest_vertex])
421 # Move to the closest vertex and repeat the process
422 current_vertex = closest_vertex
423 last_distance = remaining_distances[0, next_index]
425 # append the first one again to close the outline
426 sorted_vertices.append(centroids[prev_first])
428 return np.asarray(sorted_vertices)
430 @staticmethod
431 def _project_above_face(point, t0, t1, t2, elevation=0.1):
432 """Given 3d coordinates `point`, report coordinates that define \
433 the closest point that is `elevation` above (normal to) \
434 the plane defined by vertices `t0`, `t1`, and `t2`.
435 """
436 u = t1 - t0
437 v = t2 - t0
438 # vector normal to plane
439 n = np.cross(u, v)
440 n /= np.linalg.norm(n)
441 p_ = point - t0
443 p_normal = np.dot(p_, n) * n
444 p_tangent = p_ - p_normal
446 closest_point = p_tangent + t0
447 return closest_point + elevation * n
449 @staticmethod
450 def _do_segs_intersect(x1, y1, x2, y2, x3, y3, x4, y4):
451 """Check whether line segments intersect.
453 Parameters
454 ----------
455 x1 : :obj:`float` or :obj:`int`
456 First coordinate of first segment beginning.
457 y1 : :obj:`float` or :obj:`int`
458 Second coordinate of first segment beginning.
459 x2 : :obj:`float` or :obj:`int`
460 First coordinate of first segment end.
461 y2 : :obj:`float` or :obj:`int`
462 Second coordinate of first segment end.
463 x3 : :obj:`float` or :obj:`int`
464 First coordinate of second segment beginning.
465 y3 : :obj:`float` or :obj:`int`
466 Second coordinate of second segment beginning.
467 x4 : :obj:`float` or :obj:`int`
468 First coordinate of second segment end.
469 y4 : :obj:`float` or :obj:`int`
470 Second coordinate of second segment end.
472 Returns
473 -------
474 check: :obj:`bool`
475 True if segments intersect, otherwise False
477 Notes
478 -----
479 Implements an algorithm described here
480 https://en.wikipedia.org/w/index.php?title=Intersection_(geometry)&oldid=1215046212#Two_line_segments
481 """
482 a1 = x1 - x2
483 b1 = x3 - x4
484 c1 = x1 - x3
485 a2 = y1 - y2
486 b2 = y3 - y4
487 c2 = y1 - y3
488 d = a1 * b2 - a2 * b1
489 if np.isclose(d, 0):
490 return False
491 t = (c1 * b2 - c2 * b1) / d
492 u = (c1 * a2 - c2 * a1) / d
493 return 0 <= t <= 1 and 0 <= u <= 1
495 @staticmethod
496 def _transform_coord_to_plane(v, t0, t1, t2):
497 """Given 3d point `v`, find closest point on plane defined \
498 by vertices `t0`, `t1`, and `t2`.
499 """
500 A = linalg.orth(np.column_stack((t1 - t0, t2 - t0)))
501 normal = np.cross(A[:, 0], A[:, 1])
502 normal /= np.linalg.norm(normal)
503 B = np.column_stack((A, normal))
504 Bp = np.linalg.inv(B)
505 P = B @ np.diag((1, 1, 0)) @ Bp
506 return tuple(Bp[:2, :] @ (t0 + P @ (v - t0)))