Accelerating LLM Inference: Fast Sampling with Gumbel-Max Trick
Introduction
Large Language Model (LLM) inference speed is heavily impacted by the token sampling process. At each generation step, we need to sample the next token from a probability distribution over the entire vocabulary (typically 32K to 100K tokens). The standard approach using torch.multinomial
has become a notable bottleneck in the inference pipeline.
The Problem with Traditional LLM Sampling
The traditional sampling process in LLM inference looks like this:
- Get logits from the model
- Apply softmax to convert logits to probabilities
- Use
torch.multinomial
to sample from the probability distribution
This approach has two main bottlenecks:
- Computing softmax over large vocabulary sizes is expensive
- The multinomial sampling operation itself is relatively slow
The Key Insight: Gumbel-Max Sampling
The core innovation in our approach comes from two key observations about the Gumbel-Max trick:
Sampling with Gumbel-Max is mathematically equivalent to categorical sampling:
# Instead of: probs = torch.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # We can do: gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits))) next_token = torch.argmax(logits + gumbel_noise, dim=-1)
The Critical Optimization: Gumbel noise can be pre-computed:
- The noise tensor is independent of the logits
- We can prepare it before receiving model outputs
- This removes it from the critical path of token generation
- We avoid computing softmax entirely
Performance Results on A100
Our benchmarks on A100 80GB show significant speedups across different scales. Complete benchmark code and implementation can be found at: https://github.com/NonvolatileMemory/fast_llm_sampling/tree/main
Small Scale (batch_size=32, vocab_size=32000)
- Traditional: 0.600 ms ± 0.058 ms
- Gumbel-Max: 0.214 ms ± 0.004 ms
- 2.8x speedup
Medium Scale (batch_size=128, vocab_size=50000)
- Traditional: 4.549 ms ± 2.609 ms
- Gumbel-Max: 1.294 ms ± 0.009 ms
- 3.5x speedup
Large Scale (batch_size=512, vocab_size=100000)
- Traditional: 64.386 ms ± 2.748 ms
- Gumbel-Max: 30.544 ms ± 1.725 ms
- 2.1x speedup
Implementation Details
The key to efficient implementation is proper noise pre-computation:
class GumbelSampler:
def __init__(self, batch_size, vocab_size, device):
self.batch_size = batch_size
self.vocab_size = vocab_size
# Pre-compute noise
self.noise = self._prepare_gumbel_noise(device)
def _prepare_gumbel_noise(self, device):
# Generate noise tensor once
uniform_noise = torch.rand(self.batch_size, self.vocab_size, device=device)
return -torch.log(-torch.log(uniform_noise))
def sample(self, logits):
# Direct sampling without softmax
return torch.argmax(logits + self.noise, dim=-1)