File size: 14,343 Bytes
6257836
dfa7d4d
be0cee3
 
 
dfa7d4d
c5f8f57
 
be0cee3
dfa7d4d
be0cee3
c5f8f57
be0cee3
dfa7d4d
 
 
 
488578d
dfa7d4d
be0cee3
 
dfa7d4d
be0cee3
 
dfa7d4d
 
be0cee3
c5f8f57
be0cee3
 
 
dfa7d4d
 
28eef4f
be0cee3
 
28eef4f
 
 
dfa7d4d
c5f8f57
28eef4f
dfa7d4d
 
be0cee3
 
 
28eef4f
 
 
be0cee3
 
a03842d
28eef4f
072762b
c5f8f57
28eef4f
be0cee3
 
22fbbe5
072762b
 
28eef4f
072762b
be0cee3
28eef4f
be0cee3
28eef4f
 
 
 
c5f8f57
be0cee3
c5f8f57
be0cee3
c5f8f57
be0cee3
 
 
 
 
 
 
 
 
 
488578d
be0cee3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28eef4f
dfa7d4d
be0cee3
28eef4f
be0cee3
 
 
 
 
 
 
 
63e5350
be0cee3
dfa7d4d
be0cee3
dfa7d4d
63e5350
be0cee3
dfa7d4d
 
c5f8f57
 
 
072762b
be0cee3
dfa7d4d
28eef4f
dfa7d4d
be0cee3
28eef4f
be0cee3
c5f8f57
be0cee3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28eef4f
be0cee3
 
28eef4f
be0cee3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5f8f57
 
be0cee3
dfa7d4d
 
be0cee3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75e5c38
 
be0cee3
 
75e5c38
 
 
be0cee3
 
ed9ab57
 
be0cee3
ed9ab57
6f8f97c
 
be0cee3
28eef4f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
import os
import sys
import uuid
from typing import List, Dict, Any, Optional, Tuple

import torch
import gradio as gr
from PIL import Image
from omegaconf import OmegaConf, DictConfig

# --- 1. CONFIGURAÇÃO E IMPORTS ---

# Adiciona o diretório do VINCIE ao path do Python para permitir a importação de seus módulos.
VINCIE_DIR = "/app/VINCIE"
if VINCIE_DIR not in sys.path:
    sys.path.append(VINCIE_DIR)

try:
    from generate import VINCIEGenerator
    from common.config import load_config
    from common.seed import shift_seed
except ImportError as e:
    print(f"FATAL: Não foi possível importar os módulos do VINCIE. "
          f"Verifique se o repositório está em '{VINCIE_DIR}'.")
    raise e

# --- 2. INICIALIZAÇÃO DO MODELO (SINGLETON) ---

# Variáveis globais para armazenar a instância única do modelo e o dispositivo.
MODEL: Optional[VINCIEGenerator] = None
DEVICE: Optional[torch.device] = None

def setup_model():
    """
    Inicializa e configura o modelo VINCIE.
    Esta função é chamada uma vez no início da aplicação para carregar o modelo na GPU.
    """
    global MODEL, DEVICE

    if not torch.cuda.is_available():
        raise RuntimeError("FATAL: Nenhuma GPU compatível com CUDA foi encontrada.")

    num_gpus = torch.cuda.device_count()
    if num_gpus == 0:
        raise RuntimeError("FATAL: Nenhuma GPU foi detectada pelo PyTorch.")
    
    print(f"INFO: Detectadas {num_gpus} GPUs. A aplicação usará 'cuda:0'.")
    DEVICE = torch.device("cuda:0")
    torch.cuda.set_device(DEVICE)

    config_path = os.path.join(VINCIE_DIR, "configs/generate.yaml")
    print(f"INFO: Carregando e resolvendo configuração de '{config_path}'...")
    config = load_config(config_path, [])
    
    print("INFO: Instanciando VINCIEGenerator...")
    model_instance = VINCIEGenerator(config)
    
    # Executa a sequência de inicialização necessária do VINCIE
    print("INFO: Configurando persistência, modelos e difusão...")
    model_instance.configure_persistence()
    model_instance.configure_models()
    model_instance.configure_diffusion()
    
    if not hasattr(model_instance, 'dit'):
        raise RuntimeError("FATAL: Falha ao inicializar o componente DiT do modelo.")

    # Move todos os componentes para o dispositivo principal
    model_instance.dit.to(DEVICE)
    model_instance.vae.to(DEVICE)
    model_instance.text_encoder.to(DEVICE)
    
    MODEL = model_instance
    print(f"✅ SUCESSO: Modelo VINCIE pronto para uso na GPU {DEVICE}.")

