Accelerating Mixtral MoE fine-tuning on Amazon SageMaker with QLoRA
However, building or fine-tuning these pre-trained LLMs on extensive datasets demands substantial computational resources and engineering effort… With the increase in sizes of these pre-trained LLMs, the model customization process becomes complex, time-consuming, and often prohibitively expensiv…
Companies across various scales and industries are using large language models (LLMs) to develop generative AI applications that provide innovative experiences for customers and employees. However, building or fine-tuning these pre-trained LLMs on extensive datasets demands substantial computational resources and engineering effort. With the increase in sizes of these pre-trained LLMs, the model customization process becomes complex, time-consuming, and often prohibitively expensive for most organizations that lack the necessary infrastructure and skilled talent.
In this post, we demonstrate how you can address these challenges by using fully managed environment with Amazon SageMaker Training jobs to fine-tune the Mixtral 8x7B model using PyTorch Fully Sharded Data Parallel (FSDP) and Quantized Low Rank Adaptation (QLoRA).
We guide you through a step-by-step implementation of model fine-tuning on a GEM/viggo dataset, employing the QLoRA fine-tuning strategy on a single p4d.24xlarge
worker node (providing 8 Nvidia A100 40GB GPUs).
Business challenge
Today’s businesses are looking to adopt a variety of LLMs to enhance business applications. Primarily, they’re looking for foundation models (FMs) that are open source (that is, model weights that work without modification from the start) and can offer computational efficiency and versatility. Mistral’s Mixtral 8x7B model, released with open weights under the Apache 2.0 license, is one of the models that has gained popularity with large enterprises due to the high performance that it offers across various tasks. Mixtral employs a sparse mixture of experts (SMoE) architecture, selectively activating only a subset of its parameters for each input during model training. This architecture allows these models to use only 13B (about 18.5%) of its 46.7B total parameters during inference, making it high performing and efficient.
These FMs work well for many use cases but lack domain-specific information that limits their performance at certain tasks. This requires businesses to use fine-tuning strategies to adapt these large FMs to specific domains, thus improving performance on targeted applications. Due to the growing number of model parameters and the increasing context lengths of these modern LLMs, this process is memory intensive and requires advanced AI expertise to align and optimize them effectively. The cost of provisioning and managing the infrastructure increases the overall cost of ownership of the end-to-end solution.
In the upcoming section, we discuss how you can cost-effectively build such a solution with advanced memory optimization techniques using Amazon SageMaker.
Solution overview
To address the memory challenges of fine-tuning LLMs such as Mixtral, we will adopt the QLoRA method. As shown in the following diagram, QLoRA freezes the original model’s weights and adds low-rank trainable parameters to the transformer layers. QLoRA further uses quantization to represent the actual model’s weights in a compact, optimized format such as 4-bit NormalFloat (NF4), effectively compressing the model and reducing its memory footprint. This enables training and fine-tuning these LLMs even on systems with limited memory while maintaining performance comparable to half-precision fine-tuning. QLoRA’s support for double quantization and paged optimizers reduces the memory footprint further by quantizing the quantization constants and effectively handling any sudden memory demands.
During the forward pass computation of this architecture, the 4-bit weights get dequantized to bfloat16 (BF16) precision. On the other hand, the LoRA adapters continue to operate on BF16 precision data. Both (original weights and adapter output vectors) are then added together element-wise to produce the final result, denoted as h.
During the backward pass of the model, the gradients are computed with respect to only the LoRA parameters, not the original base model weights. Although the dequantized original weights are used in calculations, the original 4-bit quantized weights of the base model remain unchanged.
To adopt the following architecture, we will use the Hugging Face Parameter-Efficent Fine-tuning (PEFT) library, which integrates directly with bitsandbytes. This way, the QLoRA technique to fine-tune can be adopted with just a few lines of code.
QLoRA operates on a large FM. In the figure below, X denotes the input tokens of the training data, W is the existing model weights (quantized), and Wa, Wb are the segments of the adapters added by QLoRA. The original model’s weights (W) are frozen, and QLoRA adds adapters (Wa, Wb), which are low-rank trainable parameters, onto the existing transformer layer.
Although QLoRA helps optimize memory during fine-tuning, we will use Amazon SageMaker Training to spin up a resilient training cluster, manage orchestration, and monitor the cluster for failures. By offloading the management and maintenance of the training cluster to SageMaker, we reduce both training time and our total cost of ownership (TCO). Using this approach, you can focus on developing and refining the model while using the fully managed training infrastructure provided by SageMaker Training.
Implementation details
We spin up the cluster by calling the SageMaker control plane through APIs or the AWS Command Line Interface (AWS CLI) or using the SageMaker AWS SDK. In response, SageMaker spins up training jobs with the requested number and type of compute instances. In our example, we use one ml.p4d.24xlarge
compute instance.
To take complete advantage of this multi-GPU cluster, we use the recent support of QLoRA and PyTorch FSDP. Although QLoRA reduces computational requirements and memory footprint, FSDP, a data/model parallelism technique, will help shard the model across all eight GPUs (one ml.p4d.24xlarge
), enabling training the model even more efficiently. Hugging Face PEFT is where the integration happens, and you can read more about it in the PEFT documentation.
QLoRA adapters are added to the linear layers in the model. The layers (for example, transformer layers, gate networks, and feed-forward networks) put together will form the entire model, as shown in the following diagram, which will be considered to be sharded by FSDP across our cluster (shown as small shards in blue).
The following architecture diagram shows how you can use SageMaker Training to have the SageMaker Control Plane spin up a resilient training job cluster. SageMaker downloads the training image from Amazon Elastic Container Registry (Amazon ECR) and will use Amazon Simple Storage Service (Amazon S3) as an input training data source and to store training artifacts.
To put this solution into practice, execute the following use case.
Prerequisites
To perform the solution, you need to have the following prerequisites in place:
- Create a Hugging Face User Access Token and get access to the gated repo mistralai/Mixtral-8x7B-v0.1 on Hugging Face.
- (Optional) Create a Weights & Biases API key to access the Weights & Biases dashboard for logging and monitoring. This is recommended if you’d like to visualize model training specific metrics.
- Request a service quota at Service Quotas for 1x
ml.p4d.24xlarge
on Amazon SageMaker. To request a service quota increase, on the AWS Service Quotas console, navigate to AWS services, Amazon SageMaker, and chooseml.p4d.24xlarge
for training job usage. - Create an AWS Identity and Access Management (IAM) role with managed policies
AmazonSageMakerFullAccess
andAmazonEC2FullAccess
to give required access to SageMaker to run the examples.
This role is for demonstration purposes only. You need to adjust it to your specific security requirements for production. Adhere to the principle of least privilege while defining IAM policies in production.
- (Optional) Create an Amazon SageMaker Studio domain (see Quick setup to Amazon SageMaker) to access Jupyter notebooks with the preceding role. (You can use JupyterLab in your local setup too)
- Clone the GitHub repository with the assets for this deployment. This repository consists of a notebook that references training assets.
The 15_mixtral_finetune_qlora
directory contains the training scripts that you might need to deploy this sample.
Next, we will run the finetune-mixtral.ipynb notebook to fine-tune the Mixtral 8x7B model using QLoRA on SageMaker. Check out the notebook for more details on each step. In the next section, we walk through the key components of the fine-tuning execution.
Solution walkthrough
To perform the solution, follow the steps in the next sections.
Step 1: Set up required libraries
Install the relevant HuggingFace and SageMaker libraries:
Step 2: Load dataset
In this example, we use the GEM/viggo dataset from Hugging Face. This is a data-to-text generation dataset in the video game domain. The dataset is clean and organized with about 5,000 data points, and the responses are more conversational than information seeking. This type of dataset is ideal for extracting meaningful information from customer reviews. For example, an ecommerce application such as Amazon.com could use a similarly formatted dataset for fine-tuning a model for natural language processing (NLP) analysis to gauge interest in products sold. The results can be used for recommendation engines. Thus, this dataset is a good candidate for fine-tuning LLMs. To learn more about the viggo dataset, check out this research paper.
Load the dataset and convert it to the required prompt structure. The prompt is constructed with the following elements:
- Target sentence – Think of this as the final review. In the dataset, this is
target
. - Meaning representation – Think of this as a deconstructed review, broken down by attributes such as
inform
,request
, orgive_opinion
. In the dataset, this ismeaning_representation
.
Running the following cell gives us the train_set
and test_set
(training split and testing split, respectively) with structured prompts. We use the Python map
function to structure the dataset splits according to our prompt.
Upload the dataset to Amazon S3. This step is crucial because the dataset stored in Amazon S3 will serve as the input data channel for the SageMaker training cluster. SageMaker will efficiently manage the process of distributing this data across the training cluster, allowing each node to access the necessary information for model training.
We analyze the distribution of prompt tokens to determine the maximum sequence length required for training our model in the upcoming steps.
The following graph shows the prompt tokens plotted. The x-axis is the length of the prompts, and the y-axis is the number of times that length occurs in the training dataset (frequency). We use this to determine the maximum sequence length and pad the rest of the data points accordingly. The maximum number of words in our example is 173.
Step 3: Configure the parameters for SFTTrainer
for the fine-tuning task
We use TrlParser
to parse hyperparameters in a YAML file that is required to configure SFTTrainer
API for fine-tuning the model. This approach offers flexibility because we can also overwrite the arguments specified in the config file by explicitly passing them through the command line interface.
Step 4: Review the launch script
You are now prepared to fine-tune the model using a combination of PyTorch FSDP and QLoRA. We’ve prepared a script called launch_fsdp_qlora.py
that will perform the tasks mentioned in the following steps. The following is a quick review of the key points in this script before launching the training job.
- Load the dataset from a JSON file located at the specified path, using the
load_dataset
function to prepare it for model training.
- Prepare the tokenizer and the model.
We employ the BitsAndBytes
library to configure 4-bit quantization settings for our model, enabling memory-efficient loading and computation.
By setting parameters such as load_in_4bit
and bnb_4bit_use_double_quant
to True
, we enable a dramatic reduction in model size without significant loss in performance. The nf4
quantization type, coupled with bfloat16
compute and storage data types, allows for nuanced control over the quantization process, striking an optimal balance between model compression and accuracy preservation. This configuration enables the deployment of massive models on resource-constrained hardware, making advanced AI more accessible and practical for a wide range of applications.
- Initiate the training process using SFTTrainer from the Transformer Reinforcement Learning (TRL) library to fine-tune the model. The
SFTTrainer
simplifies the process of supervised fine-tuning for LLMs. This approach makes fine-tuning efficient to adapt pre-trained models to specific tasks or domains.
We use the LoraConfig
class from the Hugging Face’s PEFT library to configure and add LoRA parameters (also called “adapters”) to the model.
Step 5: Fine-tune your model
To fine-tune your model, follow the steps in the next sections.
Launch the training job
You are now ready to launch the training. We use the SageMaker Training estimator, which uses torchrun
to initiate distributed training.
The SageMaker estimator simplifies the training process by automating several key tasks in this example:
- The SageMaker estimator spins up a training cluster of one
ml.p4d.24xlarge
instance. SageMaker handles the setup and management of these compute instances, which reduces your TCO. - This estimator also uses one of the pre-built containers managed by SageMaker, PyTorch, which includes an optimized compiled version of the PyTorch framework and its required dependencies and GPU-specific libraries for accelerated computations.
The training process generates trained adapters that will be saved in a default S3 bucket named sagemaker-<region name>-<account_id>
for this job.
Monitor your training run
You can monitor training metrics, such as loss, and learning rate for your training run through the Weights & Biases Dashboard. The following figures show the results of the training run, where we track GPU utilization and GPU memory utilization.
The example is optimized to use GPU memory to its maximum capacity. Note that increasing the batch size any further will lead to CUDA Out of Memory errors.
The following graph shows the GPU memory utilization (for all eight GPUs) during the training process. You can also observe the GPU memory utilization for any given point in time.
The following graph shows the GPU compute utilization (for all eight GPUs) during the training process. You can also observe the GPU memory utilization for any given point in time.
Step 6: Merge the trained adapter with the base model for inference
Merge the training LoRA adapter with the base model. After the merge is complete, run inference to find the results. Specifically, look at how the new fine-tuned and merged model performs compared to the original unmodified Mixtral-8x7b model. The example does the adapter merge and inference both in the same launch script “merge_model_adapter.py.”
Before launching the training job, review the key components of the merge script:
Use the Hugging Face Transformers library. Specifically, use AutoModelForCausalLM
to load a PEFT model from a specified HuggingFace model directory (mistralai/Mixtral-8x7B-v0.1). We have configured this library to have a low CPU memory utilization (low_cpu_mem_usage=True
) to reduce the CPU to GPU communication overhead, and we’ve also used automatic device mapping (device_map="auto"
) while offloading the model to a designated folder to manage resource constraints.
After the model is merged, send inference requests to generate responses.
Step 7: Launch the SageMaker training job to merge the adapter
Run the following script as part of the SageMaker training job.
First, explore the adapters that were saved as part of the training run.
Create and run the PyTorch estimator to configure the training job.
Here’s the target sentence
(key prompt) to generate model inference results:
Ground truth inference (data label):
Original model inference (that is, meaning representation
):
Fine-tuned model inference result (that is, meaning representation
):
The preceding results compare the inference results of the fine-tuned model against both the ground truth and the inference results of the original unmodified Mixtral 8x7B model. You can observe that the fine-tuned model provides more details and better representation of the meaning than the base model. Run systematic evaluation to quantify the fine-tuned model’s improvements for your production workloads.
Clean up
To clean up your resources to avoid incurring any more charges, follow these steps:
- Delete any unused SageMaker Studio resources.
- (Optional) Delete the SageMaker Studio domain.
- Verify that your training job isn’t running anymore. To do so, on your SageMaker console, choose Training and check Training jobs.
To learn more about cleaning up your provisioned resources, check out Clean up.
Conclusion
In this post, we provided you with a step-by-step guide to fine-tune the Mixtral 8x7B MoE model with QLoRA. We use SageMaker Training Jobs and the Hugging Face PEFT
package for QLoRA, with bitsandbytes
for quantization together to perform the fine-tuning task. The fine-tuning was conducted using the quantized model loaded on a single compute instance, which eliminates the need of a larger cluster. As observed, the model performance improved with just 50 epochs.
To learn more about Mistral on AWS and to find more examples, check out the mistral-on-aws GitHub repository. To get started, check out the notebook on the mixtral_finetune_qlora GitHub repository. To learn more about generative AI on AWS, check out Generative AI on AWS, Amazon Bedrock, and Amazon SageMaker.
About the Authors
Aman Shanbhag is an Associate Specialist Solutions Architect on the ML Frameworks team at Amazon Web Services, where he helps customers and partners with deploying ML training and inference solutions at scale. Before joining AWS, Aman graduated from Rice University with degrees in computer science, mathematics, and entrepreneurship.
Kanwaljit Khurmi is an AI/ML Principal Solutions Architect at Amazon Web Services. He works with AWS product teams, engineering, and customers to provide guidance and technical assistance for improving the value of their hybrid ML solutions when using AWS. Kanwaljit specializes in helping customers with containerized and machine learning applications.
Nishant Karve is a Sr. Solutions Architect aligned with the healthcare and life sciences (HCLS) domain. He collaborates with large HCLS customers for their generative AI initiatives and guides them from ideation to production.
Author: Aman Shanbhag