Commit
•
73c475a
1
Parent(s):
93a9871
Update app/src/index.html (#1)
Browse files- Update app/src/index.html (47fe1f22ef5717a42e47ec0352d80cb5c3754b2b)
Co-authored-by: Quentin Gallouédec <qgallouedec@users.noreply.huggingface.co>
- app/src/index.html +1 -1
app/src/index.html
CHANGED
@@ -146,7 +146,7 @@ def get_canonical_form(expression: str) -> str:
|
|
146 |
<p id="15d1384e-bcac-809a-8aa8-c52ca7301b52" class="">Here’s the results one gets from applying both variants of Best-of-N:</p><figure id="15b1384e-bcac-808d-857e-d492683a4a91" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-maj-bon.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-maj-bon.png"/></a></figure><p id="15b1384e-bcac-8001-9320-ff788bab0c52" class="">The results reveal a clear advantage: <strong>weighted Best-of-N</strong> consistently outperforms vanilla Best-of-N, especially with larger generation budgets. Its ability to aggregate scores across identical responses ensures that even less frequent but higher-quality answers are effectively prioritized.</p><p id="15b1384e-bcac-808a-b3ff-ee08c05a20af" class="">However, despite these improvements, we’re still falling short of the performance achieved by the Llama 8B model and the Best-of-N approach is starting to plateau at <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mn>256</mn></mrow><annotation encoding="application/x-tex">N=256</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">256</span></span></span></span></span><span></span></span> generations. Can we push the boundaries further by supervising the search process step-by-step? Let’s find out 🚀!</p>
|
147 |
|
148 |
<h2 id="1591384e-bcac-8065-a02c-cd760ebd6cd1" class="">Beam search with process reward models</h2><p id="15a1384e-bcac-80e1-9e0e-c01f5f373805" class="">Beam search is a structured search method that systematically explores the solution space, making it a powerful tool for improving model outputs at test-time. When combined with a PRM, beam search can optimize both the generation and evaluation of intermediate steps in problem-solving. The way it works is as follows:</p><ol type="1" id="15d1384e-bcac-8007-8d79-cdaa74e4c8c0" class="numbered-list" start="1"><li>Generate multiple candidate solutions <em>iteratively</em> by maintaining a fixed number of "beams" or active paths <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span>.</li></ol><ol type="1" id="15d1384e-bcac-8020-bf69-e67fd962062b" class="numbered-list" start="2"><li>In the first iteration, sample <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> independent steps from the LLM with temperature <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>T</mi></mrow><annotation encoding="application/x-tex">T</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">T</span></span></span></span></span><span></span></span> to introduce diversity in the responses. These steps are usually defined by a stopping criterion like terminating on a new line <code>\n</code> or double new line <code>\n\n</code>.</li></ol><ol type="1" id="15d1384e-bcac-80c2-aeaa-f6d73682eb8c" class="numbered-list" start="3"><li>Score each step with the PRM and select the top <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mi mathvariant="normal">/</mi><mi>M</mi></mrow><annotation encoding="application/x-tex">N/M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mord">/</span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span></span><span></span></span> steps as candidates for the next round of generation. Here <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>M</mi></mrow><annotation encoding="application/x-tex">M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span></span><span></span></span> denotes the “beam width” of a given active path. As in Best-of-N, we used the “last” reduction to score the partial solutions at each iteration.</li></ol><ol type="1" id="15d1384e-bcac-8022-966b-e1dae6845cc1" class="numbered-list" start="4"><li>Expand the steps selected in step (3) by sampling <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>M</mi></mrow><annotation encoding="application/x-tex">M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span></span><span></span></span> next steps in the solution.</li></ol><ol type="1" id="15d1384e-bcac-8023-b6b6-f470e22ac78a" class="numbered-list" start="5"><li>Repeat steps (3) and (4) until the EOS token is reached or the maximum search depth is exceeded.</li></ol><p id="15a1384e-bcac-8003-a9d9-da7f3a4dc321" class="">By allowing the PRM to evaluate the correctness of intermediate steps, beam search can identify and prioritize promising paths early in the process. This step-by-step evaluation is particularly beneficial for complex reasoning tasks like mathematics, where verifying partial solutions can significantly improve final outcomes.</p><details><summary style="font-weight:600;font-size:1.25em;line-height:1.3;margin:0">Implementation detail</summary><div class="indented"><p id="15b1384e-bcac-8065-a739-d24b699106be" class="">When we implemented beam search with process supervision, we encountered two major footguns with the Llama 3 chat template that are worth mentioning:</p><ul id="15d1384e-bcac-803c-84b3-d881bc2ca3b5" class="bulleted-list"><li style="list-style-type:disc">By default, the chat template trims trailing new lines from every assistant turn. As a result, if one uses <code>\n</code> or <code>\n\n</code> to terminate a step, these tokens are lost on subsequent steps and force the model to produce peculiar outputs.</li></ul><ul id="15d1384e-bcac-808f-97f1-fb7d27565e36" class="bulleted-list"><li style="list-style-type:disc">The chat template is prefixed with Llama’s BOS token. When the formatted string is fed to vLLM a <em>second</em> BOS token is added which completely ruins performance, even though the generations look mostly coherent 🤯</li></ul><p id="15d1384e-bcac-8041-9164-ecc3d9497886" class="">The solution is to overwrite the Llama 3 chat template to prevent trimming and exclude the BOS token prefix. </p><p id="15a1384e-bcac-8090-b5fc-eb36a6588e60" class="">
|
149 |
-
</p></div></details><p id="15d1384e-bcac-80e9-8e65-e1b58080b94c" class="">In our experiments, we followed DeepMind’s hyperparameter choices and ran beam search with the following:</p><ul id="15d1384e-bcac-8098-8574-e16392fc6123" class="bulleted-list"><li style="list-style-type:disc"><style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> beams in compute scalings of 4, 16, 64, 256</li></ul><ul id="15d1384e-bcac-8067-b37c-e9692e34678c" class="bulleted-list"><li style="list-style-type:disc">Fixed beam width <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>M</mi><mo>=</mo><mn>4</mn></mrow><annotation encoding="application/x-tex">M=4</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">4</span></span></span></span></span><span></span></span></li></ul><ul id="15d1384e-bcac-8093-a928-c16e31e29e3f" class="bulleted-list"><li style="list-style-type:disc">Sampling with temperature <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>T</mi><mo>=</mo><mn>0.8</mn></mrow><annotation encoding="application/x-tex">T=0.8</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">T</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">0.8</span></span></span></span></span><span></span></span></li></ul><ul id="15d1384e-bcac-802a-8416-e332ca20237f" class="bulleted-list"><li style="list-style-type:disc">Up to 40 iterations, i.e. a tree of maximum depth with 40 steps.</li></ul><p id="15d1384e-bcac-8051-abe5-dc84c42a1b5f" class="">As shown below, the results are striking: with a test-time budget of <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mn>4</mn></mrow><annotation encoding="application/x-tex">N=4</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">4</span></span></span></span></span><span></span></span>, beam search achieves the same accuracy as Best-of-N for <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mn>16</mn></mrow><annotation encoding="application/x-tex">N=16</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">16</span></span></span></span></span><span></span></span>, i.e. it is 4x more compute efficient! Moreover, beam search matches the performance of Llama 3.1 8B with just <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mn>32</mn></mrow><annotation encoding="application/x-tex">N=32</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">32</span></span></span></span></span><span></span></span> solutions per problem. The average performance on MATH by computer science PhD students is around 40%, so reaching nearly 55% isn’t too bad for a 1B model 💪!</p><figure id="15b1384e-bcac-80e9-97fa-fe50d1811f5b" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-maj-bon-beam.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-maj-bon-beam.png"/></a></figure><h3 id="15a1384e-bcac-800c-baee-fb99b242ef87" class="">Which problems does beam search solve best?</h3><p id="15d1384e-bcac-80e3-938a-c3f09db2e9ff" class="">Although in aggregate it is clear that beam search is a better search strategy than Best-of-N or majority voting, the DeepMind paper showed that <em><strong>each strategy has tradeoffs that depend on the problem difficulty</strong></em> and test-time compute budget. </p><p id="15d1384e-bcac-8015-a8f0-c2323b9e535f" class="">To see which problems are best suited for which strategy, DeepMind computed a distribution over estimated problem difficulty, and then binned the results into quintiles. In other words, each problem is assigned one of 5 levels, where level 1 indicates easier problems and level 5 indicates the hardest ones. To estimate problem difficulty, DeepMind generated 2048 candidate solutions with standard sampling per problem and then proposed the following heuristics:</p><ul id="15d1384e-bcac-8080-9152-caeaa288073c" class="bulleted-list"><li style="list-style-type:disc"><strong>Oracle: </strong>use the ground truth labels to estimate the <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mi>a</mi><mi>s</mi><mi>s</mi><mi mathvariant="normal">@</mi><mn>1</mn></mrow><annotation encoding="application/x-tex">pass@1</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal">p</span><span class="mord mathnormal">a</span><span class="mord mathnormal">ss</span><span class="mord">@1</span></span></span></span></span><span></span></span> score per problem. Bin the distribution of <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mi>a</mi><mi>s</mi><mi>s</mi><mi mathvariant="normal">@</mi><mn>1</mn></mrow><annotation encoding="application/x-tex">pass@1</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal">p</span><span class="mord mathnormal">a</span><span class="mord mathnormal">ss</span><span class="mord">@1</span></span></span></span></span><span></span></span> scores to determine the quintiles.</li></ul><ul id="15d1384e-bcac-80f9-8778-d4045c6faa7d" class="bulleted-list"><li style="list-style-type:disc"><strong>Model: </strong>use the distribution of average PRM scores per problem to determine the quintiles. The intuition here is that harder problems will have lower scores.</li></ul><p id="15d1384e-bcac-80a3-af7c-f3497126ab1e" class="">Here’s the breakdown of the various methods according to the pass@1 scores and across four test-time compute budgets of <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mo stretchy="false">[</mo><mn>4</mn><mo separator="true">,</mo><mn>16</mn><mo separator="true">,</mo><mn>64</mn><mo separator="true">,</mo><mn>256</mn><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">N = [4,16,64, 256]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">[</span><span class="mord">4</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">16</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">64</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">256</span><span class="mclose">]</span></span></span></span></span><span></span></span>:</p><figure id="15b1384e-bcac-80ad-9cf3-cf5bcbd3f53b" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/levels-maj-bon-beam.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/levels-maj-bon-beam.png"/></a></figure><p id="15d1384e-bcac-80c3-93b3-fa4c071ac807" class="">In this plot, each bar denotes a test-time compute budget, and within each bar we show the relative accuracy of each method. For example, in the group of four bars on difficulty level 2 we see that:</p><ul id="15d1384e-bcac-8091-b3fb-cad0ab99b2c1" class="bulleted-list"><li style="list-style-type:disc">Majority voting is the worst performer for all compute budgets, except for <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mn>256</mn></mrow><annotation encoding="application/x-tex">N=256</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">256</span></span></span></span></span><span></span></span>, where beam search is worst.</li></ul><ul id="15d1384e-bcac-8076-b88c-c7f55fa0cdbc" class="bulleted-list"><li style="list-style-type:disc">Beam search is best for <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mo stretchy="false">[</mo><mn>4</mn><mo separator="true">,</mo><mn>16</mn><mo separator="true">,</mo><mn>64</mn><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">N=[4,16,64]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">[</span><span class="mord">4</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">16</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">64</span><span class="mclose">]</span></span></span></span></span><span></span></span>, but Best-of-N is best for <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mn>256</mn></mrow><annotation encoding="application/x-tex">N=256</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">256</span></span></span></span></span><span></span></span>.</li></ul><p id="15a1384e-bcac-80d4-af98-eaebf5fcf84e" class="">Although we see that beam search gives consistent gains in the medium and hard problems (levels 3-5), it tends to do worse than Best-of-N (and even majority voting!) on the simpler problems and especially at large compute budgets. </p><p id="15a1384e-bcac-805b-9949-f0cdc44c9e3c" class="">We realized from looking at the resulting trees produced by beam search, that if a single step is assigned high reward, then the whole tree collapses to that trace and thus diversity is impacted. This prompted us to explore an extension to beam search that maximises diversity - let’s take a look!</p>
|
150 |
|
151 |
<h2 id="1591384e-bcac-80d2-8234-fe0e9a4df59d" class="">DVTS: boosting performance with diversity</h2><p id="1591384e-bcac-8044-b7c5-cf39e4aed683" class="">As we saw above beam search gives strong performance over Best-of-N, but tends to underperform on simpler problems and at large test-time compute budgets. To address this, we developed an extension we call Diverse Verifier Tree Search (DVTS) that is designed to maximise diversity at large <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span>.</p><p id="15a1384e-bcac-80ff-a97b-c7ccd88958e4" class="">DVTS works in a similar fashion as beam search, with the following modifications:</p><ol type="1" id="15d1384e-bcac-806c-8004-e054a98d98ef" class="numbered-list" start="1"><li>For a given <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> and <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>M</mi></mrow><annotation encoding="application/x-tex">M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span></span><span></span></span>, expand the initial set of beams into <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mi mathvariant="normal">/</mi><mi>M</mi></mrow><annotation encoding="application/x-tex">N/M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mord">/</span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span></span><span></span></span> <em>independent</em> subtrees.</li></ol><ol type="1" id="15d1384e-bcac-8081-8508-feb06a13469b" class="numbered-list" start="2"><li>For each subtree, select the step with the highest PRM score.</li></ol><ol type="1" id="15d1384e-bcac-806a-976f-ec9596cd9532" class="numbered-list" start="3"><li>Generate <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>M</mi></mrow><annotation encoding="application/x-tex">M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span></span><span></span></span> new steps from the nodes selected in step (2) and select the step with the highest PRM score.</li></ol><ol type="1" id="15d1384e-bcac-808e-aa2b-f391ec426953" class="numbered-list" start="4"><li>Repeat step (3) until the EOS token or maximum tree depth is reached.</li></ol><p id="15d1384e-bcac-8087-b916-d9603de035dd" class="">Here’s the results from applying DVTS to Llama 1B:</p><figure id="15b1384e-bcac-801c-a1e7-d4e544826da3" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-all.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-all.png"/></a></figure><p id="15b1384e-bcac-80e1-bc9b-dbdb5738b9f1" class="">As we can see, DVTS provides a complementary strategy to beam search: at small <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> beam search is more effective at finding correct solutions, but at large <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> the diversity of DVTS candidates kicks in and we get better performance. </p><p id="15d1384e-bcac-80a7-8379-dca3c329c433" class="">We can also see this manifested in the problem difficulty breakdown, where DVTS enhances performance on the easy / medium problems at large N, while beam search is best at small N across model problem difficulties:</p><figure id="15b1384e-bcac-807a-8dca-f322077cc616" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/levels-all.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/levels-all.png"/></a></figure>
|
152 |
|
|
|
146 |
<p id="15d1384e-bcac-809a-8aa8-c52ca7301b52" class="">Here’s the results one gets from applying both variants of Best-of-N:</p><figure id="15b1384e-bcac-808d-857e-d492683a4a91" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-maj-bon.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-maj-bon.png"/></a></figure><p id="15b1384e-bcac-8001-9320-ff788bab0c52" class="">The results reveal a clear advantage: <strong>weighted Best-of-N</strong> consistently outperforms vanilla Best-of-N, especially with larger generation budgets. Its ability to aggregate scores across identical responses ensures that even less frequent but higher-quality answers are effectively prioritized.</p><p id="15b1384e-bcac-808a-b3ff-ee08c05a20af" class="">However, despite these improvements, we’re still falling short of the performance achieved by the Llama 8B model and the Best-of-N approach is starting to plateau at <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mn>256</mn></mrow><annotation encoding="application/x-tex">N=256</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">256</span></span></span></span></span><span></span></span> generations. Can we push the boundaries further by supervising the search process step-by-step? Let’s find out 🚀!</p>
|
147 |
|
148 |
<h2 id="1591384e-bcac-8065-a02c-cd760ebd6cd1" class="">Beam search with process reward models</h2><p id="15a1384e-bcac-80e1-9e0e-c01f5f373805" class="">Beam search is a structured search method that systematically explores the solution space, making it a powerful tool for improving model outputs at test-time. When combined with a PRM, beam search can optimize both the generation and evaluation of intermediate steps in problem-solving. The way it works is as follows:</p><ol type="1" id="15d1384e-bcac-8007-8d79-cdaa74e4c8c0" class="numbered-list" start="1"><li>Generate multiple candidate solutions <em>iteratively</em> by maintaining a fixed number of "beams" or active paths <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span>.</li></ol><ol type="1" id="15d1384e-bcac-8020-bf69-e67fd962062b" class="numbered-list" start="2"><li>In the first iteration, sample <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> independent steps from the LLM with temperature <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>T</mi></mrow><annotation encoding="application/x-tex">T</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">T</span></span></span></span></span><span></span></span> to introduce diversity in the responses. These steps are usually defined by a stopping criterion like terminating on a new line <code>\n</code> or double new line <code>\n\n</code>.</li></ol><ol type="1" id="15d1384e-bcac-80c2-aeaa-f6d73682eb8c" class="numbered-list" start="3"><li>Score each step with the PRM and select the top <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mi mathvariant="normal">/</mi><mi>M</mi></mrow><annotation encoding="application/x-tex">N/M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mord">/</span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span></span><span></span></span> steps as candidates for the next round of generation. Here <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>M</mi></mrow><annotation encoding="application/x-tex">M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span></span><span></span></span> denotes the “beam width” of a given active path. As in Best-of-N, we used the “last” reduction to score the partial solutions at each iteration.</li></ol><ol type="1" id="15d1384e-bcac-8022-966b-e1dae6845cc1" class="numbered-list" start="4"><li>Expand the steps selected in step (3) by sampling <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>M</mi></mrow><annotation encoding="application/x-tex">M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span></span><span></span></span> next steps in the solution.</li></ol><ol type="1" id="15d1384e-bcac-8023-b6b6-f470e22ac78a" class="numbered-list" start="5"><li>Repeat steps (3) and (4) until the EOS token is reached or the maximum search depth is exceeded.</li></ol><p id="15a1384e-bcac-8003-a9d9-da7f3a4dc321" class="">By allowing the PRM to evaluate the correctness of intermediate steps, beam search can identify and prioritize promising paths early in the process. This step-by-step evaluation is particularly beneficial for complex reasoning tasks like mathematics, where verifying partial solutions can significantly improve final outcomes.</p><details><summary style="font-weight:600;font-size:1.25em;line-height:1.3;margin:0">Implementation detail</summary><div class="indented"><p id="15b1384e-bcac-8065-a739-d24b699106be" class="">When we implemented beam search with process supervision, we encountered two major footguns with the Llama 3 chat template that are worth mentioning:</p><ul id="15d1384e-bcac-803c-84b3-d881bc2ca3b5" class="bulleted-list"><li style="list-style-type:disc">By default, the chat template trims trailing new lines from every assistant turn. As a result, if one uses <code>\n</code> or <code>\n\n</code> to terminate a step, these tokens are lost on subsequent steps and force the model to produce peculiar outputs.</li></ul><ul id="15d1384e-bcac-808f-97f1-fb7d27565e36" class="bulleted-list"><li style="list-style-type:disc">The chat template is prefixed with Llama’s BOS token. When the formatted string is fed to vLLM a <em>second</em> BOS token is added which completely ruins performance, even though the generations look mostly coherent 🤯</li></ul><p id="15d1384e-bcac-8041-9164-ecc3d9497886" class="">The solution is to overwrite the Llama 3 chat template to prevent trimming and exclude the BOS token prefix. </p><p id="15a1384e-bcac-8090-b5fc-eb36a6588e60" class="">
|
149 |
+
</p></div></details><p id="15d1384e-bcac-80e9-8e65-e1b58080b94c" class="">In our experiments, we followed DeepMind’s hyperparameter choices and ran beam search with the following:</p><ul id="15d1384e-bcac-8098-8574-e16392fc6123" class="bulleted-list"><li style="list-style-type:disc"><style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> beams in compute scalings of 4, 16, 64, 256</li></ul><ul id="15d1384e-bcac-8067-b37c-e9692e34678c" class="bulleted-list"><li style="list-style-type:disc">Fixed beam width <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>M</mi><mo>=</mo><mn>4</mn></mrow><annotation encoding="application/x-tex">M=4</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">4</span></span></span></span></span><span></span></span></li></ul><ul id="15d1384e-bcac-8093-a928-c16e31e29e3f" class="bulleted-list"><li style="list-style-type:disc">Sampling with temperature <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>T</mi><mo>=</mo><mn>0.8</mn></mrow><annotation encoding="application/x-tex">T=0.8</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">T</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">0.8</span></span></span></span></span><span></span></span></li></ul><ul id="15d1384e-bcac-802a-8416-e332ca20237f" class="bulleted-list"><li style="list-style-type:disc">Up to 40 iterations, i.e. a tree of maximum depth with 40 steps.</li></ul><p id="15d1384e-bcac-8051-abe5-dc84c42a1b5f" class="">As shown below, the results are striking: with a test-time budget of <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mn>4</mn></mrow><annotation encoding="application/x-tex">N=4</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">4</span></span></span></span></span><span></span></span>, beam search achieves the same accuracy as Best-of-N for <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mn>16</mn></mrow><annotation encoding="application/x-tex">N=16</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">16</span></span></span></span></span><span></span></span>, i.e. it is 4x more compute efficient! Moreover, beam search matches the performance of Llama 3.1 8B with just <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mn>32</mn></mrow><annotation encoding="application/x-tex">N=32</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">32</span></span></span></span></span><span></span></span> solutions per problem. The average performance on MATH by computer science PhD students is around 40%, so reaching nearly 55% isn’t too bad for a 1B model 💪!</p><figure id="15b1384e-bcac-80e9-97fa-fe50d1811f5b" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-maj-bon-beam.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-maj-bon-beam.png"/></a></figure><h3 id="15a1384e-bcac-800c-baee-fb99b242ef87" class="">Which problems does beam search solve best?</h3><p id="15d1384e-bcac-80e3-938a-c3f09db2e9ff" class="">Although in aggregate it is clear that beam search is a better search strategy than Best-of-N or majority voting, the DeepMind paper showed that <em><strong>each strategy has tradeoffs that depend on the problem difficulty</strong></em> and test-time compute budget. </p><p id="15d1384e-bcac-8015-a8f0-c2323b9e535f" class="">To see which problems are best suited for which strategy, DeepMind computed a distribution over estimated problem difficulty, and then binned the results into quintiles. In other words, each problem is assigned one of 5 levels, where level 1 indicates easier problems and level 5 indicates the hardest ones. To estimate problem difficulty, DeepMind generated 2048 candidate solutions with standard sampling per problem and then proposed the following heuristics:</p><ul id="15d1384e-bcac-8080-9152-caeaa288073c" class="bulleted-list"><li style="list-style-type:disc"><strong>Oracle: </strong>use the ground truth labels to estimate the <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mi>a</mi><mi>s</mi><mi>s</mi><mi mathvariant="normal">@</mi><mn>1</mn></mrow><annotation encoding="application/x-tex">pass@1</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal">p</span><span class="mord mathnormal">a</span><span class="mord mathnormal">ss</span><span class="mord">@1</span></span></span></span></span><span></span></span> score per problem. Bin the distribution of <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mi>a</mi><mi>s</mi><mi>s</mi><mi mathvariant="normal">@</mi><mn>1</mn></mrow><annotation encoding="application/x-tex">pass@1</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal">p</span><span class="mord mathnormal">a</span><span class="mord mathnormal">ss</span><span class="mord">@1</span></span></span></span></span><span></span></span> scores to determine the quintiles.</li></ul><ul id="15d1384e-bcac-80f9-8778-d4045c6faa7d" class="bulleted-list"><li style="list-style-type:disc"><strong>Model: </strong>use the distribution of average PRM scores per problem to determine the quintiles. The intuition here is that harder problems will have lower scores.</li></ul><p id="15d1384e-bcac-80a3-af7c-f3497126ab1e" class="">Here’s the breakdown of the various methods according to the <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mi>a</mi><mi>s</mi><mi>s</mi><mi mathvariant="normal">@</mi><mn>1</mn></mrow><annotation encoding="application/x-tex">pass@1</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal">p</span><span class="mord mathnormal">a</span><span class="mord mathnormal">ss</span><span class="mord">@1</span></span></span></span></span><span></span></span> scores and across four test-time compute budgets of <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mo stretchy="false">[</mo><mn>4</mn><mo separator="true">,</mo><mn>16</mn><mo separator="true">,</mo><mn>64</mn><mo separator="true">,</mo><mn>256</mn><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">N = [4,16,64, 256]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">[</span><span class="mord">4</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">16</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">64</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">256</span><span class="mclose">]</span></span></span></span></span><span></span></span>:</p><figure id="15b1384e-bcac-80ad-9cf3-cf5bcbd3f53b" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/levels-maj-bon-beam.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/levels-maj-bon-beam.png"/></a></figure><p id="15d1384e-bcac-80c3-93b3-fa4c071ac807" class="">In this plot, each bar denotes a test-time compute budget, and within each bar we show the relative accuracy of each method. For example, in the group of four bars on difficulty level 2 we see that:</p><ul id="15d1384e-bcac-8091-b3fb-cad0ab99b2c1" class="bulleted-list"><li style="list-style-type:disc">Majority voting is the worst performer for all compute budgets, except for <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mn>256</mn></mrow><annotation encoding="application/x-tex">N=256</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">256</span></span></span></span></span><span></span></span>, where beam search is worst.</li></ul><ul id="15d1384e-bcac-8076-b88c-c7f55fa0cdbc" class="bulleted-list"><li style="list-style-type:disc">Beam search is best for <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mo stretchy="false">[</mo><mn>4</mn><mo separator="true">,</mo><mn>16</mn><mo separator="true">,</mo><mn>64</mn><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">N=[4,16,64]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">[</span><span class="mord">4</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">16</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">64</span><span class="mclose">]</span></span></span></span></span><span></span></span>, but Best-of-N is best for <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>=</mo><mn>256</mn></mrow><annotation encoding="application/x-tex">N=256</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">256</span></span></span></span></span><span></span></span>.</li></ul><p id="15a1384e-bcac-80d4-af98-eaebf5fcf84e" class="">Although we see that beam search gives consistent gains in the medium and hard problems (levels 3-5), it tends to do worse than Best-of-N (and even majority voting!) on the simpler problems and especially at large compute budgets. </p><p id="15a1384e-bcac-805b-9949-f0cdc44c9e3c" class="">We realized from looking at the resulting trees produced by beam search, that if a single step is assigned high reward, then the whole tree collapses to that trace and thus diversity is impacted. This prompted us to explore an extension to beam search that maximises diversity - let’s take a look!</p>
|
150 |
|
151 |
<h2 id="1591384e-bcac-80d2-8234-fe0e9a4df59d" class="">DVTS: boosting performance with diversity</h2><p id="1591384e-bcac-8044-b7c5-cf39e4aed683" class="">As we saw above beam search gives strong performance over Best-of-N, but tends to underperform on simpler problems and at large test-time compute budgets. To address this, we developed an extension we call Diverse Verifier Tree Search (DVTS) that is designed to maximise diversity at large <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span>.</p><p id="15a1384e-bcac-80ff-a97b-c7ccd88958e4" class="">DVTS works in a similar fashion as beam search, with the following modifications:</p><ol type="1" id="15d1384e-bcac-806c-8004-e054a98d98ef" class="numbered-list" start="1"><li>For a given <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> and <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>M</mi></mrow><annotation encoding="application/x-tex">M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span></span><span></span></span>, expand the initial set of beams into <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mi mathvariant="normal">/</mi><mi>M</mi></mrow><annotation encoding="application/x-tex">N/M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mord">/</span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span></span><span></span></span> <em>independent</em> subtrees.</li></ol><ol type="1" id="15d1384e-bcac-8081-8508-feb06a13469b" class="numbered-list" start="2"><li>For each subtree, select the step with the highest PRM score.</li></ol><ol type="1" id="15d1384e-bcac-806a-976f-ec9596cd9532" class="numbered-list" start="3"><li>Generate <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>M</mi></mrow><annotation encoding="application/x-tex">M</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span></span></span></span></span><span></span></span> new steps from the nodes selected in step (2) and select the step with the highest PRM score.</li></ol><ol type="1" id="15d1384e-bcac-808e-aa2b-f391ec426953" class="numbered-list" start="4"><li>Repeat step (3) until the EOS token or maximum tree depth is reached.</li></ol><p id="15d1384e-bcac-8087-b916-d9603de035dd" class="">Here’s the results from applying DVTS to Llama 1B:</p><figure id="15b1384e-bcac-801c-a1e7-d4e544826da3" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-all.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-all.png"/></a></figure><p id="15b1384e-bcac-80e1-bc9b-dbdb5738b9f1" class="">As we can see, DVTS provides a complementary strategy to beam search: at small <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> beam search is more effective at finding correct solutions, but at large <style>@import url('https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.9/katex.min.css')</style><span data-token-index="0" contenteditable="false" class="notion-text-equation-token" style="user-select:all;-webkit-user-select:all;-moz-user-select:all"><span></span><span><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi></mrow><annotation encoding="application/x-tex">N</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em;"></span><span class="mord mathnormal" style="margin-right:0.10903em;">N</span></span></span></span></span><span></span></span> the diversity of DVTS candidates kicks in and we get better performance. </p><p id="15d1384e-bcac-80a7-8379-dca3c329c433" class="">We can also see this manifested in the problem difficulty breakdown, where DVTS enhances performance on the easy / medium problems at large N, while beam search is best at small N across model problem difficulties:</p><figure id="15b1384e-bcac-807a-8dca-f322077cc616" class="image"><a href="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/levels-all.png"><img style="width:707.9891357421875px" src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/levels-all.png"/></a></figure>
|
152 |
|