File size: 7,618 Bytes
3166b97
 
07e7ef5
3166b97
 
 
 
 
 
 
e13c42d
 
6efb87d
 
 
 
595acb5
6efb87d
3166b97
6efb87d
 
 
4466abf
6efb87d
4466abf
6efb87d
 
 
 
 
e13c42d
6efb87d
 
 
3166b97
6efb87d
 
 
 
 
 
3166b97
 
 
 
 
76c8ebf
3166b97
 
 
 
 
 
6efb87d
 
 
 
3166b97
 
 
76c8ebf
3166b97
6efb87d
3166b97
 
 
6efb87d
3166b97
 
 
 
 
 
 
 
 
e2468cc
3166b97
07e7ef5
595acb5
e2468cc
 
3166b97
 
e13c42d
4860fd0
6efb87d
4860fd0
 
 
 
3166b97
 
 
 
6efb87d
3166b97
 
 
 
76c8ebf
3166b97
e13c42d
 
3166b97
 
 
6efb87d
3166b97
 
 
 
72a6c2f
3166b97
 
 
 
6efb87d
3166b97
 
 
 
72a6c2f
3166b97
 
 
 
6efb87d
3166b97
 
 
 
 
72a6c2f
3166b97
 
 
 
 
6efb87d
3166b97
 
 
 
 
72a6c2f
3166b97
 
e13c42d
6efb87d
 
 
 
72a6c2f
 
 
 
 
3166b97
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# AUTOGENERATED! DO NOT EDIT! File to edit: app.ipynb.

# %% auto 0 
__all__ = ['block', 'make_clickable_model', 'make_clickable_user', 'get_submissions']

# %% app.ipynb 0
import gradio as gr
import pandas as pd
from huggingface_hub import list_models

from diffusers import StableDiffusionPipeline

# %% app.ipynb 1
def get_model_list(category):
    submissions_list = list_models(filter=["dreambooth-hackathon", category], full=True)
    spaces_pipeline_load = [submission.id for submission in submissions_list ]
    return gr.Dropdown.update(choices=spaces_pipeline_load , value=spaces_pipeline_load[4])


def get_initial_prompt(model_nm):
    #a photo of peterj/shbrcky-dog 
    user_model_nm = model_nm.split('/')[-1]
    if '-' in user_model_nm:
        prompt = " ".join(user_model_nm.split('-'))
    else:
        prompt = user_model_nm
    return gr.Textbox.update(value="a photo of " + prompt + " ")

def make_demo(model_name, prompt): #link=None):
    #prompt = "a photo of " + ' '.join(model_name.split('/')[-1].split['-']) + str(prompt)
    pipeline = StableDiffusionPipeline.from_pretrained(model_name) #("ashiqabdulkhader/shiba-dog") #('pharma/sugar-glider')
    image_demo = pipeline(prompt).images[0]    
    return image_demo #gr.Button.update()


def make_clickable_model(model_name, link=None):
    if link is None:
        link = "https://huggingface.co/" + model_name
    # Remove user from model name
    return f'<a target="_blank" href="{link}">{model_name.split("/")[-1]}</a>'
    
def make_clickable_user(user_id):
    link = "https://huggingface.co/" + user_id
    return f'<a  target="_blank" href="{link}">{user_id}</a>'

# %% app.ipynb 2
def get_submissions(category, prompt):
    submissions = list_models(filter=["dreambooth-hackathon", category], full=True)
    leaderboard_models = []

    for submission in submissions:
        # user, model, likes
        user_id = submission.id.split("/")[0]
        model_nm = submission.id.split("/")[-1]
        if '-' in model_nm:
            model_nm = " ".join(model_nm.split('-'))
        #button_html = get_button()    
        leaderboard_models.append(
            (
                make_clickable_user(user_id),
                make_clickable_model(submission.id, prompt),
                submission.likes,
                #button_html #'a photo of ' + model_nm + " "
            )
        )

    df = pd.DataFrame(data=leaderboard_models, columns=["User", "Model", "Likes", ])
    df.sort_values(by=["Likes"], ascending=False, inplace=True)
    df.insert(0, "Rank", list(range(1, len(df) + 1)))
    return df

# %% app.ipynb 3
block = gr.Blocks()

