Spaces:
Sleeping
Sleeping
Anonymous
commited on
Commit
•
e84616f
1
Parent(s):
e994f84
add spaces
Browse files
app.py
CHANGED
@@ -22,6 +22,9 @@ from funcs import (
|
|
22 |
from utils.utils import instantiate_from_config
|
23 |
from utils.utils_freetraj import plan_path
|
24 |
|
|
|
|
|
|
|
25 |
MAX_KEYS = 5
|
26 |
|
27 |
ckpt_dir_512 = "checkpoints/base_512_v2"
|
@@ -56,7 +59,7 @@ def check(radio_mode):
|
|
56 |
video_bbox_path = "output_freetraj_bbox.mp4"
|
57 |
return video_path, video_bbox_path
|
58 |
|
59 |
-
|
60 |
def infer(*user_args):
|
61 |
prompt_in = user_args[0]
|
62 |
target_indices = user_args[1]
|
@@ -75,9 +78,6 @@ def infer(*user_args):
|
|
75 |
w_positions = user_args[-MAX_KEYS:]
|
76 |
print(user_args)
|
77 |
|
78 |
-
video_length = 16
|
79 |
-
width = 512
|
80 |
-
height = 320
|
81 |
if radio_mode == 'ori':
|
82 |
config_512 = "configs/inference_t2v_512_v2.0.yaml"
|
83 |
else:
|
@@ -110,15 +110,6 @@ def infer(*user_args):
|
|
110 |
|
111 |
config_512 = OmegaConf.load(config_512)
|
112 |
model_config_512 = config_512.pop("model", OmegaConf.create())
|
113 |
-
model = instantiate_from_config(model_config_512)
|
114 |
-
model = model.cuda()
|
115 |
-
model = load_model_checkpoint(model, ckpt_path_512)
|
116 |
-
model.eval()
|
117 |
-
|
118 |
-
if seed is None:
|
119 |
-
seed = int.from_bytes(os.urandom(2), "big")
|
120 |
-
print(f"Using seed: {seed}")
|
121 |
-
seed_everything(seed)
|
122 |
|
123 |
args = argparse.Namespace(
|
124 |
mode="base",
|
@@ -127,57 +118,20 @@ def infer(*user_args):
|
|
127 |
ddim_steps=ddim_steps,
|
128 |
ddim_eta=0.0,
|
129 |
bs=1,
|
130 |
-
height=height,
|
131 |
-
width=width,
|
132 |
-
frames=video_length,
|
133 |
fps=video_fps,
|
134 |
unconditional_guidance_scale=unconditional_guidance_scale,
|
135 |
unconditional_guidance_scale_temporal=None,
|
136 |
cond_input=None,
|
|
|
|
|
137 |
ddim_edit = ddim_edit,
|
|
|
|
|
|
|
138 |
)
|
139 |
|
140 |
-
|
141 |
-
|
142 |
-
frames = model.temporal_length if args.frames < 0 else args.frames
|
143 |
-
channels = model.channels
|
144 |
-
|
145 |
-
batch_size = 1
|
146 |
-
noise_shape = [batch_size, channels, frames, h, w]
|
147 |
-
fps = torch.tensor([args.fps] * batch_size).to(model.device).long()
|
148 |
-
prompts = [prompt_in]
|
149 |
-
text_emb = model.get_learned_conditioning(prompts)
|
150 |
-
|
151 |
-
cond = {"c_crossattn": [text_emb], "fps": fps}
|
152 |
-
|
153 |
-
## inference
|
154 |
-
if radio_mode == 'ori':
|
155 |
-
batch_samples = batch_ddim_sampling(
|
156 |
-
model,
|
157 |
-
cond,
|
158 |
-
noise_shape,
|
159 |
-
args.n_samples,
|
160 |
-
args.ddim_steps,
|
161 |
-
args.ddim_eta,
|
162 |
-
args.unconditional_guidance_scale,
|
163 |
-
args=args,
|
164 |
-
)
|
165 |
-
else:
|
166 |
-
batch_samples = batch_ddim_sampling_freetraj(
|
167 |
-
model,
|
168 |
-
cond,
|
169 |
-
noise_shape,
|
170 |
-
args.n_samples,
|
171 |
-
args.ddim_steps,
|
172 |
-
args.ddim_eta,
|
173 |
-
args.unconditional_guidance_scale,
|
174 |
-
idx_list = idx_list,
|
175 |
-
input_traj = input_traj,
|
176 |
-
args=args,
|
177 |
-
)
|
178 |
-
|
179 |
-
vid_tensor = batch_samples[0]
|
180 |
-
video = vid_tensor.detach().cpu()
|
181 |
video = torch.clamp(video.float(), -1.0, 1.0)
|
182 |
video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
|
183 |
|
@@ -251,6 +205,67 @@ def infer(*user_args):
|
|
251 |
|
252 |
return video_path, video_bbox_path
|
253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
|
255 |
examples = [
|
256 |
["A squirrel jumping from one tree to another.",],
|
|
|
22 |
from utils.utils import instantiate_from_config
|
23 |
from utils.utils_freetraj import plan_path
|
24 |
|
25 |
+
video_length = 16
|
26 |
+
width = 512
|
27 |
+
height = 320
|
28 |
MAX_KEYS = 5
|
29 |
|
30 |
ckpt_dir_512 = "checkpoints/base_512_v2"
|
|
|
59 |
video_bbox_path = "output_freetraj_bbox.mp4"
|
60 |
return video_path, video_bbox_path
|
61 |
|
62 |
+
|
63 |
def infer(*user_args):
|
64 |
prompt_in = user_args[0]
|
65 |
target_indices = user_args[1]
|
|
|
78 |
w_positions = user_args[-MAX_KEYS:]
|
79 |
print(user_args)
|
80 |
|
|
|
|
|
|
|
81 |
if radio_mode == 'ori':
|
82 |
config_512 = "configs/inference_t2v_512_v2.0.yaml"
|
83 |
else:
|
|
|
110 |
|
111 |
config_512 = OmegaConf.load(config_512)
|
112 |
model_config_512 = config_512.pop("model", OmegaConf.create())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
args = argparse.Namespace(
|
115 |
mode="base",
|
|
|
118 |
ddim_steps=ddim_steps,
|
119 |
ddim_eta=0.0,
|
120 |
bs=1,
|
|
|
|
|
|
|
121 |
fps=video_fps,
|
122 |
unconditional_guidance_scale=unconditional_guidance_scale,
|
123 |
unconditional_guidance_scale_temporal=None,
|
124 |
cond_input=None,
|
125 |
+
prompt_in = prompt_in,
|
126 |
+
seed = seed,
|
127 |
ddim_edit = ddim_edit,
|
128 |
+
model_config_512 = model_config_512,
|
129 |
+
idx_list = idx_list,
|
130 |
+
input_traj = input_traj,
|
131 |
)
|
132 |
|
133 |
+
video = infer_gpu_part(args)
|
134 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
video = torch.clamp(video.float(), -1.0, 1.0)
|
136 |
video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
|
137 |
|
|
|
205 |
|
206 |
return video_path, video_bbox_path
|
207 |
|
208 |
+
|
209 |
+
|
210 |
+
@spaces.GPU(duration=270)
|
211 |
+
def infer_gpu_part(args):
|
212 |
+
|
213 |
+
model = instantiate_from_config(args.model_config_512)
|
214 |
+
model = model.cuda()
|
215 |
+
model = load_model_checkpoint(model, ckpt_path_512)
|
216 |
+
model.eval()
|
217 |
+
|
218 |
+
if args.seed is None:
|
219 |
+
seed = int.from_bytes(os.urandom(2), "big")
|
220 |
+
else:
|
221 |
+
seed = args.seed
|
222 |
+
print(f"Using seed: {seed}")
|
223 |
+
seed_everything(seed)
|
224 |
+
|
225 |
+
## latent noise shape
|
226 |
+
h, w = height // 8, width // 8
|
227 |
+
frames = video_length
|
228 |
+
channels = model.channels
|
229 |
+
|
230 |
+
batch_size = 1
|
231 |
+
noise_shape = [batch_size, channels, frames, h, w]
|
232 |
+
fps = torch.tensor([args.fps] * batch_size).to(model.device).long()
|
233 |
+
prompts = [args.prompt_in]
|
234 |
+
text_emb = model.get_learned_conditioning(prompts)
|
235 |
+
|
236 |
+
cond = {"c_crossattn": [text_emb], "fps": fps}
|
237 |
+
|
238 |
+
## inference
|
239 |
+
if radio_mode == 'ori':
|
240 |
+
batch_samples = batch_ddim_sampling(
|
241 |
+
model,
|
242 |
+
cond,
|
243 |
+
noise_shape,
|
244 |
+
args.n_samples,
|
245 |
+
args.ddim_steps,
|
246 |
+
args.ddim_eta,
|
247 |
+
args.unconditional_guidance_scale,
|
248 |
+
args=args,
|
249 |
+
)
|
250 |
+
else:
|
251 |
+
batch_samples = batch_ddim_sampling_freetraj(
|
252 |
+
model,
|
253 |
+
cond,
|
254 |
+
noise_shape,
|
255 |
+
args.n_samples,
|
256 |
+
args.ddim_steps,
|
257 |
+
args.ddim_eta,
|
258 |
+
args.unconditional_guidance_scale,
|
259 |
+
idx_list = args.idx_list,
|
260 |
+
input_traj = args.input_traj,
|
261 |
+
args=args,
|
262 |
+
)
|
263 |
+
|
264 |
+
vid_tensor = batch_samples[0]
|
265 |
+
video = vid_tensor.detach().cpu()
|
266 |
+
|
267 |
+
return video
|
268 |
+
|
269 |
|
270 |
examples = [
|
271 |
["A squirrel jumping from one tree to another.",],
|