Fine-tune Meta Llama 3.1 models using torchtune on Amazon SageMaker

TutoSartup excerpt from this article:
Recognizing this challenge, the PyTorch team developed torchtune, a PyTorch-native library that simplifies authoring, fine-tuning, and experimenting with LLMs, making it more accessible to a broader range of users and applications… In this post, AWS collaborates with Meta’s PyTorch team to sho…

This post is co-written with Meta’s PyTorch team.

In today’s rapidly evolving AI landscape, businesses are constantly seeking ways to use advanced large language models (LLMs) for their specific needs. Although foundation models (FMs) offer impressive out-of-the-box capabilities, true competitive advantage often lies in deep model customization through fine-tuning. However, fine-tuning LLMs for complex tasks typically requires advanced AI expertise to align and optimize them effectively. Recognizing this challenge, the PyTorch team developed torchtune, a PyTorch-native library that simplifies authoring, fine-tuning, and experimenting with LLMs, making it more accessible to a broader range of users and applications.

In this post, AWS collaborates with Meta’s PyTorch team to showcase how you can use PyTorch’s torchtune library to fine-tune Meta Llama-like architectures while using a fully-managed environment provided by Amazon SageMaker Training. We demonstrate this through a step-by-step implementation of model fine-tuning, inference, quantization, and evaluation. We perform the steps on a Meta Llama 3.1 8B model utilizing the LoRA fine-tuning strategy on a single p4d.24xlarge worker node (providing 8 Nvidia A100 GPUs).

Before we dive into the step-by-step guide, we first explored the performance of our technical stack by fine-tuning a Meta Llama 3.1 8B model across various configurations and instance types.

As can be seen in the following chart, we found that a single p4d.24xlarge delivers 70% higher performance than two g5.48xlarge instances (each with 8 NVIDIA A10 GPUs) at almost 47% reduced price. We therefore have optimized the example in this post for a p4d.24xlarge configuration. However, you could use the same code to run single-node or multi-node training on different instance configurations by changing the parameters passed to the SageMaker estimator. You could further optimize the time for training in the following graph by using a SageMaker managed warm pool and accessing pre-downloaded models using Amazon Elastic File System (Amazon EFS).

Challenges with fine-tuning LLMs

Generative AI models offer many promising business use cases. However, to maintain factual accuracy and relevance of these LLMs to specific business domains, fine-tuning is required. Due to the growing number of model parameters and the increasing context length of modern LLMs, this process is memory intensive. To address these challenges, fine-tuning strategies like LoRA (Low-Rank Adaptation) and QLoRA (Quantized Low-Rank Adaptation) limit the number of trainable parameters by adding low-rank parallel structures to the transformer layers. This enables you to train LLMs even on systems with low memory availability like commodity GPUs. However, this leads to an increased complexity because new dependencies have to be handled and training recipes and hyperparameters need to be adapted to the new techniques.

What businesses need today is user-friendly training recipes for these popular fine-tuning techniques, which provide abstractions to the end-to-end tuning process, addressing the common pitfalls in the most opinionated way.

How does torchtune helps?

torchtune is a PyTorch-native library that aims to democratize and streamline the fine-tuning process for LLMs. By doing so, it makes it straightforward for researchers, developers, and organizations to adapt these powerful LLMs to their specific needs and constraints. It provides training recipes for a variety of fine-tuning techniques, which can be configured through YAML files. The recipes implement common fine-tuning methods (full-weight, LoRA, QLoRA) as well as other common tasks like inference and evaluation. They automatically apply a set of important features (FSDP, activation checkpointing, gradient accumulation, mixed precision) and are specific to a given model family (such as Meta Llama 3/3.1 or Mistral) as well as compute environment (single-node vs. multi-node).

Additionally, torchtune integrates with major libraries and frameworks like Hugging Face datasets, EleutherAI’s Eval Harness, and Weights & Biases. This helps address the requirements of the generative AI fine-tuning lifecycle, from data ingestion and multi-node fine-tuning to inference and evaluation. The following diagram shows a visualization of the steps we describe in this post.

Refer to the installation instructions and PyTorch documentation to learn more about torchtune and its concepts.

Solution overview

This post demonstrates the use of SageMaker Training for running torchtune recipes through task-specific training jobs on separate compute clusters. SageMaker Training is a comprehensive, fully managed ML service that enables scalable model training. It provides flexible compute resource selection, support for custom libraries, a pay-as-you-go pricing model, and self-healing capabilities. By managing workload orchestration, health checks, and infrastructure, SageMaker helps reduce training time and total cost of ownership.

