mirror of https://github.com/malarinv/tacotron2
train.py: adding routine to warm start and ignore layers, e.g. embedding.weight
parent
bb67613493
commit
3869781877
16
train.py
16
train.py
|
|
@ -89,11 +89,18 @@ def load_model(hparams):
|
|||
return model
|
||||
|
||||
|
||||
def warm_start_model(checkpoint_path, model):
|
||||
def warm_start_model(checkpoint_path, model, ignore_layers):
|
||||
assert os.path.isfile(checkpoint_path)
|
||||
print("Warm starting model from checkpoint '{}'".format(checkpoint_path))
|
||||
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
||||
model.load_state_dict(checkpoint_dict['state_dict'])
|
||||
model_dict = checkpoint_dict['state_dict']
|
||||
if len(ignore_layers) > 0:
|
||||
model_dict = {k: v for k, v in model_dict.items()
|
||||
if k not in ignore_layers}
|
||||
dummy_dict = model.state_dict()
|
||||
dummy_dict.update(model_dict)
|
||||
model_dict = dummy_dict
|
||||
model.load_state_dict(model_dict)
|
||||
return model
|
||||
|
||||
|
||||
|
|
@ -189,7 +196,8 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
|
|||
epoch_offset = 0
|
||||
if checkpoint_path is not None:
|
||||
if warm_start:
|
||||
model = warm_start_model(checkpoint_path, model)
|
||||
model = warm_start_model(
|
||||
checkpoint_path, model, hparams.ignore_layers)
|
||||
else:
|
||||
model, optimizer, _learning_rate, iteration = load_checkpoint(
|
||||
checkpoint_path, model, optimizer)
|
||||
|
|
@ -258,7 +266,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('-c', '--checkpoint_path', type=str, default=None,
|
||||
required=False, help='checkpoint path')
|
||||
parser.add_argument('--warm_start', action='store_true',
|
||||
help='load the model only (warm start)')
|
||||
help='load model weights only, ignore specified layers')
|
||||
parser.add_argument('--n_gpus', type=int, default=1,
|
||||
required=False, help='number of gpus')
|
||||
parser.add_argument('--rank', type=int, default=0,
|
||||
|
|
|
|||
Loading…
Reference in New Issue