xierui.0097 commited on
Commit
ca20c7a
·
1 Parent(s): 24a09de
video_to_video/video_to_video_model.py CHANGED
@@ -14,6 +14,20 @@ from video_to_video.diffusion.schedules_sdedit import noise_schedule
14
  from video_to_video.utils.logger import get_logger
15
 
16
  from diffusers import AutoencoderKLTemporalDecoder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  logger = get_logger()
19
 
@@ -34,6 +48,10 @@ class VideoToVideo_sr():
34
  generator.eval()
35
 
36
  cfg.model_path = opt.model_path
 
 
 
 
37
  load_dict = torch.load(cfg.model_path, map_location='cpu')
38
  if 'state_dict' in load_dict:
39
  load_dict = load_dict['state_dict']
 
14
  from video_to_video.utils.logger import get_logger
15
 
16
  from diffusers import AutoencoderKLTemporalDecoder
17
+ import requests
18
+
19
+ def download_model(url, model_path):
20
+ if not os.path.exists(model_path):
21
+ print(f"Model not found at {model_path}, downloading...")
22
+ response = requests.get(url, stream=True)
23
+ with open(model_path, 'wb') as f:
24
+ for chunk in response.iter_content(chunk_size=1024):
25
+ if chunk:
26
+ f.write(chunk)
27
+ print(f"Model downloaded to {model_path}")
28
+ else:
29
+ print(f"Model found at {model_path}, skipping download.")
30
+
31
 
32
  logger = get_logger()
33
 
 
48
  generator.eval()
49
 
50
  cfg.model_path = opt.model_path
51
+ # download weight
52
+ model_url = 'https://huggingface.co/SherryX/STAR/resolve/main/I2VGen-XL-based/heavy_deg.pt'
53
+ download_model(model_url, cfg.model_path)
54
+
55
  load_dict = torch.load(cfg.model_path, map_location='cpu')
56
  if 'state_dict' in load_dict:
57
  load_dict = load_dict['state_dict']