mirror of https://github.com/malarinv/tacotron2
train.py: adding routine to warm start and ignore layers, e.g. embedding.weight
parent
bb67613493
commit
3869781877
16
train.py
16
train.py
|
|
@ -89,11 +89,18 @@ def load_model(hparams):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def warm_start_model(checkpoint_path, model):
|
def warm_start_model(checkpoint_path, model, ignore_layers):
|
||||||
assert os.path.isfile(checkpoint_path)
|
assert os.path.isfile(checkpoint_path)
|
||||||
print("Warm starting model from checkpoint '{}'".format(checkpoint_path))
|
print("Warm starting model from checkpoint '{}'".format(checkpoint_path))
|
||||||
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
||||||
model.load_state_dict(checkpoint_dict['state_dict'])
|
model_dict = checkpoint_dict['state_dict']
|
||||||
|
if len(ignore_layers) > 0:
|
||||||
|
model_dict = {k: v for k, v in model_dict.items()
|
||||||
|
if k not in ignore_layers}
|
||||||
|
dummy_dict = model.state_dict()
|
||||||
|
dummy_dict.update(model_dict)
|
||||||
|
model_dict = dummy_dict
|
||||||
|
model.load_state_dict(model_dict)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -189,7 +196,8 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
|
||||||
epoch_offset = 0
|
epoch_offset = 0
|
||||||
if checkpoint_path is not None:
|
if checkpoint_path is not None:
|
||||||
if warm_start:
|
if warm_start:
|
||||||
model = warm_start_model(checkpoint_path, model)
|
model = warm_start_model(
|
||||||
|
checkpoint_path, model, hparams.ignore_layers)
|
||||||
else:
|
else:
|
||||||
model, optimizer, _learning_rate, iteration = load_checkpoint(
|
model, optimizer, _learning_rate, iteration = load_checkpoint(
|
||||||
checkpoint_path, model, optimizer)
|
checkpoint_path, model, optimizer)
|
||||||
|
|
@ -258,7 +266,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('-c', '--checkpoint_path', type=str, default=None,
|
parser.add_argument('-c', '--checkpoint_path', type=str, default=None,
|
||||||
required=False, help='checkpoint path')
|
required=False, help='checkpoint path')
|
||||||
parser.add_argument('--warm_start', action='store_true',
|
parser.add_argument('--warm_start', action='store_true',
|
||||||
help='load the model only (warm start)')
|
help='load model weights only, ignore specified layers')
|
||||||
parser.add_argument('--n_gpus', type=int, default=1,
|
parser.add_argument('--n_gpus', type=int, default=1,
|
||||||
required=False, help='number of gpus')
|
required=False, help='number of gpus')
|
||||||
parser.add_argument('--rank', type=int, default=0,
|
parser.add_argument('--rank', type=int, default=0,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue