How to fine-tune a model
This guide walks you through the process of fine-tuning models using the aiXplain SDK. Learn how to select datasets and configure fine-tuning settings. To learn more about fine-tuning, refer to this guide.
Generic Example (Template)
from aixplain.factories import DatasetFactory, ModelFactory, FinetuneFactory
dataset = DatasetFactory.get("...") # specify Data ID
model = ModelFactory.get("...") # specify Model ID
finetune = FinetuneFactory.create(
"finetuned_model",
[dataset],
model
)
finetuned_model = finetune.start()
finetuned_model.check_finetune_status()
FineTune Examples
The following examples cover the four supported FineTune use cases. View them in this documentation or by opening their corresponding Google Colab notebooks.
Passthrough: models hosted on third-party infrastructure
Hosted: models hosted on aiXplain's infrastructure
Imports
from aixplain.factories import DatasetFactory, ModelFactory, FinetuneFactory
from aixplain.enums import Function, Language # for search
from aixplain.modules.finetune import Hyperparameters # for hosted models
Select Model & Datasets
Datasets are currently private, so you must first onboard the datasets in the examples below (or similar) to follow along.
See our guide on How to upload a dataset.
- Text generation (passthough)
- Text generation (hosted)
- Translation
- Search
Model
# Choose 'exactly one' model
model_list = ModelFactory.list(
function=Function.TEXT_GENERATION,
is_finetunable=True
)["results"]
for model in model_list:
print(model.__dict__)
selected_model = ModelFactory.get("640b517694bf816d35a59125")
selected_model.__dict__
Dataset
# Choose 'one or more' datasets
dataset_list = DatasetFactory.list(
function=Function.TEXT_GENERATION,
page_size=5
)["results"]
for dataset in dataset_list:
print(dataset.__dict__)
selected_dataset = DatasetFactory.get("6501ea64b61fed7fe5976c49")
selected_dataset.__dict__
Model
# Choose 'exactly one' model
model_list = ModelFactory.list(
function=Function.TEXT_GENERATION,
is_finetunable=True
)["results"]
for model in model_list:
print(model.__dict__)
selected_model = ModelFactory.get("6543cb991f695e72028e9428")
selected_model.__dict__
Dataset
# Choose 'one or more' datasets
dataset_list = DatasetFactory.list(
function=Function.TEXT_GENERATION,
page_size=5
)["results"]
for dataset in dataset_list:
print(dataset.__dict__)
selected_dataset = DatasetFactory.get("65a7f8b1b1087d75e7afea43")
selected_dataset.__dict__
Model
# Choose 'exactly one' model
model_list = ModelFactory.list(
function=Function.TRANSLATION,
source_languages=Language.English,
target_languages=Language.German,
is_finetunable=True
)["results"]
for model in model_list:
print(model.__dict__)
selected_model = ModelFactory.get("60ddefae8d38c51c5885eff7")
selected_model.__dict__
Dataset
# Choose 'one or more' datasets
dataset_list = DatasetFactory.list(
function=Function.TRANSLATION,
source_languages=Language.English,
target_languages=Language.German,
)["results"]
for dataset in dataset_list:
print(dataset.__dict__)
selected_dataset = DatasetFactory.get("6488b198ed71322dfe331c96")
selected_dataset.__dict__
Model
# Choose 'exactly one' model
model_list = ModelFactory.list(
function=Function.SEARCH,
is_finetunable=True
)["results"]
for model in model_list:
print(model.__dict__)
selected_model = ModelFactory.get("6499cc946eb5633de15d82a1")
selected_model.__dict__
Dataset
# Choose 'one or more' datasets
dataset_list = DatasetFactory.list(
function=Function.SEARCH,
page_size=5
)["results"]
for dataset in dataset_list:
print(dataset.__dict__)
selected_dataset = DatasetFactory.get("6508642bf0fee9770971168d")
selected_dataset.__dict__
Create a FineTune
Use FinetuneFactory
to create a FineTune object and the cost
method to check the estimated training, hosting and inference costs.
- Text generation (passthough)
- Text generation (hosted)
- Translation
- Search
finetune = FinetuneFactory.create(
"<UNIQUE_FINETUNE_NAME>",
[selected_dataset],
selected_model
)
finetune.__dict__
Cost
finetune.cost.to_dict()
prompt_template = """Given the context, generate the continuation:
Context: <<context>>
Continuation: <<continuation>>"""
hyperparameters = Hyperparameters(epochs=2, learning_rate=1e-5)
By default, we are training using LoRA.
finetune = FinetuneFactory.create(
"<UNIQUE_FINETUNE_NAME>",
[selected_dataset],
selected_model,
prompt_template=prompt_template,
hyperparameters=hyperparameters,
train_percentage=90,
dev_percentage=10
)
finetune.__dict__
Cost
finetune.cost.to_dict()
finetune = FinetuneFactory.create(
"<UNIQUE_FINETUNE_NAME>",
[selected_dataset],
selected_model
)
finetune.__dict__
Cost
finetune.cost.to_dict()
finetune = FinetuneFactory.create(
"<UNIQUE_FINETUNE_NAME>",
[selected_dataset],
selected_model
)
finetune.__dict__
Cost
finetune.cost.to_dict()
Starting a FineTune
Call the start
method to begin fine-tuning and the check_finetune_status
method to check its status.
finetune_model = finetune.start()
status = finetune_model.check_finetune_status()
Status can be one of the following: onboarding
, onboarded
, hidden
, training
, deleted
, enabling
, disabled
, failed
, deleting
.
You can use a loop to check the status.
import time
while status != "onboarded":
status = finetune_model.check_finetune_status()
print(f"Current status: {status}")
time.sleep(10)
Once onboarded
, you are ready to use the model as any other which can be integrated into your agents, providing customized solutions! 🥳