Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from transformers import pipeline, set_seed | |
import random | |
import re | |
from .singleton import Singleton | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
class Models(object): | |
def __getattr__(self, item): | |
if item in self.__dict__: | |
return getattr(self, item) | |
if item in ('microsoft_model', 'microsoft_tokenizer'): | |
self.microsoft_model, self.microsoft_tokenizer = self.load_microsoft_model() | |
if item in ('mj_pipe',): | |
self.mj_pipe = self.load_mj_pipe() | |
if item in ('gpt2_650k_pipe',): | |
self.gpt2_650k_pipe = self.load_gpt2_650k_pipe() | |
if item in ('gpt_neo_125m',): | |
self.gpt2_650k_pipe = self.load_gpt_neo_125m() | |
return getattr(self, item) | |
def load_gpt_neo_125m(cls): | |
return pipeline('text-generation', model='DrishtiSharma/StableDiffusion-Prompt-Generator-GPT-Neo-125M') | |
def load_gpt2_650k_pipe(cls): | |
return pipeline('text-generation', model='Ar4ikov/gpt2-650k-stable-diffusion-prompt-generator') | |
def load_mj_pipe(cls): | |
return pipeline('text-generation', model='succinctly/text2image-prompt-generator') | |
def load_microsoft_model(cls): | |
prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist") | |
tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.padding_side = "left" | |
return prompter_model, tokenizer | |
models = Models.instance() | |
def rand_length(min_length: int = 60, max_length: int = 90) -> int: | |
if min_length > max_length: | |
return max_length | |
return random.randint(min_length, max_length) | |
def generate_prompt( | |
plain_text, | |
min_length=60, | |
max_length=90, | |
num_return_sequences=8, | |
model_name='microsoft', | |
): | |
if model_name == 'gpt2_650k': | |
return generate_prompt_pipe( | |
models.gpt2_650k_pipe, | |
prompt=plain_text, | |
min_length=min_length, | |
max_length=max_length, | |
num_return_sequences=num_return_sequences, | |
) | |
elif model_name == 'gpt_neo_125m': | |
return generate_prompt_pipe( | |
models.gpt_neo_125m, | |
prompt=plain_text, | |
min_length=min_length, | |
max_length=max_length, | |
num_return_sequences=num_return_sequences, | |
) | |
elif model_name == 'mj': | |
return generate_prompt_mj( | |
text_in_english=plain_text, | |
num_return_sequences=num_return_sequences, | |
min_length=min_length, | |
max_length=max_length, | |
) | |
else: | |
return generate_prompt_microsoft( | |
plain_text=plain_text, | |
min_length=min_length, | |
max_length=max_length, | |
num_return_sequences=num_return_sequences, | |
num_beams=num_return_sequences, | |
) | |
def generate_prompt_microsoft( | |
plain_text, | |
min_length=60, | |
max_length=90, | |
num_beams=8, | |
num_return_sequences=8, | |
length_penalty=-1.0 | |
) -> str: | |
input_ids = models.microsoft_tokenizer(plain_text.strip() + " Rephrase:", return_tensors="pt").input_ids | |
eos_id = models.microsoft_tokenizer.eos_token_id | |
outputs = models.microsoft_model.generate( | |
input_ids, | |
do_sample=False, | |
max_new_tokens=rand_length(min_length, max_length), | |
num_beams=num_beams, | |
num_return_sequences=num_return_sequences, | |
eos_token_id=eos_id, | |
pad_token_id=eos_id, | |
length_penalty=length_penalty | |
) | |
output_texts = models.microsoft_tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
result = [] | |
for output_text in output_texts: | |
result.append(output_text.replace(plain_text + " Rephrase:", "").strip()) | |
return "\n".join(result) | |
def generate_prompt_pipe(pipe, prompt: str, min_length=60, max_length: int = 255, num_return_sequences: int = 8) -> str: | |
def get_valid_prompt(text: str) -> str: | |
dot_split = text.split('.')[0] | |
n_split = text.split('\n')[0] | |
return { | |
len(dot_split) < len(n_split): dot_split, | |
len(n_split) > len(dot_split): n_split, | |
len(n_split) == len(dot_split): dot_split | |
}[True] | |
output = [] | |
for _ in range(6): | |
output += [ | |
get_valid_prompt(result['generated_text']) for result in | |
pipe( | |
prompt, | |
max_new_tokens=rand_length(min_length, max_length), | |
num_return_sequences=num_return_sequences | |
) | |
] | |
output = list(set(output)) | |
if len(output) >= num_return_sequences: | |
break | |
# valid_prompt = get_valid_prompt(models.gpt2_650k_pipe(prompt, max_length=max_length)[0]['generated_text']) | |
return "\n".join([o.strip() for o in output]) | |
def generate_prompt_mj(text_in_english: str, num_return_sequences: int = 8, min_length=60, max_length=90) -> str: | |
seed = random.randint(100, 1000000) | |
set_seed(seed) | |
result = "" | |
for _ in range(6): | |
sequences = models.mj_pipe( | |
text_in_english, | |
max_new_tokens=rand_length(min_length, max_length), | |
num_return_sequences=num_return_sequences | |
) | |
list = [] | |
for sequence in sequences: | |
line = sequence['generated_text'].strip() | |
if line != text_in_english and len(line) > (len(text_in_english) + 4) and line.endswith( | |
(':', '-', '—')) is False: | |
list.append(line) | |
result = "\n".join(list) | |
result = re.sub('[^ ]+\.[^ ]+', '', result) | |
result = result.replace('<', '').replace('>', '') | |
if result != '': | |
break | |
return result | |
# return result, "\n".join(translate_en2zh(line) for line in result.split("\n") if len(line) > 0) | |