|
import os |
|
|
|
import torch |
|
from transformers import TextGenerationPipeline |
|
from transformers.pipelines.text_generation import ReturnType, Chat |
|
|
|
from stopping import get_stopping |
|
from prompter import Prompter, convert_messages_and_extract_images, get_prompt |
|
|
|
|
|
class H2OTextGenerationPipeline(TextGenerationPipeline): |
|
def __init__(self, *args, debug=False, chat=False, stream_output=False, |
|
sanitize_bot_response=False, |
|
use_prompter=True, prompter=None, |
|
context='', iinput='', |
|
chat_conversation=[], |
|
user_prompt_for_fake_system_prompt=None, |
|
prompt_type=None, prompt_dict=None, |
|
max_input_tokens=2048 - 256, |
|
base_model=None, |
|
stop=None, |
|
truncation_generation=None, |
|
max_time=None, |
|
|
|
image_file=None, |
|
image_control=None, |
|
images_num_max=None, |
|
image_resolution=None, |
|
image_format=None, |
|
rotate_align_resize_image=None, |
|
video_frame_period=None, |
|
image_batch_image_prompt=None, |
|
image_batch_final_prompt=None, |
|
image_batch_stream=None, |
|
visible_vision_models=None, |
|
video_file=None, |
|
|
|
verbose=False, |
|
**kwargs): |
|
""" |
|
HF-like pipeline, but handle instruction prompting and stopping (for some models) |
|
:param args: |
|
:param debug: |
|
:param chat: |
|
:param stream_output: |
|
:param sanitize_bot_response: |
|
:param use_prompter: Whether to use prompter. If pass prompt_type, will make prompter |
|
:param prompter: prompter, can pass if have already |
|
:param prompt_type: prompt_type, e.g. human_bot. See prompt_type to model mapping in from prompter.py. |
|
If use_prompter, then will make prompter and use it. |
|
:param prompt_dict: dict of get_prompt(, return_dict=True) for prompt_type=custom |
|
:param max_input_tokens: |
|
:param kwargs: |
|
""" |
|
super().__init__(*args, **kwargs) |
|
self.prompt_text = None |
|
self.use_prompter = use_prompter |
|
self.prompts = [] |
|
self.prompt_type = prompt_type |
|
self.prompt_dict = prompt_dict |
|
self.prompter = prompter |
|
self.context = context |
|
self.iinput = iinput |
|
self.chat_conversation = chat_conversation |
|
self.user_prompt_for_fake_system_prompt = user_prompt_for_fake_system_prompt |
|
self.debug = debug |
|
if self.use_prompter: |
|
if self.prompter is not None: |
|
assert self.prompter.prompt_type is not None |
|
else: |
|
self.prompter = Prompter(self.prompt_type, self.prompt_dict, debug=debug, |
|
stream_output=stream_output, tokenizer=self.tokenizer, |
|
base_model=base_model) |
|
self.human = self.prompter.humanstr |
|
self.bot = self.prompter.botstr |
|
self.can_stop = True |
|
else: |
|
self.prompter = None |
|
self.human = None |
|
self.bot = None |
|
self.can_stop = False |
|
self.stop = stop |
|
self.sanitize_bot_response = sanitize_bot_response |
|
self.max_input_tokens = max_input_tokens |
|
self.base_model = base_model |
|
self.verbose = verbose |
|
self.truncation_generation = truncation_generation |
|
self.max_time = max_time |
|
|
|
self.image_file = image_file |
|
self.image_control = image_control |
|
self.images_num_max = images_num_max |
|
self.image_resolution = image_resolution |
|
self.image_format = image_format |
|
self.rotate_align_resize_image = rotate_align_resize_image |
|
self.video_frame_period = video_frame_period |
|
self.image_batch_image_prompt = image_batch_image_prompt |
|
self.image_batch_final_prompt = image_batch_final_prompt |
|
self.image_batch_stream = image_batch_stream |
|
self.visible_vision_models = visible_vision_models |
|
self.video_file = video_file |
|
|
|
@staticmethod |
|
def get_token_count(x, tokenizer): |
|
|
|
|
|
if hasattr(tokenizer, 'encode'): |
|
tokens = tokenizer.encode(x) |
|
else: |
|
tokens = tokenizer(x) |
|
if isinstance(tokens, dict) and 'input_ids' in tokens: |
|
tokens = tokens['input_ids'] |
|
if isinstance(tokens, list): |
|
n_tokens = len(tokens) |
|
elif len(tokens.shape) == 2: |
|
n_tokens = tokens.shape[1] |
|
elif len(tokens.shape) == 1: |
|
n_tokens = tokens.shape[0] |
|
else: |
|
raise RuntimeError("Cannot handle tokens: %s" % tokens) |
|
return n_tokens |
|
|
|
@staticmethod |
|
def limit_prompt(prompt_text, tokenizer, max_prompt_length=None, buffer=256): |
|
if prompt_text is None: |
|
prompt_text = '' |
|
verbose = bool(int(os.getenv('VERBOSE_PIPELINE', '0'))) |
|
|
|
if hasattr(tokenizer, 'model_max_length'): |
|
|
|
model_max_length = int(tokenizer.model_max_length) |
|
if max_prompt_length is not None: |
|
model_max_length = int(min(model_max_length, max_prompt_length)) |
|
buffer = 0 |
|
|
|
|
|
if model_max_length == 0: |
|
len0 = len(prompt_text) |
|
prompt_text = '' |
|
if verbose: |
|
print("Cut of input: %s -> %s" % (len0, len(prompt_text)), flush=True) |
|
elif len(prompt_text) > model_max_length * 10: |
|
len0 = len(prompt_text) |
|
prompt_text = prompt_text[-model_max_length * 10:] |
|
if verbose: |
|
print("Cut of input: %s -> %s" % (len0, len(prompt_text)), flush=True) |
|
elif max_prompt_length is not None: |
|
model_max_length = max_prompt_length |
|
else: |
|
|
|
model_max_length = None |
|
|
|
num_prompt_tokens = None |
|
if model_max_length is not None: |
|
|
|
|
|
for trial in range(0, 5): |
|
if prompt_text: |
|
num_prompt_tokens = H2OTextGenerationPipeline.get_token_count(prompt_text, tokenizer) |
|
else: |
|
num_prompt_tokens = 0 |
|
if num_prompt_tokens > model_max_length and num_prompt_tokens > 0: |
|
|
|
chars_per_token = len(prompt_text) / num_prompt_tokens |
|
|
|
model_max_length_with_buffer = max(0, model_max_length - buffer) |
|
prompt_text = prompt_text[-int(model_max_length_with_buffer * chars_per_token):] |
|
if verbose: |
|
print("reducing %s tokens, assuming average of %s chars/token for %s characters" % ( |
|
num_prompt_tokens, chars_per_token, len(prompt_text)), flush=True) |
|
else: |
|
if verbose: |
|
print("using %s tokens with %s chars" % (num_prompt_tokens, len(prompt_text)), flush=True) |
|
break |
|
if num_prompt_tokens is not None and num_prompt_tokens > model_max_length and model_max_length > 0: |
|
print( |
|
"Failed to reduce %s tokens with %s chars: %s" % (num_prompt_tokens, len(prompt_text), prompt_text), |
|
flush=True) |
|
|
|
return prompt_text, num_prompt_tokens |
|
|
|
def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs): |
|
prompt_text, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer) |
|
|
|
data_point = dict(context=self.context, instruction=prompt_text, input=self.iinput) |
|
if self.prompter is not None and not self.image_file: |
|
prompt_text = self.prompter.generate_prompt(data_point, |
|
chat_conversation=self.chat_conversation, |
|
user_prompt_for_fake_system_prompt=self.user_prompt_for_fake_system_prompt, |
|
) |
|
|
|
self.prompt_text = prompt_text |
|
self.prompts.append(prompt_text) |
|
if handle_long_generation is None: |
|
|
|
handle_long_generation = None |
|
return self._preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation, |
|
**generate_kwargs) |
|
|
|
def _preprocess( |
|
self, |
|
prompt_text, |
|
prefix="", |
|
handle_long_generation=None, |
|
add_special_tokens=False, |
|
truncation=None, |
|
padding=False, |
|
max_length=None, |
|
**generate_kwargs, |
|
): |
|
if self.image_file: |
|
from transformers.image_utils import load_image |
|
images = [load_image(x) for x in self.image_file] |
|
|
|
|
|
from transformers import AutoProcessor |
|
|
|
processor = AutoProcessor.from_pretrained(self.base_model) |
|
|
|
history = self.chat_conversation.copy() |
|
history.append([(prompt_text, images), None]) |
|
|
|
messages, images = convert_messages_and_extract_images(history) |
|
prompt = processor.apply_chat_template(messages, add_generation_prompt=True) |
|
inputs = processor(text=prompt, images=images, return_tensors="pt") |
|
|
|
raise NotImplementedError("Not functioning yet.") |
|
elif isinstance(prompt_text, Chat): |
|
inputs = self.tokenizer.apply_chat_template( |
|
prompt_text.messages, |
|
truncation=truncation, |
|
padding=padding, |
|
max_length=max_length, |
|
add_generation_prompt=True, |
|
return_dict=True, |
|
return_tensors=self.framework, |
|
) |
|
else: |
|
inputs = self.tokenizer( |
|
prefix + prompt_text, |
|
truncation=truncation, |
|
padding=padding, |
|
max_length=max_length, |
|
add_special_tokens=add_special_tokens, |
|
return_tensors=self.framework, |
|
) |
|
inputs["prompt_text"] = prompt_text |
|
|
|
if handle_long_generation == "hole": |
|
cur_len = inputs["input_ids"].shape[-1] |
|
if "max_new_tokens" in generate_kwargs: |
|
new_tokens = generate_kwargs["max_new_tokens"] |
|
else: |
|
new_tokens = generate_kwargs.get("max_length", self.model.config.max_length) - cur_len |
|
if new_tokens < 0: |
|
raise ValueError("We cannot infer how many new tokens are expected") |
|
if cur_len + new_tokens > self.tokenizer.model_max_length: |
|
keep_length = self.tokenizer.model_max_length - new_tokens |
|
if keep_length <= 0: |
|
raise ValueError( |
|
"We cannot use `hole` to handle this generation the number of desired tokens exceeds the" |
|
" models max length" |
|
) |
|
|
|
inputs["input_ids"] = inputs["input_ids"][:, -keep_length:] |
|
if "attention_mask" in inputs: |
|
inputs["attention_mask"] = inputs["attention_mask"][:, -keep_length:] |
|
|
|
return inputs |
|
|
|
def _postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True, |
|
conditional_type=False): |
|
generated_sequence = model_outputs["generated_sequence"][0] |
|
input_ids = model_outputs["input_ids"] |
|
prompt_text = model_outputs["prompt_text"] |
|
generated_sequence = generated_sequence.numpy().tolist() |
|
records = [] |
|
for sequence in generated_sequence: |
|
if return_type == ReturnType.TENSORS: |
|
record = {"generated_token_ids": sequence} |
|
elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}: |
|
|
|
text = self.tokenizer.decode( |
|
sequence, |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces, |
|
) |
|
if conditional_type: |
|
all_text = text |
|
else: |
|
|
|
if input_ids is None: |
|
prompt_length = 0 |
|
else: |
|
prompt_length = len( |
|
self.tokenizer.decode( |
|
input_ids[0], |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces, |
|
) |
|
) |
|
|
|
if return_type == ReturnType.FULL_TEXT: |
|
all_text = prompt_text + text[prompt_length:] |
|
else: |
|
all_text = text[prompt_length:] |
|
|
|
record = {"generated_text": all_text} |
|
records.append(record) |
|
|
|
return records |
|
|
|
def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True): |
|
conditional_type = hasattr(self.model, 'conditional_type') and self.model.conditional_type |
|
records = self._postprocess(model_outputs, return_type=return_type, |
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces, |
|
conditional_type=conditional_type) |
|
key = 'generated_text' |
|
for rec in records: |
|
if self.use_prompter: |
|
outputs = rec[key] |
|
if return_type == ReturnType.NEW_TEXT: |
|
output_with_prompt = outputs |
|
prompt = None |
|
only_new_text = True |
|
elif conditional_type: |
|
if self.prompter.botstr: |
|
prompt = self.prompter.botstr |
|
output_with_prompt = prompt + outputs |
|
only_new_text = False |
|
else: |
|
prompt = None |
|
output_with_prompt = outputs |
|
only_new_text = True |
|
else: |
|
output_with_prompt = outputs |
|
prompt = self.prompt_text |
|
only_new_text = False |
|
outputs = self.prompter.get_response(output_with_prompt, prompt=prompt, |
|
only_new_text=only_new_text, |
|
sanitize_bot_response=self.sanitize_bot_response) |
|
elif self.bot in rec[key]: |
|
if self.human: |
|
outputs = rec[key].split(self.bot)[-1].split(self.human)[0] |
|
else: |
|
outputs = rec[key].split(self.bot)[-1].split(self.bot)[0] |
|
else: |
|
outputs = rec[key] |
|
rec[key] = outputs |
|
if self.debug: |
|
print("prompt: %s\noutputs: %s\n\n" % (self.prompt_text, outputs), flush=True) |
|
if hasattr(self.model, 'memory') and hasattr(self.model.memory, 'reset'): |
|
self.model.memory.reset() |
|
|
|
return records |
|
|
|
def _forward(self, model_inputs, **generate_kwargs): |
|
stop = [] |
|
if generate_kwargs.get('stop'): |
|
stop += generate_kwargs['stop'] |
|
if self.stop: |
|
stop += self.stop |
|
stop = sorted(set(self.stop)) |
|
if self.can_stop or stop: |
|
self.stopping_criteria = get_stopping(self.prompt_type, self.prompt_dict, |
|
self.tokenizer, self.device, |
|
self.base_model, |
|
human=self.human, bot=self.bot, |
|
model_max_length=self.tokenizer.model_max_length, |
|
prompter=self.prompter, |
|
stop=stop, |
|
truncation_generation=self.truncation_generation, |
|
max_time=self.max_time) |
|
generate_kwargs['stopping_criteria'] = self.stopping_criteria |
|
generate_kwargs.pop('stop', None) |
|
|
|
return self.__forward(model_inputs, **generate_kwargs) |
|
|
|
|
|
|
|
def __forward(self, model_inputs, **generate_kwargs): |
|
input_ids = model_inputs["input_ids"] |
|
attention_mask = model_inputs.get("attention_mask", None) |
|
|
|
if input_ids.shape[1] == 0: |
|
input_ids = None |
|
attention_mask = None |
|
in_b = 1 |
|
else: |
|
in_b = input_ids.shape[0] |
|
prompt_text = model_inputs.pop("prompt_text") |
|
|
|
|
|
|
|
|
|
prefix_length = generate_kwargs.pop("prefix_length", 0) |
|
if prefix_length > 0: |
|
has_max_new_tokens = "max_new_tokens" in generate_kwargs or ( |
|
"generation_config" in generate_kwargs |
|
and generate_kwargs["generation_config"].max_new_tokens is not None |
|
) |
|
if not has_max_new_tokens: |
|
generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length |
|
generate_kwargs["max_length"] += prefix_length |
|
has_min_new_tokens = "min_new_tokens" in generate_kwargs or ( |
|
"generation_config" in generate_kwargs |
|
and generate_kwargs["generation_config"].min_new_tokens is not None |
|
) |
|
if not has_min_new_tokens and "min_length" in generate_kwargs: |
|
generate_kwargs["min_length"] += prefix_length |
|
|
|
|
|
seed = generate_kwargs.pop('seed', 1234) |
|
torch.manual_seed(seed) |
|
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs) |
|
out_b = generated_sequence.shape[0] |
|
if self.framework == "pt": |
|
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:]) |
|
elif self.framework == "tf": |
|
from transformers import is_tf_available |
|
if is_tf_available(): |
|
import tensorflow as tf |
|
generated_sequence = tf.reshape(generated_sequence, |
|
(in_b, out_b // in_b, *generated_sequence.shape[1:])) |
|
else: |
|
raise ValueError("TF not avaialble.") |
|
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text} |
|
|