|
<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<script src="distill.bundle.js" type="module" fetchpriority="high" blocking></script> |
|
<script src="main.bundle.js" type="module" fetchpriority="low" defer></script> |
|
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no"> |
|
<meta charset="utf8"> |
|
<base target="_blank"> |
|
<title>Scaling test-time compute for open models: How we implemented DeepMind’s compute-optimal recipe to solve hard math problems like OpenAI’s o1</title> |
|
<link rel="stylesheet" href="style.css"> |
|
</head> |
|
|
|
<body> |
|
<d-front-matter> |
|
<script id='distill-front-matter' type="text/json">{ |
|
"title": "📈 Scaling test-time compute with open models", |
|
"description": "This blog presents our work on scaling test-time compute for large language models, we share an open implementation of DeepMind's Scaling Test Time Compute paper and introduce a new method; diverse verifier tree search (DVTS) that improves scaling performance.", |
|
"published": "Dec 16, 2024", |
|
"affiliation": {"name": "Hugging Face"}, |
|
"authors": [ |
|
{ |
|
"author":"Edward Beeching", |
|
"authorURL":"https://huggingface.co/edbeeching" |
|
}, |
|
{ |
|
"author":"Lewis Tunstall", |
|
"authorURL":"https://huggingface.co/lewtun" |
|
}, |
|
{ |
|
"author":"Sasha Rush", |
|
"authorURL":"https://huggingface.co/srush" |
|
} |
|
] |
|
}</script> |
|
</d-front-matter> |
|
|
|
<d-title> |
|
<h1 class="l-page" style="text-align: center; display: none;">Scaling test-time compute</h1> |
|
<div id="title-plot" class="main-plot-container l-page"> |
|
<figure> |
|
<img src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/banner.gif" alt="Scaling test-time compute with open models"> |
|
</figure> |
|
</div> |
|
</d-title> |
|
<d-byline></d-byline> |
|
<d-article> |
|
<d-contents> |
|
</d-contents> |
|
|
|
<p>Over the last few years, the scaling of <em><strong>train-time</strong></em> <strong>compute</strong><strong> </strong>has dominated the progress of large language models (LLMs). Although this paradigm has proven to be remarkably effective, the resources needed to pretrain ever larger models are becoming prohibitively expensive, with <a href="https://youtu.be/WXhikNA5PIc?feature=shared">billion-dollar clusters</a> already on the horizon. This trend has sparked significant interest in a complementary approach: <em><strong>test-time compute scaling</strong></em>. Rather than relying on ever-larger pretraining budgets, test-time methods use dynamic inference strategies that allow models to “think longer” on harder problems. A prominent example is <a href="https://openai.com/index/learning-to-reason-with-llms/">OpenAI’s o1 model</a>, which shows consistent improvement on difficult math problems as one increases the amount of test-time compute:</p> |
|
|
|
<figure id="1581384e-bcac-805f-8c2b-dff4509f45cb" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/compute.png.webp"><img style="width:672px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/compute.png.webp"/></a></figure> |
|
|
|
<p id="1591384e-bcac-80ce-a40c-cf00e5a73720" class="">Although we don’t know how o1 was trained, <a href="https://huggingface.co/papers/2408.03314">recent research from DeepMind</a> shows that <em><strong>test-time compute can be scaled</strong></em> <em><strong>optimally</strong></em> through strategies like iterative self-refinement or using a reward model to perform search over the space of solutions. By adaptively allocating test-time compute per prompt, smaller models can rival—and sometimes even outperform—their larger, more resource-intensive counterparts. Scaling test-time compute is especially advantageous when memory is constrained and the available hardware is not sufficient to run a larger model. However, this promising approach was demonstrated with closed-source models, and no implementation details or code were released 😢.</p><p id="1591384e-bcac-80c1-82a7-cc9a41f9fdc4" class="">Over the past months we’ve been diving deep in trying to reverse engineer and reproduce several of these results and are finally happy to share some of our knowledge. More precisely, in this blog post we’ll cover:</p><ul id="15a1384e-bcac-8049-a51d-ebeb613ae570" class="bulleted-list"><li style="list-style-type:disc"><strong>Compute-optimal scaling: </strong>How we implemented DeepMind’s recipe to boost the mathematical capabilities of open models at test-time.</li></ul><ul id="15a1384e-bcac-803f-acef-cdf3efce939b" class="bulleted-list"><li style="list-style-type:disc"><strong>Diverse Verifier Tree Search (DVTS):</strong> An unpublished extension we developed to the verifier-guided tree search technique. This simple yet effective method improves diversity and delivers better performance, particularly at large test-time compute budgets.</li></ul><ul id="15a1384e-bcac-8069-947b-ef7818ea4d83" class="bulleted-list"><li style="list-style-type:disc">🧭 <strong>Search and Learn: </strong>A lightweight toolkit for implementing search strategies with LLMs and built for speed with vLLM. You can check it out <a href="https://github.com/huggingface/search-and-learn">here</a>.</li></ul><p id="15a1384e-bcac-804b-9470-f64abc63b648" class="">So how well does compute-optimal scaling work in practice? Check out this plot where the tiny 1B and 3B Llama Instruct models outperform their much larger 8B and 70B siblings on the challenging MATH-500 benchmark if you give them enough “time to think” 🤯:</p><figure id="15b1384e-bcac-809d-9f85-ff0b5a83e1cc" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-thumbnail.png"><img style="width:707.9921875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-thumbnail.png"/></a></figure><p id="15c1384e-bcac-804f-a4d5-c8f760f28096" class="">In the rest of this blog post, we’ll dive deep into the ingredients behind results like this one and walk you through practical strategies for implementing test-time compute scaling.</p> |
|
|
|
<h2 id="1591384e-bcac-809f-b7ce-d414b4c0df4e" class="">Strategies for test-time compute scaling</h2><p id="1591384e-bcac-8021-a784-d3340af0adb4" class="">There are two main strategies for scaling test-time compute:</p><ul id="1591384e-bcac-8060-8262-d07e6f3cb300" class="bulleted-list"><li style="list-style-type:disc"><strong>Self-Refinement: </strong>Models iteratively refine their own outputs or “thoughts” by identifying and correcting errors in subsequent iterations. While effective on some tasks, this strategy usually requires models to have built-in mechanisms for self-refinement, which can limit its applicability.</li></ul><ul id="1591384e-bcac-8020-a30a-e82811cb36cc" class="bulleted-list"><li style="list-style-type:disc"><strong>Search Against a Verifier: </strong>This approach focuses on generating multiple candidate answers and using verifier to select the best one. A verifier can be anything from a hard-coded heuristic to a learned reward model, but for the purposes of this blog post we will focus on learned verifiers. It includes techniques such as Best-of-N sampling and tree search. Search strategies are more flexible and can adapt to the difficulty of the problem, although their performance is constrained by the quality of the verifier.</li></ul><p id="1591384e-bcac-801d-82e3-d2dc50cf2b24" class="">In this blog post, we’ll concentrate on search-based methods as they represent a practical and scalable solution for test-time compute optimization. In particular, we’ll examine the three strategies illustrated below:</p><figure id="15b1384e-bcac-80df-a57e-e08ddb80ec8c" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/search-strategies.png"><img style="width:700px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/search-strategies.png"/></a></figure><ul id="15b1384e-bcac-807b-a75f-e0b421bd6ee3" class="bulleted-list"><li style="list-style-type:disc"><strong>Best-of-N: </strong>Generate multiple responses per problem and assign scores to each candidate answer, typically using a reward model. Then select the answer with the highest reward (or a weighted variant discussed later). This approach emphasizes answer quality over frequency.</li></ul><ul id="15b1384e-bcac-800c-af0a-d4e4df5a003a" class="bulleted-list"><li style="list-style-type:disc"><strong>Beam search: </strong>A systematic search method that explores the solution space, often combined with a <em><a href="https://huggingface.co/papers/2211.14275">process reward model (PRM)</a></em> to optimise both the sampling and evaluation of intermediate steps in problem-solving. Unlike conventional reward models that produce a single score on the final answer, PRMs provide a <em>sequence </em>of scores, one for each step of the reasoning process. This ability to provide fine-grained feedback makes PRMs a natural fit for search methods with LLMs.</li></ul><ul id="15b1384e-bcac-80c9-91e2-e013fee74ec6" class="bulleted-list"><li style="list-style-type:disc"><strong>Diverse verifier tree search (DVTS): </strong>An extension of beam search we developed that splits the initial beams into independent subtrees, which are then expanded greedily using a PRM. This method improves solution diversity and overall performance, particularly with larger test-time compute budgets.</li></ul><p id="15a1384e-bcac-803c-bc89-ed15f18eafdc" class="">With an understanding of the key search strategies, let’s move on to how we evaluated them in practice.</p> |
|
|
|
<h2 id="15a1384e-bcac-80e2-8506-e130c2c69407" class="">Experimental setup</h2><figure id="15c1384e-bcac-8096-9f01-f2c0d388ed1d" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/system.png"><img style="width:750px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/system.png"/></a></figure><p id="15a1384e-bcac-802a-b133-e279ff948bd0" class="">As illustrated in the diagram above, our experimental setup involves a pipeline with the following steps:</p><ol type="1" id="15c1384e-bcac-806e-b372-df95fdad14c8" class="numbered-list" start="1"><li>We begin by feeding a math problem to an LLM, which generates <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> <em><strong>partial solutions</strong></em>, e.g. an intermediate step in a derivation.</li></ol><ol type="1" id="15c1384e-bcac-8045-9a74-d1c2ca472aa3" class="numbered-list" start="2"><li>Each step is scored by a PRM, which estimates the probability of each step to eventually reach the correct final answer. <ol type="a" id="15c1384e-bcac-80c5-84ea-f8f6c00b8e65" class="numbered-list" start="1"><li>The steps and PRM scores are then used by a given search strategy to select which partial solutions should be further explored to generate the next round of intermediate steps.</li></ol></li></ol><ol type="1" id="15c1384e-bcac-803a-9772-ce27e91553c0" class="numbered-list" start="3"><li>Once the search strategy terminates, the final candidate solutions are ranked by the PRM to produce the final answer.</li></ol><p id="15c1384e-bcac-80af-a31e-fd3330874674" class="">To compare various search strategies, we used the following open models and datasets:</p><ul id="15a1384e-bcac-80d3-afcf-e5741e06845d" class="bulleted-list"><li style="list-style-type:disc"><strong>Model:</strong> We used <code>meta-llama/Llama-3.2-1B-Instruct</code> as our primary model for scaling test-time compute. With 1B parameters, its lightweight nature enables fast iterations, and its unsaturated performance on math benchmarks makes it an ideal choice for highlighting the benefits of scaling.</li></ul><ul id="15a1384e-bcac-807f-81fb-cd43b2273acf" class="bulleted-list"><li style="list-style-type:disc"><strong>Process reward model: </strong>To guide our search strategies, we used <code>RLHFlow/Llama3.1-8B-PRM-Deepseek-Data</code>, an 8B reward model that has been trained using <em>process supervision</em>. Process supervision is a training approach where models receive feedback on each step of their reasoning process, not just the final outcome. We picked this model since it belongs to the same model family as our policy and gave better results than other PRMs like <a href="https://huggingface.co/peiyi9979/math-shepherd-mistral-7b-prm">Math-Shepherd</a> we tested in this weight class.</li></ul><ul id="15a1384e-bcac-80da-8c95-d6ca809696a8" class="bulleted-list"><li style="list-style-type:disc"><strong>Dataset: </strong>We evaluated on the<a href="https://huggingface.co/datasets/HuggingFaceH4/MATH-500"> MATH-500 subset</a> of the <a href="https://arxiv.org/abs/2103.03874">MATH benchmark</a>, a dataset released by OpenAI as part of their <a href="https://arxiv.org/abs/2305.20050">research</a> on process supervision. These math problems span seven subjects and are challenging for both humans and most LLMs. Take a look at the dataset viewer below to get a taste for the problem difficulty! |
|
|
|
<iframe |
|
src="https://huggingface.co/datasets/HuggingFaceH4/MATH-500/embed/viewer/default/test" |
|
frameborder="0" |
|
width="100%" |
|
height="560px" |
|
></iframe> |
|
|
|
We tested each search strategy across compute budgets ranging from 1 to 256 generations per prompt and ran the data-generation pipeline with five random seeds to estimate variance across runs. You can find the models and datasets from our analysis in this <a href="https://huggingface.co/collections/HuggingFaceH4/scaling-test-time-compute-with-open-models-675c3b475a0d6eb4528fec23">collection</a>.</p><p id="15b1384e-bcac-80f9-827a-f25d2218c04f" class="">To warmup, we’ll begin with a simple baseline and progressively incorporate additional techniques to improve performance.</p> |
|
|
|
<h2 id="1591384e-bcac-801a-9201-cd4f3b8dfe96" class="">Majority voting: a simple baseline</h2><p id="1591384e-bcac-8011-8710-deb8e5601cdb" class="">Majority voting—or <a href="https://huggingface.co/papers/2203.11171">self-consistency decoding</a> if you want to be fancy—is the most straightforward method to aggregate an LLM’s outputs. As the name suggests, for a given math problem we generate <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> candidate solutions and pick the most frequent answer. For all our experiments we sampled up to <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mn>256</mn></mrow><annotation encoding="application/x-tex">N=256</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">256</span></span></span></span></span><span></span></span> candidates with temperature <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>T</mi><mo>=</mo><mn>0.8</mn></mrow><annotation encoding="application/x-tex">T=0.8</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">T</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">0.8</span></span></span></span></span><span></span></span> and generated up to 2048 tokens per problem.</p><p id="15c1384e-bcac-8086-a0e7-e0ca93b5ea94" class="">One quirk with the MATH benchmark is that answers must be formatted in a LaTeX box like <code>\boxed{answer}</code> . We initially tried the following simple system prompt for Llama 3.2 1B</p><script src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.29.0/prism.min.js" integrity="sha512-7Z9J3l1+EYfeaPKcGXu3MS/7T+w19WtKQY/n+xzmw4hZhJ9tyYmcUS+4QqAlzhicE5LAfMQSF3iFTK9bQdTxXg==" crossorigin="anonymous" referrerPolicy="no-referrer"></script><link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/prism/1.29.0/themes/prism.min.css" integrity="sha512-tN7Ec6zAFaVSG3TpNAKtk4DOHNpSwKHxxrsiw4GHKESGPs5njn/0sMCUMl2svV4wo4BK/rCP7juYz+zx+l6oeQ==" crossorigin="anonymous" referrerPolicy="no-referrer"/><pre id="15c1384e-bcac-8042-9b96-ff3d615bb9f0" class="code"><code class="language-Python">Please think step by step and put your final answer within \boxed{}.</code></pre><p id="15c1384e-bcac-80d0-bca6-ffed38482a37" class="">but found the resulting accuracy with greedy decoding (<style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>T</mi><mo>=</mo><mn>0</mn></mrow><annotation encoding="application/x-tex">T=0</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">T</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">0</span></span></span></span></span><span></span></span>) to be far worse than the 30.6% that Meta reported in their <a href="https://ai.meta.com/blog/llama-3-2-connect-2024-vision-edge-mobile-devices/">release</a>. Luckily, Meta also <a href="https://huggingface.co/datasets/meta-llama/Llama-3.2-1B-Instruct-evals/viewer/Llama-3.2-1B-Instruct-evals__math__details">published</a> the prompts they used for their evals and switching our system prompt to theirs made all the difference:</p><script src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.29.0/prism.min.js" integrity="sha512-7Z9J3l1+EYfeaPKcGXu3MS/7T+w19WtKQY/n+xzmw4hZhJ9tyYmcUS+4QqAlzhicE5LAfMQSF3iFTK9bQdTxXg==" crossorigin="anonymous" referrerPolicy="no-referrer"></script><link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/prism/1.29.0/themes/prism.min.css" integrity="sha512-tN7Ec6zAFaVSG3TpNAKtk4DOHNpSwKHxxrsiw4GHKESGPs5njn/0sMCUMl2svV4wo4BK/rCP7juYz+zx+l6oeQ==" crossorigin="anonymous" referrerPolicy="no-referrer"/><pre id="15c1384e-bcac-8011-aee6-d7c433df8b5f" class="code"><code class="language-Python">Solve the following math problem efficiently and clearly: |
|
|
|
- For simple problems (2 steps or fewer): |
|
Provide a concise solution with minimal explanation. |
|
|
|
- For complex problems (3 steps or more): |
|
Use this step-by-step format: |
|
|
|
## Step 1: [Concise description] |
|
[Brief explanation and calculations] |
|
|
|
## Step 2: [Concise description] |
|
[Brief explanation and calculations] |
|
|
|
... |
|
|
|
Regardless of the approach, always conclude with: |
|
|
|
Therefore, the final answer is: $\boxed{answer}$. I hope it is correct. |
|
|
|
Where [answer] is just the final number or expression that solves the problem.</code></pre> |
|
|
|
<p id="15c1384e-bcac-8076-9664-caa1b098c89c" class="">One subtlety with evaluating answers to math problems is that strings like <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mn>1</mn><mi mathvariant="normal">/</mi><msqrt><mn>3</mn></msqrt></mrow><annotation encoding="application/x-tex">1/\sqrt{3}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.1572em;vertical-align:-0.25em;"></span><span class="mord">1/</span><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9072em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em;"><span class="mord">3</span></span></span><span style="top:-2.8672em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em;"><svg xmlns="http://www.w3.org/2000/svg" width='400em' height='1.08em' viewBox='0 0 400000 1080' preserveAspectRatio='xMinYMin slice'><path d='M95,702 |
|
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14 |
|
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54 |
|
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10 |
|
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429 |
|
c69,-144,104.5,-217.7,106.5,-221 |
|
l0 -0 |
|
c5.3,-9.3,12,-14,20,-14 |
|
H400000v40H845.2724 |
|
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7 |
|
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z |
|
M834 80h400000v40h-400000z'/></svg></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.1328em;"><span></span></span></span></span></span></span></span></span></span><span></span></span> and <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msqrt><mn>3</mn></msqrt><mi mathvariant="normal">/</mi><mn>3</mn></mrow><annotation encoding="application/x-tex">\sqrt{3}/3</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.1572em;vertical-align:-0.25em;"></span><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9072em;"><span class="svg-align" style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style="padding-left:0.833em;"><span class="mord">3</span></span></span><span style="top:-2.8672em;"><span class="pstrut" style="height:3em;"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em;"><svg xmlns="http://www.w3.org/2000/svg" width='400em' height='1.08em' viewBox='0 0 400000 1080' preserveAspectRatio='xMinYMin slice'><path d='M95,702 |
|
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14 |
|
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54 |
|
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10 |
|
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429 |
|
c69,-144,104.5,-217.7,106.5,-221 |
|
l0 -0 |
|
c5.3,-9.3,12,-14,20,-14 |
|
H400000v40H845.2724 |
|
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7 |
|
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z |
|
M834 80h400000v40h-400000z'/></svg></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.1328em;"><span></span></span></span></span></span><span class="mord">/3</span></span></span></span></span><span></span></span> are distinct, but represent mathematically equivalent answers. The standard <a href="https://huggingface.co/papers/2206.14858">way</a> to handle this is to convert the a pair of answers to SymPy objects and then check whether subtracting the two objects and applying <code>sympy.simplify</code> gives zero. </p><p id="15c1384e-bcac-8043-a1d3-ffa0c5c3406e" class="">While this approach works well when comparing a small number of candidate answers, we found it was terribly slow when comparing many pairs in a list of <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> candidates; in some cases, slower than generating the candidates in the first place! To deal with this, we first reduced each answer to its <a href="https://www.notion.so/Scaling-test-time-compute-with-open-models-1531384ebcac800b9d73fca3503eb783?pvs=21">canonical form</a> and then computed the frequency of each form to determine the majority vote. Expand the detail below if you’re curious about how the code looks.</p><details><summary style="font-weight:600;font-size:1.25em;line-height:1.3;margin:0"><strong>Implementation detail</strong></summary><div class="indented"><p id="15d1384e-bcac-8080-b676-e09848694520" class="">To obtain the canonical form of an algebraic expression, we first convert the LaTeX string to SymPy, apply <code>sympy.simplify</code>, and finally convert back to LaTeX: </p><script src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.29.0/prism.min.js" integrity="sha512-7Z9J3l1+EYfeaPKcGXu3MS/7T+w19WtKQY/n+xzmw4hZhJ9tyYmcUS+4QqAlzhicE5LAfMQSF3iFTK9bQdTxXg==" crossorigin="anonymous" referrerPolicy="no-referrer"></script><link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/prism/1.29.0/themes/prism.min.css" integrity="sha512-tN7Ec6zAFaVSG3TpNAKtk4DOHNpSwKHxxrsiw4GHKESGPs5njn/0sMCUMl2svV4wo4BK/rCP7juYz+zx+l6oeQ==" crossorigin="anonymous" referrerPolicy="no-referrer"/><pre id="15b1384e-bcac-80c0-96c2-feba0d1cc92e" class="code"><code class="language-Python">from latex2sympy2 import latex2sympy |
|
from sympy import latex, simplify |
|
|
|
def get_canonical_form(expression: str) -> str: |
|
parsed_expr = latex2sympy(expression) |
|
simplified_expr = simplify(parsed_expr) |
|
return latex(simplified_expr)</code></pre><p id="15c1384e-bcac-80ea-92cd-d199d25281b4" class="">With this function, we can then iterate over all candidate solutions in an list and keep track of how many times a canonical form has been seen before computing the final majority vote:</p><script src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.29.0/prism.min.js" integrity="sha512-7Z9J3l1+EYfeaPKcGXu3MS/7T+w19WtKQY/n+xzmw4hZhJ9tyYmcUS+4QqAlzhicE5LAfMQSF3iFTK9bQdTxXg==" crossorigin="anonymous" referrerPolicy="no-referrer"></script><link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/prism/1.29.0/themes/prism.min.css" integrity="sha512-tN7Ec6zAFaVSG3TpNAKtk4DOHNpSwKHxxrsiw4GHKESGPs5njn/0sMCUMl2svV4wo4BK/rCP7juYz+zx+l6oeQ==" crossorigin="anonymous" referrerPolicy="no-referrer"/><pre id="15d1384e-bcac-80f3-9062-c4f44a0d0b8b" class="code"><code class="language-Python">def find_majority_answer(answers: List[str]) -> str: |
|
canonical_groups = defaultdict(int) |
|
canonical_to_original = {} |
|
|
|
for answer in answers: |
|
canonical_form = get_canonical_form(answer) |
|
|
|
# Increment count for the canonical form |
|
canonical_groups[canonical_form] += 1 |
|
|
|
# Track the original answer for this canonical form |
|
if canonical_form not in canonical_to_original: |
|
canonical_to_original[canonical_form] = answer |
|
|
|
# Find the canonical form with the largest count |
|
max_count = max(canonical_groups.values()) |
|
for canonical_form, count in canonical_groups.items(): |
|
if count == max_count: |
|
# Return the first occurring group in case of a tie |
|
return canonical_to_original[canonical_form]</code></pre> |
|
|
|
<p id="15d1384e-bcac-804e-a99c-fe5e83313a3d" class="">This approach was significantly faster than checking each pair of solutions independently for equality.</p></div></details> |
|
<br> |
|
<p id="15b1384e-bcac-80f7-83e8-e1d6b360faa4" class="">Here’s how majority voting performs when applied to the generations from Llama 3.2 1B Instruct:</p><figure id="15b1384e-bcac-8072-9987-d80031b97793" class="image"><a href="Scaling%20test-time%20compute%20with%20open%20models%201531384ebcac800b9d73fca3503eb783/methods-maj.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-maj.png"/></a></figure><p id="15b1384e-bcac-8020-8688-fe1713e92c2b" class="">The results show that majority voting yields a significant improvement over the greedy decoding baseline, but its gains start to plateau after approximately <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mn>64</mn></mrow><annotation encoding="application/x-tex">N=64</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">64</span></span></span></span></span><span></span></span> generations. This limitation arises because majority voting struggles with problems that require nuanced reasoning or tasks where errors are consistent across generations. If you’re also wondering why the majority voting accuracy is worse than the 0-shot CoT baseline for N=1 and 2, that’s because we sample at T=0.8, which makes it less likely we produce the correct answer among a handful of candidates.</p><p id="15b1384e-bcac-8075-8fef-f26f0b8e5559" class="">Building on the limitations of majority voting, let’s see how incorporating a reward model can enhance performance.</p> |
|
|
|
<h2 id="1591384e-bcac-8098-9db5-f76c9ce00e7a" class="">Beyond majority: Best-of-N</h2><p id="15b1384e-bcac-8019-9b5c-d11bae74628d" class="">Best-of-N is a simple, but effective extension to majority voting that uses a reward model to determine the most plausible answer. This method comes in two main variants:</p><ul id="15b1384e-bcac-80b4-aae4-d5e98e29debf" class="bulleted-list"><li style="list-style-type:disc"><strong>Vanilla Best-of-N:</strong> Generate <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> independent responses and select the one with the <em>highest RM reward</em> as the final answer. This ensures that the most confident individual response is chosen, but it doesn’t account for consistency across answers.</li></ul><ul id="15b1384e-bcac-8035-a394-fbd954af1984" class="bulleted-list"><li style="list-style-type:disc"><strong>Weighted Best-of-N:</strong> Aggregate scores across all identical responses and select the answer with the <em>highest total reward</em>. This approach prioritises high-quality answers by boosting their scores through repeated occurrences. Mathematically, the weighting across answers <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>a</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">a_i</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.5806em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">a</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span><span></span></span> is performed as follows:<figure id="15d1384e-bcac-80e5-8d68-fe7bad033482" class="equation"><style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><div class="equation-container"><span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><msub><mi>a</mi><mrow><mi mathvariant="normal">w</mi><mi mathvariant="normal">e</mi><mi mathvariant="normal">i</mi><mi mathvariant="normal">g</mi><mi mathvariant="normal">h</mi><mi mathvariant="normal">t</mi><mi mathvariant="normal">e</mi><mi mathvariant="normal">d</mi></mrow></msub><mo>=</mo><mi>arg</mi><mo></mo><munder><mrow><mi>max</mi><mo></mo></mrow><mi>a</mi></munder><munderover><mo>∑</mo><mrow><mi>i</mi><mo>=</mo><mn>1</mn></mrow><mi>N</mi></munderover><mi mathvariant="double-struck">I</mi><mo stretchy="false">(</mo><msub><mi>a</mi><mi>i</mi></msub><mo>=</mo><mi>a</mi><mo stretchy="false">)</mo><mo>⋅</mo><mrow><mi mathvariant="normal">R</mi><mi mathvariant="normal">M</mi></mrow><mo stretchy="false">(</mo><mi>p</mi><mo separator="true">,</mo><msub><mi>s</mi><mi>i</mi></msub><mo stretchy="false">)</mo><mtext> </mtext><mo separator="true">,</mo></mrow><annotation encoding="application/x-tex">a_\mathrm{weighted} = \arg\max_{a} \sum_{i=1}^{N} \mathbb{I}(a_i = a) \cdot \mathrm{RM}(p, s_i) \,,</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7167em;vertical-align:-0.2861em;"></span><span class="mord"><span class="mord mathnormal">a</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathrm mtight">weighted</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:3.106em;vertical-align:-1.2777em;"></span><span class="mop">ar<span style="margin-right:0.01389em;">g</span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.4306em;"><span style="top:-2.4em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">a</span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop">max</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.7em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.8283em;"><span style="top:-1.8723em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span><span class="mrel mtight">=</span><span class="mord mtight">1</span></span></span></span><span style="top:-3.05em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op">∑</span></span></span><span style="top:-4.3em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.10903em;">N</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.2777em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathbb">I</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">a</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">a</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathrm">RM</span></span><span class="mopen">(</span><span class="mord mathnormal">p</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mpunct">,</span></span></span></span></span></div></figure><p id="15d1384e-bcac-8083-8f2a-d5701df84dcd" class="">where <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mrow><mi mathvariant="normal">R</mi><mi mathvariant="normal">M</mi></mrow><mo stretchy="false">(</mo><mi>p</mi><mo separator="true">,</mo><msub><mi>s</mi><mi>i</mi></msub><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\mathrm{RM}(p, s_i)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathrm">RM</span></span><span class="mopen">(</span><span class="mord mathnormal">p</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span><span></span></span> is the reward model score of the <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>i</mi></mrow><annotation encoding="application/x-tex">i</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6595em;"></span><span class="mord mathnormal">i</span></span></span></span></span><span></span></span>-th solution solution <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>s</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">s_i</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.5806em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span><span></span></span> to problem p.</p></li></ul><p id="15d1384e-bcac-8012-8282-c0ed1215a611" class="">Typically, one usually uses an outcome reward model (ORM) to get a single, solution-level score. But to allow for fair comparison with the other search strategies discussed later, we will use the same PRM to score the solutions from Best-of-N. As illustrated below, PRMs produce a <em>cumulative</em> <em>sequence of step-level scores</em> per solution, so we need to perform a reduction over the steps to obtain a single solution-level score: </p><figure id="15d1384e-bcac-80d6-815f-c7d87fe313a6" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/prm-reductions.png"><img style="width:700px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/prm-reductions.png"/></a></figure><p id="15d1384e-bcac-80e7-8d1a-e0aab286f9f4" class="">In the literature, the most common reductions are the following:</p><ul id="15b1384e-bcac-80e4-92b4-e2bc90a9130a" class="bulleted-list"><li style="list-style-type:disc"><strong>Min: </strong>use the minimum score across all steps.</li></ul><ul id="15b1384e-bcac-8073-b4dc-fbfcfc0567bc" class="bulleted-list"><li style="list-style-type:disc"><strong>Prod: </strong>use the product of step-level scores.</li></ul><ul id="15b1384e-bcac-80ed-8cc5-fa6e2ce330fb" class="bulleted-list"><li style="list-style-type:disc"><strong>Last: </strong>use the final score in the steps. This score contains the cumulative information from all prior steps, so treats the PRM effectively as an ORM that is able to score partial solutions.</li></ul><p id="15b1384e-bcac-80ad-96d1-d313ae3e1954" class="">We experimented with each reduction and found—like DeepMind—that <em><strong>“last” performs best for our choice of task and PRM</strong></em>. We use this aggregation throughout all of our experiments and you can expand the detail below to see how we implemented it, along with the weighting procedure discussed above.</p> |
|
|
|
<p id="15d1384e-bcac-809a-8aa8-c52ca7301b52" class="">Here’s the results one gets from applying both variants of Best-of-N:</p><figure id="15b1384e-bcac-808d-857e-d492683a4a91" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-maj-bon.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-maj-bon.png"/></a></figure><p id="15b1384e-bcac-8001-9320-ff788bab0c52" class="">The results reveal a clear advantage: <strong>weighted Best-of-N</strong> consistently outperforms vanilla Best-of-N, especially with larger generation budgets. Its ability to aggregate scores across identical responses ensures that even less frequent but higher-quality answers are effectively prioritized.</p><p id="15b1384e-bcac-808a-b3ff-ee08c05a20af" class="">However, despite these improvements, we’re still falling short of the performance achieved by the Llama 8B model and the Best-of-N approach is starting to plateau at <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mn>256</mn></mrow><annotation encoding="application/x-tex">N=256</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">256</span></span></span></span></span><span></span></span> generations. Can we push the boundaries further by supervising the search process step-by-step? Let’s find out 🚀!</p> |
|
|
|
<h2 id="1591384e-bcac-8065-a02c-cd760ebd6cd1" class="">Beam search with process reward models</h2><p id="15a1384e-bcac-80e1-9e0e-c01f5f373805" class="">Beam search is a structured search method that systematically explores the solution space, making it a powerful tool for improving model outputs at test-time. When combined with a PRM, beam search can optimize both the generation and evaluation of intermediate steps in problem-solving. The way it works is as follows:</p><ol type="1" id="15d1384e-bcac-8007-8d79-cdaa74e4c8c0" class="numbered-list" start="1"><li>Generate multiple candidate solutions <em>iteratively</em> by maintaining a fixed number of "beams" or active paths <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span>.</li></ol><ol type="1" id="15d1384e-bcac-8020-bf69-e67fd962062b" class="numbered-list" start="2"><li>In the first iteration, sample <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> independent steps from the LLM with temperature <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>T</mi></mrow><annotation encoding="application/x-tex">T</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">T</span></span></span></span></span><span></span></span> to introduce diversity in the responses. These steps are usually defined by a stopping criterion like terminating on a new line <code>\n</code> or double new line <code>\n\n</code>.</li></ol><ol type="1" id="15d1384e-bcac-80c2-aeaa-f6d73682eb8c" class="numbered-list" start="3"><li>Score each step with the PRM and select the top <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mi mathvariant="normal">/</mi><mi>M</mi></mrow><annotation encoding="application/x-tex">N/M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mord">/</span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span></span><span></span></span> steps as candidates for the next round of generation. Here <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>M</mi></mrow><annotation encoding="application/x-tex">M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span></span><span></span></span> denotes the “beam width” of a given active path. As in Best-of-N, we used the “last” reduction to score the partial solutions at each iteration.</li></ol><ol type="1" id="15d1384e-bcac-8022-966b-e1dae6845cc1" class="numbered-list" start="4"><li>Expand the steps selected in step (3) by sampling <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>M</mi></mrow><annotation encoding="application/x-tex">M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span></span><span></span></span> next steps in the solution.</li></ol><ol type="1" id="15d1384e-bcac-8023-b6b6-f470e22ac78a" class="numbered-list" start="5"><li>Repeat steps (3) and (4) until the EOS token is reached or the maximum search depth is exceeded.</li></ol><p id="15a1384e-bcac-8003-a9d9-da7f3a4dc321" class="">By allowing the PRM to evaluate the correctness of intermediate steps, beam search can identify and prioritize promising paths early in the process. This step-by-step evaluation is particularly beneficial for complex reasoning tasks like mathematics, where verifying partial solutions can significantly improve final outcomes.</p><details><summary style="font-weight:600;font-size:1.25em;line-height:1.3;margin:0">Implementation detail</summary><div class="indented"><p id="15b1384e-bcac-8065-a739-d24b699106be" class="">When we implemented beam search with process supervision, we encountered two major footguns with the Llama 3 chat template that are worth mentioning:</p><ul id="15d1384e-bcac-803c-84b3-d881bc2ca3b5" class="bulleted-list"><li style="list-style-type:disc">By default, the chat template trims trailing new lines from every assistant turn. As a result, if one uses <code>\n</code> or <code>\n\n</code> to terminate a step, these tokens are lost on subsequent steps and force the model to produce peculiar outputs.</li></ul><ul id="15d1384e-bcac-808f-97f1-fb7d27565e36" class="bulleted-list"><li style="list-style-type:disc">The chat template is prefixed with Llama’s BOS token. When the formatted string is fed to vLLM a <em>second</em> BOS token is added which completely ruins performance, even though the generations look mostly coherent 🤯</li></ul><p id="15d1384e-bcac-8041-9164-ecc3d9497886" class="">The solution is to overwrite the Llama 3 chat template to prevent trimming and exclude the BOS token prefix. </p><p id="15a1384e-bcac-8090-b5fc-eb36a6588e60" class=""> |
|
</p></div></details><p id="15d1384e-bcac-80e9-8e65-e1b58080b94c" class="">In our experiments, we followed DeepMind’s hyperparameter choices and ran beam search with the following:</p><ul id="15d1384e-bcac-8098-8574-e16392fc6123" class="bulleted-list"><li style="list-style-type:disc"><style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> beams in compute scalings of 4, 16, 64, 256</li></ul><ul id="15d1384e-bcac-8067-b37c-e9692e34678c" class="bulleted-list"><li style="list-style-type:disc">Fixed beam width <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>M</mi><mo>=</mo><mn>4</mn></mrow><annotation encoding="application/x-tex">M=4</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">4</span></span></span></span></span><span></span></span></li></ul><ul id="15d1384e-bcac-8093-a928-c16e31e29e3f" class="bulleted-list"><li style="list-style-type:disc">Sampling with temperature <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>T</mi><mo>=</mo><mn>0.8</mn></mrow><annotation encoding="application/x-tex">T=0.8</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">T</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">0.8</span></span></span></span></span><span></span></span></li></ul><ul id="15d1384e-bcac-802a-8416-e332ca20237f" class="bulleted-list"><li style="list-style-type:disc">Up to 40 iterations, i.e. a tree of maximum depth with 40 steps.</li></ul><p id="15d1384e-bcac-8051-abe5-dc84c42a1b5f" class="">As shown below, the results are striking: with a test-time budget of <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mn>4</mn></mrow><annotation encoding="application/x-tex">N=4</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">4</span></span></span></span></span><span></span></span>, beam search achieves the same accuracy as Best-of-N for <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mn>16</mn></mrow><annotation encoding="application/x-tex">N=16</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">16</span></span></span></span></span><span></span></span>, i.e. it is 4x more compute efficient! Moreover, beam search matches the performance of Llama 3.1 8B with just <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mn>32</mn></mrow><annotation encoding="application/x-tex">N=32</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">32</span></span></span></span></span><span></span></span> solutions per problem. The average performance on MATH by computer science PhD students is around 40%, so reaching nearly 55% isn’t too bad for a 1B model 💪!</p><figure id="15b1384e-bcac-80e9-97fa-fe50d1811f5b" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-maj-bon-beam.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-maj-bon-beam.png"/></a></figure><h3 id="15a1384e-bcac-800c-baee-fb99b242ef87" class="">Which problems does beam search solve best?</h3><p id="15d1384e-bcac-80e3-938a-c3f09db2e9ff" class="">Although in aggregate it is clear that beam search is a better search strategy than Best-of-N or majority voting, the DeepMind paper showed that <em><strong>each strategy has tradeoffs that depend on the problem difficulty</strong></em> and test-time compute budget. </p><p id="15d1384e-bcac-8015-a8f0-c2323b9e535f" class="">To see which problems are best suited for which strategy, DeepMind computed a distribution over estimated problem difficulty, and then binned the results into quintiles. In other words, each problem is assigned one of 5 levels, where level 1 indicates easier problems and level 5 indicates the hardest ones. To estimate problem difficulty, DeepMind generated 2048 candidate solutions with standard sampling per problem and then proposed the following heuristics:</p><ul id="15d1384e-bcac-8080-9152-caeaa288073c" class="bulleted-list"><li style="list-style-type:disc"><strong>Oracle: </strong>use the ground truth labels to estimate the <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mi>a</mi><mi>s</mi><mi>s</mi><mi mathvariant="normal">@</mi><mn>1</mn></mrow><annotation encoding="application/x-tex">pass@1</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal">p</span><span class="mord mathnormal">a</span><span class="mord mathnormal">ss</span><span class="mord">@1</span></span></span></span></span><span></span></span> score per problem. Bin the distribution of <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mi>a</mi><mi>s</mi><mi>s</mi><mi mathvariant="normal">@</mi><mn>1</mn></mrow><annotation encoding="application/x-tex">pass@1</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal">p</span><span class="mord mathnormal">a</span><span class="mord mathnormal">ss</span><span class="mord">@1</span></span></span></span></span><span></span></span> scores to determine the quintiles.</li></ul><ul id="15d1384e-bcac-80f9-8778-d4045c6faa7d" class="bulleted-list"><li style="list-style-type:disc"><strong>Model: </strong>use the distribution of average PRM scores per problem to determine the quintiles. The intuition here is that harder problems will have lower scores.</li></ul><p id="15d1384e-bcac-80a3-af7c-f3497126ab1e" class="">Here’s the breakdown of the various methods according to the pass@1 scores and across four test-time compute budgets of <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mo stretchy="false">[</mo><mn>4</mn><mo separator="true">,</mo><mn>16</mn><mo separator="true">,</mo><mn>64</mn><mo separator="true">,</mo><mn>256</mn><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">N = [4,16,64, 256]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">[</span><span class="mord">4</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">16</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">64</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">256</span><span class="mclose">]</span></span></span></span></span><span></span></span>:</p><figure id="15b1384e-bcac-80ad-9cf3-cf5bcbd3f53b" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/levels-maj-bon-beam.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/levels-maj-bon-beam.png"/></a></figure><p id="15d1384e-bcac-80c3-93b3-fa4c071ac807" class="">In this plot, each bar denotes a test-time compute budget, and within each bar we show the relative accuracy of each method. For example, in the group of four bars on difficulty level 2 we see that:</p><ul id="15d1384e-bcac-8091-b3fb-cad0ab99b2c1" class="bulleted-list"><li style="list-style-type:disc">Majority voting is the worst performer for all compute budgets, except for <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mn>256</mn></mrow><annotation encoding="application/x-tex">N=256</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">256</span></span></span></span></span><span></span></span>, where beam search is worst.</li></ul><ul id="15d1384e-bcac-8076-b88c-c7f55fa0cdbc" class="bulleted-list"><li style="list-style-type:disc">Beam search is best for <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mo stretchy="false">[</mo><mn>4</mn><mo separator="true">,</mo><mn>16</mn><mo separator="true">,</mo><mn>64</mn><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">N=[4,16,64]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">[</span><span class="mord">4</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">16</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">64</span><span class="mclose">]</span></span></span></span></span><span></span></span>, but Best-of-N is best for <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mn>256</mn></mrow><annotation encoding="application/x-tex">N=256</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">256</span></span></span></span></span><span></span></span>.</li></ul><p id="15a1384e-bcac-80d4-af98-eaebf5fcf84e" class="">Although we see that beam search gives consistent gains in the medium and hard problems (levels 3-5), it tends to do worse than Best-of-N (and even majority voting!) on the simpler problems and especially at large compute budgets. </p><p id="15a1384e-bcac-805b-9949-f0cdc44c9e3c" class="">We realized from looking at the resulting trees produced by beam search, that if a single step is assigned high reward, then the whole tree collapses to that trace and thus diversity is impacted. This prompted us to explore an extension to beam search that maximises diversity - let’s take a look!</p> |
|
|
|
<h2 id="1591384e-bcac-80d2-8234-fe0e9a4df59d" class="">DVTS: boosting performance with diversity</h2><p id="1591384e-bcac-8044-b7c5-cf39e4aed683" class="">As we saw above beam search gives strong performance over Best-of-N, but tends to underperform on simpler problems and at large test-time compute budgets. To address this, we developed an extension we call Diverse Verifier Tree Search (DVTS) that is designed to maximise diversity at large <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span>.</p><p id="15a1384e-bcac-80ff-a97b-c7ccd88958e4" class="">DVTS works in a similar fashion as beam search, with the following modifications:</p><ol type="1" id="15d1384e-bcac-806c-8004-e054a98d98ef" class="numbered-list" start="1"><li>For a given <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> and <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>M</mi></mrow><annotation encoding="application/x-tex">M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span></span><span></span></span>, expand the initial set of beams into <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mi mathvariant="normal">/</mi><mi>M</mi></mrow><annotation encoding="application/x-tex">N/M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mord">/</span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span></span><span></span></span> <em>independent</em> subtrees.</li></ol><ol type="1" id="15d1384e-bcac-8081-8508-feb06a13469b" class="numbered-list" start="2"><li>For each subtree, select the step with the highest PRM score.</li></ol><ol type="1" id="15d1384e-bcac-806a-976f-ec9596cd9532" class="numbered-list" start="3"><li>Generate <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>M</mi></mrow><annotation encoding="application/x-tex">M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span></span><span></span></span> new steps from the nodes selected in step (2) and select the step with the highest PRM score.</li></ol><ol type="1" id="15d1384e-bcac-808e-aa2b-f391ec426953" class="numbered-list" start="4"><li>Repeat step (3) until the EOS token or maximum tree depth is reached.</li></ol><p id="15d1384e-bcac-8087-b916-d9603de035dd" class="">Here’s the results from applying DVTS to Llama 1B:</p><figure id="15b1384e-bcac-801c-a1e7-d4e544826da3" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-all.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-all.png"/></a></figure><p id="15b1384e-bcac-80e1-bc9b-dbdb5738b9f1" class="">As we can see, DVTS provides a complementary strategy to beam search: at small <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> beam search is more effective at finding correct solutions, but at large <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> the diversity of DVTS candidates kicks in and we get better performance. </p><p id="15d1384e-bcac-80a7-8379-dca3c329c433" class="">We can also see this manifested in the problem difficulty breakdown, where DVTS enhances performance on the easy / medium problems at large N, while beam search is best at small N across model problem difficulties:</p><figure id="15b1384e-bcac-807a-8dca-f322077cc616" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/levels-all.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/levels-all.png"/></a></figure> |
|
|
|
<h2 id="1591384e-bcac-806b-9dd0-c80a250c7754" class="">The best of all worlds: compute-optimal scaling</h2><p id="1591384e-bcac-80e0-93e6-ceaacc131142" class="">Armed with various search strategies, a natural question is which one is best? In the DeepMind paper, they proposed a <em><strong>compute-optimal</strong></em> <em><strong>scaling strategy</strong></em> where one selects the search method and hyperparameters <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>θ</mi></mrow><annotation encoding="application/x-tex">\theta</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span></span></span></span></span><span></span></span> that achieves the <em><strong>best performance for a given compute budget </strong></em><em><style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span></em><em><strong>:</strong></em></p><figure id="15e1384e-bcac-8054-8afa-ca441a776a05" class="equation"><style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><div class="equation-container"><span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><msubsup><mi>θ</mi><mrow><mi>q</mi><mo separator="true">,</mo><msup><mi>a</mi><mo>∗</mo></msup><mo stretchy="false">(</mo><mi>q</mi><mo stretchy="false">)</mo></mrow><mo>∗</mo></msubsup><mo stretchy="false">(</mo><mi>N</mi><mo stretchy="false">)</mo><mo>=</mo><mi><munder><mo><mi>arg</mi><mo></mo><mi>max</mi><mo></mo></mo><mi>θ</mi></munder></mi><mrow><mo fence="true">(</mo><msub><mi mathvariant="double-struck">E</mi><mrow><mi>y</mi><mo>∼</mo><mtext>Target</mtext><mo stretchy="false">(</mo><mi>θ</mi><mo separator="true">,</mo><mi>N</mi><mo separator="true">,</mo><mi>q</mi><mo stretchy="false">)</mo></mrow></msub><mrow><mo fence="true">[</mo><msub><mn mathvariant="double-struck">1</mn><mrow><mi>y</mi><mo>=</mo><msup><mi>y</mi><mo>∗</mo></msup><mo stretchy="false">(</mo><mi>q</mi><mo stretchy="false">)</mo></mrow></msub><mo fence="true">]</mo></mrow><mo fence="true">)</mo></mrow><mo separator="true">,</mo></mrow><annotation encoding="application/x-tex">\theta_{q,a^*(q)}^*(N) = \underset{\theta}{\arg\max} \left( \mathbb{E}_{y \sim \text{Target}(\theta, N, q)} \left[ \mathbb{1}_{y = y^*(q)} \right] \right),</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.197em;vertical-align:-0.447em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.7387em;"><span style="top:-2.428em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span><span class="mpunct mtight">,</span><span class="mord mtight"><span class="mord mathnormal mtight">a</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.6183em;"><span style="top:-2.786em;margin-right:0.0714em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mopen mtight">(</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span><span class="mclose mtight">)</span></span></span></span><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.447em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1.7965em;vertical-align:-0.9465em;"></span><span class="mord"><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.4306em;"><span style="top:-2.1535em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em;">θ</span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop"><span class="mop">ar<span style="margin-right:0.01389em;">g</span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mop">max</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.9465em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="minner"><span class="mopen delimcenter" style="top:0em;"><span class="delimsizing size1">(</span></span><span class="mord"><span class="mord mathbb">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.5198em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span><span class="mrel mtight">∼</span><span class="mord text mtight"><span class="mord mtight">Target</span></span><span class="mopen mtight">(</span><span class="mord mathnormal mtight" style="margin-right:0.02778em;">θ</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight" style="margin-right:0.10903em;">N</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.3552em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="minner"><span class="mopen delimcenter" style="top:0em;"><span class="delimsizing size1">[</span></span><span class="mord"><span class="mord">1</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.5198em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span><span class="mrel mtight">=</span><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.6183em;"><span style="top:-2.786em;margin-right:0.0714em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mopen mtight">(</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.3552em;"><span></span></span></span></span></span></span><span class="mclose delimcenter" style="top:0em;"><span class="delimsizing size1">]</span></span></span><span class="mclose delimcenter" style="top:0em;"><span class="delimsizing size1">)</span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mpunct">,</span></span></span></span></span></div></figure><p id="15e1384e-bcac-8011-bf10-cc868e25db7c" class="">where <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msup><mi>y</mi><mo>∗</mo></msup><mo stretchy="false">(</mo><mi>q</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">y^*(q)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.6887em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.03588em;">q</span><span class="mclose">)</span></span></span></span></span><span></span></span> is the ground-truth for question <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>q</mi></mrow><annotation encoding="application/x-tex">q</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.1944em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">q</span></span></span></span></span><span></span></span> and <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msubsup><mi>θ</mi><mrow><mi>q</mi><mo separator="true">,</mo><msup><mi>a</mi><mo>∗</mo></msup><mo stretchy="false">(</mo><mi>q</mi><mo stretchy="false">)</mo></mrow><mo>∗</mo></msubsup><mo stretchy="false">(</mo><mi>N</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\theta_{q,a^*(q)}^*(N)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.247em;vertical-align:-0.497em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6887em;"><span style="top:-2.378em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span><span class="mpunct mtight">,</span><span class="mord mtight"><span class="mord mathnormal mtight">a</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.6183em;"><span style="top:-2.786em;margin-right:0.0714em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mopen mtight">(</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span><span class="mclose mtight">)</span></span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.497em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mclose">)</span></span></span></span></span><span></span></span> denotes the compute-optimal scaling strategy. Since computing <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msubsup><mi>θ</mi><mrow><mi>q</mi><mo separator="true">,</mo><msup><mi>a</mi><mo>∗</mo></msup><mo stretchy="false">(</mo><mi>q</mi><mo stretchy="false">)</mo></mrow><mo>∗</mo></msubsup><mo stretchy="false">(</mo><mi>N</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\theta_{q,a^*(q)}^*(N)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.247em;vertical-align:-0.497em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.6887em;"><span style="top:-2.378em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span><span class="mpunct mtight">,</span><span class="mord mtight"><span class="mord mathnormal mtight">a</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.6183em;"><span style="top:-2.786em;margin-right:0.0714em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mopen mtight">(</span><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span><span class="mclose mtight">)</span></span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.497em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mclose">)</span></span></span></span></span><span></span></span> directly is somewhat tricky, DeepMind proposed an approximation based on the <em><strong>problem difficulty</strong></em>, i.e. allocate test-time compute according to which search strategy achieves best performance for a given difficulty level.</p><p id="15a1384e-bcac-80c9-a276-d5ea8974c543" class="">For example, on simpler problems and lower compute budgets, it is better to use strategies like Best-of-N, while on harder problems, beam search is the better choice. To implement this, for each method we compute the accuracy for a given difficulty level and test-time compute budget. And voila, we now have our compute-optimal curve!</p><figure id="15b1384e-bcac-80b3-bc58-d20ba41d3950" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-opt.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-opt.png"/></a></figure> |
|
|
|
<h2 id="1591384e-bcac-809a-96d2-e928398d159a" class="">Scaling up to larger models</h2><p id="15a1384e-bcac-8078-86d7-f48c2146444e" class="">We also explored scaling up the compute-optimal recipe to Llama 3.2 3B Instruct to see at what point the benefits of the PRM fade in comparison to the policy’s own capacity. To our surprise, compute-optimal scaling works remarkably well, with the 3B model surpassing the performance of Llama 3.1 70B Instruct (22x it's size!):</p><figure id="15b1384e-bcac-80b3-bc58-d20ba41d3950" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-opt-3b.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-opt-3b.png"/></a></figure> |
|
|
|
<h2 id="15a1384e-bcac-809c-b5e7-eb92dadaebb4" class="">Where to go from here?</h2><p id="15b1384e-bcac-8052-91d7-d6e1f6f66e09" class="">This exploration of test-time compute scaling has revealed both the potential and the challenges of leveraging search-based methods. As we look ahead, several exciting directions emerge:</p><ol type="1" id="15b1384e-bcac-8040-91d7-e4236e1530ef" class="numbered-list" start="1"><li><strong>The Power of Strong Verifiers:</strong><p id="15b1384e-bcac-8032-85d2-e0e0283f7e4a" class="">Strong verifiers play a critical role in enhancing performance. However, their current limitations are apparent, as highlighted in benchmarks like <em>ProcessBench</em>. Improving the robustness and generalization of verifiers will be crucial for advancing these methods.</p></li></ol><ol type="1" id="15b1384e-bcac-80da-a940-e4ad39ff493d" class="numbered-list" start="2"><li><strong>The Challenge of Self-Verification:</strong><p id="15b1384e-bcac-8077-a5bb-c75fb34b60eb" class="">The ultimate goal—or "holy grail"—is achieving self-verification, where models can validate their own outputs autonomously. This approach appears to be what models like o1 are doing, but remains difficult to implement in practice. Unlike standard supervised fine-tuning (SFT), self-verification demands more nuanced strategies. The recent DeepMind paper on self-verification and <em>Score</em> sheds light on this challenge and offers a pathway for future research.</p></li></ol><ol type="1" id="15b1384e-bcac-8038-88f9-dd21d5cf272c" class="numbered-list" start="3"><li><strong>Integrating “Thoughts” into the Process:</strong><p id="15b1384e-bcac-801d-923f-d1ea46844321" class="">Incorporating explicit intermediate steps or “thoughts” during generation could further enhance reasoning and decision-making. By integrating structured reasoning into the search process, we may unlock better performance on complex tasks.</p></li></ol><ol type="1" id="15b1384e-bcac-802b-aa65-d84fb31ed476" class="numbered-list" start="4"><li><strong>Search as a Data Generation Tool:</strong><p id="15b1384e-bcac-802a-8488-ef74d5fb0a55" class="">This method can also serve as a powerful data generation process, creating high-quality training datasets. For example, fine-tuning models like Llama 1B on correct traces produced by search could yield significant gains. This on-policy approach resembles techniques like ReST or V-StaR but with the added benefits of search, offering a promising direction for iterative improvement.</p></li></ol><ol type="1" id="15b1384e-bcac-804a-8a0e-dcaf06767339" class="numbered-list" start="5"><li><strong>A Call for More PRMs:</strong><p id="15b1384e-bcac-803c-9524-f563cd123121" class="">Open process reward models (PRMs) are relatively rare, limiting their broader application. Developing and sharing more PRMs for different domains is a critical area where the community can contribute significantly.</p></li></ol><ol type="1" id="15b1384e-bcac-805f-b97c-d3383a2c9682" class="numbered-list" start="6"><li><strong>Expanding Beyond Verifiable Domains:</strong><p id="15b1384e-bcac-8073-9fea-c717f2d50df5" class="">While current methods excel in domains like math and code, where solutions are inherently verifiable, extending these techniques to other areas remains a major challenge. How can we adapt these strategies for less structured or subjective tasks? This is a vital question for future exploration.</p></li></ol><h2 id="15b1384e-bcac-8093-82c6-d9c951dc0bab" class="">Acknowledgements</h2><p id="15c1384e-bcac-80d5-8ad4-d7f6874404c5" class="">We are grateful to Charlie Snell and Aviral Kumar for many discussions about test-time compute scaling and for sharing implementation details from their work. We thank Chun Te Lee for designing the lovely banner and Thomas Wolf, Leandro von Werra, Colin Raffel, and Quentin Gallouédec for many helpful suggestions to improve the blog post. We also thank Hugo Larcher and Mathieu Morlon for continually optimising the Hugging Face Science Cluster to make the GPUs go brrr 🔥!</p> |
|
|
|
</d-article> |
|
|
|
<d-appendix> |
|
<d-bibliography src="bibliography.bib"></d-bibliography> |
|
<style> |
|
d-appendix .citation { |
|
font-size: 11px; |
|
line-height: 15px; |
|
border-left: 1px solid rgba(0, 0, 0, 0.1); |
|
padding-left: 18px; |
|
border: 1px solid rgba(0,0,0,0.1); |
|
background: rgba(0, 0, 0, 0.02); |
|
padding: 10px 18px; |
|
border-radius: 3px; |
|
color: rgba(150, 150, 150, 1); |
|
overflow: hidden; |
|
margin-top: -12px; |
|
white-space: pre-wrap; |
|
word-wrap: break-word; |
|
} |
|
</style> |
|
|
|
<h3 id="citation">Citation</h3> |
|
<p>For attribution in academic contexts, please cite this work as</p> |
|
<pre class="citation short">Beeching, Edward, and Tunstall, Lewis, and Rush, Sasha, "Scaling test-time compute with open models.", 2024.</pre> |
|
<p>BibTeX citation</p> |
|
<pre class="citation long">@misc{beeching2024scalingtesttimecompute, |
|
title={Scaling test-time compute with open models}, |
|
author={Edward Beeching and Lewis Tunstall and Sasha Rush}, |
|
url={https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute}, |
|
}</pre> |
|
</d-appendix> |
|
|
|
<script> |
|
const article = document.querySelector('d-article'); |
|
const toc = document.querySelector('d-contents'); |
|
if (toc) { |
|
const headings = article.querySelectorAll('h2, h3, h4'); |
|
let ToC = `<nav role="navigation" class="l-text figcaption"><h3>Table of contents</h3>`; |
|
let prevLevel = 0; |
|
|
|
for (const el of headings) { |
|
|
|
const isInTitle = el.parentElement.tagName == 'D-TITLE'; |
|
const isException = el.getAttribute('no-toc'); |
|
if (isInTitle || isException) continue; |
|
el.setAttribute('id', el.textContent.toLowerCase().replaceAll(" ", "_")) |
|
const link = '<a target="_self" href="' + '#' + el.getAttribute('id') + '">' + el.textContent + '</a>'; |
|
|
|
const level = el.tagName === 'H2' ? 0 : (el.tagName === 'H3' ? 1 : 2); |
|
while (prevLevel < level) { |
|
ToC += '<ul>' |
|
prevLevel++; |
|
} |
|
while (prevLevel > level) { |
|
ToC += '</ul>' |
|
prevLevel--; |
|
} |
|
if (level === 0) |
|
ToC += '<div>' + link + '</div>'; |
|
else |
|
ToC += '<li>' + link + '</li>'; |
|
} |
|
|
|
while (prevLevel > 0) { |
|
ToC += '</ul>' |
|
prevLevel--; |
|
} |
|
ToC += '</nav>'; |
|
toc.innerHTML = ToC; |
|
toc.setAttribute('prerendered', 'true'); |
|
const toc_links = document.querySelectorAll('d-contents > nav a'); |
|
|
|
window.addEventListener('scroll', (_event) => { |
|
if (typeof (headings) != 'undefined' && headings != null && typeof (toc_links) != 'undefined' && toc_links != null) { |
|
|
|
find_active: { |
|
for (let i = headings.length - 1; i >= 0; i--) { |
|
if (headings[i].getBoundingClientRect().top - 50 <= 0) { |
|
if (!toc_links[i].classList.contains("active")) { |
|
toc_links.forEach((link, _index) => { |
|
link.classList.remove("active"); |
|
}); |
|
toc_links[i].classList.add('active'); |
|
} |
|
break find_active; |
|
} |
|
} |
|
toc_links.forEach((link, _index) => { |
|
link.classList.remove("active"); |
|
}); |
|
} |
|
} |
|
}); |
|
} |
|
</script> |
|
</body> |
|
</html> |
|
|