File size: 6,016 Bytes
80fefdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390173a
 
80fefdb
 
 
390173a
 
80fefdb
390173a
 
80fefdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390173a
 
 
 
 
 
 
 
 
 
80fefdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390173a
80fefdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390173a
80fefdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import pipeline, set_seed
import random
import re
from .singleton import Singleton

device = "cuda" if torch.cuda.is_available() else "cpu"


@Singleton
class Models(object):

    def __getattr__(self, item):
        if item in self.__dict__:
            return getattr(self, item)

        if item in ('microsoft_model', 'microsoft_tokenizer'):
            self.microsoft_model, self.microsoft_tokenizer = self.load_microsoft_model()

        if item in ('mj_pipe',):
            self.mj_pipe = self.load_mj_pipe()

        if item in ('gpt2_650k_pipe',):
            self.gpt2_650k_pipe = self.load_gpt2_650k_pipe()

        if item in ('gpt_neo_125m',):
            self.gpt2_650k_pipe = self.load_gpt_neo_125m()
        return getattr(self, item)

    @classmethod
    def load_gpt_neo_125m(cls):
        return pipeline('text-generation', model='DrishtiSharma/StableDiffusion-Prompt-Generator-GPT-Neo-125M')

    @classmethod
    def load_gpt2_650k_pipe(cls):
        return pipeline('text-generation', model='Ar4ikov/gpt2-650k-stable-diffusion-prompt-generator')

    @classmethod
    def load_mj_pipe(cls):
        return pipeline('text-generation', model='succinctly/text2image-prompt-generator')

    @classmethod
    def load_microsoft_model(cls):
        prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
        tokenizer = AutoTokenizer.from_pretrained("gpt2")
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "left"
        return prompter_model, tokenizer


models = Models.instance()


def rand_length(min_length: int = 60, max_length: int = 90) -> int:
    if min_length > max_length:
        return max_length

    return random.randint(min_length, max_length)


def generate_prompt(
        plain_text,
        min_length=60,
        max_length=90,
        num_return_sequences=8,
        model_name='microsoft',
):
    if model_name == 'gpt2_650k':
        return generate_prompt_pipe(
            models.gpt2_650k_pipe,
            prompt=plain_text,
            min_length=min_length,
            max_length=max_length,
            num_return_sequences=num_return_sequences,
        )
    elif model_name == 'gpt_neo_125m':
        return generate_prompt_pipe(
            models.gpt_neo_125m,
            prompt=plain_text,
            min_length=min_length,
            max_length=max_length,
            num_return_sequences=num_return_sequences,
        )
    elif model_name == 'mj':
        return generate_prompt_mj(
            text_in_english=plain_text,
            num_return_sequences=num_return_sequences,
            min_length=min_length,
            max_length=max_length,
        )
    else:
        return generate_prompt_microsoft(
            plain_text=plain_text,
            min_length=min_length,
            max_length=max_length,
            num_return_sequences=num_return_sequences,
            num_beams=num_return_sequences,
        )


def generate_prompt_microsoft(
        plain_text,
        min_length=60,
        max_length=90,
        num_beams=8,
        num_return_sequences=8,
        length_penalty=-1.0
) -> str:
    input_ids = models.microsoft_tokenizer(plain_text.strip() + " Rephrase:", return_tensors="pt").input_ids
    eos_id = models.microsoft_tokenizer.eos_token_id

    outputs = models.microsoft_model.generate(
        input_ids,
        do_sample=False,
        max_new_tokens=rand_length(min_length, max_length),
        num_beams=num_beams,
        num_return_sequences=num_return_sequences,
        eos_token_id=eos_id,
        pad_token_id=eos_id,
        length_penalty=length_penalty
    )
    output_texts = models.microsoft_tokenizer.batch_decode(outputs, skip_special_tokens=True)
    result = []
    for output_text in output_texts:
        result.append(output_text.replace(plain_text + " Rephrase:", "").strip())

    return "\n".join(result)


def generate_prompt_pipe(pipe, prompt: str, min_length=60, max_length: int = 255, num_return_sequences: int = 8) -> str:
    def get_valid_prompt(text: str) -> str:
        dot_split = text.split('.')[0]
        n_split = text.split('\n')[0]

        return {
            len(dot_split) < len(n_split): dot_split,
            len(n_split) > len(dot_split): n_split,
            len(n_split) == len(dot_split): dot_split
        }[True]

    output = []
    for _ in range(6):

        output += [
            get_valid_prompt(result['generated_text']) for result in
            pipe(
                prompt,
                max_new_tokens=rand_length(min_length, max_length),
                num_return_sequences=num_return_sequences
            )
        ]
        output = list(set(output))
        if len(output) >= num_return_sequences:
            break

    # valid_prompt = get_valid_prompt(models.gpt2_650k_pipe(prompt, max_length=max_length)[0]['generated_text'])
    return "\n".join([o.strip() for o in output])


def generate_prompt_mj(text_in_english: str, num_return_sequences: int = 8, min_length=60, max_length=90) -> str:
    seed = random.randint(100, 1000000)
    set_seed(seed)

    result = ""
    for _ in range(6):
        sequences = models.mj_pipe(
            text_in_english,
            max_new_tokens=rand_length(min_length, max_length),
            num_return_sequences=num_return_sequences
        )
        list = []
        for sequence in sequences:
            line = sequence['generated_text'].strip()
            if line != text_in_english and len(line) > (len(text_in_english) + 4) and line.endswith(
                    (':', '-', '—')) is False:
                list.append(line)

        result = "\n".join(list)
        result = re.sub('[^ ]+\.[^ ]+', '', result)
        result = result.replace('<', '').replace('>', '')
        if result != '':
            break
    return result
    # return result, "\n".join(translate_en2zh(line) for line in result.split("\n") if len(line) > 0)