import gradio as gr import torch from t2v_metrics import VQAScore, list_all_vqascore_models torch.jit.script = lambda f: f # Avoid script error in lambda def update_model(model_name): return VQAScore(model=model_name, device="cuda") # Use global variables for model pipe and current model name global model_pipe, cur_model_name cur_model_name = "clip-flant5-xl" model_pipe = update_model(cur_model_name) # Ensure GPU context manager is imported correctly (assuming spaces is a module you have) try: from spaces import GPU except ImportError: GPU = lambda duration: (lambda f: f) # Dummy decorator if spaces.GPU is not available @GPU(duration=20) def generate(model_name, image, text): global model_pipe, cur_model_name if model_name != cur_model_name: cur_model_name = model_name # Update the current model name model_pipe = update_model(model_name) print("Image:", image) # Debug: Print image path print("Text:", text) # Debug: Print text input print("Using model:", model_name) try: result = model_pipe(images=[image], texts=[text]).cpu()[0][0].item() # Perform the model inference print("Result:", result) except RuntimeError as e: print(f"RuntimeError during model inference: {e}") raise e return result @GPU(duration=20) def rank_images(model_name, images, text): global model_pipe, cur_model_name if model_name != cur_model_name: cur_model_name = model_name # Update the current model name model_pipe = update_model(model_name) print("Images:", images) # Debug: Print image paths print("Text:", text) # Debug: Print text input print("Using model:", model_name) try: results = model_pipe(images=images, texts=[text] * len(images)).cpu()[:, 0].tolist() # Perform the model inference on all images ranked_results = sorted(zip(images, results), key=lambda x: x[1], reverse=True) # Rank results ranked_images = [img for img, score in ranked_results] print("Ranked Results:", ranked_results) except RuntimeError as e: print(f"RuntimeError during model inference: {e}") raise e return ranked_images # Create the first demo demo_vqascore = gr.Interface( fn=generate, # function to call inputs=[ gr.Dropdown(["clip-flant5-xl", "clip-flant5-xxl"], label="Model Name"), gr.Image(type="filepath"), gr.Textbox(label="Prompt") ], # define the types of inputs outputs="number", # define the type of output title="VQAScore", # title of the app description="This model evaluates the similarity between an image and a text prompt." ) # Create the second demo demo_vqascore_ranking = gr.Interface( fn=rank_images, # function to call inputs=[ gr.Dropdown(["clip-flant5-xl", "clip-flant5-xxl"], label="Model Name"), gr.Gallery(label="Generated Images"), gr.Textbox(label="Prompt") ], # define the types of inputs outputs=gr.Gallery(label="Ranked Images"), # define the type of output title="VQAScore Ranking", # title of the app description="This model ranks a gallery of images based on their similarity to a text prompt." ) # Combine the demos into a tabbed interface tabbed_interface = gr.TabbedInterface([demo_vqascore, demo_vqascore_ranking], ["VQAScore", "VQAScore Ranking"]) # Launch the tabbed interface tabbed_interface.queue() tabbed_interface.launch()