# --- 3. LÓGICAS DE INFERÊNCIA ---

def _execute_vincie_logic(
    prompt_config: DictConfig,
    steps: int,
    cfg_scale: float,
    seed: int,
    pad_img_placeholder: bool,
    resolution: int
) -> Image.Image:
    """
    Função central e reutilizável que executa a inferência do VINCIE.

    Args:
        prompt_config (DictConfig): Configuração do prompt para o modelo.
        steps (int): Número de passos de difusão.
        cfg_scale (float): Escala de orientação (Classifier-Free Guidance).
        seed (int): Semente para reprodutibilidade.
        pad_img_placeholder (bool): Se deve formatar o prompt com placeholders <IMG>.
        resolution (int): Resolução do lado menor da imagem para processamento.

    Returns:
        Image.Image: A imagem gerada.
    """
    # Salva o estado original da configuração para restaurá-lo depois
    original_config_state = {
        "steps": MODEL.config.diffusion.timesteps.sampling.steps,
        "seed": MODEL.config.generation.seed,
        "pad": MODEL.config.generation.pad_img_placehoder,
        "resolution": MODEL.config.generation.resolution,
    }
    
    try:
        OmegaConf.set_readonly(MODEL.config, False)
        
        # 1. Aplica configurações dinâmicas
        MODEL.config.diffusion.timesteps.sampling.steps = int(steps)
        MODEL.configure_diffusion()  # Recria o sampler com os novos passos
        
        current_seed = seed if seed != -1 else torch.randint(0, 2**32 - 1, (1,)).item()
        MODEL.config.generation.seed = shift_seed(current_seed, 0)
        MODEL.config.generation.pad_img_placehoder = pad_img_placeholder
        MODEL.config.generation.resolution = int(resolution)

        # 2. Prepara as entradas
        text_pos, condition, noise, _, _ = MODEL.prepare_input(
            prompt=prompt_config, repeat_idx=0, device=DEVICE
        )
        
        # 3. Executa a inferência
        with torch.no_grad():
            samples = MODEL.inference(
                noises=[noise],
                conditions=[condition],
                texts_pos=[text_pos],
                texts_neg=[MODEL.config.generation.negative_prompt],
                cfg_scale=cfg_scale
            )
            
        if not samples:
            raise RuntimeError("A inferência do modelo não produziu resultados.")
            
        # 4. Processa a saída para formato de imagem
        output_tensor = samples[0][:, -1, :, :]
        output_image_np = output_tensor.clip(-1, 1).add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy()
        return Image.fromarray(output_image_np)

    finally:
        # 5. Restaura a configuração original para garantir consistência entre chamadas
        OmegaConf.set_readonly(MODEL.config, False)
        MODEL.config.diffusion.timesteps.sampling.steps = original_config_state["steps"]
        MODEL.config.generation.seed = original_config_state["seed"]
        MODEL.config.generation.pad_img_placehoder = original_config_state["pad"]
        MODEL.config.generation.resolution = original_config_state["resolution"]
        OmegaConf.set_readonly(MODEL.config, True)
        MODEL.configure_diffusion() # Restaura o sampler padrão


def run_single_turn_inference(
    input_image: str, prompt: str, aspect_ratio: str, resolution: int, steps: int, cfg_scale: float, seed: int
) -> Image.Image:
    """Handler para a aba 'Edição Simples'."""
    if not all([input_image, prompt]):
        raise gr.Error("É necessário fornecer uma imagem de entrada e um prompt.")
    
    _print_params("Edição Simples", prompt=prompt, aspect_ratio=aspect_ratio, resolution=resolution, steps=steps, cfg_scale=cfg_scale, seed=seed)
    
    prompt_config = OmegaConf.create({
        "index": 0, "img_paths": [input_image], "context": [prompt], "aspect_ratio": aspect_ratio
    })
    
    return _execute_vincie_logic(prompt_config, steps, cfg_scale, seed, pad_img_placeholder=True, resolution=resolution)