The solution architecture incorporates the following key components to enhance security and efficiency in fine-tuning workflows:

  • Security enhancement – Training jobs are run within private subnets of your virtual private cloud (VPC), significantly improving the security posture of machine learning (ML) workflows.
  • Efficient storage solution – Amazon EFS is used to accelerate model storage and access across various phases of the ML workflow.
  • Customizable environment – We use custom containers in training jobs. The support in SageMaker for custom containers allows you to package all necessary dependencies, specialized frameworks, and libraries into a single artifact, providing full control over your ML environment.

The following diagram illustrates the solution architecture. Users initiate the process by calling the SageMaker control plane through APIs or command line interface (CLI) or using the SageMaker SDK for each individual step. In response, SageMaker spins up training jobs with the requested number and type of compute instances to run specific tasks. Each step defined in the diagram accesses torchtune recipes from an Amazon Simple Storage Service (Amazon S3) bucket and uses Amazon EFS to save and access model artifacts across different stages of the workflow.

By decoupling every torchtune step, we achieve a balance between flexibility and integration, allowing for both independent execution of steps and the potential for automating this process using seamless pipeline integration.

In this use case, we fine-tune a Meta Llama 3.1 8B model with LoRA. Subsequently, we run model inference, and optionally quantize and evaluate the model using torchtune and SageMaker Training.

Recipes, configs, datasets, and prompt templates are completely configurable and allow you to align torchtune to your requirements. To demonstrate this, we use a custom prompt template in this use case and combine it with the open source dataset Samsung/samsum from the Hugging Face hub.

We fine-tune the model using torchtune’s multi device LoRA recipe (lora_finetune_distributed) and use the SageMaker customized version of Meta Llama 3.1 8B default config (llama3_1/8B_lora).

Prerequisites

You need to complete the following prerequisites before you can run the SageMaker Jupyter notebooks:

  1. Create a Hugging Face access token to get access to the gated repo meta-llama/Meta-Llama-3.1-8B on Hugging Face.
  2. Create a Weights & Biases API key to access the Weights & Biases dashboard for logging and monitoring
  3. Request a SageMaker service quota for 1x ml.p4d.24xlarge and 1xml.g5.2xlarge.
  4. Create an AWS Identity and Access Management (IAM) role with managed policies AmazonSageMakerFullAccess, AmazonEC2FullAccess, AmazonElasticFileSystemFullAccess, and AWSCloudFormationFullAccess to give required access to SageMaker to run the examples. (This is for demonstration purposes. You should adjust this to your specific security requirements for production.)
  5. Create an Amazon SageMaker Studio domain (see Quick setup to Amazon SageMaker) to access Jupyter notebooks with the preceding role. Refer to the instructions to set permissions for Docker build.
  6. Log in to the notebook console and clone the GitHub repo:
$ git clone https://github.com/aws-samples/sagemaker-distributed-training-workshop.git
$ cd sagemaker-distributed-training-workshop/13-torchtune
  1. Run the notebook ipynb to set up VPC and Amazon EFS using an AWS CloudFormation stack.

Review torchtune configs

The following figure illustrates the steps in our workflow.

You can look up the torchtune configs for your use case by directly using the tune CLI.For this post, we provide modified config files aligned with SageMaker directory path’s structure:

sh-4.2$ cd config/
sh-4.2$ ls -ltr
-rw-rw-r-- 1 ec2-user ec2-user 1151 Aug 26 18:34 config_l3.1_8b_gen_orig.yaml
-rw-rw-r-- 1 ec2-user ec2-user 1172 Aug 26 18:34 config_l3.1_8b_gen_trained.yaml
-rw-rw-r-- 1 ec2-user ec2-user  644 Aug 26 18:49 config_l3.1_8b_quant.yaml
-rw-rw-r-- 1 ec2-user ec2-user 2223 Aug 28 14:53 config_l3.1_8b_lora.yaml
-rw-rw-r-- 1 ec2-user ec2-user 1223 Sep  4 14:28 config_l3.1_8b_eval_trained.yaml
-rw-rw-r-- 1 ec2-user ec2-user 1213 Sep  4 14:29 config_l3.1_8b_eval_original.yaml

torchtune uses these config files to select and configure the components (think models and tokenizers) during the execution of the recipes.

Build the container

As part of our example, we create a custom container to provide custom libraries like torch nightlies and torchtune. Complete the following steps:

