Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,213 Bytes
412ada8 5041f6c be34e7c 3da86ac 5041f6c 3da86ac 412ada8 3da86ac 656934c 5741e23 3da86ac 5041f6c 3da86ac 1112f1b 3da86ac 1112f1b 656934c 5041f6c 3da86ac 656934c 3da86ac 656934c 3da86ac 5b8a47b 3da86ac 656934c 3da86ac 656934c 5041f6c 656934c 3da86ac 2d904b5 8e64bf0 656934c ef893ea 9fbff60 3e68751 62ed9c9 29ab471 62ed9c9 3e68751 d227cb9 3e68751 d601aed 77ce130 d601aed 3e68751 db4a96c 3da86ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import spaces
import gradio as gr
import torch
torch.jit.script = lambda f: f # Avoid script error in lambda
from t2v_metrics import VQAScore
from functools import lru_cache
# Remove any global model loading or CUDA initialization
# Do not call torch.cuda.is_available() at the global scope
@lru_cache()
def get_model(model_name):
# This function will cache the model per process
return VQAScore(model=model_name, device="cuda")
@spaces.GPU # Decorate the function to use GPU
def generate(model_name, image, text):
# Load the model inside the GPU context
model_pipe = get_model(model_name)
print("Image:", image)
print("Text:", text)
print("Using model:", model_name)
try:
# Perform the model inference
result = model_pipe(images=[image], texts=[text]).cpu()[0][0].item()
print("Result:", result)
except RuntimeError as e:
print(f"RuntimeError during model inference: {e}")
raise e
return result
@spaces.GPU # Decorate the function to use GPU
def rank_images(model_name, images, text):
# Load the model inside the GPU context
model_pipe = get_model(model_name)
images = [image_tuple[0] for image_tuple in images]
print("Images:", images)
print("Text:", text)
print("Using model:", model_name)
try:
# Perform the model inference on all images
results = model_pipe(images=images, texts=[text]).cpu()[:, 0].tolist()
print("Initial results:", results)
# Rank results
ranked_results = sorted(zip(images, results), key=lambda x: x[1], reverse=True)
# Pair images with their scores and rank
ranked_images = [
(img, f"Rank: {rank + 1} - Score: {score:.2f}")
for rank, (img, score) in enumerate(ranked_results)
]
print("Ranked Results:", ranked_results)
except RuntimeError as e:
print(f"RuntimeError during model inference: {e}")
raise e
return ranked_images
### EXAMPLES ###
example_imgs = ["0_imgs/DALLE3.png",
"0_imgs/DeepFloyd.jpg",
"0_imgs/Midjourney.jpg",
"0_imgs/SDXL.jpg"]
example_prompt0 = "Two dogs of different breeds playfully chasing around a tree"
example_prompt1 = "Two dogs of the same breed playing on the grass"
# Custom component for loading examples
def load_example(model_name, images, prompt):
return model_name, images, prompt
# Create the second demo: VQAScore Ranking
with gr.Blocks() as demo_vqascore_ranking:
# gr.Markdown("# VQAScore Ranking\nThis model ranks a gallery of images based on their similarity to a text prompt.")
gr.Markdown("""
# VQAScore Ranking
This demo ranks a gallery of images by their VQAScores to an input text prompt. Try examples 1 and 2, or use your own images and prompts.
If you encounter errors, the model may not have loaded on the GPU properly. Retrying usually resolves this issue.
""")
with gr.Row():
with gr.Column():
model_dropdown = gr.Dropdown(["clip-flant5-xxl", "clip-flant5-xl"], value="clip-flant5-xxl", label="Model Name")
prompt = gr.Textbox(label="Prompt")
gallery = gr.Gallery(label="Input Image(s)", elem_id="input-gallery", columns=4, allow_preview=True)
rank_button = gr.Button("Submit")
with gr.Column():
ranked_gallery = gr.Gallery(label="Output: Ranked Images with Scores", elem_id="ranked-gallery", columns=4, allow_preview=True)
rank_button.click(fn=rank_images, inputs=[model_dropdown, gallery, prompt], outputs=ranked_gallery)
example1_button = gr.Button("Load Example 1")
example2_button = gr.Button("Load Example 2")
example1_button.click(fn=lambda: load_example("clip-flant5-xxl", example_imgs, example_prompt0), inputs=[], outputs=[model_dropdown, gallery, prompt])
example2_button.click(fn=lambda: load_example("clip-flant5-xxl", example_imgs, example_prompt1), inputs=[], outputs=[model_dropdown, gallery, prompt])
# Launch the interface
demo_vqascore_ranking.queue()
demo_vqascore_ranking.launch(share=True) |