animrods commited on
Commit
d174eab
·
verified ·
1 Parent(s): 013786a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -131
app.py CHANGED
@@ -4,36 +4,35 @@ import numpy as np
4
  import diffusers
5
  import os
6
  import random
7
- import spaces
8
  from PIL import Image
9
 
10
  hf_token = os.environ.get("HF_TOKEN")
11
  from diffusers import AutoPipelineForText2Image
12
 
13
-
14
- device = "cuda" #if torch.cuda.is_available() else "cpu"
15
  pipe = AutoPipelineForText2Image.from_pretrained("briaai/BRIA-2.3", torch_dtype=torch.float16, force_zeros_for_empty_prompt=False).to(device)
16
  pipe.load_ip_adapter("briaai/Image-Prompt", subfolder='models', weight_name="ip_adapter_bria.bin")
17
  pipe.to(device)
18
- # default_negative_prompt= "" #"Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers"
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
 
22
  @spaces.GPU(enable_queue=True)
23
- def predict(prompt, ip_adapter_image, ip_adapter_scale=0.5, negative_prompt="", seed=100, randomize_seed=False, center_crop=False, width=1024, height=1024, guidance_scale=5.0, num_inference_steps=50, progress=gr.Progress(track_tqdm=True)):
24
  if randomize_seed:
25
  seed = random.randint(0, MAX_SEED)
26
-
27
  # Optionally resize images if center crop is not selected
28
  if not center_crop:
29
- ip_adapter_image = [image.resize((224, 224)) for image in ip_adapter_image]
30
-
 
31
  generator = torch.Generator(device="cuda").manual_seed(seed)
32
  pipe.set_ip_adapter_scale([ip_adapter_scale])
33
 
