Coverage for nilearn/plotting/js_plotting_utils.py: 0%

82 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-16 12:32 +0200

1"""Helps for views, i.e. interactive plots from html_surface and \ 

2html_connectome. 

3""" 

4 

5import base64 

6import warnings 

7from pathlib import Path 

8from string import Template 

9 

10import matplotlib as mpl 

11import matplotlib.pyplot as plt 

12import numpy as np 

13 

14from nilearn._utils.extmath import fast_abs_percentile 

15from nilearn._utils.html_document import ( # noqa: F401 

16 HTMLDocument, 

17 set_max_img_views_before_warning, 

18) 

19from nilearn._utils.logger import find_stack_level 

20from nilearn._utils.param_validation import check_threshold 

21from nilearn.surface import load_surf_mesh 

22 

23MAX_IMG_VIEWS_BEFORE_WARNING = 10 

24 

25 

26def add_js_lib(html, embed_js=True): 

27 """Add javascript libraries to html template. 

28 

29 If embed_js is True, jquery and plotly are embedded in resulting page. 

30 otherwise, they are loaded via CDNs. 

31 

32 """ 

33 js_dir = Path(__file__).parent / "data" / "js" 

34 with (js_dir / "surface-plot-utils.js").open() as f: 

35 js_utils = f.read() 

36 if not embed_js: 

37 js_lib = f""" 

38 <script 

39 src="https://ajax.googleapis.com/ajax/libs/jquery/3.6.0/jquery.min.js"> 

40 </script> 

41 <script src="https://cdn.plot.ly/plotly-gl3d-latest.min.js"></script> 

42 <script> 

43 {js_utils} 

44 </script> 

45 """ 

46 else: 

47 with (js_dir / "jquery.min.js").open() as f: 

48 jquery = f.read() 

49 with (js_dir / "plotly-gl3d-latest.min.js").open() as f: 

50 plotly = f.read() 

51 js_lib = f""" 

52 <script>{jquery}</script> 

53 <script>{plotly}</script> 

54 <script> 

55 {js_utils} 

56 </script> 

57 """ 

58 if not isinstance(html, Template): 

59 html = Template(html) 

60 return html.safe_substitute({"INSERT_JS_LIBRARIES_HERE": js_lib}) 

61 

62 

63def get_html_template(template_name): 

64 """Get an HTML file from package data.""" 

65 template_path = Path(__file__).parent / "data" / "html" / template_name 

66 

67 with template_path.open("rb") as f: 

68 return Template(f.read().decode("utf-8")) 

69 

70 

71def colorscale( 

72 cmap, values, threshold=None, symmetric_cmap=True, vmax=None, vmin=None 

73): 

74 """Normalize a cmap, put it in plotly format, get threshold and range.""" 

75 cmap = plt.get_cmap(cmap) 

76 abs_values = np.abs(values) 

77 if not symmetric_cmap and (values.min() < 0): 

78 warnings.warn( 

79 "you have specified symmetric_cmap=False " 

80 "but the map contains negative values; " 

81 "setting symmetric_cmap to True", 

82 stacklevel=find_stack_level(), 

83 ) 

84 symmetric_cmap = True 

85 if symmetric_cmap and vmin is not None: 

86 warnings.warn( 

87 "vmin cannot be chosen when cmap is symmetric", 

88 stacklevel=find_stack_level(), 

89 ) 

90 vmin = None 

91 if vmax is None: 

92 vmax = abs_values.max() 

93 # cast to float to avoid TypeError if vmax is a numpy boolean 

94 vmax = float(vmax) 

95 if symmetric_cmap: 

96 vmin = -vmax 

97 if vmin is None: 

98 vmin = values.min() 

99 norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) 

100 cmaplist = [cmap(i) for i in range(cmap.N)] 

101 abs_threshold = None 

102 if threshold is not None: 

103 abs_threshold = check_threshold(threshold, values, fast_abs_percentile) 

104 istart = int(norm(-abs_threshold, clip=True) * (cmap.N - 1)) 

105 istop = int(norm(abs_threshold, clip=True) * (cmap.N - 1)) 

106 for i in range(istart, istop): 

107 cmaplist[i] = (0.5, 0.5, 0.5, 1.0) # just an average gray color 

108 our_cmap = mpl.colors.LinearSegmentedColormap.from_list( 

109 "Custom cmap", cmaplist, cmap.N 

110 ) 

111 x = np.linspace(0, 1, 100) 

112 rgb = our_cmap(x, bytes=True)[:, :3] 

113 rgb = np.array(rgb, dtype=int) 

114 colors = [ 

115 [np.round(i, 3), f"rgb({col[0]}, {col[1]}, {col[2]})"] 

116 for i, col in zip(x, rgb) 

117 ] 

118 return { 

119 "colors": colors, 

120 "vmin": vmin, 

121 "vmax": vmax, 

122 "cmap": our_cmap, 

123 "norm": norm, 

124 "abs_threshold": abs_threshold, 

125 "symmetric_cmap": symmetric_cmap, 

126 } 

127 

128 

129def encode(a): 

130 """Base64 encode a numpy array.""" 

131 try: 

132 data = a.tobytes() 

133 except AttributeError: 

134 # np < 1.9 

135 data = a.tostring() 

136 return base64.b64encode(data).decode("utf-8") 

137 

138 

139def decode(b, dtype): 

140 """Decode a numpy array encoded as Base64.""" 

141 return np.frombuffer(base64.b64decode(b.encode("utf-8")), dtype) 

142 

143 

144def mesh_to_plotly(mesh): 

145 """Convert a :term:`mesh` to plotly format.""" 

146 mesh = load_surf_mesh(mesh) 

147 x, y, z = map(encode, np.asarray(mesh.coordinates.T, dtype="<f4")) 

148 i, j, k = map(encode, np.asarray(mesh.faces.T, dtype="<i4")) 

149 info = { 

150 "_x": x, 

151 "_y": y, 

152 "_z": z, 

153 "_i": i, 

154 "_j": j, 

155 "_k": k, 

156 } 

157 return info 

158 

159 

160def to_color_strings(colors): 

161 """Return a list of colors as hex strings.""" 

162 cmap = mpl.colors.ListedColormap(colors) 

163 colors = cmap(np.arange(cmap.N))[:, :3] 

164 colors = np.asarray(colors * 255, dtype="uint8") 

165 colors = [ 

166 f"#{int(row[0]):02x}{int(row[1]):02x}{int(row[2]):02x}" 

167 for row in colors 

168 ] 

169 return colors