Spaces:
Runtime error
Runtime error
add seed
Browse files- OmniGen/__pycache__/__init__.cpython-310.pyc +0 -0
- OmniGen/__pycache__/model.cpython-310.pyc +0 -0
- OmniGen/__pycache__/pipeline.cpython-310.pyc +0 -0
- OmniGen/__pycache__/processor.cpython-310.pyc +0 -0
- OmniGen/__pycache__/scheduler.cpython-310.pyc +0 -0
- OmniGen/__pycache__/transformer.cpython-310.pyc +0 -0
- OmniGen/__pycache__/utils.cpython-310.pyc +0 -0
- OmniGen/model.py +16 -11
- OmniGen/pipeline.py +36 -8
- OmniGen/processor.py +11 -25
- OmniGen/scheduler.py +1 -1
- OmniGen/transformer.py +1 -1
- OmniGen/utils.py +110 -0
- app.py +14 -5
OmniGen/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (325 Bytes). View file
|
|
OmniGen/__pycache__/model.cpython-310.pyc
ADDED
Binary file (12.5 kB). View file
|
|
OmniGen/__pycache__/pipeline.cpython-310.pyc
ADDED
Binary file (8.5 kB). View file
|
|
OmniGen/__pycache__/processor.cpython-310.pyc
ADDED
Binary file (11.2 kB). View file
|
|
OmniGen/__pycache__/scheduler.cpython-310.pyc
ADDED
Binary file (2.75 kB). View file
|
|
OmniGen/__pycache__/transformer.cpython-310.pyc
ADDED
Binary file (3.95 kB). View file
|
|
OmniGen/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (3.52 kB). View file
|
|
OmniGen/model.py
CHANGED
@@ -5,7 +5,10 @@ import torch.nn as nn
|
|
5 |
import numpy as np
|
6 |
import math
|
7 |
from typing import Dict
|
|
|
|
|
8 |
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
|
|
|
9 |
|
10 |
from OmniGen.transformer import Phi3Config, Phi3Transformer
|
11 |
|
@@ -145,7 +148,7 @@ class PatchEmbedMR(nn.Module):
|
|
145 |
return x
|
146 |
|
147 |
|
148 |
-
class OmniGen(nn.Module):
|
149 |
"""
|
150 |
Diffusion model with a Transformer backbone.
|
151 |
"""
|
@@ -191,7 +194,7 @@ class OmniGen(nn.Module):
|
|
191 |
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
|
192 |
config = Phi3Config.from_pretrained(model_name)
|
193 |
model = cls(config)
|
194 |
-
ckpt = torch.load(os.path.join(model_name, 'model.pt'))
|
195 |
model.load_state_dict(ckpt)
|
196 |
return model
|
197 |
|
@@ -304,7 +307,7 @@ class OmniGen(nn.Module):
|
|
304 |
return latents, num_tokens, shapes
|
305 |
|
306 |
|
307 |
-
def forward(self, x, timestep,
|
308 |
"""
|
309 |
|
310 |
"""
|
@@ -312,16 +315,16 @@ class OmniGen(nn.Module):
|
|
312 |
x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
|
313 |
time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
|
314 |
|
315 |
-
if
|
316 |
-
input_latents, _, _ = self.patch_multiple_resolutions(
|
317 |
-
if
|
318 |
-
condition_embeds = self.llm.embed_tokens(
|
319 |
input_img_inx = 0
|
320 |
-
for b_inx in
|
321 |
-
for start_inx, end_inx in
|
322 |
condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
|
323 |
input_img_inx += 1
|
324 |
-
if
|
325 |
assert input_img_inx == len(input_latents)
|
326 |
|
327 |
input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
|
@@ -344,7 +347,9 @@ class OmniGen(nn.Module):
|
|
344 |
x = self.final_layer(image_embedding, time_emb)
|
345 |
latents = self.unpatchify(x, shapes[0], shapes[1])
|
346 |
|
347 |
-
|
|
|
|
|
348 |
|
349 |
@torch.no_grad()
|
350 |
def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache):
|
|
|
5 |
import numpy as np
|
6 |
import math
|
7 |
from typing import Dict
|
8 |
+
|
9 |
+
from diffusers.loaders import PeftAdapterMixin
|
10 |
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
|
11 |
+
from huggingface_hub import snapshot_download
|
12 |
|
13 |
from OmniGen.transformer import Phi3Config, Phi3Transformer
|
14 |
|
|
|
148 |
return x
|
149 |
|
150 |
|
151 |
+
class OmniGen(nn.Module, PeftAdapterMixin):
|
152 |
"""
|
153 |
Diffusion model with a Transformer backbone.
|
154 |
"""
|
|
|
194 |
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
|
195 |
config = Phi3Config.from_pretrained(model_name)
|
196 |
model = cls(config)
|
197 |
+
ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu')
|
198 |
model.load_state_dict(ckpt)
|
199 |
return model
|
200 |
|
|
|
307 |
return latents, num_tokens, shapes
|
308 |
|
309 |
|
310 |
+
def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True):
|
311 |
"""
|
312 |
|
313 |
"""
|
|
|
315 |
x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
|
316 |
time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
|
317 |
|
318 |
+
if input_img_latents is not None:
|
319 |
+
input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
|
320 |
+
if input_ids is not None:
|
321 |
+
condition_embeds = self.llm.embed_tokens(input_ids).clone()
|
322 |
input_img_inx = 0
|
323 |
+
for b_inx in input_image_sizes.keys():
|
324 |
+
for start_inx, end_inx in input_image_sizes[b_inx]:
|
325 |
condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
|
326 |
input_img_inx += 1
|
327 |
+
if input_img_latents is not None:
|
328 |
assert input_img_inx == len(input_latents)
|
329 |
|
330 |
input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
|
|
|
347 |
x = self.final_layer(image_embedding, time_emb)
|
348 |
latents = self.unpatchify(x, shapes[0], shapes[1])
|
349 |
|
350 |
+
if past_key_values:
|
351 |
+
return latents, past_key_values
|
352 |
+
return latents
|
353 |
|
354 |
@torch.no_grad()
|
355 |
def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache):
|
OmniGen/pipeline.py
CHANGED
@@ -6,6 +6,7 @@ from PIL import Image
|
|
6 |
import numpy as np
|
7 |
import torch
|
8 |
from huggingface_hub import snapshot_download
|
|
|
9 |
from diffusers.models import AutoencoderKL
|
10 |
from diffusers.utils import (
|
11 |
USE_PEFT_BACKEND,
|
@@ -31,7 +32,7 @@ EXAMPLE_DOC_STRING = """
|
|
31 |
>>> prompt = "A woman holds a bouquet of flowers and faces the camera"
|
32 |
>>> image = pipe(
|
33 |
... prompt,
|
34 |
-
... guidance_scale=
|
35 |
... num_inference_steps=50,
|
36 |
... ).images[0]
|
37 |
>>> image.save("t2i.png")
|
@@ -53,23 +54,42 @@ class OmniGenPipeline:
|
|
53 |
|
54 |
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
55 |
self.model.to(self.device)
|
|
|
56 |
self.vae.to(self.device)
|
57 |
|
58 |
@classmethod
|
59 |
-
def from_pretrained(cls, model_name):
|
60 |
if not os.path.exists(model_name):
|
|
|
61 |
cache_folder = os.getenv('HF_HUB_CACHE')
|
62 |
-
print(cache_folder)
|
63 |
model_name = snapshot_download(repo_id=model_name,
|
64 |
cache_dir=cache_folder,
|
65 |
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
|
66 |
logger.info(f"Downloaded model to {model_name}")
|
67 |
model = OmniGen.from_pretrained(model_name)
|
68 |
processor = OmniGenProcessor.from_pretrained(model_name)
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
return cls(vae, model, processor)
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
def vae_encode(self, x, dtype):
|
74 |
if self.vae.config.shift_factor is not None:
|
75 |
x = self.vae.encode(x).latent_dist.sample()
|
@@ -100,6 +120,7 @@ class OmniGenPipeline:
|
|
100 |
separate_cfg_infer: bool = False,
|
101 |
use_kv_cache: bool = True,
|
102 |
dtype: torch.dtype = torch.bfloat16,
|
|
|
103 |
):
|
104 |
r"""
|
105 |
Function invoked when calling the pipeline for generation.
|
@@ -128,15 +149,18 @@ class OmniGenPipeline:
|
|
128 |
separate_cfg_infer (`bool`, *optional*, defaults to False):
|
129 |
Perform inference on images with different guidance separately; this can save memory when generating images of large size at the expense of slower inference.
|
130 |
use_kv_cache (`bool`, *optional*, defaults to True): enable kv cache to speed up the inference
|
131 |
-
|
|
|
|
|
132 |
Examples:
|
133 |
|
134 |
Returns:
|
135 |
A list with the generated images.
|
136 |
"""
|
137 |
assert height%16 == 0 and width%16 == 0
|
138 |
-
if
|
139 |
-
|
|
|
140 |
if input_images is None:
|
141 |
use_img_guidance = False
|
142 |
if isinstance(prompt, str):
|
@@ -149,7 +173,11 @@ class OmniGenPipeline:
|
|
149 |
num_cfg = 2 if use_img_guidance else 1
|
150 |
latent_size_h, latent_size_w = height//8, width//8
|
151 |
|
152 |
-
|
|
|
|
|
|
|
|
|
153 |
latents = torch.cat([latents]*(1+num_cfg), 0).to(dtype)
|
154 |
|
155 |
input_img_latents = []
|
|
|
6 |
import numpy as np
|
7 |
import torch
|
8 |
from huggingface_hub import snapshot_download
|
9 |
+
from peft import LoraConfig, PeftModel
|
10 |
from diffusers.models import AutoencoderKL
|
11 |
from diffusers.utils import (
|
12 |
USE_PEFT_BACKEND,
|
|
|
32 |
>>> prompt = "A woman holds a bouquet of flowers and faces the camera"
|
33 |
>>> image = pipe(
|
34 |
... prompt,
|
35 |
+
... guidance_scale=3.0,
|
36 |
... num_inference_steps=50,
|
37 |
... ).images[0]
|
38 |
>>> image.save("t2i.png")
|
|
|
54 |
|
55 |
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
56 |
self.model.to(self.device)
|
57 |
+
self.model.eval()
|
58 |
self.vae.to(self.device)
|
59 |
|
60 |
@classmethod
|
61 |
+
def from_pretrained(cls, model_name, vae_path: str=None):
|
62 |
if not os.path.exists(model_name):
|
63 |
+
logger.info("Model not found, downloading...")
|
64 |
cache_folder = os.getenv('HF_HUB_CACHE')
|
|
|
65 |
model_name = snapshot_download(repo_id=model_name,
|
66 |
cache_dir=cache_folder,
|
67 |
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
|
68 |
logger.info(f"Downloaded model to {model_name}")
|
69 |
model = OmniGen.from_pretrained(model_name)
|
70 |
processor = OmniGenProcessor.from_pretrained(model_name)
|
71 |
+
|
72 |
+
if os.path.exists(os.path.join(model_name, "vae")):
|
73 |
+
vae = AutoencoderKL.from_pretrained(os.path.join(model_name, "vae"))
|
74 |
+
elif vae_path is not None:
|
75 |
+
vae = AutoencoderKL.from_pretrained(vae_path).to(device)
|
76 |
+
else:
|
77 |
+
logger.info(f"No VAE found in {model_name}, downloading stabilityai/sdxl-vae from HF")
|
78 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device)
|
79 |
|
80 |
return cls(vae, model, processor)
|
81 |
|
82 |
+
def merge_lora(self, lora_path: str):
|
83 |
+
model = PeftModel.from_pretrained(self.model, lora_path)
|
84 |
+
model.merge_and_unload()
|
85 |
+
self.model = model
|
86 |
+
|
87 |
+
def to(self, device: Union[str, torch.device]):
|
88 |
+
if isinstance(device, str):
|
89 |
+
device = torch.device(device)
|
90 |
+
self.model.to(device)
|
91 |
+
self.vae.to(device)
|
92 |
+
|
93 |
def vae_encode(self, x, dtype):
|
94 |
if self.vae.config.shift_factor is not None:
|
95 |
x = self.vae.encode(x).latent_dist.sample()
|
|
|
120 |
separate_cfg_infer: bool = False,
|
121 |
use_kv_cache: bool = True,
|
122 |
dtype: torch.dtype = torch.bfloat16,
|
123 |
+
seed: int = None,
|
124 |
):
|
125 |
r"""
|
126 |
Function invoked when calling the pipeline for generation.
|
|
|
149 |
separate_cfg_infer (`bool`, *optional*, defaults to False):
|
150 |
Perform inference on images with different guidance separately; this can save memory when generating images of large size at the expense of slower inference.
|
151 |
use_kv_cache (`bool`, *optional*, defaults to True): enable kv cache to speed up the inference
|
152 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
153 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
154 |
+
to make generation deterministic.
|
155 |
Examples:
|
156 |
|
157 |
Returns:
|
158 |
A list with the generated images.
|
159 |
"""
|
160 |
assert height%16 == 0 and width%16 == 0
|
161 |
+
if separate_cfg_infer:
|
162 |
+
use_kv_cache = False
|
163 |
+
# raise "Currently, don't support both use_kv_cache and separate_cfg_infer"
|
164 |
if input_images is None:
|
165 |
use_img_guidance = False
|
166 |
if isinstance(prompt, str):
|
|
|
173 |
num_cfg = 2 if use_img_guidance else 1
|
174 |
latent_size_h, latent_size_w = height//8, width//8
|
175 |
|
176 |
+
if seed is not None:
|
177 |
+
generator = torch.Generator(device=self.device).manual_seed(seed)
|
178 |
+
else:
|
179 |
+
generator = None
|
180 |
+
latents = torch.randn(num_prompt, 4, latent_size_h, latent_size_w, device=self.device, generator=generator)
|
181 |
latents = torch.cat([latents]*(1+num_cfg), 0).to(dtype)
|
182 |
|
183 |
input_img_latents = []
|
OmniGen/processor.py
CHANGED
@@ -11,28 +11,15 @@ from torchvision import transforms
|
|
11 |
from transformers import AutoTokenizer
|
12 |
from huggingface_hub import snapshot_download
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
def crop_arr(pil_image, max_image_size):
|
16 |
-
while min(*pil_image.size) >= 2 * max_image_size:
|
17 |
-
pil_image = pil_image.resize(
|
18 |
-
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
19 |
-
)
|
20 |
|
21 |
-
if max(*pil_image.size) > max_image_size:
|
22 |
-
scale = max_image_size / max(*pil_image.size)
|
23 |
-
pil_image = pil_image.resize(
|
24 |
-
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
25 |
-
)
|
26 |
-
|
27 |
-
arr = np.array(pil_image)
|
28 |
-
crop_y1 = (arr.shape[0] % 16) // 2
|
29 |
-
crop_y2 = arr.shape[0] % 16 - crop_y1
|
30 |
-
|
31 |
-
crop_x1 = (arr.shape[1] % 16) // 2
|
32 |
-
crop_x2 = arr.shape[1] % 16 - crop_x1
|
33 |
-
|
34 |
-
arr = arr[crop_y1:arr.shape[0]-crop_y2, crop_x1:arr.shape[1]-crop_x2]
|
35 |
-
return Image.fromarray(arr)
|
36 |
|
37 |
|
38 |
class OmniGenProcessor:
|
@@ -68,6 +55,7 @@ class OmniGenProcessor:
|
|
68 |
return self.image_transform(image)
|
69 |
|
70 |
def process_multi_modal_prompt(self, text, input_images):
|
|
|
71 |
if input_images is None or len(input_images) == 0:
|
72 |
model_inputs = self.text_tokenizer(text)
|
73 |
return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None}
|
@@ -132,7 +120,6 @@ class OmniGenProcessor:
|
|
132 |
for i in range(len(instructions)):
|
133 |
cur_instruction = instructions[i]
|
134 |
cur_input_images = None if input_images is None else input_images[i]
|
135 |
-
cur_instruction = self.add_prefix_instruction(cur_instruction)
|
136 |
if cur_input_images is not None and len(cur_input_images) > 0:
|
137 |
cur_input_images = [self.process_image(x) for x in cur_input_images]
|
138 |
else:
|
@@ -143,14 +130,13 @@ class OmniGenProcessor:
|
|
143 |
|
144 |
|
145 |
neg_mllm_input, img_cfg_mllm_input = None, None
|
146 |
-
|
147 |
-
neg_mllm_input = self.process_multi_modal_prompt(neg_instruction, None)
|
148 |
if use_img_cfg:
|
149 |
if cur_input_images is not None and len(cur_input_images) >= 1:
|
150 |
img_cfg_prompt = [f"<img><|image_{i+1}|></img>" for i in range(len(cur_input_images))]
|
151 |
-
img_cfg_mllm_input = self.process_multi_modal_prompt(
|
152 |
else:
|
153 |
-
img_cfg_mllm_input =
|
154 |
|
155 |
input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
|
156 |
|
|
|
11 |
from transformers import AutoTokenizer
|
12 |
from huggingface_hub import snapshot_download
|
13 |
|
14 |
+
from OmniGen.utils import (
|
15 |
+
create_logger,
|
16 |
+
update_ema,
|
17 |
+
requires_grad,
|
18 |
+
center_crop_arr,
|
19 |
+
crop_arr,
|
20 |
+
)
|
21 |
|
|
|
|
|
|
|
|
|
|
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
|
25 |
class OmniGenProcessor:
|
|
|
55 |
return self.image_transform(image)
|
56 |
|
57 |
def process_multi_modal_prompt(self, text, input_images):
|
58 |
+
text = self.add_prefix_instruction(text)
|
59 |
if input_images is None or len(input_images) == 0:
|
60 |
model_inputs = self.text_tokenizer(text)
|
61 |
return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None}
|
|
|
120 |
for i in range(len(instructions)):
|
121 |
cur_instruction = instructions[i]
|
122 |
cur_input_images = None if input_images is None else input_images[i]
|
|
|
123 |
if cur_input_images is not None and len(cur_input_images) > 0:
|
124 |
cur_input_images = [self.process_image(x) for x in cur_input_images]
|
125 |
else:
|
|
|
130 |
|
131 |
|
132 |
neg_mllm_input, img_cfg_mllm_input = None, None
|
133 |
+
neg_mllm_input = self.process_multi_modal_prompt(negative_prompt, None)
|
|
|
134 |
if use_img_cfg:
|
135 |
if cur_input_images is not None and len(cur_input_images) >= 1:
|
136 |
img_cfg_prompt = [f"<img><|image_{i+1}|></img>" for i in range(len(cur_input_images))]
|
137 |
+
img_cfg_mllm_input = self.process_multi_modal_prompt(" ".join(img_cfg_prompt), cur_input_images)
|
138 |
else:
|
139 |
+
img_cfg_mllm_input = neg_mllm_input
|
140 |
|
141 |
input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
|
142 |
|
OmniGen/scheduler.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import torch
|
2 |
from tqdm import tqdm
|
3 |
-
from transformers.cache_utils import Cache, DynamicCache
|
4 |
|
5 |
class OmniGenScheduler:
|
6 |
def __init__(self, num_steps: int=50, time_shifting_factor: int=1):
|
|
|
1 |
import torch
|
2 |
from tqdm import tqdm
|
3 |
+
from transformers.cache_utils import Cache, DynamicCache
|
4 |
|
5 |
class OmniGenScheduler:
|
6 |
def __init__(self, num_steps: int=50, time_shifting_factor: int=1):
|
OmniGen/transformer.py
CHANGED
@@ -16,7 +16,7 @@ from transformers.modeling_outputs import (
|
|
16 |
)
|
17 |
from transformers.modeling_utils import PreTrainedModel
|
18 |
from transformers import Phi3Config, Phi3Model
|
19 |
-
from transformers.cache_utils import Cache, DynamicCache,
|
20 |
from transformers.utils import logging
|
21 |
|
22 |
logger = logging.get_logger(__name__)
|
|
|
16 |
)
|
17 |
from transformers.modeling_utils import PreTrainedModel
|
18 |
from transformers import Phi3Config, Phi3Model
|
19 |
+
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
20 |
from transformers.utils import logging
|
21 |
|
22 |
logger = logging.get_logger(__name__)
|
OmniGen/utils.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
def create_logger(logging_dir):
|
8 |
+
"""
|
9 |
+
Create a logger that writes to a log file and stdout.
|
10 |
+
"""
|
11 |
+
logging.basicConfig(
|
12 |
+
level=logging.INFO,
|
13 |
+
format='[\033[34m%(asctime)s\033[0m] %(message)s',
|
14 |
+
datefmt='%Y-%m-%d %H:%M:%S',
|
15 |
+
handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
|
16 |
+
)
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
return logger
|
19 |
+
|
20 |
+
|
21 |
+
@torch.no_grad()
|
22 |
+
def update_ema(ema_model, model, decay=0.9999):
|
23 |
+
"""
|
24 |
+
Step the EMA model towards the current model.
|
25 |
+
"""
|
26 |
+
ema_params = dict(ema_model.named_parameters())
|
27 |
+
for name, param in model.named_parameters():
|
28 |
+
# TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
|
29 |
+
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
def requires_grad(model, flag=True):
|
35 |
+
"""
|
36 |
+
Set requires_grad flag for all parameters in a model.
|
37 |
+
"""
|
38 |
+
for p in model.parameters():
|
39 |
+
p.requires_grad = flag
|
40 |
+
|
41 |
+
|
42 |
+
def center_crop_arr(pil_image, image_size):
|
43 |
+
"""
|
44 |
+
Center cropping implementation from ADM.
|
45 |
+
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
46 |
+
"""
|
47 |
+
while min(*pil_image.size) >= 2 * image_size:
|
48 |
+
pil_image = pil_image.resize(
|
49 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
50 |
+
)
|
51 |
+
|
52 |
+
scale = image_size / min(*pil_image.size)
|
53 |
+
pil_image = pil_image.resize(
|
54 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
55 |
+
)
|
56 |
+
|
57 |
+
arr = np.array(pil_image)
|
58 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
59 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
60 |
+
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
def crop_arr(pil_image, max_image_size):
|
65 |
+
while min(*pil_image.size) >= 2 * max_image_size:
|
66 |
+
pil_image = pil_image.resize(
|
67 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
68 |
+
)
|
69 |
+
|
70 |
+
if max(*pil_image.size) > max_image_size:
|
71 |
+
scale = max_image_size / max(*pil_image.size)
|
72 |
+
pil_image = pil_image.resize(
|
73 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
74 |
+
)
|
75 |
+
|
76 |
+
if min(*pil_image.size) < 16:
|
77 |
+
scale = 16 / min(*pil_image.size)
|
78 |
+
pil_image = pil_image.resize(
|
79 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
80 |
+
)
|
81 |
+
|
82 |
+
arr = np.array(pil_image)
|
83 |
+
crop_y1 = (arr.shape[0] % 16) // 2
|
84 |
+
crop_y2 = arr.shape[0] % 16 - crop_y1
|
85 |
+
|
86 |
+
crop_x1 = (arr.shape[1] % 16) // 2
|
87 |
+
crop_x2 = arr.shape[1] % 16 - crop_x1
|
88 |
+
|
89 |
+
arr = arr[crop_y1:arr.shape[0]-crop_y2, crop_x1:arr.shape[1]-crop_x2]
|
90 |
+
return Image.fromarray(arr)
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
def vae_encode(vae, x, weight_dtype):
|
95 |
+
if x is not None:
|
96 |
+
if vae.config.shift_factor is not None:
|
97 |
+
x = vae.encode(x).latent_dist.sample()
|
98 |
+
x = (x - vae.config.shift_factor) * vae.config.scaling_factor
|
99 |
+
else:
|
100 |
+
x = vae.encode(x).latent_dist.sample().mul_(vae.config.scaling_factor)
|
101 |
+
x = x.to(weight_dtype)
|
102 |
+
return x
|
103 |
+
|
104 |
+
def vae_encode_list(vae, x, weight_dtype):
|
105 |
+
latents = []
|
106 |
+
for img in x:
|
107 |
+
img = vae_encode(vae, img, weight_dtype)
|
108 |
+
latents.append(img)
|
109 |
+
return latents
|
110 |
+
|
app.py
CHANGED
@@ -11,7 +11,7 @@ pipe = OmniGenPipeline.from_pretrained(
|
|
11 |
|
12 |
@spaces.GPU
|
13 |
# 示例处理函数:生成图像
|
14 |
-
def generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps):
|
15 |
input_images = [img1, img2, img3]
|
16 |
# 去除 None
|
17 |
input_images = [img for img in input_images if img is not None]
|
@@ -28,6 +28,7 @@ def generate_image(text, img1, img2, img3, height, width, guidance_scale, infere
|
|
28 |
num_inference_steps=inference_steps,
|
29 |
separate_cfg_infer=True,
|
30 |
use_kv_cache=False,
|
|
|
31 |
)
|
32 |
img = output[0]
|
33 |
return img
|
@@ -54,6 +55,7 @@ def get_example():
|
|
54 |
1024,
|
55 |
3.0,
|
56 |
20,
|
|
|
57 |
],
|
58 |
[
|
59 |
"Three zebras are standing side by side on a vibrant savannah, each showcasing unique patterns and characteristics that highlight their individuality. The zebra on the left has a strikingly bold black and white stripe pattern, with wider stripes that create a dramatic contrast against its sleek body. In the middle, the zebra features a more subtle stripe arrangement, with thinner stripes that blend seamlessly into a slightly sandy-colored coat, giving it a softer appearance. On the right, the zebra's stripes are more irregular, with a distinct patch of brown fur near its shoulder, adding a layer of uniqueness to its overall look. Together, these zebras create a captivating scene, each representing the diverse beauty of their species in the wild. The right zebras is the zebras from <img><|image_1|></img>. The center zebras is from <img><|image_2|></img>. The left zebras is the zebras from <img><|image_3|></img>.",
|
@@ -64,22 +66,23 @@ def get_example():
|
|
64 |
1024,
|
65 |
3.0,
|
66 |
20,
|
|
|
67 |
],
|
68 |
]
|
69 |
return case
|
70 |
|
71 |
-
def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, inference_steps):
|
72 |
-
return generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps)
|
73 |
|
74 |
|
75 |
# Gradio 接口
|
76 |
with gr.Blocks() as demo:
|
77 |
-
gr.Markdown("
|
78 |
with gr.Row():
|
79 |
with gr.Column():
|
80 |
# 文本输入框
|
81 |
prompt_input = gr.Textbox(
|
82 |
-
label="Enter your prompt", placeholder="Type your prompt here..."
|
83 |
)
|
84 |
|
85 |
with gr.Row(equal_height=True):
|
@@ -105,6 +108,10 @@ with gr.Blocks() as demo:
|
|
105 |
label="Inference Steps", minimum=1, maximum=50, value=50, step=1
|
106 |
)
|
107 |
|
|
|
|
|
|
|
|
|
108 |
# 生成按钮
|
109 |
generate_button = gr.Button("Generate Image")
|
110 |
|
@@ -124,6 +131,7 @@ with gr.Blocks() as demo:
|
|
124 |
width_input,
|
125 |
guidance_scale_input,
|
126 |
num_inference_steps,
|
|
|
127 |
],
|
128 |
outputs=output_image,
|
129 |
)
|
@@ -140,6 +148,7 @@ with gr.Blocks() as demo:
|
|
140 |
width_input,
|
141 |
guidance_scale_input,
|
142 |
num_inference_steps,
|
|
|
143 |
],
|
144 |
outputs=output_image,
|
145 |
)
|
|
|
11 |
|
12 |
@spaces.GPU
|
13 |
# 示例处理函数:生成图像
|
14 |
+
def generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed):
|
15 |
input_images = [img1, img2, img3]
|
16 |
# 去除 None
|
17 |
input_images = [img for img in input_images if img is not None]
|
|
|
28 |
num_inference_steps=inference_steps,
|
29 |
separate_cfg_infer=True,
|
30 |
use_kv_cache=False,
|
31 |
+
seed=seed,
|
32 |
)
|
33 |
img = output[0]
|
34 |
return img
|
|
|
55 |
1024,
|
56 |
3.0,
|
57 |
20,
|
58 |
+
42,
|
59 |
],
|
60 |
[
|
61 |
"Three zebras are standing side by side on a vibrant savannah, each showcasing unique patterns and characteristics that highlight their individuality. The zebra on the left has a strikingly bold black and white stripe pattern, with wider stripes that create a dramatic contrast against its sleek body. In the middle, the zebra features a more subtle stripe arrangement, with thinner stripes that blend seamlessly into a slightly sandy-colored coat, giving it a softer appearance. On the right, the zebra's stripes are more irregular, with a distinct patch of brown fur near its shoulder, adding a layer of uniqueness to its overall look. Together, these zebras create a captivating scene, each representing the diverse beauty of their species in the wild. The right zebras is the zebras from <img><|image_1|></img>. The center zebras is from <img><|image_2|></img>. The left zebras is the zebras from <img><|image_3|></img>.",
|
|
|
66 |
1024,
|
67 |
3.0,
|
68 |
20,
|
69 |
+
42,
|
70 |
],
|
71 |
]
|
72 |
return case
|
73 |
|
74 |
+
def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed):
|
75 |
+
return generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed)
|
76 |
|
77 |
|
78 |
# Gradio 接口
|
79 |
with gr.Blocks() as demo:
|
80 |
+
gr.Markdown("# OmniGen: Unified Image Generation")
|
81 |
with gr.Row():
|
82 |
with gr.Column():
|
83 |
# 文本输入框
|
84 |
prompt_input = gr.Textbox(
|
85 |
+
label="Enter your prompt, use <img><|image_i|></img> tokens for images", placeholder="Type your prompt here..."
|
86 |
)
|
87 |
|
88 |
with gr.Row(equal_height=True):
|
|
|
108 |
label="Inference Steps", minimum=1, maximum=50, value=50, step=1
|
109 |
)
|
110 |
|
111 |
+
seed_input = gr.Slider(
|
112 |
+
label="Seed", minimum=0, maximum=2147483647, value=42, step=1
|
113 |
+
)
|
114 |
+
|
115 |
# 生成按钮
|
116 |
generate_button = gr.Button("Generate Image")
|
117 |
|
|
|
131 |
width_input,
|
132 |
guidance_scale_input,
|
133 |
num_inference_steps,
|
134 |
+
seed_input,
|
135 |
],
|
136 |
outputs=output_image,
|
137 |
)
|
|
|
148 |
width_input,
|
149 |
guidance_scale_input,
|
150 |
num_inference_steps,
|
151 |
+
seed_input,
|
152 |
],
|
153 |
outputs=output_image,
|
154 |
)
|