LLM fine-tuning
How to fine-tune a model
This guide walks you through the process of fine-tuning models using the aiXplain SDK. Learn how to select datasets and configure fine-tuning settings.
Generic Example (Template)
from aixplain.factories import DatasetFactory, ModelFactory, FinetuneFactory
dataset = DatasetFactory.get("...") # specify Data ID
model = ModelFactory.get("...") # specify Model ID
finetune = FinetuneFactory.create(
  "finetuned_model",
  [dataset],
  model
)
finetuned_model = finetune.start()
finetuned_model.check_finetune_status()
FineTune Examples
The following examples cover the four supported FineTune use cases. View them in this documentation or by opening their corresponding Google Colab notebooks.
Passthrough: models hosted on third-party infrastructure
Hosted: models hosted on aiXplain's infrastructure
Imports
from aixplain.factories import DatasetFactory, ModelFactory, FinetuneFactory
from aixplain.enums import Function, Language # for search
from aixplain.modules.finetune import Hyperparameters # for hosted models
Select Model & Datasets
Datasets are currently private, so you must first onboard the datasets in the examples below (or similar) to follow along.
See our guide on How to upload a dataset.
- Text generation (passthough)
- Text generation (hosted)
- Translation
Model
# Choose 'exactly one' model
model_list = ModelFactory.list(
    function=Function.TEXT_GENERATION,
    is_finetunable=True
)["results"]
for model in model_list:
    print(model.__dict__)
selected_model = ModelFactory.get("640b517694bf816d35a59125")
selected_model.__dict__
Dataset
# Choose 'one or more' datasets
dataset_list = DatasetFactory.list(
  function=Function.TEXT_GENERATION,
  page_size=5
)["results"]
for dataset in dataset_list:
  print(dataset.__dict__)
selected_dataset = DatasetFactory.get("6501ea64b61fed7fe5976c49")
selected_dataset.__dict__
Model
# Choose 'exactly one' model
model_list = ModelFactory.list(
    function=Function.TEXT_GENERATION,
    is_finetunable=True
)["results"]
for model in model_list:
    print(model.__dict__)
selected_model = ModelFactory.get("6543cb991f695e72028e9428")
selected_model.__dict__
Dataset
# Choose 'one or more' datasets
dataset_list = DatasetFactory.list(
  function=Function.TEXT_GENERATION,
  page_size=5
)["results"]
for dataset in dataset_list:
  print(dataset.__dict__)
selected_dataset = DatasetFactory.get("65a7f8b1b1087d75e7afea43")
selected_dataset.__dict__
Model
# Choose 'exactly one' model
model_list = ModelFactory.list(
  function=Function.TRANSLATION,
  source_languages=Language.English,
  target_languages=Language.German,
  is_finetunable=True
)["results"]
for model in model_list:
  print(model.__dict__)
selected_model = ModelFactory.get("60ddefae8d38c51c5885eff7")
selected_model.__dict__
Dataset
# Choose 'one or more' datasets
dataset_list = DatasetFactory.list(
  function=Function.TRANSLATION,
  source_languages=Language.English,
  target_languages=Language.German,
)["results"]
for dataset in dataset_list:
  print(dataset.__dict__)
selected_dataset = DatasetFactory.get("6488b198ed71322dfe331c96")
selected_dataset.__dict__
Create a FineTune
Use FinetuneFactory to create a FineTune object and the cost method to check the estimated training, hosting and inference costs.
- Text generation (passthough)
- Text generation (hosted)
- Translation
finetune = FinetuneFactory.create(
  "<UNIQUE_FINETUNE_NAME>",
  [selected_dataset],
  selected_model
)
finetune.__dict__
Cost
finetune.cost.to_dict()
prompt_template = """Given the context, generate the continuation:
Context: <<context>>
Continuation: <<continuation>>"""
hyperparameters = Hyperparameters(epochs=2, learning_rate=1e-5)
By default, we are training using LoRA.
finetune = FinetuneFactory.create(
  "<UNIQUE_FINETUNE_NAME>",
  [selected_dataset],
  selected_model,
  prompt_template=prompt_template,
  hyperparameters=hyperparameters,
  train_percentage=90,
  dev_percentage=10
)
finetune.__dict__
Cost
finetune.cost.to_dict()
finetune = FinetuneFactory.create(
  "<UNIQUE_FINETUNE_NAME>",
  [selected_dataset],
  selected_model
)
finetune.__dict__
Cost
finetune.cost.to_dict()
Starting a FineTune
Call the start method to begin fine-tuning and the check_finetune_status method to check its status.
finetune_model = finetune.start()
status = finetune_model.check_finetune_status()
Status can be one of the following: onboarding, onboarded, hidden, training, deleted, enabling, disabled, failed, deleting.
You can use a loop to check the status.
import time
while status != "onboarded":
  status = finetune_model.check_finetune_status()
  print(f"Current status: {status}")
  time.sleep(10)
Once onboarded, you are ready to use the model as any other which can be integrated into your agents, providing customized solutions! 🥳