def run_multi_turn_inference(
    input_image: str, prompts_text: str, steps: int, cfg_scale: float, seed: int, progress=gr.Progress()
) -> List[Image.Image]:
    """Handler para a aba 'Edição em Múltiplos Turnos'."""
    if not all([input_image, prompts_text]):
        raise gr.Error("É necessário fornecer uma imagem de entrada e pelo menos um prompt.")
    
    prompts = [p.strip() for p in prompts_text.splitlines() if p.strip()]
    if not prompts:
        raise gr.Error("Nenhum prompt válido fornecido.")
        
    _print_params("Edição em Múltiplos Turnos", prompts=prompts, steps=steps, cfg_scale=cfg_scale, seed=seed)
    
    output_images_with_paths = []
    
    for i, prompt in enumerate(progress.tqdm(prompts, desc="Processando turnos")):
        print(f"--- Turno {i+1}/{len(prompts)}: {prompt} ---")
        
        image_paths = [input_image] + [path for _, path in output_images_with_paths]
        context_prompts = prompts[:i+1]
        
        prompt_config = OmegaConf.create({
            "index": i, "img_paths": image_paths, "context": context_prompts, "aspect_ratio": "keep_ratio"
        })
        
        turn_seed = seed if seed == -1 else seed + i
        result_image = _execute_vincie_logic(prompt_config, steps, cfg_scale, turn_seed, pad_img_placeholder=True, resolution=512)
        
        temp_path = os.path.join("/tmp", f"{uuid.uuid4()}.png")
        result_image.save(temp_path)
        output_images_with_paths.append((result_image, temp_path))

    return [img for img, _ in output_images_with_paths]


def run_multi_concept_inference(prompt: str, *images: str) -> Image.Image:
    """Handler para a aba 'Composição de Conceitos'."""
    image_paths = [img for img in images if img is not None]
    if not image_paths or not prompt.strip():
        raise gr.Error("É necessário um prompt e pelo menos uma imagem de entrada.")
        
    _print_params("Composição de Conceitos", prompt=prompt, num_images=len(image_paths))

    prefix_prompts = [f"<IMG{i+1}>: " for i in range(len(image_paths) - 1)]
    all_prompts = prefix_prompts + [prompt]

    prompt_config = OmegaConf.create({
        "index": 0, "img_paths": image_paths, "context": all_prompts, "aspect_ratio": "1:1"
    })
    
    return _execute_vincie_logic(prompt_config, steps=50, cfg_scale=7.5, seed=1, pad_img_placeholder=False, resolution=512)


def _print_params(mode: str, **kwargs: Any):
    """Função auxiliar para logging formatado das requisições."""
    log_message = f"\n{'='*50}\nINFO: Nova requisição - Modo: {mode}\n"
    for key, value in kwargs.items():
        log_message += f"  - {key.replace('_', ' ').title()}: {value}\n"
    log_message += f"{'='*50}\n"
    print(log_message)


# --- 4. CONSTRUÇÃO DA INTERFACE GRadio ---

