caveman1 commited on
Commit
920e1f5
·
verified ·
1 Parent(s): c3b06c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +248 -248
app.py CHANGED
@@ -1,248 +1,248 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
-
5
- try:
6
- # переводчик с русского на английский
7
- from google_translate import TranslatorWithCache
8
- is_google_translate_installed=True
9
- translator = TranslatorWithCache()
10
- except ImportError:
11
- is_google_translate_installed=False
12
-
13
-
14
- try:
15
- from config_ui import Config
16
- is_config_ui_installed=True
17
- config = Config()
18
- device = "cuda" if (config.cuda=="cuda" and torch.cuda.is_available()) else "cpu"
19
- lang=config.lang
20
- except ImportError:
21
- is_config_ui_installed=False
22
- device = "cuda" if torch.cuda.is_available() else "cpu"
23
- lang='EN'
24
-
25
- try:
26
- from prompt.portrait_prompt import generate_random_portrait_prompt
27
- is_rnd_gen_installed=True
28
- except:
29
- is_rnd_gen_installed=False
30
-
31
-
32
- model_checkpoint = "gokaygokay/Flux-Prompt-Enhance"
33
- tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
34
- model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to(device)
35
-
36
- max_target_length = 256
37
- prefix = "enhance prompt"
38
-
39
- def enhance_prompt(prompt, system_prompt, temperature=0.5, repetition_penalty=1.2, seed=-1, is_rnd_seed=True):
40
-
41
- if is_rnd_seed or seed==-1:
42
- seed = torch.randint(0, 2**32 - 1, (1,)).item()
43
-
44
- torch.manual_seed(seed)
45
-
46
- if is_google_translate_installed:
47
- # Перевод с русского на английский
48
- en_prompt = translator.translate_ru2eng(prompt)
49
- input_text = f"{system_prompt}: {en_prompt}"
50
- else:
51
- input_text = f"{system_prompt}: {prompt}"
52
-
53
- input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
54
-
55
- # Генерация текста
56
- outputs = model.generate(
57
- input_ids,
58
- max_length=max_target_length,
59
- num_return_sequences=1,
60
- do_sample=True,
61
- temperature=temperature,
62
- repetition_penalty=repetition_penalty
63
- )
64
-
65
- generated_text_en = tokenizer.decode(outputs[0], skip_special_tokens=True)
66
-
67
- if is_google_translate_installed:
68
- result_output_ru = translator.translate_eng2ru(generated_text_en)
69
- else:
70
- result_output_ru=generated_text_en
71
-
72
- return seed, generated_text_en, result_output_ru
73
-
74
- def random_prompt():
75
- rnd_prompt_str=generate_random_portrait_prompt()
76
- #rnd_prompt_str=get_random_words()
77
- return rnd_prompt_str
78
-
79
- # Функция копирования текста в буфер
80
- def copy_to_clipboard(text):
81
- gr.Info("скопировано в буффер обмена" if (lang=="RU") else "copy to clipboard" ,duration=1)
82
- return None
83
-
84
- LABELS_EN={"prompt_input": "Input initial prompt:",
85
- "seed_output": "Seed:",
86
- "result_output" : "Improved prompt",
87
- "result_output_ru" : "Improved prompt (in Russian)",
88
- "generate_button": "Improve prompt",
89
- "copy_button": "Copy to clipboard",
90
- "save_button": "Save config",
91
- "system_prompt" : "System prompt",
92
- "temperature": "Temperature",
93
- "repetition_penalty": "Repetition penalty",
94
- "is_rnd_seed": "Random Seed"
95
- }
96
- LABELS=LABELS_EN
97
-
98
- if is_google_translate_installed:
99
-
100
- LABELS_RU={"prompt_input": "Введите начальный промпт:",
101
- "seed_output": "Seed для генерации:",
102
- "result_output" : "Улучшенный промпт (на английском):",
103
- "result_output_ru" : "Улучшенный промпт (на русском):",
104
- "generate_button": "Улучшить промпт",
105
- "copy_button": "Скопировать в буффер обмена",
106
- "save_button": "Сохранить настройки",
107
- "system_prompt": "Системный промпт",
108
- "temperature": "Температура",
109
- "repetition_penalty": "Штраф за повторение",
110
- "is_rnd_seed": "Случайный Seed"
111
- }
112
- LABELS=LABELS_EN if lang=="EN" else LABELS_RU
113
-
114
- if is_google_translate_installed:
115
- def process_lang(selected_lang):
116
- global lang
117
- lang=selected_lang
118
- if selected_lang == "RU":
119
- LABELS=LABELS_RU
120
- message="Вы выбрали русский"
121
- isVisible=True
122
- elif selected_lang == "EN":
123
- LABELS=LABELS_EN
124
- message="You selected English"
125
- isVisible=False
126
- ret = [gr.update(value=LABELS["generate_button"]),
127
- gr.update(value=LABELS["copy_button"]),
128
- gr.update(value=LABELS["save_button"]),
129
- gr.update(label=LABELS["prompt_input"]),
130
- gr.update(label=LABELS["seed_output"]),
131
- gr.update(label=LABELS["is_rnd_seed"]),
132
- gr.update(label=LABELS["result_output"]),
133
- gr.update(visible=isVisible, label=LABELS["result_output_ru"]),
134
- gr.update(label=LABELS["system_prompt"]),
135
- gr.update(label=LABELS["temperature"]),
136
- gr.update(label=LABELS["repetition_penalty"])
137
- ]
138
- return message, *ret
139
-
140
-
141
- if is_config_ui_installed:
142
- def save_config():
143
- global lang,device,isOpenAdvanced ,config, AccordionAdvanced
144
-
145
- config.set_lang(lang)
146
- config.set_cuda(str(device))
147
- isOpenAdvanced=AccordionAdvanced.open
148
- print(AccordionAdvanced.open)
149
- config.set_OpenAdvanced=(isOpenAdvanced)
150
-
151
- # Сохраняем изменения в файл
152
- config.save()
153
- return "save config to file" if lang=='EN' else "Конфигурация сохранена в файл"
154
-
155
- def process_gpu(selected_gpu):
156
- """Функция для переключения модели между устройствами (CPU / CUDA)"""
157
- global model, device # Используем глобальные переменные model и device
158
- device = torch.device(selected_gpu) # Устанавливаем новое устройство
159
- model = model.to(device) # Переносим модель на новое устройство
160
- message= f"Модель переключена на устройство: {selected_gpu}" if lang=="RU" else f"Model switched to device: {selected_gpu}"
161
- return message
162
-
163
- def set_initial():
164
- global device
165
- dev="cpu"
166
- if str(device) =='cuda':
167
- device = torch.cuda.current_device()
168
- device_name = torch.cuda.get_device_name(device)
169
- device_name = f"GPU: {device_name}"
170
- dev="cuda"
171
- else:
172
- device_name = "use CPU"
173
- return gr.update(value=lang), gr.update(value=dev), f'{device_name}\nset to "{lang}" language'
174
-
175
-
176
- # Настройка интерфейса Gradio
177
- with gr.Blocks(title="Flux Prompt Enhance",
178
- theme=gr.themes.Default(primary_hue=gr.themes.colors.sky, secondary_hue=gr.themes.colors.indigo),
179
- analytics_enabled=False, css="footer{display:none !important}") as demo:
180
- gr.Image(label="header AiCave", value="./static/ai_cave_title.jpg",height="100%",
181
- show_download_button=False, show_label=False, show_share_button=False,
182
- interactive=False, show_fullscreen_button=False,)
183
-
184
- with gr.Row(variant="default"):
185
- gr.HTML("""
186
- <h1>Flux Prompt Enhance portable by <a href='https://boosty.to/aicave/donate' style="color: #4AA0E2;">CaveMan</a></h1>
187
- """)
188
- with gr.Row(variant="default"):
189
- # выбор языка UI
190
- radio_lang = gr.Radio(choices = ["RU", "EN"], show_label = False, container = False, type = "value",
191
- visible = True if is_google_translate_installed else False)
192
-
193
- radio_gpu = gr.Radio(choices = ["cuda","cpu"], show_label = False, container = False, type = "value",
194
- visible = True if torch.cuda.is_available() else False)
195
- save_button = gr.Button(LABELS["save_button"], visible= True if is_config_ui_installed else False)
196
-
197
- with gr.Row(variant="default"):
198
- prompt_input = gr.Textbox(label=LABELS["prompt_input"])
199
- if is_rnd_gen_installed:
200
- button_random = gr.Button("", icon="./static/random.png", scale=0, min_width=200)
201
- button_random.click(fn=random_prompt, outputs=prompt_input)
202
-
203
-
204
- with gr.Accordion("Advanced:", open=False ) as AccordionAdvanced:
205
- with gr.Row(variant="default"):
206
- system_prompt = gr.Textbox(label=LABELS["system_prompt"], interactive=False,value=prefix)
207
- seed_output = gr.Textbox(label=LABELS["seed_output"], interactive=True,value=502119)
208
- is_rnd_seed = gr.Checkbox(value=True, label="Random seed", interactive=True)
209
-
210
- with gr.Row(variant="default"):
211
- temperature = gr.Slider(label=LABELS["temperature"], interactive=True,value=0.7, minimum=0.1,maximum=1,step=0.1)
212
- repetition_penalty = gr.Slider(label=LABELS["repetition_penalty"], interactive=True,value=1.2, minimum=0.1,maximum=2,step=0.1)
213
- #repetition_penalty =
214
-
215
-
216
- result_output = gr.Textbox(label=LABELS["result_output"], interactive=False)
217
- result_output_ru = gr.Textbox(label=LABELS["result_output_ru"], interactive=False, visible = False if lang == "EN" else True)
218
-
219
- #prompt_input.submit(fn=enhance_prompt, inputs=[prompt_input,system_prompt,temperature,repetition_penalty], outputs=[seed_output, result_output, result_output_ru], show_progress=False)
220
-
221
- # Кнопка генерации
222
- with gr.Row(variant="default"):
223
- generate_button = gr.Button(LABELS["generate_button"], variant="primary", size="lg")
224
- generate_button.click(fn=enhance_prompt, inputs=[prompt_input,system_prompt,temperature,repetition_penalty,seed_output,is_rnd_seed],
225
- outputs=[seed_output, result_output, result_output_ru])
226
-
227
- # Кнопка копирования в буфер обмена
228
- copy_button = gr.Button(LABELS["copy_button"], variant="secondary")
229
- copy_button.click(fn=copy_to_clipboard, inputs=result_output, outputs=[],js="(text) => navigator.clipboard.writeText(text)")
230
- with gr.Row(variant="default"):
231
- log_text = gr.Textbox(label="")
232
-
233
- if is_config_ui_installed:
234
- save_button.click(fn=save_config, inputs=[], outputs=log_text)
235
-
236
- #preload values for lang
237
- demo.load(set_initial, outputs=[radio_lang, radio_gpu, log_text])
238
-
239
- if is_google_translate_installed:
240
- radio_lang.change(process_lang, inputs=radio_lang,
241
- outputs=[log_text,generate_button, copy_button, save_button, prompt_input, seed_output, is_rnd_seed,
242
- result_output, result_output_ru,system_prompt, temperature, repetition_penalty])
243
-
244
- radio_gpu.change(process_gpu, inputs=radio_gpu, outputs=log_text)
245
-
246
-
247
- # Запуск приложения с прослушиванием на всех интерфейсах и открытием в браузере
248
- demo.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+
5
+ try:
6
+ # переводчик с русского на английский
7
+ from google_translate import TranslatorWithCache
8
+ is_google_translate_installed=True
9
+ translator = TranslatorWithCache()
10
+ except ImportError:
11
+ is_google_translate_installed=False
12
+
13
+
14
+ try:
15
+ from config_ui import Config
16
+ is_config_ui_installed=True
17
+ config = Config()
18
+ device = "cuda" if (config.cuda=="cuda" and torch.cuda.is_available()) else "cpu"
19
+ lang=config.lang
20
+ except ImportError:
21
+ is_config_ui_installed=False
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ lang='EN'
24
+
25
+ try:
26
+ from prompt.portrait_prompt import generate_random_portrait_prompt
27
+ is_rnd_gen_installed=True
28
+ except:
29
+ is_rnd_gen_installed=False
30
+
31
+
32
+ model_checkpoint = "gokaygokay/Flux-Prompt-Enhance"
33
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
34
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to(device)
35
+
36
+ max_target_length = 256
37
+ prefix = "enhance prompt"
38
+
39
+ def enhance_prompt(prompt, system_prompt, temperature=0.5, repetition_penalty=1.2, seed=-1, is_rnd_seed=True):
40
+
41
+ if is_rnd_seed or seed==-1:
42
+ seed = torch.randint(0, 2**32 - 1, (1,)).item()
43
+
44
+ torch.manual_seed(seed)
45
+
46
+ if is_google_translate_installed:
47
+ # Перевод с русского на английский
48
+ en_prompt = translator.translate_ru2eng(prompt)
49
+ input_text = f"{system_prompt}: {en_prompt}"
50
+ else:
51
+ input_text = f"{system_prompt}: {prompt}"
52
+
53
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
54
+
55
+ # Генерация текста
56
+ outputs = model.generate(
57
+ input_ids,
58
+ max_length=max_target_length,
59
+ num_return_sequences=1,
60
+ do_sample=True,
61
+ temperature=temperature,
62
+ repetition_penalty=repetition_penalty
63
+ )
64
+
65
+ generated_text_en = tokenizer.decode(outputs[0], skip_special_tokens=True)
66
+
67
+ if is_google_translate_installed:
68
+ result_output_ru = translator.translate_eng2ru(generated_text_en)
69
+ else:
70
+ result_output_ru=generated_text_en
71
+
72
+ return seed, generated_text_en, result_output_ru
73
+
74
+ def random_prompt():
75
+ rnd_prompt_str=generate_random_portrait_prompt()
76
+ #rnd_prompt_str=get_random_words()
77
+ return rnd_prompt_str
78
+
79
+ # Функция копирования текста в буфер
80
+ def copy_to_clipboard(text):
81
+ gr.Info("скопировано в буффер обмена" if (lang=="RU") else "copy to clipboard" ,duration=1)
82
+ return None
83
+
84
+ LABELS_EN={"prompt_input": "Input initial prompt:",
85
+ "seed_output": "Seed:",
86
+ "result_output" : "Improved prompt",
87
+ "result_output_ru" : "Improved prompt (in Russian)",
88
+ "generate_button": "Improve prompt",
89
+ "copy_button": "Copy to clipboard",
90
+ "save_button": "Save config",
91
+ "system_prompt" : "System prompt",
92
+ "temperature": "Temperature",
93
+ "repetition_penalty": "Repetition penalty",
94
+ "is_rnd_seed": "Random Seed"
95
+ }
96
+ LABELS=LABELS_EN
97
+
98
+ if is_google_translate_installed:
99
+
100
+ LABELS_RU={"prompt_input": "Введите начальный промпт:",
101
+ "seed_output": "Seed для генерации:",
102
+ "result_output" : "Улучшенный промпт (на английском):",
103
+ "result_output_ru" : "Улучшенный промпт (на русском):",
104
+ "generate_button": "Улучшить промпт",
105
+ "copy_button": "Скопировать в буффер обмена",
106
+ "save_button": "Сохранить настройки",
107
+ "system_prompt": "Системный промпт",
108
+ "temperature": "Температура",
109
+ "repetition_penalty": "Штраф за повторение",
110
+ "is_rnd_seed": "Случайный Seed"
111
+ }
112
+ LABELS=LABELS_EN if lang=="EN" else LABELS_RU
113
+
114
+ if is_google_translate_installed:
115
+ def process_lang(selected_lang):
116
+ global lang
117
+ lang=selected_lang
118
+ if selected_lang == "RU":
119
+ LABELS=LABELS_RU
120
+ message="Вы выбрали русский"
121
+ isVisible=True
122
+ elif selected_lang == "EN":
123
+ LABELS=LABELS_EN
124
+ message="You selected English"
125
+ isVisible=False
126
+ ret = [gr.update(value=LABELS["generate_button"]),
127
+ gr.update(value=LABELS["copy_button"]),
128
+ gr.update(value=LABELS["save_button"]),
129
+ gr.update(label=LABELS["prompt_input"]),
130
+ gr.update(label=LABELS["seed_output"]),
131
+ gr.update(label=LABELS["is_rnd_seed"]),
132
+ gr.update(label=LABELS["result_output"]),
133
+ gr.update(visible=isVisible, label=LABELS["result_output_ru"]),
134
+ gr.update(label=LABELS["system_prompt"]),
135
+ gr.update(label=LABELS["temperature"]),
136
+ gr.update(label=LABELS["repetition_penalty"])
137
+ ]
138
+ return message, *ret
139
+
140
+
141
+ if is_config_ui_installed:
142
+ def save_config():
143
+ global lang,device,isOpenAdvanced ,config, AccordionAdvanced
144
+
145
+ config.set_lang(lang)
146
+ config.set_cuda(str(device))
147
+ isOpenAdvanced=AccordionAdvanced.open
148
+ print(AccordionAdvanced.open)
149
+ config.set_OpenAdvanced=(isOpenAdvanced)
150
+
151
+ # Сохраняем изменения в файл
152
+ config.save()
153
+ return "save config to file" if lang=='EN' else "Конфигурация сохранена в файл"
154
+
155
+ def process_gpu(selected_gpu):
156
+ """Функция для переключения модели между устройствами (CPU / CUDA)"""
157
+ global model, device # Используем глобальные переменные model и device
158
+ device = torch.device(selected_gpu) # Устанавливаем новое устройство
159
+ model = model.to(device) # Переносим модель на новое устройство
160
+ message= f"Модель переключена на устройство: {selected_gpu}" if lang=="RU" else f"Model switched to device: {selected_gpu}"
161
+ return message
162
+
163
+ def set_initial():
164
+ global device
165
+ dev="cpu"
166
+ if str(device) =='cuda':
167
+ device = torch.cuda.current_device()
168
+ device_name = torch.cuda.get_device_name(device)
169
+ device_name = f"GPU: {device_name}"
170
+ dev="cuda"
171
+ else:
172
+ device_name = "use CPU"
173
+ return gr.update(value=lang), gr.update(value=dev), f'{device_name}\nset to "{lang}" language'
174
+
175
+
176
+ # Настройка интерфейса Gradio
177
+ with gr.Blocks(title="Flux Prompt Enhance",
178
+ theme=gr.themes.Default(primary_hue=gr.themes.colors.sky, secondary_hue=gr.themes.colors.indigo),
179
+ analytics_enabled=False, css="footer{display:none !important}") as demo:
180
+ gr.Image(label="header AiCave", value="./static/ai_cave_title.jpg",height="100%",
181
+ show_download_button=False, show_label=False, show_share_button=False,
182
+ interactive=False, show_fullscreen_button=False,)
183
+
184
+ with gr.Row(variant="default"):
185
+ gr.HTML("""
186
+ <h1>Flux Prompt Enhance portable by <a href='https://boosty.to/aicave/donate' style="color: #4AA0E2;">CaveMan</a></h1>
187
+ """)
188
+ with gr.Row(variant="default"):
189
+ # выбор языка UI
190
+ radio_lang = gr.Radio(choices = ["RU", "EN"], show_label = False, container = False, type = "value",
191
+ visible = True if is_google_translate_installed else False)
192
+
193
+ radio_gpu = gr.Radio(choices = ["cuda","cpu"], show_label = False, container = False, type = "value",
194
+ visible = True if torch.cuda.is_available() else False)
195
+ save_button = gr.Button(LABELS["save_button"], visible= True if is_config_ui_installed else False)
196
+
197
+ with gr.Row(variant="default"):
198
+ prompt_input = gr.Textbox(label=LABELS["prompt_input"])
199
+ if is_rnd_gen_installed:
200
+ button_random = gr.Button("", icon="./static/random.png", scale=0, min_width=200)
201
+ button_random.click(fn=random_prompt, outputs=prompt_input)
202
+
203
+
204
+ with gr.Accordion("Advanced:", open=False ) as AccordionAdvanced:
205
+ with gr.Row(variant="default"):
206
+ system_prompt = gr.Textbox(label=LABELS["system_prompt"], interactive=False,value=prefix)
207
+ seed_output = gr.Textbox(label=LABELS["seed_output"], interactive=True,value=502119)
208
+ is_rnd_seed = gr.Checkbox(value=True, label="Random seed", interactive=True)
209
+
210
+ with gr.Row(variant="default"):
211
+ temperature = gr.Slider(label=LABELS["temperature"], interactive=True,value=0.7, minimum=0.1,maximum=1,step=0.1)
212
+ repetition_penalty = gr.Slider(label=LABELS["repetition_penalty"], interactive=True,value=1.2, minimum=0.1,maximum=1.9,step=0.1)
213
+ #repetition_penalty =
214
+
215
+
216
+ result_output = gr.Textbox(label=LABELS["result_output"], interactive=False)
217
+ result_output_ru = gr.Textbox(label=LABELS["result_output_ru"], interactive=False, visible = False if lang == "EN" else True)
218
+
219
+ #prompt_input.submit(fn=enhance_prompt, inputs=[prompt_input,system_prompt,temperature,repetition_penalty], outputs=[seed_output, result_output, result_output_ru], show_progress=False)
220
+
221
+ # Кнопка генерации
222
+ with gr.Row(variant="default"):
223
+ generate_button = gr.Button(LABELS["generate_button"], variant="primary", size="lg")
224
+ generate_button.click(fn=enhance_prompt, inputs=[prompt_input,system_prompt,temperature,repetition_penalty,seed_output,is_rnd_seed],
225
+ outputs=[seed_output, result_output, result_output_ru])
226
+
227
+ # Кнопка копирования в буфер обмена
228
+ copy_button = gr.Button(LABELS["copy_button"], variant="secondary")
229
+ copy_button.click(fn=copy_to_clipboard, inputs=result_output, outputs=[],js="(text) => navigator.clipboard.writeText(text)")
230
+ with gr.Row(variant="default"):
231
+ log_text = gr.Textbox(label="")
232
+
233
+ if is_config_ui_installed:
234
+ save_button.click(fn=save_config, inputs=[], outputs=log_text)
235
+
236
+ #preload values for lang
237
+ demo.load(set_initial, outputs=[radio_lang, radio_gpu, log_text])
238
+
239
+ if is_google_translate_installed:
240
+ radio_lang.change(process_lang, inputs=radio_lang,
241
+ outputs=[log_text,generate_button, copy_button, save_button, prompt_input, seed_output, is_rnd_seed,
242
+ result_output, result_output_ru,system_prompt, temperature, repetition_penalty])
243
+
244
+ radio_gpu.change(process_gpu, inputs=radio_gpu, outputs=log_text)
245
+
246
+
247
+ # Запуск приложения с прослушиванием на всех интерфейсах и открытием в браузере
248
+ demo.launch()