Spaces:
Running
Running
JacobLinCool
commited on
Commit
•
fa9dd69
1
Parent(s):
0e6fd1f
feat: early return when trained 10 epoch
Browse files
infer/modules/train/train.py
CHANGED
@@ -248,8 +248,8 @@ def run(rank, n_gpus, hps, logger: logging.Logger, state):
|
|
248 |
scaler = GradScaler(enabled=hps.train.fp16_run)
|
249 |
|
250 |
cache = []
|
|
|
251 |
for epoch in range(epoch_str, hps.train.epochs + 1):
|
252 |
-
print("epoch", epoch)
|
253 |
if rank == 0:
|
254 |
train_and_evaluate(
|
255 |
rank,
|
@@ -283,6 +283,10 @@ def run(rank, n_gpus, hps, logger: logging.Logger, state):
|
|
283 |
scheduler_g.step()
|
284 |
scheduler_d.step()
|
285 |
|
|
|
|
|
|
|
|
|
286 |
|
287 |
def train_and_evaluate(
|
288 |
rank,
|
|
|
248 |
scaler = GradScaler(enabled=hps.train.fp16_run)
|
249 |
|
250 |
cache = []
|
251 |
+
trained = 0
|
252 |
for epoch in range(epoch_str, hps.train.epochs + 1):
|
|
|
253 |
if rank == 0:
|
254 |
train_and_evaluate(
|
255 |
rank,
|
|
|
283 |
scheduler_g.step()
|
284 |
scheduler_d.step()
|
285 |
|
286 |
+
trained += 1
|
287 |
+
if trained >= 10:
|
288 |
+
break
|
289 |
+
|
290 |
|
291 |
def train_and_evaluate(
|
292 |
rank,
|