如何使神经网络与 python 中的 sklearn CountVectorizer 一起工作?

如何解决如何使神经网络与 python 中的 sklearn CountVectorizer 一起工作?

所以我有一个具有多输出预测(连续浮点类型)的项目,我正在测试多个模型。我现在被困在神经网络中,因为我在 model.fit 函数中不断收到此错误

ValueError: in user code:

    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:805 train_function  *
        return step_function(self,iterator)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:795 step_function  **
        outputs = model.distribute_strategy.run(run_step,args=(data,))
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:1259 run
        return self._extended.call_for_each_replica(fn,args=args,kwargs=kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2730 call_for_each_replica
        return self._call_for_each_replica(fn,args,kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:3417 _call_for_each_replica
        return fn(*args,**kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:788 run_step  **
        outputs = model.train_step(data)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:754 train_step
        y_pred = self(x,training=True)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py:998 __call__
        input_spec.assert_input_compatibility(self.input_spec,inputs,self.name)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/input_spec.py:259 assert_input_compatibility
        ' but received input with shape ' + display_shape(x.shape))

    ValueError: Input 0 of layer sequential_2 is incompatible with the layer: expected axis -1 of input shape to have value 200 but received input with shape (None,3386)

奇怪的是,它在一个小的训练集上工作了一段时间(看到所有的 epochs 摘要和执行都没有问题),然后当我改变到实际更大的训练集时,我得到了这个错误。我第一次尝试使用第一个训练集时也得到了它,但不知何故它奏效了。

起初,我收到了这个错误


InvalidArgumentError: indices[2] = [0,1540] is out of order. Many sparse ops require sorted indices.
    Use `tf.sparse.reorder` to create a correctly ordered copy.

 [Op:SerializeManySparse]

然后我将 .toarray() 添加到我的输入数据中,从那以后我得到了上面的数据(轴=-1)

我的 X_train 是 CountVectorizer 函数的结果(返回一个 csr 矩阵)。我尝试了各种其他方法,例如重塑或转换为 SparseTensor,但我仍然遇到此错误

X_train.shape 的结果是 (200,3386) - 不像错误中所说的那样 (None,3386)。

我会给你留下一些关于如何获得输入/输出向量和模型的代码

#prepare training data
X_train_raw = df_train.msg.values
X_train_clean = np.asarray(preprocess(X_train_raw))
vectorizer = CountVectorizer()
X_train = vectorizer.fit_transform(X_train_clean).toarray()
Y_train = df_train.drop(['id'],axis=1).drop(['msg'],axis=1).values

.........

model = Sequential()
model.add(layers.Dense(16,input_dim=200,activation='relu'))
model.add(layers.Dense(6,activation='relu'))
model.add(layers.Dense(2))

model.compile(loss='mse',optimizer='adam',metrics=['mae'])


model.fit(X_train,Y_train,epochs=1000,verbose=2)
predicted_validation = model.predict(X_validation)
mse_value,mae_value = model.evaluate(X_validation,Y_validation,verbose=0)
test_loss,test_metrics= model.evaluate(X_validation,verbose=0)

print('test loss',test_loss)
print('test metrics',test_metrics)

predicted_scores_test = model.predict(X_test)

如果您有任何建议,请告诉我!也许我没有很好地使用 Sequential 模型,我是 ML 的新手。

谢谢!

解决方法

在您的第一层中,您应该指定 input_dim=3386:这是您的数据具有的特征数量。或者,更好的是,由于不同的数据集会在 CountVectorizer 中产生不同数量的单词,因此请使用 input_dim=len(vectorizer.vocabulary_),这样您就不必在更改数据时随时更改它。

X_train.shape 的结果是 (200,3386) - 不像错误中所说的那样 (None,3386)。

作为输入的第一维的 200None 替换,因为网络看到的形状将取决于您如何批量训练样本。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?