Spaces:
Runtime error
Runtime error
add tqdm
Browse files- requirements.txt +2 -0
- train.py +7 -6
requirements.txt
CHANGED
@@ -6,3 +6,5 @@ scikit-image>=0.14.0
|
|
6 |
torchvision>=0.2.1
|
7 |
pillow>=7.2.0
|
8 |
lpips>=0.1.3
|
|
|
|
|
|
6 |
torchvision>=0.2.1
|
7 |
pillow>=7.2.0
|
8 |
lpips>=0.1.3
|
9 |
+
gdown
|
10 |
+
tqdm
|
train.py
CHANGED
@@ -12,6 +12,7 @@ from data_loader import (FileDataset,
|
|
12 |
RandomResizedCropWithAutoCenteringAndZeroPadding)
|
13 |
from torch.utils.data.distributed import DistributedSampler
|
14 |
from conr import CoNR
|
|
|
15 |
|
16 |
def data_sampler(dataset, shuffle, distributed):
|
17 |
|
@@ -123,7 +124,7 @@ def infer(args, humanflowmodel, image_names_list):
|
|
123 |
time_stamp = time.time()
|
124 |
prev_frame_rgb = []
|
125 |
prev_frame_a = []
|
126 |
-
for i, data in enumerate(train_data):
|
127 |
data_time_interval = time.time() - time_stamp
|
128 |
time_stamp = time.time()
|
129 |
with torch.no_grad():
|
@@ -137,11 +138,11 @@ def infer(args, humanflowmodel, image_names_list):
|
|
137 |
|
138 |
train_time_interval = time.time() - time_stamp
|
139 |
time_stamp = time.time()
|
140 |
-
if i % 5 == 0 and args.local_rank == 0:
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
with torch.no_grad():
|
146 |
|
147 |
if args.test_output_video:
|
|
|
12 |
RandomResizedCropWithAutoCenteringAndZeroPadding)
|
13 |
from torch.utils.data.distributed import DistributedSampler
|
14 |
from conr import CoNR
|
15 |
+
from tqdm import tqdm
|
16 |
|
17 |
def data_sampler(dataset, shuffle, distributed):
|
18 |
|
|
|
124 |
time_stamp = time.time()
|
125 |
prev_frame_rgb = []
|
126 |
prev_frame_a = []
|
127 |
+
for i, data in tqdm(enumerate(train_data)):
|
128 |
data_time_interval = time.time() - time_stamp
|
129 |
time_stamp = time.time()
|
130 |
with torch.no_grad():
|
|
|
138 |
|
139 |
train_time_interval = time.time() - time_stamp
|
140 |
time_stamp = time.time()
|
141 |
+
# if i % 5 == 0 and args.local_rank == 0:
|
142 |
+
# print("[infer batch: %4d/%4d] time:%2f+%2f" % (
|
143 |
+
# i, train_num,
|
144 |
+
# data_time_interval, train_time_interval
|
145 |
+
# ))
|
146 |
with torch.no_grad():
|
147 |
|
148 |
if args.test_output_video:
|