Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import torch | |
torch.jit.script = lambda f: f # Avoid script error in lambda | |
from t2v_metrics import VQAScore | |
from functools import lru_cache | |
# Remove any global model loading or CUDA initialization | |
# Do not call torch.cuda.is_available() at the global scope | |
def get_model(model_name): | |
# This function will cache the model per process | |
return VQAScore(model=model_name, device="cuda") | |
# Decorate the function to use GPU | |
def generate(model_name, image, text): | |
# Load the model inside the GPU context | |
model_pipe = get_model(model_name) | |
print("Image:", image) | |
print("Text:", text) | |
print("Using model:", model_name) | |
try: | |
# Perform the model inference | |
result = model_pipe(images=[image], texts=[text]).cpu()[0][0].item() | |
print("Result:", result) | |
except RuntimeError as e: | |
print(f"RuntimeError during model inference: {e}") | |
raise e | |
return result | |
# Decorate the function to use GPU | |
def rank_images(model_name, images, text): | |
# Load the model inside the GPU context | |
model_pipe = get_model(model_name) | |
images = [image_tuple[0] for image_tuple in images] | |
print("Images:", images) | |
print("Text:", text) | |
print("Using model:", model_name) | |
try: | |
# Perform the model inference on all images | |
results = model_pipe(images=images, texts=[text]).cpu()[:, 0].tolist() | |
print("Initial results:", results) | |
# Rank results | |
ranked_results = sorted(zip(images, results), key=lambda x: x[1], reverse=True) | |
# Pair images with their scores and rank | |
ranked_images = [ | |
(img, f"Rank: {rank + 1} - Score: {score:.2f}") | |
for rank, (img, score) in enumerate(ranked_results) | |
] | |
print("Ranked Results:", ranked_results) | |
except RuntimeError as e: | |
print(f"RuntimeError during model inference: {e}") | |
raise e | |
return ranked_images | |
### EXAMPLES ### | |
example_imgs = ["0_imgs/DALLE3.png", | |
"0_imgs/DeepFloyd.jpg", | |
"0_imgs/Midjourney.jpg", | |
"0_imgs/SDXL.jpg"] | |
example_prompt0 = "Two dogs of different breeds playfully chasing around a tree" | |
example_prompt1 = "Two dogs of the same breed playing on the grass" | |
# Custom component for loading examples | |
def load_example(model_name, images, prompt): | |
return model_name, images, prompt | |
# Create the second demo: VQAScore Ranking | |
with gr.Blocks() as demo_vqascore_ranking: | |
# gr.Markdown("# VQAScore Ranking\nThis model ranks a gallery of images based on their similarity to a text prompt.") | |
gr.Markdown(""" | |
# VQAScore Ranking | |
This demo ranks a gallery of images by their VQAScores to an input text prompt. Try examples 1 and 2, or use your own images and prompts. | |
If you encounter errors, the model may not have loaded on the GPU properly. Retrying usually resolves this issue. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
model_dropdown = gr.Dropdown(["clip-flant5-xxl", "clip-flant5-xl"], value="clip-flant5-xxl", label="Model Name") | |
prompt = gr.Textbox(label="Prompt") | |
gallery = gr.Gallery(label="Input Image(s)", elem_id="input-gallery", columns=4, allow_preview=True) | |
rank_button = gr.Button("Submit") | |
with gr.Column(): | |
ranked_gallery = gr.Gallery(label="Output: Ranked Images with Scores", elem_id="ranked-gallery", columns=4, allow_preview=True) | |
rank_button.click(fn=rank_images, inputs=[model_dropdown, gallery, prompt], outputs=ranked_gallery) | |
example1_button = gr.Button("Load Example 1") | |
example2_button = gr.Button("Load Example 2") | |
example1_button.click(fn=lambda: load_example("clip-flant5-xxl", example_imgs, example_prompt0), inputs=[], outputs=[model_dropdown, gallery, prompt]) | |
example2_button.click(fn=lambda: load_example("clip-flant5-xxl", example_imgs, example_prompt1), inputs=[], outputs=[model_dropdown, gallery, prompt]) | |
# Launch the interface | |
demo_vqascore_ranking.queue() | |
demo_vqascore_ranking.launch(share=True) |