hparams.py: adding use saved learning rate param

latest_model
rafaelvalle 2018-06-05 08:12:49 -07:00
parent 22bcff1155
commit 5f0ea06c41
1 changed files with 4 additions and 1 deletions

View File

@ -190,8 +190,11 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
if warm_start: if warm_start:
model = warm_start_model(checkpoint_path, model) model = warm_start_model(checkpoint_path, model)
else: else:
model, optimizer, learning_rate, iteration = load_checkpoint( model, optimizer, _learning_rate, iteration = load_checkpoint(
checkpoint_path, model, optimizer) checkpoint_path, model, optimizer)
if hparams.use_saved_learning_rate:
learning_rate = _learning_rate
iteration += 1 # next iteration is iteration + 1 iteration += 1 # next iteration is iteration + 1
epoch_offset = max(0, int(iteration / len(train_loader))) epoch_offset = max(0, int(iteration / len(train_loader)))