RLHF 101: A Technical Dive into RLHF

Community Article Published December 11, 2024

This blog delves into the full training pipeline of the Reinforcement Learning from Human Feedback (RLHF) framework. We will explore every stage — from data generation and reward model inference to the final training of a large language model (LLM). Our goal is to ensure that everything is fully reproducible by providing all the necessary code and the exact specifications of the environments used. By the end of this post, you should be equipped to train any model with any instruction dataset using the algorithm of your choice.

Prelimiary: Setup & Environment

We will use the following setup for this tutorial:

  • Dataset: UltraFeedback a well-curated dataset designed consists of general chat prompts.
  • Base Model: Llama-3-8B-it, a state-of-the-art instruction-tuned model.
  • Reward Model: Armo, a robust reward model optimized for evaluating the generated outputs.
  • Training Algorithm: REBEL, a state-of-the-art algorithm tailored for efficient RLHF optimization.

To get started, clone our repo, which contains all the resources required for this tutorial:

git clone https://github.com/ZhaolinGao/REBEL
cd REBEL

We use two separate environments for different stages of the pipeline:

  • vllm: Handles data generation, leveraging the efficient vllm library.
  • rebel: Used for training the RLHF model.

You can install both environments using the provided YAML files:

conda env create -f ./envs/rebel_env.yml
conda env create -f ./envs/vllm_env.yml

Part 1: Data Generation

In this section, we will load the base model using vllm for fast inference, prepare the dataset, and generate multiple responses for each prompt in the dataset. The complete code for this part is available here.

Activate the vllm environment:

conda activate vllm

First, load the base model and tokenizer using vllm:

from transformers import AutoTokenizer
from vllm import LLM
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
llm = LLM(
    model="meta-llama/Meta-Llama-3-8B-Instruct",
    tensor_parallel_size=8,
)

Here, tensor_parallel_size specifies the number of GPUs to use.

Next, load the UltraFeedback dataset:

from datasets import load_dataset
dataset = load_dataset("allenai/ultrafeedback_binarized_cleaned_train", split='train')

You can select a subset of the dataset using dataset.select. For example, to select the first 10,000 rows:

dataset = dataset.select(range(10000))

Alternatively, you can split the dataset into chunks using dataset.shard for implementations like SPPO where each iteration only trains on one of the chunks.

Now, let's prepare the dataset for generation. The Llama model uses special tokens to distinguish prompts and responses. For example:

<|begin_of_text|><|start_header_id|>user<|end_header_id|>

What is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Therefore, for every prompt in the dataset, we need to convert it from plain text into this format before generating:

def get_message(instruction):
    message = [
        {"role": "user", "content": instruction},
    ]
    return message
prompts = [tokenizer.apply_chat_template(get_message(row['prompt']), tokenize=False, add_generation_prompt=True) for row in dataset]
  • get_message transforms the plain-text prompt into a dictionary indicating it is from the user.
  • tokenizer.apply_chat_template adds the required special tokens and appends the response tokens (<|start_header_id|>assistant<|end_header_id|>\n\n) at the end with add_generation_prompt=True.

Finally, we can generate the responses using vllm with the prompts we just formatted. We are going to generate 5 responses per prompt:

import torch
import random
import numpy as np
from vllm import SamplingParams

def set_seed(seed=5775709):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

for p in range(5):
    set_seed(p * 50)
    sampling_params = SamplingParams(
        temperature=0.8,
        top_p=0.9,
        max_tokens=2048,
        seed=p * 50,
    )
    response = llm.generate(prompts, sampling_params)
    output = list(map(lambda x: x.outputs[0].text, response))
    dataset = dataset.add_column(f"response_{p}", output)
  • temperature=0.8, top_p=0.9 are common settings to control diversity in generation.
  • set_seed is used to ensure reproducibility and sets a different seed for each response.
  • llm.generate generates the response, and the results are added to the dataset with dataset.add_column.

You could run the complete scipt with:

python ./src/ultrafeedback_largebatch/generate.py --world_size NUM_GPU --output_repo OUTPUT_REPO

Part 2: Reward Model Inference

