Initial GPTQ model commit
Browse files- h2oai_pipeline.py +929 -0
h2oai_pipeline.py
ADDED
@@ -0,0 +1,929 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from transformers import TextGenerationPipeline
|
4 |
+
from transformers.pipelines.text_generation import ReturnType
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
class H2OTextGenerationPipeline(TextGenerationPipeline):
|
11 |
+
def __init__(self, *args, debug=False, chat=False, stream_output=False,
|
12 |
+
sanitize_bot_response=False,
|
13 |
+
use_prompter=True, prompter=None,
|
14 |
+
prompt_type=None, prompt_dict=None,
|
15 |
+
max_input_tokens=2048 - 256, **kwargs):
|
16 |
+
"""
|
17 |
+
HF-like pipeline, but handle instruction prompting and stopping (for some models)
|
18 |
+
:param args:
|
19 |
+
:param debug:
|
20 |
+
:param chat:
|
21 |
+
:param stream_output:
|
22 |
+
:param sanitize_bot_response:
|
23 |
+
:param use_prompter: Whether to use prompter. If pass prompt_type, will make prompter
|
24 |
+
:param prompter: prompter, can pass if have already
|
25 |
+
:param prompt_type: prompt_type, e.g. human_bot. See prompt_type to model mapping in
|
26 |
+
If use_prompter, then will make prompter and use it.
|
27 |
+
:param prompt_dict: dict of get_prompt(, return_dict=True) for prompt_type=custom
|
28 |
+
:param max_input_tokens:
|
29 |
+
:param kwargs:
|
30 |
+
"""
|
31 |
+
super().__init__(*args, **kwargs)
|
32 |
+
self.prompt_text = None
|
33 |
+
self.use_prompter = use_prompter
|
34 |
+
self.prompt_type = prompt_type
|
35 |
+
self.prompt_dict = prompt_dict
|
36 |
+
self.prompter = prompter
|
37 |
+
if self.use_prompter:
|
38 |
+
if self.prompter is not None:
|
39 |
+
assert self.prompter.prompt_type is not None
|
40 |
+
else:
|
41 |
+
self.prompter = Prompter(self.prompt_type, self.prompt_dict, debug=debug, chat=chat,
|
42 |
+
stream_output=stream_output)
|
43 |
+
self.human = self.prompter.humanstr
|
44 |
+
self.bot = self.prompter.botstr
|
45 |
+
self.can_stop = True
|
46 |
+
else:
|
47 |
+
self.prompter = None
|
48 |
+
self.human = None
|
49 |
+
self.bot = None
|
50 |
+
self.can_stop = False
|
51 |
+
self.sanitize_bot_response = sanitize_bot_response
|
52 |
+
self.max_input_tokens = max_input_tokens # not for generate, so ok that not kwargs
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def limit_prompt(prompt_text, tokenizer, max_prompt_length=None):
|
56 |
+
verbose = bool(int(os.getenv('VERBOSE_PIPELINE', '0')))
|
57 |
+
|
58 |
+
if hasattr(tokenizer, 'model_max_length'):
|
59 |
+
# model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py
|
60 |
+
model_max_length = tokenizer.model_max_length
|
61 |
+
if max_prompt_length is not None:
|
62 |
+
model_max_length = min(model_max_length, max_prompt_length)
|
63 |
+
# cut at some upper likely limit to avoid excessive tokenization etc
|
64 |
+
# upper bound of 10 chars/token, e.g. special chars sometimes are long
|
65 |
+
if len(prompt_text) > model_max_length * 10:
|
66 |
+
len0 = len(prompt_text)
|
67 |
+
prompt_text = prompt_text[-model_max_length * 10:]
|
68 |
+
if verbose:
|
69 |
+
print("Cut of input: %s -> %s" % (len0, len(prompt_text)), flush=True)
|
70 |
+
else:
|
71 |
+
# unknown
|
72 |
+
model_max_length = None
|
73 |
+
|
74 |
+
if model_max_length is not None:
|
75 |
+
num_prompt_tokens = None
|
76 |
+
# can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
|
77 |
+
# For https://github.com/h2oai/h2ogpt/issues/192
|
78 |
+
for trial in range(0, 3):
|
79 |
+
prompt_tokens = tokenizer(prompt_text)['input_ids']
|
80 |
+
num_prompt_tokens = len(prompt_tokens)
|
81 |
+
if num_prompt_tokens > model_max_length:
|
82 |
+
# conservative by using int()
|
83 |
+
chars_per_token = int(len(prompt_text) / num_prompt_tokens)
|
84 |
+
# keep tail, where question is if using langchain
|
85 |
+
prompt_text = prompt_text[-model_max_length * chars_per_token:]
|
86 |
+
if verbose:
|
87 |
+
print("reducing %s tokens, assuming average of %s chars/token for %s characters" % (
|
88 |
+
num_prompt_tokens, chars_per_token, len(prompt_text)), flush=True)
|
89 |
+
else:
|
90 |
+
if verbose:
|
91 |
+
print("using %s tokens with %s chars" % (num_prompt_tokens, len(prompt_text)), flush=True)
|
92 |
+
break
|
93 |
+
|
94 |
+
# Why Below False: don't limit max_new_tokens more, just rely upon stopping to reach limit of model
|
95 |
+
if False:
|
96 |
+
# if input prompt is some number of tokens, despite user request, can't have max_new_tokens more
|
97 |
+
#
|
98 |
+
assert num_prompt_tokens is not None
|
99 |
+
if self.prompt_type not in [PromptType.plain.name, PromptType.plain.value]:
|
100 |
+
# then give room for prompt
|
101 |
+
fudge = 20
|
102 |
+
else:
|
103 |
+
fudge = 0
|
104 |
+
max_new_tokens = max(0, min(generate_kwargs['max_new_tokens'],
|
105 |
+
model_max_length - (num_prompt_tokens + fudge)))
|
106 |
+
if max_new_tokens < generate_kwargs['max_new_tokens']:
|
107 |
+
if verbose:
|
108 |
+
print("Reduced max_new_tokens from %s -> %s" % (
|
109 |
+
generate_kwargs['max_new_tokens'], max_new_tokens))
|
110 |
+
generate_kwargs['max_new_tokens'] = max_new_tokens
|
111 |
+
return prompt_text
|
112 |
+
|
113 |
+
def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
|
114 |
+
prompt_text = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
|
115 |
+
|
116 |
+
data_point = dict(context='', instruction=prompt_text, input='')
|
117 |
+
if self.prompter is not None:
|
118 |
+
prompt_text = self.prompter.generate_prompt(data_point)
|
119 |
+
self.prompt_text = prompt_text
|
120 |
+
if handle_long_generation is None:
|
121 |
+
# forces truncation of inputs to avoid critical failure
|
122 |
+
handle_long_generation = None # disable with new approaches
|
123 |
+
return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation,
|
124 |
+
**generate_kwargs)
|
125 |
+
|
126 |
+
def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True):
|
127 |
+
records = super().postprocess(model_outputs, return_type=return_type,
|
128 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces)
|
129 |
+
for rec in records:
|
130 |
+
if self.use_prompter:
|
131 |
+
outputs = rec['generated_text']
|
132 |
+
outputs = self.prompter.get_response(outputs, prompt=self.prompt_text,
|
133 |
+
sanitize_bot_response=self.sanitize_bot_response)
|
134 |
+
elif self.bot and self.human:
|
135 |
+
outputs = rec['generated_text'].split(self.bot)[1].strip().split(self.human)[0].strip()
|
136 |
+
else:
|
137 |
+
outputs = rec['generated_text']
|
138 |
+
rec['generated_text'] = outputs
|
139 |
+
return records
|
140 |
+
|
141 |
+
def _forward(self, model_inputs, **generate_kwargs):
|
142 |
+
if self.can_stop:
|
143 |
+
stopping_criteria = get_stopping(self.prompt_type, self.prompt_dict,
|
144 |
+
self.tokenizer, self.device,
|
145 |
+
human=self.human, bot=self.bot,
|
146 |
+
model_max_length=self.tokenizer.model_max_length)
|
147 |
+
generate_kwargs['stopping_criteria'] = stopping_criteria
|
148 |
+
# return super()._forward(model_inputs, **generate_kwargs)
|
149 |
+
return self.__forward(model_inputs, **generate_kwargs)
|
150 |
+
|
151 |
+
# FIXME: Copy-paste of original _forward, but removed copy.deepcopy()
|
152 |
+
# FIXME: https://github.com/h2oai/h2ogpt/issues/172
|
153 |
+
def __forward(self, model_inputs, **generate_kwargs):
|
154 |
+
input_ids = model_inputs["input_ids"]
|
155 |
+
attention_mask = model_inputs.get("attention_mask", None)
|
156 |
+
# Allow empty prompts
|
157 |
+
if input_ids.shape[1] == 0:
|
158 |
+
input_ids = None
|
159 |
+
attention_mask = None
|
160 |
+
in_b = 1
|
161 |
+
else:
|
162 |
+
in_b = input_ids.shape[0]
|
163 |
+
prompt_text = model_inputs.pop("prompt_text")
|
164 |
+
|
165 |
+
## If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
|
166 |
+
## generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
|
167 |
+
# generate_kwargs = copy.deepcopy(generate_kwargs)
|
168 |
+
prefix_length = generate_kwargs.pop("prefix_length", 0)
|
169 |
+
if prefix_length > 0:
|
170 |
+
has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
|
171 |
+
"generation_config" in generate_kwargs
|
172 |
+
and generate_kwargs["generation_config"].max_new_tokens is not None
|
173 |
+
)
|
174 |
+
if not has_max_new_tokens:
|
175 |
+
generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length
|
176 |
+
generate_kwargs["max_length"] += prefix_length
|
177 |
+
has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
|
178 |
+
"generation_config" in generate_kwargs
|
179 |
+
and generate_kwargs["generation_config"].min_new_tokens is not None
|
180 |
+
)
|
181 |
+
if not has_min_new_tokens and "min_length" in generate_kwargs:
|
182 |
+
generate_kwargs["min_length"] += prefix_length
|
183 |
+
|
184 |
+
# BS x SL
|
185 |
+
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
|
186 |
+
out_b = generated_sequence.shape[0]
|
187 |
+
if self.framework == "pt":
|
188 |
+
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
|
189 |
+
elif self.framework == "tf":
|
190 |
+
from transformers import is_tf_available
|
191 |
+
if is_tf_available():
|
192 |
+
import tensorflow as tf
|
193 |
+
generated_sequence = tf.reshape(generated_sequence,
|
194 |
+
(in_b, out_b // in_b, *generated_sequence.shape[1:]))
|
195 |
+
else:
|
196 |
+
raise ValueError("TF not avaialble.")
|
197 |
+
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
|
198 |
+
import torch
|
199 |
+
from transformers import StoppingCriteria, StoppingCriteriaList
|
200 |
+
|
201 |
+
|
202 |
+
|
203 |
+
class StoppingCriteriaSub(StoppingCriteria):
|
204 |
+
|
205 |
+
def __init__(self, stops=[], encounters=[], device="cuda", model_max_length=None):
|
206 |
+
super().__init__()
|
207 |
+
assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
|
208 |
+
self.encounters = encounters
|
209 |
+
self.stops = [stop.to(device) for stop in stops]
|
210 |
+
self.num_stops = [0] * len(stops)
|
211 |
+
self.model_max_length = model_max_length
|
212 |
+
|
213 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
214 |
+
for stopi, stop in enumerate(self.stops):
|
215 |
+
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
216 |
+
self.num_stops[stopi] += 1
|
217 |
+
if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
|
218 |
+
# print("Stopped", flush=True)
|
219 |
+
return True
|
220 |
+
if self.model_max_length is not None and input_ids[0].shape[0] >= self.model_max_length:
|
221 |
+
# critical limit
|
222 |
+
return True
|
223 |
+
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
|
224 |
+
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
|
225 |
+
return False
|
226 |
+
|
227 |
+
|
228 |
+
def get_stopping(prompt_type, prompt_dict, tokenizer, device, human='<human>:', bot="<bot>:", model_max_length=None):
|
229 |
+
# FIXME: prompt_dict unused currently
|
230 |
+
if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]:
|
231 |
+
if prompt_type == PromptType.human_bot.name:
|
232 |
+
# encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
|
233 |
+
# stopping only starts once output is beyond prompt
|
234 |
+
# 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
|
235 |
+
stop_words = [human, bot, '\n' + human, '\n' + bot]
|
236 |
+
encounters = [1, 2]
|
237 |
+
elif prompt_type == PromptType.instruct_vicuna.name:
|
238 |
+
# even below is not enough, generic strings and many ways to encode
|
239 |
+
stop_words = [
|
240 |
+
'### Human:',
|
241 |
+
"""
|
242 |
+
### Human:""",
|
243 |
+
"""
|
244 |
+
### Human:
|
245 |
+
""",
|
246 |
+
'### Assistant:',
|
247 |
+
"""
|
248 |
+
### Assistant:""",
|
249 |
+
"""
|
250 |
+
### Assistant:
|
251 |
+
""",
|
252 |
+
]
|
253 |
+
encounters = [1, 2]
|
254 |
+
else:
|
255 |
+
# some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
|
256 |
+
stop_words = ['### End']
|
257 |
+
encounters = [1]
|
258 |
+
stop_words_ids = [
|
259 |
+
tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
|
260 |
+
# handle single token case
|
261 |
+
stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
|
262 |
+
stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
|
263 |
+
# avoid padding in front of tokens
|
264 |
+
if tokenizer._pad_token: # use hidden variable to avoid annoying properly logger bug
|
265 |
+
stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
|
266 |
+
# handle fake \n added
|
267 |
+
stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
|
268 |
+
# build stopper
|
269 |
+
stopping_criteria = StoppingCriteriaList(
|
270 |
+
[StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device,
|
271 |
+
model_max_length=model_max_length)])
|
272 |
+
else:
|
273 |
+
stopping_criteria = StoppingCriteriaList()
|
274 |
+
return stopping_criteria
|
275 |
+
from enum import Enum
|
276 |
+
|
277 |
+
|
278 |
+
class PromptType(Enum):
|
279 |
+
custom = -1
|
280 |
+
plain = 0
|
281 |
+
instruct = 1
|
282 |
+
quality = 2
|
283 |
+
human_bot = 3
|
284 |
+
dai_faq = 4
|
285 |
+
summarize = 5
|
286 |
+
simple_instruct = 6
|
287 |
+
instruct_vicuna = 7
|
288 |
+
instruct_with_end = 8
|
289 |
+
human_bot_orig = 9
|
290 |
+
prompt_answer = 10
|
291 |
+
open_assistant = 11
|
292 |
+
wizard_lm = 12
|
293 |
+
wizard_mega = 13
|
294 |
+
instruct_vicuna2 = 14
|
295 |
+
instruct_vicuna3 = 15
|
296 |
+
wizard2 = 16
|
297 |
+
wizard3 = 17
|
298 |
+
instruct_simple = 18
|
299 |
+
|
300 |
+
|
301 |
+
class DocumentChoices(Enum):
|
302 |
+
All_Relevant = 0
|
303 |
+
All_Relevant_Only_Sources = 1
|
304 |
+
Only_All_Sources = 2
|
305 |
+
Just_LLM = 3
|
306 |
+
|
307 |
+
|
308 |
+
class LangChainMode(Enum):
|
309 |
+
"""LangChain mode"""
|
310 |
+
|
311 |
+
DISABLED = "Disabled"
|
312 |
+
CHAT_LLM = "ChatLLM"
|
313 |
+
LLM = "LLM"
|
314 |
+
ALL = "All"
|
315 |
+
WIKI = "wiki"
|
316 |
+
WIKI_FULL = "wiki_full"
|
317 |
+
USER_DATA = "UserData"
|
318 |
+
MY_DATA = "MyData"
|
319 |
+
GITHUB_H2OGPT = "github h2oGPT"
|
320 |
+
H2O_DAI_DOCS = "DriverlessAI docs"
|
321 |
+
import ast
|
322 |
+
import time
|
323 |
+
from enums import PromptType # also supports imports from this file from other files
|
324 |
+
|
325 |
+
non_hf_types = ['gpt4all_llama', 'llama', 'gptj']
|
326 |
+
|
327 |
+
prompt_type_to_model_name = {
|
328 |
+
'plain': [
|
329 |
+
'EleutherAI/gpt-j-6B',
|
330 |
+
'EleutherAI/pythia-6.9b',
|
331 |
+
'EleutherAI/pythia-12b',
|
332 |
+
'EleutherAI/pythia-12b-deduped',
|
333 |
+
'EleutherAI/gpt-neox-20b',
|
334 |
+
'openlm-research/open_llama_7b_700bt_preview',
|
335 |
+
'decapoda-research/llama-7b-hf',
|
336 |
+
'decapoda-research/llama-13b-hf',
|
337 |
+
'decapoda-research/llama-30b-hf',
|
338 |
+
'decapoda-research/llama-65b-hf',
|
339 |
+
'facebook/mbart-large-50-many-to-many-mmt',
|
340 |
+
'philschmid/bart-large-cnn-samsum',
|
341 |
+
'philschmid/flan-t5-base-samsum',
|
342 |
+
'gpt2',
|
343 |
+
'distilgpt2',
|
344 |
+
'mosaicml/mpt-7b-storywriter',
|
345 |
+
'mosaicml/mpt-7b-instruct', # internal code handles instruct
|
346 |
+
'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
|
347 |
+
'gptj', # internally handles prompting
|
348 |
+
'llama', # plain, or need to choose prompt_type for given TheBloke model
|
349 |
+
'gpt4all_llama', # internally handles prompting
|
350 |
+
],
|
351 |
+
'prompt_answer': [
|
352 |
+
'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
|
353 |
+
'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
|
354 |
+
'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
|
355 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
|
356 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
|
357 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-700bt',
|
358 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b',
|
359 |
+
'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b',
|
360 |
+
'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b-v2',
|
361 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b',
|
362 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v2',
|
363 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v1',
|
364 |
+
],
|
365 |
+
'instruct': [],
|
366 |
+
'instruct_with_end': ['databricks/dolly-v2-12b'],
|
367 |
+
'quality': [],
|
368 |
+
'human_bot': [
|
369 |
+
'h2oai/h2ogpt-oasst1-512-12b',
|
370 |
+
'h2oai/h2ogpt-oasst1-512-20b',
|
371 |
+
'h2oai/h2ogpt-oig-oasst1-256-6_9b',
|
372 |
+
'h2oai/h2ogpt-oig-oasst1-512-6_9b',
|
373 |
+
'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
|
374 |
+
'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
|
375 |
+
'h2oai/h2ogpt-research-oasst1-512-30b',
|
376 |
+
'h2oai/h2ogpt-oasst1-falcon-40b',
|
377 |
+
'h2oai/h2ogpt-oig-oasst1-falcon-40b',
|
378 |
+
],
|
379 |
+
'dai_faq': [],
|
380 |
+
'summarize': [],
|
381 |
+
'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
|
382 |
+
'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b', 'TheBloke/stable-vicuna-13B-HF', 'junelee/wizard-vicuna-13b'],
|
383 |
+
'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
|
384 |
+
"open_assistant": ['OpenAssistant/oasst-sft-7-llama-30b-xor', 'oasst-sft-7-llama-30b'],
|
385 |
+
"wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'],
|
386 |
+
"wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'],
|
387 |
+
"instruct_simple": ['JosephusCheung/Guanaco'],
|
388 |
+
}
|
389 |
+
|
390 |
+
inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
391 |
+
inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
392 |
+
|
393 |
+
prompt_types_strings = []
|
394 |
+
for p in PromptType:
|
395 |
+
prompt_types_strings.extend([p.name])
|
396 |
+
|
397 |
+
prompt_types = []
|
398 |
+
for p in PromptType:
|
399 |
+
prompt_types.extend([p.name, p.value, str(p.value)])
|
400 |
+
|
401 |
+
|
402 |
+
def get_prompt(prompt_type, prompt_dict, chat, context, reduced, return_dict=False):
|
403 |
+
prompt_dict_error = ''
|
404 |
+
if prompt_type == PromptType.custom.name and not isinstance(prompt_dict, dict):
|
405 |
+
try:
|
406 |
+
prompt_dict = ast.literal_eval(prompt_dict)
|
407 |
+
except BaseException as e:
|
408 |
+
prompt_dict_error = str(e)
|
409 |
+
if prompt_dict_error:
|
410 |
+
return dict(), prompt_dict_error
|
411 |
+
|
412 |
+
if prompt_type in [PromptType.custom.value, str(PromptType.custom.value),
|
413 |
+
PromptType.custom.name]:
|
414 |
+
promptA = prompt_dict.get('promptA', '')
|
415 |
+
promptB = prompt_dict('promptB', '')
|
416 |
+
PreInstruct = prompt_dict.get('PreInstruct', '')
|
417 |
+
PreInput = prompt_dict.get('PreInput', '')
|
418 |
+
PreResponse = prompt_dict.get('PreResponse', '')
|
419 |
+
terminate_response = prompt_dict.get('terminate_response', None)
|
420 |
+
chat_sep = prompt_dict.get('chat_sep', '\n')
|
421 |
+
humanstr = prompt_dict.get('humanstr', '')
|
422 |
+
botstr = prompt_dict.get('botstr', '')
|
423 |
+
elif prompt_type in [PromptType.plain.value, str(PromptType.plain.value),
|
424 |
+
PromptType.plain.name]:
|
425 |
+
promptA = promptB = PreInstruct = PreInput = PreResponse = ''
|
426 |
+
terminate_response = []
|
427 |
+
chat_sep = ''
|
428 |
+
humanstr = ''
|
429 |
+
botstr = ''
|
430 |
+
elif prompt_type == 'simple_instruct':
|
431 |
+
promptA = promptB = PreInstruct = PreInput = PreResponse = None
|
432 |
+
terminate_response = []
|
433 |
+
chat_sep = '\n'
|
434 |
+
humanstr = ''
|
435 |
+
botstr = ''
|
436 |
+
elif prompt_type in [PromptType.instruct.value, str(PromptType.instruct.value),
|
437 |
+
PromptType.instruct.name] + [PromptType.instruct_with_end.value,
|
438 |
+
str(PromptType.instruct_with_end.value),
|
439 |
+
PromptType.instruct_with_end.name]:
|
440 |
+
promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (
|
441 |
+
chat and reduced) else ''
|
442 |
+
promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (
|
443 |
+
chat and reduced) else ''
|
444 |
+
|
445 |
+
PreInstruct = """
|
446 |
+
### Instruction:
|
447 |
+
"""
|
448 |
+
|
449 |
+
PreInput = """
|
450 |
+
### Input:
|
451 |
+
"""
|
452 |
+
|
453 |
+
PreResponse = """
|
454 |
+
### Response:
|
455 |
+
"""
|
456 |
+
if prompt_type in [PromptType.instruct_with_end.value, str(PromptType.instruct_with_end.value),
|
457 |
+
PromptType.instruct_with_end.name]:
|
458 |
+
terminate_response = ['### End']
|
459 |
+
else:
|
460 |
+
terminate_response = None
|
461 |
+
chat_sep = '\n'
|
462 |
+
humanstr = PreInstruct
|
463 |
+
botstr = PreResponse
|
464 |
+
elif prompt_type in [PromptType.quality.value, str(PromptType.quality.value),
|
465 |
+
PromptType.quality.name]:
|
466 |
+
promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (
|
467 |
+
chat and reduced) else ''
|
468 |
+
promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (
|
469 |
+
chat and reduced) else ''
|
470 |
+
|
471 |
+
PreInstruct = """
|
472 |
+
### Instruction:
|
473 |
+
"""
|
474 |
+
|
475 |
+
PreInput = """
|
476 |
+
### Input:
|
477 |
+
"""
|
478 |
+
|
479 |
+
PreResponse = """
|
480 |
+
### Response:
|
481 |
+
"""
|
482 |
+
terminate_response = None
|
483 |
+
chat_sep = '\n'
|
484 |
+
humanstr = PreInstruct # first thing human says
|
485 |
+
botstr = PreResponse # first thing bot says
|
486 |
+
elif prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
|
487 |
+
PromptType.human_bot.name] + [PromptType.human_bot_orig.value,
|
488 |
+
str(PromptType.human_bot_orig.value),
|
489 |
+
PromptType.human_bot_orig.name]:
|
490 |
+
human = '<human>:'
|
491 |
+
bot = "<bot>:"
|
492 |
+
if reduced or context or prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
|
493 |
+
PromptType.human_bot.name]:
|
494 |
+
preprompt = ''
|
495 |
+
else:
|
496 |
+
cur_date = time.strftime('%Y-%m-%d')
|
497 |
+
cur_time = time.strftime('%H:%M:%S %p %Z')
|
498 |
+
|
499 |
+
PRE_PROMPT = """\
|
500 |
+
Current Date: {}
|
501 |
+
Current Time: {}
|
502 |
+
|
503 |
+
"""
|
504 |
+
preprompt = PRE_PROMPT.format(cur_date, cur_time)
|
505 |
+
start = human
|
506 |
+
promptB = promptA = '%s%s ' % (preprompt, start)
|
507 |
+
|
508 |
+
PreInstruct = ""
|
509 |
+
|
510 |
+
PreInput = None
|
511 |
+
|
512 |
+
if reduced:
|
513 |
+
# when making context, want it to appear as-if LLM generated, which starts with space after :
|
514 |
+
PreResponse = bot + ' '
|
515 |
+
else:
|
516 |
+
# normally LLM adds space after this, because was how trained.
|
517 |
+
# if add space here, non-unique tokenization will often make LLM produce wrong output
|
518 |
+
PreResponse = bot
|
519 |
+
|
520 |
+
terminate_response = [start, PreResponse]
|
521 |
+
chat_sep = '\n'
|
522 |
+
humanstr = human # tag before human talks
|
523 |
+
botstr = bot # tag before bot talks
|
524 |
+
elif prompt_type in [PromptType.dai_faq.value, str(PromptType.dai_faq.value),
|
525 |
+
PromptType.dai_faq.name]:
|
526 |
+
promptA = ''
|
527 |
+
promptB = 'Answer the following Driverless AI question.\n'
|
528 |
+
|
529 |
+
PreInstruct = """
|
530 |
+
### Driverless AI frequently asked question:
|
531 |
+
"""
|
532 |
+
|
533 |
+
PreInput = None
|
534 |
+
|
535 |
+
PreResponse = """
|
536 |
+
### Driverless AI documentation answer:
|
537 |
+
"""
|
538 |
+
terminate_response = ['\n\n']
|
539 |
+
chat_sep = terminate_response
|
540 |
+
humanstr = PreInstruct
|
541 |
+
botstr = PreResponse
|
542 |
+
elif prompt_type in [PromptType.summarize.value, str(PromptType.summarize.value),
|
543 |
+
PromptType.summarize.name]:
|
544 |
+
promptA = promptB = PreInput = ''
|
545 |
+
PreInstruct = '## Main Text\n\n'
|
546 |
+
PreResponse = '\n\n## Summary\n\n'
|
547 |
+
terminate_response = None
|
548 |
+
chat_sep = '\n'
|
549 |
+
humanstr = PreInstruct
|
550 |
+
botstr = PreResponse
|
551 |
+
elif prompt_type in [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value),
|
552 |
+
PromptType.instruct_vicuna.name]:
|
553 |
+
promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
|
554 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions." if not (
|
555 |
+
chat and reduced) else ''
|
556 |
+
|
557 |
+
PreInstruct = """
|
558 |
+
### Human:
|
559 |
+
"""
|
560 |
+
|
561 |
+
PreInput = None
|
562 |
+
|
563 |
+
PreResponse = """
|
564 |
+
### Assistant:
|
565 |
+
"""
|
566 |
+
terminate_response = [
|
567 |
+
'### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
568 |
+
chat_sep = '\n'
|
569 |
+
humanstr = PreInstruct
|
570 |
+
botstr = PreResponse
|
571 |
+
elif prompt_type in [PromptType.prompt_answer.value, str(PromptType.prompt_answer.value),
|
572 |
+
PromptType.prompt_answer.name]:
|
573 |
+
preprompt = ''
|
574 |
+
prompt_tokens = "<|prompt|>"
|
575 |
+
answer_tokens = "<|answer|>"
|
576 |
+
start = prompt_tokens
|
577 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
578 |
+
PreInstruct = ""
|
579 |
+
PreInput = None
|
580 |
+
PreResponse = answer_tokens
|
581 |
+
eos = '<|endoftext|>' # neox eos
|
582 |
+
terminate_response = [start, PreResponse, eos]
|
583 |
+
chat_sep = eos
|
584 |
+
humanstr = prompt_tokens
|
585 |
+
botstr = answer_tokens
|
586 |
+
elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value),
|
587 |
+
PromptType.open_assistant.name]:
|
588 |
+
# From added_tokens.json
|
589 |
+
preprompt = ''
|
590 |
+
prompt_tokens = "<|prompter|>"
|
591 |
+
answer_tokens = "<|assistant|>"
|
592 |
+
start = prompt_tokens
|
593 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
594 |
+
PreInstruct = ""
|
595 |
+
PreInput = None
|
596 |
+
PreResponse = answer_tokens
|
597 |
+
pend = "<|prefix_end|>"
|
598 |
+
eos = "</s>"
|
599 |
+
terminate_response = [start, PreResponse, pend, eos]
|
600 |
+
chat_sep = eos
|
601 |
+
humanstr = prompt_tokens
|
602 |
+
botstr = answer_tokens
|
603 |
+
elif prompt_type in [PromptType.wizard_lm.value, str(PromptType.wizard_lm.value),
|
604 |
+
PromptType.wizard_lm.name]:
|
605 |
+
# https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
|
606 |
+
preprompt = ''
|
607 |
+
start = ''
|
608 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
609 |
+
PreInstruct = ""
|
610 |
+
PreInput = None
|
611 |
+
PreResponse = "\n\n### Response\n"
|
612 |
+
eos = "</s>"
|
613 |
+
terminate_response = [PreResponse, eos]
|
614 |
+
chat_sep = eos
|
615 |
+
humanstr = promptA
|
616 |
+
botstr = PreResponse
|
617 |
+
elif prompt_type in [PromptType.wizard_mega.value, str(PromptType.wizard_mega.value),
|
618 |
+
PromptType.wizard_mega.name]:
|
619 |
+
preprompt = ''
|
620 |
+
start = ''
|
621 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
622 |
+
PreInstruct = """
|
623 |
+
### Instruction:
|
624 |
+
"""
|
625 |
+
PreInput = None
|
626 |
+
PreResponse = """
|
627 |
+
### Assistant:
|
628 |
+
"""
|
629 |
+
terminate_response = [PreResponse]
|
630 |
+
chat_sep = '\n'
|
631 |
+
humanstr = PreInstruct
|
632 |
+
botstr = PreResponse
|
633 |
+
elif prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value),
|
634 |
+
PromptType.instruct_vicuna2.name]:
|
635 |
+
promptA = promptB = "" if not (
|
636 |
+
chat and reduced) else ''
|
637 |
+
|
638 |
+
PreInstruct = """
|
639 |
+
HUMAN:
|
640 |
+
"""
|
641 |
+
|
642 |
+
PreInput = None
|
643 |
+
|
644 |
+
PreResponse = """
|
645 |
+
ASSISTANT:
|
646 |
+
"""
|
647 |
+
terminate_response = [
|
648 |
+
'HUMAN:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
649 |
+
chat_sep = '\n'
|
650 |
+
humanstr = PreInstruct
|
651 |
+
botstr = PreResponse
|
652 |
+
elif prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value),
|
653 |
+
PromptType.instruct_vicuna3.name]:
|
654 |
+
promptA = promptB = "" if not (
|
655 |
+
chat and reduced) else ''
|
656 |
+
|
657 |
+
PreInstruct = """
|
658 |
+
### User:
|
659 |
+
"""
|
660 |
+
|
661 |
+
PreInput = None
|
662 |
+
|
663 |
+
PreResponse = """
|
664 |
+
### Assistant:
|
665 |
+
"""
|
666 |
+
terminate_response = [
|
667 |
+
'### User:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
668 |
+
chat_sep = '\n'
|
669 |
+
humanstr = PreInstruct
|
670 |
+
botstr = PreResponse
|
671 |
+
elif prompt_type in [PromptType.wizard2.value, str(PromptType.wizard2.value),
|
672 |
+
PromptType.wizard2.name]:
|
673 |
+
# https://huggingface.co/TheBloke/WizardLM-7B-uncensored-GGML
|
674 |
+
preprompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request."""
|
675 |
+
start = ''
|
676 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
677 |
+
PreInstruct = """
|
678 |
+
### Instruction:
|
679 |
+
"""
|
680 |
+
PreInput = None
|
681 |
+
PreResponse = """
|
682 |
+
### Response:
|
683 |
+
"""
|
684 |
+
terminate_response = [PreResponse]
|
685 |
+
chat_sep = '\n'
|
686 |
+
humanstr = PreInstruct
|
687 |
+
botstr = PreResponse
|
688 |
+
elif prompt_type in [PromptType.wizard3.value, str(PromptType.wizard3.value),
|
689 |
+
PromptType.wizard3.name]:
|
690 |
+
# https://huggingface.co/TheBloke/wizardLM-13B-1.0-GGML
|
691 |
+
preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."""
|
692 |
+
start = ''
|
693 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
694 |
+
PreInstruct = """USER: """
|
695 |
+
PreInput = None
|
696 |
+
PreResponse = """ASSISTANT: """
|
697 |
+
terminate_response = [PreResponse]
|
698 |
+
chat_sep = '\n'
|
699 |
+
humanstr = PreInstruct
|
700 |
+
botstr = PreResponse
|
701 |
+
|
702 |
+
elif prompt_type in [PromptType.instruct_simple.value, str(PromptType.instruct_simple.value),
|
703 |
+
PromptType.instruct_simple.name]:
|
704 |
+
promptA = '' if not (chat and reduced) else ''
|
705 |
+
promptB = '' if not (chat and reduced) else ''
|
706 |
+
|
707 |
+
PreInstruct = """
|
708 |
+
### Instruction:
|
709 |
+
"""
|
710 |
+
|
711 |
+
PreInput = """
|
712 |
+
### Input:
|
713 |
+
"""
|
714 |
+
|
715 |
+
PreResponse = """
|
716 |
+
### Response:
|
717 |
+
"""
|
718 |
+
terminate_response = None
|
719 |
+
chat_sep = '\n'
|
720 |
+
humanstr = PreInstruct
|
721 |
+
botstr = PreResponse
|
722 |
+
else:
|
723 |
+
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
724 |
+
|
725 |
+
if return_dict:
|
726 |
+
return dict(promptA=promptA, promptB=promptB, PreInstruct=PreInstruct, PreInput=PreInput,
|
727 |
+
PreResponse=PreResponse, terminate_response=terminate_response, chat_sep=chat_sep,
|
728 |
+
humanstr=humanstr, botstr=botstr), ''
|
729 |
+
else:
|
730 |
+
return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response, chat_sep, humanstr, botstr
|
731 |
+
|
732 |
+
|
733 |
+
def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced):
|
734 |
+
context = data_point.get('context')
|
735 |
+
if context is None:
|
736 |
+
context = ''
|
737 |
+
instruction = data_point.get('instruction')
|
738 |
+
input = data_point.get('input')
|
739 |
+
output = data_point.get('output')
|
740 |
+
prompt_type = data_point.get('prompt_type', prompt_type)
|
741 |
+
prompt_dict = data_point.get('prompt_dict', prompt_dict)
|
742 |
+
assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
|
743 |
+
promptA, promptB, PreInstruct, PreInput, PreResponse, \
|
744 |
+
terminate_response, chat_sep, humanstr, botstr = get_prompt(prompt_type, prompt_dict, chat, context, reduced)
|
745 |
+
|
746 |
+
prompt = context if not reduced else ''
|
747 |
+
|
748 |
+
if input and promptA:
|
749 |
+
prompt += f"""{promptA}"""
|
750 |
+
elif promptB:
|
751 |
+
prompt += f"""{promptB}"""
|
752 |
+
|
753 |
+
if instruction and PreInstruct is not None and input and PreInput is not None:
|
754 |
+
prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
|
755 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
756 |
+
elif instruction and input and PreInstruct is None and PreInput is not None:
|
757 |
+
prompt += f"""{PreInput}{instruction}
|
758 |
+
{input}"""
|
759 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
760 |
+
elif input and instruction and PreInput is None and PreInstruct is not None:
|
761 |
+
prompt += f"""{PreInstruct}{instruction}
|
762 |
+
{input}"""
|
763 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
764 |
+
elif instruction and PreInstruct is not None:
|
765 |
+
prompt += f"""{PreInstruct}{instruction}"""
|
766 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
767 |
+
elif input and PreInput is not None:
|
768 |
+
prompt += f"""{PreInput}{input}"""
|
769 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
770 |
+
elif input and instruction and PreInput is not None:
|
771 |
+
prompt += f"""{PreInput}{instruction}{input}"""
|
772 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
773 |
+
elif input and instruction and PreInstruct is not None:
|
774 |
+
prompt += f"""{PreInstruct}{instruction}{input}"""
|
775 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
776 |
+
elif input and instruction:
|
777 |
+
# i.e. for simple_instruct
|
778 |
+
prompt += f"""{instruction}: {input}"""
|
779 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
780 |
+
elif input:
|
781 |
+
prompt += f"""{input}"""
|
782 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
783 |
+
elif instruction:
|
784 |
+
prompt += f"""{instruction}"""
|
785 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
786 |
+
|
787 |
+
if PreResponse is not None:
|
788 |
+
prompt += f"""{PreResponse}"""
|
789 |
+
pre_response = PreResponse # Don't use strip
|
790 |
+
else:
|
791 |
+
pre_response = ''
|
792 |
+
|
793 |
+
if output:
|
794 |
+
prompt += f"""{output}"""
|
795 |
+
|
796 |
+
return prompt, pre_response, terminate_response, chat_sep
|
797 |
+
|
798 |
+
|
799 |
+
def inject_chatsep(prompt_type, prompt, chat_sep=None):
|
800 |
+
if chat_sep:
|
801 |
+
# only add new line if structured prompt, while 'plain' is just generation of next tokens from input
|
802 |
+
prompt += chat_sep
|
803 |
+
return prompt
|
804 |
+
|
805 |
+
|
806 |
+
class Prompter(object):
|
807 |
+
def __init__(self, prompt_type, prompt_dict, debug=False, chat=False, stream_output=False, repeat_penalty=True,
|
808 |
+
allowed_repeat_line_length=10):
|
809 |
+
self.prompt_type = prompt_type
|
810 |
+
self.prompt_dict = prompt_dict
|
811 |
+
data_point = dict(instruction='', input='', output='')
|
812 |
+
_, self.pre_response, self.terminate_response, self.chat_sep = \
|
813 |
+
generate_prompt(data_point, self.prompt_type, self.prompt_dict, chat, False)
|
814 |
+
self.debug = debug
|
815 |
+
self.chat = chat
|
816 |
+
self.stream_output = stream_output
|
817 |
+
self.repeat_penalty = repeat_penalty
|
818 |
+
self.allowed_repeat_line_length = allowed_repeat_line_length
|
819 |
+
self.prompt = None
|
820 |
+
context = "" # not for chat context
|
821 |
+
reduced = False # not for chat context
|
822 |
+
self.promptA, self.promptB, self.PreInstruct, self.PreInput, self.PreResponse, \
|
823 |
+
self.terminate_response, self.chat_sep, self.humanstr, self.botstr = \
|
824 |
+
get_prompt(self.prompt_type, self.prompt_dict, chat, context, reduced)
|
825 |
+
|
826 |
+
def generate_prompt(self, data_point):
|
827 |
+
reduced = False
|
828 |
+
prompt, _, _, _ = generate_prompt(data_point, self.prompt_type, self.prompt_dict, self.chat, reduced)
|
829 |
+
if self.debug:
|
830 |
+
print("prompt: %s" % prompt, flush=True)
|
831 |
+
self.prompt = prompt
|
832 |
+
return prompt
|
833 |
+
|
834 |
+
def get_response(self, outputs, prompt=None, sanitize_bot_response=False):
|
835 |
+
if isinstance(outputs, str):
|
836 |
+
outputs = [outputs]
|
837 |
+
if self.debug:
|
838 |
+
print("output:\n%s" % '\n\n'.join(outputs), flush=True)
|
839 |
+
if prompt is not None:
|
840 |
+
self.prompt = prompt
|
841 |
+
|
842 |
+
def clean_response(response):
|
843 |
+
meaningless_words = ['<pad>', '</s>', '<|endoftext|>']
|
844 |
+
for word in meaningless_words:
|
845 |
+
response = response.replace(word, "")
|
846 |
+
if sanitize_bot_response:
|
847 |
+
from better_profanity import profanity
|
848 |
+
response = profanity.censor(response)
|
849 |
+
response = response.strip("\n")
|
850 |
+
return response
|
851 |
+
|
852 |
+
def clean_repeats(response):
|
853 |
+
lines = response.split('\n')
|
854 |
+
new_lines = []
|
855 |
+
[new_lines.append(line) for line in lines if
|
856 |
+
line not in new_lines or len(line) < self.allowed_repeat_line_length]
|
857 |
+
if self.debug and len(lines) != len(new_lines):
|
858 |
+
print("cleaned repeats: %s %s" % (len(lines), len(new_lines)), flush=True)
|
859 |
+
response = '\n'.join(new_lines)
|
860 |
+
return response
|
861 |
+
|
862 |
+
multi_output = len(outputs) > 1
|
863 |
+
|
864 |
+
for oi, output in enumerate(outputs):
|
865 |
+
if self.prompt_type in [PromptType.plain.value, str(PromptType.plain.value), PromptType.plain.name]:
|
866 |
+
output = clean_response(output)
|
867 |
+
elif prompt is None:
|
868 |
+
# then use most basic parsing like pipeline
|
869 |
+
if self.botstr in output:
|
870 |
+
if self.humanstr:
|
871 |
+
output = clean_response(output.split(self.botstr)[1].strip().split(self.humanstr)[0].strip())
|
872 |
+
else:
|
873 |
+
# i.e. use after bot but only up to next bot
|
874 |
+
output = clean_response(output.split(self.botstr)[1].strip().split(self.botstr)[0].strip())
|
875 |
+
else:
|
876 |
+
# output = clean_response(output.strip())
|
877 |
+
# assume just not printed yet
|
878 |
+
output = ""
|
879 |
+
else:
|
880 |
+
# find first instance of prereponse
|
881 |
+
# prompt sometimes has odd characters, that mutate length,
|
882 |
+
# so can't go by length alone
|
883 |
+
if self.pre_response:
|
884 |
+
outputi = output.find(prompt)
|
885 |
+
if outputi >= 0:
|
886 |
+
output = output[outputi + len(prompt):]
|
887 |
+
allow_terminate = True
|
888 |
+
else:
|
889 |
+
# subtraction is risky due to space offsets sometimes, so only do if necessary
|
890 |
+
output = output[len(prompt) - len(self.pre_response):]
|
891 |
+
# [1] to avoid repeated pre_response, just take first (after prompt - pre_response for chat)
|
892 |
+
if self.pre_response in output:
|
893 |
+
output = output.split(self.pre_response)[1]
|
894 |
+
allow_terminate = True
|
895 |
+
else:
|
896 |
+
if output:
|
897 |
+
print("Failure of parsing or not enough output yet: %s" % output, flush=True)
|
898 |
+
allow_terminate = False
|
899 |
+
else:
|
900 |
+
allow_terminate = True
|
901 |
+
output = output[len(prompt):]
|
902 |
+
# clean after subtract prompt out, so correct removal of pre_response
|
903 |
+
output = clean_response(output).strip()
|
904 |
+
if self.repeat_penalty:
|
905 |
+
output = clean_repeats(output).strip()
|
906 |
+
if self.terminate_response and allow_terminate:
|
907 |
+
finds = []
|
908 |
+
for term in self.terminate_response:
|
909 |
+
finds.append(output.find(term))
|
910 |
+
finds = [x for x in finds if x >= 0]
|
911 |
+
if len(finds) > 0:
|
912 |
+
termi = finds[0]
|
913 |
+
output = output[:termi].strip()
|
914 |
+
else:
|
915 |
+
output = output.strip()
|
916 |
+
else:
|
917 |
+
output = output.strip()
|
918 |
+
if multi_output:
|
919 |
+
# prefix with output counter
|
920 |
+
output = "\n=========== Output %d\n\n" % (1 + oi) + output
|
921 |
+
if oi > 0:
|
922 |
+
# post fix outputs with seperator
|
923 |
+
output += '\n'
|
924 |
+
outputs[oi] = output
|
925 |
+
# join all outputs, only one extra new line between outputs
|
926 |
+
output = '\n'.join(outputs)
|
927 |
+
if self.debug:
|
928 |
+
print("outputclean:\n%s" % '\n\n'.join(outputs), flush=True)
|
929 |
+
return output
|