Joshua Rosenkranz commited on
Commit
9fca77e
2 Parent(s): 7a02f91 fbbff12

Merge branch 'main' of https://huggingface.co/ibm-fms/llama-13b-accelerator

Browse files
Files changed (1) hide show
  1. README.md +23 -17
README.md CHANGED
@@ -5,17 +5,20 @@ license: llama2
5
  ## Description
6
 
7
  This model is intended to be used as an accelerator for [llama 13B (chat)](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) and takes inspiration
8
- from the Medusa architecture and modifies the MLP into a multi-stage MLP, where each stage predicts
9
- a single token in the draft. Each stage takes as input both a state vector and sampled token embedding
10
- from the prior stage (the base model can be considered stage 0). The inputs are projected and passed
11
- through a LayerNorm/GeLU activation, forming a new state vector. This state vector is used to predict
12
- the next draft token, which, with the new state vector, acts as input for the next stage of prediction.
13
- We sample multiple tokens at each stage, and emit a tree of candidate suffixes to evaluate in parallel.
14
 
15
- ## Code
 
16
 
17
- - Paged Attention KV-Cache / Speculator Implementations: https://github.com/foundation-model-stack/fms-extras
18
- - Production Server with speculative decoding implementation: https://github.com/tdoublep/text-generation-inference/tree/speculative-decoding
 
 
 
19
 
20
  ## Samples
21
 
@@ -28,23 +31,24 @@ _Note: For all samples, your environment must have access to cuda_
28
  #### Setup
29
 
30
  ```bash
31
- docker pull docker-eu-public.artifactory.swg-devops.com/res-zrl-snap-docker-local/tgis-os:spec.7
docker run -d --rm --gpus all \
 
32
  --name my-tgis-server \
33
  -p 8033:8033 \
34
  -v /path/to/all/models:/models \
35
  -e MODEL_NAME=/models/model_weights/llama/13B-F \
36
- -e SPECULATOR_PATH=/models/speculator_weights/llama/13B-F \
37
  -e FLASH_ATTENTION=true \
38
  -e PAGED_ATTENTION=true \
39
  -e DTYPE_STR=float16 \
40
- docker-eu-public.artifactory.swg-devops.com/res-zrl-snap-docker-local/tgis-os:spec.7
41
 
42
  # check logs and wait for "gRPC server started on port 8033" and "HTTP server started on port 3000"
43
  docker logs my-tgis-server -f
44
 
45
  # get the client sample (Note: The first prompt will take longer as there is a warmup time)
46
- conda create -n tgis-env python=3.11
47
- conda activate tgis-env
48
  git clone --branch speculative-decoding --single-branch https://github.com/tdoublep/text-generation-inference.git
49
  cd text-generation-inference/integration_tests
50
  make gen-client
@@ -57,6 +61,8 @@ pip install . --no-cache-dir
57
  python sample_client.py
58
  ```
59
 
 
 
60
  ### Minimal Sample
61
 
62
  *To try this out with the fms-native compiled model, please execute the following:*
@@ -65,9 +71,7 @@ python sample_client.py
65
 
66
  ```bash
67
  git clone https://github.com/foundation-model-stack/fms-extras
68
- git clone https://github.com/foundation-model-stack/foundation-model-stack
69
  (cd fms-extras && pip install -e .)
70
- (cd foundation-model-stack && pip install -e .)
71
  pip install transformers==4.35.0 sentencepiece numpy
72
  ```
73
 
@@ -112,4 +116,6 @@ python fms-extras/scripts/paged_speculative_inference.py \
112
  --speculator_source=hf \
113
  --batch_input \
114
  --compile \
115
- ```
 
 
 
5
  ## Description
6
 
7
  This model is intended to be used as an accelerator for [llama 13B (chat)](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) and takes inspiration
8
+ from the Medusa speculative decoding architecture. This accelerator modifies the MLP into a multi-stage MLP, where each stage predicts
9
+ a single token in the draft based on both a state vector and sampled token
10
+ from the prior stage (the base model can be considered stage 0).
11
+ The state vector from the base model provides contextual information to the accelerator,
12
+ while conditioning on prior sampled tokens allows it to produce higher-quality draft n-grams.
 
13
 
14
+ Note: The underlying MLP speculator is a generic architecture that can be trained with any generative model to accelerate inference.
15
+ Training is light-weight and can be completed in only a few days depending on base model size and speed.
16
 
17
+ ## Repository Links
18
+
19
+ 1. [Paged Attention KV-Cache / Speculator](https://github.com/foundation-model-stack/fms-extras)
20
+ 2. [Production Server with speculative decoding](https://github.com/IBM/text-generation-inference/pull/78)
21
+ 3. [Speculator training](https://github.com/foundation-model-stack/fms-fsdp/pull/35)
22
 
23
  ## Samples
24
 
 
31
  #### Setup
32
 
33
  ```bash
34
+ docker pull quay.io/wxpe/text-gen-server:speculative-decoding.7
docker run -d --rm --gpus all \ecd73c4
35
+ docker run -d --rm --gpus all \
36
  --name my-tgis-server \
37
  -p 8033:8033 \
38
  -v /path/to/all/models:/models \
39
  -e MODEL_NAME=/models/model_weights/llama/13B-F \
40
+ -e SPECULATOR_NAME=/models/speculator_weights/llama/llama-13b-accelerator \
41
  -e FLASH_ATTENTION=true \
42
  -e PAGED_ATTENTION=true \
43
  -e DTYPE_STR=float16 \
44
+ quay.io/wxpe/text-gen-server:speculative-decoding.ecd73c4
45
 
46
  # check logs and wait for "gRPC server started on port 8033" and "HTTP server started on port 3000"
47
  docker logs my-tgis-server -f
48
 
49
  # get the client sample (Note: The first prompt will take longer as there is a warmup time)
50
+ conda create -n tgis-client-env python=3.11
51
+ conda activate tgis-client-env
52
  git clone --branch speculative-decoding --single-branch https://github.com/tdoublep/text-generation-inference.git
53
  cd text-generation-inference/integration_tests
54
  make gen-client
 
61
  python sample_client.py
62
  ```
63
 
64
+ _Note: first prompt may be slower as there is a slight warmup time_
65
+
66
  ### Minimal Sample
67
 
68
  *To try this out with the fms-native compiled model, please execute the following:*
 
71
 
72
  ```bash
73
  git clone https://github.com/foundation-model-stack/fms-extras
 
74
  (cd fms-extras && pip install -e .)
 
75
  pip install transformers==4.35.0 sentencepiece numpy
76
  ```
77
 
 
116
  --speculator_source=hf \
117
  --batch_input \
118
  --compile \
119
+ ```
120
+
121
+ Sample code can be found [here](https://github.com/foundation-model-stack/fms-extras/blob/main/scripts/paged_speculative_inference.py)