kiln_ai.adapters.fine_tune.base_finetune

  1from abc import ABC, abstractmethod
  2from typing import Literal
  3
  4from pydantic import BaseModel
  5
  6from kiln_ai.adapters.ml_model_list import built_in_models
  7from kiln_ai.datamodel import DatasetSplit, FinetuneDataStrategy, FineTuneStatusType
  8from kiln_ai.datamodel import Finetune as FinetuneModel
  9from kiln_ai.utils.name_generator import generate_memorable_name
 10
 11
 12class FineTuneStatus(BaseModel):
 13    """
 14    The status of a fine-tune, including a user friendly message.
 15    """
 16
 17    status: FineTuneStatusType
 18    message: str | None = None
 19
 20
 21class FineTuneParameter(BaseModel):
 22    """
 23    A parameter for a fine-tune. Hyperparameters, etc.
 24    """
 25
 26    name: str
 27    type: Literal["string", "int", "float", "bool"]
 28    description: str
 29    optional: bool = True
 30
 31
 32TYPE_MAP = {
 33    "string": str,
 34    "int": int,
 35    "float": float,
 36    "bool": bool,
 37}
 38
 39
 40class BaseFinetuneAdapter(ABC):
 41    """
 42    A base class for fine-tuning adapters.
 43    """
 44
 45    def __init__(
 46        self,
 47        datamodel: FinetuneModel,
 48    ):
 49        self.datamodel = datamodel
 50
 51    @classmethod
 52    async def create_and_start(
 53        cls,
 54        dataset: DatasetSplit,
 55        provider_id: str,
 56        provider_base_model_id: str,
 57        train_split_name: str,
 58        system_message: str,
 59        thinking_instructions: str | None,
 60        data_strategy: FinetuneDataStrategy,
 61        parameters: dict[str, str | int | float | bool] = {},
 62        name: str | None = None,
 63        description: str | None = None,
 64        validation_split_name: str | None = None,
 65    ) -> tuple["BaseFinetuneAdapter", FinetuneModel]:
 66        """
 67        Create and start a fine-tune.
 68        """
 69
 70        cls.check_valid_provider_model(provider_id, provider_base_model_id)
 71
 72        if not dataset.id:
 73            raise ValueError("Dataset must have an id")
 74
 75        if train_split_name not in dataset.split_contents:
 76            raise ValueError(f"Train split {train_split_name} not found in dataset")
 77
 78        if (
 79            validation_split_name
 80            and validation_split_name not in dataset.split_contents
 81        ):
 82            raise ValueError(
 83                f"Validation split {validation_split_name} not found in dataset"
 84            )
 85
 86        # Default name if not provided
 87        if name is None:
 88            name = generate_memorable_name()
 89
 90        cls.validate_parameters(parameters)
 91        parent_task = dataset.parent_task()
 92        if parent_task is None or not parent_task.path:
 93            raise ValueError("Dataset must have a parent task with a path")
 94
 95        datamodel = FinetuneModel(
 96            name=name,
 97            description=description,
 98            provider=provider_id,
 99            base_model_id=provider_base_model_id,
100            dataset_split_id=dataset.id,
101            train_split_name=train_split_name,
102            validation_split_name=validation_split_name,
103            parameters=parameters,
104            system_message=system_message,
105            thinking_instructions=thinking_instructions,
106            parent=parent_task,
107            data_strategy=data_strategy,
108        )
109
110        adapter = cls(datamodel)
111        await adapter._start(dataset)
112
113        datamodel.save_to_file()
114
115        return adapter, datamodel
116
117    @abstractmethod
118    async def _start(self, dataset: DatasetSplit) -> None:
119        """
120        Start the fine-tune.
121        """
122        pass
123
124    @abstractmethod
125    async def status(self) -> FineTuneStatus:
126        """
127        Get the status of the fine-tune.
128        """
129        pass
130
131    @classmethod
132    def available_parameters(cls) -> list[FineTuneParameter]:
133        """
134        Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.
135        """
136        return []
137
138    @classmethod
139    def validate_parameters(
140        cls, parameters: dict[str, str | int | float | bool]
141    ) -> None:
142        """
143        Validate the parameters for this fine-tune.
144        """
145        # Check required parameters and parameter types
146        available_parameters = cls.available_parameters()
147        for parameter in available_parameters:
148            if not parameter.optional and parameter.name not in parameters:
149                raise ValueError(f"Parameter {parameter.name} is required")
150            elif parameter.name in parameters:
151                # check parameter is correct type
152                expected_type = TYPE_MAP[parameter.type]
153                value = parameters[parameter.name]
154
155                # Strict type checking for numeric types
156                if expected_type is float and not isinstance(value, float):
157                    raise ValueError(
158                        f"Parameter {parameter.name} must be a float, got {type(value)}"
159                    )
160                elif expected_type is int and not isinstance(value, int):
161                    raise ValueError(
162                        f"Parameter {parameter.name} must be an integer, got {type(value)}"
163                    )
164                elif not isinstance(value, expected_type):
165                    raise ValueError(
166                        f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}"
167                    )
168
169        allowed_parameters = [p.name for p in available_parameters]
170        for parameter_key in parameters:
171            if parameter_key not in allowed_parameters:
172                raise ValueError(f"Parameter {parameter_key} is not available")
173
174    @classmethod
175    def check_valid_provider_model(
176        cls, provider_id: str, provider_base_model_id: str
177    ) -> None:
178        """
179        Check if the provider and base model are valid.
180        """
181        for model in built_in_models:
182            for provider in model.providers:
183                if (
184                    provider.name == provider_id
185                    and provider.provider_finetune_id == provider_base_model_id
186                ):
187                    return
188        raise ValueError(
189            f"Provider {provider_id} with base model {provider_base_model_id} is not available"
190        )
class FineTuneStatus(pydantic.main.BaseModel):
13class FineTuneStatus(BaseModel):
14    """
15    The status of a fine-tune, including a user friendly message.
16    """
17
18    status: FineTuneStatusType
19    message: str | None = None

