Over the last few years, the scaling of train-time compute has dominated the progress of large language models (LLMs).
Although we don’t know how o1 was trained, recent research from DeepMind
Over the past months we’ve been diving deep in trying to reverse engineer and reproduce several of these results and are finally happy to share some of our knowledge. More precisely, in this blog post we’ll cover:
So how well does compute-optimal scaling work in practice? Check out this plot where the tiny 1B and 3B Llama Instruct models outperform their much larger 8B and 70B siblings on the challenging MATH-500 benchmark if you give them enough “time to think” 🤯:
In the rest of this blog post, we’ll dive deep into the ingredients behind results like this one and walk you through practical strategies for implementing test-time compute scaling.
There are two main strategies for scaling test-time compute:
In this blog post, we’ll concentrate on search-based methods as they represent a practical and scalable solution for test-time compute optimization. In particular, we’ll examine the three strategies illustrated below:
With an understanding of the key search strategies, let’s move on to how we evaluated them in practice.
As illustrated in the diagram above, our experimental setup involves a pipeline with the following steps:
To compare various search strategies, we used the following open models and datasets:
meta-llama/Llama-3.2-1B-Instruct
as our primary model for scaling test-time compute. With 1B parameters, its lightweight nature enables fast iterations, and its unsaturated performance on math benchmarks makes it an ideal choice for highlighting the benefits of scaling.RLHFlow/Llama3.1-8B-PRM-Deepseek-Data
, an 8B reward model that has been trained using process supervision. Process supervision is a training approach where models receive feedback on each step of their reasoning process, not just the final outcome. We picked this model since it belongs to the same model family as our policy and gave better results than other PRMs like Math-Shepherd we tested in this weight class.We tested each search strategy across compute budgets ranging from 1 to 256 generations per prompt and ran the data-generation pipeline with five random seeds to estimate variance across runs. You can find the models and datasets from our analysis in this Hugging Face collection.
To warmup, we’ll begin with a simple baseline and progressively incorporate additional techniques to improve performance.
Majority voting—or self-consistency decoding
One quirk with the MATH benchmark is that answers must be formatted in a LaTeX box like \boxed{answer}
. We initially tried the following simple system prompt for Llama 3.2 1B
Please think step by step and put your final answer within \boxed{}.
but found the resulting accuracy with greedy decoding (\(T=0\)) to be far worse than the 30.6% that Meta reported in their release. Luckily, Meta also published the prompts they used for their evals and switching our system prompt to theirs made all the difference:
Solve the following math problem efficiently and clearly:
- For simple problems (2 steps or fewer):
Provide a concise solution with minimal explanation.
- For complex problems (3 steps or more):
Use this step-by-step format:
## Step 1: [Concise description]
[Brief explanation and calculations]
## Step 2: [Concise description]
[Brief explanation and calculations]
...
Regardless of the approach, always conclude with:
Therefore, the final answer is: $\boxed{answer}$. I hope it is correct.
Where [answer] is just the final number or expression that solves the problem.
One subtlety with evaluating answers to math problems is that strings like \(1/\sqrt{3}\) and \(\sqrt{3}/3\) are distinct, but represent mathematically equivalent answers. The standard waysympy.simplify
gives zero.
While this approach works well when comparing a small number of candidate answers, we found it was terribly slow when comparing many pairs in a list of \(N\) candidates; in some cases, slower than generating the candidates in the first place! To deal with this, we first reduced each answer to its canonical form and then computed the frequency of each form to determine the majority vote. Expand the detail below if you’re curious about how the code looks.
To obtain the canonical form of an algebraic expression, we first convert the LaTeX string to SymPy, apply sympy.simplify
, and finally convert back to LaTeX:
from latex2sympy2 import latex2sympy
from sympy import latex, simplify
def get_canonical_form(expression: str) -> str:
parsed_expr = latex2sympy(expression)
simplified_expr = simplify(parsed_expr)
return latex(simplified_expr)
With this function, we can then iterate over all candidate solutions in an list and keep track of how many times a canonical form has been seen before computing the final majority vote:
def find_majority_answer(answers: list[str]) -> str:
canonical_groups = defaultdict(int)
canonical_to_original = {}
for answer in answers:
canonical_form = get_canonical_form(answer)
# Increment count for the canonical form
canonical_groups[canonical_form] += 1
# Track the original answer for this canonical form
if canonical_form not in canonical_to_original:
canonical_to_original[canonical_form] = answer
# Find the canonical form with the largest count
max_count = max(canonical_groups.values())
for canonical_form, count in canonical_groups.items():
if count == max_count:
# Return the first occurring group in case of a tie
return canonical_to_original[canonical_form]
This approach was significantly faster than checking each pair of solutions independently for equality.
Here’s how majority voting performs when applied to the generations from Llama 3.2 1B Instruct:
The results show that majority voting yields a significant improvement over the greedy decoding baseline, but its gains start to plateau after approximately \(N=64\) generations. This limitation arises because majority voting struggles with problems that require nuanced reasoning or tasks where errors are consistent across generations. If you’re also wondering why the majority voting accuracy is worse than the 0-shot CoT baseline for \(N=1\) and \(2\), that’s because we sample at \(T=0.8\), which makes it less likely we produce the correct answer among a handful of candidates.
Building on the limitations of majority voting, let’s see how incorporating a reward model can enhance performance.
Best-of-N is a simple, but effective extension to majority voting that uses a reward model to determine the most plausible answer. This method comes in two main variants:
\boxed{answer}.
Typically, one usually uses an outcome reward model (ORM) to get a single, solution-level score. But to allow for fair comparison with the other search strategies discussed later, we will use the same PRM to score the solutions from Best-of-N. As illustrated below, PRMs produce a cumulative sequence of step-level scores per solution, so we need to perform a reduction over the steps to obtain a single solution-level score:
In the literature, the most common reductions are the following:
We experimented with each reduction and found—like DeepMind—that “last” performs best for our choice of task and PRM. We use this aggregation throughout all of our experiments and you can expand the detail below to see how we implemented it, along with the weighting procedure discussed above.
Here’s the results one gets from applying both variants of Best-of-N:
The results reveal a clear advantage: weighted Best-of-N consistently outperforms vanilla Best-of-N, especially with larger generation budgets. Its ability to aggregate scores across identical responses ensures that even less frequent but higher-quality answers are effectively prioritized.
However, despite these improvements, we’re still falling short of the performance achieved by the Llama 8B model and the Best-of-N approach is starting to plateau at \(N=256\) generations. Can we push the boundaries further by supervising the search process step-by-step? Let’s find out 🚀!
Beam search is a structured search method that systematically explores the solution space, making it a powerful tool for improving model outputs at test-time. When combined with a PRM, beam search can optimize both the generation and evaluation of intermediate steps in problem-solving. The way it works is as follows:
\n
or double new line \n\n
.By allowing the PRM to evaluate the correctness of intermediate steps, beam search can identify and prioritize promising paths early in the process. This step-by-step evaluation is particularly beneficial for complex reasoning tasks like mathematics, where verifying partial solutions can significantly improve final outcomes.
When we implemented beam search with process supervision, we encountered two major footguns with the Llama 3 chat template that are worth mentioning:
\n
or \n\n
to terminate a step, these tokens are lost on subsequent steps and force the model to produce peculiar outputs.The solution is to overwrite the Llama 3 chat template to prevent trimming and exclude the BOS token prefix.
In our experiments, we followed DeepMind’s hyperparameter choices
As shown below, the results are striking: with a test-time budget of \(N=4\), beam search achieves the same accuracy as Best-of-N for \(N=16\), i.e. it is 4x more compute efficient! Similarly, with \(N=16\), beam search achieves the same accuracy as Best-of-N for \(N=256\), making it 16x more compute efficient at larger \(N\). Moreover, beam search matches the performance of Llama 3.1 8B with just \(N=32\) solutions per problem. The average performance on MATH by computer science PhD students is around 40%, so reaching nearly 55% isn’t too bad for a 1B model 💪!
Although in aggregate it is clear that beam search is a better search strategy than Best-of-N or majority voting, the DeepMind paper showed that each strategy has tradeoffs that depend on the problem difficulty and test-time compute budget.
To see which problems are best suited for which strategy, DeepMind computed a distribution over estimated problem difficulty, and then binned the results into quintiles. In other words, each problem is assigned one of 5 levels, where level 1 indicates easier problems and level 5 indicates the hardest ones. To estimate problem difficulty, DeepMind generated 2048 candidate solutions with standard sampling per problem and then proposed the following heuristics:
The pass@k metric measures the probability, computed over a set of problems, that at least one of the top \(k\) generated outputs for each problem contains the correct solution. In practice, computing pass@k naively leads to high variance; for example, if we compute pass@1 from a single completion per problem, we can get significantly different values from repeated evaluations due to sampling. To combat this, OpenAI's Codex paper
However, computing the estimator directly suffers from numerical instabilities, so in practice one uses the following simplified form:
Here’s the breakdown of the various methods according to the pass@1 scores and across four test-time compute budgets of \(N = [4,16,64, 256]\):
In this plot, each bar denotes a test-time compute budget, and within each bar we show the relative accuracy of each method. For example, in the group of four bars on difficulty level 2 we see that:
Although we see that beam search gives consistent gains in the medium and hard problems (levels 3-5), it tends to do worse than Best-of-N (and even majority voting!) on the simpler problems and especially at large compute budgets.
We realized from looking at the resulting trees produced by beam search, that if a single step is assigned high reward, then the whole tree collapses to that trace and thus diversity is impacted. This prompted us to explore an extension to beam search that maximises diversity - let’s take a look!
As we saw above beam search gives strong performance over Best-of-N, but tends to underperform on simpler problems and at large test-time compute budgets. To address this, we developed an extension we call Diverse Verifier Tree Search (DVTS) that is designed to maximise diversity at large \(N\).
DVTS works in a similar fashion as beam search, with the following modifications:
Here’s the results from applying DVTS to Llama 1B:
As we can see, DVTS provides a complementary strategy to beam search: at small \(N\) beam search is more effective at finding correct solutions, but at large \(N\) the diversity of DVTS candidates kicks in and we get better performance.
We can also see this manifested in the problem difficulty breakdown, where DVTS enhances performance on the easy / medium problems at large \(N\), while beam search is best at small \(N\) across model problem difficulties:
Armed with various search strategies, a natural question is which one is best? In the DeepMind paper, they proposed a compute-optimal scaling strategy where one selects the search method and hyperparameters \(\theta\) that achieves the best performance for a given compute budget \(N\): $$\theta_{q,a^*(q)}^*(N) = \underset{\theta}{\arg\max} \left( \mathbb{E}_{y \sim \text{Target}(\theta, N, q)} \left[ \mathbb{1}_{y = y^*(q)} \right] \right),$$ where \(y^*(q)\) is the ground-truth for question \(q\) and \(\theta_{q,a^*(q)}^*(N)\) denotes the compute-optimal scaling strategy. Since computing \(\theta_{q,a^*(q)}^*(N)\) directly is somewhat tricky, DeepMind proposed an approximation based on the problem difficulty, i.e. allocate test-time compute according to which search strategy achieves best performance for a given difficulty level.
For example, on simpler problems and lower compute budgets, it is better to use strategies like Best-of-N, while on harder problems, beam search is the better choice. To implement this, for each method we compute the accuracy for a given difficulty level and test-time compute budget. And voila, we now have our compute-optimal curve!
We also explored scaling up the compute-optimal recipe to Llama 3.2 3B Instruct to see at what point the benefits of the PRM fade in comparison to the policy’s own capacity. To our surprise, compute-optimal scaling works remarkably well, with the 3B model surpassing the performance of Llama 3.1 70B Instruct (22x it's size!):
This exploration of test-time compute scaling has revealed both the potential and the challenges of leveraging search-based methods. As we look ahead, several exciting directions emerge:
We'd love to hear from you on your ideas or feedback in the discussions tab!
We are grateful to Charlie Snell and Aviral Kumar for many discussions about test-time compute scaling and for sharing implementation details from their work. We thank Chun Te Lee for designing the lovely banner and Thomas Wolf, Leandro von Werra, Colin Raffel, and Quentin Gallouédec for many helpful suggestions to improve the blog post. We also thank Hugo Larcher and Mathieu Morlon for continually optimising the Hugging Face Science Cluster to make the GPUs go brrr 🔥!
For attribution in academic contexts, please cite this work as
Beeching, Edward, and Tunstall, Lewis, and Rush, Sasha, "Scaling test-time compute with open models.", 2024.
BibTeX citation
@misc{beeching2024scalingtesttimecompute, title={Scaling test-time compute with open models}, author={Edward Beeching and Lewis Tunstall and Sasha Rush}, url={https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute}, }