如何解决“ ImbalancedDatasetSampler”类无法访问TensorDataset标签
我正在处理不平衡的数据(仅有2%的少数样本)。我在pytorch中尝试了“ WeightedRandomSampler”方法来构建深度学习模型,该模型对我的验证集有效,但在独立测试集的情况下会失败。我碰到 https://github.com/ufoym/imbalanced-dataset-sampler 我想对我的数据尝试这种方法。问题是-这个“ ImbalancedDatasetSampler”模块无法找出我的TensorDataset对象中的标签。
tensor_x = torch.Tensor(X_TRAIN) # X_TRAIN shape [16200,1105]
tensor_y = torch.Tensor(Y_TRAIN.values) # Y_TRAIN shape [16200,1]
dataset_train = TensorDataset(tensor_x,tensor_y)
train_loader = data_utils.DataLoader(train_dataset,batch_size = BATCH_SIZE,sampler=ImbalancedDatasetSampler(train_dataset))
返回错误
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
<ipython-input-13-b7cc711025fa> in <module>
2 train_loader = data_utils.DataLoader(train_dataset,3 batch_size = BATCH_SIZE,----> 4 sampler=ImbalancedDatasetSampler(train_dataset))
~/py_torch_sampler.py in __init__(self,dataset,indices,num_samples,callback_get_label)
30 label_to_count = {}
31 for idx in self.indices:
---> 32 label = self._get_label(dataset,idx)
33 if label in label_to_count:
34 label_to_count[label] += 1
~/py_torch_sampler.py in _get_label(self,idx)
51 return self.callback_get_label(dataset,idx)
52 else:
---> 53 raise NotImplementedError
54
55 def __iter__(self):
NotImplementedError:
有人可以告诉我如何解决这个问题吗?
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。