File size: 5,347 Bytes
6a4771c
 
8d16ec9
6a4771c
 
 
2bbf193
8d16ec9
6a4771c
0d9d6b6
 
 
6a4771c
8d16ec9
2bbf193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a4771c
 
 
0d9d6b6
 
 
6a4771c
 
50af147
6a4771c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d9d6b6
 
 
6a4771c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50af147
6a4771c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
import spaces
import gradio as gr
from diffusers import FluxInpaintPipeline
import random
import numpy as np
import google.generativeai as genai

MARKDOWN = """
# Prompt Canvas🎨
Thanks to [Black Forest Labs](https://huggingface.co/black-forest-labs) team for creating this amazing model,
and a big thanks to [Gothos](https://github.com/Gothos) for taking it to the next level by enabling inpainting with the FLUX.
"""

#Gemini Setup
genai.configure(api_key = os.environ['Gemini_API'])
gemini_flash = genai.GenerativeModel(model_name='gemini-1.5-flash-002')

def gemini_predict(prompt):
    system_message = f"""You are the best text analyser.
                         You have to analyse a user query and identify what the user wants to change, from a given user query.
        
                         Examples:
                             Query: Change Lipstick colour to blue
                             Response: Lips
        
                             Query: Add a nose stud
                             Response: Nose
        
                             Query: Add a wallpaper to the right wall
                             Response: Right wall
        
                             Query: Change the Sofa's colour to Purple
                             Response: Sofa
        
                        Your response should be in 1 or 2-3 words
                        Query : {prompt}
                        """
    response = gemini_flash.generate_content(system_message)
    return(response.text)


MAX_SEED = np.iinfo(np.int32).max
DEVICE = "cuda" #if torch.cuda.is_available() else "cpu"

inpaint_pipe = FluxInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)



@spaces.GPU()
def process(input_image_editor, mask_image, input_text, strength, seed, randomize_seed, num_inference_steps, guidance_scale=3.5, progress=gr.Progress(track_tqdm=True)):
    if not input_text:
        raise gr.Error("Please enter a text prompt.")

    image = input_image_editor['background']

    if not image:
        raise gr.Error("Please upload an image.")

    width, height = image.size

    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
        
    generator = torch.Generator(device=DEVICE).manual_seed(seed)

    result = inpaint_pipe(prompt=input_text, image=image, mask_image=mask_image, width=width, height=height,
                          strength=strength, num_inference_steps=num_inference_steps, generator=generator,
                          guidance_scale=guidance_scale).images[0]

    return result, mask_image, seed

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(MARKDOWN)
    with gr.Row():
        with gr.Column(scale=1):
            input_image_component = gr.ImageEditor(
                label='Image',
                type='pil',
                sources=["upload", "webcam"],
                image_mode='RGB',
                layers=False,
                brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
            input_text_component = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )
            with gr.Accordion("Advanced Settings", open=False):
                strength_slider = gr.Slider(
                    minimum=0.0,
                    maximum=1.0,
                    value=0.7,
                    step=0.01,
                    label="Strength"
                )
                num_inference_steps = gr.Slider(
                    minimum=1,
                    maximum=100,
                    value=30,
                    step=1,
                    label="Number of inference steps"
                )
                guidance_scale = gr.Slider(
                    label="Guidance Scale",
                    minimum=1,
                    maximum=15,
                    step=0.1,
                    value=3.5,
                )
                seed_number = gr.Number(
                    label="Seed", 
                    value=42,
                    precision=0
                )
                randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
            with gr.Accordion("Upload a mask", open=False):
                uploaded_mask_component = gr.Image(label="Already made mask (black pixels will be preserved, white pixels will be redrawn)", sources=["upload"], type="pil")
            submit_button_component = gr.Button(
                value='Inpaint', variant='primary')
        with gr.Column(scale=1):
            output_image_component = gr.Image(
                type='pil', image_mode='RGB', label='Generated Image')
            output_mask_component = gr.Image(
                type='pil', image_mode='RGB', label='Generated Mask')
            with gr.Accordion("Debug Info", open=False):
                output_seed = gr.Number(label="Used Seed")

    submit_button_component.click(
        fn=process,
        inputs=[input_image_component, uploaded_mask_component, input_text_component, strength_slider, seed_number, randomize_seed, num_inference_steps, guidance_scale],
        outputs=[output_image_component, output_mask_component, output_seed]
    )

demo.launch(debug=False, show_error=True)