如何将kfold.split应用于列表字典?

如何解决如何将kfold.split应用于列表字典?

我想使用CrossValidation训练Keras模型,但是我的数据是列表的字典。

我想折叠10次,所以我希望每个验证步骤中有10%的dict键的子集,然后在下一个步骤中又需要10%的子集。

示例: 对于第一步验证:

pairs_train = {'0': list1,'1': list2,'2': list3,'3': list4,'4': list5,'5': list6,'6': list7,'7': list8,'8': list9,}

 pairs_val = {'9': list10,}

这是我的职能:

def crossValidation(self,k_folds=10):
    cv_accuracy_train = []
    cv_accuracy_val = []
    cv_loss_train = []
    cv_loss_val = []

    s = pd.Series(pairs)

    idx = 0
    for train_idx,val_idx in kfold.split(s):
        print("=========================================")
        print("====== K Fold Validation step => %d/%d =======" % (idx,k_folds))
        print("=========================================")

        train_gen = DataGenerator(pairs=s[train_idx],batch_size=self.param_grid['batch_size'],nr_files=len(self.Data.all_files),nr_tests=len(self.Data.all_tests),negative_ratio=self.param_grid['negative_ratio'])

        val_gen = DataGenerator(pairs=s[val_idx],negative_ratio=self.param_grid['negative_ratio'])

        # Train
        h = self.model.fit(train_gen,validation_data=val_gen,epochs=self.param_grid['nb_epochs'],verbose=2)

        cv_accuracy_train.append(np.array(h.history['mae'])[-1])
        cv_accuracy_val.append(np.array(h.history['val_mae'])[-1])
        cv_loss_train.append(np.array(h.history['loss'])[-1])
        cv_loss_val.append(np.array(h.history['val_loss'])[-1])
        idx += 1

跟踪:

    File "/Users/joaolousada/Documents/5ºAno/Master-Thesis/main/Prioritizer/Prioritizer.py",line 173,in crossValidation
    train_gen = DataGenerator(pairs=s[train_idx],File "/Users/joaolousada/opt/anaconda3/lib/python3.7/site-packages/pandas/core/series.py",line 908,in __getitem__
    return self._get_with(key)
  File "/Users/joaolousada/opt/anaconda3/lib/python3.7/site-packages/pandas/core/series.py",line 943,in _get_with
    return self.loc[key]
  File "/Users/joaolousada/opt/anaconda3/lib/python3.7/site-packages/pandas/core/indexing.py",line 879,in __getitem__
    return self._getitem_axis(maybe_callable,axis=axis)
  File "/Users/joaolousada/opt/anaconda3/lib/python3.7/site-packages/pandas/core/indexing.py",line 1099,in _getitem_axis
    return self._getitem_iterable(key,line 1037,in _getitem_iterable
    keyarr,indexer = self._get_listlike_indexer(key,axis,raise_missing=False)
  File "/Users/joaolousada/opt/anaconda3/lib/python3.7/site-packages/pandas/core/indexing.py",line 1254,in _get_listlike_indexer
    self._validate_read_indexer(keyarr,indexer,raise_missing=raise_missing)
  File "/Users/joaolousada/opt/anaconda3/lib/python3.7/site-packages/pandas/core/indexing.py",line 1298,in _validate_read_indexer
    raise KeyError(f"None of [{key}] are in the [{axis_name}]")
KeyError: "None of [Int64Index([   0,1,2,3,4,5,6,7,8,9,\n            ...\n            3257,3258,3261,3262,3263,3265,3266,3267,3268,3269],\n           dtype='int64',length=2943)] are in the [index]"

解决方法

如果一个dict的值为list。例如

pairs = {'0': [1,2,3],'1': [1,'2': [4,6,8],'3': [2,1,9],'4': [9,7,'5': [4,'6': [9,'7': [9,'8': [1,'9': [4,}

以下函数将返回索引以按索引将字典分开

def kfold_split(pairs:dict,perc:float,shuffle:bool) -> list:
    
    keys = list(pairs.keys())
    sets = len(keys)
    cv_perc = int(sets*perc)
    folds = int(sets/cv_perc)
    
    indices = [] 
    
    for fold in range(folds):
    
        # If you want to generate random keys
        if shuffle:

            # Choose random keys 
            random_keys = list(np.random.choice(keys,cv_perc))
            
            other_keys = list(set(keys) - set(random_keys)) 
            
            indices.append((other_keys,random_keys))
            
        else: 
            
            if fold == 0: 
                fold_keys = keys[-cv_perc*(fold+1):]
            else:
                fold_keys = keys[-cv_perc*(fold+1):-cv_perc*(fold)]
            
            other_keys = list(set(keys) - set(fold_keys)) 
            
            indices.append((other_keys,fold_keys))             
    
    return indices

您可以检索随机索引

kfold_split(pairs,perc=.2,shuffle=True)
>>>
[(['6','2','1','5','4','7','0','3'],['9','8']),(['6','9',['8','2']),(['2','8',['6','0']),['1','6']),'1'])]

订单索引

kfold_split(pairs,shuffle=False)
>>>
[(['6','9']),'7']),['4','5']),'0'],['2','3']),['0','1'])]

然后,您可以根据以下索引过滤字典

for indices in result:
    train_indices,test_indices = indices
    
    # Filter dict by indices
    pair_test = {k:v for k,v in pairs.items() if k in test_indices}
                   
    # Train data
    pair_train = {k:v for k,v in pairs.items() if k not in train_indices}
    
    # Some other stuff here
,

我设法解决了自己的问题,方法是将所有字典键都当作np.array并在kf.split()中使用它们。然后,在获得索引后,访问所需的字典键。不确定是否使用更优化/ Python的解决方案,但效果很好。

def crossValidation(self,k_folds=10):
    cv_accuracy_train = []
    cv_accuracy_val = []
    cv_loss_train = []
    cv_loss_val = []

    s = np.array(list(self.Data.pairs.keys()))

    kfold = KFold(n_splits=k_folds,shuffle=True)

    idx = 0
    for train_idx,val_idx in kfold.split(s):
        print("=========================================")
        print("====== K Fold Validation step => %d/%d =======" % (idx,k_folds))
        print("=========================================")
        pairs_train = {s[key]: self.Data.pairs[s[key]] for key in train_idx}
        pairs_val = {s[key]: self.Data.pairs[s[key]] for key in val_idx}

        train_gen = DataGenerator(pairs=pairs_train,batch_size=self.param_grid['batch_size'],nr_files=len(self.Data.all_files),nr_tests=len(self.Data.all_tests),negative_ratio=self.param_grid['negative_ratio'])

        val_gen = DataGenerator(pairs=pairs_val,negative_ratio=self.param_grid['negative_ratio'])

        # Train
        h = self.model.fit(train_gen,validation_data=val_gen,epochs=self.param_grid['nb_epochs'],verbose=2)

        cv_accuracy_train.append(np.array(h.history['accuracy'])[-1])
        cv_accuracy_val.append(np.array(h.history['val_accuracy'])[-1])
        cv_loss_train.append(np.array(h.history['loss'])[-1])
        cv_loss_val.append(np.array(h.history['val_loss'])[-1])
        idx += 1

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?