如何解决使用 Pytorch Lightning 训练时模型权重未更新
我是 Pytorch Lightning 的新手。我使用 PL 制作了一个非常简单的模型。 我在训练前后检查了模型的权重,但知道在训练期间损失减少,它们完全相同。
def main(args,df_train,df_dev,df_test) :
""" main function"""
# Wandb connect
wandb_connect()
wandb_logger = WandbLogger(project="project name",name="Run name")
# Tokenization
[df_train,df_test],params,tokenizer_qid,tokenizer_uid,tokenizer_qu_id,tokenizer_rank = apply_tokenization([df_train,df_test])
# Dataloadeers
[train_loader,dev_loader,test_loader] = list(map(lambda x : Dataset_SM(x).get_dataloader(args.batch_size),[df_train,df_test]))
# Model definition
model = NCM(**params).to(device)
# Weight before training
WW = model.emb_qid.weight
print(torch.mean(model.emb_qid.weight))
# Train & Eval
es = EarlyStopping(monitor='dev_loss',patience=4)
checkpoint_callback = ModelCheckpoint(dirpath=args.result_path)
trainer = pl.Trainer(max_epochs=args.n_epochs,callbacks=[es,checkpoint_callback],val_check_interval=args.val_check_interval,logger=wandb_logger,gpus=1)
trainer.fit(model,train_loader,dev_loader)
trainer.save_checkpoint(args.result_path + "example.ckpt")
loaded_model = NCM.load_from_checkpoint(checkpoint_path=args.result_path + "example.ckpt",**params)
print(loaded_model.emb_qid.weight == WW)
如果我错过了什么,有人可以告诉我吗?
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。