In this part, we will calculate reward scores for the responses generated in Part 1. The complete code for this part is available here.

Activate the rebel environment:

conda activate rebel

To begin, we'll initialize the Armo reward model pipeline. This reward model is a fine-tuned sequence classification model that assigns a scalar reward score to a given dialogue based on its quality. The implementation is as follows:

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from typing import Dict, List

class ArmoRMPipeline:
    def __init__(self, model_id, device_map="cuda", torch_dtype=torch.bfloat16, truncation=True, trust_remote_code=False, max_length=4096):
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_id,
            device_map=device_map,
            trust_remote_code=trust_remote_code,
            torch_dtype=torch_dtype,
        )
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_id,
            use_fast=True,
        )
        self.truncation = truncation
        self.device = self.model.device
        self.max_length = max_length

    def __call__(self, messages: List[Dict[str, str]]) -> Dict[str, float]:
        input_ids = self.tokenizer.apply_chat_template(
            messages,
            return_tensors="pt",
            padding=True,
            truncation=self.truncation,
            max_length=self.max_length,
        ).to(self.device)
        with torch.no_grad():
            output = self.model(input_ids)
            score = output.score.float().item()
        return score

rm = ArmoRMPipeline("RLHFlow/ArmoRM-Llama3-8B-v0.1", trust_remote_code=True)

Now, we could gather the reward scores:

def get_message(instruction, response):
    return [{"role": "user", "content": instruction}, {"role": "assistant", "content": response}]

rewards = {}
for i in range(5):
    rewards[f"response_{i}_reward"] = []
    for row in dataset:
        reward = rm(get_message(row['prompt'], row[f'response_{i}']))
        rewards[f"response_{i}_reward"].append(reward)
for k, v in rewards.items():
    dataset = dataset.add_column(k, v)
  • get_message formats the user prompt and assistant response into a list of dictionaries.
  • rm computes a reward score for each response in the dataset.

You could run the complete scipt with:

python ./src/ultrafeedback_largebatch/rank.py --input_repo INPUT_REPO
  • INPUT_REPO is the saved repo from Part 1 that contains the generated responses.

Part 3: Filter and Tokenize

In this part, we’ll walk through the process of preparing a dataset for training by filtering excessively long prompts and responses to prevent OOM, selecting the best and worst responses for training, and removing duplicate responses. The complete code for this part is available here.

Let's first initialize two different tokenizers where one pads from the right and one pads from the left:

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
tokenizer_left = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", padding_side='left')
tokenizer_left.add_special_tokens({"pad_token": "[PAD]"})

Since there is no pad token for Llama, we manually add it with tokenizer.add_special_tokens.

These two different tokenizers allow us to pad the prompt from left and the response from the right such that they meet in the middle. By combining left-padded prompts with right-padded responses, we ensure that:

  • Prompts and responses meet at a consistent position.
  • Relative position embeddings remain correct for model training.

Here’s an example format:

[PAD] ... [PAD] <|begin_of_text|><|start_header_id|>user<|end_header_id|>

PROMPT<|eot_id|><|start_header_id|>assistant<|end_header_id|>


RESPONSE<|eot_id|>[PAD] ... [PAD]

We want to ensure that the length of

[PAD] ... [PAD] <|begin_of_text|><|start_header_id|>user<|end_header_id|>

PROMPT<|eot_id|><|start_header_id|>assistant<|end_header_id|>

is the same for all prompts, and the length of

RESPONSE<|eot_id|>[PAD] ... [PAD]

is the same for all responses.

Next, we filter out prompts longer than 1,024 tokens and responses exceeding 2,048 tokens. The filtering process uses a helper function to create message templates:

def get_message(instruction=None, response=None):

    assert instruction != None or response != None

    if response == None:
        message = [
            {"role": "user", "content": instruction},
        ]
    elif instruction == None:
        message = [
            {"role": "assistant", "content": response}
        ]
    else:
        message = [
            {"role": "user", "content": instruction},
            {"role": "assistant", "content": response}
        ]

    return message

