如何解决加载 PyTorch Lightning 训练的检查点
我使用的是 PyTorch Lightning 1.4.0 版,并为数据集定义了以下类:
class CustomTrainDataset(Dataset):
'''
Custom PyTorch Dataset for training
Args:
data (pd.DataFrame) - DF containing product info (and maybe also ratings)
all_itemIds (list) - Python3 list containing all Item IDs
'''
def __init__(self,data,all_orderIds):
self.users,self.items,self.labels = self.get_dataset(data,all_orderIds)
def __len__(self):
return len(self.users)
def __getitem__(self,idx):
return self.users[idx],self.items[idx],self.labels[idx]
def get_dataset(self,all_orderIds):
users,items,labels = [],[],[]
user_item_set = set(zip(train_ratings['CustomerID'],train_ratings['ItemCode']))
num_negatives = 7
for u,i in user_item_set:
users.append(u)
items.append(i)
labels.append(1)
for _ in range(num_negatives):
negative_item = np.random.choice(all_itemIds)
while (u,negative_item) in user_item_set:
negative_item = np.random.choice(all_itemIds)
users.append(u)
items.append(negative_item)
labels.append(0)
return torch.tensor(users),torch.tensor(items),torch.tensor(labels)
接着是PL类:
class NCF(pl.LightningModule):
'''
Neural Collaborative Filtering (NCF)
Args:
num_users (int): Number of unique users
num_items (int): Number of unique items
data (pd.DataFrame): Dataframe containing the food ratings for training
all_orderIds (list): List containing all orderIds (train + test)
'''
def __init__(self,num_users,num_items,all_itemIds):
# def __init__(self,ratings,all_movieIds):
super().__init__()
self.user_embedding = nn.Embedding(num_embeddings = num_users,embedding_dim = 8)
# self.user_embedding = nn.Embedding(num_embeddings = num_users,embedding_dim = 10)
self.item_embedding = nn.Embedding(num_embeddings = num_items,embedding_dim = 8)
# self.item_embedding = nn.Embedding(num_embeddings = num_items,embedding_dim = 10)
self.fc1 = nn.Linear(in_features = 16,out_features = 64)
# self.fc1 = nn.Linear(in_features = 20,out_features = 64)
self.fc2 = nn.Linear(in_features = 64,out_features = 64)
self.fc3 = nn.Linear(in_features = 64,out_features = 32)
self.output = nn.Linear(in_features = 32,out_features = 1)
self.data = data
# self.ratings = ratings
# self.all_movieIds = all_movieIds
self.all_orderIds = all_orderIds
def forward(self,user_input,item_input):
# Pass through embedding layers
user_embedded = self.user_embedding(user_input)
item_embedded = self.item_embedding(item_input)
# Concat the two embedding layers
vector = torch.cat([user_embedded,item_embedded],dim = -1)
# Pass through dense layer
vector = nn.ReLU()(self.fc1(vector))
vector = nn.ReLU()(self.fc2(vector))
vector = nn.ReLU()(self.fc3(vector))
# Output layer
pred = nn.Sigmoid()(self.output(vector))
return pred
def training_step(self,batch,batch_idx):
user_input,item_input,labels = batch
predicted_labels = self(user_input,item_input)
loss = nn.BCELoss()(predicted_labels,labels.view(-1,1).float())
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters())
def train_dataloader(self):
return DataLoader(
ChupsTrainDataset(
self.data,self.all_orderIds
),batch_size = 32,num_workers = 2
# Google Colab's suggested max number of worker in current
# system is 2 and not 4.
)
print(f"num_users = {num_users},num_items = {num_items} & all_itemIds = {len(all_itemIds)}")
# num_users = 12958,num_items = 511238 & all_itemIds = 9114
# Initialize NCF model-
model = NCF(num_users,train_ratings,all_itemIds)
trainer = pl.Trainer(
max_epochs = 75,gpus = 1,# max_epochs = 5,reload_dataloaders_every_n_epochs = True,# reload_dataloaders_every_epoch = True,# deprecated!
progress_bar_refresh_rate = 50,logger = False,checkpoint_callback = False)
trainer.fit(model)
# Save trained model as a checkpoint-
trainer.save_checkpoint("NCF_Trained.ckpt")
要加载保存的检查点,我尝试过:
trained_model = NCF.load_from_checkpoint(
"NCF_Trained.ckpt",num_users = num_users,num_items = train_ratings,data = train_ratings,all_itemIds = all_itemIds)
trained_model = NCF(num_users,all_orderIds).load_from_checkpoint(checkpoint_path = "NCF_Trained.ckpt")
但这些似乎不起作用。如何加载这个保存的检查点?
谢谢!
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。