Spaces:
Running
on
Zero
Running
on
Zero
MohamedRashad
commited on
Commit
·
f7f1ca1
1
Parent(s):
041b736
Refactor app.py to improve UI layout and rename weight download function; update import path for AutoEncoder in vae.py
Browse files- app.py +29 -12
- models/bsq_vae/vae.py +1 -1
app.py
CHANGED
@@ -16,13 +16,12 @@ from typing import List
|
|
16 |
|
17 |
import cv2
|
18 |
import numpy as np
|
19 |
-
import pandas as pd
|
20 |
import torch
|
21 |
import torch.nn.functional as F
|
22 |
from PIL import Image, ImageEnhance
|
23 |
import PIL.Image as PImage
|
24 |
from torchvision.transforms.functional import to_tensor
|
25 |
-
from transformers import AutoTokenizer, T5EncoderModel, T5TokenizerFast
|
26 |
from huggingface_hub import hf_hub_download
|
27 |
import gradio as gr
|
28 |
import spaces
|
@@ -34,7 +33,7 @@ from utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates
|
|
34 |
torch._dynamo.config.cache_size_limit = 64
|
35 |
|
36 |
# Define a function to download weights if not present
|
37 |
-
def
|
38 |
try:
|
39 |
model_file = weights_path / 'infinity_2b_reg.pth'
|
40 |
if not model_file.exists():
|
@@ -361,7 +360,7 @@ def load_transformer(vae, args):
|
|
361 |
# Set up paths
|
362 |
weights_path = Path(__file__).parent / 'weights'
|
363 |
weights_path.mkdir(exist_ok=True)
|
364 |
-
|
365 |
|
366 |
# Device setup
|
367 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
@@ -451,16 +450,34 @@ with gr.Blocks() as demo:
|
|
451 |
gr.Markdown("<h1><center>Infinity Image Generator</center></h1>")
|
452 |
|
453 |
with gr.Row():
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
460 |
|
461 |
-
|
462 |
-
|
|
|
|
|
|
|
463 |
|
|
|
|
|
|
|
|
|
464 |
generate_button.click(
|
465 |
generate_image,
|
466 |
inputs=[prompt, cfg, tau, h_div_w, seed, enable_positive_prompt],
|
|
|
16 |
|
17 |
import cv2
|
18 |
import numpy as np
|
|
|
19 |
import torch
|
20 |
import torch.nn.functional as F
|
21 |
from PIL import Image, ImageEnhance
|
22 |
import PIL.Image as PImage
|
23 |
from torchvision.transforms.functional import to_tensor
|
24 |
+
from transformers import AutoTokenizer, T5EncoderModel, T5TokenizerFast
|
25 |
from huggingface_hub import hf_hub_download
|
26 |
import gradio as gr
|
27 |
import spaces
|
|
|
33 |
torch._dynamo.config.cache_size_limit = 64
|
34 |
|
35 |
# Define a function to download weights if not present
|
36 |
+
def download_infinity_weights(weights_path):
|
37 |
try:
|
38 |
model_file = weights_path / 'infinity_2b_reg.pth'
|
39 |
if not model_file.exists():
|
|
|
360 |
# Set up paths
|
361 |
weights_path = Path(__file__).parent / 'weights'
|
362 |
weights_path.mkdir(exist_ok=True)
|
363 |
+
download_infinity_weights(weights_path)
|
364 |
|
365 |
# Device setup
|
366 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
450 |
gr.Markdown("<h1><center>Infinity Image Generator</center></h1>")
|
451 |
|
452 |
with gr.Row():
|
453 |
+
with gr.Column():
|
454 |
+
# Prompt Settings
|
455 |
+
gr.Markdown("### Prompt Settings")
|
456 |
+
prompt = gr.Textbox(label="Prompt", value="alien spaceship enterprise", placeholder="Enter your prompt here...")
|
457 |
+
enable_positive_prompt = gr.Checkbox(label="Enable Positive Prompt", value=False, info="Enhance prompts with positive attributes for faces.")
|
458 |
+
|
459 |
+
# Image Settings
|
460 |
+
gr.Markdown("### Image Settings")
|
461 |
+
with gr.Row():
|
462 |
+
cfg = gr.Slider(label="CFG (Classifier-Free Guidance)", minimum=1, maximum=10, step=0.5, value=3, info="Controls the strength of the prompt.")
|
463 |
+
tau = gr.Slider(label="Tau (Temperature)", minimum=0.1, maximum=1.0, step=0.1, value=0.5, info="Controls the randomness of the output.")
|
464 |
+
with gr.Row():
|
465 |
+
h_div_w = gr.Slider(label="Aspect Ratio (Height/Width)", minimum=0.5, maximum=2.0, step=0.1, value=1.0, info="Set the aspect ratio of the generated image.")
|
466 |
+
seed = gr.Number(label="Seed", value=random.randint(0, 10000), info="Set a seed for reproducibility.")
|
467 |
+
|
468 |
+
# Generate Button
|
469 |
+
generate_button = gr.Button("Generate Image", variant="primary")
|
470 |
|
471 |
+
with gr.Column():
|
472 |
+
# Output Section
|
473 |
+
gr.Markdown("### Generated Image")
|
474 |
+
output_image = gr.Image(label="Generated Image", type="pil")
|
475 |
+
gr.Markdown("**Tip:** Right-click the image to save it.")
|
476 |
|
477 |
+
# Error Handling
|
478 |
+
error_message = gr.Textbox(label="Error Message", visible=False)
|
479 |
+
|
480 |
+
# Link the generate button to the image generation function
|
481 |
generate_button.click(
|
482 |
generate_image,
|
483 |
inputs=[prompt, cfg, tau, h_div_w, seed, enable_positive_prompt],
|
models/bsq_vae/vae.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import argparse
|
2 |
import torch
|
3 |
|
4 |
-
from
|
5 |
|
6 |
def load_cnn(model, state_dict, prefix, expand=False, use_linear=False):
|
7 |
delete_keys = []
|
|
|
1 |
import argparse
|
2 |
import torch
|
3 |
|
4 |
+
from models.bsq_vae.flux_vqgan import AutoEncoder
|
5 |
|
6 |
def load_cnn(model, state_dict, prefix, expand=False, use_linear=False):
|
7 |
delete_keys = []
|