Spaces:
Running
on
Zero
Running
on
Zero
update image_cache
Browse files- .gitignore +3 -3
- get_webvid_prompt.py +0 -32
- model/model_manager.py +0 -55
- model/models/generate_image_cache.py +0 -99
- model/models/generate_video_cache.py +0 -62
.gitignore
CHANGED
@@ -174,6 +174,6 @@ ksort-logs/
|
|
174 |
cache_video/
|
175 |
cache_image/
|
176 |
|
177 |
-
|
178 |
-
|
179 |
-
|
|
|
174 |
cache_video/
|
175 |
cache_image/
|
176 |
|
177 |
+
model/models/generate_image_cache.py
|
178 |
+
model/models/generate_video_cache.py
|
179 |
+
get_webvid_prompt.py
|
get_webvid_prompt.py
DELETED
@@ -1,32 +0,0 @@
|
|
1 |
-
from datasets import load_dataset
|
2 |
-
import pandas as pd
|
3 |
-
import re
|
4 |
-
# # Load the WebVid dataset
|
5 |
-
# dataset = load_dataset('webvid', 'webvid-10m', split='train')
|
6 |
-
# from datasets import load_dataset
|
7 |
-
|
8 |
-
ds = load_dataset("TempoFunk/webvid-10M", cache_dir="/mnt/data/lizhikai/webvid/")
|
9 |
-
v = ds['validation']['name']
|
10 |
-
# 定义字符串长度的合理范围
|
11 |
-
MIN_LENGTH = 30
|
12 |
-
MAX_LENGTH = 300
|
13 |
-
pattern = re.compile(r'^[a-zA-Z\s]+$')
|
14 |
-
|
15 |
-
# 过滤掉空字符串和特别长特别短的字符串
|
16 |
-
v = [s for s in v if len(s) >= MIN_LENGTH and len(s) <= MAX_LENGTH and pattern.match(s)]
|
17 |
-
|
18 |
-
# 指定保存文件的路径
|
19 |
-
file_path = 'webvid_prompt.txt'
|
20 |
-
|
21 |
-
# 打开文件,以写入模式
|
22 |
-
with open(file_path, 'w', encoding='utf-8') as file:
|
23 |
-
# 遍历列表中的每个字符串并写入文件
|
24 |
-
for item in v:
|
25 |
-
if '\n' in item:
|
26 |
-
continue
|
27 |
-
else:
|
28 |
-
file.write(item + '\n')
|
29 |
-
|
30 |
-
print("字符串列表已成功保存到文件中。")
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/model_manager.py
CHANGED
@@ -8,7 +8,6 @@ import torch
|
|
8 |
from PIL import Image
|
9 |
from openai import OpenAI
|
10 |
from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS, load_pipeline
|
11 |
-
from .fetch_museum_results import draw_from_imagen_museum, draw2_from_imagen_museum
|
12 |
from serve.upload import get_random_mscoco_prompt, get_random_video_prompt, get_ssh_random_video_prompt, get_ssh_random_image_prompt
|
13 |
from serve.constants import SSH_CACHE_OPENSOURCE, SSH_CACHE_ADVANCE, SSH_CACHE_PIKA, SSH_CACHE_SORA, SSH_CACHE_IMAGE
|
14 |
|
@@ -62,14 +61,6 @@ class ModelManager:
|
|
62 |
|
63 |
return result
|
64 |
|
65 |
-
def generate_image_ig_museum(self, model_name):
|
66 |
-
model_name = model_name.split('_')[1]
|
67 |
-
result_list = draw_from_imagen_museum("t2i", model_name)
|
68 |
-
image_link = result_list[0]
|
69 |
-
prompt = result_list[1]
|
70 |
-
|
71 |
-
return image_link, prompt
|
72 |
-
|
73 |
|
74 |
def generate_image_ig_parallel_anony(self, prompt, model_A, model_B, model_C, model_D):
|
75 |
if model_A == "" and model_B == "" and model_C == "" and model_D == "":
|
@@ -174,15 +165,6 @@ class ModelManager:
|
|
174 |
results = [future.result() for future in futures]
|
175 |
return results[0], results[1]
|
176 |
|
177 |
-
def generate_image_ig_museum_parallel(self, model_A, model_B):
|
178 |
-
with concurrent.futures.ThreadPoolExecutor() as executor:
|
179 |
-
model_1 = model_A.split('_')[1]
|
180 |
-
model_2 = model_B.split('_')[1]
|
181 |
-
result_list = draw2_from_imagen_museum("t2i", model_1, model_2)
|
182 |
-
image_links = result_list[0]
|
183 |
-
prompt_list = result_list[1]
|
184 |
-
return image_links[0], image_links[1], prompt_list[0]
|
185 |
-
|
186 |
|
187 |
@spaces.GPU(duration=200)
|
188 |
def generate_image_ie(self, textbox_source, textbox_target, textbox_instruct, source_image, model_name):
|
@@ -190,14 +172,6 @@ class ModelManager:
|
|
190 |
result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct)
|
191 |
return result
|
192 |
|
193 |
-
def generate_image_ie_museum(self, model_name):
|
194 |
-
model_name = model_name.split('_')[1]
|
195 |
-
result_list = draw_from_imagen_museum("tie", model_name)
|
196 |
-
image_links = result_list[0]
|
197 |
-
prompt_list = result_list[1]
|
198 |
-
# image_links = [src, model]
|
199 |
-
# prompt_list = [source_caption, target_caption, instruction]
|
200 |
-
return image_links[0], image_links[1], prompt_list[0], prompt_list[1], prompt_list[2]
|
201 |
|
202 |
def generate_image_ie_parallel(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
|
203 |
model_names = [model_A, model_B]
|
@@ -208,17 +182,6 @@ class ModelManager:
|
|
208 |
results = [future.result() for future in futures]
|
209 |
return results[0], results[1]
|
210 |
|
211 |
-
def generate_image_ie_museum_parallel(self, model_A, model_B):
|
212 |
-
model_names = [model_A, model_B]
|
213 |
-
with concurrent.futures.ThreadPoolExecutor() as executor:
|
214 |
-
model_1 = model_names[0].split('_')[1]
|
215 |
-
model_2 = model_names[1].split('_')[1]
|
216 |
-
result_list = draw2_from_imagen_museum("tie", model_1, model_2)
|
217 |
-
image_links = result_list[0]
|
218 |
-
prompt_list = result_list[1]
|
219 |
-
# image_links = [src, model_A, model_B]
|
220 |
-
# prompt_list = [source_caption, target_caption, instruction]
|
221 |
-
return image_links[0], image_links[1], image_links[2], prompt_list[0], prompt_list[1], prompt_list[2]
|
222 |
|
223 |
def generate_image_ie_parallel_anony(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
|
224 |
if model_A == "" and model_B == "":
|
@@ -229,21 +192,3 @@ class ModelManager:
|
|
229 |
futures = [executor.submit(self.generate_image_ie, textbox_source, textbox_target, textbox_instruct, source_image, model) for model in model_names]
|
230 |
results = [future.result() for future in futures]
|
231 |
return results[0], results[1], model_names[0], model_names[1]
|
232 |
-
|
233 |
-
def generate_image_ie_museum_parallel_anony(self, model_A, model_B):
|
234 |
-
if model_A == "" and model_B == "":
|
235 |
-
model_names = random.sample([model for model in self.model_ie_list], 2)
|
236 |
-
else:
|
237 |
-
model_names = [model_A, model_B]
|
238 |
-
with concurrent.futures.ThreadPoolExecutor() as executor:
|
239 |
-
model_1 = model_names[0].split('_')[1]
|
240 |
-
model_2 = model_names[1].split('_')[1]
|
241 |
-
result_list = draw2_from_imagen_museum("tie", model_1, model_2)
|
242 |
-
image_links = result_list[0]
|
243 |
-
prompt_list = result_list[1]
|
244 |
-
# image_links = [src, model_A, model_B]
|
245 |
-
# prompt_list = [source_caption, target_caption, instruction]
|
246 |
-
return image_links[0], image_links[1], image_links[2], prompt_list[0], prompt_list[1], prompt_list[2], model_names[0], model_names[1]
|
247 |
-
|
248 |
-
|
249 |
-
raise NotImplementedError
|
|
|
8 |
from PIL import Image
|
9 |
from openai import OpenAI
|
10 |
from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS, load_pipeline
|
|
|
11 |
from serve.upload import get_random_mscoco_prompt, get_random_video_prompt, get_ssh_random_video_prompt, get_ssh_random_image_prompt
|
12 |
from serve.constants import SSH_CACHE_OPENSOURCE, SSH_CACHE_ADVANCE, SSH_CACHE_PIKA, SSH_CACHE_SORA, SSH_CACHE_IMAGE
|
13 |
|
|
|
61 |
|
62 |
return result
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
def generate_image_ig_parallel_anony(self, prompt, model_A, model_B, model_C, model_D):
|
66 |
if model_A == "" and model_B == "" and model_C == "" and model_D == "":
|
|
|
165 |
results = [future.result() for future in futures]
|
166 |
return results[0], results[1]
|
167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
@spaces.GPU(duration=200)
|
170 |
def generate_image_ie(self, textbox_source, textbox_target, textbox_instruct, source_image, model_name):
|
|
|
172 |
result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct)
|
173 |
return result
|
174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
|
176 |
def generate_image_ie_parallel(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
|
177 |
model_names = [model_A, model_B]
|
|
|
182 |
results = [future.result() for future in futures]
|
183 |
return results[0], results[1]
|
184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
|
186 |
def generate_image_ie_parallel_anony(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
|
187 |
if model_A == "" and model_B == "":
|
|
|
192 |
futures = [executor.submit(self.generate_image_ie, textbox_source, textbox_target, textbox_instruct, source_image, model) for model in model_names]
|
193 |
results = [future.result() for future in futures]
|
194 |
return results[0], results[1], model_names[0], model_names[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/models/generate_image_cache.py
DELETED
@@ -1,99 +0,0 @@
|
|
1 |
-
from huggingface_models import load_huggingface_model
|
2 |
-
from replicate_api_models import load_replicate_model
|
3 |
-
from openai_api_models import load_openai_model
|
4 |
-
from other_api_models import load_other_model
|
5 |
-
import concurrent.futures
|
6 |
-
import os
|
7 |
-
import io, time
|
8 |
-
import requests
|
9 |
-
import json
|
10 |
-
from PIL import Image
|
11 |
-
|
12 |
-
|
13 |
-
IMAGE_GENERATION_MODELS = [
|
14 |
-
# 'replicate_SDXL_text2image',
|
15 |
-
# 'replicate_SD-v3.0_text2image',
|
16 |
-
# 'replicate_SD-v2.1_text2image',
|
17 |
-
# 'replicate_SD-v1.5_text2image',
|
18 |
-
# 'replicate_SDXL-Lightning_text2image',
|
19 |
-
# 'replicate_Kandinsky-v2.0_text2image',
|
20 |
-
# 'replicate_Kandinsky-v2.2_text2image',
|
21 |
-
# 'replicate_Proteus-v0.2_text2image',
|
22 |
-
# 'replicate_Playground-v2.0_text2image',
|
23 |
-
# 'replicate_Playground-v2.5_text2image',
|
24 |
-
# 'replicate_Dreamshaper-xl-turbo_text2image',
|
25 |
-
# 'replicate_SDXL-Deepcache_text2image',
|
26 |
-
# 'replicate_Openjourney-v4_text2image',
|
27 |
-
# 'replicate_LCM-v1.5_text2image',
|
28 |
-
# 'replicate_Realvisxl-v3.0_text2image',
|
29 |
-
# 'replicate_Realvisxl-v2.0_text2image',
|
30 |
-
# 'replicate_Pixart-Sigma_text2image',
|
31 |
-
# 'replicate_SSD-1b_text2image',
|
32 |
-
# 'replicate_Open-Dalle-v1.1_text2image',
|
33 |
-
# 'replicate_Deepfloyd-IF_text2image',
|
34 |
-
# 'huggingface_SD-turbo_text2image',
|
35 |
-
# 'huggingface_SDXL-turbo_text2image',
|
36 |
-
# 'huggingface_Stable-cascade_text2image',
|
37 |
-
# 'openai_Dalle-2_text2image',
|
38 |
-
# 'openai_Dalle-3_text2image',
|
39 |
-
'other_Midjourney-v6.0_text2image',
|
40 |
-
'other_Midjourney-v5.0_text2image',
|
41 |
-
# "replicate_FLUX.1-schnell_text2image",
|
42 |
-
# "replicate_FLUX.1-pro_text2image",
|
43 |
-
# "replicate_FLUX.1-dev_text2image",
|
44 |
-
]
|
45 |
-
|
46 |
-
Prompts = [
|
47 |
-
# 'An aerial view of someone walking through a forest alone in the style of Romanticism.',
|
48 |
-
# 'With dark tones and backlit resolution, this oil painting depicts a thunderstorm over a cityscape.',
|
49 |
-
# 'The rendering depicts a futuristic train station with volumetric lighting in an Art Nouveau style.',
|
50 |
-
# 'An Impressionist illustration depicts a river winding through a meadow.', # featuring a thick black outline
|
51 |
-
# 'Photo of a black and white picture of a person facing the sunset from a bench.',
|
52 |
-
# 'The skyline of a city is painted in bright, high-resolution colors.',
|
53 |
-
# 'A sketch shows two robots talking to each other, featuring a surreal look and narrow aspect ratio.',
|
54 |
-
# 'An abstract Dadaist collage in neon tones and 4K resolutions of a post-apocalyptic world.',
|
55 |
-
# 'With abstract elements and a rococo style, the painting depicts a garden in high resolution.',
|
56 |
-
# 'A picture of a senior man walking in the rain and looking directly at the camera from a medium distance.',
|
57 |
-
]
|
58 |
-
|
59 |
-
def load_pipeline(model_name):
|
60 |
-
model_source, model_name, model_type = model_name.split("_")
|
61 |
-
if model_source == "replicate":
|
62 |
-
pipe = load_replicate_model(model_name, model_type)
|
63 |
-
elif model_source == "huggingface":
|
64 |
-
pipe = load_huggingface_model(model_name, model_type)
|
65 |
-
elif model_source == "openai":
|
66 |
-
pipe = load_openai_model(model_name, model_type)
|
67 |
-
elif model_source == "other":
|
68 |
-
pipe = load_other_model(model_name, model_type)
|
69 |
-
else:
|
70 |
-
raise ValueError(f"Model source {model_source} not supported")
|
71 |
-
return pipe
|
72 |
-
|
73 |
-
def generate_image_ig_api(prompt, model_name):
|
74 |
-
pipe = load_pipeline(model_name)
|
75 |
-
result = pipe(prompt=prompt)
|
76 |
-
return result
|
77 |
-
|
78 |
-
save_names = []
|
79 |
-
for name in IMAGE_GENERATION_MODELS:
|
80 |
-
model_source, model_name, model_type = name.split("_")
|
81 |
-
save_names.append(model_name)
|
82 |
-
|
83 |
-
for i, prompt in enumerate(Prompts):
|
84 |
-
print("save the {} prompt".format(i+1))
|
85 |
-
with concurrent.futures.ThreadPoolExecutor() as executor:
|
86 |
-
futures = [executor.submit(generate_image_ig_api, prompt, model) for model in IMAGE_GENERATION_MODELS]
|
87 |
-
results = [future.result() for future in futures]
|
88 |
-
|
89 |
-
root_dir = '/rscratch/zhendong/lizhikai/ksort/ksort_image_cache/'
|
90 |
-
save_dir = os.path.join(root_dir, f'output-{i+4}')
|
91 |
-
if not os.path.exists(save_dir):
|
92 |
-
os.makedirs(save_dir, exist_ok=True)
|
93 |
-
with open(os.path.join(save_dir, "prompt.txt"), 'w', encoding='utf-8') as file:
|
94 |
-
file.write(prompt)
|
95 |
-
|
96 |
-
for j, result in enumerate(results):
|
97 |
-
result = result.resize((512, 512))
|
98 |
-
file_path = os.path.join(save_dir, f'{save_names[j]}.jpg')
|
99 |
-
result.save(file_path, format="JPEG")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/models/generate_video_cache.py
DELETED
@@ -1,62 +0,0 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
file_path = '/home/lizhikai/webvid_prompt100.txt'
|
4 |
-
str_list = []
|
5 |
-
with open(file_path, 'r', encoding='utf-8') as file:
|
6 |
-
for line in file:
|
7 |
-
str_list.append(line.strip())
|
8 |
-
if len(str_list) == 100:
|
9 |
-
break
|
10 |
-
|
11 |
-
def generate_image_ig_api(prompt, model_name):
|
12 |
-
model_source, model_name, model_type = model_name.split("_")
|
13 |
-
pipe = load_replicate_model(model_name, model_type)
|
14 |
-
result = pipe(prompt=prompt)
|
15 |
-
return result
|
16 |
-
model_names = ['replicate_Zeroscope-v2-xl_text2video',
|
17 |
-
# 'replicate_Damo-Text-to-Video_text2video',
|
18 |
-
'replicate_Animate-Diff_text2video',
|
19 |
-
'replicate_OpenSora_text2video',
|
20 |
-
'replicate_LaVie_text2video',
|
21 |
-
'replicate_VideoCrafter2_text2video',
|
22 |
-
'replicate_Stable-Video-Diffusion_text2video',
|
23 |
-
]
|
24 |
-
save_names = []
|
25 |
-
for name in model_names:
|
26 |
-
model_source, model_name, model_type = name.split("_")
|
27 |
-
save_names.append(model_name)
|
28 |
-
|
29 |
-
for i, prompt in enumerate(str_list):
|
30 |
-
print("save the {} prompt".format(i+1))
|
31 |
-
# if i+1 < 97:
|
32 |
-
# continue
|
33 |
-
with concurrent.futures.ThreadPoolExecutor() as executor:
|
34 |
-
futures = [executor.submit(generate_image_ig_api, prompt, model) for model in model_names]
|
35 |
-
results = [future.result() for future in futures]
|
36 |
-
|
37 |
-
root_dir = '/mnt/data/lizhikai/ksort_video_cache/'
|
38 |
-
save_dir = os.path.join(root_dir, f'cache_{i+1}')
|
39 |
-
if not os.path.exists(save_dir):
|
40 |
-
os.makedirs(save_dir, exist_ok=True)
|
41 |
-
with open(os.path.join(save_dir, "prompt.txt"), 'w', encoding='utf-8') as file:
|
42 |
-
file.write(prompt)
|
43 |
-
|
44 |
-
# 下载视频并保存
|
45 |
-
repeat_num = 5
|
46 |
-
for j, url in enumerate(results):
|
47 |
-
while 1:
|
48 |
-
time.sleep(1)
|
49 |
-
response = requests.get(url, stream=True)
|
50 |
-
if response.status_code == 200:
|
51 |
-
file_path = os.path.join(save_dir, f'{save_names[j]}.mp4')
|
52 |
-
with open(file_path, 'wb') as file:
|
53 |
-
for chunk in response.iter_content(chunk_size=8192):
|
54 |
-
file.write(chunk)
|
55 |
-
print(f"视频 {j} 已保存到 {file_path}")
|
56 |
-
break
|
57 |
-
else:
|
58 |
-
repeat_num = repeat_num - 1
|
59 |
-
if repeat_num == 0:
|
60 |
-
print(f"视频 {j} 保存失败")
|
61 |
-
# raise ValueError("Video request failed.")
|
62 |
-
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|