Drexubery commited on
Commit
7ccdbd8
1 Parent(s): 752a7e2
Files changed (3) hide show
  1. app.py +13 -14
  2. requirements.txt +1 -2
  3. viewcrafter.py +2 -1
app.py CHANGED
@@ -6,12 +6,10 @@ import spaces
6
  # os.system('pip install iopath')
7
  # os.system("pip install -v -v -v 'git+https://github.com/facebookresearch/pytorch3d.git@stable'")
8
  # os.system("cd pytorch3d && pip install -e . && cd ..")
9
- # os.system("pip install 'git+https://github.com/facebookresearch/pytorch3d.git'")
10
  os.system("mkdir -p checkpoints/ && wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth -P checkpoints/")
11
 
12
  import gradio as gr
13
  import random
14
- from viewcrafter import ViewCrafter
15
  from configs.infer_config import get_parser
16
  from huggingface_hub import hf_hub_download
17
 
@@ -34,11 +32,19 @@ def download_model():
34
  hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/', force_download=True)
35
 
36
  download_model()
 
 
 
 
 
 
 
 
37
 
38
  def viewcrafter_demo(opts):
39
  css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height:576px} #random_button {max-width: 100px !important}"""
40
  image2video = ViewCrafter(opts, gradio = True)
41
- image2video.run_gradio = spaces.GPU(image2video.run_gradio, duration=800)
42
  with gr.Blocks(analytics_enabled=False, css=css) as viewcrafter_iface:
43
  gr.Markdown("<div align='center'> <h1> ViewCrafter: Taming Video Diffusion Models for High-fidelity Novel View Synthesis </span> </h1> \
44
  <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
@@ -100,15 +106,8 @@ def viewcrafter_demo(opts):
100
  return viewcrafter_iface
101
 
102
 
103
- if __name__ == "__main__":
104
- parser = get_parser() # infer_config.py
105
- opts = parser.parse_args() # default device: 'cuda:0'
106
- opts.save_dir = './'
107
- os.makedirs(opts.save_dir,exist_ok=True)
108
- test_tensor = torch.Tensor([0]).cuda()
109
- opts.device = str(test_tensor.device)
110
- viewcrafter_iface = viewcrafter_demo(opts)
111
- viewcrafter_iface.queue(max_size=10)
112
- viewcrafter_iface.launch()
113
- # viewcrafter_iface.launch(server_name='127.0.0.1', server_port=80, max_threads=1,debug=False)
114
 
 
6
  # os.system('pip install iopath')
7
  # os.system("pip install -v -v -v 'git+https://github.com/facebookresearch/pytorch3d.git@stable'")
8
  # os.system("cd pytorch3d && pip install -e . && cd ..")
 
9
  os.system("mkdir -p checkpoints/ && wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth -P checkpoints/")
10
 
11
  import gradio as gr
12
  import random
 
13
  from configs.infer_config import get_parser
14
  from huggingface_hub import hf_hub_download
15
 
 
32
  hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/', force_download=True)
33
 
34
  download_model()
35
+ parser = get_parser() # infer_config.py
36
+ opts = parser.parse_args() # default device: 'cuda:0'
37
+ opts.save_dir = './'
38
+ os.makedirs(opts.save_dir,exist_ok=True)
39
+ test_tensor = torch.Tensor([0]).cuda()
40
+ opts.device = str(test_tensor.device)
41
+ os.system("pip install 'git+https://github.com/facebookresearch/pytorch3d.git'")
42
+ from viewcrafter import ViewCrafter
43
 
44
  def viewcrafter_demo(opts):
45
  css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height:576px} #random_button {max-width: 100px !important}"""
46
  image2video = ViewCrafter(opts, gradio = True)
47
+ image2video.run_gradio = spaces.GPU(image2video.run_gradio, duration=300)
48
  with gr.Blocks(analytics_enabled=False, css=css) as viewcrafter_iface:
49
  gr.Markdown("<div align='center'> <h1> ViewCrafter: Taming Video Diffusion Models for High-fidelity Novel View Synthesis </span> </h1> \
50
  <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
 
106
  return viewcrafter_iface
107
 
108
 
109
+ viewcrafter_iface = viewcrafter_demo(opts)
110
+ viewcrafter_iface.queue(max_size=10)
111
+ viewcrafter_iface.launch()
112
+ # viewcrafter_iface.launch(server_name='127.0.0.1', server_port=80, max_threads=1,debug=False)
 
 
 
 
 
 
 
113
 
requirements.txt CHANGED
@@ -28,5 +28,4 @@ omegaconf==2.3.0
28
  triton
29
  av
30
  xformers
31
- gradio
32
- git+https://github.com/facebookresearch/pytorch3d.git
 
28
  triton
29
  av
30
  xformers
31
+ gradio
 
viewcrafter.py CHANGED
@@ -143,7 +143,8 @@ class ViewCrafter:
143
  r = [float(i) for i in lines[2].split()]
144
  else:
145
  phi, theta, r = self.gradio_traj
146
- device = torch.device("cpu")
 
147
  camera_traj,num_views = generate_traj_txt(c2ws, H, W, focals, principal_points, phi, theta, r,self.opts.video_length, device,viz_traj=True, save_dir = self.opts.save_dir)
148
  # camera_traj,num_views = generate_traj_txt(c2ws, H, W, focals, principal_points, phi, theta, r,self.opts.video_length, self.device,viz_traj=True, save_dir = self.opts.save_dir)
149
  else:
 
143
  r = [float(i) for i in lines[2].split()]
144
  else:
145
  phi, theta, r = self.gradio_traj
146
+ # device = torch.device("cpu")
147
+ device = self.device
148
  camera_traj,num_views = generate_traj_txt(c2ws, H, W, focals, principal_points, phi, theta, r,self.opts.video_length, device,viz_traj=True, save_dir = self.opts.save_dir)
149
  # camera_traj,num_views = generate_traj_txt(c2ws, H, W, focals, principal_points, phi, theta, r,self.opts.video_length, self.device,viz_traj=True, save_dir = self.opts.save_dir)
150
  else: