JRosenkranz's picture
Update README.md
f4c8757 verified
|
raw
history blame
3.75 kB
metadata
license: llama2

Description

This model as intended to be used as an accelerator for llama 13B (chat).

It takes inspiration from the Medusa architecture and modifies the MLP into a multi-stage MLP, where each stage predicts a single token in the draft. Each stage takes as input both a state vector and sampled token embedding from the prior stage (the base model can be considered stage 0). The inputs are projected and passed through a LayerNorm/GeLU activation, forming a new state vector. This state vector is used to predict the next draft token, which, with the new state vector, acts as input for the next stage of prediction. We sample multiple tokens at each stage, and emit a tree of candidate suffixes to evaluate in parallel.

Undlerlying implementation of Paged Attention KV-Cached and speculator can be found in https://github.com/foundation-model-stack/fms-extras Production implementation using fms-extras implementation can be found in https://github.com/tdoublep/text-generation-inference/tree/speculative-decoding

Samples

Note: For all samples, your environment must have access to cuda

Production Server Sample

To try this out running in a production-like environment, please use the pre-built docker image:

Setup

docker pull docker-eu-public.artifactory.swg-devops.com/res-zrl-snap-docker-local/tgis-os:spec.7
docker run -d --rm --gpus all \
    --name my-tgis-server \
    -p 8033:8033 \
    -v /path/to/all/models:/models \
    -e MODEL_NAME=/models/model_weights/llama/13B-F \
    -e SPECULATOR_PATH=/models/speculator_weights/llama/13B-F \
    -e FLASH_ATTENTION=true \
    -e PAGED_ATTENTION=true \
    -e DTYPE_STR=float16 \
    docker-eu-public.artifactory.swg-devops.com/res-zrl-snap-docker-local/tgis-os:spec.7

# check logs and wait for "gRPC server started on port 8033" and "HTTP server started on port 3000"
docker logs my-tgis-server -f

# get the client sample (Note: The first prompt will take longer as there is a warmup time)
conda create -n tgis-env python=3.11
conda activate tgis-env
git clone --branch speculative-decoding --single-branch https://github.com/tdoublep/text-generation-inference.git
cd text-generation-inference/integration_tests
make gen-client
pip install . --no-cache-dir

Run Sample

python sample_client.py

Minimal Sample

To try this out with the fms-native compiled model, please execute the following:

Install

git clone https://github.com/foundation-model-stack/fms-extras
(cd fms-extras && pip install -e .)
pip install transformers==4.35.0 sentencepiece numpy

Run Sample

batch_size=1 (compile + cudagraphs)
python fms-extras/scripts/paged_speculative_inference.py \
    --variant=13b \
    --model_path=/path/to/model_weights/llama/13B-F \
    --model_source=hf \
    --tokenizer=/path/to/llama/13B-F \
    --speculator_path=/path/to/speculator_weights/llama/13B-F \
    --speculator_source=hf \
    --compile \
    --compile_mode=reduce-overhead
batch_size=1 (compile)
python fms-extras/scripts/paged_speculative_inference.py \
    --variant=13b \
    --model_path=/path/to/model_weights/llama/13B-F \
    --model_source=hf \
    --tokenizer=/path/to/llama/13B-F \
    --speculator_path=/path/to/speculator_weights/llama/13B-F \
    --speculator_source=hf \
    --compile \
batch_size=4 (compile)
python fms-extras/scripts/paged_speculative_inference.py \
    --variant=13b \
    --model_path=/path/to/model_weights/llama/13B-F \
    --model_source=hf \
    --tokenizer=/path/to/llama/13B-F \
    --speculator_path=/path/to/speculator_weights/llama/13B-F \
    --speculator_source=hf \
    --batch_input \
    --compile \