mirror of
https://github.com/malarinv/tacotron2
synced 2026-03-08 09:42:34 +00:00
train.py: val logger on gpu 0 only
This commit is contained in:
7
train.py
7
train.py
@@ -142,6 +142,7 @@ def validate(model, criterion, valset, iteration, batch_size, n_gpus,
|
|||||||
val_loss = val_loss / (i + 1)
|
val_loss = val_loss / (i + 1)
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
|
if rank == 0:
|
||||||
print("Validation loss {}: {:9f} ".format(iteration, reduced_val_loss))
|
print("Validation loss {}: {:9f} ".format(iteration, reduced_val_loss))
|
||||||
logger.log_validation(reduced_val_loss, model, y, y_pred, iteration)
|
logger.log_validation(reduced_val_loss, model, y, y_pred, iteration)
|
||||||
|
|
||||||
@@ -236,9 +237,9 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
|
|||||||
reduced_loss, grad_norm, learning_rate, duration, iteration)
|
reduced_loss, grad_norm, learning_rate, duration, iteration)
|
||||||
|
|
||||||
if not overflow and (iteration % hparams.iters_per_checkpoint == 0):
|
if not overflow and (iteration % hparams.iters_per_checkpoint == 0):
|
||||||
validate(model, criterion, valset, iteration, hparams.batch_size,
|
validate(model, criterion, valset, iteration,
|
||||||
n_gpus, collate_fn, logger, hparams.distributed_run, rank)
|
hparams.batch_size, n_gpus, collate_fn, logger,
|
||||||
|
hparams.distributed_run, rank)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
checkpoint_path = os.path.join(
|
checkpoint_path = os.path.join(
|
||||||
output_directory, "checkpoint_{}".format(iteration))
|
output_directory, "checkpoint_{}".format(iteration))
|
||||||
|
|||||||
Reference in New Issue
Block a user