File size: 4,213 Bytes
412ada8
5041f6c
 
be34e7c
3da86ac
 
5041f6c
3da86ac
 
412ada8
3da86ac
 
 
656934c
5741e23
3da86ac
5041f6c
3da86ac
 
1112f1b
3da86ac
 
1112f1b
656934c
5041f6c
3da86ac
 
656934c
 
 
 
 
 
 
3da86ac
656934c
3da86ac
 
5b8a47b
 
3da86ac
 
656934c
 
 
3da86ac
 
 
 
 
 
 
 
 
 
656934c
5041f6c
 
 
 
656934c
3da86ac
2d904b5
 
 
 
 
 
 
 
8e64bf0
656934c
ef893ea
 
 
9fbff60
3e68751
 
62ed9c9
 
 
29ab471
 
62ed9c9
3e68751
 
 
 
 
d227cb9
 
3e68751
 
d601aed
 
 
 
 
77ce130
 
d601aed
 
3e68751
db4a96c
 
3da86ac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import spaces
import gradio as gr
import torch
torch.jit.script = lambda f: f  # Avoid script error in lambda
from t2v_metrics import VQAScore
from functools import lru_cache

# Remove any global model loading or CUDA initialization
# Do not call torch.cuda.is_available() at the global scope

@lru_cache()
def get_model(model_name):
    # This function will cache the model per process
    return VQAScore(model=model_name, device="cuda")

@spaces.GPU  # Decorate the function to use GPU
def generate(model_name, image, text):
    # Load the model inside the GPU context
    model_pipe = get_model(model_name)
    
    print("Image:", image)
    print("Text:", text)
    print("Using model:", model_name)
    
    try:
        # Perform the model inference
        result = model_pipe(images=[image], texts=[text]).cpu()[0][0].item()
        print("Result:", result)
    except RuntimeError as e:
        print(f"RuntimeError during model inference: {e}")
        raise e
    
    return result

@spaces.GPU  # Decorate the function to use GPU
def rank_images(model_name, images, text):
    # Load the model inside the GPU context
    model_pipe = get_model(model_name)

    images = [image_tuple[0] for image_tuple in images]
    print("Images:", images)
    print("Text:", text)
    print("Using model:", model_name)
    
    try:
        # Perform the model inference on all images
        results = model_pipe(images=images, texts=[text]).cpu()[:, 0].tolist()
        print("Initial results:", results)
        # Rank results
        ranked_results = sorted(zip(images, results), key=lambda x: x[1], reverse=True)
        # Pair images with their scores and rank
        ranked_images = [
            (img, f"Rank: {rank + 1} - Score: {score:.2f}")
            for rank, (img, score) in enumerate(ranked_results)
        ]
        print("Ranked Results:", ranked_results)
    except RuntimeError as e:
        print(f"RuntimeError during model inference: {e}")
        raise e
    
    return ranked_images
    

### EXAMPLES ###
example_imgs = ["0_imgs/DALLE3.png", 
                "0_imgs/DeepFloyd.jpg",
                "0_imgs/Midjourney.jpg",
                "0_imgs/SDXL.jpg"]
example_prompt0 = "Two dogs of different breeds playfully chasing around a tree"
example_prompt1 = "Two dogs of the same breed playing on the grass"


# Custom component for loading examples
def load_example(model_name, images, prompt):
    return model_name, images, prompt

# Create the second demo: VQAScore Ranking
with gr.Blocks() as demo_vqascore_ranking:
    # gr.Markdown("# VQAScore Ranking\nThis model ranks a gallery of images based on their similarity to a text prompt.")
    gr.Markdown("""
    # VQAScore Ranking
    This demo ranks a gallery of images by their VQAScores to an input text prompt. Try examples 1 and 2, or use your own images and prompts.
    If you encounter errors, the model may not have loaded on the GPU properly. Retrying usually resolves this issue.
    """)
    
    with gr.Row():
        with gr.Column():
            model_dropdown = gr.Dropdown(["clip-flant5-xxl", "clip-flant5-xl"], value="clip-flant5-xxl", label="Model Name")
            prompt = gr.Textbox(label="Prompt")
            gallery = gr.Gallery(label="Input Image(s)", elem_id="input-gallery", columns=4, allow_preview=True)
            rank_button = gr.Button("Submit")
            
        with gr.Column():

            ranked_gallery = gr.Gallery(label="Output: Ranked Images with Scores", elem_id="ranked-gallery", columns=4, allow_preview=True)
            rank_button.click(fn=rank_images, inputs=[model_dropdown, gallery, prompt], outputs=ranked_gallery)

            
            example1_button = gr.Button("Load Example 1")
            example2_button = gr.Button("Load Example 2")
            example1_button.click(fn=lambda: load_example("clip-flant5-xxl", example_imgs, example_prompt0), inputs=[], outputs=[model_dropdown, gallery, prompt])
            example2_button.click(fn=lambda: load_example("clip-flant5-xxl", example_imgs, example_prompt1), inputs=[], outputs=[model_dropdown, gallery, prompt])
    
# Launch the interface
demo_vqascore_ranking.queue()
demo_vqascore_ranking.launch(share=True)