sh-4.2$ cat Dockerfile
# Set the default value for the REGION build argument
ARG REGION=us-west-2
# SageMaker PyTorch image for TRAINING
FROM ${ACCOUNTID}.dkr.ecr.${REGION}.amazonaws.com/pytorch-training:2.3.0-gpu-py311-cu121-ubuntu20.04-sagemaker
# Uninstall existing PyTorch packages
RUN pip uninstall torch torchvision transformer-engine -y
# Install latest release of PyTorch and torchvision
RUN pip install --force-reinstall torch==2.4.1 torchao==0.4.0 torchvision==0.19.1

Run the 1_build_container.ipynb notebook until the following command to push this file to your ECR repository:

!sm-docker build . --repository accelerate:latest

sm-docker is a CLI tool designed for building Docker images in SageMaker Studio using AWS CodeBuild. We install the library as part of the notebook.

Next, we will run the 2_torchtune-llama3_1.ipynb notebook for all fine-tuning workflow tasks.

For every task, we review three artifacts:

  • torchtune configuration file
  • SageMaker task config with compute and torchtune recipe details
  • SageMaker task output

Run the fine-tuning task

In this section, we walk through the steps to run and monitor the fine-tuning task.

Run the fine-tuning job

The following code shows a shortened torchtune recipe configuration highlighting a few key components of the file for a fine-tuning job:

  • Model component including LoRA rank configuration
  • Meta Llama 3 tokenizer to tokenize the data
  • Checkpointer to read and write checkpoints
  • Dataset component to load the dataset
sh-4.2$ cat config_l3.1_8b_lora.yaml
# Model Arguments
model:
  _component_: torchtune.models.llama3_1.lora_llama3_1_8b
  lora_attn_modules: ['q_proj', 'v_proj']
  lora_rank: 8
  lora_alpha: 16

# Tokenizer
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: /opt/ml/input/data/model/hf-model/original/tokenizer.model

checkpointer:
  _component_: torchtune.utils.FullModelMetaCheckpointer
  checkpoint_files: [
    consolidated.00.pth
  ]
  …

# Dataset and Sampler
dataset:
  _component_: torchtune.datasets.samsum_dataset
  train_on_input: True
batch_size: 13

# Training
epochs: 1
gradient_accumulation_steps: 2

... and more ...

We use Weights & Biases for logging and monitoring our training jobs, which helps us track our model’s performance:

metric_logger:
_component_: torchtune.utils.metric_logging.WandBLogger
…

Next, we define a SageMaker task that will be passed to our utility function in the script create_pytorch_estimator. This script creates the PyTorch estimator with all the defined parameters.

In the task, we use the lora_finetune_distributed torchrun recipe with config config-l3.1-8b-lora.yaml on an ml.p4d.24xlarge instance. Make sure you download the base model from Hugging Face before it’s fine-tuned using the use_downloaded_model parameter. The image_uri parameter defines the URI of the custom container.

sagemaker_tasks={
    "fine-tune":{
        "hyperparameters":{
            "tune_config_name":"config-l3.1-8b-lora.yaml",
            "tune_action":"fine-tune",
            "use_downloaded_model":"false",
            "tune_recipe":"lora_finetune_distributed"
            },
        "instance_count":1,
        "instance_type":"ml.p4d.24xlarge",        
        "image_uri":"<accountid>.dkr.ecr.<region>.amazonaws.com/accelerate:latest"
    }
    ... and more ...
}

To create and run the task, run the following code:

Task="fine-tune"
estimator=create_pytorch_estimator(**sagemaker_tasks[Task])
execute_task(estimator)

The following code shows the task output and reported status:

# Refer-Output

2024-08-16 17:45:32 Starting - Starting the training job...
...
...

1|140|Loss: 1.4883038997650146:  99%|█████████▉| 141/142 [06:26<00:02,  2.47s/it]
1|141|Loss: 1.4621509313583374:  99%|█████████▉| 141/142 [06:26<00:02,  2.47s/it]

Training completed with code: 0
2024-08-26 14:19:09,760 sagemaker-training-toolkit INFO     Reporting training SUCCESS

The final model is saved to Amazon EFS, which makes it available without download time penalties.

Monitor the fine-tuning job

You can monitor various metrics such as loss and learning rate for your training run through the Weights & Biases dashboard. The following figures show the results of the training run where we tracked GPU utilization, GPU memory utilization, and loss curve.

For the following graph, to optimize memory usage, torchtune uses only rank 0 to initially load the model into CPU memory. rank 0 therefore will be responsible for loading the model weights from the checkpoint.

The example is optimized to use GPU memory to its maximum capacity. Increasing the batch size further will lead to CUDA out-of-memory (OOM) errors.

The run took about 13 minutes to complete for one epoch, resulting in the loss curve shown in the following graph.

Run the model generation task

In the next step, we use the previously fine-tuned model weights to generate the answer to a sample prompt and compare it to the base model.

