kiln_ai.datamodel.dataset_split
Tools for splitting datasets into train/test/validation splits. Includes filters for selecting which task runs to include in each split.
1""" 2Tools for splitting datasets into train/test/validation splits. Includes filters for selecting which task runs to include in each split. 3""" 4 5import math 6import random 7from enum import Enum 8from typing import TYPE_CHECKING, Callable 9 10from pydantic import BaseModel, Field, model_validator 11 12from kiln_ai.datamodel.basemodel import NAME_FIELD, KilnParentedModel 13from kiln_ai.datamodel.task_run import TaskRun 14 15if TYPE_CHECKING: 16 from kiln_ai.datamodel.task import Task 17 18 19# Define the type alias for clarity 20""" 21A function that takes a TaskRun and returns a boolean indicating whether the task run should be included in the split. 22 23Several filters are defined below like AllDatasetFilter, HighRatingDatasetFilter, etc. 24""" 25DatasetFilter = Callable[[TaskRun], bool] 26 27 28def AllDatasetFilter(_: TaskRun) -> bool: 29 return True 30 31 32def HighRatingDatasetFilter(task_run: TaskRun) -> bool: 33 if task_run.output is None: 34 return False 35 if task_run.repaired_output is not None: 36 # Repairs always considered high quality 37 return True 38 if task_run.output.rating is None: 39 return False 40 return task_run.output.rating.is_high_quality() 41 42 43def ThinkingModelDatasetFilter(task_run: TaskRun) -> bool: 44 """ 45 A filter that returns True if the task has intermediate outputs we can training a 'thinking' model on (reasoning or chain of thought) 46 """ 47 return task_run.has_thinking_training_data() 48 49 50def ThinkingModelHighRatedFilter(task_run: TaskRun) -> bool: 51 """ 52 A filter that returns True if the task has thinking data and the output is high quality 53 """ 54 return ThinkingModelDatasetFilter(task_run) and HighRatingDatasetFilter(task_run) 55 56 57class DatasetFilterType(str, Enum): 58 """Dataset filter names.""" 59 60 ALL = "all" 61 HIGH_RATING = "high_rating" 62 THINKING_MODEL = "thinking_model" 63 THINKING_MODEL_HIGH_RATED = "thinking_model_high_rated" 64 65 66dataset_filters = { 67 DatasetFilterType.ALL: AllDatasetFilter, 68 DatasetFilterType.HIGH_RATING: HighRatingDatasetFilter, 69 DatasetFilterType.THINKING_MODEL: ThinkingModelDatasetFilter, 70 DatasetFilterType.THINKING_MODEL_HIGH_RATED: ThinkingModelHighRatedFilter, 71} 72 73 74class DatasetSplitDefinition(BaseModel): 75 """ 76 A definition of a split in a dataset. 77 78 Example: name="train", description="The training set", percentage=0.8 (80% of the dataset) 79 """ 80 81 name: str = NAME_FIELD 82 description: str | None = Field( 83 default=None, 84 description="A description of the dataset for you and your team. Not used in training.", 85 ) 86 percentage: float = Field( 87 ge=0.0, 88 le=1.0, 89 description="The percentage of the dataset that this split represents (between 0 and 1).", 90 ) 91 92 93AllSplitDefinition: list[DatasetSplitDefinition] = [ 94 DatasetSplitDefinition(name="all", percentage=1.0) 95] 96Train80Test20SplitDefinition: list[DatasetSplitDefinition] = [ 97 DatasetSplitDefinition(name="train", percentage=0.8), 98 DatasetSplitDefinition(name="test", percentage=0.2), 99] 100Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [ 101 DatasetSplitDefinition(name="train", percentage=0.6), 102 DatasetSplitDefinition(name="test", percentage=0.2), 103 DatasetSplitDefinition(name="val", percentage=0.2), 104] 105Train80Test10Val10SplitDefinition: list[DatasetSplitDefinition] = [ 106 DatasetSplitDefinition(name="train", percentage=0.8), 107 DatasetSplitDefinition(name="test", percentage=0.1), 108 DatasetSplitDefinition(name="val", percentage=0.1), 109] 110 111 112class DatasetSplit(KilnParentedModel): 113 """ 114 A collection of task runs, with optional splits (train, test, validation). 115 116 Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks. 117 118 Maintains a list of IDs for each split, to avoid data duplication. 119 """ 120 121 name: str = NAME_FIELD 122 description: str | None = Field( 123 default=None, 124 description="A description of the dataset for you and your team. Not used in training.", 125 ) 126 splits: list[DatasetSplitDefinition] = Field( 127 default_factory=list, 128 description="The splits in the dataset.", 129 ) 130 split_contents: dict[str, list[str]] = Field( 131 description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.", 132 ) 133 filter: DatasetFilterType | None = Field( 134 default=None, 135 description="The filter used to build the dataset.", 136 ) 137 138 @model_validator(mode="after") 139 def validate_split_percentages(self) -> "DatasetSplit": 140 total = sum(split.percentage for split in self.splits) 141 if not math.isclose(total, 1.0, rel_tol=1e-9): 142 raise ValueError(f"The sum of split percentages must be 1.0 (got {total})") 143 return self 144 145 @classmethod 146 def from_task( 147 cls, 148 name: str, 149 task: "Task", 150 splits: list[DatasetSplitDefinition], 151 filter_type: DatasetFilterType = DatasetFilterType.ALL, 152 description: str | None = None, 153 ): 154 """ 155 Build a dataset split from a task. 156 """ 157 filter = dataset_filters[filter_type] 158 split_contents = cls.build_split_contents(task, splits, filter) 159 return cls( 160 parent=task, 161 name=name, 162 description=description, 163 splits=splits, 164 split_contents=split_contents, 165 filter=filter_type, 166 ) 167 168 @classmethod 169 def build_split_contents( 170 cls, 171 task: "Task", 172 splits: list[DatasetSplitDefinition], 173 filter: DatasetFilter, 174 ) -> dict[str, list[str]]: 175 valid_ids = [] 176 for task_run in task.runs(): 177 if filter(task_run): 178 valid_ids.append(task_run.id) 179 180 # Shuffle and split by split percentage 181 random.shuffle(valid_ids) 182 split_contents = {} 183 start_idx = 0 184 remaining_items = len(valid_ids) 185 186 # Handle all splits except the last one 187 for split in splits[:-1]: 188 split_size = round(len(valid_ids) * split.percentage) 189 split_contents[split.name] = valid_ids[start_idx : start_idx + split_size] 190 start_idx += split_size 191 remaining_items -= split_size 192 193 # Last split gets all remaining items (for rounding) 194 if splits: 195 split_contents[splits[-1].name] = valid_ids[start_idx:] 196 197 return split_contents 198 199 def parent_task(self) -> "Task | None": 200 # inline import to avoid circular import 201 from kiln_ai.datamodel import Task 202 203 if not isinstance(self.parent, Task): 204 return None 205 return self.parent 206 207 def missing_count(self) -> int: 208 """ 209 Returns: 210 int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset 211 """ 212 parent = self.parent_task() 213 if parent is None: 214 raise ValueError("DatasetSplit has no parent task") 215 216 runs = parent.runs(readonly=True) 217 all_ids = set(run.id for run in runs) 218 all_ids_in_splits = set() 219 for ids in self.split_contents.values(): 220 all_ids_in_splits.update(ids) 221 missing = all_ids_in_splits - all_ids 222 return len(missing)
33def HighRatingDatasetFilter(task_run: TaskRun) -> bool: 34 if task_run.output is None: 35 return False 36 if task_run.repaired_output is not None: 37 # Repairs always considered high quality 38 return True 39 if task_run.output.rating is None: 40 return False 41 return task_run.output.rating.is_high_quality()
44def ThinkingModelDatasetFilter(task_run: TaskRun) -> bool: 45 """ 46 A filter that returns True if the task has intermediate outputs we can training a 'thinking' model on (reasoning or chain of thought) 47 """ 48 return task_run.has_thinking_training_data()
A filter that returns True if the task has intermediate outputs we can training a 'thinking' model on (reasoning or chain of thought)
51def ThinkingModelHighRatedFilter(task_run: TaskRun) -> bool: 52 """ 53 A filter that returns True if the task has thinking data and the output is high quality 54 """ 55 return ThinkingModelDatasetFilter(task_run) and HighRatingDatasetFilter(task_run)
A filter that returns True if the task has thinking data and the output is high quality
58class DatasetFilterType(str, Enum): 59 """Dataset filter names.""" 60 61 ALL = "all" 62 HIGH_RATING = "high_rating" 63 THINKING_MODEL = "thinking_model" 64 THINKING_MODEL_HIGH_RATED = "thinking_model_high_rated"
Dataset filter names.
75class DatasetSplitDefinition(BaseModel): 76 """ 77 A definition of a split in a dataset. 78 79 Example: name="train", description="The training set", percentage=0.8 (80% of the dataset) 80 """ 81 82 name: str = NAME_FIELD 83 description: str | None = Field( 84 default=None, 85 description="A description of the dataset for you and your team. Not used in training.", 86 ) 87 percentage: float = Field( 88 ge=0.0, 89 le=1.0, 90 description="The percentage of the dataset that this split represents (between 0 and 1).", 91 )
A definition of a split in a dataset.
Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)
113class DatasetSplit(KilnParentedModel): 114 """ 115 A collection of task runs, with optional splits (train, test, validation). 116 117 Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks. 118 119 Maintains a list of IDs for each split, to avoid data duplication. 120 """ 121 122 name: str = NAME_FIELD 123 description: str | None = Field( 124 default=None, 125 description="A description of the dataset for you and your team. Not used in training.", 126 ) 127 splits: list[DatasetSplitDefinition] = Field( 128 default_factory=list, 129 description="The splits in the dataset.", 130 ) 131 split_contents: dict[str, list[str]] = Field( 132 description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.", 133 ) 134 filter: DatasetFilterType | None = Field( 135 default=None, 136 description="The filter used to build the dataset.", 137 ) 138 139 @model_validator(mode="after") 140 def validate_split_percentages(self) -> "DatasetSplit": 141 total = sum(split.percentage for split in self.splits) 142 if not math.isclose(total, 1.0, rel_tol=1e-9): 143 raise ValueError(f"The sum of split percentages must be 1.0 (got {total})") 144 return self 145 146 @classmethod 147 def from_task( 148 cls, 149 name: str, 150 task: "Task", 151 splits: list[DatasetSplitDefinition], 152 filter_type: DatasetFilterType = DatasetFilterType.ALL, 153 description: str | None = None, 154 ): 155 """ 156 Build a dataset split from a task. 157 """ 158 filter = dataset_filters[filter_type] 159 split_contents = cls.build_split_contents(task, splits, filter) 160 return cls( 161 parent=task, 162 name=name, 163 description=description, 164 splits=splits, 165 split_contents=split_contents, 166 filter=filter_type, 167 ) 168 169 @classmethod 170 def build_split_contents( 171 cls, 172 task: "Task", 173 splits: list[DatasetSplitDefinition], 174 filter: DatasetFilter, 175 ) -> dict[str, list[str]]: 176 valid_ids = [] 177 for task_run in task.runs(): 178 if filter(task_run): 179 valid_ids.append(task_run.id) 180 181 # Shuffle and split by split percentage 182 random.shuffle(valid_ids) 183 split_contents = {} 184 start_idx = 0 185 remaining_items = len(valid_ids) 186 187 # Handle all splits except the last one 188 for split in splits[:-1]: 189 split_size = round(len(valid_ids) * split.percentage) 190 split_contents[split.name] = valid_ids[start_idx : start_idx + split_size] 191 start_idx += split_size 192 remaining_items -= split_size 193 194 # Last split gets all remaining items (for rounding) 195 if splits: 196 split_contents[splits[-1].name] = valid_ids[start_idx:] 197 198 return split_contents 199 200 def parent_task(self) -> "Task | None": 201 # inline import to avoid circular import 202 from kiln_ai.datamodel import Task 203 204 if not isinstance(self.parent, Task): 205 return None 206 return self.parent 207 208 def missing_count(self) -> int: 209 """ 210 Returns: 211 int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset 212 """ 213 parent = self.parent_task() 214 if parent is None: 215 raise ValueError("DatasetSplit has no parent task") 216 217 runs = parent.runs(readonly=True) 218 all_ids = set(run.id for run in runs) 219 all_ids_in_splits = set() 220 for ids in self.split_contents.values(): 221 all_ids_in_splits.update(ids) 222 missing = all_ids_in_splits - all_ids 223 return len(missing)
A collection of task runs, with optional splits (train, test, validation).
Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.
Maintains a list of IDs for each split, to avoid data duplication.
139 @model_validator(mode="after") 140 def validate_split_percentages(self) -> "DatasetSplit": 141 total = sum(split.percentage for split in self.splits) 142 if not math.isclose(total, 1.0, rel_tol=1e-9): 143 raise ValueError(f"The sum of split percentages must be 1.0 (got {total})") 144 return self
146 @classmethod 147 def from_task( 148 cls, 149 name: str, 150 task: "Task", 151 splits: list[DatasetSplitDefinition], 152 filter_type: DatasetFilterType = DatasetFilterType.ALL, 153 description: str | None = None, 154 ): 155 """ 156 Build a dataset split from a task. 157 """ 158 filter = dataset_filters[filter_type] 159 split_contents = cls.build_split_contents(task, splits, filter) 160 return cls( 161 parent=task, 162 name=name, 163 description=description, 164 splits=splits, 165 split_contents=split_contents, 166 filter=filter_type, 167 )
Build a dataset split from a task.
169 @classmethod 170 def build_split_contents( 171 cls, 172 task: "Task", 173 splits: list[DatasetSplitDefinition], 174 filter: DatasetFilter, 175 ) -> dict[str, list[str]]: 176 valid_ids = [] 177 for task_run in task.runs(): 178 if filter(task_run): 179 valid_ids.append(task_run.id) 180 181 # Shuffle and split by split percentage 182 random.shuffle(valid_ids) 183 split_contents = {} 184 start_idx = 0 185 remaining_items = len(valid_ids) 186 187 # Handle all splits except the last one 188 for split in splits[:-1]: 189 split_size = round(len(valid_ids) * split.percentage) 190 split_contents[split.name] = valid_ids[start_idx : start_idx + split_size] 191 start_idx += split_size 192 remaining_items -= split_size 193 194 # Last split gets all remaining items (for rounding) 195 if splits: 196 split_contents[splits[-1].name] = valid_ids[start_idx:] 197 198 return split_contents
208 def missing_count(self) -> int: 209 """ 210 Returns: 211 int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset 212 """ 213 parent = self.parent_task() 214 if parent is None: 215 raise ValueError("DatasetSplit has no parent task") 216 217 runs = parent.runs(readonly=True) 218 all_ids = set(run.id for run in runs) 219 all_ids_in_splits = set() 220 for ids in self.split_contents.values(): 221 all_ids_in_splits.update(ids) 222 missing = all_ids_in_splits - all_ids 223 return len(missing)
Returns: int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
The type of the None singleton.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
122 def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None: 123 """We need to both initialize private attributes and call the user-defined model_post_init 124 method. 125 """ 126 init_private_attributes(self, context) 127 original_model_post_init(self, context)
We need to both initialize private attributes and call the user-defined model_post_init method.