salomonsky commited on
Commit
8c1a558
1 Parent(s): 99a5876

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -107
app.py CHANGED
@@ -7,33 +7,21 @@ import streamlit as st
7
  from huggingface_hub import InferenceClient, AsyncInferenceClient
8
  from gradio_client import Client, handle_file
9
  import asyncio
10
- from concurrent.futures import ThreadPoolExecutor
11
 
12
  MAX_SEED = np.iinfo(np.int32).max
13
  HF_TOKEN = os.environ.get("HF_TOKEN")
14
  HF_TOKEN_UPSCALER = os.environ.get("HF_TOKEN_UPSCALER")
15
  client = AsyncInferenceClient()
16
- llm_client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
17
  DATA_PATH = Path("./data")
18
  DATA_PATH.mkdir(exist_ok=True)
19
 
20
- def run_async(func):
21
- loop = asyncio.new_event_loop()
22
- asyncio.set_event_loop(loop)
23
- executor = ThreadPoolExecutor(max_workers=1)
24
- result = loop.run_in_executor(executor, func)
25
- return loop.run_until_complete(result)
26
-
27
- def enable_lora(lora_add, basemodel):
28
- return lora_add if lora_add else basemodel
29
-
30
- async def generate_image(combined_prompt, model, width, height, scales, steps, seed):
31
  try:
32
  if seed == -1:
33
  seed = random.randint(0, MAX_SEED)
34
  seed = int(seed)
35
  image = await client.text_to_image(
36
- prompt=combined_prompt, height=height, width=width, guidance_scale=scales,
37
  num_inference_steps=steps, model=model
38
  )
39
  return image, seed
@@ -50,89 +38,48 @@ def get_upscale_finegrain(prompt, img_path, upscale_factor):
50
  except Exception as e:
51
  return None
52
 
53
- def save_prompt(prompt_text, seed):
54
- try:
55
- prompt_file_path = DATA_PATH / f"prompt_{seed}.txt"
56
- with open(prompt_file_path, "w") as prompt_file:
57
- prompt_file.write(prompt_text)
58
- return prompt_file_path
59
- except Exception as e:
60
- st.error(f"Error al guardar el prompt: {e}")
61
- return None
62
-
63
  def save_image(image, seed):
64
- image_path = DATA_PATH / f"generated_image_{seed}.jpg"
65
- image.save(image_path)
66
- return image_path
67
-
68
- async def improve_prompt(prompt, language):
69
  try:
70
- instruction = (
71
- "With this idea, describe in English a detailed txt2img prompt in 500 characters at most, add illumination, atmosphere, cinematic elements, and characters..."
72
- if language == "English"
73
- else "Con esta idea, describe en español un prompt detallado de txt2img en un máximo de 500 caracteres, añadiendo iluminación, atmósfera, elementos cinematográficos y personajes..."
74
- )
75
-
76
- formatted_prompt = f"{prompt}: {instruction}"
77
- response = llm_client.text_generation(formatted_prompt, max_new_tokens=300)
78
- improved_text = response['generated_text'].strip() if 'generated_text' in response else response.strip()
79
- return improved_text[:300] if len(improved_text) > 300 else improved_text
80
  except Exception as e:
81
- return f"Error mejorando el prompt: {e}"
82
-
83
- async def gen(prompt, basemodel, width, height, scales, steps, seed, upscale_factor, process_upscale, lora_model, process_lora, process_enhancer, prompt_language):
84
- model = enable_lora(lora_model, basemodel) if process_lora else basemodel
85
- combined_prompt = prompt # Usar el prompt original por defecto
86
-
87
- if process_enhancer:
88
- improved_prompt = await improve_prompt(prompt, prompt_language)
89
- combined_prompt = f"{prompt} {improved_prompt}"
90
-
91
- if seed == -1:
92
- seed = random.randint(0, MAX_SEED)
93
- seed = int(seed)
94
- progress_bar = st.progress(0)
95
- image, seed = await generate_image(combined_prompt, model, width, height, scales, steps, seed)
96
- progress_bar.progress(50)
97
 