dataset = dataset.filter(lambda row: tokenizer.apply_chat_template(get_message(row['prompt']), tokenize=True, add_generation_prompt=True, return_tensors='pt').shape[-1] <= 1024)
for i in range(5):
    dataset = dataset.filter(lambda row: tokenizer.apply_chat_template(get_message(response=row[f'response_{i}']), tokenize=True, add_generation_prompt=False, return_tensors='pt')[:, 5:].shape[-1] <= 2048)

Note that we skip the first five tokens of responses when counting lengths to exclude special tokens (e.g. <|begin_of_text|><|start_header_id|>assistant<|end_header_id|>\n\n) and only count the actual length of the response plus the EOS token (<|eot_id|>) at the end.

Now we could tokenize the prompt with left padding to a maximum length of 1,024 tokens:

llama_prompt_tokens = []
for row in dataset:
    llama_prompt_token = tokenizer_left.apply_chat_template(
            get_message(row['prompt']), 
            add_generation_prompt=True,
            tokenize=True,
            padding='max_length',
            max_length=1024,
    )
    assert len(llama_prompt_token) == 1024
    assert (llama_prompt_token[0] == 128000 or llama_prompt_token[0] == 128256) and llama_prompt_token[-1] == 271
    llama_prompt_tokens.append(llama_prompt_token)
dataset = dataset.add_column("llama_prompt_tokens", llama_prompt_tokens)

The assertions are used to ensure that the length is always 1,024 and the tokenized prompt either starts with [pad] token or <|begin_of_text|> token and ends with \n\n token.

Then, we select the responses with the highest and lowest rewards for each prompt as the chosen and reject responses, and tokenize them with right padding:

chosen, reject, llama_chosen_tokens, llama_reject_tokens, chosen_reward, reject_reward = [], [], [], [], [], []

for row in dataset:

    all_rewards = [row[f"response_{i}_reward"] for i in range(5)]
    chosen_idx, reject_idx = np.argmax(all_rewards), np.argmin(all_rewards)

    chosen.append(row[f"response_{chosen_idx}"])
    reject.append(row[f"response_{reject_idx}"])

    llama_chosen_token = tokenizer.apply_chat_template(
            get_message(response=row[f"response_{chosen_idx}"]),
            add_generation_prompt=False,
            tokenize=True,
            padding='max_length',
            max_length=2048+5,
    )[5:]
    llama_chosen_tokens.append(llama_chosen_token)
    chosen_reward.append(row[f"response_{chosen_idx}_reward"])
    assert len(llama_chosen_token) == 2048
    assert llama_chosen_token[-1] == 128009 or llama_chosen_token[-1] == 128256

    llama_reject_token = tokenizer.apply_chat_template(
            get_message(response=row[f"response_{reject_idx}"]),
            add_generation_prompt=False,
            tokenize=True,
            padding='max_length',
            max_length=2048+5,
    )[5:]
    llama_reject_tokens.append(llama_reject_token)
    reject_reward.append(row[f"response_{reject_idx}_reward"])
    assert len(llama_reject_token) == 2048
    assert llama_reject_token[-1] == 128009 or llama_reject_token[-1] == 128256

dataset = dataset.add_column("chosen", chosen)
dataset = dataset.add_column("chosen_reward", chosen_reward)
dataset = dataset.add_column("llama_chosen_tokens", llama_chosen_tokens)
dataset = dataset.add_column("reject", reject)
dataset = dataset.add_column("reject_reward", reject_reward)
dataset = dataset.add_column("llama_reject_tokens", llama_reject_tokens)

Again the assertions are used to ensure that the lengths of the tokenized responses are always 2,048 and the tokenized responses either end with [pad] token or <|eot_id|> token.

Finally, we filter out rows where the chosen and reject responses are the same:

dataset = dataset.filter(lambda row: row['chosen'] != row['reject'])

and split the dataset into a train set and a test set with 1,000 prompts:

dataset = dataset.train_test_split(test_size=1000, shuffle=True)

You could run the complete scipt with:

python ./src/ultrafeedback_largebatch/filter_tokenize.py --input_repo INPUT_REPO
  • INPUT_REPO is the saved repo from Part 2 that contains the rewards for each response.

