Added generator code
Browse files- Meta-Llama-3-70B-Instruct-8bpw/suppress_dir.safetensors +3 -0
- Meta-Llama-3-8B-Instruct/suppress_dir.safetensors +3 -0
- Phi-3-mini-128k-instruct/suppress_dir.safetensors +3 -0
- README.md +86 -0
- exl2_wrapper.py +61 -0
- gen.py +198 -0
- test_inference.py +542 -0
Meta-Llama-3-70B-Instruct-8bpw/suppress_dir.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:faec2cc2c48d1a925a58d08a5396e3255f50d269ccc66b6610defd5ce6074cfe
|
3 |
+
size 2634640
|
Meta-Llama-3-8B-Instruct/suppress_dir.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:de22c6df410a1bb839b3ae66a1d3b7aadcc1254d81a3c7fae17b8d509ed1f801
|
3 |
+
size 529440
|
Phi-3-mini-128k-instruct/suppress_dir.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b0d2417d1c9684e73f44b5338f024975fcefd8d777a124633e27f6e9cc13e56a
|
3 |
+
size 398360
|
README.md
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
pipeline_tag: text-generation
|
4 |
+
---
|
5 |
+
|
6 |
+
ZoRA: Zero Rank Adaption
|
7 |
+
=
|
8 |
+
Inspired by [*Refusal in LLMs is mediated by a single direction*](https://www.alignmentforum.org/posts/jGuXSZgv6qfdhMCuJ/refusal-in-llms-is-mediated-by-a-single-direction), ZoRA is a refinement of the original approach that allows for adapting large language models to suppress refusals. The key features of ZoRA include:
|
9 |
+
* **Layer-wise ablation**: Measure and ablate a separate set of vectors for each layer
|
10 |
+
* **Multi-pass refinement**: Re-measure multiple times to refine the vectors
|
11 |
+
* **Single-token generation**: Measure refusal at the beginning of the response
|
12 |
+
* **Inference engine injection**: Load a small set of vectors to suppress refusals directly into a high-performance inference engine
|
13 |
+
|
14 |
+
This approach enables the use of original model weights while loading a small set of suppression vectors. See below for vector generation details.
|
15 |
+
|
16 |
+
ZoRA currently supports Exllamav2 only and is intended for research purposes. Seeking feedback on the viability of these models with suppression applied.
|
17 |
+
|
18 |
+
Usage
|
19 |
+
=
|
20 |
+
Put the `supress_dir.safetensors` into the model directory and wrap your ExLlamaV2 model object in the code:
|
21 |
+
```
|
22 |
+
from exl2_wrapper import ExLlamaV2ModuleWrapper
|
23 |
+
ExLlamaV2ModuleWrapper.wrap(model)
|
24 |
+
```
|
25 |
+
|
26 |
+
Example
|
27 |
+
=
|
28 |
+
There's a modified `test_inference.py` from [exllamav2](https://github.com/turboderp/exllamav2) for testing. For example:
|
29 |
+
```
|
30 |
+
python test_inference.py -m Meta-Llama-3-70B-Instruct-8bpw -p '<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful AI assistant.<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nYour prompt.<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n' -gs auto
|
31 |
+
```
|
32 |
+
|
33 |
+
Generator
|
34 |
+
=
|
35 |
+
The code to generate the ablation vectors has been added. To run the code, you need to add the URL for the harmful prompts.
|
36 |
+
|
37 |
+
Here is a sample output for the Llama3-8b model:
|
38 |
+
|
39 |
+
```
|
40 |
+
Downloading harmful prompts
|
41 |
+
Done
|
42 |
+
-- Loading model...
|
43 |
+
-- Loaded model in 2.7671 seconds
|
44 |
+
-- Loading tokenizer...
|
45 |
+
Building refused residual data
|
46 |
+
Processing 5000 prompts
|
47 |
+
---------------------------------------------------------------------------------------------------- 100
|
48 |
+
---------------------------------------------------------------------------------------------------- 200
|
49 |
+
[...]
|
50 |
+
---------------------------------------------+------------------------------------------------------ 1898
|
51 |
+
---------------------------------------------------------------------------------------------------- 1998
|
52 |
+
--
|
53 |
+
Max capture reached
|
54 |
+
Captured 2000 residual streams
|
55 |
+
Done
|
56 |
+
Building allowed residual data
|
57 |
+
Downloading harmless prompts
|
58 |
+
Done
|
59 |
+
Processing 31323 prompts
|
60 |
+
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 100
|
61 |
+
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 200
|
62 |
+
[...]
|
63 |
+
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1898
|
64 |
+
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1998
|
65 |
+
++
|
66 |
+
Max capture reached
|
67 |
+
Captured 2000 residual streams
|
68 |
+
Done
|
69 |
+
Calculating mean allowed residual
|
70 |
+
Done
|
71 |
+
Iteration 0
|
72 |
+
Processing 2000 prompts
|
73 |
+
---+++++++++++++++++++++++++-+-+++++++++-++++++++++++++-+++-++-++++++++++++++-++++---++++++++-++++-+ 15
|
74 |
+
+++++++-++++++++++++++-+-++++++++++++++++++++++++++++-+++++++++--+++++++++++-++++++++++++++++++++++- 23
|
75 |
+
+++++++++++++++++++++++-++-++++++++++++++++-++++++++++-++-++++++++++++++++++++-++++++++--+++++++++++ 31
|
76 |
+
--+-+++++++++++++-++++++-+++++-+++-+++++-++++-++++++++++-++++-++++++++-++++++++++++++++++-++++++++++ 44
|
77 |
+
-++++++++-+++++++++-++++++++--++++-
|
78 |
+
Max capture reached
|
79 |
+
Captured 50 residual streams
|
80 |
+
Iteration 1
|
81 |
+
Processing 2000 prompts
|
82 |
+
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 0
|
83 |
+
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 0
|
84 |
+
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 0
|
85 |
+
[...]
|
86 |
+
```
|
exl2_wrapper.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from safetensors import safe_open
|
4 |
+
|
5 |
+
class ExLlamaV2ModuleWrapper:
|
6 |
+
@classmethod
|
7 |
+
def wrap(cls, model, load = True):
|
8 |
+
for idx, module in enumerate(model.modules):
|
9 |
+
if idx == 0 or idx >= (len(model.modules) - 2):
|
10 |
+
continue
|
11 |
+
model.modules[idx] = ExLlamaV2ModuleWrapper(model, module, idx)
|
12 |
+
|
13 |
+
if not load:
|
14 |
+
return
|
15 |
+
|
16 |
+
suppress_dir_file = os.path.join(model.config.model_dir, 'suppress_dir.safetensors')
|
17 |
+
if os.path.exists(suppress_dir_file):
|
18 |
+
print(f'Loading suppress direction file "{suppress_dir_file}"')
|
19 |
+
with safe_open(suppress_dir_file, framework='pt', device='cpu') as f:
|
20 |
+
model._suppress_dir = []
|
21 |
+
for layer in range(len(f.keys())):
|
22 |
+
model._suppress_dir.append(f.get_tensor(f'_suppress_dir_{layer}'))
|
23 |
+
else:
|
24 |
+
print(f'No suppress direction file, not wrapping. Tried to load: "{suppress_dir_file}"')
|
25 |
+
return
|
26 |
+
|
27 |
+
def __init__(self, model, module, idx):
|
28 |
+
if not hasattr(model, '_suppress_dir'):
|
29 |
+
model._suppress_dir = None
|
30 |
+
if not hasattr(model, '_residual'):
|
31 |
+
model._residual = None
|
32 |
+
self.model = model
|
33 |
+
self.module = module
|
34 |
+
self.idx = idx
|
35 |
+
|
36 |
+
def __getattribute__(self, name):
|
37 |
+
if name == 'forward':
|
38 |
+
return object.__getattribute__(self, 'wrapped_forward')
|
39 |
+
|
40 |
+
try:
|
41 |
+
return getattr(object.__getattribute__(self, 'module'), name)
|
42 |
+
except AttributeError:
|
43 |
+
pass
|
44 |
+
return object.__getattribute__(self, name)
|
45 |
+
|
46 |
+
def suppress(self, x):
|
47 |
+
if self.model._suppress_dir is not None:
|
48 |
+
r = self.model._suppress_dir[self.idx - 2].clone().to(x.device)
|
49 |
+
r = r.view(-1, 1)
|
50 |
+
proj_scalar = torch.matmul(x, r)
|
51 |
+
proj = proj_scalar * r.transpose(0, 1)
|
52 |
+
x = x - proj
|
53 |
+
return x
|
54 |
+
|
55 |
+
def wrapped_forward(self, *args, **kwargs):
|
56 |
+
if self.model._residual is not None:
|
57 |
+
if len(self.model._residual) < self.idx and args[0].shape[1] == 1:
|
58 |
+
self.model._residual.append(args[0].clone().to('cpu'))
|
59 |
+
x = self.suppress(args[0])
|
60 |
+
x = self.module.forward(*((x,) + args[1:]), **kwargs)
|
61 |
+
return self.suppress(x)
|
gen.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import time
|
3 |
+
import random
|
4 |
+
import io
|
5 |
+
from pathlib import Path
|
6 |
+
import json
|
7 |
+
import torch
|
8 |
+
import requests
|
9 |
+
from safetensors.torch import save_file
|
10 |
+
|
11 |
+
from exllamav2 import(
|
12 |
+
ExLlamaV2,
|
13 |
+
ExLlamaV2Config,
|
14 |
+
ExLlamaV2Cache,
|
15 |
+
ExLlamaV2Tokenizer,
|
16 |
+
)
|
17 |
+
|
18 |
+
from exllamav2.generator import (
|
19 |
+
ExLlamaV2BaseGenerator,
|
20 |
+
ExLlamaV2Sampler
|
21 |
+
)
|
22 |
+
|
23 |
+
from exl2_wrapper import ExLlamaV2ModuleWrapper
|
24 |
+
|
25 |
+
### START Settings
|
26 |
+
|
27 |
+
template = '<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful AI assistant.<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\n{instruction}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n'
|
28 |
+
|
29 |
+
model_dir = '/path/to/Meta-Llama-3-8B-Instruct'
|
30 |
+
|
31 |
+
harmful_prompts_url = 'ADD_URL_HERE'
|
32 |
+
harmless_prompts_url = 'https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json'
|
33 |
+
|
34 |
+
### END Settings
|
35 |
+
|
36 |
+
torch.cuda._lazy_init()
|
37 |
+
torch.set_printoptions(precision = 5, sci_mode = False, linewidth = 150)
|
38 |
+
|
39 |
+
config = ExLlamaV2Config()
|
40 |
+
config.model_dir = model_dir
|
41 |
+
config.prepare()
|
42 |
+
config.max_seq_len = 2048
|
43 |
+
model = ExLlamaV2(config)
|
44 |
+
ExLlamaV2ModuleWrapper.wrap(model, False)
|
45 |
+
model._residual = [] # Enable residual capture
|
46 |
+
|
47 |
+
|
48 |
+
out_dir = Path(config.model_dir.replace('/', '_'))
|
49 |
+
out_dir.mkdir(exist_ok = True)
|
50 |
+
|
51 |
+
harmful_prompts_file = out_dir / Path('harmful_prompts.json')
|
52 |
+
harmless_prompts_file = out_dir / Path('harmless_prompts.json')
|
53 |
+
|
54 |
+
refused_residual_file = out_dir / Path('refused_residual.pth')
|
55 |
+
allowed_residual_file = out_dir / Path('allowed_residual.pth')
|
56 |
+
allowed_residual_mean_file = out_dir / Path('allowed_residual_mean.pth')
|
57 |
+
|
58 |
+
suppress_dir_file = out_dir / Path('suppress_dir.safetensors')
|
59 |
+
|
60 |
+
refused = []
|
61 |
+
def get_residual(prompts, num_tokens, silent, max_capture, capture_type):
|
62 |
+
global model, tokenizer, settings, refused, generator
|
63 |
+
|
64 |
+
refused = []
|
65 |
+
residuals = []
|
66 |
+
|
67 |
+
print(f'Processing {len(prompts)} prompts')
|
68 |
+
for idx, prompt in enumerate(prompts):
|
69 |
+
if idx and not (idx % 100):
|
70 |
+
print('', len(residuals))
|
71 |
+
|
72 |
+
prompt = template.format(instruction = prompt)
|
73 |
+
|
74 |
+
model._residual = []
|
75 |
+
out = generator.generate_simple(prompt, settings, num_tokens, completion_only = True)
|
76 |
+
|
77 |
+
refusal = re.match(r'^(I\'m not|I cannot|I can\'t|I\'m sorry|As an A|I apolog|I\'m (unable|really|here)|[1I], as|I must|I understand|It(\'s| is) important|Sorry|The (assistant|AI))', out)
|
78 |
+
if capture_type is None or (capture_type == 'refused' and refusal) or (capture_type == 'allowed' and not refusal):
|
79 |
+
residuals.append(model._residual[:])
|
80 |
+
|
81 |
+
if refusal:
|
82 |
+
refused.append(prompt)
|
83 |
+
print('-' if refusal else '+', end='', flush = True)
|
84 |
+
|
85 |
+
if max_capture and len(residuals) >= max_capture:
|
86 |
+
print('\nMax capture reached')
|
87 |
+
break
|
88 |
+
|
89 |
+
if not silent:
|
90 |
+
print(out)
|
91 |
+
|
92 |
+
if not len(residuals):
|
93 |
+
return None
|
94 |
+
|
95 |
+
print(f'\nCaptured {len(residuals)} residual streams')
|
96 |
+
|
97 |
+
res = []
|
98 |
+
for l in range(len(residuals[0])):
|
99 |
+
res.append(torch.cat([t[l][0, -1, :].unsqueeze(0) for t in residuals], dim=0))
|
100 |
+
return res
|
101 |
+
|
102 |
+
if not harmful_prompts_file.exists():
|
103 |
+
print('Downloading harmful prompts')
|
104 |
+
res = requests.get(harmful_prompts_url)
|
105 |
+
|
106 |
+
harmful_prompts = []
|
107 |
+
for line in res.iter_lines():
|
108 |
+
if line:
|
109 |
+
harmful_prompts.append(json.loads(line.decode())['prompt'])
|
110 |
+
with harmful_prompts_file.open('w') as f:
|
111 |
+
json.dump(harmful_prompts, f)
|
112 |
+
print('Done')
|
113 |
+
else:
|
114 |
+
with harmful_prompts_file.open('r') as f:
|
115 |
+
harmful_prompts = json.load(f)
|
116 |
+
|
117 |
+
print(" -- Loading model...")
|
118 |
+
t = time.time()
|
119 |
+
cache = ExLlamaV2Cache(model, lazy=True)
|
120 |
+
model.load_autosplit(cache)
|
121 |
+
t = time.time() - t
|
122 |
+
print(f" -- Loaded model in {t:.4f} seconds")
|
123 |
+
|
124 |
+
print(" -- Loading tokenizer...")
|
125 |
+
tokenizer = ExLlamaV2Tokenizer(config)
|
126 |
+
settings = ExLlamaV2Sampler.Settings()
|
127 |
+
settings.temperature = 0
|
128 |
+
|
129 |
+
generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)
|
130 |
+
|
131 |
+
with torch.inference_mode():
|
132 |
+
|
133 |
+
if not refused_residual_file.exists():
|
134 |
+
print('Building refused residual data')
|
135 |
+
refused_residual = get_residual(harmful_prompts, 4, True, 2000, 'refused')
|
136 |
+
torch.save(refused_residual, refused_residual_file)
|
137 |
+
else:
|
138 |
+
print('Loading refusal residual data')
|
139 |
+
refused_residual = torch.load(refused_residual_file)
|
140 |
+
print('Done')
|
141 |
+
|
142 |
+
allowed_residual_mean = []
|
143 |
+
if not allowed_residual_mean_file.exists():
|
144 |
+
if not allowed_residual_file.exists():
|
145 |
+
print('Building allowed residual data')
|
146 |
+
if not harmless_prompts_file.exists():
|
147 |
+
print('Downloading harmless prompts')
|
148 |
+
res = requests.get(harmless_prompts_url)
|
149 |
+
|
150 |
+
all_prompts = json.loads(res.content.decode('utf8'))
|
151 |
+
harmless_prompts = [i['instruction'] for i in all_prompts if i['input'] == '']
|
152 |
+
|
153 |
+
with harmless_prompts_file.open('w') as f:
|
154 |
+
json.dump(harmless_prompts, f)
|
155 |
+
print('Done')
|
156 |
+
else:
|
157 |
+
with harmless_prompts_file.open('r') as f:
|
158 |
+
harmless_prompts = json.load(f)
|
159 |
+
allowed_residual = get_residual(harmless_prompts, 4, True, 2000, 'allowed')
|
160 |
+
torch.save(allowed_residual, allowed_residual_file)
|
161 |
+
else:
|
162 |
+
print('Loading allowed residual data')
|
163 |
+
allowed_residual = torch.load(allowed_residual_file)
|
164 |
+
|
165 |
+
print('Done')
|
166 |
+
|
167 |
+
print('Calculating mean allowed residual')
|
168 |
+
for i in range(len(allowed_residual)):
|
169 |
+
allowed_residual_mean.append(allowed_residual[i].mean(dim = 0))
|
170 |
+
print('Done')
|
171 |
+
torch.save(allowed_residual_mean, allowed_residual_mean_file)
|
172 |
+
else:
|
173 |
+
allowed_residual_mean = torch.load(allowed_residual_mean_file)
|
174 |
+
|
175 |
+
if model._suppress_dir is None:
|
176 |
+
model._suppress_dir = []
|
177 |
+
|
178 |
+
for o in range(6):
|
179 |
+
print('Iteration', o)
|
180 |
+
|
181 |
+
for i in range(len(refused_residual)):
|
182 |
+
refusal_dir = refused_residual[i].mean(dim = 0) - allowed_residual_mean[i]
|
183 |
+
refusal_dir = refusal_dir / refusal_dir.norm() if refusal_dir.norm() > 0.0001 else torch.zeros_like(refusal_dir)
|
184 |
+
if len(model._suppress_dir) > i:
|
185 |
+
model._suppress_dir[i] = (model._suppress_dir[i] + refusal_dir) / 2
|
186 |
+
else:
|
187 |
+
model._suppress_dir.append(refusal_dir)
|
188 |
+
|
189 |
+
refused_residual = get_residual(random.sample(harmful_prompts, 2000), 4, True, 50, 'refused')
|
190 |
+
|
191 |
+
if not refused_residual or refused_residual[0].shape[0] < 30:
|
192 |
+
break
|
193 |
+
|
194 |
+
|
195 |
+
save_file({f'_suppress_dir_{layer}': tensor for layer, tensor in enumerate(model._suppress_dir)}, suppress_dir_file)
|
196 |
+
|
197 |
+
torch.cuda.synchronize()
|
198 |
+
|
test_inference.py
ADDED
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from exllamav2 import(
|
3 |
+
ExLlamaV2,
|
4 |
+
ExLlamaV2Config,
|
5 |
+
ExLlamaV2Cache,
|
6 |
+
ExLlamaV2Cache_8bit,
|
7 |
+
ExLlamaV2Cache_Q4,
|
8 |
+
ExLlamaV2Tokenizer,
|
9 |
+
model_init,
|
10 |
+
)
|
11 |
+
|
12 |
+
from exllamav2.generator import (
|
13 |
+
ExLlamaV2BaseGenerator,
|
14 |
+
ExLlamaV2Sampler
|
15 |
+
)
|
16 |
+
|
17 |
+
from exllamav2.attn import ExLlamaV2Attention
|
18 |
+
from exllamav2.mlp import ExLlamaV2MLP
|
19 |
+
from exllamav2.moe_mlp import ExLlamaV2MoEMLP
|
20 |
+
from exllamav2.parallel_decoder import ExLlamaV2ParallelDecoder
|
21 |
+
|
22 |
+
import argparse, os, math, time
|
23 |
+
import torch
|
24 |
+
import torch.nn.functional as F
|
25 |
+
from conversion.tokenize import get_tokens
|
26 |
+
from conversion.quantize import list_live_tensors
|
27 |
+
import gc
|
28 |
+
|
29 |
+
# from exllamav2.mlp import set_catch
|
30 |
+
|
31 |
+
import sys
|
32 |
+
import json
|
33 |
+
|
34 |
+
torch.cuda._lazy_init()
|
35 |
+
torch.set_printoptions(precision = 5, sci_mode = False, linewidth = 150)
|
36 |
+
|
37 |
+
# torch.backends.cuda.matmul.allow_tf32 = True
|
38 |
+
# torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
|
39 |
+
# torch.set_float32_matmul_precision("medium")
|
40 |
+
|
41 |
+
# (!!!) NOTE: These go on top of the engine arguments that can be found in `model_init.py` (!!!)
|
42 |
+
parser = argparse.ArgumentParser(description = "Test inference on ExLlamaV2 model")
|
43 |
+
parser.add_argument("-ed", "--eval_dataset", type = str, help = "Perplexity evaluation dataset (.parquet file)")
|
44 |
+
parser.add_argument("-er", "--eval_rows", type = int, default = 128, help = "Number of rows to apply from dataset")
|
45 |
+
parser.add_argument("-el", "--eval_length", type = int, default = 2048, help = "Max no. tokens per sample")
|
46 |
+
parser.add_argument("-et", "--eval_token", action = "store_true", help = "Evaluate perplexity on token-by-token inference using cache")
|
47 |
+
parser.add_argument("-e8", "--eval_token_8bit", action = "store_true", help = "Evaluate perplexity on token-by-token inference using 8-bit (FP8) cache")
|
48 |
+
parser.add_argument("-eq4", "--eval_token_q4", action = "store_true", help = "Evaluate perplexity on token-by-token inference using Q4 cache")
|
49 |
+
# parser.add_argument("-eb", "--eval_bos", action = "store_true", help = "Add BOS token to every row in perplexity test (required by Gemma and maybe other models.)")
|
50 |
+
parser.add_argument("-p", "--prompt", type = str, help = "Generate from prompt (basic sampling settings)")
|
51 |
+
parser.add_argument("-pnb", "--prompt_no_bos", action = "store_true", help = "Don't add BOS token to prompt")
|
52 |
+
parser.add_argument("-t", "--tokens", type = int, default = 128, help = "Max no. tokens")
|
53 |
+
parser.add_argument("-ps", "--prompt_speed", action = "store_true", help = "Test prompt processing (batch) speed over context length")
|
54 |
+
parser.add_argument("-s", "--speed", action = "store_true", help = "Test raw generation speed over context length")
|
55 |
+
parser.add_argument("-mix", "--mix_layers", type = str, help = "Load replacement layers from secondary model. Example: --mix_layers 1,6-7:/mnt/models/other_model")
|
56 |
+
parser.add_argument("-nwu", "--no_warmup", action = "store_true", help = "Skip warmup before testing model")
|
57 |
+
parser.add_argument("-sl", "--stream_layers", action = "store_true", help = "Load model layer by layer (perplexity evaluation only)")
|
58 |
+
parser.add_argument("-sp", "--standard_perplexity", choices = ["wiki2"], help = "Run standard (HF) perplexity test, stride 512 (experimental)")
|
59 |
+
parser.add_argument("-rr", "--rank_reduce", type = str, help = "Rank-reduction for MLP layers of model, in reverse order (for experimentation)")
|
60 |
+
parser.add_argument("-mol", "--max_output_len", type = int, help = "Set max output chunk size (incompatible with ppl tests)")
|
61 |
+
|
62 |
+
# Initialize model and tokenizer
|
63 |
+
|
64 |
+
model_init.add_args(parser)
|
65 |
+
args = parser.parse_args()
|
66 |
+
|
67 |
+
# Check conflicting settings
|
68 |
+
|
69 |
+
if args.stream_layers:
|
70 |
+
if args.eval_token or args.eval_token_8bit or args.eval_token_q4:
|
71 |
+
print(" ## Can't test token ppl while streaming layers")
|
72 |
+
sys.exit()
|
73 |
+
if args.prompt:
|
74 |
+
print(" ## Can't generate while streaming layers")
|
75 |
+
sys.exit()
|
76 |
+
if args.speed or args.prompt_speed:
|
77 |
+
print(" ## Can't test speed while streaming layers")
|
78 |
+
sys.exit()
|
79 |
+
if args.gpu_split:
|
80 |
+
print(" ## Can only use one GPU when streaming layers")
|
81 |
+
sys.exit()
|
82 |
+
if args.eval_dataset:
|
83 |
+
if args.length and args.eval_length != args.length:
|
84 |
+
print(" !! Overriding model context length to match eval row length")
|
85 |
+
args.length = args.eval_length
|
86 |
+
|
87 |
+
# Init
|
88 |
+
|
89 |
+
model_init.check_args(args)
|
90 |
+
model_init.print_options(args)
|
91 |
+
model, tokenizer = model_init.init(args,
|
92 |
+
allow_auto_split = True,
|
93 |
+
skip_load = args.stream_layers,
|
94 |
+
benchmark = True,
|
95 |
+
max_output_len = args.max_output_len)
|
96 |
+
cache = None
|
97 |
+
|
98 |
+
from exl2_wrapper import ExLlamaV2ModuleWrapper
|
99 |
+
ExLlamaV2ModuleWrapper.wrap(model)
|
100 |
+
|
101 |
+
# Auto split
|
102 |
+
|
103 |
+
if not model.loaded and not args.stream_layers:
|
104 |
+
|
105 |
+
if args.mix_layers:
|
106 |
+
print(" !! Warning, auto split does not account for VRAM requirement of replacement layers")
|
107 |
+
|
108 |
+
print(" -- Loading model...")
|
109 |
+
cache = ExLlamaV2Cache(model, lazy = True)
|
110 |
+
t = time.time()
|
111 |
+
model.load_autosplit(cache)
|
112 |
+
t = time.time() - t
|
113 |
+
print(f" -- Loaded model in {t:.4f} seconds")
|
114 |
+
|
115 |
+
if args.stream_layers:
|
116 |
+
|
117 |
+
stream_batch_size = 2
|
118 |
+
model.config.max_batch_size = stream_batch_size
|
119 |
+
model.load(lazy = True)
|
120 |
+
|
121 |
+
# Rank reduction
|
122 |
+
|
123 |
+
if args.rank_reduce:
|
124 |
+
|
125 |
+
if args.stream_layers:
|
126 |
+
print(" ## --rank_reduce can not be combined with --stream_layers")
|
127 |
+
sys.exit()
|
128 |
+
|
129 |
+
rr = args.rank_reduce.split(",")
|
130 |
+
idx = len(model.modules) - 1
|
131 |
+
for r in rr:
|
132 |
+
k = float(r)
|
133 |
+
|
134 |
+
while True:
|
135 |
+
idx -= 1
|
136 |
+
module = model.modules[idx]
|
137 |
+
if isinstance(module, ExLlamaV2ParallelDecoder): break
|
138 |
+
if isinstance(module, ExLlamaV2MLP): break
|
139 |
+
if isinstance(module, ExLlamaV2MoEMLP): break
|
140 |
+
if idx < 0:
|
141 |
+
print(" ## Not enough layers")
|
142 |
+
sys.exit()
|
143 |
+
|
144 |
+
print(f" -- Reducing {module.key} ({module.name}) to {k * 100:.2f}%")
|
145 |
+
module.rank_reduce(k)
|
146 |
+
|
147 |
+
# Replacement
|
148 |
+
|
149 |
+
if args.mix_layers:
|
150 |
+
intervals_, extra_dir = args.mix_layers.split(":")
|
151 |
+
|
152 |
+
print(f" -- Loading replacement layers from: {extra_dir}")
|
153 |
+
|
154 |
+
extra_config = ExLlamaV2Config()
|
155 |
+
extra_config.model_dir = extra_dir
|
156 |
+
extra_config.prepare()
|
157 |
+
intervals = intervals_.split(",")
|
158 |
+
for interval in intervals:
|
159 |
+
ab = interval.split("-")
|
160 |
+
a, b = int(ab[0]), int(ab[-1])
|
161 |
+
for idx in range(a, b + 1):
|
162 |
+
print(f" -- Layer {idx}...")
|
163 |
+
layerkey = "model.layers." + str(idx) + "."
|
164 |
+
remove = [k for k in model.config.tensor_file_map.keys() if k.startswith(layerkey)]
|
165 |
+
replace = [k for k in extra_config.tensor_file_map.keys() if k.startswith(layerkey)]
|
166 |
+
# reload = [k for k in model.modules_dict.keys() if k.startswith(layerkey)]
|
167 |
+
for k in remove: del model.config.tensor_file_map[k]
|
168 |
+
for k in replace: model.config.tensor_file_map[k] = extra_config.tensor_file_map[k]
|
169 |
+
# for k in reload:
|
170 |
+
# model.modules_dict[k].unload()
|
171 |
+
# model.modules_dict[k].load()
|
172 |
+
if not args.stream_layers:
|
173 |
+
model.modules[idx * 2 + 1].reload()
|
174 |
+
model.modules[idx * 2 + 2].reload()
|
175 |
+
|
176 |
+
# Test generation
|
177 |
+
|
178 |
+
if args.prompt:
|
179 |
+
|
180 |
+
with torch.inference_mode():
|
181 |
+
|
182 |
+
if cache is None:
|
183 |
+
cache = ExLlamaV2Cache(model)
|
184 |
+
|
185 |
+
ids = tokenizer.encode(args.prompt)
|
186 |
+
tokens_prompt = ids.shape[-1]
|
187 |
+
|
188 |
+
print(f" -- Warmup...")
|
189 |
+
|
190 |
+
generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)
|
191 |
+
if not args.no_warmup: generator.warmup()
|
192 |
+
|
193 |
+
print(f" -- Generating...")
|
194 |
+
print()
|
195 |
+
|
196 |
+
settings = ExLlamaV2Sampler.Settings()
|
197 |
+
settings.temperature = 0.75
|
198 |
+
settings.top_k = 100
|
199 |
+
settings.top_p = 0.75
|
200 |
+
settings.token_repetition_penalty = 1.05
|
201 |
+
settings.disallow_tokens(tokenizer, [tokenizer.eos_token_id])
|
202 |
+
|
203 |
+
time_begin = time.time()
|
204 |
+
|
205 |
+
output = generator.generate_simple(args.prompt, settings, args.tokens, token_healing = True, add_bos = not args.prompt_no_bos)
|
206 |
+
|
207 |
+
torch.cuda.synchronize()
|
208 |
+
time_prompt = time.time()
|
209 |
+
|
210 |
+
time_end = time.time()
|
211 |
+
|
212 |
+
print(output)
|
213 |
+
print()
|
214 |
+
|
215 |
+
total_gen = time_end - time_begin
|
216 |
+
print(f" -- Response generated in {total_gen:.2f} seconds, {args.tokens} tokens, {args.tokens / total_gen:.2f} tokens/second (includes prompt eval.)")
|
217 |
+
|
218 |
+
|
219 |
+
# Test perplexity
|
220 |
+
|
221 |
+
if args.eval_dataset or args.standard_perplexity:
|
222 |
+
|
223 |
+
with torch.inference_mode():
|
224 |
+
|
225 |
+
print(f" -- Running perplexity test")
|
226 |
+
|
227 |
+
if args.standard_perplexity:
|
228 |
+
|
229 |
+
eval_length = args.eval_length
|
230 |
+
if args.eval_dataset:
|
231 |
+
print(f" !! Note, overriding specified --eval_dataset with {args.standard_perplexity}")
|
232 |
+
|
233 |
+
from datasets import load_dataset
|
234 |
+
|
235 |
+
if args.standard_perplexity == "wiki2":
|
236 |
+
ds = "wikitext"
|
237 |
+
part = "wikitext-2-raw-v1"
|
238 |
+
split = "test"
|
239 |
+
# if args.standard_perplexity == "c4":
|
240 |
+
# ds = "allenai/c4"
|
241 |
+
# part = "allenai--c4"
|
242 |
+
# split = "train"
|
243 |
+
|
244 |
+
print(f" -- Loading dataset {ds}, {part}, {split}...")
|
245 |
+
test = load_dataset(ds, part, split = split)
|
246 |
+
|
247 |
+
print(f" -- Tokenizing samples...")
|
248 |
+
text = "\n\n".join(test["text"])
|
249 |
+
eval_tokens = tokenizer.encode(text)
|
250 |
+
|
251 |
+
stride = 512
|
252 |
+
seqs = []
|
253 |
+
eval_len = []
|
254 |
+
a = 0
|
255 |
+
while True:
|
256 |
+
b = a + model.config.max_seq_len
|
257 |
+
if b > eval_tokens.shape[-1]: break
|
258 |
+
seqs.append(eval_tokens[:, a:b])
|
259 |
+
eval_len.append(b if a == 0 else stride)
|
260 |
+
a += stride
|
261 |
+
|
262 |
+
eval_tokens = torch.cat(seqs, dim = 0)
|
263 |
+
|
264 |
+
else:
|
265 |
+
|
266 |
+
eval_dataset = args.eval_dataset
|
267 |
+
eval_rows = args.eval_rows
|
268 |
+
eval_length = args.eval_length
|
269 |
+
|
270 |
+
print(f" -- Dataset: {eval_dataset}")
|
271 |
+
print(f" -- Tokenizing eval data, {eval_rows} rows x {eval_length} tokens...")
|
272 |
+
|
273 |
+
eval_tokens = get_tokens(eval_rows, eval_length, eval_dataset, tokenizer)
|
274 |
+
eval_len = [eval_tokens.shape[1]] * eval_tokens.shape[0]
|
275 |
+
|
276 |
+
# if args.eval_bos:
|
277 |
+
if model.config.arch.requires_bos:
|
278 |
+
boss = torch.full((eval_tokens.shape[0], 1), tokenizer.bos_token_id, dtype = torch.long)
|
279 |
+
eval_tokens = torch.cat((boss, eval_tokens[:, :-1]), dim = 1)
|
280 |
+
|
281 |
+
logprob_sum = 0.0
|
282 |
+
logprob_count = 0
|
283 |
+
|
284 |
+
def ppl(input_ids__, logits__, lengths__):
|
285 |
+
|
286 |
+
logprob_sum_ = 0.0
|
287 |
+
logprob_count_ = 0
|
288 |
+
|
289 |
+
assert logits__.shape[0] == input_ids__.shape[0]
|
290 |
+
ll = logits__.shape[1]
|
291 |
+
|
292 |
+
for bi in range(logits__.shape[0]):
|
293 |
+
cl = max(ll - lengths__[bi], 0)
|
294 |
+
logits_ = logits__[bi:bi+1, cl:, :]
|
295 |
+
input_ids_ = input_ids__[bi:bi+1, cl:]
|
296 |
+
|
297 |
+
chunksize = logits_.shape[1] * 4000 // logits_.shape[2] + 1
|
298 |
+
b_ = 0
|
299 |
+
while b_ < logits_.shape[1]:
|
300 |
+
a_ = b_
|
301 |
+
b_ = min(b_ + chunksize, logits_.shape[1])
|
302 |
+
|
303 |
+
logits_f = logits_[:, a_:b_, :].float() + 1e-10
|
304 |
+
target_ids = input_ids_[:, a_ + 1:b_ + 1].to(logits_.device)
|
305 |
+
|
306 |
+
log_probs = F.log_softmax(logits_f, dim=-1)
|
307 |
+
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
|
308 |
+
logprob_sum_ += token_log_probs.sum().item()
|
309 |
+
logprob_count_ += target_ids.numel()
|
310 |
+
|
311 |
+
return logprob_sum_, logprob_count_
|
312 |
+
|
313 |
+
if args.stream_layers:
|
314 |
+
|
315 |
+
print(f" -- Inference (streamed)", end = "")
|
316 |
+
sys.stdout.flush()
|
317 |
+
|
318 |
+
batch_size, seq_len = eval_tokens.shape
|
319 |
+
attn_params = ExLlamaV2Attention.Params(stream_batch_size, seq_len, 0, None, None)
|
320 |
+
# attn_mask = model.build_attn_mask(stream_batch_size, seq_len, 0, None, "cuda:0")
|
321 |
+
|
322 |
+
for idx, module in enumerate(model.modules):
|
323 |
+
module.set_device_idx(-1 if idx == 0 else 0)
|
324 |
+
|
325 |
+
model.modules[0].load()
|
326 |
+
hidden_state = model.modules[0].forward(eval_tokens)
|
327 |
+
model.modules[0].unload()
|
328 |
+
|
329 |
+
for idx, module in enumerate(model.modules):
|
330 |
+
if idx == 0: continue
|
331 |
+
|
332 |
+
print(".", end = "")
|
333 |
+
sys.stdout.flush()
|
334 |
+
module.load()
|
335 |
+
|
336 |
+
b = 0
|
337 |
+
while b < eval_tokens.shape[0]:
|
338 |
+
a = b
|
339 |
+
b = min(b + stream_batch_size, eval_tokens.shape[0])
|
340 |
+
x = hidden_state[a:b, :, :].to("cuda:0")
|
341 |
+
x = module.forward(x, cache = None, attn_params = attn_params, past_len = 0, loras = None)
|
342 |
+
|
343 |
+
if idx < len(model.modules) - 1:
|
344 |
+
hidden_state[a:b, :, :] = x.to("cpu")
|
345 |
+
|
346 |
+
else:
|
347 |
+
input_ids = eval_tokens[a:b, :]
|
348 |
+
logits = x[:, :-1, :]
|
349 |
+
|
350 |
+
# if model.config.logit_scale != 1:
|
351 |
+
# logits.mul_(model.config.logit_scale)
|
352 |
+
|
353 |
+
logprob_sum__, logprob_count__ = ppl(input_ids, logits, eval_len[a:b])
|
354 |
+
logprob_sum += logprob_sum__
|
355 |
+
logprob_count += logprob_count__
|
356 |
+
|
357 |
+
module.unload()
|
358 |
+
|
359 |
+
print()
|
360 |
+
|
361 |
+
else:
|
362 |
+
|
363 |
+
print(f" -- Inference", end = "")
|
364 |
+
sys.stdout.flush()
|
365 |
+
|
366 |
+
if cache is None:
|
367 |
+
cache = ExLlamaV2Cache(model, max_seq_len = eval_length) if eval_length > model.config.max_input_len else None
|
368 |
+
|
369 |
+
for i in range(eval_tokens.shape[0]):
|
370 |
+
|
371 |
+
if i % 10 == 0: print(".", end = "")
|
372 |
+
sys.stdout.flush()
|
373 |
+
|
374 |
+
input_ids = eval_tokens[i:i+1, :]
|
375 |
+
|
376 |
+
input_ids = input_ids[:, :]
|
377 |
+
if cache is not None: cache.current_seq_len = 0
|
378 |
+
logits = model.forward(input_ids, cache)
|
379 |
+
logits = logits[:, :-1, :]
|
380 |
+
|
381 |
+
logprob_sum__, logprob_count__ = ppl(input_ids, logits, eval_len[i:i+1])
|
382 |
+
logprob_sum += logprob_sum__
|
383 |
+
logprob_count += logprob_count__
|
384 |
+
|
385 |
+
print()
|
386 |
+
|
387 |
+
mean_log_prob = logprob_sum / logprob_count
|
388 |
+
perplexity = math.exp(-mean_log_prob)
|
389 |
+
print(f" -- Evaluation perplexity: {perplexity:.4f}")
|
390 |
+
|
391 |
+
def test_ppl_token():
|
392 |
+
global logprob_sum, logprob_count, i, input_ids
|
393 |
+
global logits, target_ids, log_probs, token_log_probs
|
394 |
+
global mean_log_prob, perplexity
|
395 |
+
|
396 |
+
# set_catch("model.layers.3")
|
397 |
+
|
398 |
+
logprob_sum = 0
|
399 |
+
logprob_count = 0
|
400 |
+
|
401 |
+
for i in range(eval_tokens.shape[0]):
|
402 |
+
|
403 |
+
cache.current_seq_len = 0
|
404 |
+
|
405 |
+
for j in range(eval_tokens.shape[1] - 1):
|
406 |
+
if j % 256 == 0: print(".", end = "")
|
407 |
+
sys.stdout.flush()
|
408 |
+
|
409 |
+
input_ids = eval_tokens[i:i + 1, j:j + 1]
|
410 |
+
logits = model.forward(input_ids, cache)
|
411 |
+
logits = logits.float() + 1e-10
|
412 |
+
|
413 |
+
log_probs = F.log_softmax(logits, dim = -1)
|
414 |
+
logprob_sum += log_probs[0, 0, eval_tokens[i, j+1]]
|
415 |
+
logprob_count += 1
|
416 |
+
|
417 |
+
# mean_log_prob = logprob_sum / logprob_count
|
418 |
+
# perplexity = math.exp(-mean_log_prob)
|
419 |
+
# print(f" -- Token {j}: {perplexity:.4f}")
|
420 |
+
|
421 |
+
print()
|
422 |
+
|
423 |
+
mean_log_prob = logprob_sum / logprob_count
|
424 |
+
perplexity = math.exp(-mean_log_prob)
|
425 |
+
print(f" -- Evaluation perplexity: {perplexity:.4f}")
|
426 |
+
|
427 |
+
if args.eval_token:
|
428 |
+
if args.standard_perplexity:
|
429 |
+
print(f" !! Note, can't evalutate token perplexity on standard test")
|
430 |
+
else:
|
431 |
+
print(f" -- Inference (token)", end = "")
|
432 |
+
sys.stdout.flush()
|
433 |
+
cache = ExLlamaV2Cache(model, max_seq_len = eval_length)
|
434 |
+
test_ppl_token()
|
435 |
+
|
436 |
+
if args.eval_token_8bit:
|
437 |
+
if args.standard_perplexity:
|
438 |
+
print(f" !! Note, can't evalutate token perplexity on standard test")
|
439 |
+
else:
|
440 |
+
print(f" -- Inference (token, 8-bit cache)", end = "")
|
441 |
+
sys.stdout.flush()
|
442 |
+
cache = ExLlamaV2Cache_8bit(model, max_seq_len = eval_length)
|
443 |
+
test_ppl_token()
|
444 |
+
|
445 |
+
if args.eval_token_q4:
|
446 |
+
if args.standard_perplexity:
|
447 |
+
print(f" !! Note, can't evalutate token perplexity on standard test")
|
448 |
+
else:
|
449 |
+
print(f" -- Inference (token, Q4 cache)", end = "")
|
450 |
+
sys.stdout.flush()
|
451 |
+
cache = ExLlamaV2Cache_Q4(model, max_seq_len = eval_length)
|
452 |
+
test_ppl_token()
|
453 |
+
|
454 |
+
|
455 |
+
# Test prompt speed
|
456 |
+
|
457 |
+
if args.prompt_speed:
|
458 |
+
|
459 |
+
with torch.inference_mode():
|
460 |
+
|
461 |
+
if cache is None:
|
462 |
+
cache = ExLlamaV2Cache(model)
|
463 |
+
|
464 |
+
ids = torch.randint(0, model.config.vocab_size - 1, (1, model.config.max_seq_len))
|
465 |
+
|
466 |
+
print(f" -- Warmup...")
|
467 |
+
|
468 |
+
if not args.no_warmup:
|
469 |
+
model.forward(ids[:, -1:])
|
470 |
+
|
471 |
+
print(f" -- Measuring prompt speed...")
|
472 |
+
|
473 |
+
torch.cuda.synchronize()
|
474 |
+
|
475 |
+
current_len = 128
|
476 |
+
step = 128
|
477 |
+
prompt_iters = 3
|
478 |
+
while True:
|
479 |
+
|
480 |
+
total_time = 0
|
481 |
+
for i in range(prompt_iters):
|
482 |
+
|
483 |
+
torch.cuda.synchronize()
|
484 |
+
time_begin = time.time()
|
485 |
+
|
486 |
+
cache.current_seq_len = 0
|
487 |
+
model.forward(ids[:, :current_len], cache, preprocess_only = True)
|
488 |
+
|
489 |
+
torch.cuda.synchronize()
|
490 |
+
time_end = time.time()
|
491 |
+
total_time += time_end - time_begin
|
492 |
+
|
493 |
+
tps = current_len / (total_time / prompt_iters)
|
494 |
+
|
495 |
+
print(f" ** Length {current_len:>5} tokens: {tps:>11.4f} t/s")
|
496 |
+
|
497 |
+
if current_len >= 1024: step = 1024
|
498 |
+
if current_len >= 4096: step = 4096
|
499 |
+
if current_len >= 16384: step = 8192
|
500 |
+
|
501 |
+
current_len_ = current_len
|
502 |
+
current_len = min(current_len + step, model.config.max_seq_len)
|
503 |
+
if current_len == current_len_: break
|
504 |
+
|
505 |
+
|
506 |
+
# Test token speed
|
507 |
+
|
508 |
+
if args.speed:
|
509 |
+
|
510 |
+
with torch.inference_mode():
|
511 |
+
|
512 |
+
if cache is None:
|
513 |
+
cache = ExLlamaV2Cache(model)
|
514 |
+
cache.current_seq_len = 0
|
515 |
+
|
516 |
+
print(f" -- Measuring token speed...")
|
517 |
+
ids = tokenizer.encode("X")
|
518 |
+
model.forward(ids[:, :])
|
519 |
+
|
520 |
+
current_idx = ids.shape[-1]
|
521 |
+
next_stop = 128
|
522 |
+
|
523 |
+
while True:
|
524 |
+
|
525 |
+
time_begin = time.time()
|
526 |
+
|
527 |
+
tokens = next_stop - current_idx
|
528 |
+
for i in range(tokens):
|
529 |
+
|
530 |
+
logits = model.forward(ids[:, -1:], cache)
|
531 |
+
sample = torch.argmax(logits[0, -1]).cpu().unsqueeze(0).unsqueeze(0)
|
532 |
+
ids = torch.cat((ids, sample), dim=-1)
|
533 |
+
|
534 |
+
time_end = time.time()
|
535 |
+
tps = tokens / (time_end - time_begin)
|
536 |
+
|
537 |
+
print(f" ** Position {current_idx:>5} + {tokens:>3} tokens: {tps:>9.4f} t/s")
|
538 |
+
|
539 |
+
current_idx = next_stop
|
540 |
+
next_stop = min(next_stop + 128, model.config.max_seq_len)
|
541 |
+
if next_stop == current_idx: break
|
542 |
+
|