Perform batch transforms with Amazon SageMaker Jumpstart Text2Text Generation large language models

TutoSartup excerpt from this article:
Today we are excited to announce that you can now perform batch transforms with Amazon SageMaker JumpStart large language models (LLMs) for Text2Text Generation… For batch transform, a batch job is run that takes batch input as a dataset and a pre-trained model, and outputs predictions for each da…

Today we are excited to announce that you can now perform batch transforms with Amazon SageMaker JumpStart large language models (LLMs) for Text2Text Generation. Batch transforms are useful in situations where the responses don’t need to be real time and therefore you can do inference in batch for large datasets in bulk. For batch transform, a batch job is run that takes batch input as a dataset and a pre-trained model, and outputs predictions for each data point in the dataset. Batch transform is cost-effective because unlike real-time hosted endpoints that have persistent hardware, batch transform clusters are torn down when the job is complete and therefore the hardware is only used for the duration of the batch job.

In some use cases, real-time inference requests can be grouped in small batches for batch processing to create real-time or near-real-time responses. For example, if you need to process a continuous stream of data with low latency and high throughput, invoking a real-time endpoint for each request separately would require more resources and can take longer to process all the requests because the processing is being done serially. A better approach would be to group some of the requests and call the real-time endpoint in batch inference mode, which processes your requests in one forward pass of the model and returns the bulk response for the request in real time or near-real time. The latency of the response will depend upon how many requests you group together and instance memory size, therefore you can tune the batch size per your business requirements for latency and throughput. We call this real-time batch inference because it combines the concept of batching while still providing real-time responses. With real-time batch inference, you can achieve a balance between low latency and high throughput, enabling you to process large volumes of data in a timely and efficient manner.

Jumpstart batch transform for Text2Text Generation models allows you to pass the batch hyperparameters through environment variables that further increase throughput and minimize latency.

JumpStart provides pretrained, open-source models for a wide range of problem types to help you get started with machine learning (ML). You can incrementally train and tune these models before deployment. JumpStart also provides solution templates that set up infrastructure for common use cases, and executable example notebooks for ML with Amazon SageMaker. You can access the pre-trained models, solution templates, and examples through the JumpStart landing page in Amazon SageMaker Studio. You can also access JumpStart models using the SageMaker Python SDK.

In this post, we demonstrate how to use the state-of-the-art pre-trained text2text FLAN T5 models from Hugging Face for batch transform and real-time batch inference.

Solution overview

The notebook showing batch transform of pre-trained Text2Text FLAN T5 models from Hugging Face in available in the following GitHub repository. This notebook uses data from the Hugging Face cnn_dailymail dataset for a text summarization task using the SageMaker SDK.

The following are the key steps for implementing batch transform and real-time batch inference:

  1. Set up prerequisites.
  2. Select a pre-trained model.
  3. Retrieve artifacts for the model.
  4. Specify batch transform job hyperparameters.
  5. Prepare data for the batch transform.
  6. Run the batch transform job.
  7. Evaluate the summarization using a ROUGE (Recall-Oriented Understudy for Gisting Evaluation) score.
  8. Perform real-time batch inference.

Set up prerequisites

Before you run the notebook, you must complete some initial setup steps. Let’s set up the SageMaker execution role so it has permissions to run AWS services on your behalf:

sagemaker_session = Session()
aws_role = sagemaker_session.get_caller_identity_arn()
aws_region = boto3.Session().region_name
sess = sagemaker.Session()

Select a pre-trained model

We use the huggingface-text2text-flan-t5-large model as a default model. Optionally, you can retrieve the list of available Text2Text models on JumpStart and choose your preferred model. This method provides a straightforward way to select different model IDs using same notebook. For demonstration purposes, we use the huggingface-text2text-flan-t5-large model:

model_id, model_version, = (
"huggingface-text2text-flan-t5-large",
"*",
)

Retrieve artifacts for the model

With SageMaker, we can perform inference on the pre-trained model, even without fine-tuning it first on a new dataset. We start by retrieving the deploy_image_uri, deploy_source_uri, and model_uri for the pre-trained model:

inference_instance_type = "ml.p3.2xlarge"

# Retrieve the inference docker container uri. This is the base HuggingFace container image for the default model above.
deploy_image_uri = image_uris.retrieve(
region=None,
framework=None, # automatically inferred from model_id
image_scope="inference",
model_id=model_id,
model_version=model_version,
instance_type=inference_instance_type,
)

# Retrieve the model uri.
model_uri = model_uris.retrieve(
model_id=model_id, model_version=model_version, model_scope="inference"
)

#Create the SageMaker model instance
model = Model(
image_uri=deploy_image_uri,
model_data=model_uri,
role=aws_role,
predictor_cls=Predictor)

Specify batch transform job hyperparameters