The status of a fine-tune, including a user friendly message.

message: str | None
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class FineTuneParameter(pydantic.main.BaseModel):
22class FineTuneParameter(BaseModel):
23    """
24    A parameter for a fine-tune. Hyperparameters, etc.
25    """
26
27    name: str
28    type: Literal["string", "int", "float", "bool"]
29    description: str
30    optional: bool = True

A parameter for a fine-tune. Hyperparameters, etc.

name: str
type: Literal['string', 'int', 'float', 'bool']
description: str
optional: bool
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

TYPE_MAP = {'string': <class 'str'>, 'int': <class 'int'>, 'float': <class 'float'>, 'bool': <class 'bool'>}
class BaseFinetuneAdapter(abc.ABC):
 41class BaseFinetuneAdapter(ABC):
 42    """
 43    A base class for fine-tuning adapters.
 44    """
 45
 46    def __init__(
 47        self,
 48        datamodel: FinetuneModel,
 49    ):
 50        self.datamodel = datamodel
 51
 52    @classmethod
 53    async def create_and_start(
 54        cls,
 55        dataset: DatasetSplit,
 56        provider_id: str,
 57        provider_base_model_id: str,
 58        train_split_name: str,
 59        system_message: str,
 60        thinking_instructions: str | None,
 61        data_strategy: FinetuneDataStrategy,
 62        parameters: dict[str, str | int | float | bool] = {},
 63        name: str | None = None,
 64        description: str | None = None,
 65        validation_split_name: str | None = None,
 66    ) -> tuple["BaseFinetuneAdapter", FinetuneModel]:
 67        """
 68        Create and start a fine-tune.
 69        """
 70
 71        cls.check_valid_provider_model(provider_id, provider_base_model_id)
 72
 73        if not dataset.id:
 74            raise ValueError("Dataset must have an id")
 75
 76        if train_split_name not in dataset.split_contents:
 77            raise ValueError(f"Train split {train_split_name} not found in dataset")
 78
 79        if (
 80            validation_split_name
 81            and validation_split_name not in dataset.split_contents
 82        ):
 83            raise ValueError(
 84                f"Validation split {validation_split_name} not found in dataset"
 85            )
 86
 87        # Default name if not provided
 88        if name is None:
 89            name = generate_memorable_name()
 90
 91        cls.validate_parameters(parameters)
 92        parent_task = dataset.parent_task()
 93        if parent_task is None or not parent_task.path:
 94            raise ValueError("Dataset must have a parent task with a path")
 95
 96        datamodel = FinetuneModel(
 97            name=name,
 98            description=description,
 99            provider=provider_id,
100            base_model_id=provider_base_model_id,
101            dataset_split_id=dataset.id,
102            train_split_name=train_split_name,
103            validation_split_name=validation_split_name,
104            parameters=parameters,
105            system_message=system_message,
106            thinking_instructions=thinking_instructions,
107            parent=parent_task,
108            data_strategy=data_strategy,
109        )
110
111        adapter = cls(datamodel)
112        await adapter._start(dataset)
113
114        datamodel.save_to_file()
115
116        return adapter, datamodel
117
118    @abstractmethod
119    async def _start(self, dataset: DatasetSplit) -> None:
120        """
121        Start the fine-tune.
122        """
123        pass
124
125    @abstractmethod
126    async def status(self) -> FineTuneStatus:
127        """
128        Get the status of the fine-tune.
129        """
130        pass
131
132    @classmethod
133    def available_parameters(cls) -> list[FineTuneParameter]:
134        """
135        Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.
136        """
137        return []
138
139    @classmethod
140    def validate_parameters(
141        cls, parameters: dict[str, str | int | float | bool]
142    ) -> None:
143        """
144        Validate the parameters for this fine-tune.
145        """
146        # Check required parameters and parameter types
147        available_parameters = cls.available_parameters()
148        for parameter in available_parameters:
149            if not parameter.optional and parameter.name not in parameters:
150                raise ValueError(f"Parameter {parameter.name} is required")
151            elif parameter.name in parameters:
152                # check parameter is correct type
153                expected_type = TYPE_MAP[parameter.type]
154                value = parameters[parameter.name]
155
156                # Strict type checking for numeric types
157                if expected_type is float and not isinstance(value, float):
158                    raise ValueError(
159                        f"Parameter {parameter.name} must be a float, got {type(value)}"
160                    )
161                elif expected_type is int and not isinstance(value, int):
162                    raise ValueError(
163                        f"Parameter {parameter.name} must be an integer, got {type(value)}"
164                    )
165                elif not isinstance(value, expected_type):
166                    raise ValueError(
167                        f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}"
168                    )
169
170        allowed_parameters = [p.name for p in available_parameters]
171        for parameter_key in parameters:
172            if parameter_key not in allowed_parameters:
173                raise ValueError(f"Parameter {parameter_key} is not available")
174
175    @classmethod
176    def check_valid_provider_model(
177        cls, provider_id: str, provider_base_model_id: str
178    ) -> None:
179        """
180        Check if the provider and base model are valid.
181        """
182        for model in built_in_models:
183            for provider in model.providers:
184                if (
185                    provider.name == provider_id
186                    and provider.provider_finetune_id == provider_base_model_id
187                ):
188                    return
189        raise ValueError(
190            f"Provider {provider_id} with base model {provider_base_model_id} is not available"
191        )

