from transformers import TextGenerationPipeline | |
from transformers.pipelines.text_generation import ReturnType | |
STYLE = "<|prompt|>{instruction}</s><|answer|>" | |
class MambaGPTTextGenerationPipeline(TextGenerationPipeline): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.prompt = STYLE | |
def preprocess( | |
self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs | |
): | |
prompt_text = self.prompt.format(instruction=prompt_text) | |
return super().preprocess( | |
prompt_text, | |
prefix=prefix, | |
handle_long_generation=handle_long_generation, | |
**generate_kwargs, | |
) | |
def postprocess( | |
self, | |
model_outputs, | |
return_type=ReturnType.FULL_TEXT, | |
clean_up_tokenization_spaces=True, | |
): | |
records = super().postprocess( | |
model_outputs, | |
return_type=return_type, | |
clean_up_tokenization_spaces=clean_up_tokenization_spaces, | |
) | |
for rec in records: | |
rec["generated_text"] = ( | |
rec["generated_text"] | |
.split("<|answer|>")[1] | |
.strip() | |
.split("<|prompt|>")[0] | |
.strip() | |
) | |
return records | |