Coverage for aipyapp/llm/models.py: 47%
87 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 12:02 +0200
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 12:02 +0200
1import os
2from dataclasses import dataclass, field
3from typing import Set, Dict, Any, Optional
4from enum import Enum, auto
6import yaml
7from loguru import logger
9class ModelCapability(Enum):
10 TEXT = auto()
11 IMAGE_INPUT = auto()
12 IMAGE_OUTPUT = auto()
13 AUDIO_INPUT = auto()
14 AUDIO_OUTPUT = auto()
15 VIDEO_INPUT = auto()
16 VIDEO_OUTPUT = auto()
17 FILE_INPUT = auto()
18 FILE_OUTPUT = auto()
19 CODE = auto()
20 FUNCTION_CALLING = auto()
21 REASONING = auto()
22 NATIVE_SEARCH = auto()
23 EXTENDED_THINKING = auto()
24 CODE_EXECUTION = auto()
25 STRUCTURED_OUTPUT = auto()
27@dataclass(frozen=True)
28class ModelInfo:
29 name: str
30 description: str
31 capabilities: Set[ModelCapability]
32 context_length: int
33 company: str
34 alias: Set[str] = field(default_factory=set)
35 extra: Optional[Dict[str, Any]] = None
37 def has_capability(self, cap: ModelCapability) -> bool:
38 return cap in self.capabilities
40class ModelRegistry:
41 def __init__(self, config_path: str):
42 self.models: Dict[str, ModelInfo] = {}
43 self.alias_map: Dict[str, str] = {}
44 self.logger = logger.bind(src=self.__class__.__name__)
45 self._load_from_yaml(config_path)
47 def _load_from_yaml(self, path: str):
48 with open(path, 'r', encoding='utf-8') as f:
49 data = yaml.safe_load(f)
50 for company, models_dict in data.items():
51 for model_name, v in models_dict.items():
52 aliases = set(v.get('alias', []))
53 info = ModelInfo(
54 name=model_name, # 直接用 key 作为 name
55 description=v['description'],
56 capabilities={ModelCapability[cap] for cap in v['capabilities']},
57 context_length=v['context_length'],
58 company=company,
59 alias=aliases,
60 extra={k: v for k, v in v.items() if k not in ('description', 'capabilities', 'context_length', 'alias')}
61 )
62 self.models[model_name] = info
63 for a in aliases:
64 self.alias_map[a] = model_name
65 self.logger.info(f"Loaded {len(self.models)} models from {path}")
67 def get_model_info(self, model_name: str) -> Optional[ModelInfo]:
68 main_name = self.alias_map.get(model_name, model_name)
69 return self.models.get(main_name)
71 def get_models_by_company(self, company: str) -> Dict[str, ModelInfo]:
72 return {k: v for k, v in self.models.items() if v.company == company}
74 def all_models(self) -> Dict[str, ModelInfo]:
75 return self.models
77 def reload(self, config_path: str):
78 self.models.clear()
79 self.alias_map.clear()
80 self._load_from_yaml(config_path)
82# ================== 测试代码 ==================
83if __name__ == "__main__":
84 # 假设 res/models.yaml 路径
85 yaml_path = os.path.join(os.path.dirname(__file__), '../res/models.yaml')
86 yaml_path = os.path.abspath(yaml_path)
87 print(f"加载模型配置: {yaml_path}")
88 registry = ModelRegistry(yaml_path)
90 # 测试:获取模型信息
91 info = registry.get_model_info('gpt-4o')
92 print("gpt-4o:", info)
93 # 测试:通过 alias 查询
94 info2 = registry.get_model_info('gpt-4o-2024-05-13')
95 print("gpt-4o-2024-05-13 (alias):", info2)
96 # 测试:alias 查询能力
97 if info2:
98 print("gpt-4o-2024-05-13 是否支持图片理解:", info2.has_capability(ModelCapability.IMAGE_UNDERSTAND))
99 print("gpt-4o-2024-05-13 是否支持图片生成:", info2.has_capability(ModelCapability.IMAGE_GENERATE))
100 # 测试:获取公司下所有模型
101 openai_models = registry.get_models_by_company('OpenAI')
102 print("OpenAI models:", list(openai_models.keys()))
103 # 测试:所有模型
104 print("所有模型:", list(registry.all_models().keys()))
105 # 遍历所有模型,输出其能力
106 print("\n所有模型及其能力:")
107 for name, model in registry.all_models().items():
108 print(f"- {name}: {[c.name for c in model.capabilities]}")
109 # 测试不存在模型
110 not_found = registry.get_model_info('not-exist-model')
111 print("not-exist-model:", not_found)
112 # 测试不存在 alias
113 not_found_alias = registry.get_model_info('not-exist-alias')
114 print("not-exist-alias:", not_found_alias)
115 # 输出每个模型的所有 alias
116 print("\n每个模型的 alias:")
117 for name, model in registry.all_models().items():
118 print(f"- {name}: {sorted(model.alias) if model.alias else '无'}")