kiln_ai.datamodel.dataset_split

Tools for splitting datasets into train/test/validation splits. Includes filters for selecting which task runs to include in each split.

  1"""
  2Tools for splitting datasets into train/test/validation splits. Includes filters for selecting which task runs to include in each split.
  3"""
  4
  5import math
  6import random
  7from enum import Enum
  8from typing import TYPE_CHECKING, Callable
  9
 10from pydantic import BaseModel, Field, model_validator
 11
 12from kiln_ai.datamodel.basemodel import NAME_FIELD, KilnParentedModel
 13from kiln_ai.datamodel.task_run import TaskRun
 14
 15if TYPE_CHECKING:
 16    from kiln_ai.datamodel.task import Task
 17
 18
 19# Define the type alias for clarity
 20"""
 21A function that takes a TaskRun and returns a boolean indicating whether the task run should be included in the split.
 22
 23Several filters are defined below like AllDatasetFilter, HighRatingDatasetFilter, etc.
 24"""
 25DatasetFilter = Callable[[TaskRun], bool]
 26
 27
 28def AllDatasetFilter(_: TaskRun) -> bool:
 29    return True
 30
 31
 32def HighRatingDatasetFilter(task_run: TaskRun) -> bool:
 33    if task_run.output is None:
 34        return False
 35    if task_run.repaired_output is not None:
 36        # Repairs always considered high quality
 37        return True
 38    if task_run.output.rating is None:
 39        return False
 40    return task_run.output.rating.is_high_quality()
 41
 42
 43def ThinkingModelDatasetFilter(task_run: TaskRun) -> bool:
 44    """
 45    A filter that returns True if the task has intermediate outputs we can training a 'thinking' model on (reasoning or chain of thought)
 46    """
 47    return task_run.has_thinking_training_data()
 48
 49
 50def ThinkingModelHighRatedFilter(task_run: TaskRun) -> bool:
 51    """
 52    A filter that returns True if the task has thinking data and the output is high quality
 53    """
 54    return ThinkingModelDatasetFilter(task_run) and HighRatingDatasetFilter(task_run)
 55
 56
 57class DatasetFilterType(str, Enum):
 58    """Dataset filter names."""
 59
 60    ALL = "all"
 61    HIGH_RATING = "high_rating"
 62    THINKING_MODEL = "thinking_model"
 63    THINKING_MODEL_HIGH_RATED = "thinking_model_high_rated"
 64
 65
 66dataset_filters = {
 67    DatasetFilterType.ALL: AllDatasetFilter,
 68    DatasetFilterType.HIGH_RATING: HighRatingDatasetFilter,
 69    DatasetFilterType.THINKING_MODEL: ThinkingModelDatasetFilter,
 70    DatasetFilterType.THINKING_MODEL_HIGH_RATED: ThinkingModelHighRatedFilter,
 71}
 72
 73
 74class DatasetSplitDefinition(BaseModel):
 75    """
 76    A definition of a split in a dataset.
 77
 78    Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)
 79    """
 80
 81    name: str = NAME_FIELD
 82    description: str | None = Field(
 83        default=None,
 84        description="A description of the dataset for you and your team. Not used in training.",
 85    )
 86    percentage: float = Field(
 87        ge=0.0,
 88        le=1.0,
 89        description="The percentage of the dataset that this split represents (between 0 and 1).",
 90    )
 91
 92
 93AllSplitDefinition: list[DatasetSplitDefinition] = [
 94    DatasetSplitDefinition(name="all", percentage=1.0)
 95]
 96Train80Test20SplitDefinition: list[DatasetSplitDefinition] = [
 97    DatasetSplitDefinition(name="train", percentage=0.8),
 98    DatasetSplitDefinition(name="test", percentage=0.2),
 99]
100Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [
101    DatasetSplitDefinition(name="train", percentage=0.6),
102    DatasetSplitDefinition(name="test", percentage=0.2),
103    DatasetSplitDefinition(name="val", percentage=0.2),
104]
105Train80Test10Val10SplitDefinition: list[DatasetSplitDefinition] = [
106    DatasetSplitDefinition(name="train", percentage=0.8),
107    DatasetSplitDefinition(name="test", percentage=0.1),
108    DatasetSplitDefinition(name="val", percentage=0.1),
109]
110
111
112class DatasetSplit(KilnParentedModel):
113    """
114    A collection of task runs, with optional splits (train, test, validation).
115
116    Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.
117
118    Maintains a list of IDs for each split, to avoid data duplication.
119    """
120
121    name: str = NAME_FIELD
122    description: str | None = Field(
123        default=None,
124        description="A description of the dataset for you and your team. Not used in training.",
125    )
126    splits: list[DatasetSplitDefinition] = Field(
127        default_factory=list,
128        description="The splits in the dataset.",
129    )
130    split_contents: dict[str, list[str]] = Field(
131        description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.",
132    )
133    filter: DatasetFilterType | None = Field(
134        default=None,
135        description="The filter used to build the dataset.",
136    )
137
138    @model_validator(mode="after")
139    def validate_split_percentages(self) -> "DatasetSplit":
140        total = sum(split.percentage for split in self.splits)
141        if not math.isclose(total, 1.0, rel_tol=1e-9):
142            raise ValueError(f"The sum of split percentages must be 1.0 (got {total})")
143        return self
144
145    @classmethod
146    def from_task(
147        cls,
148        name: str,
149        task: "Task",
150        splits: list[DatasetSplitDefinition],
151        filter_type: DatasetFilterType = DatasetFilterType.ALL,
152        description: str | None = None,
153    ):
154        """
155        Build a dataset split from a task.
156        """
157        filter = dataset_filters[filter_type]
158        split_contents = cls.build_split_contents(task, splits, filter)
159        return cls(
160            parent=task,
161            name=name,
162            description=description,
163            splits=splits,
164            split_contents=split_contents,
165            filter=filter_type,
166        )
167
168    @classmethod
169    def build_split_contents(
170        cls,
171        task: "Task",
172        splits: list[DatasetSplitDefinition],
173        filter: DatasetFilter,
174    ) -> dict[str, list[str]]:
175        valid_ids = []
176        for task_run in task.runs():
177            if filter(task_run):
178                valid_ids.append(task_run.id)
179
180        # Shuffle and split by split percentage
181        random.shuffle(valid_ids)
182        split_contents = {}
183        start_idx = 0
184        remaining_items = len(valid_ids)
185
186        # Handle all splits except the last one
187        for split in splits[:-1]:
188            split_size = round(len(valid_ids) * split.percentage)
189            split_contents[split.name] = valid_ids[start_idx : start_idx + split_size]
190            start_idx += split_size
191            remaining_items -= split_size
192
193        # Last split gets all remaining items (for rounding)
194        if splits:
195            split_contents[splits[-1].name] = valid_ids[start_idx:]
196
197        return split_contents
198
199    def parent_task(self) -> "Task | None":
200        # inline import to avoid circular import
201        from kiln_ai.datamodel import Task
202
203        if not isinstance(self.parent, Task):
204            return None
205        return self.parent
206
207    def missing_count(self) -> int:
208        """
209        Returns:
210            int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
211        """
212        parent = self.parent_task()
213        if parent is None:
214            raise ValueError("DatasetSplit has no parent task")
215
216        runs = parent.runs(readonly=True)
217        all_ids = set(run.id for run in runs)
218        all_ids_in_splits = set()
219        for ids in self.split_contents.values():
220            all_ids_in_splits.update(ids)
221        missing = all_ids_in_splits - all_ids
222        return len(missing)
DatasetFilter = typing.Callable[[kiln_ai.datamodel.TaskRun], bool]
def AllDatasetFilter(_: kiln_ai.datamodel.TaskRun) -> bool:
29def AllDatasetFilter(_: TaskRun) -> bool:
30    return True
def HighRatingDatasetFilter(task_run: kiln_ai.datamodel.TaskRun) -> bool:
33def HighRatingDatasetFilter(task_run: TaskRun) -> bool:
34    if task_run.output is None:
35        return False
36    if task_run.repaired_output is not None:
37        # Repairs always considered high quality
38        return True
39    if task_run.output.rating is None:
40        return False
41    return task_run.output.rating.is_high_quality()
def ThinkingModelDatasetFilter(task_run: kiln_ai.datamodel.TaskRun) -> bool:
44def ThinkingModelDatasetFilter(task_run: TaskRun) -> bool:
45    """
46    A filter that returns True if the task has intermediate outputs we can training a 'thinking' model on (reasoning or chain of thought)
47    """
48    return task_run.has_thinking_training_data()