98
- if isinstance(image, str) and image.startswith("Error"):
99
- progress_bar.empty()
100
- return [image, None, combined_prompt]
 
101
 
102
- image_path = save_image(image, seed)
103
- prompt_file_path = save_prompt(combined_prompt, seed)
 
104
 
105
- if process_upscale:
106
- upscale_image_path = get_upscale_finegrain(combined_prompt, image_path, upscale_factor)
107
- if upscale_image_path:
108
- upscale_image = Image.open(upscale_image_path)
109
- upscale_image.save(DATA_PATH / f"upscale_image_{seed}.jpg", format="JPEG")
110
- progress_bar.progress(100)
111
- image_path.unlink()
112
- return [str(DATA_PATH / f"upscale_image_{seed}.jpg"), str(prompt_file_path)]
113
  else:
114
- progress_bar.empty()
115
- return [str(image_path), str(prompt_file_path)]
116
- else:
117
- progress_bar.progress(100)
118
- return [str(image_path), str(prompt_file_path)]
119
 
120
  def main():
121
  st.set_page_config(layout="wide")
122
- st.title("FLUX with enhancer and upscaler with LORA model training")
123
 
124
- prompt = st.sidebar.text_input("Descripción de la imagen", max_chars=200)
125
- process_enhancer = st.sidebar.checkbox("Mejorar Prompt", value=True)
126
- prompt_language = st.sidebar.selectbox("Idioma para mejorar el prompt", ["English", "Spanish"])
127
- basemodel = st.sidebar.selectbox("Modelo Base", ["black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-DEV"])
128
- lora_model = st.sidebar.selectbox("LORA Realismo", ["Shakker-Labs/FLUX.1-dev-LoRA-add-details", "XLabs-AI/flux-RealismLora"])
129
- format_option = st.sidebar.selectbox("Formato", ["9:16", "16:9"])
130
- process_lora = st.sidebar.checkbox("Procesar LORA", value=True)
131
- process_upscale = st.sidebar.checkbox("Procesar Escalador", value=True)
132
- upscale_factor = st.sidebar.selectbox("Factor de Escala", [2, 4, 8], index=0)
133
- scales = st.sidebar.slider("Escalado", 1, 20, 10)
134
- steps = st.sidebar.slider("Pasos", 1, 100, 20)
135
- seed = st.sidebar.number_input("Semilla", value=-1)
136
 
137
  if format_option == "9:16":
138
  width = 720
@@ -141,27 +88,35 @@ def main():
141
  width = 1280
142
  height = 720
143
 
144
- if st.sidebar.button("Generar Imagen"):
145
- with st.spinner("Mejorando y generando imagen..."):
146
- result = asyncio.run(gen(prompt, basemodel, width, height, scales, steps, seed, upscale_factor, process_upscale, lora_model, process_lora, process_enhancer, prompt_language))
147
- image_paths = result[0]
148
- prompt_file = result[1]
149
-
150
- st.write(f"Image paths: {image_paths}")
151
-
152
- if image_paths:
153
- if Path(image_paths).exists():
154
- st.image(image_paths, caption="Imagen Generada")
155
- else:
156
- st.error("El archivo de imagen no existe.")
157
-
158
- if prompt_file and Path(prompt_file).exists():
159
- prompt_text = Path(prompt_file).read_text()
160
- st.write(f"Prompt utilizado: {prompt_text}")
161
  else:
162
- st.write("El archivo del prompt no está disponible.")
163
- else:
164
- st.error("No se pudo generar la imagen.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  if __name__ == "__main__":
167
- main()
 
7
  from huggingface_hub import InferenceClient, AsyncInferenceClient
8
  from gradio_client import Client, handle_file
9
  import asyncio
 
10
 
11
  MAX_SEED = np.iinfo(np.int32).max
12
  HF_TOKEN = os.environ.get("HF_TOKEN")
13
  HF_TOKEN_UPSCALER = os.environ.get("HF_TOKEN_UPSCALER")
14
  client = AsyncInferenceClient()
 
15
  DATA_PATH = Path("./data")
16
  DATA_PATH.mkdir(exist_ok=True)
17
 
18
+ async def generate_image(prompt, model, width, height, scales, steps, seed):
 
 
 
 
 
 
 
 
 
 
19
  try:
20
  if seed == -1:
21
  seed = random.randint(0, MAX_SEED)
22
  seed = int(seed)
23
  image = await client.text_to_image(
24
+ prompt=prompt, height=height, width=width, guidance_scale=scales,
25
  num_inference_steps=steps, model=model
26
  )
27
  return image, seed
 
38
  except Exception as e:
39
  return None
40
 
 
 
 
 
 
 
 
 
 
 
41
  def save_image(image, seed):
 
 
 
 
 
42
  try:
43
+ image_path = DATA_PATH / f"image_{seed}.jpg"
44
+ image.save(image_path, format="JPEG")
45
+ return image_path
 
 
 
 
 
 
 
46
  except Exception as e:
47
+ st.error(f"Error al guardar la imagen: {e}")
48
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ def get_storage():
51
+ files = [file for file in DATA_PATH.glob("*.jpg") if file.is_file()]
52
+ files.sort(key=lambda x: x.stat().st_mtime, reverse=True)
53
+ return [str(file.resolve()) for file in files]
54
 
55
+ def get_prompts():
56
+ prompt_files = [file for file in DATA_PATH.glob("*.txt") if file.is_file()]
57
+ return {file.stem.replace("prompt_", ""): file for file in prompt_files}
58
 
59
+ def delete_image(image_path):
60
+ try:
61
+ if Path(image_path).exists():
62
+ Path(image_path).unlink()
63
+ st.success(f"Imagen {image_path} borrada.")
 
 
 
64
  else:
65
+ st.error("El archivo de imagen no existe.")
66
+ except Exception as e:
67
+ st.error(f"Error al borrar la imagen: {e}")
 
 
68
 
69
  def main():
70
  st.set_page_config(layout="wide")
71
+ st.title("Generación de Imágenes")
72
 
73
+ prompt = st.text_input("Descripción de la imagen", max_chars=200)
74
+
75
+ with st.expander("Opciones avanzadas", expanded=False):
76
+ basemodel = st.selectbox("Modelo Base", ["black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-DEV"])
77
+ format_option = st.selectbox("Formato", ["9:16", "16:9"])
78
+ process_upscale = st.checkbox("Procesar Escalador", value=True)
79
+ upscale_factor = st.selectbox("Factor de Escala", [2, 4, 8], index=0)
80
+ scales = st.slider("Escalado", 1, 20, 10)
81
+ steps = st.slider("Pasos", 1, 100, 20)
82
+ seed = st.number_input("Semilla", value=-1)
 
 
83
 
84
  if format_option == "9:16":
85
  width = 720
 
88
  width = 1280
89
  height = 720
90
 
91
+ if st.button("Generar Imagen"):
92
+ with st.spinner("Generando imagen..."):
93
+ image, seed = await generate_image(prompt, basemodel, width, height, scales, steps, seed)
94
+
95
+ if isinstance(image, str) and image.startswith("Error"):
96
+ st.error(image)
 
 
 
 
 
 
 
 
 
 
 
97
  else:
98
+ image_path = save_image(image, seed)
99
+ if image_path:
100
+ st.image(image_path, caption="Imagen Generada")
101
+ st.success("Imagen generada y guardada.")
102
+
103
+ # Mostrar galería de imágenes
104
+ files = get_storage()
105
+ prompts = get_prompts()
106
+
107
+ st.subheader("Galería de Imágenes")
108
+ cols = st.columns(3)
109
+
110
+ for idx, file in enumerate(files):
111
+ with cols[idx % 3]:
112
+ image = Image.open(file)
113
+ prompt_text = prompts.get(Path(file).stem.replace("image_", ""), "No disponible")
114
+
115
+ st.image(image, caption=f"Imagen {idx + 1}")
116
+ st.write(f"Prompt: {prompt_text}")
117
+
118
+ if st.button(f"Borrar Imagen {idx + 1}", key=f"delete_{idx}"):
119
+ delete_image(file)
120
 
121
  if __name__ == "__main__":
122
+ main()