You may pass any subset of hyperparameters as environment variables to the batch transform job. You can also pass these hyperparameters in a JSON payload. However, if you’re setting environment variables for hyperparameters like the following code shows, then the advanced hyperparameters from the individual examples in the JSON lines payload will not be used. If you want to use hyperparameters from the payload, you may want to set the hyper_params_dict parameter as null instead.

#Specify the Batch Job Hyper Params Here, If you want to treate each example hyperparameters different please pass hyper_params_dict as None
hyper_params = {"batch_size":4, "max_length":50, "top_k": 50, "top_p": 0.95, "do_sample": True}
hyper_params_dict = {"HYPER_PARAMS":str(hyper_params)}

Prepare data for batch transform

Now we’re ready to load the cnn_dailymail dataset from Hugging Face:

cnn_test = load_dataset('cnn_dailymail','3.0.0',split='test')

We go over each data entry and create the input data in the required format. We create an articles.jsonl file as a test data file containing articles that need to be summarized as input payload. As we create this file, we append the prompt "Briefly summarize this text:" to each test input row. If you want to have different hyperparameters for each test input, you can append those hyperparameters as part of creating the dataset.

We create highlights.jsonl as the ground truth file containing highlights of each article stored in the test file articles.jsonl. We store both test files in an Amazon Simple Storage Service (Amazon S3) bucket. See the following code:

#You can specify a prompt here
prompt = "Briefly summarize this text: "
#Provide the test data and the ground truth file name
test_data_file_name = "articles.jsonl"
test_reference_file_name = 'highlights.jsonl'

test_articles = []
test_highlights =[]

# We will go over each data entry and create the data in the input required format as described above
for id, test_entry in enumerate(cnn_test):
    article = test_entry['article']
    highlights = test_entry['highlights']
    # Create a payload like this if you want to have different hyperparameters for each test input
    # payload = {"id": id,"text_inputs": f"{prompt}{article}", "max_length": 100, "temperature": 0.95}
    # Note that if you specify hyperparameter for each payload individually, you may want to ensure that hyper_params_dict is set to None instead
    payload = {"id": id,"text_inputs": f"{prompt}{article}"}
    test_articles.append(payload)
    test_highlights.append({"id":id, "highlights": highlights})

with open(test_data_file_name, "w") as outfile:
    for entry in test_articles:
        outfile.write("%sn" % json.dumps(entry))

with open(test_reference_file_name, "w") as outfile:
    for entry in test_highlights:
        outfile.write("%sn" % json.dumps(entry))

# Uploading the data        
s3 = boto3.client("s3")
s3.upload_file(test_data_file_name, output_bucket, os.path.join(output_prefix + "/batch_input/articles.jsonl"))

Run the batch transform job

When you start a batch transform job, SageMaker launches the necessary compute resources to process the data, including CPU or GPU instances depending on the selected instance type. During the batch transform job, SageMaker automatically provisions and manages the compute resources required to process the data, including instances, storage, and networking resources. When the batch transform job is complete, the compute resources are automatically cleaned up by SageMaker. This means that the instances and storage used during the job are stopped and removed, freeing up resources and minimizing cost. See the following code:

# Creating the Batch transformer object
batch_transformer = model.transformer(
    instance_count=1,
    instance_type=inference_instance_type,
    output_path=s3_output_data_path,
    assemble_with="Line",
    accept="text/csv",
    max_payload=1,
    env = hyper_params_dict
)

# Making the predications on the input data
batch_transformer.transform(s3_input_data_path, content_type="application/jsonlines", split_type="Line")

batch_transformer.wait()

The following is one example record from the articles.jsonl test file. Note that record in this file has an ID that matched with predict.jsonl file records that shows a summarized record as output from the Hugging Face Text2Text model. Similarly, the ground truth file also has a matching ID for the data record. The matching ID across the test file, ground truth file, and output file allows linking input records with output records for easy interpretation of the results.

The following is the example input record provided for summarization:

{"id": 0, "text_inputs": "Briefly summarize this text: (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court's treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What's objectionable is the attempts to undermine international justice, not Palestine's decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court's decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder and Faith Karimi contributed to this report."}

The following is the predicted output with summarization:

{'id': 0, 'generated_texts': ['The Palestinian Authority officially became a member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories.']}

The following is the ground truth summarization for model evaluation purposes:

{"id": 0, "highlights": "Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June .nIsrael and the United States opposed the move, which could open the door to war crimes investigations against Israelis ."}

Next, we use the ground truth and predicted outputs for model evaluation.

Evaluate the model using a ROUGE score¶

ROUGE, or Recall-Oriented Understudy for Gisting Evaluation, is a set of metrics and a software package used for evaluating automatic summarization and machine translation in natural language processing. The metrics compare an automatically produced summary or translation against a reference (human-produced) summary or translation or a set of references.

In the following code, we combine the predicted and original summaries by joining them on the common key id and use this to compute the ROUGE score:

# Downloading the predictions
s3.download_file(
output_bucket, output_prefix + "/batch_output/" + "articles.jsonl.out", "predict.jsonl"
)

