Coverage for src/dynapydantic/tracking_group.py: 100%

61 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-11 18:47 +0000

1"""Base class for dynamic pydantic models""" 

2 

3import typing as ty 

4 

5import pydantic 

6import pydantic.fields 

7import pydantic_core 

8 

9from .exceptions import AmbiguousDiscriminatorValueError, RegistrationError 

10 

11 

12def _inject_discriminator_field( 

13 cls: type[pydantic.BaseModel], 

14 disc_field: str, 

15 value: str, 

16) -> pydantic.fields.FieldInfo: 

17 """Injects the discriminator field into the given model 

18 

19 Parameters 

20 ---------- 

21 cls 

22 The BaseModel subclass 

23 disc_field 

24 Name of the discriminator field 

25 value 

26 Value of the discriminator field 

27 """ 

28 cls.model_fields[disc_field] = pydantic.fields.FieldInfo( 

29 default=value, 

30 annotation=ty.Literal[value], 

31 frozen=True, 

32 ) 

33 cls.model_rebuild(force=True) 

34 return cls.model_fields[disc_field] 

35 

36 

37class TrackingGroup(pydantic.BaseModel): 

38 """Tracker for pydantic models""" 

39 

40 name: str = pydantic.Field( 

41 description=( 

42 "Name of the tracking group. This is for human display, so it " 

43 "doesn't technically need to be globally unique, but it should be " 

44 "meaningfully named, as it will be used in error messages." 

45 ), 

46 ) 

47 discriminator_field: str = pydantic.Field( 

48 description="Name of the discriminator field", 

49 ) 

50 plugin_entry_point: str | None = pydantic.Field( 

51 None, 

52 description=( 

53 "If given, then plugins packages will be supported through this " 

54 "Python entrypoint. The entrypoint can either be a function, " 

55 "which will be called, or simply a module, which will be " 

56 "imported. In either case, models found along the import path of " 

57 "the entrypoint will be registered. If the entrypoint is a " 

58 "function, additional models may be declared in the function." 

59 ), 

60 ) 

61 discriminator_value_generator: ty.Callable[[type], str] | None = pydantic.Field( 

62 None, 

63 description=( 

64 "A callable that produces default values for the discriminator field" 

65 ), 

66 ) 

67 models: dict[str, type[pydantic.BaseModel]] = pydantic.Field( 

68 {}, 

69 description="The tracked models", 

70 ) 

71 

72 def load_plugins(self) -> None: 

73 """Load plugins to discover/register additional models""" 

74 if self.plugin_entry_point is None: 

75 return 

76 

77 from importlib.metadata import entry_points # noqa: PLC0415 

78 

79 for ep in entry_points().select(group=self.plugin_entry_point): 

80 plugin = ep.load() 

81 if callable(plugin): 

82 plugin() 

83 

84 def register( 

85 self, 

86 discriminator_value: str | None = None, 

87 ) -> ty.Callable[[type], type]: 

88 """Register a model into this group (decorator) 

89 

90 Parameters 

91 ---------- 

92 discriminator_value 

93 Value for the discriminator field. If not given, then 

94 discriminator_value_generator must be non-None or the 

95 discriminator field must be declared by hand. 

96 """ 

97 

98 def _wrapper(cls: type[pydantic.BaseModel]) -> None: 

99 disc = self.discriminator_field 

100 field = cls.model_fields.get(self.discriminator_field) 

101 if field is None: 

102 if discriminator_value is not None: 

103 _inject_discriminator_field(cls, disc, discriminator_value) 

104 elif self.discriminator_value_generator is not None: 

105 _inject_discriminator_field( 

106 cls, 

107 disc, 

108 self.discriminator_value_generator(cls), 

109 ) 

110 else: 

111 msg = ( 

112 f"unable to determine a discriminator value for " 

113 f'{cls.__name__} in tracking group "{self.name}". No ' 

114 "value was passed to register(), " 

115 "discriminator_value_generator was None and the " 

116 f'"{disc}" field was not defined.' 

117 ) 

118 raise RegistrationError(msg) 

119 elif ( 

120 discriminator_value is not None and field.default != discriminator_value 

121 ): 

122 msg = ( 

123 f"the discriminator value for {cls.__name__} was " 

124 f'ambiguous, it was set to "{discriminator_value}" via ' 

125 f'register() and "{field.default}" via the discriminator ' 

126 f"field ({self.discriminator_field})." 

127 ) 

128 raise AmbiguousDiscriminatorValueError(msg) 

129 

130 self._register_with_discriminator_field(cls) 

131 return cls 

132 

133 return _wrapper 

134 

135 def register_model(self, cls: type[pydantic.BaseModel]) -> None: 

136 """Register the given model into this group 

137 

138 Parameters 

139 ---------- 

140 cls 

141 The model to register 

142 """ 

143 disc = self.discriminator_field 

144 if cls.model_fields.get(self.discriminator_field) is None: 

145 if self.discriminator_value_generator is not None: 

146 _inject_discriminator_field( 

147 cls, 

148 disc, 

149 self.discriminator_value_generator(cls), 

150 ) 

151 else: 

152 msg = ( 

153 f"unable to determine a discriminator value for " 

154 f'{cls.__name__} in tracking group "{self.name}", ' 

155 "discriminator_value_generator was None and the " 

156 f'"{disc}" field was not defined.' 

157 ) 

158 raise RegistrationError(msg) 

159 

160 self._register_with_discriminator_field(cls) 

161 

162 def _register_with_discriminator_field(self, cls: type[pydantic.BaseModel]) -> None: 

163 """Register the model with the default of the discriminator field 

164 

165 Parameters 

166 ---------- 

167 cls 

168 The class to register, must have the disciminator field set with a 

169 unique default value in the group. 

170 """ 

171 disc = self.discriminator_field 

172 field = cls.model_fields.get(disc) 

173 value = field.default 

174 if value == pydantic_core.PydanticUndefined: 

175 msg = ( 

176 f"{cls.__name__}.{disc} had no default value, it must " 

177 "have one which is unique among all tracked models." 

178 ) 

179 raise RegistrationError(msg) 

180 

181 if (other := self.models.get(value)) is not None and other is not cls: 

182 msg = ( 

183 f'Cannot register {cls.__name__} under the "{value}" ' 

184 f"identifier, which is already in use by {other.__name__}." 

185 ) 

186 raise RegistrationError(msg) 

187 

188 self.models[value] = cls 

189 

190 def union(self, *, annotated: bool = True) -> ty.GenericAlias: 

191 """Return the union of all registered models""" 

192 return ( 

193 ty.Annotated[ 

194 ty.Union[ # noqa: UP007 

195 tuple( 

196 ty.Annotated[x, pydantic.Tag(v)] for v, x in self.models.items() 

197 ) 

198 ], 

199 pydantic.Field(discriminator=self.discriminator_field), 

200 ] 

201 if annotated 

202 else ty.Union[tuple(self.models.values())] # noqa: UP007 

203 )