Coverage for excalidraw_mcp/element_factory.py: 98%

148 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-16 08:08 -0700

1"""Element factory for creating and managing Excalidraw elements.""" 

2 

3import logging 

4import uuid 

5from datetime import UTC, datetime 

6from typing import Any 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11class ElementFactory: 

12 """Factory for creating and managing Excalidraw elements.""" 

13 

14 def __init__(self) -> None: 

15 self.default_timestamp = "2025-01-01T00:00:00.000Z" 

16 

17 def create_element(self, element_data: dict[str, Any]) -> dict[str, Any]: 

18 """Create a new element with proper defaults and validation.""" 

19 # Generate unique ID 

20 element_id = str(uuid.uuid4()) 

21 

22 # Get current timestamp 

23 current_time = datetime.now(UTC).isoformat() + "Z" 

24 

25 # Base element structure 

26 element = { 

27 "id": element_id, 

28 "type": element_data.get("type", "rectangle"), 

29 "x": float(element_data.get("x", 0)), 

30 "y": float(element_data.get("y", 0)), 

31 "width": self._get_optional_float(element_data, "width"), 

32 "height": self._get_optional_float(element_data, "height"), 

33 "version": 1, 

34 "createdAt": current_time, 

35 "updatedAt": current_time, 

36 "locked": element_data.get("locked", False), 

37 } 

38 

39 # Add optional properties 

40 self._add_optional_properties(element, element_data) 

41 

42 return element 

43 

44 def prepare_update_data(self, update_data: dict[str, Any]) -> dict[str, Any]: 

45 """Prepare element data for updates.""" 

46 # Remove ID from update data (handled separately) 

47 element_id = update_data.pop("id", None) 

48 if not element_id: 

49 raise ValueError("Element ID is required for updates") 

50 

51 # Increment version and update timestamp 

52 current_time = datetime.now(UTC).isoformat() + "Z" 

53 

54 # Prepare update payload 

55 update_payload = {"id": element_id, "updatedAt": current_time} 

56 

57 # Add provided updates 

58 for key, value in update_data.items(): 

59 if key not in ("createdAt", "version"): # Protect immutable fields 

60 if key in ( 

61 "x", 

62 "y", 

63 "width", 

64 "height", 

65 "strokeWidth", 

66 "opacity", 

67 "roughness", 

68 "fontSize", 

69 ): 

70 update_payload[key] = float(value) if value is not None else value 

71 else: 

72 update_payload[key] = value 

73 

74 return update_payload 

75 

76 def _add_optional_properties( 

77 self, element: dict[str, Any], element_data: dict[str, Any] 

78 ) -> None: 

79 """Add optional properties to element based on type.""" 

80 element_type = element["type"] 

81 

82 # Text properties 

83 if element_type == "text" or element_data.get("text"): 

84 element["text"] = element_data.get("text", "") 

85 element["fontSize"] = self._get_optional_float(element_data, "fontSize", 16) 

86 element["fontFamily"] = element_data.get("fontFamily", "Cascadia, Consolas") 

87 

88 # Visual properties 

89 self._add_visual_properties(element, element_data) 

90 

91 # Shape-specific properties 

92 if element_type in ("rectangle", "ellipse", "diamond"): 

93 self._add_shape_properties(element, element_data) 

94 elif element_type in ("line", "arrow"): 

95 self._add_line_properties(element, element_data) 

96 

97 def _add_visual_properties( 

98 self, element: dict[str, Any], element_data: dict[str, Any] 

99 ) -> None: 

100 """Add visual styling properties.""" 

101 element["strokeColor"] = element_data.get("strokeColor", "#000000") 

102 element["backgroundColor"] = element_data.get("backgroundColor", "#ffffff") 

103 element["strokeWidth"] = self._get_optional_float( 

104 element_data, "strokeWidth", 2 

105 ) 

106 element["opacity"] = self._get_optional_float(element_data, "opacity", 100) 

107 element["roughness"] = self._get_optional_float(element_data, "roughness", 1) 

108 

109 def _add_shape_properties( 

110 self, element: dict[str, Any], element_data: dict[str, Any] 

111 ) -> None: 

112 """Add properties specific to shapes (rectangles, ellipses, diamonds).""" 

113 # Default dimensions for shapes 

114 if element["width"] is None: 

115 element["width"] = 100.0 

116 if element["height"] is None: 

117 element["height"] = 100.0 

118 

119 def _add_line_properties( 

120 self, element: dict[str, Any], element_data: dict[str, Any] 

121 ) -> None: 

122 """Add properties specific to lines and arrows.""" 

123 # Lines typically don't have fill 

124 element["backgroundColor"] = "transparent" 

125 

126 # Default line endpoints (if not provided, create a simple horizontal line) 

127 if element["width"] is None: 

128 element["width"] = 100.0 

129 if element["height"] is None: 

130 element["height"] = 0.0 

131 

132 def _get_optional_float( 

133 self, data: dict[str, Any], key: str, default: float | None = None 

134 ) -> float | None: 

135 """Get an optional float value from data.""" 

136 value = data.get(key, default) 

137 if value is None: 

138 return None 

139 try: 

140 return float(value) 

141 except (ValueError, TypeError): 

142 logger.warning(f"Invalid float value for {key}: {value}") 

143 return default 

144 

145 def validate_element_data(self, element_data: dict[str, Any]) -> dict[str, Any]: 

