Coverage for fastblocks/actions/gather/models.py: 41%

305 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-21 04:50 -0700

1"""Model gathering functionality to discover and collect SQLModel and Pydantic models.""" 

2 

3import inspect 

4import typing as t 

5from contextlib import suppress 

6from importlib import import_module 

7from pathlib import Path 

8 

9from acb.adapters import get_adapters, root_path 

10from acb.debug import debug 

11from anyio import Path as AsyncPath 

12 

13from .strategies import GatherStrategy, gather_with_strategy 

14 

15 

16class ModelGatherResult: 

17 def __init__( 

18 self, 

19 *, 

20 sql_models: dict[str, type] | None = None, 

21 nosql_models: dict[str, type] | None = None, 

22 adapter_models: dict[str, dict[str, type]] | None = None, 

23 admin_models: list[type] | None = None, 

24 model_metadata: dict[str, dict[str, t.Any]] | None = None, 

25 errors: list[Exception] | None = None, 

26 ) -> None: 

27 self.sql_models = sql_models if sql_models is not None else {} 

28 self.nosql_models = nosql_models if nosql_models is not None else {} 

29 self.adapter_models = adapter_models if adapter_models is not None else {} 

30 self.admin_models = admin_models if admin_models is not None else [] 

31 self.model_metadata = model_metadata if model_metadata is not None else {} 

32 self.errors = errors if errors is not None else [] 

33 

34 @property 

35 def total_models(self) -> int: 

36 adapter_count = sum(len(models) for models in self.adapter_models.values()) 

37 return len(self.sql_models) + len(self.nosql_models) + adapter_count 

38 

39 @property 

40 def has_errors(self) -> bool: 

41 return len(self.errors) > 0 

42 

43 def get_all_models(self) -> dict[str, type]: 

44 all_models = {} 

45 all_models.update(self.sql_models) 

46 all_models.update(self.nosql_models) 

47 for adapter_models in self.adapter_models.values(): 

48 all_models.update(adapter_models) 

49 return all_models 

50 

51 

52async def gather_models( 

53 *, 

54 sources: list[str] | None = None, 

55 patterns: list[str] | None = None, 

56 include_base: bool = True, 

57 include_adapters: bool = True, 

58 include_admin: bool = True, 

59 base_classes: list[type] | None = None, 

60 strategy: GatherStrategy | None = None, 

61) -> ModelGatherResult: 

62 config = _prepare_model_gather_config(sources, patterns, base_classes, strategy) 

63 result = ModelGatherResult() 

64 

65 tasks = _build_model_gather_tasks(config, include_base, include_adapters) 

66 gather_result = await gather_with_strategy( 

67 tasks, 

68 config["strategy"], 

69 cache_key=f"models:{':'.join(config['sources'])}:{':'.join(config['patterns'])}", 

70 ) 

71 

72 _process_model_gather_results(gather_result, config["sources"], result) 

73 

74 if include_admin: 

75 admin_models = await _gather_admin_models(result.get_all_models()) 

76 result.admin_models = admin_models 

77 

78 result.errors.extend(gather_result.errors) 

79 debug( 

80 f"Gathered {result.total_models} models from {len(config['sources'])} sources", 

81 ) 

82 

83 return result 

84 

85 

86def _prepare_model_gather_config( 

87 sources: list[str] | None, 

88 patterns: list[str] | None, 

89 base_classes: list[type] | None, 

90 strategy: GatherStrategy | None, 

91) -> dict[str, t.Any]: 

92 return { 

93 "sources": sources if sources is not None else ["base", "adapters"], 

94 "patterns": patterns 

95 if patterns is not None 

96 else ["models.py", "_models.py", "_models_*.py"], 

97 "base_classes": base_classes 

98 if base_classes is not None 

99 else _get_default_model_base_classes(), 

100 "strategy": strategy if strategy is not None else GatherStrategy(), 

101 } 

102 

103 

104def _get_default_model_base_classes() -> list[type]: 

105 base_classes = [] 

106 with suppress(ImportError): 

107 from sqlmodel import SQLModel # type: ignore[import-untyped] 

108 

109 base_classes.append(SQLModel) 

110 with suppress(ImportError): 

111 from pydantic import BaseModel # type: ignore[import-untyped] 

112 

113 base_classes.append(BaseModel) 

114 

115 return base_classes 

116 

117 

118def _build_model_gather_tasks( 

119 config: dict[str, t.Any], 

120 include_base: bool, 

121 include_adapters: bool, 

122) -> list[t.Coroutine[t.Any, t.Any, t.Any]]: 