A base class for fine-tuning adapters.

datamodel
@classmethod
async def create_and_start( cls, dataset: kiln_ai.datamodel.DatasetSplit, provider_id: str, provider_base_model_id: str, train_split_name: str, system_message: str, thinking_instructions: str | None, data_strategy: kiln_ai.datamodel.FinetuneDataStrategy, parameters: dict[str, str | int | float | bool] = {}, name: str | None = None, description: str | None = None, validation_split_name: str | None = None) -> tuple[BaseFinetuneAdapter, kiln_ai.datamodel.Finetune]:
 52    @classmethod
 53    async def create_and_start(
 54        cls,
 55        dataset: DatasetSplit,
 56        provider_id: str,
 57        provider_base_model_id: str,
 58        train_split_name: str,
 59        system_message: str,
 60        thinking_instructions: str | None,
 61        data_strategy: FinetuneDataStrategy,
 62        parameters: dict[str, str | int | float | bool] = {},
 63        name: str | None = None,
 64        description: str | None = None,
 65        validation_split_name: str | None = None,
 66    ) -> tuple["BaseFinetuneAdapter", FinetuneModel]:
 67        """
 68        Create and start a fine-tune.
 69        """
 70
 71        cls.check_valid_provider_model(provider_id, provider_base_model_id)
 72
 73        if not dataset.id:
 74            raise ValueError("Dataset must have an id")
 75
 76        if train_split_name not in dataset.split_contents:
 77            raise ValueError(f"Train split {train_split_name} not found in dataset")
 78
 79        if (
 80            validation_split_name
 81            and validation_split_name not in dataset.split_contents
 82        ):
 83            raise ValueError(
 84                f"Validation split {validation_split_name} not found in dataset"
 85            )
 86
 87        # Default name if not provided
 88        if name is None:
 89            name = generate_memorable_name()
 90
 91        cls.validate_parameters(parameters)
 92        parent_task = dataset.parent_task()
 93        if parent_task is None or not parent_task.path:
 94            raise ValueError("Dataset must have a parent task with a path")
 95
 96        datamodel = FinetuneModel(
 97            name=name,
 98            description=description,
 99            provider=provider_id,
100            base_model_id=provider_base_model_id,
101            dataset_split_id=dataset.id,
102            train_split_name=train_split_name,
103            validation_split_name=validation_split_name,
104            parameters=parameters,
105            system_message=system_message,
106            thinking_instructions=thinking_instructions,
107            parent=parent_task,
108            data_strategy=data_strategy,
109        )
110
111        adapter = cls(datamodel)
112        await adapter._start(dataset)
113
114        datamodel.save_to_file()
115
116        return adapter, datamodel