146 """Validate and normalize element data.""" 

147 errors: list[str] = [] 

148 

149 # Required fields 

150 if "type" not in element_data: 

151 errors.append("Element type is required") 

152 

153 # Validate type 

154 self._validate_element_type(element_data, errors) 

155 

156 # Validate coordinates 

157 self._validate_coordinates(element_data, errors) 

158 

159 # Validate dimensions 

160 self._validate_dimensions(element_data, errors) 

161 

162 # Validate colors 

163 self._validate_colors(element_data, errors) 

164 

165 # Validate numeric ranges 

166 self._validate_numeric_ranges(element_data, errors) 

167 

168 if errors: 

169 raise ValueError(f"Element validation failed: {'; '.join(errors)}") 

170 

171 return element_data 

172 

173 def _validate_element_type( 

174 self, element_data: dict[str, Any], errors: list[str] 

175 ) -> None: 

176 """Validate element type.""" 

177 valid_types = [ 

178 "rectangle", 

179 "ellipse", 

180 "diamond", 

181 "text", 

182 "line", 

183 "arrow", 

184 "draw", 

185 "image", 

186 "frame", 

187 "embeddable", 

188 "magicframe", 

189 ] 

190 if element_data.get("type") not in valid_types: 

191 errors.append( 

192 f"Invalid element type. Must be one of: {', '.join(valid_types)}" 

193 ) 

194 

195 def _validate_coordinates( 

196 self, element_data: dict[str, Any], errors: list[str] 

197 ) -> None: 

198 """Validate coordinates.""" 

199 for coord in ("x", "y"): 

200 if coord in element_data: 

201 try: 

202 float(element_data[coord]) 

203 except (ValueError, TypeError): 

204 errors.append(f"Invalid {coord} coordinate: must be a number") 

205 

206 def _validate_dimensions( 

207 self, element_data: dict[str, Any], errors: list[str] 

208 ) -> None: 

209 """Validate dimensions.""" 

210 for dimension in ("width", "height"): 

211 if dimension in element_data and element_data[dimension] is not None: 

212 try: 

213 value = float(element_data[dimension]) 

214 if value < 0: 

215 errors.append(f"Invalid {dimension}: must be non-negative") 

216 except (ValueError, TypeError): 

217 errors.append(f"Invalid {dimension}: must be a number") 

218 

219 def _is_valid_color(self, color: str) -> bool: 

220 """Validate hex color format.""" 

221 # Allow transparent 

222 if color.lower() == "transparent": 

223 return True 

224 

225 # Check hex color format 

226 if color.startswith("#") and len(color) == 7: 

227 try: 

228 int(color[1:], 16) 

229 return True 

230 except ValueError: 

231 return False 

232 

233 # Default case - invalid color format 

234 return False 

235 

236 def _validate_colors(self, element_data: dict[str, Any], errors: list[str]) -> None: 

237 """Validate colors.""" 

238 for color_prop in ("strokeColor", "backgroundColor"): 

239 if color_prop in element_data: 

240 color = element_data[color_prop] 

241 if color and not self._is_valid_color(color): 

242 errors.append(f"Invalid {color_prop}: must be a valid hex color") 

243 

244 def _validate_stroke_width( 

245 self, element_data: dict[str, Any], errors: list[str] 

246 ) -> None: 

247 """Validate stroke width property.""" 

248 if "strokeWidth" in element_data: 

249 try: 

250 stroke_width = float(element_data["strokeWidth"]) 

251 if not (0 <= stroke_width <= 50): 

252 errors.append("strokeWidth must be between 0 and 50") 

253 except (ValueError, TypeError): 

254 errors.append("strokeWidth must be a number") 

255 

256 def _validate_opacity( 

257 self, element_data: dict[str, Any], errors: list[str] 

258 ) -> None: 

259 """Validate opacity property.""" 

260 if "opacity" in element_data: 

261 try: 

262 opacity = float(element_data["opacity"]) 

263 if not (0 <= opacity <= 100): 

264 errors.append("opacity must be between 0 and 100") 

265 except (ValueError, TypeError): 

266 errors.append("opacity must be a number") 

267 

268 def _validate_roughness( 

269 self, element_data: dict[str, Any], errors: list[str] 

270 ) -> None: 

271 """Validate roughness property.""" 

272 if "roughness" in element_data: 

273 try: 

274 roughness = float(element_data["roughness"]) 

275 if not (0 <= roughness <= 3): 

276 errors.append("roughness must be between 0 and 3") 

277 except (ValueError, TypeError): 

278 errors.append("roughness must be a number") 

279 

280 def _validate_font_size( 

281 self, element_data: dict[str, Any], errors: list[str] 

282 ) -> None: 

283 """Validate font size property.""" 

284 if "fontSize" in element_data: 

285 try: 

286 font_size = float(element_data["fontSize"]) 

287 if not (8 <= font_size <= 200): 

288 errors.append("fontSize must be between 8 and 200") 

289 except (ValueError, TypeError): 

290 errors.append("fontSize must be a number") 

291 

292 def _validate_numeric_ranges( 

293 self, element_data: dict[str, Any], errors: list[str] 

294 ) -> None: 

295 """Validate numeric properties are within acceptable ranges.""" 

296 self._validate_stroke_width(element_data, errors) 

297 self._validate_opacity(element_data, errors) 

298 self._validate_roughness(element_data, errors) 

299 self._validate_font_size(element_data, errors)