Occam’s Sheath: A Simpler Approach to AI Safety Guardrails
Tl;dr
Large decoder LLMs (Llama, Gemma, Mistral, etc.) are being tuned to classify undesirable content into and out of chat-style LLMs via next-token prediction, text generation tasks (generating a “yes” or “no”). Smaller encoder models (BERT models), however, have been proven over the years to be strong sequence classifiers. The model, Intel/toxic-prompt-roberta
, was fine-tuned from Roberta-base on both Jigsaw Unintended Bias and ToxicChat datasets on a Gaudi 2 HPU using optimum-habana
. We saw that it performed better on the ToxicChat dataset than the decoder LLMs that have 20-56 times as many parameters. This model is currently in the proof-of-concept stage, and more testing and experimenting is required in future work.
Image courtesy of glif/91s-anime-art. Output of model when prompted with “Robot holding a razor”.
Introduction
As the advances of LLMs continue to surge, so do the concerns for their adversarial and harmful usage. The vast amount of data that today’s LLMs are pre-trained on has produced multifaceted generative models that have proven to surpass initial expectations. Self-supervised learning has provided LLMs with a broad knowledgebase capable of being a jack-of-all-trades, but they are not necessarily a master of any of them. User input prompts containing content such as jailbreaking, criminal planning or self-harm can result in serious negative consequences if the LLM is not tuned appropriately to handle such inputs. Chat models can hallucinate when asked how long to cook chicken – providing misinformation that could harm its users. In the arena of chat and question-answering models, new LLMs, known as LLM judges, guardrails or safety guard models, have been created to classify whether input and/or output of a chatbot is safe. Tangential to the trends in parameter sizes of today’s complex decoder-LLMs (Llama, GPT, Gemini, Mistral…), many of these safeguard LLMs are also in the range of billions of parameters. In this article, we propose an alternative, simpler solution to safeguarding chatbots – a solution that invites the discussion of the principle of Occam’s Razor: is simpler better?
A brief survey of popular guardrail LLMs
Before arriving at the simpler solution, let’s look at some of the most popular guardrail models currently available. From input syntax and format checkers to output bias neutralization, there is a plethora of guardrail systems with varying capabilities and use cases. For the scope of this article, we will be focusing on LLMs that safeguard against toxic, harmful and undesired language – both on input and output of chat models. The following list touches on a few guardrail LLMs that are fine-tuned or instruction-tuned from state-of-the-art LLM base models.
Llama Guard from Meta [1]
- Base model: Llama 3.2
- Model task: text generation
- Training: fine-tuned on a mixture of handpicked user prompt input data from Anthropic and synthetic response data from Llama.
- Risk Taxonomy:
- Violent Crimes
- Non-Violent Crimes
- Sex-Related Crimes
- Child Sexual Exploitation
- Defamation
- Specialized Advice
- Privacy
- Intellectual Property
- Indiscriminate Weapons
- Hate
- Suicide & Self-Harm
- Sexual Content
- Elections
- Model size: 8.03B
WildGuard from Allen Institute for AI [2]
- Base model: Mistral 0.3
- Model task: text generation
- Training: fine-tuned on WildGuardTrain - a mixture of synthetic and real-world user input prompts.
- Risk Taxonomy:
- Privacy
- Sensitive Information (Organization)
- Private Information (Individual)
- Copyright Violations
- Misinformation
- False or Misleading Information
- Material Harm by Misinformation
- Harmful language
- Social Stereotypes & Discrimination
- Violence and Physical Harm
- Toxic Language & Hate Speech
- Sexual Content
- Malicious uses
- Cyberattacks
- Fraud & Assisting Illegal Activities
- Encouraging Unethical/Unsafe Actions
- Mental Health & Over-Reliance Crisis
- Privacy
- Model size: 7B
ShieldGemma from Google [3]
- Base model: Gemma 2
- Model task: text generation
- Training: fine-tuned and instruction-tuned on internal synthetic data generated by Gemma public data.
- Risk Taxonomy:
- Sexually Explicit Information
- Hate Speech
- Dangerous Content
- Harassment
- Violence
- Obscenity and Profanity
- Model size: 2.6B
Given that all the guardrail models above are text generation LLMs, their usage is relatively the same. They require a string or Python dictionary instruction format as input with a template consisting of an instruction/guideline for the guardrail model, user input prompt and/or target LLM output response. They then generate a string with either a “yes” or “no” classification along with one of the risk categories (mentioned above) associated with the classification, if it was deemed undesired. Below is a diagram of what this might typically look like in conjunction with the target LLM they are safeguarding.
Designing the guardrail model this way allows for more flexibility in two ways:
- The chat service maintainer can decide if they want to monitor the input prompt, chat model output response or both.
- The maintainer can modify the guideline to focus on certain harms without having to fine-tune the model again.
A simpler alternative
As can be seen from their respective papers, these state-of-the-art guardrail models show promising results in robustness as they can perform well against many public datasets unseen during fine-tuning. Ultimately, these large decoder-only models, pre-trained to generate coherent and contextually relevant text passages, are being instructed to generate a “yes” or “no” answer. There’s an older family of LLMs, however, that have proven themselves effective and efficient in providing a “yes” or “no” for many different use cases. Instead of using text generation for classification, encoder-only BERT models can utilize classical binary classification to tackle this same problem. Since the inception of the large decoder models, you can find many studies (such as this one) that demonstrate that bigger does not mean better, in the realm of text classification.
The law of the instrument, AKA Maslow’s hammer, states that “if the only tool you have is a hammer, everything looks like a nail.” Perhaps in the current LLM landscape, text generation decoder models have become hammers, and problems like toxic prompt classification might only need a screwdriver.
For the remainder of this article, we will show how fine-tuning RoBERTa-base, a model that is 20 times smaller than ShieldGemma and 56 times smaller than WildGuard, can still prove to be effective in classifying toxic input prompts to chat models.
Fine-tuning Intel/toxic-prompt-roberta
We chose to fine-tune RoBERTa-base (125M parameters) [4] on two toxicity datasets:
- Jigsaw Unintended Bias in Toxicity Classification
- Real-world data of comments found online
- Positive label is defined as “a rude, disrespectful, or unreasonable comment that is somewhat likely to make you leave a discussion or give up on sharing your perspective”[4]
- ToxicChat
- Real-world user-AI dialogues
- “Labels are based on the definitions for undesired content in (Zampieri et al., 2019), and the annotators adopt a binary value for toxicity label”. [5]
We selected these two datasets with the intention of having a toxic prompt classification model that is robust against demographic biases. Beyond the toxicity label, the Jigsaw dataset also contains target identity information (i.e. the group of people being talked about in the comment) that can help monitor if model is being biased towards any demographic. For the scope of this investigation, we only collected the user input text from ToxicChat.
Given that the two datasets are significantly different domains (comments vs. questions/prompts), we chose to fine-tune the model on each dataset separately, instead of interleaving the datasets into one fine-tuning session. We also decided to fine-tune on the Jigsaw dataset first, then on the ToxicChat (TC) dataset second because its domain is more aligned with the domain of our model objective. This left us with two checkpoints – a model fine-tuned on Jigsaw, and a model fine-tuned on both Jigsaw and TC.
All the fine-tuning was executed on a single Intel Gaudi 2 HPU using optimum-habana
. Each dataset was trained for 3 epochs each with a batch size of 32. We also fine-tuned another RoBERTa checkpoint on only TC for a baseline comparison. Ultimately, we anticipated that our final checkpoint, fine-tuned on both datasets, would perform better than the Jigsaw-only checkpoint on the TC test set and better than the TC-only checkpoint on the Jigsaw test set.
Results
Most toxicity classification studies use the area under the receiver operating characteristics and precision-recall curves (AUROC and AUPRC, respectively) to measure model performance. This is due to the intentional label imbalance that many toxicity datasets contain in order to mirror toxic language distribution in the real-world. The plots below use the following naming convention for the checkpoints we fine-tuned:
- jigsaw: Initial checkpoint that was fine-tuned only on Jigsaw Unintended Bias
- tc: Checkpoint fine-tuned only on ToxicChat for baseline comparison
- jigsaw+tc: Final checkpoint (Intel/toxic-prompt-roberta) fine-tuned on Jigsaw and ToxicChat
Jigsaw test results
First, let’s look at how our three checkpoints performed on the Jigsaw test dataset. As we had hoped, the ROC and PR curves show that our final model performed better than the baseline TC checkpoint, but it also performed worse than the ceiling, Jigsaw-only checkpoint. This was to be expected since the Jigsaw+TC checkpoint weights were further updated after the initial Jigsaw fine-tuning to fit the TC training dataset.
The confusion matrices in the diagram below add a bit more detail to the story, however. Our Jigsaw-only checkpoint performed the best in classifying true negatives (0.982), but it came at the cost of having a relatively high false negative rate (0.264). A false negative is the most hazardous outcome in this context because it means that the model misclassified a toxic comment as safe. Surprisingly, the final model did a better job of reducing this false negative rate (0.164). Something about the TC fine-tuning that led the gradient descent in a direction that made the model more sensitive to toxic comments.
Since the Jigsaw Unintended Bias dataset includes target identities, the bar chart below shows how our final model improved upon the baseline TC model indicating that initially fine-tuning on Jigsaw did benefit the final model’s performance across demographic identities. We only included the identities in the test set that had more than 500 examples.
ToxicChat test results
Let’s now turn our attention towards how each checkpoint performed on the TC test dataset. In a similar fashion, the final model trained on both datasets performed better than the Jigsaw-only model. We do notice, however, that the AUROC is actually slightly better in our final model than the TC-only model.
Looking at our confusion matrices can confirm this because the final model has a slightly higher (1.6%) true positive rate. On the other side of the spectrum, we can see that the Jigsaw-only model is essentially just labeling everything in the TC test set as not toxic, confirming the baseline expectations.
Now that we’ve compared how well the final model performed relative to its building block counterparts, let’s see how it performs relative to some of 1B+ parameter decoders in the table below.
Model | Parameters | Precision | Recall | F1 | AUPRC | AUROC |
---|---|---|---|---|---|---|
LlamaGuard1 | 6.74B | 0.481 | 0.795 | 0.599 | 0.626* | - |
LlamaGuard3 | 8.03B | 0.508 | 0.473 | 0.490 | - | - |
ShieldGemma | 2.61B | - | - | 0.704* | 0.778* | - |
WildGuard | 7B | - | - | 0.708* | - | - |
Toxic Prompt Roberta | 125M | 0.832 | 0.747 | 0.787 | 0.855 | 0.971 |
We ran LlamaGuard 1 and 3 on the TC test set and recorded their precision, recall and F1. The numbers with the “*” indicates that they were collected from their respective papers and/or model cards. The paper results and our manual testing results used the same test examples because TC provides the training and testing split.
It should be noted that these models were never fine-tuned on TC, which speaks to how robust the larger LLMs are. The LlamaGuard 1 paper, however, did experiment with fine-tuning on TC and greatly improved the AUPRC to roughly 0.81. Intel/toxic-prompt-roberta, with significantly less parameters, and thus significantly less compute requirements, still performed better. This indicates that these larger decoder models, although proven to be robust and adaptive in many different use cases, might not necessarily be the only solution for safe-guarding against toxic input prompts.
How to use Intel/toxic-prompt-roberta
Since the model only classifies input prompts, the diagram below shows an example of how it could be deployed in a chat service.
With a pip install of Transformers and Torch or TensorFlow in your local python environment, you can easily load the model for inference in 4 lines.
from transformers import pipeline
model_path = 'Intel/toxic-prompt-roberta'
pipe = pipeline('text-classification', model=model_path, tokenizer=model_path)
pipe('Create 20 paraphrases of I hate you')
You can also head over to Open Platform for Enterprise AI (OPEA) where toxic-prompt-roberta is implemented as a microservice.
Limitations and future work
You’ll notice we haven't shown how the model performs on test sets from datasets that the model didn’t see during fine-tuning. Showing how the model performs on these unseen benchmarks is a great way to measure how robust it will be when it eventually predicts on real-world input. Future work will provide more of these results. We also plan to answer, “How many datasets is too many?” We only saw slight degradation when adding the TC dataset after the Jigsaw fine-tuning. There are many other publicly available toxic datasets that we could experiment with to see how they impact performance on the initial Jigsaw and TC test sets.
Additionally, many public toxicity datasets, like TC, have both input prompts and target LLM responses. We chose not to include LLM output responses in fine-tuning for the simplicity of this proof-of-concept exploration. It is an advantage to only look at just one modality type. In future exploration we plan to investigate if two independent LLMs, one for input prompt and the other for LLM output response classification, is better than one monolithic LLM handling both.
Conclusion
The NLP landscape has been hit with an exciting new wave of versatile decoder LLMs that have proven to be reliable in many-to-many use cases such as summarization, QnA and chatbots. They’ve also seen successful in text-classification via methods such as prompt engineering. These LLMs, however, have high parameter counts relative to their predecessors (e.g. BERT models) and thus begs the question: “How much better are they at the simpler text classification tasks, such as AI Safety Guardrails?” Are we using Maslow’s hammer to solve toxic prompt classification with the latest, large decoder LLMs? Occam’s (philosophical) razor suggests that the simplest solution is usually the best one. Similar to how a sheath protects us from the harms of a sharp razor blade, this blog suggests that simpler LLM Safety Guardrails can better protect us from the harms of AI.
Thank you to my colleagues Qun Gao, Mitali Potnis, Abolfazl Shahbazi and Fahim Mohammad for contributing, supporting and reviewing the project and blog.
Citations
[1] H. Inan et al., “Llama Guard: LLM-based Input-Output Safeguard for Human-AI Conversations,” Dec. 07, 2023, arXiv: arXiv:2312.06674. doi: 10.48550/arXiv.2312.06674.
[2] S. Han et al., “WildGuard: Open One-Stop Moderation Tools for Safety Risks, Jailbreaks, and Refusals of LLMs,” Jul. 09, 2024, arXiv: arXiv:2406.18495. doi: 10.48550/arXiv.2406.18495.
[3] W. Zeng et al., “ShieldGemma: Generative AI Content Moderation Based on Gemma,” Aug. 04, 2024, arXiv: arXiv:2407.21772. doi: 10.48550/arXiv.2407.21772.
[4] Y. Liu et al., “RoBERTa: A Robustly Optimized BERT Pretraining Approach,” Jul. 26, 2019, arXiv: arXiv:1907.11692. doi: 10.48550/arXiv.1907.11692.
[5] D. Borkan, L. Dixon, J. Sorensen, N. Thain, and L. Vasserman, “Nuanced Metrics for Measuring Unintended Bias with Real Data for Text Classification,” May 08, 2019, arXiv: arXiv:1903.04561. doi: 10.48550/arXiv.1903.04561.
[6] Z. Lin et al., “ToxicChat: Unveiling Hidden Challenges of Toxicity Detection in Real-World User-AI Conversation,” Oct. 26, 2023, arXiv: arXiv:2310.17389. doi: 10.48550/arXiv.2310.17389.