A filter that returns True if the task has intermediate outputs we can training a 'thinking' model on (reasoning or chain of thought)

def ThinkingModelHighRatedFilter(task_run: kiln_ai.datamodel.TaskRun) -> bool:
51def ThinkingModelHighRatedFilter(task_run: TaskRun) -> bool:
52    """
53    A filter that returns True if the task has thinking data and the output is high quality
54    """
55    return ThinkingModelDatasetFilter(task_run) and HighRatingDatasetFilter(task_run)

A filter that returns True if the task has thinking data and the output is high quality

class DatasetFilterType(builtins.str, enum.Enum):
58class DatasetFilterType(str, Enum):
59    """Dataset filter names."""
60
61    ALL = "all"
62    HIGH_RATING = "high_rating"
63    THINKING_MODEL = "thinking_model"
64    THINKING_MODEL_HIGH_RATED = "thinking_model_high_rated"

Dataset filter names.

ALL = <DatasetFilterType.ALL: 'all'>
HIGH_RATING = <DatasetFilterType.HIGH_RATING: 'high_rating'>
THINKING_MODEL = <DatasetFilterType.THINKING_MODEL: 'thinking_model'>
THINKING_MODEL_HIGH_RATED = <DatasetFilterType.THINKING_MODEL_HIGH_RATED: 'thinking_model_high_rated'>
dataset_filters = {<DatasetFilterType.ALL: 'all'>: <function AllDatasetFilter>, <DatasetFilterType.HIGH_RATING: 'high_rating'>: <function HighRatingDatasetFilter>, <DatasetFilterType.THINKING_MODEL: 'thinking_model'>: <function ThinkingModelDatasetFilter>, <DatasetFilterType.THINKING_MODEL_HIGH_RATED: 'thinking_model_high_rated'>: <function ThinkingModelHighRatedFilter>}
class DatasetSplitDefinition(pydantic.main.BaseModel):
75class DatasetSplitDefinition(BaseModel):
76    """
77    A definition of a split in a dataset.
78
79    Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)
80    """
81
82    name: str = NAME_FIELD
83    description: str | None = Field(
84        default=None,
85        description="A description of the dataset for you and your team. Not used in training.",
86    )
87    percentage: float = Field(
88        ge=0.0,
89        le=1.0,
90        description="The percentage of the dataset that this split represents (between 0 and 1).",
91    )

A definition of a split in a dataset.

Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)

name: str
description: str | None
percentage: float
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