with block:
    gr.Markdown(
        """# Gradio-powered leaderboard for the DreamBooth Hackathon
    
    Welcome to this Gradio-powered leaderboard! Select a theme and one of the dreambooth models trained by hacakthon-participants, and key in your prompt as shown (eg., a photo of Shiba dog in a jungle). Note that, the image generation might take long (around 400 seconds) as it will have to load the respective model pipeline into memory.  
    <br>**If you like a model demo, click on the model name in the table below and UPVOTE the model on Huggingface hub**<br><br>
    DreamBooth Hackathon - is an ongoing community event where particpants **personalise a Stable Diffusion model** by fine-tuning it with a powerful technique called [_DreamBooth_](https://arxiv.org/abs/2208.12242). This technique allows one to implant a subject into the output domain of the model such that it can be synthesized with a _unique identifier_ (eg., shiba dog) in the prompt. 
        This competition is composed of 5 _themes_ - Animals, Science, Food, Lanscapes, and Wildcards. For details on how to participate, check out the hackathon's guide [here](https://github.com/huggingface/diffusion-models-class/blob/main/hackathon/README.md).
    """
    )
    with gr.Row():
        prompt_in = gr.Textbox(label="Type in a Prompt in front of the given text..", value="")
        with gr.Column():
            theme = gr.Radio(label="Pick a Theme",choices=["animal","science", "food", "landscape", "wildcard"] )
            model_list = gr.Dropdown(label="Pick a Dreamboooth model", choices = []) # choices=
            button_in = gr.Button(Value = "Generate Image")
        image_out = gr.Image(label="Generated image with your choice of Dreambooth model")
    with gr.Tabs():
        with gr.TabItem("Animal 🐨"):
            with gr.Row():
                animal_data = gr.components.Dataframe(
                    type="pandas", datatype=["number", "markdown", "markdown", "number","str"], interactive = True
                )
            with gr.Row():
                data_run = gr.Button("Refresh")
                data_run.click(
                    get_submissions, inputs=[gr.Variable("animal"), prompt_in], outputs=animal_data
                )

                
        with gr.TabItem("Science 🔬"):
            with gr.Row():
                science_data = gr.components.Dataframe(
                    type="pandas", datatype=["number", "markdown", "markdown", "number", "str"], interactive = True
                )
            with gr.Row():
                data_run = gr.Button("Refresh")
                data_run.click(
                    get_submissions, inputs=[gr.Variable("science"), prompt_in], outputs=science_data
                )
        with gr.TabItem("Food 🍔"):
            with gr.Row():
                food_data = gr.components.Dataframe(
                    type="pandas", datatype=["number", "markdown", "markdown", "number", "str"], interactive = True
                )
            with gr.Row():
                data_run = gr.Button("Refresh")
                data_run.click(
                    get_submissions, inputs=[gr.Variable("food"), prompt_in], outputs=food_data
                )
        with gr.TabItem("Landscape 🏔"):
            with gr.Row():
                landscape_data = gr.components.Dataframe(
                    type="pandas", datatype=["number", "markdown", "markdown", "number", "str"], interactive = True
                )
            with gr.Row():
                data_run = gr.Button("Refresh")
                data_run.click(
                    get_submissions,
                    inputs=[gr.Variable("landscape"),prompt_in],
                    outputs=landscape_data,
                )
        with gr.TabItem("Wilcard 🔥"):
            with gr.Row():
                wildcard_data = gr.components.Dataframe(
                    type="pandas", datatype=["number", "markdown", "markdown", "number", "str"], interactive = True
                )
            with gr.Row():
                data_run = gr.Button("Refresh")
                data_run.click(
                    get_submissions,
                    inputs=[gr.Variable("wildcard"),prompt_in],
                    outputs=wildcard_data,
                )
    
    theme.change(get_model_list, theme, model_list )  
    model_list.change(get_initial_prompt, model_list, prompt_in )
    button_in.click(make_demo, [model_list, prompt_in], image_out)
    
    block.load(get_submissions, inputs=[gr.Variable("animal"), prompt_in], outputs=animal_data)
    block.load(get_submissions, inputs=[gr.Variable("science"), prompt_in], outputs=science_data)
    block.load(get_submissions, inputs=[gr.Variable("food"), prompt_in], outputs=food_data)
    block.load(get_submissions, inputs=[gr.Variable("landscape"), prompt_in], outputs=landscape_data)
    block.load(get_submissions, inputs=[gr.Variable("wildcard"), prompt_in], outputs=wildcard_data)


block.launch()