34
- image = pipe(
 
35
  prompt=prompt,
36
- ip_adapter_image=ip_adapter_image,
37
  negative_prompt=negative_prompt,
38
  height=height,
39
  width=width,
@@ -41,142 +40,61 @@ def predict(prompt, ip_adapter_image, ip_adapter_scale=0.5, negative_prompt="",
41
  guidance_scale=guidance_scale,
42
  num_images_per_prompt=1,
43
  generator=generator,
44
- ).images[0]
45
 
46
- return image, seed
47
 
48
- def swap_to_gallery(images):
49
- return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
50
-
51
  examples = [
52
- ["high quality", "example1.png", 1.0, "", 1000, False, False, 1152, 896],
53
- ["capybara", "example2.png", 0.7, "", 1000, False, False, 1152, 896],
54
  ]
55
 
56
- css="""
57
  #col-container {
58
- margin: 0 auto;
59
- max-width: 1024px;
60
- }
61
- #result img{
62
- object-position: top;
63
- }
64
- #result .image-container{
65
- height: 100%
66
  }
67
  """
 
68
  with gr.Blocks(css=css) as demo:
69
- with gr.Column(elem_id="col-container"):
70
- gr.Markdown(f"""
71
- # Bria's Image-Prompt-Adapter
72
- """)
73
-
74
  with gr.Row():
75
- with gr.Column():
76
- # ip_adapter_image = gr.Image(label="IP-Adapter Image", type="pil")
77
- ip_adapter_image = gr.Gallery(
78
- label="IP-Adapter Image/s", show_label=False, elem_id="image-gallery",
79
- columns=[3], rows=[1], object_fit="contain", height="auto")
80
-
81
- ip_adapter_scale = gr.Slider(
82
- label="Image Input Scale",
83
- info="Use 1 for creating image variations",
84
- minimum=0.0,
85
- maximum=1.0,
86
- step=0.05,
87
- value=1.0,
88
- )
89
- with gr.Column():
90
- result = gr.Image(label="Result", elem_id="result", format="png")
91
- prompt = gr.Text(
92
- label="Prompt",
93
- show_label=True,
94
- lines=1,
95
- placeholder="Enter your prompt",
96
- container=True,
97
- info='For image variation, leave empty or try a prompt like: "high quality".'
98
- )
99
-
100
  with gr.Row():
101
- width = gr.Slider(
102
- label="Width",
103
- minimum=256,
104
- maximum=2048,
105
- step=32,
106
- value=1024,
107
- )
108
- height = gr.Slider(
109
- label="Height",
110
- minimum=256,
111
- maximum=2048,
112
- step=32,
113
- value=1024,
114
- )
115
- run_button = gr.Button("Run", scale=0)
116
-
117
-
118
- with gr.Accordion("Advanced Settings", open=False):
119
- negative_prompt = gr.Text(
120
- label="Negative prompt",
121
- max_lines=1,
122
- placeholder="Enter a negative prompt",
123
  )
124
- seed = gr.Slider(
125
- label="Seed",
126
- minimum=0,
127
- maximum=MAX_SEED,
128
- step=1,
129
- value=1000,
130
  )
131
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
132
- center_crop = gr.Checkbox(label="Center Crop image", value=False, info="If not checked, the IP-Adapter image input would be resized to a square.")
133
- # with gr.Row():
134
- # width = gr.Slider(
135
- # label="Width",
136
- # minimum=256,
137
- # maximum=2048,
138
- # step=32,
139
- # value=1024,
140
- # )
141
- # height = gr.Slider(
142
- # label="Height",
143
- # minimum=256,
144
- # maximum=2048,
145
- # step=32,
146
- # value=1024,
147
- # )
148
- with gr.Row():
149
- guidance_scale = gr.Slider(
150
- label="Guidance scale",
151
- minimum=0.0,
152
- maximum=10.0,
153
- step=0.1,
154
- value=7.0,
155
- )
156
- num_inference_steps = gr.Slider(
157
- label="Number of inference steps",
158
- minimum=1,
159
- maximum=100,
160
- step=1,
161
- value=25,
162
- )
163
-
164
 
165
- # gr.Examples(
166
- # examples=examples,
167
- # fn=predict,
168
- # inputs=[prompt, ip_adapter_image, ip_adapter_scale, negative_prompt, seed, randomize_seed, center_crop, width, height],
169
- # outputs=[result, seed],
170
- # cache_examples="lazy"
171
- # )
172
-
173
- gr.on(
174
- triggers=[run_button.click, prompt.submit],
175
- fn=predict,
176
- inputs=[prompt, ip_adapter_image, ip_adapter_scale, negative_prompt, seed, randomize_seed, center_crop, width, height, guidance_scale, num_inference_steps],
177
  outputs=[result, seed]
178
  )
179
 
180
- demo.queue(max_size=25,api_open=False).launch(show_api=False)
 
 
 
 
 
 
181
 
182
- # image_blocks.queue(max_size=25,api_open=False).launch(show_api=False)
 
4
  import diffusers
5
  import os
6
  import random
 
7
  from PIL import Image
8
 
9
  hf_token = os.environ.get("HF_TOKEN")
10
  from diffusers import AutoPipelineForText2Image
11
 
12
+ device = "cuda" # if torch.cuda.is_available() else "cpu"
 
13
  pipe = AutoPipelineForText2Image.from_pretrained("briaai/BRIA-2.3", torch_dtype=torch.float16, force_zeros_for_empty_prompt=False).to(device)
14
  pipe.load_ip_adapter("briaai/Image-Prompt", subfolder='models', weight_name="ip_adapter_bria.bin")
15
  pipe.to(device)
 
16
 
17
  MAX_SEED = np.iinfo(np.int32).max
18
 
19
  @spaces.GPU(enable_queue=True)
20
+ def predict(prompt, ip_adapter_images, ip_adapter_scale=0.5, negative_prompt="", seed=100, randomize_seed=False, center_crop=False, width=1024, height=1024, guidance_scale=5.0, num_inference_steps=50, progress=gr.Progress(track_tqdm=True)):
21
  if randomize_seed:
22
  seed = random.randint(0, MAX_SEED)
23
+
24
  # Optionally resize images if center crop is not selected
25
  if not center_crop:
26
+ ip_adapter_images = [image.resize((224, 224)) for image in ip_adapter_images]
27
+
28
+ # Create a generator for reproducible random seed
29
  generator = torch.Generator(device="cuda").manual_seed(seed)
30
  pipe.set_ip_adapter_scale([ip_adapter_scale])
31
 
32
+ # Pass all images at once to the pipe
33
+ result_images = pipe(
34
  prompt=prompt,
35
+ ip_adapter_image=ip_adapter_images, # Pass the list of images
36
  negative_prompt=negative_prompt,
37
  height=height,
38
  width=width,
 
40
  guidance_scale=guidance_scale,
41
  num_images_per_prompt=1,
42
  generator=generator,
43
+ ).images
44
 
45
+ return result_images, seed
46
 
 
 
 
47
  examples = [
48
+ ["high quality", ["example1.png", "example2.png"], 1.0, "", 1000, False, False, 1152, 896],
 
49
  ]
50
 
51
+ css = """
52
  #col-container {
53
+ display: flex;
54
+ flex-direction: column;
55
+ align-items: center;
56
+ padding: 10px;
 
 
 
 
57
  }
58
  """
59
+
60
  with gr.Blocks(css=css) as demo:
61
+ with gr.Column():
 
 
 
 
62
  with gr.Row():
63
+ prompt = gr.Textbox(label="Prompt", lines=1)
64
+ ip_adapter_images = gr.Gallery(label="Input Images", elem_id="image-gallery").style(grid=[2], preview=True)
65
+
66
+ ip_adapter_scale = gr.Slider(label="IP Adapter Scale", minimum=0.0, maximum=1.0, step=0.1, value=0.5)
67
+ negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Optional", lines=1)
68
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  with gr.Row():
70
+ seed = gr.Number(label="Seed", value=100)
71
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
72
+ center_crop = gr.Checkbox(label="Center Crop Image", value=False, info="If not checked, the images will be resized.")
73
+
74
+ with gr.Row():
75
+ guidance_scale = gr.Slider(
76
+ label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=7.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  )
78
+ num_inference_steps = gr.Slider(
79
+ label="Number of Inference Steps", minimum=1, maximum=100, step=1, value=25
 
 
 
 
80
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+ result = gr.Gallery(label="Generated Images").style(grid=[2], preview=True)
83
+
84
+ run_button = gr.Button("Run")
85
+
86
+ run_button.click(
87
+ predict,
88
+ inputs=[prompt, ip_adapter_images, ip_adapter_scale, negative_prompt, seed, randomize_seed, center_crop, width, height, guidance_scale, num_inference_steps],
 
 
 
 
 
89
  outputs=[result, seed]
90
  )
91
 
92
+ gr.Examples(
93
+ examples=examples,
94
+ fn=predict,
95
+ inputs=[prompt, ip_adapter_images, ip_adapter_scale, negative_prompt, seed, randomize_seed, center_crop, width, height],
96
+ outputs=[result, seed],
97
+ cache_examples="lazy"
98
+ )
99
 
100
+ demo.queue(max_size=25, api_open=False).launch(show_api=False)