AllSplitDefinition: list[DatasetSplitDefinition] = [DatasetSplitDefinition(name='all', description=None, percentage=1.0)]
Train80Test20SplitDefinition: list[DatasetSplitDefinition] = [DatasetSplitDefinition(name='train', description=None, percentage=0.8), DatasetSplitDefinition(name='test', description=None, percentage=0.2)]
Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [DatasetSplitDefinition(name='train', description=None, percentage=0.6), DatasetSplitDefinition(name='test', description=None, percentage=0.2), DatasetSplitDefinition(name='val', description=None, percentage=0.2)]
Train80Test10Val10SplitDefinition: list[DatasetSplitDefinition] = [DatasetSplitDefinition(name='train', description=None, percentage=0.8), DatasetSplitDefinition(name='test', description=None, percentage=0.1), DatasetSplitDefinition(name='val', description=None, percentage=0.1)]
class DatasetSplit(kiln_ai.datamodel.basemodel.KilnParentedModel):
113class DatasetSplit(KilnParentedModel):
114    """
115    A collection of task runs, with optional splits (train, test, validation).
116
117    Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.
118
119    Maintains a list of IDs for each split, to avoid data duplication.
120    """
121
122    name: str = NAME_FIELD
123    description: str | None = Field(
124        default=None,
125        description="A description of the dataset for you and your team. Not used in training.",
126    )
127    splits: list[DatasetSplitDefinition] = Field(
128        default_factory=list,
129        description="The splits in the dataset.",
130    )
131    split_contents: dict[str, list[str]] = Field(
132        description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.",
133    )
134    filter: DatasetFilterType | None = Field(
135        default=None,
136        description="The filter used to build the dataset.",
137    )
138
139    @model_validator(mode="after")
140    def validate_split_percentages(self) -> "DatasetSplit":
141        total = sum(split.percentage for split in self.splits)
142        if not math.isclose(total, 1.0, rel_tol=1e-9):
143            raise ValueError(f"The sum of split percentages must be 1.0 (got {total})")
144        return self
145
146    @classmethod
147    def from_task(
148        cls,
149        name: str,
150        task: "Task",
151        splits: list[DatasetSplitDefinition],
152        filter_type: DatasetFilterType = DatasetFilterType.ALL,
153        description: str | None = None,
154    ):
155        """
156        Build a dataset split from a task.
157        """
158        filter = dataset_filters[filter_type]
159        split_contents = cls.build_split_contents(task, splits, filter)
160        return cls(
161            parent=task,
162            name=name,
163            description=description,
164            splits=splits,
165            split_contents=split_contents,
166            filter=filter_type,
167        )
168
169    @classmethod
170    def build_split_contents(
171        cls,
172        task: "Task",
173        splits: list[DatasetSplitDefinition],
174        filter: DatasetFilter,
175    ) -> dict[str, list[str]]:
176        valid_ids = []
177        for task_run in task.runs():
178            if filter(task_run):
179                valid_ids.append(task_run.id)
180
181        # Shuffle and split by split percentage
182        random.shuffle(valid_ids)
183        split_contents = {}
184        start_idx = 0
185        remaining_items = len(valid_ids)
186
187        # Handle all splits except the last one
188        for split in splits[:-1]:
189            split_size = round(len(valid_ids) * split.percentage)
190            split_contents[split.name] = valid_ids[start_idx : start_idx + split_size]
191            start_idx += split_size
192            remaining_items -= split_size
193
194        # Last split gets all remaining items (for rounding)
195        if splits:
196            split_contents[splits[-1].name] = valid_ids[start_idx:]
197
198        return split_contents
199
200    def parent_task(self) -> "Task | None":
201        # inline import to avoid circular import
202        from kiln_ai.datamodel import Task
203
204        if not isinstance(self.parent, Task):
205            return None
206        return self.parent
207
208    def missing_count(self) -> int:
209        """
210        Returns:
211            int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
212        """
213        parent = self.parent_task()
214        if parent is None:
215            raise ValueError("DatasetSplit has no parent task")
216
217        runs = parent.runs(readonly=True)
218        all_ids = set(run.id for run in runs)
219        all_ids_in_splits = set()
220        for ids in self.split_contents.values():
221            all_ids_in_splits.update(ids)
222        missing = all_ids_in_splits - all_ids
223        return len(missing)

A collection of task runs, with optional splits (train, test, validation).

Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.

Maintains a list of IDs for each split, to avoid data duplication.

