Coverage for aipyapp/llm/models.py: 47%

87 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-11 12:02 +0200

1import os 

2from dataclasses import dataclass, field 

3from typing import Set, Dict, Any, Optional 

4from enum import Enum, auto 

5 

6import yaml 

7from loguru import logger 

8 

9class ModelCapability(Enum): 

10 TEXT = auto() 

11 IMAGE_INPUT = auto() 

12 IMAGE_OUTPUT = auto() 

13 AUDIO_INPUT = auto() 

14 AUDIO_OUTPUT = auto() 

15 VIDEO_INPUT = auto() 

16 VIDEO_OUTPUT = auto() 

17 FILE_INPUT = auto() 

18 FILE_OUTPUT = auto() 

19 CODE = auto() 

20 FUNCTION_CALLING = auto() 

21 REASONING = auto() 

22 NATIVE_SEARCH = auto() 

23 EXTENDED_THINKING = auto() 

24 CODE_EXECUTION = auto() 

25 STRUCTURED_OUTPUT = auto() 

26 

27@dataclass(frozen=True) 

28class ModelInfo: 

29 name: str 

30 description: str 

31 capabilities: Set[ModelCapability] 

32 context_length: int 

33 company: str 

34 alias: Set[str] = field(default_factory=set) 

35 extra: Optional[Dict[str, Any]] = None 

36 

37 def has_capability(self, cap: ModelCapability) -> bool: 

38 return cap in self.capabilities 

39 

40class ModelRegistry: 

41 def __init__(self, config_path: str): 

42 self.models: Dict[str, ModelInfo] = {} 

43 self.alias_map: Dict[str, str] = {} 

44 self.logger = logger.bind(src=self.__class__.__name__) 

45 self._load_from_yaml(config_path) 

46 

47 def _load_from_yaml(self, path: str): 

48 with open(path, 'r', encoding='utf-8') as f: 

49 data = yaml.safe_load(f) 

50 for company, models_dict in data.items(): 

51 for model_name, v in models_dict.items(): 

52 aliases = set(v.get('alias', [])) 

53 info = ModelInfo( 

54 name=model_name, # 直接用 key 作为 name 

55 description=v['description'], 

56 capabilities={ModelCapability[cap] for cap in v['capabilities']}, 

57 context_length=v['context_length'], 

58 company=company, 

59 alias=aliases, 

60 extra={k: v for k, v in v.items() if k not in ('description', 'capabilities', 'context_length', 'alias')} 

61 ) 

62 self.models[model_name] = info 

63 for a in aliases: 

64 self.alias_map[a] = model_name 

65 self.logger.info(f"Loaded {len(self.models)} models from {path}") 

66 

67 def get_model_info(self, model_name: str) -> Optional[ModelInfo]: 

68 main_name = self.alias_map.get(model_name, model_name) 

69 return self.models.get(main_name) 

70 

71 def get_models_by_company(self, company: str) -> Dict[str, ModelInfo]: 

72 return {k: v for k, v in self.models.items() if v.company == company} 

73 

74 def all_models(self) -> Dict[str, ModelInfo]: 

75 return self.models 

76 

77 def reload(self, config_path: str): 

78 self.models.clear() 

79 self.alias_map.clear() 

80 self._load_from_yaml(config_path) 

81 

82# ================== 测试代码 ================== 

83if __name__ == "__main__": 

84 # 假设 res/models.yaml 路径 

85 yaml_path = os.path.join(os.path.dirname(__file__), '../res/models.yaml') 

86 yaml_path = os.path.abspath(yaml_path) 

87 print(f"加载模型配置: {yaml_path}") 

88 registry = ModelRegistry(yaml_path) 

89 

90 # 测试:获取模型信息 

91 info = registry.get_model_info('gpt-4o') 

92 print("gpt-4o:", info) 

93 # 测试:通过 alias 查询 

94 info2 = registry.get_model_info('gpt-4o-2024-05-13') 

95 print("gpt-4o-2024-05-13 (alias):", info2) 

96 # 测试:alias 查询能力 

97 if info2: 

98 print("gpt-4o-2024-05-13 是否支持图片理解:", info2.has_capability(ModelCapability.IMAGE_UNDERSTAND)) 

99 print("gpt-4o-2024-05-13 是否支持图片生成:", info2.has_capability(ModelCapability.IMAGE_GENERATE)) 

100 # 测试:获取公司下所有模型 

101 openai_models = registry.get_models_by_company('OpenAI') 

102 print("OpenAI models:", list(openai_models.keys())) 

103 # 测试:所有模型 

104 print("所有模型:", list(registry.all_models().keys())) 

105 # 遍历所有模型,输出其能力 

106 print("\n所有模型及其能力:") 

107 for name, model in registry.all_models().items(): 

108 print(f"- {name}: {[c.name for c in model.capabilities]}") 

109 # 测试不存在模型 

110 not_found = registry.get_model_info('not-exist-model') 

111 print("not-exist-model:", not_found) 

112 # 测试不存在 alias 

113 not_found_alias = registry.get_model_info('not-exist-alias') 

114 print("not-exist-alias:", not_found_alias) 

115 # 输出每个模型的所有 alias 

116 print("\n每个模型的 alias:") 

117 for name, model in registry.all_models().items(): 

118 print(f"- {name}: {sorted(model.alias) if model.alias else '无'}")