123 tasks = [] 

124 sources = config["sources"] 

125 patterns = config["patterns"] 

126 base_classes = config["base_classes"] 

127 

128 if "base" in sources and include_base: 

129 tasks.append(_gather_base_models(patterns, base_classes)) 

130 

131 if "adapters" in sources and include_adapters: 

132 tasks.append(_gather_adapter_models(patterns, base_classes)) 

133 

134 if "custom" in sources: 

135 tasks.append(_gather_custom_models(patterns, base_classes)) 

136 

137 return tasks 

138 

139 

140def _process_model_gather_results( 

141 gather_result: t.Any, 

142 sources: list[str], 

143 result: ModelGatherResult, 

144) -> None: 

145 for i, success in enumerate(gather_result.success): 

146 source_type = _get_model_source_type_by_index(i, sources) 

147 _process_single_model_source_result(success, source_type, result) 

148 

149 

150def _get_model_source_type_by_index(index: int, sources: list[str]) -> str: 

151 source_mapping = [("base", "base"), ("adapters", "adapters"), ("custom", "custom")] 

152 for i, (source_name, source_type) in enumerate(source_mapping): 

153 if i == index and source_name in sources: 

154 return source_type 

155 

156 return "unknown" 

157 

158 

159def _process_single_model_source_result( 

160 success: dict[str, t.Any], 

161 source_type: str, 

162 result: ModelGatherResult, 

163) -> None: 

164 if source_type in ("base", "custom"): 

165 result.sql_models.update(success.get("sql", {})) 

166 result.nosql_models.update(success.get("nosql", {})) 

167 result.model_metadata.update(success.get("metadata", {})) 

168 elif source_type == "adapters": 

169 _process_adapter_models(success, result) 

170 

171 

172def _process_adapter_models( 

173 success: dict[str, t.Any], 

174 result: ModelGatherResult, 

175) -> None: 

176 result.adapter_models.update(success.get("adapter_models", {})) 

177 

178 for models in success.get("adapter_models", {}).values(): 

179 for model_name, model_class in models.items(): 

180 if _is_sql_model(model_class): 

181 result.sql_models[model_name] = model_class 

182 else: 

183 result.nosql_models[model_name] = model_class 

184 

185 result.model_metadata.update(success.get("metadata", {})) 

186 

187 

188async def _gather_base_models( 

189 patterns: list[str], 

190 base_classes: list[type], 

191) -> dict[str, t.Any]: 

192 base_models = { 

193 "sql": {}, 

194 "nosql": {}, 

195 "metadata": {}, 

196 } 

197 

198 models_file_path = Path(root_path) / "models.py" 

199 await _process_base_models_file(models_file_path, base_classes, base_models) 

200 

201 await _process_base_models_directory( 

202 Path(root_path) / "models", 

203 base_classes, 

204 base_models, 

205 ) 

206 

207 return base_models 

208 

209 

210async def _process_base_models_file( 

211 models_file: Path, 

212 base_classes: list[type], 

213 base_models: dict[str, t.Any], 

214) -> None: 

215 if not await AsyncPath(models_file).exists(): 

216 return 

217 

218 try: 

219 models = await _extract_models_from_file(models_file, base_classes) 

220 _add_models_to_base_collection(models, str(models_file), base_models) 

221 debug(f"Found {len(models)} base models in models.py") 

222 except Exception as e: 

223 debug(f"Error gathering base models from models.py: {e}") 

224 

225 

226async def _process_base_models_directory( 

227 models_dir: Path, 

228 base_classes: list[type], 

229 base_models: dict[str, t.Any], 

230) -> None: 

231 if not ( 

232 await AsyncPath(models_dir).exists() and await AsyncPath(models_dir).is_dir() 

233 ): 

234 return 

235 

236 async for file_path in AsyncPath(models_dir).rglob("*.py"): 

237 if file_path.name.startswith("_"): 

238 continue 

239 

240 try: 

241 models = await _extract_models_from_file(Path(file_path), base_classes) 

242 _add_models_to_base_collection(models, str(file_path), base_models) 

243 

244 if models: 

245 debug( 

246 f"Found {len(models)} models in {file_path.relative_to(root_path)}", 

247 ) 

248 except Exception as e: 

249 debug(f"Error gathering models from {file_path}: {e}") 

250 

251 

252def _add_models_to_base_collection( 

253 models: dict[str, type], 

254 source_path: str, 

255 base_models: dict[str, t.Any], 

256) -> None: 

257 for model_name, model_class in models.items(): 