Part 4: Training with REBEL

At each iteration tt of REBEL, we aim to solve the following square loss regression problem: θt+1=argminθΘ(x,y,y)Dt(1η(lnπθ(yx)πθt(yx)lnπθ(yx)πθt(yx))(r(x,y)r(x,y)))2\theta_{t+1}=\arg\min_{\theta\in\Theta}\sum_{(x, y, y')\in \mathcal{D}_t}\left(\frac{1}{\eta} \left(\ln \frac{\pi_\theta(y|x)}{\pi_{\theta_t}(y|x)} - \ln \frac{\pi_\theta(y'|x)}{\pi_{\theta_t}(y'|x)}\right) - \left(r(x, y) - r(x, y')\right)\right)^2 where η\eta is a hyperparameter, θ\theta is the parameter of the model, xx is the prompt, Dt\mathcal{D}_t is the dataset we collected from the previous three parts, yy and yy' are the responses for xx, πθ(yx)\pi_\theta(y|x) is the probability of generation response yy given prompt xx under the parameterized policy πθ\pi_\theta, and r(x,y)r(x, y) is the reward of response yy for prompt xx which is obtained from Part 2.

In this tutorial, we demonstrate a single iteration of REBEL (t=0) using the base model πθ0\pi_{\theta_0}. For multi-iteration training, you can repeat Parts 1 through 4, initializing each iteration with the model trained in the previous iteration.

The complete code for this part is available here. To enable full parameter training using 8 GPUs, we use the Accelerate library with Deepspeed Stage 3 by running:

accelerate launch --config_file accelerate_cfgs/deepspeed_config_stage_3.yaml --main-process-port 29080 --num_processes 8 src/ultrafeedback_largebatch/rebel.py --task.input_repo INPUT_REPO --output_dir OUTPUT_DIR
  • INPUT_REPO is the saved repo from Part 3 that contains the tokenized prompts and responses.
  • OUTPUT_DIR is the directory to save the models.

Step 1: Initialization & Loading

We start by initializing the batch size for distributed training:

args.world_size = accelerator.num_processes
args.batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps
args.local_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps
args.rebel.num_updates = args.total_episodes // args.batch_size
  • args.world_size is the number of GPUs we are using.
  • args.local_batch_size is the batch size for each GPU.
  • args.batch_size is the actual batch size for training.
  • args.rebel.num_updates is the total number of updates to perform and args.total_episodes is the number of data points to train for. Typically, we set args.total_episodes to be the size of the training set for one epoch.

Next, we load the model and tokenizer, ensuring dropout layers are disabled:

tokenizer = AutoTokenizer.from_pretrained(
                args.base_model, 
                padding_side='right',
                trust_remote_code=True,
            )
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
policy = AutoModelForCausalLM.from_pretrained(
            args.base_model,
            trust_remote_code=True,
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
        )
disable_dropout_in_model(policy)

Step 2: Logging Probabilities

To facilitate hyperparameter tuning (e.g., learning rate and η\eta), we precompute and save the log probabilities of πθt\pi_{\theta_t}. Therefore, when we run the following code for the first time, try is going to fail as there is no chosen_logprob or reject_logprob available in the dataset yet:

compute_log = False
try:
    dataset = load_dataset(args.task.input_repo + '_logprob', split='train')
    dataset = dataset.with_format("torch", columns=["llama_prompt_tokens", 
                                                    "llama_chosen_tokens", "chosen_reward", "chosen_logprob",
                                                    "llama_reject_tokens", "reject_reward", "reject_logprob"])
    temp_dataloader = DataLoader(dataset, batch_size=args.local_batch_size, shuffle=True)
    validation_dataset = load_dataset(args.task.input_repo + '_logprob', split='test')
    validation_dataset = validation_dataset.with_format("torch", columns=["llama_prompt_tokens", 
                                                        "llama_chosen_tokens", "chosen_reward", "chosen_logprob",
                                                        "llama_reject_tokens", "reject_reward", "reject_logprob"])
