MohamedRashad commited on
Commit
f7f1ca1
·
1 Parent(s): 041b736

Refactor app.py to improve UI layout and rename weight download function; update import path for AutoEncoder in vae.py

Browse files
Files changed (2) hide show
  1. app.py +29 -12
  2. models/bsq_vae/vae.py +1 -1
app.py CHANGED
@@ -16,13 +16,12 @@ from typing import List
16
 
17
  import cv2
18
  import numpy as np
19
- import pandas as pd
20
  import torch
21
  import torch.nn.functional as F
22
  from PIL import Image, ImageEnhance
23
  import PIL.Image as PImage
24
  from torchvision.transforms.functional import to_tensor
25
- from transformers import AutoTokenizer, T5EncoderModel, T5TokenizerFast, T5Tokenizer, T5ForConditionalGeneration
26
  from huggingface_hub import hf_hub_download
27
  import gradio as gr
28
  import spaces
@@ -34,7 +33,7 @@ from utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates
34
  torch._dynamo.config.cache_size_limit = 64
35
 
36
  # Define a function to download weights if not present
37
- def download_weights(weights_path):
38
  try:
39
  model_file = weights_path / 'infinity_2b_reg.pth'
40
  if not model_file.exists():
@@ -361,7 +360,7 @@ def load_transformer(vae, args):
361
  # Set up paths
362
  weights_path = Path(__file__).parent / 'weights'
363
  weights_path.mkdir(exist_ok=True)
364
- download_weights(weights_path)
365
 
366
  # Device setup
367
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -451,16 +450,34 @@ with gr.Blocks() as demo:
451
  gr.Markdown("<h1><center>Infinity Image Generator</center></h1>")
452
 
453
  with gr.Row():
454
- prompt = gr.Textbox(label="Prompt", value="alien spaceship enterprise")
455
- cfg = gr.Slider(label="CFG", minimum=1, maximum=10, step=0.5, value=3)
456
- tau = gr.Slider(label="Tau", minimum=0.1, maximum=1.0, step=0.1, value=0.5)
457
- h_div_w = gr.Slider(label="Aspect Ratio (Height/Width)", minimum=0.5, maximum=2.0, step=0.1, value=1.0)
458
- seed = gr.Number(label="Seed", value=random.randint(0, 10000))
459
- enable_positive_prompt = gr.Checkbox(label="Enable Positive Prompt", value=False)
 
 
 
 
 
 
 
 
 
 
 
460
 
461
- generate_button = gr.Button("Generate Image")
462
- output_image = gr.Image(label="Generated Image", type="pil")
 
 
 
463
 
 
 
 
 
464
  generate_button.click(
465
  generate_image,
466
  inputs=[prompt, cfg, tau, h_div_w, seed, enable_positive_prompt],
 
16
 
17
  import cv2
18
  import numpy as np
 
19
  import torch
20
  import torch.nn.functional as F
21
  from PIL import Image, ImageEnhance
22
  import PIL.Image as PImage
23
  from torchvision.transforms.functional import to_tensor
24
+ from transformers import AutoTokenizer, T5EncoderModel, T5TokenizerFast
25
  from huggingface_hub import hf_hub_download
26
  import gradio as gr
27
  import spaces
 
33
  torch._dynamo.config.cache_size_limit = 64
34
 
35
  # Define a function to download weights if not present
36
+ def download_infinity_weights(weights_path):
37
  try:
38
  model_file = weights_path / 'infinity_2b_reg.pth'
39
  if not model_file.exists():
 
360
  # Set up paths
361
  weights_path = Path(__file__).parent / 'weights'
362
  weights_path.mkdir(exist_ok=True)
363
+ download_infinity_weights(weights_path)
364
 
365
  # Device setup
366
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
450
  gr.Markdown("<h1><center>Infinity Image Generator</center></h1>")
451
 
452
  with gr.Row():
453
+ with gr.Column():
454
+ # Prompt Settings
455
+ gr.Markdown("### Prompt Settings")
456
+ prompt = gr.Textbox(label="Prompt", value="alien spaceship enterprise", placeholder="Enter your prompt here...")
457
+ enable_positive_prompt = gr.Checkbox(label="Enable Positive Prompt", value=False, info="Enhance prompts with positive attributes for faces.")
458
+
459
+ # Image Settings
460
+ gr.Markdown("### Image Settings")
461
+ with gr.Row():
462
+ cfg = gr.Slider(label="CFG (Classifier-Free Guidance)", minimum=1, maximum=10, step=0.5, value=3, info="Controls the strength of the prompt.")
463
+ tau = gr.Slider(label="Tau (Temperature)", minimum=0.1, maximum=1.0, step=0.1, value=0.5, info="Controls the randomness of the output.")
464
+ with gr.Row():
465
+ h_div_w = gr.Slider(label="Aspect Ratio (Height/Width)", minimum=0.5, maximum=2.0, step=0.1, value=1.0, info="Set the aspect ratio of the generated image.")
466
+ seed = gr.Number(label="Seed", value=random.randint(0, 10000), info="Set a seed for reproducibility.")
467
+
468
+ # Generate Button
469
+ generate_button = gr.Button("Generate Image", variant="primary")
470
 
471
+ with gr.Column():
472
+ # Output Section
473
+ gr.Markdown("### Generated Image")
474
+ output_image = gr.Image(label="Generated Image", type="pil")
475
+ gr.Markdown("**Tip:** Right-click the image to save it.")
476
 
477
+ # Error Handling
478
+ error_message = gr.Textbox(label="Error Message", visible=False)
479
+
480
+ # Link the generate button to the image generation function
481
  generate_button.click(
482
  generate_image,
483
  inputs=[prompt, cfg, tau, h_div_w, seed, enable_positive_prompt],
models/bsq_vae/vae.py CHANGED
@@ -1,7 +1,7 @@
1
  import argparse
2
  import torch
3
 
4
- from infinity.models.bsq_vae.flux_vqgan import AutoEncoder
5
 
6
  def load_cnn(model, state_dict, prefix, expand=False, use_linear=False):
7
  delete_keys = []
 
1
  import argparse
2
  import torch
3
 
4
+ from models.bsq_vae.flux_vqgan import AutoEncoder
5
 
6
  def load_cnn(model, state_dict, prefix, expand=False, use_linear=False):
7
  delete_keys = []