with open('predict.jsonl', 'r') as json_file:
json_list = list(json_file)

# Creating the prediction list for the dataframe
predict_dict_list = []
for predict in json_list:
if len(predict) > 1:
predict_dict = ast.literal_eval(predict)
predict_dict_req = {"id": predict_dict["id"], "prediction": predict_dict["generated_texts"][0]}
predict_dict_list.append(predict_dict_req)

# Creating the predictions dataframe
predict_df = pd.DataFrame(predict_dict_list)

test_highlights_df = pd.DataFrame(test_highlights)

# Combining the predict dataframe with the original summarization on id to compute the rouge score
df_merge = test_highlights_df.merge(predict_df, on="id", how="left")

rouge = evaluate.load('rouge')
results = rouge.compute(predictions=list(df_merge["prediction"]),references=list(df_merge["highlights"]))
print(results)
{'rouge1': 0.32749078992945646, 'rouge2': 0.126038645005132, 'rougeL': 0.22764277967933363, 'rougeLsum': 0.28162915746368966}

Perform real-time batch inference

Next, we show you how to run real-time batch inference on the endpoint by providing the inputs as a list. We use the same model ID and dataset as earlier, except we take a few records from the test dataset and use them to invoke a real-time endpoint.

The following code shows how to create and deploy a real-time endpoint for real-time batch inference:

from sagemaker.utils import name_from_base
endpoint_name = name_from_base(f"jumpstart-example-{model_id}")
# deploy the Model. Note that we need to pass Predictor class when we deploy model through Model class,
# for being able to run inference through the sagemaker API.
model_predictor = model.deploy(
    initial_instance_count=1,
    instance_type=inference_instance_type,
    predictor_cls=Predictor,
    endpoint_name=endpoint_name
)

Next, we prepare our input payload. For this, we use the data that we prepared earlier and extract the first 10 test inputs and append the text inputs with hyperparameters that we want to use. We provide this payload to the real-time invoke_endpoint. The response payload is then returned as a list of responses. See the following code:

#Provide all the text inputs to the model as a list
text_inputs = [entry["text_inputs"] for entry in test_articles[0:10]]

# The information about the different Parameters is provided above
payload = {
"text_inputs": text_inputs,
"max_length": 50,
"num_return_sequences": 1,
"top_k": 50,
"top_p": 0.95,
"do_sample": True,
"batch_size": 4
}


def query_endpoint_with_json_payload(encoded_json, endpoint_name):
client = boto3.client("runtime.sagemaker")
response = client.invoke_endpoint(
EndpointName=endpoint_name, ContentType="application/json", Body=encoded_json
)
return response


query_response = query_endpoint_with_json_payload(
json.dumps(payload).encode("utf-8"), endpoint_name=endpoint_name
)


def parse_response_multiple_texts(query_response):
model_predictions = json.loads(query_response["Body"].read())
return model_predictions

generated_text_list = parse_response_multiple_texts(query_response)
print(*generated_text_list, sep='n')

Clean up

After you have tested the endpoint, make sure you delete the SageMaker inference endpoint and delete the model to avoid incurring charges.

Conclusion

In this notebook, we performed a batch transform to showcase the Hugging Face Text2Text Generator model for summarization tasks. Batch transform is advantageous in obtaining inferences from large datasets without requiring a persistent endpoint. We linked input records with inferences to aid in result interpretation. We used the ROUGE score to compare the test data summarization with the model-generated summarization.

Additionally, we demonstrated real-time batch inference, where you can send a small batch of data to a real-time endpoint to achieve a balance between latency and throughput for scenarios like streaming input data. Real-time batch inference helps increase throughput for real-time requests.

Try out the batch transform with Text2Text Generation models in SageMaker today and let us know your feedback!


About the authors

Hemant Singh is a Machine Learning Engineer with experience in Amazon SageMaker JumpStart and Amazon SageMaker built-in algorithms. He got his masters from Courant Institute of Mathematical Sciences and B.Tech from IIT Delhi. He has experience in working on a diverse range of machine learning problems within the domain of natural language processing, computer vision, and time series analysis.

Rachna Chadha is a Principal Solutions Architect AI/ML in Strategic Accounts at AWS. Rachna is an optimist who believes that the ethical and responsible use of AI can improve society in future and bring economic and social prosperity. In her spare time, Rachna likes spending time with her family, hiking, and listening to music.

Dr. Ashish Khetan is a Senior Applied Scientist with Amazon SageMaker built-in algorithms and helps develop machine learning algorithms. He got his PhD from University of Illinois Urbana-Champaign. He is an active researcher in machine learning and statistical inference, and has published many papers in NeurIPS, ICML, ICLR, JMLR, ACL, and EMNLP conferences.

Perform batch transforms with Amazon SageMaker Jumpstart Text2Text Generation large language models
Author: Hemant Singh