Improve prediction quality in custom classification models with Amazon Comprehend
Organizations have started to use AI/ML services like Amazon Comprehend to build classification models with their unstructured data to get deep insights that they didn’t have before… Although you can use pre-trained models with minimal effort, without proper data curation and model tuning, you …
Artificial intelligence (AI) and machine learning (ML) have seen widespread adoption across enterprise and government organizations. Processing unstructured data has become easier with the advancements in natural language processing (NLP) and user-friendly AI/ML services like Amazon Textract, Amazon Transcribe, and Amazon Comprehend. Organizations have started to use AI/ML services like Amazon Comprehend to build classification models with their unstructured data to get deep insights that they didn’t have before. Although you can use pre-trained models with minimal effort, without proper data curation and model tuning, you can’t realize the full benefits AI/ML models.
In this post, we explain how to build and optimize a custom classification model using Amazon Comprehend. We demonstrate this using an Amazon Comprehend custom classification to build a multi-label custom classification model, and provide guidelines on how to prepare the training dataset and tune the model to meet performance metrics such as accuracy, precision, recall, and F1 score. We use the Amazon Comprehend model training output artifacts like a confusion matrix to tune model performance and guide you on improving your training data.
Solution overview
This solution presents an approach to building an optimized custom classification model using Amazon Comprehend. We go through several steps, including data preparation, model creation, model performance metric analysis, and optimizing inference based on our analysis. We use an Amazon SageMaker notebook and the AWS Management Console to complete some of these steps.
We also go through best practices and optimization techniques during data preparation, model building, and model tuning.
Prerequisites
If you don’t have a SageMaker notebook instance, you can create one. For instructions, refer to Create an Amazon SageMaker Notebook Instance.
Prepare the data
For this analysis, we use the Toxic Comment Classification dataset from Kaggle. This dataset contains 6 labels with 158,571 data points. However, each label only has less than 10% of the total data as positive examples, with two of the labels having less than 1%.
We convert the existing Kaggle dataset to the Amazon Comprehend two-column CSV format with the labels split using a pipe (|) delimiter. Amazon Comprehend expects at least one label for each data point. In this dataset, we encounter several data points that don’t fall under any of the provided labels. We create a new label called clean and assign any of the data points that aren’t toxic to be positive with this label. Finally, we split the curated datasets into training and test datasets using an 80/20 ratio split per label.
We will be using the Data-Preparation notebook. The following steps use the Kaggle dataset and prepare the data for our model.
- On the SageMaker console, choose Notebook instances in the navigation pane.
- Select the notebook instance you have configured and choose Open Jupyter.
- On the New menu, choose Terminal.
- Run the following commands in the terminal to download the required artifacts for this post:
- Close the terminal window.
You should see three notebooks and train.csv files.
- Choose the notebook Data-Preparation.ipynb.
- Run all the steps in the notebook.
These steps prepare the raw Kaggle dataset to serve as curated training and test datasets. Curated datasets will be stored in the notebook and Amazon Simple Storage Service (Amazon S3).
Consider the following data preparation guidelines when dealing with large-scale multi-label datasets:
- Datasets must have a minimum of 10 samples per label.
- Amazon Comprehend accepts a maximum of 100 labels. This is a soft limit that can be increased.
- Ensure the dataset file is correctly formatted with the proper delimiter. Incorrect delimiters can introduce blank labels.
- All the data points must have labels.
- Training and test datasets should have balanced data distribution per label. Don’t use random distribution because it might introduce bias in the training and test datasets.
Build a custom classification model
We use the curated training and test datasets we created during the data preparation step to build our model. The following steps create an Amazon Comprehend multi-label custom classification model:
- On the Amazon Comprehend console, choose Custom classification in the navigation pane.
- Choose Create new model.
- For Model name, enter toxic-classification-model.
- For Version name, enter 1.
- For Annotation and data format, choose Using Multi-label mode.
- For Training dataset, enter the location of the curated training dataset on Amazon S3.
- Choose Customer provided test dataset and enter the location of the curated test data on Amazon S3.
- For Output data, enter the Amazon S3 location.
- For IAM role, select Create an IAM role, specify the name suffix as “comprehend-blog”.
- Choose Create to start the custom classification model training and model creation.
The following screenshot shows the custom classification model details on the Amazon Comprehend console.
Tune for model performance
The following screenshot shows the model performance metrics. It includes key metrics like precision, recall, F1 score, accuracy, and more.
After the model is trained and created, it will generate the output.tar.gz file, which contains the labels from the dataset as well as the confusion matrix for each of the labels. To further tune the model’s prediction performance, you have to understand your model with the prediction probabilities for each class. To do this, you need to create an analysis job to identify the scores Amazon Comprehend assigned to each of the data points.
Complete the following steps to create an analysis job:
- On the Amazon Comprehend console, choose Analysis jobs in the navigation pane.
- Choose Create job.
- For Name, enter
toxic_train_data_analysis_job
. - For Analysis type, choose Custom classification.
- For Classification models and flywheels, specify
toxic-classification-model
. - For Version, specify 1.
- For Input data S3 location, enter the location of the curated training data file.
- For Input format, choose One document per line.
- For Output data S3 location, enter the location.
- For Access Permissions, select Use an existing IAM Role and pick the role created previously.
- Choose Create job to start the analysis job.
- Select the Analysis jobs to view the job details. Please take a note of the job id under Job details. We will be using the job id in our next step.
Repeat the steps to the start analysis job for the curated test data. We use the prediction outputs from our analysis jobs to learn about our model’s prediction probabilities. Please make note of job ids of training and test analysis jobs.
We use the Model-Threshold-Analysis.ipynb notebook to test the outputs on all possible thresholds and score the output based on the prediction probability using the scikit-learn’s precision_recall_curve
function. Additionally, we can compute the F1 score at each threshold.
We will need the Amazon Comprehend analysis job id’s as input for Model-Threshold-Analysis notebook. You can get the job ids from Amazon Comprehend console. Execute all the steps in Model-Threshold-Analysis notebook to observe the thresholds for all the classes.
Notice how precision goes up as the threshold goes up, while the inverse occurs with recall. To find the balance between the two, we use the F1 score where it has visible peaks in their curve. The peaks in the F1 score correspond to a particular threshold that can improve the model’s performance. Notice how most of the labels fall around the 0.5 mark for the threshold except for threat label, which has a threshold around 0.04.
We can then use this threshold for specific labels that are underperforming with just the default 0.5 threshold. By using the optimized thresholds, the results of the model on the test data improve for the label threat from 0.00 to 0.24. We are using the max F1 score at the threshold as a benchmark to determine positive vs. negative for that label instead of a common benchmark (a standard value like > 0.7) for all the labels.
Handling underrepresented classes
Another approach that’s effective for an imbalanced dataset is oversampling. By oversampling the underrepresented class, the model sees the underrepresented class more often and emphasizes the importance of those samples. We use the Oversampling-underrepresented.ipynb notebook to optimize the datasets.
For this dataset, we tested how the model’s performance on the evaluation dataset changes as we provide more samples. We use the oversampling technique to increase the occurrence of underrepresented classes to improve the performance.
In this particular case, we tested on 10, 25, 50, 100, 200, and 500 positive examples. Notice that although we are repeating data points, we are inherently improving the performance of the model by emphasizing the importance of the underrepresented class.
Cost
With Amazon Comprehend, you pay as you go based on the number of text characters processed. Refer to Amazon Comprehend Pricing for actual costs.
Clean up
When you’re finished experimenting with this solution, clean up your resources to delete all the resources deployed in this example. This helps you avoid continuing costs in your account.
Conclusion
In this post, we have provided best practices and guidance on data preparation, model tuning using prediction probabilities and techniques to handle underrepresented data classes. You can use these best practices and techniques to improve the performance metrics of your Amazon Comprehend custom classification model.
For more information about Amazon Comprehend, visit Amazon Comprehend developer resources to find video resources and blog posts, and refer to AWS Comprehend FAQs.
About the Authors
Sathya Balakrishnan is a Sr. Customer Delivery Architect in the Professional Services team at AWS, specializing in data and ML solutions. He works with US federal financial clients. He is passionate about building pragmatic solutions to solve customers’ business problems. In his spare time, he enjoys watching movies and hiking with his family.
Prince Mallari is an NLP Data Scientist in the Professional Services team at AWS, specializing in applications of NLP for public sector customers. He is passionate about using ML as a tool to allow customers to be more productive. In his spare time, he enjoys playing video games and developing one with his friends.
Author: Sathya Balakrishnan