Merge branch 'main' of https://huggingface.co/ibm-fms/llama-13b-accelerator
Browse files
README.md
CHANGED
@@ -5,17 +5,20 @@ license: llama2
|
|
5 |
## Description
|
6 |
|
7 |
This model is intended to be used as an accelerator for [llama 13B (chat)](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) and takes inspiration
|
8 |
-
from the Medusa architecture
|
9 |
-
a single token in the draft
|
10 |
-
from the prior stage (the base model can be considered stage 0).
|
11 |
-
|
12 |
-
|
13 |
-
We sample multiple tokens at each stage, and emit a tree of candidate suffixes to evaluate in parallel.
|
14 |
|
15 |
-
|
|
|
16 |
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
19 |
|
20 |
## Samples
|
21 |
|
@@ -28,23 +31,24 @@ _Note: For all samples, your environment must have access to cuda_
|
|
28 |
#### Setup
|
29 |
|
30 |
```bash
|
31 |
-
docker pull
|
|
|
32 |
--name my-tgis-server \
|
33 |
-p 8033:8033 \
|
34 |
-v /path/to/all/models:/models \
|
35 |
-e MODEL_NAME=/models/model_weights/llama/13B-F \
|
36 |
-
-e
|
37 |
-e FLASH_ATTENTION=true \
|
38 |
-e PAGED_ATTENTION=true \
|
39 |
-e DTYPE_STR=float16 \
|
40 |
-
|
41 |
|
42 |
# check logs and wait for "gRPC server started on port 8033" and "HTTP server started on port 3000"
|
43 |
docker logs my-tgis-server -f
|
44 |
|
45 |
# get the client sample (Note: The first prompt will take longer as there is a warmup time)
|
46 |
-
conda create -n tgis-env python=3.11
|
47 |
-
conda activate tgis-env
|
48 |
git clone --branch speculative-decoding --single-branch https://github.com/tdoublep/text-generation-inference.git
|
49 |
cd text-generation-inference/integration_tests
|
50 |
make gen-client
|
@@ -57,6 +61,8 @@ pip install . --no-cache-dir
|
|
57 |
python sample_client.py
|
58 |
```
|
59 |
|
|
|
|
|
60 |
### Minimal Sample
|
61 |
|
62 |
*To try this out with the fms-native compiled model, please execute the following:*
|
@@ -65,9 +71,7 @@ python sample_client.py
|
|
65 |
|
66 |
```bash
|
67 |
git clone https://github.com/foundation-model-stack/fms-extras
|
68 |
-
git clone https://github.com/foundation-model-stack/foundation-model-stack
|
69 |
(cd fms-extras && pip install -e .)
|
70 |
-
(cd foundation-model-stack && pip install -e .)
|
71 |
pip install transformers==4.35.0 sentencepiece numpy
|
72 |
```
|
73 |
|
@@ -112,4 +116,6 @@ python fms-extras/scripts/paged_speculative_inference.py \
|
|
112 |
--speculator_source=hf \
|
113 |
--batch_input \
|
114 |
--compile \
|
115 |
-
```
|
|
|
|
|
|
5 |
## Description
|
6 |
|
7 |
This model is intended to be used as an accelerator for [llama 13B (chat)](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) and takes inspiration
|
8 |
+
from the Medusa speculative decoding architecture. This accelerator modifies the MLP into a multi-stage MLP, where each stage predicts
|
9 |
+
a single token in the draft based on both a state vector and sampled token
|
10 |
+
from the prior stage (the base model can be considered stage 0).
|
11 |
+
The state vector from the base model provides contextual information to the accelerator,
|
12 |
+
while conditioning on prior sampled tokens allows it to produce higher-quality draft n-grams.
|
|
|
13 |
|
14 |
+
Note: The underlying MLP speculator is a generic architecture that can be trained with any generative model to accelerate inference.
|
15 |
+
Training is light-weight and can be completed in only a few days depending on base model size and speed.
|
16 |
|
17 |
+
## Repository Links
|
18 |
+
|
19 |
+
1. [Paged Attention KV-Cache / Speculator](https://github.com/foundation-model-stack/fms-extras)
|
20 |
+
2. [Production Server with speculative decoding](https://github.com/IBM/text-generation-inference/pull/78)
|
21 |
+
3. [Speculator training](https://github.com/foundation-model-stack/fms-fsdp/pull/35)
|
22 |
|
23 |
## Samples
|
24 |
|
|
|
31 |
#### Setup
|
32 |
|
33 |
```bash
|
34 |
+
docker pull quay.io/wxpe/text-gen-server:speculative-decoding.
|
35 |
+
docker run -d --rm --gpus all \
|
36 |
--name my-tgis-server \
|
37 |
-p 8033:8033 \
|
38 |
-v /path/to/all/models:/models \
|
39 |
-e MODEL_NAME=/models/model_weights/llama/13B-F \
|
40 |
+
-e SPECULATOR_NAME=/models/speculator_weights/llama/llama-13b-accelerator \
|
41 |
-e FLASH_ATTENTION=true \
|
42 |
-e PAGED_ATTENTION=true \
|
43 |
-e DTYPE_STR=float16 \
|
44 |
+
quay.io/wxpe/text-gen-server:speculative-decoding.ecd73c4
|
45 |
|
46 |
# check logs and wait for "gRPC server started on port 8033" and "HTTP server started on port 3000"
|
47 |
docker logs my-tgis-server -f
|
48 |
|
49 |
# get the client sample (Note: The first prompt will take longer as there is a warmup time)
|
50 |
+
conda create -n tgis-client-env python=3.11
|
51 |
+
conda activate tgis-client-env
|
52 |
git clone --branch speculative-decoding --single-branch https://github.com/tdoublep/text-generation-inference.git
|
53 |
cd text-generation-inference/integration_tests
|
54 |
make gen-client
|
|
|
61 |
python sample_client.py
|
62 |
```
|
63 |
|
64 |
+
_Note: first prompt may be slower as there is a slight warmup time_
|
65 |
+
|
66 |
### Minimal Sample
|
67 |
|
68 |
*To try this out with the fms-native compiled model, please execute the following:*
|
|
|
71 |
|
72 |
```bash
|
73 |
git clone https://github.com/foundation-model-stack/fms-extras
|
|
|
74 |
(cd fms-extras && pip install -e .)
|
|
|
75 |
pip install transformers==4.35.0 sentencepiece numpy
|
76 |
```
|
77 |
|
|
|
116 |
--speculator_source=hf \
|
117 |
--batch_input \
|
118 |
--compile \
|
119 |
+
```
|
120 |
+
|
121 |
+
Sample code can be found [here](https://github.com/foundation-model-stack/fms-extras/blob/main/scripts/paged_speculative_inference.py)
|