258 if _is_sql_model(model_class): 

259 base_models["sql"][model_name] = model_class 

260 else: 

261 base_models["nosql"][model_name] = model_class 

262 

263 base_models["metadata"][model_name] = { 

264 "source": source_path, 

265 "type": "sql" if _is_sql_model(model_class) else "nosql", 

266 "location": "base", 

267 } 

268 

269 

270async def _gather_adapter_models( 

271 patterns: list[str], 

272 base_classes: list[type], 

273) -> dict[str, t.Any]: 

274 adapter_models = { 

275 "adapter_models": {}, 

276 "metadata": {}, 

277 } 

278 

279 for adapter in get_adapters(): 

280 found_models = await _gather_single_adapter_models( 

281 adapter, 

282 patterns, 

283 base_classes, 

284 adapter_models, 

285 ) 

286 

287 if found_models: 

288 adapter_models["adapter_models"][adapter.name] = found_models 

289 

290 return adapter_models 

291 

292 

293async def _gather_single_adapter_models( 

294 adapter: t.Any, 

295 patterns: list[str], 

296 base_classes: list[type], 

297 adapter_models: dict[str, t.Any], 

298) -> dict[str, type]: 

299 adapter_name = adapter.name 

300 adapter_path = adapter.path.parent 

301 found_models = {} 

302 

303 for pattern in patterns: 

304 if "*" in pattern: 

305 found_models.update( 

306 await _gather_models_with_glob_pattern( 

307 adapter_path, 

308 pattern, 

309 base_classes, 

310 adapter_name, 

311 adapter_models, 

312 ), 

313 ) 

314 else: 

315 found_models.update( 

316 await _gather_models_with_exact_pattern( 

317 adapter_path, 

318 pattern, 

319 base_classes, 

320 adapter_name, 

321 adapter_models, 

322 ), 

323 ) 

324 

325 return found_models 

326 

327 

328async def _gather_models_with_glob_pattern( 

329 adapter_path: t.Any, 

330 pattern: str, 

331 base_classes: list[type], 

332 adapter_name: str, 

333 adapter_models: dict[str, t.Any], 

334) -> dict[str, type]: 

335 found_models = {} 

336 

337 async for file_path in AsyncPath(adapter_path).glob(pattern): 

338 if await file_path.is_file(): 

339 try: 

340 models = await _extract_models_from_file(Path(file_path), base_classes) 

341 found_models.update(models) 

342 _store_adapter_model_metadata( 

343 models, 

344 adapter_name, 

345 str(file_path), 

346 adapter_models, 

347 ) 

348 

349 if models: 

350 debug( 

351 f"Found {len(models)} models in {adapter_name}/{file_path.name}", 

352 ) 

353 

354 except Exception as e: 

355 debug( 

356 f"Error gathering models from {adapter_name}/{file_path.name}: {e}", 

357 ) 

358 

359 return found_models 

360 

361 

362async def _gather_models_with_exact_pattern( 

363 adapter_path: t.Any, 

364 pattern: str, 

365 base_classes: list[type], 

366 adapter_name: str, 

367 adapter_models: dict[str, t.Any], 

368) -> dict[str, type]: 

369 found_models = {} 

370 file_path = adapter_path / pattern 

371 

372 if await AsyncPath(file_path).exists(): 

373 try: 

374 models = await _extract_models_from_file(Path(file_path), base_classes) 

375 found_models.update(models) 

376 _store_adapter_model_metadata( 

377 models, 

378 adapter_name, 

379 str(file_path), 

380 adapter_models, 

381 ) 

382 

383 if models: 

384 debug(f"Found {len(models)} models in {adapter_name}/{pattern}") 

385 

386 except Exception as e: 

387 debug(f"Error gathering models from {adapter_name}/{pattern}: {e}") 

388 

389 return found_models 

390 

391 

392def _store_adapter_model_metadata( 

393 models: dict[str, type], 

394 adapter_name: str, 

395 source_path: str, 

396 adapter_models: dict[str, t.Any], 

397) -> None: 

398 for model_name, model_class in models.items(): 

399 adapter_models["metadata"][f"{adapter_name}.{model_name}"] = { 

400 "source": source_path, 

401 "type": "sql" if _is_sql_model(model_class) else "nosql", 

402 "location": "adapter", 

403 "adapter": adapter_name, 

404 } 

405 

406 

407async def _gather_custom_models( 

408 patterns: list[str], 

409 base_classes: list[type], 

410) -> dict[str, t.Any]: 

411 custom_models = { 

412 "sql": {}, 

413 "nosql": {}, 

414 "metadata": {}, 

415 } 

