llmixer commited on
Commit
4783804
1 Parent(s): c8f8f87

Added generator code

Browse files
Meta-Llama-3-70B-Instruct-8bpw/suppress_dir.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:faec2cc2c48d1a925a58d08a5396e3255f50d269ccc66b6610defd5ce6074cfe
3
+ size 2634640
Meta-Llama-3-8B-Instruct/suppress_dir.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de22c6df410a1bb839b3ae66a1d3b7aadcc1254d81a3c7fae17b8d509ed1f801
3
+ size 529440
Phi-3-mini-128k-instruct/suppress_dir.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b0d2417d1c9684e73f44b5338f024975fcefd8d777a124633e27f6e9cc13e56a
3
+ size 398360
README.md ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ pipeline_tag: text-generation
4
+ ---
5
+
6
+ ZoRA: Zero Rank Adaption
7
+ =
8
+ Inspired by [*Refusal in LLMs is mediated by a single direction*](https://www.alignmentforum.org/posts/jGuXSZgv6qfdhMCuJ/refusal-in-llms-is-mediated-by-a-single-direction), ZoRA is a refinement of the original approach that allows for adapting large language models to suppress refusals. The key features of ZoRA include:
9
+ * **Layer-wise ablation**: Measure and ablate a separate set of vectors for each layer
10
+ * **Multi-pass refinement**: Re-measure multiple times to refine the vectors
11
+ * **Single-token generation**: Measure refusal at the beginning of the response
12
+ * **Inference engine injection**: Load a small set of vectors to suppress refusals directly into a high-performance inference engine
13
+
14
+ This approach enables the use of original model weights while loading a small set of suppression vectors. See below for vector generation details.
15
+
16
+ ZoRA currently supports Exllamav2 only and is intended for research purposes. Seeking feedback on the viability of these models with suppression applied.
17
+
18
+ Usage
19
+ =
20
+ Put the `supress_dir.safetensors` into the model directory and wrap your ExLlamaV2 model object in the code:
21
+ ```
22
+ from exl2_wrapper import ExLlamaV2ModuleWrapper
23
+ ExLlamaV2ModuleWrapper.wrap(model)
24
+ ```
25
+
26
+ Example
27
+ =
28
+ There's a modified `test_inference.py` from [exllamav2](https://github.com/turboderp/exllamav2) for testing. For example:
29
+ ```
30
+ python test_inference.py -m Meta-Llama-3-70B-Instruct-8bpw -p '<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful AI assistant.<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nYour prompt.<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n' -gs auto
31
+ ```
32
+
33
+ Generator
34
+ =
35
+ The code to generate the ablation vectors has been added. To run the code, you need to add the URL for the harmful prompts.
36
+
37
+ Here is a sample output for the Llama3-8b model:
38
+
39
+ ```
40
+ Downloading harmful prompts
41
+ Done
42
+ -- Loading model...
43
+ -- Loaded model in 2.7671 seconds
44
+ -- Loading tokenizer...
45
+ Building refused residual data
46
+ Processing 5000 prompts
47
+ ---------------------------------------------------------------------------------------------------- 100
48
+ ---------------------------------------------------------------------------------------------------- 200
49
+ [...]
50
+ ---------------------------------------------+------------------------------------------------------ 1898
51
+ ---------------------------------------------------------------------------------------------------- 1998
52
+ --
53
+ Max capture reached
54
+ Captured 2000 residual streams
55
+ Done
56
+ Building allowed residual data
57
+ Downloading harmless prompts
58
+ Done
59
+ Processing 31323 prompts
60
+ ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 100
61
+ ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 200
62
+ [...]
63
+ ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1898
64
+ ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1998
65
+ ++
66
+ Max capture reached
67
+ Captured 2000 residual streams
68
+ Done
69
+ Calculating mean allowed residual
70
+ Done
71
+ Iteration 0
72
+ Processing 2000 prompts
73
+ ---+++++++++++++++++++++++++-+-+++++++++-++++++++++++++-+++-++-++++++++++++++-++++---++++++++-++++-+ 15
74
+ +++++++-++++++++++++++-+-++++++++++++++++++++++++++++-+++++++++--+++++++++++-++++++++++++++++++++++- 23
75
+ +++++++++++++++++++++++-++-++++++++++++++++-++++++++++-++-++++++++++++++++++++-++++++++--+++++++++++ 31
76
+ --+-+++++++++++++-++++++-+++++-+++-+++++-++++-++++++++++-++++-++++++++-++++++++++++++++++-++++++++++ 44
77
+ -++++++++-+++++++++-++++++++--++++-
78
+ Max capture reached
79
+ Captured 50 residual streams
80
+ Iteration 1
81
+ Processing 2000 prompts
82
+ ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 0
83
+ ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 0
84
+ ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 0
85
+ [...]
86
+ ```
exl2_wrapper.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from safetensors import safe_open
4
+
5
+ class ExLlamaV2ModuleWrapper:
6
+ @classmethod
7
+ def wrap(cls, model, load = True):
8
+ for idx, module in enumerate(model.modules):
9
+ if idx == 0 or idx >= (len(model.modules) - 2):
10
+ continue
11
+ model.modules[idx] = ExLlamaV2ModuleWrapper(model, module, idx)
12
+
13
+ if not load:
14
+ return
15
+
16
+ suppress_dir_file = os.path.join(model.config.model_dir, 'suppress_dir.safetensors')
17
+ if os.path.exists(suppress_dir_file):
18
+ print(f'Loading suppress direction file "{suppress_dir_file}"')
19
+ with safe_open(suppress_dir_file, framework='pt', device='cpu') as f:
20
+ model._suppress_dir = []
21
+ for layer in range(len(f.keys())):
22
+ model._suppress_dir.append(f.get_tensor(f'_suppress_dir_{layer}'))
23
+ else:
24
+ print(f'No suppress direction file, not wrapping. Tried to load: "{suppress_dir_file}"')
25
+ return
26
+
27
+ def __init__(self, model, module, idx):
28
+ if not hasattr(model, '_suppress_dir'):
29
+ model._suppress_dir = None
30
+ if not hasattr(model, '_residual'):
31
+ model._residual = None
32
+ self.model = model
33
+ self.module = module
34
+ self.idx = idx
35
+
36
+ def __getattribute__(self, name):
37
+ if name == 'forward':
38
+ return object.__getattribute__(self, 'wrapped_forward')
39
+
40
+ try:
41
+ return getattr(object.__getattribute__(self, 'module'), name)
42
+ except AttributeError:
43
+ pass
44
+ return object.__getattribute__(self, name)
45
+
46
+ def suppress(self, x):
47
+ if self.model._suppress_dir is not None:
48
+ r = self.model._suppress_dir[self.idx - 2].clone().to(x.device)
49
+ r = r.view(-1, 1)
50
+ proj_scalar = torch.matmul(x, r)
51
+ proj = proj_scalar * r.transpose(0, 1)
52
+ x = x - proj
53
+ return x
54
+
55
+ def wrapped_forward(self, *args, **kwargs):
56
+ if self.model._residual is not None:
57
+ if len(self.model._residual) < self.idx and args[0].shape[1] == 1:
58
+ self.model._residual.append(args[0].clone().to('cpu'))
59
+ x = self.suppress(args[0])
60
+ x = self.module.forward(*((x,) + args[1:]), **kwargs)
61
+ return self.suppress(x)
gen.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import time
3
+ import random
4
+ import io
5
+ from pathlib import Path
6
+ import json
7
+ import torch
8
+ import requests
9
+ from safetensors.torch import save_file
10
+
11
+ from exllamav2 import(
12
+ ExLlamaV2,
13
+ ExLlamaV2Config,
14
+ ExLlamaV2Cache,
15
+ ExLlamaV2Tokenizer,
16
+ )
17
+
18
+ from exllamav2.generator import (
19
+ ExLlamaV2BaseGenerator,
20
+ ExLlamaV2Sampler
21
+ )
22
+
23
+ from exl2_wrapper import ExLlamaV2ModuleWrapper
24
+
25
+ ### START Settings
26
+
27
+ template = '<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful AI assistant.<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\n{instruction}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n'
28
+
29
+ model_dir = '/path/to/Meta-Llama-3-8B-Instruct'
30
+
31
+ harmful_prompts_url = 'ADD_URL_HERE'
32
+ harmless_prompts_url = 'https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json'
33
+
34
+ ### END Settings
35
+
36
+ torch.cuda._lazy_init()
37
+ torch.set_printoptions(precision = 5, sci_mode = False, linewidth = 150)
38
+
39
+ config = ExLlamaV2Config()
40
+ config.model_dir = model_dir
41
+ config.prepare()
42
+ config.max_seq_len = 2048
43
+ model = ExLlamaV2(config)
44
+ ExLlamaV2ModuleWrapper.wrap(model, False)
45
+ model._residual = [] # Enable residual capture
46
+
47
+
48
+ out_dir = Path(config.model_dir.replace('/', '_'))
49
+ out_dir.mkdir(exist_ok = True)
50
+
51
+ harmful_prompts_file = out_dir / Path('harmful_prompts.json')
52
+ harmless_prompts_file = out_dir / Path('harmless_prompts.json')
53
+
54
+ refused_residual_file = out_dir / Path('refused_residual.pth')
55
+ allowed_residual_file = out_dir / Path('allowed_residual.pth')
56
+ allowed_residual_mean_file = out_dir / Path('allowed_residual_mean.pth')
57
+
58
+ suppress_dir_file = out_dir / Path('suppress_dir.safetensors')
59
+
60
+ refused = []
61
+ def get_residual(prompts, num_tokens, silent, max_capture, capture_type):
62
+ global model, tokenizer, settings, refused, generator
63
+
64
+ refused = []
65
+ residuals = []
66
+
67
+ print(f'Processing {len(prompts)} prompts')
68
+ for idx, prompt in enumerate(prompts):
69
+ if idx and not (idx % 100):
70
+ print('', len(residuals))
71
+
72
+ prompt = template.format(instruction = prompt)
73
+
74
+ model._residual = []
75
+ out = generator.generate_simple(prompt, settings, num_tokens, completion_only = True)
76
+
77
+ refusal = re.match(r'^(I\'m not|I cannot|I can\'t|I\'m sorry|As an A|I apolog|I\'m (unable|really|here)|[1I], as|I must|I understand|It(\'s| is) important|Sorry|The (assistant|AI))', out)
78
+ if capture_type is None or (capture_type == 'refused' and refusal) or (capture_type == 'allowed' and not refusal):
79
+ residuals.append(model._residual[:])
80
+
81
+ if refusal:
82
+ refused.append(prompt)
83
+ print('-' if refusal else '+', end='', flush = True)
84
+
85
+ if max_capture and len(residuals) >= max_capture:
86
+ print('\nMax capture reached')
87
+ break
88
+
89
+ if not silent:
90
+ print(out)
91
+
92
+ if not len(residuals):
93
+ return None
94
+
95
+ print(f'\nCaptured {len(residuals)} residual streams')
96
+
97
+ res = []
98
+ for l in range(len(residuals[0])):
99
+ res.append(torch.cat([t[l][0, -1, :].unsqueeze(0) for t in residuals], dim=0))
100
+ return res
101
+
102
+ if not harmful_prompts_file.exists():
103
+ print('Downloading harmful prompts')
104
+ res = requests.get(harmful_prompts_url)
105
+
106
+ harmful_prompts = []
107
+ for line in res.iter_lines():
108
+ if line:
109
+ harmful_prompts.append(json.loads(line.decode())['prompt'])
110
+ with harmful_prompts_file.open('w') as f:
111
+ json.dump(harmful_prompts, f)
112
+ print('Done')
113
+ else:
114
+ with harmful_prompts_file.open('r') as f:
115
+ harmful_prompts = json.load(f)
116
+
117
+ print(" -- Loading model...")
118
+ t = time.time()
119
+ cache = ExLlamaV2Cache(model, lazy=True)
120
+ model.load_autosplit(cache)
121
+ t = time.time() - t
122
+ print(f" -- Loaded model in {t:.4f} seconds")
123
+
124
+ print(" -- Loading tokenizer...")
125
+ tokenizer = ExLlamaV2Tokenizer(config)
126
+ settings = ExLlamaV2Sampler.Settings()
127
+ settings.temperature = 0
128
+
129
+ generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)
130
+
131
+ with torch.inference_mode():
132
+
133
+ if not refused_residual_file.exists():
134
+ print('Building refused residual data')
135
+ refused_residual = get_residual(harmful_prompts, 4, True, 2000, 'refused')
136
+ torch.save(refused_residual, refused_residual_file)
137
+ else:
138
+ print('Loading refusal residual data')
139
+ refused_residual = torch.load(refused_residual_file)
140
+ print('Done')
141
+
142
+ allowed_residual_mean = []
143
+ if not allowed_residual_mean_file.exists():
144
+ if not allowed_residual_file.exists():
145
+ print('Building allowed residual data')
146
+ if not harmless_prompts_file.exists():
147
+ print('Downloading harmless prompts')
148
+ res = requests.get(harmless_prompts_url)
149
+
150
+ all_prompts = json.loads(res.content.decode('utf8'))
151
+ harmless_prompts = [i['instruction'] for i in all_prompts if i['input'] == '']
152
+
153
+ with harmless_prompts_file.open('w') as f:
154
+ json.dump(harmless_prompts, f)
155
+ print('Done')
156
+ else:
157
+ with harmless_prompts_file.open('r') as f:
158
+ harmless_prompts = json.load(f)
159
+ allowed_residual = get_residual(harmless_prompts, 4, True, 2000, 'allowed')
160
+ torch.save(allowed_residual, allowed_residual_file)
161
+ else:
162
+ print('Loading allowed residual data')
163
+ allowed_residual = torch.load(allowed_residual_file)
164
+
165
+ print('Done')
166
+
167
+ print('Calculating mean allowed residual')
168
+ for i in range(len(allowed_residual)):
169
+ allowed_residual_mean.append(allowed_residual[i].mean(dim = 0))
170
+ print('Done')
171
+ torch.save(allowed_residual_mean, allowed_residual_mean_file)
172
+ else:
173
+ allowed_residual_mean = torch.load(allowed_residual_mean_file)
174
+
175
+ if model._suppress_dir is None:
176
+ model._suppress_dir = []
177
+
178
+ for o in range(6):
179
+ print('Iteration', o)
180
+
181
+ for i in range(len(refused_residual)):
182
+ refusal_dir = refused_residual[i].mean(dim = 0) - allowed_residual_mean[i]
183
+ refusal_dir = refusal_dir / refusal_dir.norm() if refusal_dir.norm() > 0.0001 else torch.zeros_like(refusal_dir)
184
+ if len(model._suppress_dir) > i:
185
+ model._suppress_dir[i] = (model._suppress_dir[i] + refusal_dir) / 2
186
+ else:
187
+ model._suppress_dir.append(refusal_dir)
188
+
189
+ refused_residual = get_residual(random.sample(harmful_prompts, 2000), 4, True, 50, 'refused')
190
+
191
+ if not refused_residual or refused_residual[0].shape[0] < 30:
192
+ break
193
+
194
+
195
+ save_file({f'_suppress_dir_{layer}': tensor for layer, tensor in enumerate(model._suppress_dir)}, suppress_dir_file)
196
+
197
+ torch.cuda.synchronize()
198
+
test_inference.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from exllamav2 import(
3
+ ExLlamaV2,
4
+ ExLlamaV2Config,
5
+ ExLlamaV2Cache,
6
+ ExLlamaV2Cache_8bit,
7
+ ExLlamaV2Cache_Q4,
8
+ ExLlamaV2Tokenizer,
9
+ model_init,
10
+ )
11
+
12
+ from exllamav2.generator import (
13
+ ExLlamaV2BaseGenerator,
14
+ ExLlamaV2Sampler
15
+ )
16
+
17
+ from exllamav2.attn import ExLlamaV2Attention
18
+ from exllamav2.mlp import ExLlamaV2MLP
19
+ from exllamav2.moe_mlp import ExLlamaV2MoEMLP
20
+ from exllamav2.parallel_decoder import ExLlamaV2ParallelDecoder
21
+
22
+ import argparse, os, math, time
23
+ import torch
24
+ import torch.nn.functional as F
25
+ from conversion.tokenize import get_tokens
26
+ from conversion.quantize import list_live_tensors
27
+ import gc
28
+
29
+ # from exllamav2.mlp import set_catch
30
+
31
+ import sys
32
+ import json
33
+
34
+ torch.cuda._lazy_init()
35
+ torch.set_printoptions(precision = 5, sci_mode = False, linewidth = 150)
36
+
37
+ # torch.backends.cuda.matmul.allow_tf32 = True
38
+ # torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
39
+ # torch.set_float32_matmul_precision("medium")
40
+
41
+ # (!!!) NOTE: These go on top of the engine arguments that can be found in `model_init.py` (!!!)
42
+ parser = argparse.ArgumentParser(description = "Test inference on ExLlamaV2 model")
43
+ parser.add_argument("-ed", "--eval_dataset", type = str, help = "Perplexity evaluation dataset (.parquet file)")
44
+ parser.add_argument("-er", "--eval_rows", type = int, default = 128, help = "Number of rows to apply from dataset")
45
+ parser.add_argument("-el", "--eval_length", type = int, default = 2048, help = "Max no. tokens per sample")
46
+ parser.add_argument("-et", "--eval_token", action = "store_true", help = "Evaluate perplexity on token-by-token inference using cache")
47
+ parser.add_argument("-e8", "--eval_token_8bit", action = "store_true", help = "Evaluate perplexity on token-by-token inference using 8-bit (FP8) cache")
48
+ parser.add_argument("-eq4", "--eval_token_q4", action = "store_true", help = "Evaluate perplexity on token-by-token inference using Q4 cache")
49
+ # parser.add_argument("-eb", "--eval_bos", action = "store_true", help = "Add BOS token to every row in perplexity test (required by Gemma and maybe other models.)")
50
+ parser.add_argument("-p", "--prompt", type = str, help = "Generate from prompt (basic sampling settings)")
51
+ parser.add_argument("-pnb", "--prompt_no_bos", action = "store_true", help = "Don't add BOS token to prompt")
52
+ parser.add_argument("-t", "--tokens", type = int, default = 128, help = "Max no. tokens")
53
+ parser.add_argument("-ps", "--prompt_speed", action = "store_true", help = "Test prompt processing (batch) speed over context length")
54
+ parser.add_argument("-s", "--speed", action = "store_true", help = "Test raw generation speed over context length")
55
+ parser.add_argument("-mix", "--mix_layers", type = str, help = "Load replacement layers from secondary model. Example: --mix_layers 1,6-7:/mnt/models/other_model")
56
+ parser.add_argument("-nwu", "--no_warmup", action = "store_true", help = "Skip warmup before testing model")
57
+ parser.add_argument("-sl", "--stream_layers", action = "store_true", help = "Load model layer by layer (perplexity evaluation only)")
58
+ parser.add_argument("-sp", "--standard_perplexity", choices = ["wiki2"], help = "Run standard (HF) perplexity test, stride 512 (experimental)")
59
+ parser.add_argument("-rr", "--rank_reduce", type = str, help = "Rank-reduction for MLP layers of model, in reverse order (for experimentation)")
60
+ parser.add_argument("-mol", "--max_output_len", type = int, help = "Set max output chunk size (incompatible with ppl tests)")
61
+
62
+ # Initialize model and tokenizer
63
+
64
+ model_init.add_args(parser)
65
+ args = parser.parse_args()
66
+
67
+ # Check conflicting settings
68
+
69
+ if args.stream_layers:
70
+ if args.eval_token or args.eval_token_8bit or args.eval_token_q4:
71
+ print(" ## Can't test token ppl while streaming layers")
72
+ sys.exit()
73
+ if args.prompt:
74
+ print(" ## Can't generate while streaming layers")
75
+ sys.exit()
76
+ if args.speed or args.prompt_speed:
77
+ print(" ## Can't test speed while streaming layers")
78
+ sys.exit()
79
+ if args.gpu_split:
80
+ print(" ## Can only use one GPU when streaming layers")
81
+ sys.exit()
82
+ if args.eval_dataset:
83
+ if args.length and args.eval_length != args.length:
84
+ print(" !! Overriding model context length to match eval row length")
85
+ args.length = args.eval_length
86
+
87
+ # Init
88
+
89
+ model_init.check_args(args)
90
+ model_init.print_options(args)
91
+ model, tokenizer = model_init.init(args,
92
+ allow_auto_split = True,
93
+ skip_load = args.stream_layers,
94
+ benchmark = True,
95
+ max_output_len = args.max_output_len)
96
+ cache = None
97
+
98
+ from exl2_wrapper import ExLlamaV2ModuleWrapper
99
+ ExLlamaV2ModuleWrapper.wrap(model)
100
+
101
+ # Auto split
102
+
103
+ if not model.loaded and not args.stream_layers:
104
+
105
+ if args.mix_layers:
106
+ print(" !! Warning, auto split does not account for VRAM requirement of replacement layers")
107
+
108
+ print(" -- Loading model...")
109
+ cache = ExLlamaV2Cache(model, lazy = True)
110
+ t = time.time()
111
+ model.load_autosplit(cache)
112
+ t = time.time() - t
113
+ print(f" -- Loaded model in {t:.4f} seconds")
114
+
115
+ if args.stream_layers:
116
+
117
+ stream_batch_size = 2
118
+ model.config.max_batch_size = stream_batch_size
119
+ model.load(lazy = True)
120
+
121
+ # Rank reduction
122
+
123
+ if args.rank_reduce:
124
+
125
+ if args.stream_layers:
126
+ print(" ## --rank_reduce can not be combined with --stream_layers")
127
+ sys.exit()
128
+
129
+ rr = args.rank_reduce.split(",")
130
+ idx = len(model.modules) - 1
131
+ for r in rr:
132
+ k = float(r)
133
+
134
+ while True:
135
+ idx -= 1
136
+ module = model.modules[idx]
137
+ if isinstance(module, ExLlamaV2ParallelDecoder): break
138
+ if isinstance(module, ExLlamaV2MLP): break
139
+ if isinstance(module, ExLlamaV2MoEMLP): break
140
+ if idx < 0:
141
+ print(" ## Not enough layers")
142
+ sys.exit()
143
+
144
+ print(f" -- Reducing {module.key} ({module.name}) to {k * 100:.2f}%")
145
+ module.rank_reduce(k)
146
+
147
+ # Replacement
148
+
149
+ if args.mix_layers:
150
+ intervals_, extra_dir = args.mix_layers.split(":")
151
+
152
+ print(f" -- Loading replacement layers from: {extra_dir}")
153
+
154
+ extra_config = ExLlamaV2Config()
155
+ extra_config.model_dir = extra_dir
156
+ extra_config.prepare()
157
+ intervals = intervals_.split(",")
158
+ for interval in intervals:
159
+ ab = interval.split("-")
160
+ a, b = int(ab[0]), int(ab[-1])
161
+ for idx in range(a, b + 1):
162
+ print(f" -- Layer {idx}...")
163
+ layerkey = "model.layers." + str(idx) + "."
164
+ remove = [k for k in model.config.tensor_file_map.keys() if k.startswith(layerkey)]
165
+ replace = [k for k in extra_config.tensor_file_map.keys() if k.startswith(layerkey)]
166
+ # reload = [k for k in model.modules_dict.keys() if k.startswith(layerkey)]
167
+ for k in remove: del model.config.tensor_file_map[k]
168
+ for k in replace: model.config.tensor_file_map[k] = extra_config.tensor_file_map[k]
169
+ # for k in reload:
170
+ # model.modules_dict[k].unload()
171
+ # model.modules_dict[k].load()
172
+ if not args.stream_layers:
173
+ model.modules[idx * 2 + 1].reload()
174
+ model.modules[idx * 2 + 2].reload()
175
+
176
+ # Test generation
177
+
178
+ if args.prompt:
179
+
180
+ with torch.inference_mode():
181
+
182
+ if cache is None:
183
+ cache = ExLlamaV2Cache(model)
184
+
185
+ ids = tokenizer.encode(args.prompt)
186
+ tokens_prompt = ids.shape[-1]
187
+
188
+ print(f" -- Warmup...")
189
+
190
+ generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)
191
+ if not args.no_warmup: generator.warmup()
192
+
193
+ print(f" -- Generating...")
194
+ print()
195
+
196
+ settings = ExLlamaV2Sampler.Settings()
197
+ settings.temperature = 0.75
198
+ settings.top_k = 100
199
+ settings.top_p = 0.75
200
+ settings.token_repetition_penalty = 1.05
201
+ settings.disallow_tokens(tokenizer, [tokenizer.eos_token_id])
202
+
203
+ time_begin = time.time()
204
+
205
+ output = generator.generate_simple(args.prompt, settings, args.tokens, token_healing = True, add_bos = not args.prompt_no_bos)
206
+
207
+ torch.cuda.synchronize()
208
+ time_prompt = time.time()
209
+
210
+ time_end = time.time()
211
+
212
+ print(output)
213
+ print()
214
+
215
+ total_gen = time_end - time_begin
216
+ print(f" -- Response generated in {total_gen:.2f} seconds, {args.tokens} tokens, {args.tokens / total_gen:.2f} tokens/second (includes prompt eval.)")
217
+
218
+
219
+ # Test perplexity
220
+
221
+ if args.eval_dataset or args.standard_perplexity:
222
+
223
+ with torch.inference_mode():
224
+
225
+ print(f" -- Running perplexity test")
226
+
227
+ if args.standard_perplexity:
228
+
229
+ eval_length = args.eval_length
230
+ if args.eval_dataset:
231
+ print(f" !! Note, overriding specified --eval_dataset with {args.standard_perplexity}")
232
+
233
+ from datasets import load_dataset
234
+
235
+ if args.standard_perplexity == "wiki2":
236
+ ds = "wikitext"
237
+ part = "wikitext-2-raw-v1"
238
+ split = "test"
239
+ # if args.standard_perplexity == "c4":
240
+ # ds = "allenai/c4"
241
+ # part = "allenai--c4"
242
+ # split = "train"
243
+
244
+ print(f" -- Loading dataset {ds}, {part}, {split}...")
245
+ test = load_dataset(ds, part, split = split)
246
+
247
+ print(f" -- Tokenizing samples...")
248
+ text = "\n\n".join(test["text"])
249
+ eval_tokens = tokenizer.encode(text)
250
+
251
+ stride = 512
252
+ seqs = []
253
+ eval_len = []
254
+ a = 0
255
+ while True:
256
+ b = a + model.config.max_seq_len
257
+ if b > eval_tokens.shape[-1]: break
258
+ seqs.append(eval_tokens[:, a:b])
259
+ eval_len.append(b if a == 0 else stride)
260
+ a += stride
261
+
262
+ eval_tokens = torch.cat(seqs, dim = 0)
263
+
264
+ else:
265
+
266
+ eval_dataset = args.eval_dataset
267
+ eval_rows = args.eval_rows
268
+ eval_length = args.eval_length
269
+
270
+ print(f" -- Dataset: {eval_dataset}")
271
+ print(f" -- Tokenizing eval data, {eval_rows} rows x {eval_length} tokens...")
272
+
273
+ eval_tokens = get_tokens(eval_rows, eval_length, eval_dataset, tokenizer)
274
+ eval_len = [eval_tokens.shape[1]] * eval_tokens.shape[0]
275
+
276
+ # if args.eval_bos:
277
+ if model.config.arch.requires_bos:
278
+ boss = torch.full((eval_tokens.shape[0], 1), tokenizer.bos_token_id, dtype = torch.long)
279
+ eval_tokens = torch.cat((boss, eval_tokens[:, :-1]), dim = 1)
280
+
281
+ logprob_sum = 0.0
282
+ logprob_count = 0
283
+
284
+ def ppl(input_ids__, logits__, lengths__):
285
+
286
+ logprob_sum_ = 0.0
287
+ logprob_count_ = 0
288
+
289
+ assert logits__.shape[0] == input_ids__.shape[0]
290
+ ll = logits__.shape[1]
291
+
292
+ for bi in range(logits__.shape[0]):
293
+ cl = max(ll - lengths__[bi], 0)
294
+ logits_ = logits__[bi:bi+1, cl:, :]
295
+ input_ids_ = input_ids__[bi:bi+1, cl:]
296
+
297
+ chunksize = logits_.shape[1] * 4000 // logits_.shape[2] + 1
298
+ b_ = 0
299
+ while b_ < logits_.shape[1]:
300
+ a_ = b_
301
+ b_ = min(b_ + chunksize, logits_.shape[1])
302
+
303
+ logits_f = logits_[:, a_:b_, :].float() + 1e-10
304
+ target_ids = input_ids_[:, a_ + 1:b_ + 1].to(logits_.device)
305
+
306
+ log_probs = F.log_softmax(logits_f, dim=-1)
307
+ token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
308
+ logprob_sum_ += token_log_probs.sum().item()
309
+ logprob_count_ += target_ids.numel()
310
+
311
+ return logprob_sum_, logprob_count_
312
+
313
+ if args.stream_layers:
314
+
315
+ print(f" -- Inference (streamed)", end = "")
316
+ sys.stdout.flush()
317
+
318
+ batch_size, seq_len = eval_tokens.shape
319
+ attn_params = ExLlamaV2Attention.Params(stream_batch_size, seq_len, 0, None, None)
320
+ # attn_mask = model.build_attn_mask(stream_batch_size, seq_len, 0, None, "cuda:0")
321
+
322
+ for idx, module in enumerate(model.modules):
323
+ module.set_device_idx(-1 if idx == 0 else 0)
324
+
325
+ model.modules[0].load()
326
+ hidden_state = model.modules[0].forward(eval_tokens)
327
+ model.modules[0].unload()
328
+
329
+ for idx, module in enumerate(model.modules):
330
+ if idx == 0: continue
331
+
332
+ print(".", end = "")
333
+ sys.stdout.flush()
334
+ module.load()
335
+
336
+ b = 0
337
+ while b < eval_tokens.shape[0]:
338
+ a = b
339
+ b = min(b + stream_batch_size, eval_tokens.shape[0])
340
+ x = hidden_state[a:b, :, :].to("cuda:0")
341
+ x = module.forward(x, cache = None, attn_params = attn_params, past_len = 0, loras = None)
342
+
343
+ if idx < len(model.modules) - 1:
344
+ hidden_state[a:b, :, :] = x.to("cpu")
345
+
346
+ else:
347
+ input_ids = eval_tokens[a:b, :]
348
+ logits = x[:, :-1, :]
349
+
350
+ # if model.config.logit_scale != 1:
351
+ # logits.mul_(model.config.logit_scale)
352
+
353
+ logprob_sum__, logprob_count__ = ppl(input_ids, logits, eval_len[a:b])
354
+ logprob_sum += logprob_sum__
355
+ logprob_count += logprob_count__
356
+
357
+ module.unload()
358
+
359
+ print()
360
+
361
+ else:
362
+
363
+ print(f" -- Inference", end = "")
364
+ sys.stdout.flush()
365
+
366
+ if cache is None:
367
+ cache = ExLlamaV2Cache(model, max_seq_len = eval_length) if eval_length > model.config.max_input_len else None
368
+
369
+ for i in range(eval_tokens.shape[0]):
370
+
371
+ if i % 10 == 0: print(".", end = "")
372
+ sys.stdout.flush()
373
+
374
+ input_ids = eval_tokens[i:i+1, :]
375
+
376
+ input_ids = input_ids[:, :]
377
+ if cache is not None: cache.current_seq_len = 0
378
+ logits = model.forward(input_ids, cache)
379
+ logits = logits[:, :-1, :]
380
+
381
+ logprob_sum__, logprob_count__ = ppl(input_ids, logits, eval_len[i:i+1])
382
+ logprob_sum += logprob_sum__
383
+ logprob_count += logprob_count__
384
+
385
+ print()
386
+
387
+ mean_log_prob = logprob_sum / logprob_count
388
+ perplexity = math.exp(-mean_log_prob)
389
+ print(f" -- Evaluation perplexity: {perplexity:.4f}")
390
+
391
+ def test_ppl_token():
392
+ global logprob_sum, logprob_count, i, input_ids
393
+ global logits, target_ids, log_probs, token_log_probs
394
+ global mean_log_prob, perplexity
395
+
396
+ # set_catch("model.layers.3")
397
+
398
+ logprob_sum = 0
399
+ logprob_count = 0
400
+
401
+ for i in range(eval_tokens.shape[0]):
402
+
403
+ cache.current_seq_len = 0
404
+
405
+ for j in range(eval_tokens.shape[1] - 1):
406
+ if j % 256 == 0: print(".", end = "")
407
+ sys.stdout.flush()
408
+
409
+ input_ids = eval_tokens[i:i + 1, j:j + 1]
410
+ logits = model.forward(input_ids, cache)
411
+ logits = logits.float() + 1e-10
412
+
413
+ log_probs = F.log_softmax(logits, dim = -1)
414
+ logprob_sum += log_probs[0, 0, eval_tokens[i, j+1]]
415
+ logprob_count += 1
416
+
417
+ # mean_log_prob = logprob_sum / logprob_count
418
+ # perplexity = math.exp(-mean_log_prob)
419
+ # print(f" -- Token {j}: {perplexity:.4f}")
420
+
421
+ print()
422
+
423
+ mean_log_prob = logprob_sum / logprob_count
424
+ perplexity = math.exp(-mean_log_prob)
425
+ print(f" -- Evaluation perplexity: {perplexity:.4f}")
426
+
427
+ if args.eval_token:
428
+ if args.standard_perplexity:
429
+ print(f" !! Note, can't evalutate token perplexity on standard test")
430
+ else:
431
+ print(f" -- Inference (token)", end = "")
432
+ sys.stdout.flush()
433
+ cache = ExLlamaV2Cache(model, max_seq_len = eval_length)
434
+ test_ppl_token()
435
+
436
+ if args.eval_token_8bit:
437
+ if args.standard_perplexity:
438
+ print(f" !! Note, can't evalutate token perplexity on standard test")
439
+ else:
440
+ print(f" -- Inference (token, 8-bit cache)", end = "")
441
+ sys.stdout.flush()
442
+ cache = ExLlamaV2Cache_8bit(model, max_seq_len = eval_length)
443
+ test_ppl_token()
444
+
445
+ if args.eval_token_q4:
446
+ if args.standard_perplexity:
447
+ print(f" !! Note, can't evalutate token perplexity on standard test")
448
+ else:
449
+ print(f" -- Inference (token, Q4 cache)", end = "")
450
+ sys.stdout.flush()
451
+ cache = ExLlamaV2Cache_Q4(model, max_seq_len = eval_length)
452
+ test_ppl_token()
453
+
454
+
455
+ # Test prompt speed
456
+
457
+ if args.prompt_speed:
458
+
459
+ with torch.inference_mode():
460
+
461
+ if cache is None:
462
+ cache = ExLlamaV2Cache(model)
463
+
464
+ ids = torch.randint(0, model.config.vocab_size - 1, (1, model.config.max_seq_len))
465
+
466
+ print(f" -- Warmup...")
467
+
468
+ if not args.no_warmup:
469
+ model.forward(ids[:, -1:])
470
+
471
+ print(f" -- Measuring prompt speed...")
472
+
473
+ torch.cuda.synchronize()
474
+
475
+ current_len = 128
476
+ step = 128
477
+ prompt_iters = 3
478
+ while True:
479
+
480
+ total_time = 0
481
+ for i in range(prompt_iters):
482
+
483
+ torch.cuda.synchronize()
484
+ time_begin = time.time()
485
+
486
+ cache.current_seq_len = 0
487
+ model.forward(ids[:, :current_len], cache, preprocess_only = True)
488
+
489
+ torch.cuda.synchronize()
490
+ time_end = time.time()
491
+ total_time += time_end - time_begin
492
+
493
+ tps = current_len / (total_time / prompt_iters)
494
+
495
+ print(f" ** Length {current_len:>5} tokens: {tps:>11.4f} t/s")
496
+
497
+ if current_len >= 1024: step = 1024
498
+ if current_len >= 4096: step = 4096
499
+ if current_len >= 16384: step = 8192
500
+
501
+ current_len_ = current_len
502
+ current_len = min(current_len + step, model.config.max_seq_len)
503
+ if current_len == current_len_: break
504
+
505
+
506
+ # Test token speed
507
+
508
+ if args.speed:
509
+
510
+ with torch.inference_mode():
511
+
512
+ if cache is None:
513
+ cache = ExLlamaV2Cache(model)
514
+ cache.current_seq_len = 0
515
+
516
+ print(f" -- Measuring token speed...")
517
+ ids = tokenizer.encode("X")
518
+ model.forward(ids[:, :])
519
+
520
+ current_idx = ids.shape[-1]
521
+ next_stop = 128
522
+
523
+ while True:
524
+
525
+ time_begin = time.time()
526
+
527
+ tokens = next_stop - current_idx
528
+ for i in range(tokens):
529
+
530
+ logits = model.forward(ids[:, -1:], cache)
531
+ sample = torch.argmax(logits[0, -1]).cpu().unsqueeze(0).unsqueeze(0)
532
+ ids = torch.cat((ids, sample), dim=-1)
533
+
534
+ time_end = time.time()
535
+ tps = tokens / (time_end - time_begin)
536
+
537
+ print(f" ** Position {current_idx:>5} + {tokens:>3} tokens: {tps:>9.4f} t/s")
538
+
539
+ current_idx = next_stop
540
+ next_stop = min(next_stop + 128, model.config.max_seq_len)
541
+ if next_stop == current_idx: break
542
+