Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
try: | |
# переводчик с русского на английский | |
from google_translate import TranslatorWithCache | |
is_google_translate_installed=True | |
translator = TranslatorWithCache() | |
except ImportError: | |
is_google_translate_installed=False | |
try: | |
from config_ui import Config | |
is_config_ui_installed=True | |
config = Config() | |
device = "cuda" if (config.cuda=="cuda" and torch.cuda.is_available()) else "cpu" | |
lang=config.lang | |
except ImportError: | |
is_config_ui_installed=False | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
lang='EN' | |
try: | |
from prompt.portrait_prompt import generate_random_portrait_prompt | |
is_rnd_gen_installed=True | |
except: | |
is_rnd_gen_installed=False | |
model_checkpoint = "gokaygokay/Flux-Prompt-Enhance" | |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to(device) | |
max_target_length = 256 | |
prefix = "enhance prompt" | |
def enhance_prompt(prompt, system_prompt, temperature=0.5, repetition_penalty=1.2, seed=-1, is_rnd_seed=True): | |
if is_rnd_seed or seed==-1: | |
seed = torch.randint(0, 2**32 - 1, (1,)).item() | |
torch.manual_seed(seed) | |
if is_google_translate_installed: | |
# Перевод с русского на английский | |
en_prompt = translator.translate_ru2eng(prompt) | |
input_text = f"{system_prompt}: {en_prompt}" | |
else: | |
input_text = f"{system_prompt}: {prompt}" | |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device) | |
# Генерация текста | |
outputs = model.generate( | |
input_ids, | |
max_length=max_target_length, | |
num_return_sequences=1, | |
do_sample=True, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty | |
) | |
generated_text_en = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
if is_google_translate_installed: | |
result_output_ru = translator.translate_eng2ru(generated_text_en) | |
else: | |
result_output_ru=generated_text_en | |
return seed, generated_text_en, result_output_ru | |
def random_prompt(): | |
rnd_prompt_str=generate_random_portrait_prompt() | |
#rnd_prompt_str=get_random_words() | |
return rnd_prompt_str | |
# Функция копирования текста в буфер | |
def copy_to_clipboard(text): | |
gr.Info("скопировано в буффер обмена" if (lang=="RU") else "copy to clipboard" ,duration=1) | |
return None | |
LABELS_EN={"prompt_input": "Input initial prompt:", | |
"seed_output": "Seed:", | |
"result_output" : "Improved prompt", | |
"result_output_ru" : "Improved prompt (in Russian)", | |
"generate_button": "Improve prompt", | |
"copy_button": "Copy to clipboard", | |
"save_button": "Save config", | |
"system_prompt" : "System prompt", | |
"temperature": "Temperature", | |
"repetition_penalty": "Repetition penalty", | |
"is_rnd_seed": "Random Seed" | |
} | |
LABELS=LABELS_EN | |
if is_google_translate_installed: | |
LABELS_RU={"prompt_input": "Введите начальный промпт:", | |
"seed_output": "Seed для генерации:", | |
"result_output" : "Улучшенный промпт (на английском):", | |
"result_output_ru" : "Улучшенный промпт (на русском):", | |
"generate_button": "Улучшить промпт", | |
"copy_button": "Скопировать в буффер обмена", | |
"save_button": "Сохранить настройки", | |
"system_prompt": "Системный промпт", | |
"temperature": "Температура", | |
"repetition_penalty": "Штраф за повторение", | |
"is_rnd_seed": "Случайный Seed" | |
} | |
LABELS=LABELS_EN if lang=="EN" else LABELS_RU | |
if is_google_translate_installed: | |
def process_lang(selected_lang): | |
global lang | |
lang=selected_lang | |
if selected_lang == "RU": | |
LABELS=LABELS_RU | |
message="Вы выбрали русский" | |
isVisible=True | |
elif selected_lang == "EN": | |
LABELS=LABELS_EN | |
message="You selected English" | |
isVisible=False | |
ret = [gr.update(value=LABELS["generate_button"]), | |
gr.update(value=LABELS["copy_button"]), | |
gr.update(value=LABELS["save_button"]), | |
gr.update(label=LABELS["prompt_input"]), | |
gr.update(label=LABELS["seed_output"]), | |
gr.update(label=LABELS["is_rnd_seed"]), | |
gr.update(label=LABELS["result_output"]), | |
gr.update(visible=isVisible, label=LABELS["result_output_ru"]), | |
gr.update(label=LABELS["system_prompt"]), | |
gr.update(label=LABELS["temperature"]), | |
gr.update(label=LABELS["repetition_penalty"]) | |
] | |
return message, *ret | |
if is_config_ui_installed: | |
def save_config(): | |
global lang,device,isOpenAdvanced ,config, AccordionAdvanced | |
config.set_lang(lang) | |
config.set_cuda(str(device)) | |
isOpenAdvanced=AccordionAdvanced.open | |
print(AccordionAdvanced.open) | |
config.set_OpenAdvanced=(isOpenAdvanced) | |
# Сохраняем изменения в файл | |
config.save() | |
return "save config to file" if lang=='EN' else "Конфигурация сохранена в файл" | |
def process_gpu(selected_gpu): | |
"""Функция для переключения модели между устройствами (CPU / CUDA)""" | |
global model, device # Используем глобальные переменные model и device | |
device = torch.device(selected_gpu) # Устанавливаем новое устройство | |
model = model.to(device) # Переносим модель на новое устройство | |
message= f"Модель переключена на устройство: {selected_gpu}" if lang=="RU" else f"Model switched to device: {selected_gpu}" | |
return message | |
def set_initial(): | |
global device | |
dev="cpu" | |
if str(device) =='cuda': | |
device = torch.cuda.current_device() | |
device_name = torch.cuda.get_device_name(device) | |
device_name = f"GPU: {device_name}" | |
dev="cuda" | |
else: | |
device_name = "use CPU" | |
return gr.update(value=lang), gr.update(value=dev), f'{device_name}\nset to "{lang}" language' | |
# Настройка интерфейса Gradio | |
with gr.Blocks(title="Flux Prompt Enhance", | |
theme=gr.themes.Default(primary_hue=gr.themes.colors.sky, secondary_hue=gr.themes.colors.indigo), | |
analytics_enabled=False, css="footer{display:none !important}") as demo: | |
gr.Image(label="header AiCave", value="./static/ai_cave_title.jpg",height="100%", | |
show_download_button=False, show_label=False, show_share_button=False, | |
interactive=False, show_fullscreen_button=False,) | |
with gr.Row(variant="default"): | |
gr.HTML(""" | |
<h1>Flux Prompt Enhance portable by <a href='https://boosty.to/aicave/donate' style="color: #4AA0E2;">CaveMan</a></h1> | |
""") | |
with gr.Row(variant="default"): | |
# выбор языка UI | |
radio_lang = gr.Radio(choices = ["RU", "EN"], show_label = False, container = False, type = "value", | |
visible = True if is_google_translate_installed else False) | |
radio_gpu = gr.Radio(choices = ["cuda","cpu"], show_label = False, container = False, type = "value", | |
visible = True if torch.cuda.is_available() else False) | |
save_button = gr.Button(LABELS["save_button"], visible= True if is_config_ui_installed else False) | |
with gr.Row(variant="default"): | |
prompt_input = gr.Textbox(label=LABELS["prompt_input"]) | |
if is_rnd_gen_installed: | |
button_random = gr.Button("", icon="./static/random.png", scale=0, min_width=200) | |
button_random.click(fn=random_prompt, outputs=prompt_input) | |
with gr.Accordion("Advanced:", open=False ) as AccordionAdvanced: | |
with gr.Row(variant="default"): | |
system_prompt = gr.Textbox(label=LABELS["system_prompt"], interactive=False,value=prefix) | |
seed_output = gr.Textbox(label=LABELS["seed_output"], interactive=True,value=502119) | |
is_rnd_seed = gr.Checkbox(value=True, label="Random seed", interactive=True) | |
with gr.Row(variant="default"): | |
temperature = gr.Slider(label=LABELS["temperature"], interactive=True,value=0.7, minimum=0.1,maximum=1,step=0.1) | |
repetition_penalty = gr.Slider(label=LABELS["repetition_penalty"], interactive=True,value=1.2, minimum=0.1,maximum=2,step=0.1) | |
#repetition_penalty = | |
result_output = gr.Textbox(label=LABELS["result_output"], interactive=False) | |
result_output_ru = gr.Textbox(label=LABELS["result_output_ru"], interactive=False, visible = False if lang == "EN" else True) | |
#prompt_input.submit(fn=enhance_prompt, inputs=[prompt_input,system_prompt,temperature,repetition_penalty], outputs=[seed_output, result_output, result_output_ru], show_progress=False) | |
# Кнопка генерации | |
with gr.Row(variant="default"): | |
generate_button = gr.Button(LABELS["generate_button"], variant="primary", size="lg") | |
generate_button.click(fn=enhance_prompt, inputs=[prompt_input,system_prompt,temperature,repetition_penalty,seed_output,is_rnd_seed], | |
outputs=[seed_output, result_output, result_output_ru]) | |
# Кнопка копирования в буфер обмена | |
copy_button = gr.Button(LABELS["copy_button"], variant="secondary") | |
copy_button.click(fn=copy_to_clipboard, inputs=result_output, outputs=[],js="(text) => navigator.clipboard.writeText(text)") | |
with gr.Row(variant="default"): | |
log_text = gr.Textbox(label="") | |
if is_config_ui_installed: | |
save_button.click(fn=save_config, inputs=[], outputs=log_text) | |
#preload values for lang | |
demo.load(set_initial, outputs=[radio_lang, radio_gpu, log_text]) | |
if is_google_translate_installed: | |
radio_lang.change(process_lang, inputs=radio_lang, | |
outputs=[log_text,generate_button, copy_button, save_button, prompt_input, seed_output, is_rnd_seed, | |
result_output, result_output_ru,system_prompt, temperature, repetition_penalty]) | |
radio_gpu.change(process_gpu, inputs=radio_gpu, outputs=log_text) | |
# Запуск приложения с прослушиванием на всех интерфейсах и открытием в браузере | |
demo.launch() | |