Coverage for src/dynapydantic/tracking_group.py: 100%
61 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-11 18:47 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-11 18:47 +0000
1"""Base class for dynamic pydantic models"""
3import typing as ty
5import pydantic
6import pydantic.fields
7import pydantic_core
9from .exceptions import AmbiguousDiscriminatorValueError, RegistrationError
12def _inject_discriminator_field(
13 cls: type[pydantic.BaseModel],
14 disc_field: str,
15 value: str,
16) -> pydantic.fields.FieldInfo:
17 """Injects the discriminator field into the given model
19 Parameters
20 ----------
21 cls
22 The BaseModel subclass
23 disc_field
24 Name of the discriminator field
25 value
26 Value of the discriminator field
27 """
28 cls.model_fields[disc_field] = pydantic.fields.FieldInfo(
29 default=value,
30 annotation=ty.Literal[value],
31 frozen=True,
32 )
33 cls.model_rebuild(force=True)
34 return cls.model_fields[disc_field]
37class TrackingGroup(pydantic.BaseModel):
38 """Tracker for pydantic models"""
40 name: str = pydantic.Field(
41 description=(
42 "Name of the tracking group. This is for human display, so it "
43 "doesn't technically need to be globally unique, but it should be "
44 "meaningfully named, as it will be used in error messages."
45 ),
46 )
47 discriminator_field: str = pydantic.Field(
48 description="Name of the discriminator field",
49 )
50 plugin_entry_point: str | None = pydantic.Field(
51 None,
52 description=(
53 "If given, then plugins packages will be supported through this "
54 "Python entrypoint. The entrypoint can either be a function, "
55 "which will be called, or simply a module, which will be "
56 "imported. In either case, models found along the import path of "
57 "the entrypoint will be registered. If the entrypoint is a "
58 "function, additional models may be declared in the function."
59 ),
60 )
61 discriminator_value_generator: ty.Callable[[type], str] | None = pydantic.Field(
62 None,
63 description=(
64 "A callable that produces default values for the discriminator field"
65 ),
66 )
67 models: dict[str, type[pydantic.BaseModel]] = pydantic.Field(
68 {},
69 description="The tracked models",
70 )
72 def load_plugins(self) -> None:
73 """Load plugins to discover/register additional models"""
74 if self.plugin_entry_point is None:
75 return
77 from importlib.metadata import entry_points # noqa: PLC0415
79 for ep in entry_points().select(group=self.plugin_entry_point):
80 plugin = ep.load()
81 if callable(plugin):
82 plugin()
84 def register(
85 self,
86 discriminator_value: str | None = None,
87 ) -> ty.Callable[[type], type]:
88 """Register a model into this group (decorator)
90 Parameters
91 ----------
92 discriminator_value
93 Value for the discriminator field. If not given, then
94 discriminator_value_generator must be non-None or the
95 discriminator field must be declared by hand.
96 """
98 def _wrapper(cls: type[pydantic.BaseModel]) -> None:
99 disc = self.discriminator_field
100 field = cls.model_fields.get(self.discriminator_field)
101 if field is None:
102 if discriminator_value is not None:
103 _inject_discriminator_field(cls, disc, discriminator_value)
104 elif self.discriminator_value_generator is not None:
105 _inject_discriminator_field(
106 cls,
107 disc,
108 self.discriminator_value_generator(cls),
109 )
110 else:
111 msg = (
112 f"unable to determine a discriminator value for "
113 f'{cls.__name__} in tracking group "{self.name}". No '
114 "value was passed to register(), "
115 "discriminator_value_generator was None and the "
116 f'"{disc}" field was not defined.'
117 )
118 raise RegistrationError(msg)
119 elif (
120 discriminator_value is not None and field.default != discriminator_value
121 ):
122 msg = (
123 f"the discriminator value for {cls.__name__} was "
124 f'ambiguous, it was set to "{discriminator_value}" via '
125 f'register() and "{field.default}" via the discriminator '
126 f"field ({self.discriminator_field})."
127 )
128 raise AmbiguousDiscriminatorValueError(msg)
130 self._register_with_discriminator_field(cls)
131 return cls
133 return _wrapper
135 def register_model(self, cls: type[pydantic.BaseModel]) -> None:
136 """Register the given model into this group
138 Parameters
139 ----------
140 cls
141 The model to register
142 """
143 disc = self.discriminator_field
144 if cls.model_fields.get(self.discriminator_field) is None:
145 if self.discriminator_value_generator is not None:
146 _inject_discriminator_field(
147 cls,
148 disc,
149 self.discriminator_value_generator(cls),
150 )
151 else:
152 msg = (
153 f"unable to determine a discriminator value for "
154 f'{cls.__name__} in tracking group "{self.name}", '
155 "discriminator_value_generator was None and the "
156 f'"{disc}" field was not defined.'
157 )
158 raise RegistrationError(msg)
160 self._register_with_discriminator_field(cls)
162 def _register_with_discriminator_field(self, cls: type[pydantic.BaseModel]) -> None:
163 """Register the model with the default of the discriminator field
165 Parameters
166 ----------
167 cls
168 The class to register, must have the disciminator field set with a
169 unique default value in the group.
170 """
171 disc = self.discriminator_field
172 field = cls.model_fields.get(disc)
173 value = field.default
174 if value == pydantic_core.PydanticUndefined:
175 msg = (
176 f"{cls.__name__}.{disc} had no default value, it must "
177 "have one which is unique among all tracked models."
178 )
179 raise RegistrationError(msg)
181 if (other := self.models.get(value)) is not None and other is not cls:
182 msg = (
183 f'Cannot register {cls.__name__} under the "{value}" '
184 f"identifier, which is already in use by {other.__name__}."
185 )
186 raise RegistrationError(msg)
188 self.models[value] = cls
190 def union(self, *, annotated: bool = True) -> ty.GenericAlias:
191 """Return the union of all registered models"""
192 return (
193 ty.Annotated[
194 ty.Union[ # noqa: UP007
195 tuple(
196 ty.Annotated[x, pydantic.Tag(v)] for v, x in self.models.items()
197 )
198 ],
199 pydantic.Field(discriminator=self.discriminator_field),
200 ]
201 if annotated
202 else ty.Union[tuple(self.models.values())] # noqa: UP007
203 )