416 

417 custom_paths = [ 

418 Path(root_path) / "app" / "models.py", 

419 Path(root_path) / "src" / "models.py", 

420 Path(root_path) / "custom" / "models.py", 

421 ] 

422 

423 for custom_path in custom_paths: 

424 if await AsyncPath(custom_path).exists(): 

425 await _process_custom_models_file(custom_path, base_classes, custom_models) 

426 

427 return custom_models 

428 

429 

430async def _process_custom_models_file( 

431 custom_path: Path, 

432 base_classes: list[type], 

433 custom_models: dict[str, t.Any], 

434) -> None: 

435 try: 

436 models = await _extract_models_from_file(custom_path, base_classes) 

437 _add_custom_models_to_collection(models, str(custom_path), custom_models) 

438 

439 if models: 

440 debug(f"Found {len(models)} custom models in {custom_path}") 

441 

442 except Exception as e: 

443 debug(f"Error gathering custom models from {custom_path}: {e}") 

444 

445 

446def _add_custom_models_to_collection( 

447 models: dict[str, type], 

448 source_path: str, 

449 custom_models: dict[str, t.Any], 

450) -> None: 

451 for model_name, model_class in models.items(): 

452 if _is_sql_model(model_class): 

453 custom_models["sql"][model_name] = model_class 

454 else: 

455 custom_models["nosql"][model_name] = model_class 

456 

457 custom_models["metadata"][model_name] = { 

458 "source": source_path, 

459 "type": "sql" if _is_sql_model(model_class) else "nosql", 

460 "location": "custom", 

461 } 

462 

463 

464async def _extract_models_from_file( 

465 file_path: Path, 

466 base_classes: list[type], 

467) -> dict[str, type]: 

468 module_path = _get_module_path_from_file(file_path) 

469 debug(f"Extracting models from {file_path} -> {module_path}") 

470 

471 try: 

472 with suppress(ModuleNotFoundError, ImportError): 

473 module = import_module(module_path) 

474 return _extract_models_from_module(module, base_classes) 

475 except Exception as e: 

476 debug(f"Error extracting models from {file_path}: {e}") 

477 raise 

478 

479 return {} 

480 

481 

482def _get_module_path_from_file(file_path: Path) -> str: 

483 try: 

484 relative_path = file_path.relative_to(root_path) 

485 return str(relative_path).replace("/", ".").removesuffix(".py") 

486 except ValueError: 

487 return file_path.stem 

488 

489 

490def _extract_models_from_module( 

491 module: t.Any, 

492 base_classes: list[type], 

493) -> dict[str, type]: 

494 models = {} 

495 

496 for attr_name in dir(module): 

497 if attr_name.startswith("_"): 

498 continue 

499 

500 attr = getattr(module, attr_name) 

501 

502 if _is_valid_model_class(attr, module, base_classes): 

503 models[attr_name] = attr 

504 

505 return models 

506 

507 

508def _is_valid_model_class(attr: t.Any, module: t.Any, base_classes: list[type]) -> bool: 

509 if not (inspect.isclass(attr) and attr.__module__ == module.__name__): 

510 return False 

511 for base_class in base_classes: 

512 if issubclass(attr, base_class) and attr != base_class: 

513 return True 

514 

515 return hasattr(attr, "__table__") 

516 

517 

518async def _gather_admin_models( 

519 all_models: dict[str, type], 

520) -> list[type]: 

521 admin_models = [ 

522 model_class 

523 for model_class in all_models.values() 

524 if ( 

525 hasattr(model_class, "__admin__") 

526 or hasattr(model_class, "can_create") 

527 or hasattr(model_class, "can_edit") 

528 or hasattr(model_class, "__table__") 

529 ) 

530 ] 

531 

532 debug(f"Found {len(admin_models)} admin-enabled models") 

533 

534 return admin_models 

535 

536 

537def _is_sql_model(model_class: type) -> bool: 

538 if hasattr(model_class, "__table__"): 

539 return True 

540 if hasattr(model_class, "__tablename__"): 

541 return True 

542 return bool(model_class.__module__ and "sql" in model_class.__module__.lower()) 

543 

544 

545async def create_models_namespace( 

546 gather_result: ModelGatherResult, 

547) -> t.Any: 

548 class SQLNamespace: 

549 pass 

550 

551 class NoSQLNamespace: 

552 pass 

553 

554 class ModelsNamespace: 

555 def __init__(self) -> None: 

556 self.sql = SQLNamespace() 

557 self.nosql = NoSQLNamespace() 