except:
    dataset = load_dataset(args.task.input_repo, split='train')
    dataset = dataset.with_format("torch", columns=["llama_prompt_tokens", 
                                                    "llama_chosen_tokens", "chosen_reward",
                                                    "llama_reject_tokens", "reject_reward"])
    temp_dataloader = DataLoader(dataset, batch_size=args.local_batch_size, shuffle=True)
    validation_dataset = load_dataset(args.task.input_repo, split='test')
    validation_dataset = validation_dataset.with_format("torch", columns=["llama_prompt_tokens", 
                                                        "llama_chosen_tokens", "chosen_reward",
                                                        "llama_reject_tokens", "reject_reward"])
    compute_log = True

Here, we need a temporary dataloader temp_dataloader to prepare the model with accelerator.prepare before computing the logprobs to enable multi-gpu inference.

Now, we could compute and save the logprobs:

def gather_logprob(args, model, tokenizer, query, response, device):

    query_response = torch.cat((query, response), dim=-1).long().to(device).unsqueeze(0)
    response = response.long().to(device).unsqueeze(0)
    attention_mask = query_response != tokenizer.pad_token_id
    input_ids = torch.masked_fill(query_response, ~attention_mask, tokenizer.eos_token_id)
    with torch.no_grad():
        output = model(
                    input_ids=input_ids, 
                    attention_mask=attention_mask,
                    return_dict=True,
                 )
        logits = output.logits[:, args.task.maxlen_prompt - 1 : -1]
        logits /= args.task.temperature + 1e-7
        all_logprob = F.log_softmax(logits, dim=-1)
        logprob = torch.gather(all_logprob, 2, input_ids[:, args.task.maxlen_prompt:].unsqueeze(-1)).squeeze(-1)
        sequence_length = first_true_indices(response == tokenizer.pad_token_id) - 1
        seq_mask = torch.arange(args.task.maxlen, device=device).unsqueeze(0).expand_as(response) <= sequence_length.unsqueeze(1)
        
        return (logprob * seq_mask).sum(-1)


def gather_all_logprob(args, process_idx, policy, tokenizer, dataset, device):

    batch_size = len(dataset) // args.world_size + 1
    start_idx = batch_size * process_idx

    # make batch size same for accelerator.gather
    if start_idx + batch_size > len(dataset):
        start_idx = len(dataset) - batch_size

    chosen_logprob, reject_logprob, index = [], [], []

    with torch.no_grad():
        for i in tqdm(range(start_idx, start_idx + batch_size)):

            chosen_logprob.append(gather_logprob(args, policy, tokenizer, dataset[i]["llama_prompt_tokens"], dataset[i]["llama_chosen_tokens"], device))
            reject_logprob.append(gather_logprob(args, policy, tokenizer, dataset[i]["llama_prompt_tokens"], dataset[i]["llama_reject_tokens"], device))
            index.append(i)

        chosen_logprob = torch.cat(chosen_logprob)
        reject_logprob = torch.cat(reject_logprob)
        index = torch.LongTensor(index).to(device)

    chosen_logprob = accelerator.gather(chosen_logprob).cpu().tolist()
    reject_logprob = accelerator.gather(reject_logprob).cpu().tolist()
    index = accelerator.gather(index).cpu().tolist()

    chosen_logprobs = [0] * len(dataset)
    reject_logprobs = [0] * len(dataset)

    for i, data_i in enumerate(index):
        chosen_logprobs[data_i] = chosen_logprob[i]
        reject_logprobs[data_i] = reject_logprob[i]
        
    return chosen_logprobs, reject_logprobs