The following code shows the configuration of the generate recipe config_l3.1_8b_gen_trained.yaml. The following are key parameters:

  • FullModelMetaCheckpointer – We use this to load the trained model checkpoint meta_model_0.pt from Amazon EFS
  • CustomTemplate.SummarizeTemplate – We use this to format the prompt for inference
# torchtune - trained model generation config - config_l3.1_8b_gen_trained.yaml
model:
  _component_: torchtune.models.llama3_1.llama3_1_8b
  
checkpointer:
  _component_: torchtune.utils.FullModelMetaCheckpointer
  checkpoint_dir: /opt/ml/input/data/model/
  checkpoint_files: [
    meta_model_0.pt
  ]
  …

# Generation arguments; defaults taken from gpt-fast
instruct_template: CustomTemplate.SummarizeTemplate

... and more ...

Next, we configure the SageMaker task to run on a single ml.g5.2xlarge instance:

prompt=r'{"dialogue":"Amanda: I baked  cookies. Do you want some?rnJerry: Sure rnAmanda: I will bring you tomorrow :-)"}'

sagemaker_tasks={
    "generate_inference_on_trained":{
        "hyperparameters":{
            "tune_config_name":"config_l3.1_8b_gen_trained.yaml ",
            "tune_action":"generate-trained",
            "use_downloaded_model":"true",
            "prompt":json.dumps(prompt)
            },
        "instance_count":1,
        "instance_type":"ml.g5.2xlarge",
 "image_uri":"<accountid>.dkr.ecr.<region>.amazonaws.com/accelerate:latest"
    }
}

In the output of the SageMaker task, we see the model summary output and some stats like tokens per second:

#Refer- Output
...
Amanda: I baked  cookies. Do you want some?rnJerry: Sure rnAmanda: I will bring you tomorrow :-)

Summary:
Amanda baked cookies. She will bring some to Jerry tomorrow.

INFO:torchtune.utils.logging:Time for inference: 1.71 sec total, 7.61 tokens/sec
INFO:torchtune.utils.logging:Memory used: 18.32 GB

... and more ...

We can generate inference from the original model using the original model artifact consolidated.00.pth:

# torchtune - trained original generation config - config_l3.1_8b_gen_orig.yaml
…  
checkpointer:
  _component_: torchtune.utils.FullModelMetaCheckpointer
  checkpoint_dir: /opt/ml/input/data/model/hf-model/original/
  checkpoint_files: [
    consolidated.00.pth
  ]
  
... and more ...

The following code shows the comparison output from the base model run with the SageMaker task (generate_inference_on_original). We can see that the fine-tuned model is performing subjectively better than the base model by also mentioning that Amanda baked the cookies.

# Refer-Output 
---
Summary:
Jerry tells Amanda he wants some cookies. Amanda says she will bring him some cookies tomorrow.

... and more ...

Run the model quantization task

To speed up the inference and decrease the model artifact size, we can apply post-training quantization. torchtune relies on torchao for post-training quantization.

We configure the recipe to use Int8DynActInt4WeightQuantizer, which refers to int8 dynamic per token activation quantization combined with int4 grouped per axis weight quantization. For more details, refer to the torchao implementation.

# torchtune model quantization config - config_l3.1_8b_quant.yaml
model:
  _component_: torchtune.models.llama3_1.llama3_1_8b

checkpointer:
  _component_: torchtune.utils.FullModelMetaCheckpointer
  …

quantizer:
  _component_: torchtune.utils.quantization.Int8DynActInt4WeightQuantizer
  groupsize: 256

We again use a single ml.g5.2xlarge instance and use SageMaker warm pool configuration to speed up the spin-up time for the compute nodes:

sagemaker_tasks={
"quantize_trained_model":{
        "hyperparameters":{
            "tune_config_name":"config_l3.1_8b_quant.yaml",
            "tune_action":"run-quant",
            "use_downloaded_model":"true"
            },
        "instance_count":1,
        "instance_type":"ml.g5.2xlarge",
        "image_uri":"<accountid>.dkr.ecr.<region>.amazonaws.com/accelerate:latest"
    }
}

In the output, we see the location of the quantized model and how much memory we saved due to the process:

#Refer-Output
...

linear: layers.31.mlp.w1, in=4096, out=14336
linear: layers.31.mlp.w2, in=14336, out=4096
linear: layers.31.mlp.w3, in=4096, out=14336
linear: output, in=4096, out=128256
INFO:torchtune.utils.logging:Time for quantization: 7.40 sec
INFO:torchtune.utils.logging:Memory used: 22.97 GB
INFO:torchtune.utils.logging:Model checkpoint of size 8.79 GB saved to /opt/ml/input/data/model/quantized/meta_model_0-8da4w.pt