558 self._all_models = {} 

559 

560 def get_admin_models(self) -> list[type]: 

561 return gather_result.admin_models 

562 

563 def __getattr__(self, name: str) -> t.Any: 

564 if name in self._all_models: 

565 return self._all_models[name] 

566 msg = f"Model '{name}' not found" 

567 raise AttributeError(msg) 

568 

569 models = ModelsNamespace() 

570 

571 for model_name, model_class in gather_result.sql_models.items(): 

572 setattr(models.sql, model_name, model_class) 

573 models._all_models[model_name] = model_class 

574 

575 for model_name, model_class in gather_result.nosql_models.items(): 

576 setattr(models.nosql, model_name, model_class) 

577 models._all_models[model_name] = model_class 

578 

579 debug(f"Created models namespace with {len(models._all_models)} models") 

580 

581 return models 

582 

583 

584async def validate_models( 

585 models: dict[str, type], 

586) -> dict[str, t.Any]: 

587 validation: dict[str, t.Any] = { 

588 "valid_models": [], 

589 "invalid_models": [], 

590 "warnings": [], 

591 "total_checked": len(models), 

592 } 

593 

594 for model_name, model_class in models.items(): 

595 try: 

596 _validate_single_model(model_name, model_class, validation) 

597 except Exception as e: 

598 validation["invalid_models"].append( 

599 { 

600 "model": model_name, 

601 "error": str(e), 

602 }, 

603 ) 

604 

605 return validation 

606 

607 

608def _validate_single_model( 

609 model_name: str, 

610 model_class: type, 

611 validation: dict[str, t.Any], 

612) -> None: 

613 issues = [] 

614 

615 _check_model_definition(model_class, issues) 

616 

617 _check_duplicate_model_name(model_name, validation) 

618 

619 is_valid = _check_model_inheritance(model_class, issues) 

620 

621 _categorize_model_validation_result(model_name, is_valid, issues, validation) 

622 

623 

624def _check_model_definition(model_class: type, issues: list[str]) -> None: 

625 if ( 

626 not hasattr(model_class, "__table__") 

627 and not hasattr( 

628 model_class, 

629 "__tablename__", 

630 ) 

631 and not hasattr(model_class, "__collection__") 

632 ): 

633 issues.append("Missing table/collection definition") 

634 

635 

636def _check_duplicate_model_name(model_name: str, validation: dict[str, t.Any]) -> None: 

637 if model_name in validation["valid_models"]: 

638 validation["warnings"].append(f"Duplicate model name: {model_name}") 

639 

640 

641def _check_model_inheritance(model_class: type, issues: list[str]) -> bool: 

642 if not any(hasattr(model_class, attr) for attr in ("__bases__", "__mro__")): 

643 issues.append("Invalid model inheritance") 

644 return False 

645 return True 

646 

647 

648def _categorize_model_validation_result( 

649 model_name: str, 

650 is_valid: bool, 

651 issues: list[str], 

652 validation: dict[str, t.Any], 

653) -> None: 

654 if is_valid and not issues: 

655 validation["valid_models"].append(model_name) 

656 else: 

657 validation["invalid_models"].append( 

658 { 

659 "model": model_name, 

660 "issues": issues, 

661 }, 

662 ) 

663 

664 

665def get_model_info( 

666 model_class: type, 

667 metadata: dict[str, t.Any] | None = None, 

668) -> dict[str, t.Any]: 

669 info: dict[str, t.Any] = { 

670 "name": model_class.__name__, 

671 "module": model_class.__module__, 

672 "type": "sql" if _is_sql_model(model_class) else "nosql", 

673 "attributes": [], 

674 "relationships": [], 

675 } 

676 

677 if metadata: 

678 info.update(metadata) 

679 

680 for attr_name in dir(model_class): 

681 if not attr_name.startswith("_"): 

682 attr = getattr(model_class, attr_name) 

683 if hasattr(attr, "__class__") and "Column" in attr.__class__.__name__: 

684 info["attributes"].append(attr_name) 

685 elif ( 

686 hasattr(attr, "__class__") and "Relationship" in attr.__class__.__name__ 

687 ): 

688 info["relationships"].append(attr_name) 

689 

690 if hasattr(model_class, "__table__"): 

691 info["table_name"] = model_class.__table__.name 

692 elif hasattr(model_class, "__tablename__"): 

693 info["table_name"] = model_class.__tablename__ 

694 elif hasattr(model_class, "__collection__"): 

695 info["collection_name"] = model_class.__collection__ 

696 

697 return info