if compute_log:
    accelerator.print('gathering validation logprob')
    chosen_logprob, reject_logprob = gather_all_logprob(args, accelerator.process_index, accelerator.unwrap_model(policy), tokenizer, validation_dataset, device)
    validation_dataset = validation_dataset.add_column("chosen_logprob", chosen_logprob)
    validation_dataset = validation_dataset.add_column("reject_logprob", reject_logprob)
    validation_dataset = validation_dataset.with_format("torch", columns=["llama_prompt_tokens", 
                                                                            "llama_chosen_tokens", "chosen_reward", "chosen_logprob",
                                                                            "llama_reject_tokens", "reject_reward", "reject_logprob"])

    accelerator.print('gathering logprob')
    chosen_logprob, reject_logprob = gather_all_logprob(args, accelerator.process_index, accelerator.unwrap_model(policy), tokenizer, dataset, device)
    dataset = dataset.add_column("chosen_logprob", chosen_logprob)
    dataset = dataset.add_column("reject_logprob", reject_logprob)
    dataset = dataset.with_format("torch", columns=["llama_prompt_tokens", 
                                                    "llama_chosen_tokens", "chosen_reward", "chosen_logprob",
                                                    "llama_reject_tokens", "reject_reward", "reject_logprob"])
    if accelerator.is_main_process:
        temp = DatasetDict({
            "train" : dataset,
            "test"  : validation_dataset,
        })
        temp.push_to_hub(args.task.input_repo + '_logprob')
  • gather_logprob computes the logprob for each prompt and response pair in the dataset.
  • gather_all_logprob parallelizes the computation by dividing the dataset into args.world_size segments where the size of each segment is the same. If the dataset size cannot be perfectly divided by the number of segments, we make the last segment overlap with the previous one. The same-sized segments ensure that accelerator.gather would function properly.

Step 3: Training

Looking again at the REBEL objective, the only thing we need now to train is to compute πθ(yx)\pi_\theta(y|x). We can compute it with:

output = policy(
    input_ids=input_ids, 
    attention_mask=attention_mask,
    return_dict=True,
    output_hidden_states=True,
)
logits = output.logits[:, args.task.maxlen_prompt - 1 : -1]
logits /= args.task.temperature + 1e-7
new_all_logprobs = F.log_softmax(logits, dim=-1)
new_logprobs = torch.gather(new_all_logprobs, 2, input_ids[:, args.task.maxlen_prompt:].unsqueeze(-1)).squeeze(-1)
new_logprobs = (new_logprobs * mb_seq_mask).sum(-1)
  • output.logits contains the logits of all tokens in the vocabulary for the sequence of input_ids.
  • output.logits[:, args.task.maxlen_prompt - 1 : -1] is the logits of all tokens in the vocabulary for the sequence of response only. It is shifted by 1 since the logits at position pp are referring to the logits at position p+1p+1.
  • We divide logits by args.task.temperature to obtain the actual probability during generation.
  • torch.gather is used to gather the perspective token in the response.
  • mb_seq_mask masks out the paddings.

Step 4: Loss Computation

Finally, we could compute the loss by:

ratio_logprob = new_logprobs - mb_logprobs
ratio_logprob = ratio_logprob[:args.per_device_train_batch_size] - ratio_logprob[args.per_device_train_batch_size:]
reg_diff = ratio_logprob - args.rebel.eta * (mb_chosen_reward - mb_reject_reward)
loss = (reg_diff ** 2).mean()

Performance

With only one iteration of the above 4 parts, we could greatly enhance the performance of the base model on AlpacaEval, MT-Bench, and ArenaHard:

Model AlpacaEval 2.0
LC Win Rate
AlpacaEval 2.0
Win Rate
MT-Bench
Average
ArenaHard
Llama-3-8B-it 22.9 22.6 8.10 22.3
REBEL-Llama-3-Armo-iter_1 48.3 41.8 8.13 34.5

Takeaway

In this post, we outlined the pipeline for RLHF, covering the entire process from data generation to the actual training phase. While we focused specifically on the REBEL algorithm, this pipeline is versatile and can be readily adapted to other methods such as DPO or SimPO. The necessary components for these methods are already included except for the specific loss formulation.

If you find this implementation useful, please consider citing our work:

@misc{gao2024rebel,
      title={REBEL: Reinforcement Learning via Regressing Relative Rewards}, 
      author={Zhaolin Gao and Jonathan D. Chang and Wenhao Zhan and Owen Oertell and Gokul Swamy and Kianté Brantley and Thorsten Joachims and J. Andrew Bagnell and Jason D. Lee and Wen Sun},
      year={2024},
      eprint={2404.16767},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}