... and more ...

You can run model inference on the quantized model meta_model_0-8da4w.pt by updating the inference-specific configurations.

Run the model evaluation task

Finally, let’s evaluate our fine-tuned model in an objective manner by running an evaluation on the validation portion of our dataset.

torchtune integrates with EleutherAI’s evaluation harness and provides the eleuther_eval recipe.

For our evaluation, we use a custom task for the evaluation harness to evaluate the dialogue summarizations using the rouge metrics.

The recipe configuration points the evaluation harness to our custom evaluation task:

# torchtune trained model evaluation config - config_l3.1_8b_eval_trained.yaml

model:
...

include_path: "/opt/ml/input/data/config/tasks"
tasks: ["samsum"]
...

The following code is the SageMaker task that we run on a single ml.p4d.24xlarge instance:

sagemaker_tasks={
"evaluate_trained_model":{
        "hyperparameters":{
            "tune_config_name":"config_l3.1_8b_eval_trained.yaml",
            "tune_action":"run-eval",
            "use_downloaded_model":"true",
            },
        "instance_count":1,
        "instance_type":"ml.p4d.24xlarge",
    }
}

Run the model evaluation on ml.p4d.24xlarge:

Task="evaluate_trained_model"
estimator=create_pytorch_estimator(**sagemaker_tasks[Task])
execute_task(estimator)

The following tables show the task output for the fine-tuned model as well as the base model.

The following output is for the fine-tuned model.

 

TasksVersionFiltern-shotMetricDirectionValue±Stderr
samsum2noneNonerouge145.8661±N/A
noneNonerouge223.6071±N/A
noneNonerougeL37.1828±N/A

The following output is for the base model.

TasksVersionFiltern-shotMetricDirectionValue±Stderr
samsum2noneNonerouge133.6109±N/A
noneNonerouge213.0929±N/A
noneNonerougeL26.2371±N/A

Our fine-tuned model achieves an improvement of approximately 46% on the summarization task, which is approximately 12 points better than the baseline.

Clean up

Complete the following steps to clean up your resources:

  1. Delete any unused SageMaker Studio resources.
  2. Optionally, delete the SageMaker Studio domain.
  3. Delete the CloudFormation stack to delete the VPC and Amazon EFS resources.

Conclusion

In this post, we discussed how you can fine-tune Meta Llama-like architectures using various fine-tuning strategies on your preferred compute and libraries, using custom dataset prompt templates with torchtune and SageMaker. This architecture gives you a flexible way of running fine-tuning jobs that are optimized for GPU memory and performance. We demonstrated this through fine-tuning a Meta Llama3.1 model using P4 and G5 instances on SageMaker and used observability tools like Weights & Biases to monitor loss curve, as well as CPU and GPU utilization.

We encourage you to use SageMaker training capabilities and PyTorch’s torchtune library to fine-tune Meta Llama-like architectures for your specific business use cases. To stay informed about upcoming releases and new features, refer to the torchtune GitHub repo and the official Amazon SageMaker training documentation .

Special thanks to Kartikay Khandelwal (Software Engineer at Meta), Eli Uriegas (Engineering Manager at Meta), Raj Devnath (Sr. Product Manager Technical at AWS) and Arun Kumar Lokanatha (Sr. ML Solution Architect at AWS) for their support to the launch of this post.


About the Authors

Kanwaljit Khurmi is a Principal Solutions Architect at Amazon Web Services. He works with AWS customers to provide guidance and technical assistance, helping them improve the value of their solutions when using AWS. Kanwaljit specializes in helping customers with containerized and machine learning applications.

Roy Allela is a Senior AI/ML Specialist Solutions Architect at AWS.He helps AWS customers—from small startups to large enterprises—train and deploy large language models efficiently on AWS.

Matthias Reso is a Partner Engineer at PyTorch working on open source, high-performance model optimization, distributed training (FSDP), and inference. He is a co-maintainer of llama-recipes and TorchServe.

Trevor Harvey is a Principal Specialist in Generative AI at Amazon Web Services (AWS) and an AWS Certified Solutions Architect – Professional. He serves as a voting member of the PyTorch Foundation Governing Board, where he contributes to the strategic advancement of open-source deep learning frameworks. At AWS, Trevor works with customers to design and implement machine learning solutions and leads go-to-market strategies for generative AI services.

Fine-tune Meta Llama 3.1 models using torchtune on Amazon SageMaker
Author: Kanwaljit Khurmi