def create_ui():
    """Cria e retorna a interface Gradio completa com todas as abas e controles."""
    with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="VINCIE Playground") as demo:
        gr.Markdown("# 🖼️ **VINCIE Playground**\nExplore as diferentes capacidades do modelo VINCIE.")

        with gr.Accordion("Opções Avançadas (para Abas 1 e 2)", open=False):
            steps_input = gr.Slider(label="Passos de Inferência", minimum=10, maximum=100, step=1, value=50)
            cfg_scale_input = gr.Slider(label="Escala de Orientação (CFG)", minimum=1.0, maximum=15.0, step=0.5, value=7.5)
            seed_input = gr.Number(label="Semente (Seed)", value=-1, precision=0, info="Use -1 para aleatório.")

        with gr.Tabs():
            # Aba 1: Edição Simples
            with gr.TabItem("Edição Simples"):
                with gr.Row():
                    with gr.Column(scale=1):
                        single_turn_img_in = gr.Image(type="filepath", label="Imagem de Entrada")
                        single_turn_prompt = gr.Textbox(lines=2, label="Prompt de Edição")
                        with gr.Accordion("Opções de Imagem", open=True):
                             aspect_ratio_input = gr.Dropdown(label="Aspect Ratio", choices=["keep_ratio", "1:1", "16:9", "9:16", "4:3", "3:4"], value="keep_ratio")
                             resolution_input = gr.Slider(label="Resolução (lado menor)", minimum=256, maximum=1024, step=64, value=512)
                        single_turn_button = gr.Button("Gerar", variant="primary")
                    with gr.Column(scale=1):
                        single_turn_img_out = gr.Image(label="Resultado", interactive=False, height=512)
                gr.Examples([["/app/VINCIE/assets/woman_pineapple.png", "Adicione uma coroa na cabeça da mulher."]], [single_turn_img_in, single_turn_prompt])

            # Aba 2: Edição em Múltiplos Turnos
            with gr.TabItem("Edição em Múltiplos Turnos"):
                with gr.Row():
                    with gr.Column(scale=1):
                        multi_turn_img_in = gr.Image(type="filepath", label="Imagem de Entrada")
                        multi_turn_prompts = gr.Textbox(lines=5, label="Prompts (um por linha)", placeholder="Turno 1: faça isso\nTurno 2: agora mude aquilo...")
                        multi_turn_button = gr.Button("Gerar Sequência", variant="primary")
                    with gr.Column(scale=1):
                        multi_turn_gallery_out = gr.Gallery(label="Resultados dos Turnos", columns=3, height="auto")

            # Aba 3: Composição de Conceitos
            with gr.TabItem("Composição de Conceitos"):
                gr.Markdown("Faça o upload de até 6 imagens (`<IMG0>` a `<IMG5>`) e escreva um prompt que as combine para gerar uma nova imagem (`<IMG6>`).")
                with gr.Row():
                    concept_inputs = [gr.Image(type="filepath", label=f"Imagem {i} (<IMG{i}>)") for i in range(6)]
                concept_prompt = gr.Textbox(lines=4, label="Prompt de Composição Final", value="Baseado em <IMG0>, <IMG1> e <IMG2>, crie um retrato de uma família com o pai de <IMG0>, a mãe de <IMG1> e o cachorro de <IMG2> em um parque. Saída <IMG6>:")
                concept_button = gr.Button("Compor Imagem", variant="primary")
                concept_img_out = gr.Image(label="Resultado da Composição", interactive=False, height=512)
                gr.Examples(
                    [[
                        "Baseado em <IMG0>, <IMG1> e <IMG2>, crie um retrato de uma família com o pai de <IMG0>, a mãe de <IMG1> e o cachorro de <IMG2> em um parque. Saída <IMG6>:",
                        "/app/VINCIE/assets/father.png", "/app/VINCIE/assets/mother.png", "/app/VINCIE/assets/dog1.png"
                    ]],
                    [concept_prompt] + concept_inputs
                )

        # Conecta os botões às suas respectivas funções de backend
        single_turn_button.click(fn=run_single_turn_inference, inputs=[single_turn_img_in, single_turn_prompt, aspect_ratio_input, resolution_input, steps_input, cfg_scale_input, seed_input], outputs=[single_turn_img_out])
        multi_turn_button.click(fn=run_multi_turn_inference, inputs=[multi_turn_img_in, multi_turn_prompts, steps_input, cfg_scale_input, seed_input], outputs=[multi_turn_gallery_out])
        concept_button.click(fn=run_multi_concept_inference, inputs=[concept_prompt] + concept_inputs, outputs=[concept_img_out])
        
    return demo

# --- 5. PONTO DE ENTRADA DA APLICAÇÃO ---

if __name__ == "__main__":
    setup_model()
    ui = create_ui()
    
    server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
    server_port = int(os.environ.get("GRADIO_SERVER_PORT", 7860))
    enable_queue = os.environ.get("GRADIO_ENABLE_QUEUE", "True").lower() == "true"

    print(f"INFO: Lançando a interface Gradio em http://{server_name}:{server_port}")
    if enable_queue:
        ui.queue()

    ui.launch(server_name=server_name, server_port=server_port)