Coverage for nilearn/plotting/js_plotting_utils.py: 0%
82 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
1"""Helps for views, i.e. interactive plots from html_surface and \
2html_connectome.
3"""
5import base64
6import warnings
7from pathlib import Path
8from string import Template
10import matplotlib as mpl
11import matplotlib.pyplot as plt
12import numpy as np
14from nilearn._utils.extmath import fast_abs_percentile
15from nilearn._utils.html_document import ( # noqa: F401
16 HTMLDocument,
17 set_max_img_views_before_warning,
18)
19from nilearn._utils.logger import find_stack_level
20from nilearn._utils.param_validation import check_threshold
21from nilearn.surface import load_surf_mesh
23MAX_IMG_VIEWS_BEFORE_WARNING = 10
26def add_js_lib(html, embed_js=True):
27 """Add javascript libraries to html template.
29 If embed_js is True, jquery and plotly are embedded in resulting page.
30 otherwise, they are loaded via CDNs.
32 """
33 js_dir = Path(__file__).parent / "data" / "js"
34 with (js_dir / "surface-plot-utils.js").open() as f:
35 js_utils = f.read()
36 if not embed_js:
37 js_lib = f"""
38 <script
39 src="https://ajax.googleapis.com/ajax/libs/jquery/3.6.0/jquery.min.js">
40 </script>
41 <script src="https://cdn.plot.ly/plotly-gl3d-latest.min.js"></script>
42 <script>
43 {js_utils}
44 </script>
45 """
46 else:
47 with (js_dir / "jquery.min.js").open() as f:
48 jquery = f.read()
49 with (js_dir / "plotly-gl3d-latest.min.js").open() as f:
50 plotly = f.read()
51 js_lib = f"""
52 <script>{jquery}</script>
53 <script>{plotly}</script>
54 <script>
55 {js_utils}
56 </script>
57 """
58 if not isinstance(html, Template):
59 html = Template(html)
60 return html.safe_substitute({"INSERT_JS_LIBRARIES_HERE": js_lib})
63def get_html_template(template_name):
64 """Get an HTML file from package data."""
65 template_path = Path(__file__).parent / "data" / "html" / template_name
67 with template_path.open("rb") as f:
68 return Template(f.read().decode("utf-8"))
71def colorscale(
72 cmap, values, threshold=None, symmetric_cmap=True, vmax=None, vmin=None
73):
74 """Normalize a cmap, put it in plotly format, get threshold and range."""
75 cmap = plt.get_cmap(cmap)
76 abs_values = np.abs(values)
77 if not symmetric_cmap and (values.min() < 0):
78 warnings.warn(
79 "you have specified symmetric_cmap=False "
80 "but the map contains negative values; "
81 "setting symmetric_cmap to True",
82 stacklevel=find_stack_level(),
83 )
84 symmetric_cmap = True
85 if symmetric_cmap and vmin is not None:
86 warnings.warn(
87 "vmin cannot be chosen when cmap is symmetric",
88 stacklevel=find_stack_level(),
89 )
90 vmin = None
91 if vmax is None:
92 vmax = abs_values.max()
93 # cast to float to avoid TypeError if vmax is a numpy boolean
94 vmax = float(vmax)
95 if symmetric_cmap:
96 vmin = -vmax
97 if vmin is None:
98 vmin = values.min()
99 norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
100 cmaplist = [cmap(i) for i in range(cmap.N)]
101 abs_threshold = None
102 if threshold is not None:
103 abs_threshold = check_threshold(threshold, values, fast_abs_percentile)
104 istart = int(norm(-abs_threshold, clip=True) * (cmap.N - 1))
105 istop = int(norm(abs_threshold, clip=True) * (cmap.N - 1))
106 for i in range(istart, istop):
107 cmaplist[i] = (0.5, 0.5, 0.5, 1.0) # just an average gray color
108 our_cmap = mpl.colors.LinearSegmentedColormap.from_list(
109 "Custom cmap", cmaplist, cmap.N
110 )
111 x = np.linspace(0, 1, 100)
112 rgb = our_cmap(x, bytes=True)[:, :3]
113 rgb = np.array(rgb, dtype=int)
114 colors = [
115 [np.round(i, 3), f"rgb({col[0]}, {col[1]}, {col[2]})"]
116 for i, col in zip(x, rgb)
117 ]
118 return {
119 "colors": colors,
120 "vmin": vmin,
121 "vmax": vmax,
122 "cmap": our_cmap,
123 "norm": norm,
124 "abs_threshold": abs_threshold,
125 "symmetric_cmap": symmetric_cmap,
126 }
129def encode(a):
130 """Base64 encode a numpy array."""
131 try:
132 data = a.tobytes()
133 except AttributeError:
134 # np < 1.9
135 data = a.tostring()
136 return base64.b64encode(data).decode("utf-8")
139def decode(b, dtype):
140 """Decode a numpy array encoded as Base64."""
141 return np.frombuffer(base64.b64decode(b.encode("utf-8")), dtype)
144def mesh_to_plotly(mesh):
145 """Convert a :term:`mesh` to plotly format."""
146 mesh = load_surf_mesh(mesh)
147 x, y, z = map(encode, np.asarray(mesh.coordinates.T, dtype="<f4"))
148 i, j, k = map(encode, np.asarray(mesh.faces.T, dtype="<i4"))
149 info = {
150 "_x": x,
151 "_y": y,
152 "_z": z,
153 "_i": i,
154 "_j": j,
155 "_k": k,
156 }
157 return info
160def to_color_strings(colors):
161 """Return a list of colors as hex strings."""
162 cmap = mpl.colors.ListedColormap(colors)
163 colors = cmap(np.arange(cmap.N))[:, :3]
164 colors = np.asarray(colors * 255, dtype="uint8")
165 colors = [
166 f"#{int(row[0]):02x}{int(row[1]):02x}{int(row[2]):02x}"
167 for row in colors
168 ]
169 return colors