微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

NumPy 在数据集预处理中的效率

如何解决NumPy 在数据集预处理中的效率

我目前正在开展一个研究项目,该项目与使用在 EEG 数据集上运行的神经网络相关。我正在使用 BCICIV 2a 数据集,它由一系列包含受试者试验数据的文件组成。每个文件包含一组 25 个通道和一个非常长的 ~600000 时间步长的信号阵列。我一直在编写代码来将这些数据预处理为可以传递给神经网络的东西,但遇到了一些效率问题。目前,我已经编写了用于确定文件中所有试验在数组中的位置的代码,然后尝试提取存储在另一个数组中的 3D NumPy 数组。但是,当我尝试运行此代码时,它的速度慢得离谱。我对 NumPy 不是很熟悉,此时我的大部分经验都是用 C 语言编写的。我的目的是将预处理的结果写入一个单独的文件,该文件可以加载以避免预处理。从 C 的角度来看,所有需要做的就是移动指针以适当地格式化数据,所以我不知道为什么 NumPy 这么慢。任何建议都会非常有帮助,因为目前对于 1 个文件提取 1 个试验需要大约 2 分钟,一个文件中有 288 个试验和 9 个文件,这将比我想要的要长得多。我对如何充分利用 NumPy 对泛型列表的效率改进的知识不太满意。谢谢!

import glob,os
import numpy as np
import mne

DURATION = 313
XDIM = 7
YDIM = 6
IGnorE = ('EOG-left','EOG-central','EOG-right')

def getIndex(raw,tagIndex):
    return int(raw.annotations[tagIndex]['onset']*250)

def isEvent(raw,tagIndex,events):
    for event in events:
        if (raw.annotations[tagIndex]['description'] == event):
            return True
    return False

def getSlice1D(raw,channel,dur,index):
    if (type(channel) == int):
        channel = raw.ch_names[channel]
    return raw[channel][0][0][index:index+dur]

def getSliceFull(raw,index):
    trial = np.zeros((XDIM,YDIM,dur))
    for channel in raw.ch_names:
        if not channel in IGnorE:
            x,y = convertIndices(channel)
            trial[x][y] = getSlice1D(raw,index)
    return trial

def convertIndices(channel):
    xDict = {'EEG-Fz':3,'EEG-0':1,'EEG-1':2,'EEG-2':3,'EEG-3':4,'EEG-4':5,'EEG-5':0,'EEG-C3':1,'EEG-6':2,'EEG-Cz':3,'EEG-7':4,'EEG-C4':5,'EEG-8':6,'EEG-9':1,'EEG-10':2,'EEG-11':3,'EEG-12':4,'EEG-13':5,'EEG-14':2,'EEG-Pz':3,'EEG-15':4,'EEG-16':3}
    yDict = {'EEG-Fz':0,'EEG-1':1,'EEG-2':1,'EEG-3':1,'EEG-4':1,'EEG-5':2,'EEG-C3':2,'EEG-Cz':2,'EEG-7':2,'EEG-C4':2,'EEG-8':2,'EEG-9':3,'EEG-10':3,'EEG-12':3,'EEG-13':3,'EEG-14':4,'EEG-Pz':4,'EEG-16':5}
    return xDict[channel],yDict[channel]

data_files = glob.glob('../datasets/BCICIV_2a_gdf/*.gdf')

try:
    raw = mne.io.read_raw_gdf(data_files[0],verbose='ERROR')
except IndexError:
    print("No data files found")

event_times = []

for i in range(len(raw.annotations)):
    if (isEvent(raw,i,('769','770','771','772'))):
        event_times.append(getIndex(raw,i))

data = np.empty((len(event_times),XDIM,DURATION))

print(len(event_times))

for i,event in enumerate(event_times):
    data[i] = getSliceFull(raw,DURATION,event)

编辑: 我想回来添加一些关于数据集结构的更多细节。有一个包含数据的 25x~600000 数组和一个更短的注释对象,其中包含事件标记并将这些标记与更大数组中的时间相关联。特定事件表示运动图像线索,这是我的网络正在接受的试验,我试图提取一个 3D 切片,其中包括使用时间维度适当格式化的相关通道,发现时间维度为 313 时间步长。注释为我提供了调查的相关时间步长。 Ian 推荐的分析结果表明,主要计算时间位于 getSlice1D() 函数中。特别是在我索引原始对象的地方。从注释中提取事件时间的代码相对可以忽略不计。

解决方法

这是部分答案,因为评论中的格式有点垃圾,但是

def getIndex(raw,tagIndex):
    return int(raw.annotations[tagIndex]['onset']*250)


def isEvent(raw,tagIndex,events):
    for event in events:
        if (raw.annotations[tagIndex]['description'] == event):
            return True
    return False

for i in range(len(raw.annotations)):
    if (isEvent(raw,i,('769','770','771','772'))):
        event_times.append(getIndex(raw,i))

注意你是如何迭代 I 的。你可以做的是

def isEvent(raw_annotations_desc,raw_annotations_onset,events):
    flag_container = []

    for event in events:    # Iterate through all the events
        # Do a vectorized comparison across all the indices
        flag_container.append(raw_annotations_desc == event)
    # At this point flag_container will be of shape (|events|,len(raw_annotations_desc) 

    """
    Assuming I understand correctly,for a given index if  
        ANY of the events is true,we return true and get the index,correct?
    def getIndex(raw,tagIndex):
        return int(raw.annotations[tagIndex]['onset']*250)
    """
    flag_container = np.asarray(flag_container)  # Change raw list to np array
    
    # Python treats False as 0 and True as 1. So,we sum over the cols 
    # and we now have an array of shape (1,len(raw_annotations))
    flag_container = flag_container.sum(1)  

    # Add indices because we will use these later
    flag_container = np.asarray(np.arange(len(raw_annotations)),flag_container)

    # Almost there. Now,flag_container has 2 cols: the index AND the number of True in a given row
    
    # Gets us all the indices where the sum was greater than 1 (aka one positive)
    
    flag_container = flag_container[flag_container[1,:] > 0]  

    # Now,an array of shape (2,x <= len(raw_annotations_desc))
    flag_container = flag_container[0,:]  # We only care about the indices,not the actual count of positives now so we slice out the 0th-col

    return int(raw_annotations_onset[flag_container] * 250)

那种效果:) 这应该会加快速度

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。