name: str
description: str | None
splits: list[DatasetSplitDefinition]
split_contents: dict[str, list[str]]
filter: DatasetFilterType | None
@model_validator(mode='after')
def validate_split_percentages(self) -> DatasetSplit:
139    @model_validator(mode="after")
140    def validate_split_percentages(self) -> "DatasetSplit":
141        total = sum(split.percentage for split in self.splits)
142        if not math.isclose(total, 1.0, rel_tol=1e-9):
143            raise ValueError(f"The sum of split percentages must be 1.0 (got {total})")
144        return self
@classmethod
def from_task( cls, name: str, task: kiln_ai.datamodel.Task, splits: list[DatasetSplitDefinition], filter_type: DatasetFilterType = <DatasetFilterType.ALL: 'all'>, description: str | None = None):
146    @classmethod
147    def from_task(
148        cls,
149        name: str,
150        task: "Task",
151        splits: list[DatasetSplitDefinition],
152        filter_type: DatasetFilterType = DatasetFilterType.ALL,
153        description: str | None = None,
154    ):
155        """
156        Build a dataset split from a task.
157        """
158        filter = dataset_filters[filter_type]
159        split_contents = cls.build_split_contents(task, splits, filter)
160        return cls(
161            parent=task,
162            name=name,
163            description=description,
164            splits=splits,
165            split_contents=split_contents,
166            filter=filter_type,
167        )

Build a dataset split from a task.

@classmethod
def build_split_contents( cls, task: kiln_ai.datamodel.Task, splits: list[DatasetSplitDefinition], filter: Callable[[kiln_ai.datamodel.TaskRun], bool]) -> dict[str, list[str]]:
169    @classmethod
170    def build_split_contents(
171        cls,
172        task: "Task",
173        splits: list[DatasetSplitDefinition],
174        filter: DatasetFilter,
175    ) -> dict[str, list[str]]:
176        valid_ids = []
177        for task_run in task.runs():
178            if filter(task_run):
179                valid_ids.append(task_run.id)
180
181        # Shuffle and split by split percentage
182        random.shuffle(valid_ids)
183        split_contents = {}
184        start_idx = 0
185        remaining_items = len(valid_ids)
186
187        # Handle all splits except the last one
188        for split in splits[:-1]:
189            split_size = round(len(valid_ids) * split.percentage)
190            split_contents[split.name] = valid_ids[start_idx : start_idx + split_size]
191            start_idx += split_size
192            remaining_items -= split_size
193
194        # Last split gets all remaining items (for rounding)
195        if splits:
196            split_contents[splits[-1].name] = valid_ids[start_idx:]
197
198        return split_contents
def parent_task(self) -> kiln_ai.datamodel.Task | None:
200    def parent_task(self) -> "Task | None":
201        # inline import to avoid circular import
202        from kiln_ai.datamodel import Task
203
204        if not isinstance(self.parent, Task):
205            return None
206        return self.parent
def missing_count(self) -> int:
208    def missing_count(self) -> int:
209        """
210        Returns:
211            int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
212        """
213        parent = self.parent_task()
214        if parent is None:
215            raise ValueError("DatasetSplit has no parent task")
216
217        runs = parent.runs(readonly=True)
218        all_ids = set(run.id for run in runs)
219        all_ids_in_splits = set()
220        for ids in self.split_contents.values():
221            all_ids_in_splits.update(ids)
222        missing = all_ids_in_splits - all_ids
223        return len(missing)

Returns: int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset

def relationship_name() -> str:
436        def relationship_name_method() -> str:
437            return relationship_name

The type of the None singleton.

def parent_type() -> Type[kiln_ai.datamodel.basemodel.KilnParentModel]:
429        def parent_class_method() -> Type[KilnParentModel]:
430            return cls

The type of the None singleton.

model_config = {'validate_assignment': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

def model_post_init(self: pydantic.main.BaseModel, context: Any, /) -> None:
122                    def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None:
123                        """We need to both initialize private attributes and call the user-defined model_post_init
124                        method.
125                        """
126                        init_private_attributes(self, context)
127                        original_model_post_init(self, context)

We need to both initialize private attributes and call the user-defined model_post_init method.