微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

枚举数据加载器时无法访问所有数据

如何解决枚举数据加载器时无法访问所有数据

我定义了一个自定义 Dataset一个自定义 DataLoader,我想使用 for i,batch in enumerate(loader) 访问所有批次。但是这个 for 循环在每个 epoch 中给了我不同的批次数,并且所有批次都远小于实际的批次数(等于 number_of_samples/batch_size)。

以下是我定义数据集和数据加载器的方式:



class UsptoDataset(Dataset):
    def __init__(self,csv_file):
        df = pd.read_csv(csv_file)
        self.rea_trees = df['reactants_trees'].to_numpy()
        self.syn_trees = df['synthons_trees'].to_numpy()
        self.syn_smiles = df['synthons'].to_numpy()
        self.product_smiles = df['product'].to_numpy()

    def __len__(self):
        return len(self.rea_trees)

    def __getitem__(self,item):
        rea_tree = self.rea_trees[item]
        syn_tree = self.syn_trees[item]
        syn_smile = self.syn_smiles[item]
        pro_smile = self.product_smiles[item]
        # omit the snippet used to process the data here,which gives us the variables used in the return statement.
        return {'input_words': input_words,'input_chars': input_chars,'syn_tree_indices': syn_tree_indices,'syn_rule_nl_left': syn_rule_nl_left,'syn_rule_nl_right': syn_rule_nl_right,'rea_tree_indices': rea_tree_indices,'rea_rule_nl_left': rea_rule_nl_left,'rea_rule_nl_right': rea_rule_nl_right,'class_mask': class_mask,'query_paths': query_paths,'labels': labels,'parent_matrix': parent_matrix,'syn_parent_matrix': syn_parent_matrix,'path_lens': path_lens,'syn_path_lens': syn_path_lens}

    @staticmethod
    def collate_fn(batch):
        input_words = torch.tensor(np.stack([_['input_words'] for _ in batch],axis=0),dtype=torch.long)
        input_chars = torch.tensor(np.stack([_['input_chars'] for _ in batch],dtype=torch.long)
        syn_tree_indices = torch.tensor(np.stack([_['syn_tree_indices'] for _ in batch],dtype=torch.long)
        syn_rule_nl_left = torch.tensor(np.stack([_['syn_rule_nl_left'] for _ in batch],dtype=torch.long)
        syn_rule_nl_right = torch.tensor(np.stack([_['syn_rule_nl_right'] for _ in batch],dtype=torch.long)
        rea_tree_indices = torch.tensor(np.stack([_['rea_tree_indices'] for _ in batch],dtype=torch.long)
        rea_rule_nl_left = torch.tensor(np.stack([_['rea_rule_nl_left'] for _ in batch],dtype=torch.long)
        rea_rule_nl_right = torch.tensor(np.stack([_['rea_rule_nl_right'] for _ in batch],dtype=torch.long)
        class_mask = torch.tensor(np.stack([_['class_mask'] for _ in batch],dtype=torch.float32)
        query_paths = torch.tensor(np.stack([_['query_paths'] for _ in batch],dtype=torch.long)
        labels = torch.tensor(np.stack([_['labels'] for _ in batch],dtype=torch.long)
        parent_matrix = torch.tensor(np.stack([_['parent_matrix'] for _ in batch],dtype=torch.float)
        syn_parent_matrix = torch.tensor(np.stack([_['syn_parent_matrix'] for _ in batch],dtype=torch.float)
        path_lens = torch.tensor(np.stack([_['path_lens'] for _ in batch],dtype=torch.long)
        syn_path_lens = torch.tensor(np.stack([_['syn_path_lens'] for _ in batch],dtype=torch.long)

        return_dict = {'input_words': input_words,'syn_path_lens': syn_path_lens}
        return return_dict


train_dataset=UsptoDataset("train_trees.csv")

train_loader = DataLoader(train_dataset,batch_size=4,shuffle=True,num_workers=1,collate_fn=UsptoDataset.collate_fn)         

当我按如下方式使用数据加载器时,它会在每个时期给我不同数量的批次:

epoch_steps = len(train_loader)
for e in range(epochs):
    for j,batch_data in enumerate(train_loader):
        step = e * epoch_steps + j

日志显示一个epoch只有5个batch,第二个epoch有3个batch,第三个epoch有5个batch,以此类推。

 1 Config:
  2 Namespace(batch_size_per_gpu=4,epochs=400,eval_every_epoch=1,hidden_size=128,keep=10,log_every_step=1,lr=0.001,new_model=False,save_dir='saved_model/',workers=1)
  3 2021-01-06 15:33:17,909 - __main__ - WARNING - Checkpoints not found in dir saved_model/,creating a new model.
  4 2021-01-06 15:33:18,340 - __main__ - INFO - Step: 0,Loss: 5.4213,Rule acc: 0.1388
  5 2021-01-06 15:33:18,686 - __main__ - INFO - Step: 1,Loss: 4.884,Rule acc: 0.542
  6 2021-01-06 15:33:18,941 - __main__ - INFO - Step: 2,Loss: 4.6205,Rule acc: 0.6122
  7 2021-01-06 15:33:19,174 - __main__ - INFO - Step: 3,Loss: 4.4442,Rule acc: 0.61
  8 2021-01-06 15:33:19,424 - __main__ - INFO - Step: 4,Loss: 4.3033,Rule acc: 0.6211
  9 2021-01-06 15:33:20,684 - __main__ - INFO - Dev Loss: 3.5034,Dev Sample Acc: 0.0,Dev Rule Acc: 0.5970844200679234,in epoch 0
 10 2021-01-06 15:33:22,203 - __main__ - INFO - Test Loss: 3.4878,Test Sample Acc: 0.0,Test Rule Acc: 0.6470248053471247
 11 2021-01-06 15:33:22,394 - __main__ - INFO - Found better dev sample accuracy 0.0 in epoch 0
 12 2021-01-06 15:33:22,803 - __main__ - INFO - Step: 10002,Loss: 3.6232,Rule acc: 0.6555
 13 2021-01-06 15:33:23,046 - __main__ - INFO - Step: 10003,Loss: 3.53,Rule acc: 0.6442
 14 2021-01-06 15:33:23,286 - __main__ - INFO - Step: 10004,Loss: 3.4907,Rule acc: 0.6498
 15 2021-01-06 15:33:24,617 - __main__ - INFO - Dev Loss: 3.3081,Dev Rule Acc: 0.5980878387178693,in epoch 1
 16 2021-01-06 15:33:26,215 - __main__ - INFO - Test Loss: 3.2859,Test Rule Acc: 0.6466992994149526
 17 2021-01-06 15:33:26,857 - __main__ - INFO - Step: 20004,Loss: 3.3965,Rule acc: 0.6493
 18 2021-01-06 15:33:27,093 - __main__ - INFO - Step: 20005,Loss: 3.3797,Rule acc: 0.6314
 19 2021-01-06 15:33:27,353 - __main__ - INFO - Step: 20006,Loss: 3.3959,Rule acc: 0.5727
 20 2021-01-06 15:33:27,609 - __main__ - INFO - Step: 20007,Loss: 3.3632,Rule acc: 0.6279
 21 2021-01-06 15:33:27,837 - __main__ - INFO - Step: 20008,Loss: 3.3331,Rule acc: 0.6158
 22 2021-01-06 15:33:29,122 - __main__ - INFO - Dev Loss: 3.0911,Dev Rule Acc: 0.6016287207603455,in epoch 2
 23 2021-01-06 15:33:30,689 - __main__ - INFO - Test Loss: 3.0651,Test Rule Acc: 0.6531393428643545
 24 2021-01-06 15:33:32,143 - __main__ - INFO - Dev Loss: 3.0911,in epoch 3
 25 2021-01-06 15:33:33,765 - __main__ - INFO - Test Loss: 3.0651,Test Rule Acc: 0.6531393428643545
 26 2021-01-06 15:33:34,359 - __main__ - INFO - Step: 40008,Loss: 3.108,Rule acc: 0.6816
 27 2021-01-06 15:33:34,604 - __main__ - INFO - Step: 40009,Loss: 3.0756,Rule acc: 0.6732
 28 2021-01-06 15:33:35,823 - __main__ - INFO - Dev Loss: 3.0419,Dev Rule Acc: 0.613776079245976,in epoch 4

仅供参考,len(train_loader.dataset)batch_sizelen(train_loader)的值分别是40008410002,这正是我预期的。因此令人困惑的是,使用 enumerate 只给我几个批次,例如 35(预计为 10002)。

解决方法

我不确定您的代码有什么问题。据我所知,您在 collate_fn 中尝试做的是,从批处理中收集和堆叠相同特征类型的数据。类似的东西:

您正在使用 input_wordsinput_charssyn_tree_indicessyn_rule_nl_leftsyn_rule_nl_leftsyn_rule_nl_rightrea_tree_indices、{{ 1}}、rea_tree_indicesrea_rule_nl_leftrea_rule_nl_rightclass_maskquery_pathslabelsparent_matrix、{{1} },和 syn_parent_matrix 作为键。在我的示例中,我们将保持简单,仅使用 path_lenssyn_path_lensab

  • c 将从您的数据集中返回单个数据点。在我们的例子中,它将是一个字典d

  • __getitem__:是数据集和数据加载器返回数据时的中间层。它需要一个 list 批处理元素(用 {'a': ...,'b': ...,'c': ...,'d': ...} 一一收集的元素)。您要在这里返回的是经过整理的批次。将collate_fn转换为__getitem__的东西。其中键 [{'a': ...,'d': ...},...] 将包含来自 {'a': [...],'b': [...],'c': [...],'d': [...]} 特征的所有数据...

现在您可能不知道对于这种简单类型的整理,您实际上并不需要 'a'。我相信 tuplesdictionnaries 是由 PyTorch 数据加载器自动处理的。这意味着如果您从 a 返回一个 dictionnary,您的数据加载器将通过键自动整理。

这里,仍然是我们最小的例子:

collate_fn

正如您在下面的打印中所见,数据是通过密钥收集的。

__getitem__

提供 class D(Dataset): def __init__(self): super(D,self).__init__() self.a = [1,11,111,1111,11111] self.b = [2,22,222,2222,22222] self.c = [3,33,333,3333,33333] self.d = [4,44,444,4444,44444] def __getitem__(self,i): return { 'a': self.a[i],'b': self.b[i],'c': self.c[i],'d': self.d[i] } def __len__(self): return len(self.a) 参数将删除此自动整理。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。