微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Mxnet,使用 Pandas 从 csv 文件加载数据并馈送至 NN 模型

如何解决Mxnet,使用 Pandas 从 csv 文件加载数据并馈送至 NN 模型

我在正确加载 .csv 文件以用作非常简单的密集 NN 模型的输入时遇到问题。 csv 文件包含所有输入特征和一个“目标”列,用作回归的输出

这是我目前所做的:

def main():

    batch_size = 500

    ## load input file
    df_data = pd.read_csv('some_file.csv',index_col=0)
    ## random train/test split
    df_train = df_data.sample(frac=0.8,random_state=200)
    df_test = df_data.drop(df_train.index)

    ## data pre-processing
    df_train.reset_index(drop=True,inplace=True)
    df_test.reset_index(drop=True,inplace=True)    
    y_train = df_train['target'].to_numpy(dtype=np.float64)
    y_test = df_test['target'].to_numpy(dtype=np.float64)
    X_train = df_train.drop(['target'],axis=1).to_numpy(dtype=np.float64)
    X_test = df_test.drop(['target'],axis=1).to_numpy(dtype=np.float64)


    dataset = mx.gluon.data.dataset.ArrayDataset(X_train,y_train)
    data_loader = mx.gluon.data.DataLoader(dataset,batch_size=batch_size,shuffle=True)

    ##   building model 
    model = nn.Sequential()
    model.add(nn.Dense(150))
    model.add(nn.Dense(1))
    model.initialize(init.normal(sigma=0.01))

    ## loss function (squared loss)
    loss = gloss.L2Loss()

    ## optimization algorithm,specify:
    trainer = gluon.Trainer(model.collect_params(),'sgd',{'learning_rate': 0.03})

    ##   training   #
    num_epochs = 10
    for epoch in range(1,num_epochs + 1):
        for X_batch,Y_batch in data_loader:
            with autograd.record():
                l = loss(model(X_batch),Y_batch)
            l.backward()
            trainer.step(batch_size)
        # overall (entire dataset) loss after epoch
        l = loss(model(X_train),y_train)
        print(f'\nEpoch {epoch},loss: {l.mean().asnumpy()}')

我收到错误

mxnet.base.MXNetError: [16:09:03] src/operator/numpy/linalg/./../../tensor/../elemwise_op_common.h:135: Check Failed: assign(&dattr,vec.at(i)): Incompatible attr in node  at 1-th input: expected float64,got float32

所以,我尝试通过将 np.float64 切换为 np.float32 来转换数据,但我得到了:

File "/home/lews/anaconda3/envs/gluon/lib/python3.7/site-packages/mxnet/gluon/block.py",line 1136,in forward
raise ValueError('In HybridBlock,there must be one ndarray or one Symbol in the input.'
ValueError: In HybridBlock,there must be one ndarray or one Symbol in the input. Please check the type of the args.

加载这些数据的正确方法是什么?

解决方法

我通过使用修复了它

 ## data pre-processing
y_train = np.array(df_train['target'].to_numpy().reshape(-1,1),dtype=np.float32)
y_test = np.array(df_test['target'].to_numpy().reshape(-1,dtype=np.float32)
X_train = np.array(df_train.drop(['target'],axis=1).to_numpy(),dtype=np.float32)
X_test = np.array(df_test.drop(['target'],dtype=np.float32)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?