Coverage for src/dynapydantic/subclass_tracking_model.py: 100%
37 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-11 18:47 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-11 18:47 +0000
1"""Base class for dynamic pydantic models"""
3import typing as ty
5import pydantic
7from .exceptions import ConfigurationError
8from .tracking_group import TrackingGroup
11def direct_children_of_base_in_mro(derived: type, base: type) -> list[type]:
12 """Find all classes in derived's MRO that are direct subclasses of base.
14 Parameters
15 ----------
16 derived
17 The class whose MRO is being examined.
18 base
19 The base class to find direct subclasses of.
21 Returns
22 -------
23 Classes in derived's MRO that are direct subclasses of base.
24 """
25 return [cls for cls in derived.__mro__ if cls is not base and base in cls.__bases__]
28class SubclassTrackingModel(pydantic.BaseModel):
29 """Subclass-tracking BaseModel"""
31 def __init_subclass__(
32 cls,
33 *args,
34 exclude_from_union: bool | None = None,
35 **kwargs,
36 ) -> None:
37 """Subclass hook"""
38 # Intercept any kwargs that are intended for TrackingGroup
39 super().__pydantic_init_subclass__(
40 *args,
41 **{k: v for k, v in kwargs.items() if k not in TrackingGroup.model_fields},
42 )
44 @classmethod
45 def __pydantic_init_subclass__(
46 cls,
47 *args,
48 exclude_from_union: bool | None = None,
49 **kwargs,
50 ) -> None:
51 """Pydantic subclass hook"""
52 if SubclassTrackingModel in cls.__bases__:
53 # Intercept any kwargs that are intended for TrackingGroup
54 super().__pydantic_init_subclass__(
55 *args,
56 **{
57 k: v
58 for k, v in kwargs.items()
59 if k not in TrackingGroup.model_fields
60 },
61 )
63 if isinstance(getattr(cls, "tracking_config", None), TrackingGroup):
64 cls.__DYNAPYDANTIC__ = cls.tracking_config
65 else:
66 try:
67 cls.__DYNAPYDANTIC__: TrackingGroup = TrackingGroup.model_validate(
68 {"name": f"{cls.__name__}-subclasses"} | kwargs,
69 )
70 except pydantic.ValidationError as e:
71 msg = (
72 "SubclassTrackingModel subclasses must either have a "
73 "tracking_config: ClassVar[dynapydantic.TrackingGroup] "
74 "member or pass kwargs sufficient to construct a "
75 "dynapydantic.TrackingGroup in the class declaration. "
76 "The latter approach produced the following "
77 f"ValidationError:\n{e}"
78 )
79 raise ConfigurationError(msg) from e
81 # Promote the tracking group's methods to the parent class
82 if cls.__DYNAPYDANTIC__.plugin_entry_point is not None:
84 def _load_plugins() -> None:
85 """Load plugins to register more models"""
86 cls.__DYNAPYDANTIC__.load_plugins()
88 cls.load_plugins = staticmethod(_load_plugins)
90 def _union(*, annotated: bool = True) -> ty.GenericAlias:
91 """Get the union of all tracked subclasses
93 Parameters
94 ----------
95 annotated
96 Whether this should be an annotated union for usage as a
97 pydantic field annotation, or a plain typing.Union for a
98 regular type annotation.
99 """
100 return cls.__DYNAPYDANTIC__.union(annotated=annotated)
102 cls.union = staticmethod(_union)
104 def _subclasses() -> dict[str, type[cls]]:
105 """Return a mapping of discriminator values to registered model"""
106 return cls.__DYNAPYDANTIC__.models
108 cls.registered_subclasses = staticmethod(_subclasses)
110 return
112 super().__pydantic_init_subclass__(*args, **kwargs)
114 if exclude_from_union:
115 return
117 supers = direct_children_of_base_in_mro(cls, SubclassTrackingModel)
118 for base in supers:
119 base.__DYNAPYDANTIC__.register_model(cls)