Create and start a fine-tune.

@abstractmethod
async def status(self) -> FineTuneStatus:
125    @abstractmethod
126    async def status(self) -> FineTuneStatus:
127        """
128        Get the status of the fine-tune.
129        """
130        pass

Get the status of the fine-tune.

@classmethod
def available_parameters(cls) -> list[FineTuneParameter]:
132    @classmethod
133    def available_parameters(cls) -> list[FineTuneParameter]:
134        """
135        Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.
136        """
137        return []

Returns a list of parameters that can be provided for this fine-tune. Includes hyperparameters, etc.

@classmethod
def validate_parameters(cls, parameters: dict[str, str | int | float | bool]) -> None:
139    @classmethod
140    def validate_parameters(
141        cls, parameters: dict[str, str | int | float | bool]
142    ) -> None:
143        """
144        Validate the parameters for this fine-tune.
145        """
146        # Check required parameters and parameter types
147        available_parameters = cls.available_parameters()
148        for parameter in available_parameters:
149            if not parameter.optional and parameter.name not in parameters:
150                raise ValueError(f"Parameter {parameter.name} is required")
151            elif parameter.name in parameters:
152                # check parameter is correct type
153                expected_type = TYPE_MAP[parameter.type]
154                value = parameters[parameter.name]
155
156                # Strict type checking for numeric types
157                if expected_type is float and not isinstance(value, float):
158                    raise ValueError(
159                        f"Parameter {parameter.name} must be a float, got {type(value)}"
160                    )
161                elif expected_type is int and not isinstance(value, int):
162                    raise ValueError(
163                        f"Parameter {parameter.name} must be an integer, got {type(value)}"
164                    )
165                elif not isinstance(value, expected_type):
166                    raise ValueError(
167                        f"Parameter {parameter.name} must be type {expected_type}, got {type(value)}"
168                    )
169
170        allowed_parameters = [p.name for p in available_parameters]
171        for parameter_key in parameters:
172            if parameter_key not in allowed_parameters:
173                raise ValueError(f"Parameter {parameter_key} is not available")

Validate the parameters for this fine-tune.

@classmethod
def check_valid_provider_model(cls, provider_id: str, provider_base_model_id: str) -> None:
175    @classmethod
176    def check_valid_provider_model(
177        cls, provider_id: str, provider_base_model_id: str
178    ) -> None:
179        """
180        Check if the provider and base model are valid.
181        """
182        for model in built_in_models:
183            for provider in model.providers:
184                if (
185                    provider.name == provider_id
186                    and provider.provider_finetune_id == provider_base_model_id
187                ):
188                    return
189        raise ValueError(
190            f"Provider {provider_id} with base model {provider_base_model_id} is not available"
191        )

Check if the provider and base model are valid.