wzhouxiff commited on
Commit
9e2e83c
1 Parent(s): 44aebc4

add model path

Browse files
Files changed (1) hide show
  1. app.py +16 -5
app.py CHANGED
@@ -295,11 +295,14 @@ def fn_traj_reset():
295
  ###########################################
296
  model_path='./checkpoints/motionctrl.pth'
297
  config_path='./configs/inference/config_both.yaml'
 
 
298
 
299
  config = OmegaConf.load(config_path)
300
  model_config = config.pop("model", OmegaConf.create())
301
  model = instantiate_from_config(model_config)
302
- model = model.cuda()
 
303
 
304
  model = load_model_checkpoint(model, model_path)
305
  model.eval()
@@ -332,21 +335,29 @@ def model_run(prompts, infer_mode, seed, n_samples):
332
 
333
  if infer_mode == MODE[0]:
334
  camera_poses = RT
335
- camera_poses = torch.tensor(camera_poses).float().cuda()
336
  camera_poses = camera_poses.unsqueeze(0)
337
  trajs = None
 
 
338
  elif infer_mode == MODE[1]:
339
  trajs = traj_flow
340
- trajs = torch.tensor(trajs).float().cuda()
341
  trajs = trajs.unsqueeze(0)
342
  camera_poses = None
 
 
343
  else:
344
  camera_poses = RT
345
  trajs = traj_flow
346
- camera_poses = torch.tensor(camera_poses).float().cuda()
347
- trajs = torch.tensor(trajs).float().cuda()
348
  camera_poses = camera_poses.unsqueeze(0)
349
  trajs = trajs.unsqueeze(0)
 
 
 
 
350
 
351
  ddim_sampler = DDIMSampler(model)
352
  batch_size = noise_shape[0]
 
295
  ###########################################
296
  model_path='./checkpoints/motionctrl.pth'
297
  config_path='./configs/inference/config_both.yaml'
298
+ if not os.path.exists(model_path):
299
+ os.system(f'wget https://huggingface.co/TencentARC/MotionCtrl/resolve/main/motionctrl.pth?download=true -P ./checkpoints/')
300
 
301
  config = OmegaConf.load(config_path)
302
  model_config = config.pop("model", OmegaConf.create())
303
  model = instantiate_from_config(model_config)
304
+ if torch.cuda.is_available():
305
+ model = model.cuda()
306
 
307
  model = load_model_checkpoint(model, model_path)
308
  model.eval()
 
335
 
336
  if infer_mode == MODE[0]:
337
  camera_poses = RT
338
+ camera_poses = torch.tensor(camera_poses).float()
339
  camera_poses = camera_poses.unsqueeze(0)
340
  trajs = None
341
+ if torch.cuda.is_available():
342
+ camera_poses = camera_poses.cuda()
343
  elif infer_mode == MODE[1]:
344
  trajs = traj_flow
345
+ trajs = torch.tensor(trajs).float()
346
  trajs = trajs.unsqueeze(0)
347
  camera_poses = None
348
+ if torch.cuda.is_available():
349
+ trajs = trajs.cuda()
350
  else:
351
  camera_poses = RT
352
  trajs = traj_flow
353
+ camera_poses = torch.tensor(camera_poses).float()
354
+ trajs = torch.tensor(trajs).float()
355
  camera_poses = camera_poses.unsqueeze(0)
356
  trajs = trajs.unsqueeze(0)
357
+ if torch.cuda.is_available():
358
+ camera_poses = camera_poses.cuda()
359
+ trajs = trajs.cuda()
360
+
361
 
362
  ddim_sampler = DDIMSampler(model)
363
  batch_size = noise_shape[0]