Coverage for src/dynapydantic/subclass_tracking_model.py: 100%

37 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-11 18:47 +0000

1"""Base class for dynamic pydantic models""" 

2 

3import typing as ty 

4 

5import pydantic 

6 

7from .exceptions import ConfigurationError 

8from .tracking_group import TrackingGroup 

9 

10 

11def direct_children_of_base_in_mro(derived: type, base: type) -> list[type]: 

12 """Find all classes in derived's MRO that are direct subclasses of base. 

13 

14 Parameters 

15 ---------- 

16 derived 

17 The class whose MRO is being examined. 

18 base 

19 The base class to find direct subclasses of. 

20 

21 Returns 

22 ------- 

23 Classes in derived's MRO that are direct subclasses of base. 

24 """ 

25 return [cls for cls in derived.__mro__ if cls is not base and base in cls.__bases__] 

26 

27 

28class SubclassTrackingModel(pydantic.BaseModel): 

29 """Subclass-tracking BaseModel""" 

30 

31 def __init_subclass__( 

32 cls, 

33 *args, 

34 exclude_from_union: bool | None = None, 

35 **kwargs, 

36 ) -> None: 

37 """Subclass hook""" 

38 # Intercept any kwargs that are intended for TrackingGroup 

39 super().__pydantic_init_subclass__( 

40 *args, 

41 **{k: v for k, v in kwargs.items() if k not in TrackingGroup.model_fields}, 

42 ) 

43 

44 @classmethod 

45 def __pydantic_init_subclass__( 

46 cls, 

47 *args, 

48 exclude_from_union: bool | None = None, 

49 **kwargs, 

50 ) -> None: 

51 """Pydantic subclass hook""" 

52 if SubclassTrackingModel in cls.__bases__: 

53 # Intercept any kwargs that are intended for TrackingGroup 

54 super().__pydantic_init_subclass__( 

55 *args, 

56 **{ 

57 k: v 

58 for k, v in kwargs.items() 

59 if k not in TrackingGroup.model_fields 

60 }, 

61 ) 

62 

63 if isinstance(getattr(cls, "tracking_config", None), TrackingGroup): 

64 cls.__DYNAPYDANTIC__ = cls.tracking_config 

65 else: 

66 try: 

67 cls.__DYNAPYDANTIC__: TrackingGroup = TrackingGroup.model_validate( 

68 {"name": f"{cls.__name__}-subclasses"} | kwargs, 

69 ) 

70 except pydantic.ValidationError as e: 

71 msg = ( 

72 "SubclassTrackingModel subclasses must either have a " 

73 "tracking_config: ClassVar[dynapydantic.TrackingGroup] " 

74 "member or pass kwargs sufficient to construct a " 

75 "dynapydantic.TrackingGroup in the class declaration. " 

76 "The latter approach produced the following " 

77 f"ValidationError:\n{e}" 

78 ) 

79 raise ConfigurationError(msg) from e 

80 

81 # Promote the tracking group's methods to the parent class 

82 if cls.__DYNAPYDANTIC__.plugin_entry_point is not None: 

83 

84 def _load_plugins() -> None: 

85 """Load plugins to register more models""" 

86 cls.__DYNAPYDANTIC__.load_plugins() 

87 

88 cls.load_plugins = staticmethod(_load_plugins) 

89 

90 def _union(*, annotated: bool = True) -> ty.GenericAlias: 

91 """Get the union of all tracked subclasses 

92 

93 Parameters 

94 ---------- 

95 annotated 

96 Whether this should be an annotated union for usage as a 

97 pydantic field annotation, or a plain typing.Union for a 

98 regular type annotation. 

99 """ 

100 return cls.__DYNAPYDANTIC__.union(annotated=annotated) 

101 

102 cls.union = staticmethod(_union) 

103 

104 def _subclasses() -> dict[str, type[cls]]: 

105 """Return a mapping of discriminator values to registered model""" 

106 return cls.__DYNAPYDANTIC__.models 

107 

108 cls.registered_subclasses = staticmethod(_subclasses) 

109 

110 return 

111 

112 super().__pydantic_init_subclass__(*args, **kwargs) 

113 

114 if exclude_from_union: 

115 return 

116 

117 supers = direct_children_of_base_in_mro(cls, SubclassTrackingModel) 

118 for base in supers: 

119 base.__DYNAPYDANTIC__.register_model(cls)