File size: 3,986 Bytes
4a0cc75
0ce8592
4a0cc75
0ce8592
8e4035c
0ce8592
 
 
 
0a5fba5
0ce8592
 
8e4035c
d5e510d
 
8e4035c
 
d5e510d
8e4035c
d5e510d
 
0ce8592
 
8e4035c
 
4a0cc75
0a5fba5
 
 
 
 
 
 
 
 
 
 
 
 
 
4a0cc75
 
d5e510d
4a0cc75
d5e510d
 
 
 
4a0cc75
0a5fba5
d5e510d
 
 
 
 
8e4035c
0ce8592
 
 
 
0a5fba5
d5e510d
0ce8592
 
 
 
 
4a0cc75
0ce8592
 
 
 
88a3ed8
8e4035c
 
 
 
0a5fba5
ff6b31d
 
d5e510d
0ce8592
8e4035c
d5e510d
0a5fba5
 
d5e510d
ff6b31d
 
0ce8592
d5e510d
 
0a5fba5
 
 
 
 
 
 
 
d5e510d
4a0cc75
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import io
from PIL import Image
import requests
import random
import dom
import os
import time

NUM_IMAGES = 2

# Configuración del dispositivo
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"

# Configuración de modelos
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev"
headers = {"Authorization": f"Bearer {os.getenv('api_token')}"}
model_id_image_description = "vikhyatk/moondream2"
revision = "2024-08-26"

# Para medir el rendimiento de los métodos, voy a crear este decorador, que simplemente imprime en nuestra terminal
#  el tiempo de ejecucion de los metodos que tengan los modelos y los usen, de esta manera podremos estudiar 
# el tiempo que este cada modelo activo
def measure_performance(func):
    def wrapper(*args, **kwargs):
        print(f"Starting execution of '{func.__name__}' with 'args={args}, kwargs={kwargs}'")
        start = time.time()
        result = func(*args, **kwargs)
        end = time.time()
        duration = end - start
        print(f"Execution time of '{func.__name__}' with 'args={args}, kwargs={kwargs}': {duration:.4f} seconds")
        return result
    return wrapper

torch_dtype = torch.float32
if torch.cuda.is_available():
    torch_dtype = torch.bfloat16  # Optimización en GPU

# Carga de modelos persistente
print("Cargando modelo de descripción de imágenes...")
model_description = AutoModelForCausalLM.from_pretrained(model_id_image_description, trust_remote_code=True, revision=revision)
tokenizer_description = AutoTokenizer.from_pretrained(model_id_image_description, revision=revision)

@measure_performance
def generate_description(image_path):
    image_test = Image.open(image_path)
    enc_image = model_description.encode_image(image_test)
    description = model_description.answer_question(enc_image, "Describe this image to create an avatar", tokenizer_description)
    return description

def query(payload):
	response = requests.post(API_URL, headers=headers, json=payload)
	return response.content

@measure_performance
def generate_image_by_description(description, avatar_style=None):
    images = []
    for _ in range(NUM_IMAGES):
        prompt = f"Create a pigeon profile avatar. Use the following description: {description}."
        if avatar_style:
            prompt += f" Use {avatar_style} style."

        image_bytes = query({"inputs": prompt, "parameters": {"seed": random.randint(0, 1000)}})
        image = Image.open(io.BytesIO(image_bytes))
        images.append(image)
    return images

def process_and_generate(image, avatar_style):
    description = generate_description(image)
    return generate_image_by_description(description, avatar_style)

with gr.Blocks(js=dom.generate_title) as demo:        
    with gr.Row():
        with gr.Column(scale=2, min_width=300):
            selected_image = gr.Image(type="filepath", label="Upload an Image of the Pigeon", height=300)
            example_image = gr.Examples(["./examples/pigeon.webp"], label="Example Images", inputs=[selected_image])
            avatar_style = gr.Radio(
                ["Realistic", "Pixel Art", "Imaginative", "Cartoon"], 
                label="(optional) Select the avatar style:",
                value="Pixel Art"
            )
            generate_button = gr.Button("Generate Avatar", variant="primary")
        with gr.Column(scale=2, min_width=300):
            generated_image = gr.Gallery(type="pil", label="Generated Avatar", height=300)
            
        generate_button.click(process_and_generate, inputs=[selected_image, avatar_style], outputs=generated_image)
    with gr.Tab(label="Description"):
        gr.Markdown(dom.generate_markdown)
        gr.Markdown(dom.models)
    with gr.Tab(label="Documentation"):
        gr.Markdown(dom.doccumentation)

        
        

demo.launch()