caveman1 commited on
Commit
c3b06c0
·
verified ·
1 Parent(s): 038f3f2

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +248 -0
  2. google_translate.py +56 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+
5
+ try:
6
+ # переводчик с русского на английский
7
+ from google_translate import TranslatorWithCache
8
+ is_google_translate_installed=True
9
+ translator = TranslatorWithCache()
10
+ except ImportError:
11
+ is_google_translate_installed=False
12
+
13
+
14
+ try:
15
+ from config_ui import Config
16
+ is_config_ui_installed=True
17
+ config = Config()
18
+ device = "cuda" if (config.cuda=="cuda" and torch.cuda.is_available()) else "cpu"
19
+ lang=config.lang
20
+ except ImportError:
21
+ is_config_ui_installed=False
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ lang='EN'
24
+
25
+ try:
26
+ from prompt.portrait_prompt import generate_random_portrait_prompt
27
+ is_rnd_gen_installed=True
28
+ except:
29
+ is_rnd_gen_installed=False
30
+
31
+
32
+ model_checkpoint = "gokaygokay/Flux-Prompt-Enhance"
33
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
34
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to(device)
35
+
36
+ max_target_length = 256
37
+ prefix = "enhance prompt"
38
+
39
+ def enhance_prompt(prompt, system_prompt, temperature=0.5, repetition_penalty=1.2, seed=-1, is_rnd_seed=True):
40
+
41
+ if is_rnd_seed or seed==-1:
42
+ seed = torch.randint(0, 2**32 - 1, (1,)).item()
43
+
44
+ torch.manual_seed(seed)
45
+
46
+ if is_google_translate_installed:
47
+ # Перевод с русского на английский
48
+ en_prompt = translator.translate_ru2eng(prompt)
49
+ input_text = f"{system_prompt}: {en_prompt}"
50
+ else:
51
+ input_text = f"{system_prompt}: {prompt}"
52
+
53
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
54
+
55
+ # Генерация текста
56
+ outputs = model.generate(
57
+ input_ids,
58
+ max_length=max_target_length,
59
+ num_return_sequences=1,
60
+ do_sample=True,
61
+ temperature=temperature,
62
+ repetition_penalty=repetition_penalty
63
+ )
64
+
65
+ generated_text_en = tokenizer.decode(outputs[0], skip_special_tokens=True)
66
+
67
+ if is_google_translate_installed:
68
+ result_output_ru = translator.translate_eng2ru(generated_text_en)
69
+ else:
70
+ result_output_ru=generated_text_en
71
+
72
+ return seed, generated_text_en, result_output_ru
73
+
74
+ def random_prompt():
75
+ rnd_prompt_str=generate_random_portrait_prompt()
76
+ #rnd_prompt_str=get_random_words()
77
+ return rnd_prompt_str
78
+
79
+ # Функция копирования текста в буфер
80
+ def copy_to_clipboard(text):
81
+ gr.Info("скопировано в буффер обмена" if (lang=="RU") else "copy to clipboard" ,duration=1)
82
+ return None
83
+
84
+ LABELS_EN={"prompt_input": "Input initial prompt:",
85
+ "seed_output": "Seed:",
86
+ "result_output" : "Improved prompt",
87
+ "result_output_ru" : "Improved prompt (in Russian)",
88
+ "generate_button": "Improve prompt",
89
+ "copy_button": "Copy to clipboard",
90
+ "save_button": "Save config",
91
+ "system_prompt" : "System prompt",
92
+ "temperature": "Temperature",
93
+ "repetition_penalty": "Repetition penalty",
94
+ "is_rnd_seed": "Random Seed"
95
+ }
96
+ LABELS=LABELS_EN
97
+
98
+ if is_google_translate_installed:
99
+
100
+ LABELS_RU={"prompt_input": "Введите начальный промпт:",
101
+ "seed_output": "Seed для генерации:",
102
+ "result_output" : "Улучшенный промпт (на английском):",
103
+ "result_output_ru" : "Улучшенный промпт (на русском):",
104
+ "generate_button": "Улучшить промпт",
105
+ "copy_button": "Скопировать в буффер обмена",
106
+ "save_button": "Сохранить настройки",
107
+ "system_prompt": "Системный промпт",
108
+ "temperature": "Температура",
109
+ "repetition_penalty": "Штраф за повторение",
110
+ "is_rnd_seed": "Случайный Seed"
111
+ }
112
+ LABELS=LABELS_EN if lang=="EN" else LABELS_RU
113
+
114
+ if is_google_translate_installed:
115
+ def process_lang(selected_lang):
116
+ global lang
117
+ lang=selected_lang
118
+ if selected_lang == "RU":
119
+ LABELS=LABELS_RU
120
+ message="Вы выбрали русский"
121
+ isVisible=True
122
+ elif selected_lang == "EN":
123
+ LABELS=LABELS_EN
124
+ message="You selected English"
125
+ isVisible=False
126
+ ret = [gr.update(value=LABELS["generate_button"]),
127
+ gr.update(value=LABELS["copy_button"]),
128
+ gr.update(value=LABELS["save_button"]),
129
+ gr.update(label=LABELS["prompt_input"]),
130
+ gr.update(label=LABELS["seed_output"]),
131
+ gr.update(label=LABELS["is_rnd_seed"]),
132
+ gr.update(label=LABELS["result_output"]),
133
+ gr.update(visible=isVisible, label=LABELS["result_output_ru"]),
134
+ gr.update(label=LABELS["system_prompt"]),
135
+ gr.update(label=LABELS["temperature"]),
136
+ gr.update(label=LABELS["repetition_penalty"])
137
+ ]
138
+ return message, *ret
139
+
140
+
141
+ if is_config_ui_installed:
142
+ def save_config():
143
+ global lang,device,isOpenAdvanced ,config, AccordionAdvanced
144
+
145
+ config.set_lang(lang)
146
+ config.set_cuda(str(device))
147
+ isOpenAdvanced=AccordionAdvanced.open
148
+ print(AccordionAdvanced.open)
149
+ config.set_OpenAdvanced=(isOpenAdvanced)
150
+
151
+ # Сохраняем изменения в файл
152
+ config.save()
153
+ return "save config to file" if lang=='EN' else "Конфигурация сохранена в файл"
154
+
155
+ def process_gpu(selected_gpu):
156
+ """Функция для переключения модели между устройствами (CPU / CUDA)"""
157
+ global model, device # Используем глобальные переменные model и device
158
+ device = torch.device(selected_gpu) # Устанавливаем новое устройство
159
+ model = model.to(device) # Переносим модель на новое устройство
160
+ message= f"Модель переключена на устройство: {selected_gpu}" if lang=="RU" else f"Model switched to device: {selected_gpu}"
161
+ return message
162
+
163
+ def set_initial():
164
+ global device
165
+ dev="cpu"
166
+ if str(device) =='cuda':
167
+ device = torch.cuda.current_device()
168
+ device_name = torch.cuda.get_device_name(device)
169
+ device_name = f"GPU: {device_name}"
170
+ dev="cuda"
171
+ else:
172
+ device_name = "use CPU"
173
+ return gr.update(value=lang), gr.update(value=dev), f'{device_name}\nset to "{lang}" language'
174
+
175
+
176
+ # Настройка интерфейса Gradio
177
+ with gr.Blocks(title="Flux Prompt Enhance",
178
+ theme=gr.themes.Default(primary_hue=gr.themes.colors.sky, secondary_hue=gr.themes.colors.indigo),
179
+ analytics_enabled=False, css="footer{display:none !important}") as demo:
180
+ gr.Image(label="header AiCave", value="./static/ai_cave_title.jpg",height="100%",
181
+ show_download_button=False, show_label=False, show_share_button=False,
182
+ interactive=False, show_fullscreen_button=False,)
183
+
184
+ with gr.Row(variant="default"):
185
+ gr.HTML("""
186
+ <h1>Flux Prompt Enhance portable by <a href='https://boosty.to/aicave/donate' style="color: #4AA0E2;">CaveMan</a></h1>
187
+ """)
188
+ with gr.Row(variant="default"):
189
+ # выбор языка UI
190
+ radio_lang = gr.Radio(choices = ["RU", "EN"], show_label = False, container = False, type = "value",
191
+ visible = True if is_google_translate_installed else False)
192
+
193
+ radio_gpu = gr.Radio(choices = ["cuda","cpu"], show_label = False, container = False, type = "value",
194
+ visible = True if torch.cuda.is_available() else False)
195
+ save_button = gr.Button(LABELS["save_button"], visible= True if is_config_ui_installed else False)
196
+
197
+ with gr.Row(variant="default"):
198
+ prompt_input = gr.Textbox(label=LABELS["prompt_input"])
199
+ if is_rnd_gen_installed:
200
+ button_random = gr.Button("", icon="./static/random.png", scale=0, min_width=200)
201
+ button_random.click(fn=random_prompt, outputs=prompt_input)
202
+
203
+
204
+ with gr.Accordion("Advanced:", open=False ) as AccordionAdvanced:
205
+ with gr.Row(variant="default"):
206
+ system_prompt = gr.Textbox(label=LABELS["system_prompt"], interactive=False,value=prefix)
207
+ seed_output = gr.Textbox(label=LABELS["seed_output"], interactive=True,value=502119)
208
+ is_rnd_seed = gr.Checkbox(value=True, label="Random seed", interactive=True)
209
+
210
+ with gr.Row(variant="default"):
211
+ temperature = gr.Slider(label=LABELS["temperature"], interactive=True,value=0.7, minimum=0.1,maximum=1,step=0.1)
212
+ repetition_penalty = gr.Slider(label=LABELS["repetition_penalty"], interactive=True,value=1.2, minimum=0.1,maximum=2,step=0.1)
213
+ #repetition_penalty =
214
+
215
+
216
+ result_output = gr.Textbox(label=LABELS["result_output"], interactive=False)
217
+ result_output_ru = gr.Textbox(label=LABELS["result_output_ru"], interactive=False, visible = False if lang == "EN" else True)
218
+
219
+ #prompt_input.submit(fn=enhance_prompt, inputs=[prompt_input,system_prompt,temperature,repetition_penalty], outputs=[seed_output, result_output, result_output_ru], show_progress=False)
220
+
221
+ # Кнопка генерации
222
+ with gr.Row(variant="default"):
223
+ generate_button = gr.Button(LABELS["generate_button"], variant="primary", size="lg")
224
+ generate_button.click(fn=enhance_prompt, inputs=[prompt_input,system_prompt,temperature,repetition_penalty,seed_output,is_rnd_seed],
225
+ outputs=[seed_output, result_output, result_output_ru])
226
+
227
+ # Кнопка копирования в буфер обмена
228
+ copy_button = gr.Button(LABELS["copy_button"], variant="secondary")
229
+ copy_button.click(fn=copy_to_clipboard, inputs=result_output, outputs=[],js="(text) => navigator.clipboard.writeText(text)")
230
+ with gr.Row(variant="default"):
231
+ log_text = gr.Textbox(label="")
232
+
233
+ if is_config_ui_installed:
234
+ save_button.click(fn=save_config, inputs=[], outputs=log_text)
235
+
236
+ #preload values for lang
237
+ demo.load(set_initial, outputs=[radio_lang, radio_gpu, log_text])
238
+
239
+ if is_google_translate_installed:
240
+ radio_lang.change(process_lang, inputs=radio_lang,
241
+ outputs=[log_text,generate_button, copy_button, save_button, prompt_input, seed_output, is_rnd_seed,
242
+ result_output, result_output_ru,system_prompt, temperature, repetition_penalty])
243
+
244
+ radio_gpu.change(process_gpu, inputs=radio_gpu, outputs=log_text)
245
+
246
+
247
+ # Запуск приложения с прослушиванием на всех интерфейсах и открытием в браузере
248
+ demo.launch()
google_translate.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import time
3
+ import random
4
+ import os
5
+ from googletrans import Translator
6
+ from functools import lru_cache
7
+
8
+ TRANSLATOR_CACHE_SIZE = int(os.getenv("TRANSLATOR_CACHE_SIZE",100))
9
+
10
+ class TranslatorWithCache:
11
+ def __init__(self, cache_size=TRANSLATOR_CACHE_SIZE):
12
+ self.cache_size = cache_size
13
+ self._init_cache()
14
+
15
+ def _init_cache(self):
16
+ self.cached_translate = lru_cache(maxsize=self.cache_size)(self._translate)
17
+
18
+ def has_russian_word(self, text):
19
+ pattern = re.compile(r'[а-яА-ЯёЁ]+')
20
+ return bool(pattern.search(text))
21
+
22
+ def has_english_word(self, text):
23
+ pattern = re.compile(r'[a-zA-Z]+')
24
+ return bool(pattern.search(text))
25
+
26
+ def _translate(self, text: str, src: str, dest: str, retries=3) -> str:
27
+ tr = Translator()
28
+ for attempt in range(retries):
29
+ try:
30
+ return tr.translate(text, src=src, dest=dest).text
31
+ except TypeError as e:
32
+ if attempt < retries - 1:
33
+ delay = random.uniform(1, 2)
34
+ print(f"Ошибка: {e}. Повторная попытка через {delay:.2f} секунд...")
35
+ time.sleep(delay)
36
+ else:
37
+ print(f"Ошибка: {e}. Превышено количество попыток.")
38
+ return text # Возвращаем оригинальный текст, если все попытки провалились.
39
+
40
+ def translate_ru2eng(self, text: str, src='ru', dest='en') -> str:
41
+ if self.has_russian_word(text):
42
+ return self.cached_translate(text, src, dest)
43
+ else:
44
+ return text
45
+
46
+ def translate_eng2ru(self, text: str, src='en', dest='ru') -> str:
47
+ if self.has_english_word(text):
48
+ return self.cached_translate(text, src, dest)
49
+ else:
50
+ return text
51
+
52
+ if __name__ == "__main__":
53
+ translator = TranslatorWithCache(cache_size=100)
54
+ print(translator.translate_ru2eng("Привет, как дела?"))
55
+ print(translator.translate_eng2ru("Hello, how are you?"))
56
+ print(translator.translate_ru2eng("Привет, как дела?")) # Повторный вызов для проверки кэша
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ git+https://github.com/c-goosen/py-googletrans.git
2
+ torch
3
+ transformers
4
+ gradio