ItzRoBeerT commited on
Commit
0ce8592
1 Parent(s): 36af654

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -19
app.py CHANGED
@@ -1,8 +1,14 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
- from diffusers import StableDiffusionPipeline
4
  import torch
 
5
  from PIL import Image
 
 
 
 
 
 
6
 
7
  # Configuración del dispositivo
8
  device = "cpu"
@@ -12,7 +18,8 @@ elif torch.backends.mps.is_available():
12
  device = "mps"
13
 
14
  # Configuración de modelos
15
- model_id_image = "sd-legacy/stable-diffusion-v1-5"
 
16
  model_id_image_description = "vikhyatk/moondream2"
17
  revision = "2024-08-26"
18
 
@@ -25,44 +32,47 @@ print("Cargando modelo de descripción de imágenes...")
25
  model_description = AutoModelForCausalLM.from_pretrained(model_id_image_description, trust_remote_code=True, revision=revision)
26
  tokenizer_description = AutoTokenizer.from_pretrained(model_id_image_description, revision=revision)
27
 
28
- print("Cargando modelo de Stable Diffusion...")
29
- pipe_sd = StableDiffusionPipeline.from_pretrained(model_id_image, torch_dtype=torch_dtype)
30
- pipe_sd = pipe_sd.to(device)
31
-
32
- # Opciones para optimizar memoria
33
- pipe_sd.enable_attention_slicing()
34
- if device == "cuda":
35
- pipe_sd.enable_sequential_cpu_offload() # Liberar memoria gradualmente para GPUs pequeñas
36
-
37
  def generate_description(image_path):
38
  image_test = Image.open(image_path)
39
  enc_image = model_description.encode_image(image_test)
40
  description = model_description.answer_question(enc_image, "Describe this image to create an avatar", tokenizer_description)
41
  return description
42
 
 
 
 
 
43
  def generate_image_by_description(description, avatar_style=None):
44
- prompt = f"Create a pigeon profile avatar. Use the following description: {description}."
45
- if avatar_style:
46
- prompt += f" Use {avatar_style} style."
 
 
47
 
48
- result = pipe_sd(prompt)
49
- return result.images[0]
 
 
 
50
 
51
  def process_and_generate(image, avatar_style):
52
  description = generate_description(image)
53
  return generate_image_by_description(description, avatar_style)
54
 
55
- with gr.Blocks() as demo:
 
 
56
  with gr.Row():
57
  with gr.Column(scale=2, min_width=300):
58
  selected_image = gr.Image(type="filepath", label="Upload an Image of the Pigeon", height=300)
 
59
  avatar_style = gr.Radio(
60
  ["Realistic", "Pixel Art", "Imaginative", "Cartoon"],
61
  label="(optional) Select the avatar style:"
62
  )
63
  generate_button = gr.Button("Generate Avatar", variant="primary")
64
  with gr.Column(scale=2, min_width=300):
65
- generated_image = gr.Image(type="numpy", label="Generated Avatar", height=300)
66
 
67
  generate_button.click(process_and_generate, inputs=[selected_image, avatar_style], outputs=generated_image)
68
 
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
3
  import torch
4
+ import io
5
  from PIL import Image
6
+ import requests
7
+ import random
8
+ import dom
9
+ import os
10
+
11
+ NUM_IMAGES = 2
12
 
13
  # Configuración del dispositivo
14
  device = "cpu"
 
18
  device = "mps"
19
 
20
  # Configuración de modelos
21
+ API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev"
22
+ headers = {"Authorization": f"Bearer {os.getenv('api_token')}"}
23
  model_id_image_description = "vikhyatk/moondream2"
24
  revision = "2024-08-26"
25
 
 
32
  model_description = AutoModelForCausalLM.from_pretrained(model_id_image_description, trust_remote_code=True, revision=revision)
33
  tokenizer_description = AutoTokenizer.from_pretrained(model_id_image_description, revision=revision)
34
 
 
 
 
 
 
 
 
 
 
35
  def generate_description(image_path):
36
  image_test = Image.open(image_path)
37
  enc_image = model_description.encode_image(image_test)
38
  description = model_description.answer_question(enc_image, "Describe this image to create an avatar", tokenizer_description)
39
  return description
40
 
41
+ def query(payload):
42
+ response = requests.post(API_URL, headers=headers, json=payload)
43
+ return response.content
44
+
45
  def generate_image_by_description(description, avatar_style=None):
46
+ images = []
47
+ for _ in range(NUM_IMAGES):
48
+ prompt = f"Create a pigeon profile avatar. Use the following description: {description}."
49
+ if avatar_style:
50
+ prompt += f" Use {avatar_style} style."
51
 
52
+ image_bytes = query({"inputs": prompt, "parameters": {"seed": random.randint(0, 1000)}})
53
+ image = Image.open(io.BytesIO(image_bytes))
54
+ images.append(image)
55
+ print(images)
56
+ return images
57
 
58
  def process_and_generate(image, avatar_style):
59
  description = generate_description(image)
60
  return generate_image_by_description(description, avatar_style)
61
 
62
+ with gr.Blocks(js=dom.generate_title) as demo:
63
+ with gr.Row():
64
+ gr.Markdown(dom.generate_markdown)
65
  with gr.Row():
66
  with gr.Column(scale=2, min_width=300):
67
  selected_image = gr.Image(type="filepath", label="Upload an Image of the Pigeon", height=300)
68
+ example_image = gr.Examples(["./examples/pigeon.webp"], label="Example Images", inputs=[selected_image])
69
  avatar_style = gr.Radio(
70
  ["Realistic", "Pixel Art", "Imaginative", "Cartoon"],
71
  label="(optional) Select the avatar style:"
72
  )
73
  generate_button = gr.Button("Generate Avatar", variant="primary")
74
  with gr.Column(scale=2, min_width=300):
75
+ generated_image = gr.Gallery(type="pil", label="Generated Avatar", height=300)
76
 
77
  generate_button.click(process_and_generate, inputs=[selected_image, avatar_style], outputs=generated_image)
78