使用validation_data嵌入LSTM

如何解决使用validation_data嵌入LSTM

我正在尝试在 LSTM 中使用带有validation_data 的嵌入。然而,嵌入似乎改变了数据的形状。考虑到验证数据与嵌入后的数据具有不同的形状,我应该如何继续使用validation_data?非常感谢。

请在下面找到可重现的 python 代码

import pandas as pd
import numpy as np
from keras.models import Sequential
from keras.layers import Embedding,LSTM,Dense
from matplotlib import pyplot as plt


#fake data
np.random.seed(287528)
dat=pd.DataFrame(np.random.rand(2880*10).reshape(2880,10),columns = ['y','x1','x2','x3','x4','x5','x6','x7','x8','x9'])
x=dat.iloc[:,[1,2,3,4,5,6,7,8,9]]
y=dat.iloc[:,0]
dat.head()
dat.shape
x=x.values # convert DataFrame to array
y=y.values
y=np.reshape(y,[y.shape[0],1])
train_x=x[0:1584]
train_y=y[0:1584]
test_x=x[1584:2880]
test_y=y[1584:2880]

m_train_x=np.mean(train_x,axis=0) 
m_train_y=np.mean(train_y,axis=0)
std_train_x=np.std(train_x,axis=0)
std_train_y=np.std(train_y,axis=0)
s_train_x=(train_x-m_train_x)/std_train_x
s_train_y=(train_y-m_train_y)/std_train_y
s_test_x=(test_x-m_train_x)/std_train_x
s_test_y=(test_y-m_train_y)/std_train_y
print(s_train_x.shape,s_train_y.shape,s_test_x.shape,s_test_y.shape)

s_train_x=np.reshape(s_train_x,[-1,144,s_train_x.shape[1]]) 
s_train_y=np.reshape(s_train_y,s_train_y.shape[1]])
s_test_x=np.reshape(s_test_x,1,s_test_x.shape[1]]) 
s_test_y=np.reshape(s_test_y,s_test_y.shape[1]])
print(s_train_x.shape,s_test_y.shape)

model=Sequential()
model.add(Embedding(9,input_length=144))
model.add(LSTM(100,dropout=0.2)) 
model.add(Dense(1))
model.compile(loss='mse',optimizer='rmsprop',metrics=['accuracy'])

# things start go wrong here:
 history=model.fit(s_train_x,s_train_y,epochs=10,batch_size=100,validation_data=(s_test_x[1440:1584],s_test_y[1440:1584]),shuffle=False)


错误如下:

Epoch 1/10
WARNING:tensorflow:Model was constructed with shape (None,144) for input KerasTensor(type_spec=TensorSpec(shape=(None,144),dtype=tf.float32,name='embedding_23_input'),name='embedding_23_input',description="created by layer 'embedding_23_input'"),but it was called on an input with incompatible shape (None,9).
Traceback (most recent call last):
  File "<stdin>",line 1,in <module>
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py",line 1100,in fit
    tmp_logs = self.train_function(iterator)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py",line 828,in __call__
    result = self._call(*args,**kwds)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py",line 871,in _call
    self._initialize(args,kwds,add_initializers_to=initializers)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py",line 726,in _initialize
    *args,**kwds))
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/eager/function.py",line 2969,in _get_concrete_function_internal_garbage_collected
    graph_function,_ = self._maybe_define_function(args,kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/eager/function.py",line 3361,in _maybe_define_function
    graph_function = self._create_graph_function(args,line 3206,in _create_graph_function
    capture_by_value=self._capture_by_value),File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py",line 990,in func_graph_from_py_func
    func_outputs = python_func(*func_args,**func_kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py",line 634,in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args,**kwds)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py",line 977,in wrapper
    raise e.ag_error_Metadata.to_exception(e)
ValueError: in user code:

    /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:805 train_function  *
        return step_function(self,iterator)
    /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:795 step_function  **
        outputs = model.distribute_strategy.run(run_step,args=(data,))
    /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py:1259 run
        return self._extended.call_for_each_replica(fn,args=args,kwargs=kwargs)
    /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py:2730 call_for_each_replica
        return self._call_for_each_replica(fn,args,kwargs)
    /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py:3417 _call_for_each_replica
        return fn(*args,**kwargs)
    /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:788 run_step  **
        outputs = model.train_step(data)
    /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:754 train_step
        y_pred = self(x,training=True)
    /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py:1012 __call__
        outputs = call_fn(inputs,*args,**kwargs)
    /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/keras/engine/sequential.py:375 call
        return super(Sequential,self).call(inputs,training=training,mask=mask)
    /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py:425 call
        inputs,mask=mask)
    /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py:560 _run_internal_graph
        outputs = node.layer(*args,**kwargs)
    /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/keras/layers/recurrent.py:660 __call__
        return super(RNN,self).__call__(inputs,**kwargs)
    /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py:998 __call__
        input_spec.assert_input_compatibility(self.input_spec,inputs,self.name)
    /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/python/keras/engine/input_spec.py:223 assert_input_compatibility
        str(tuple(shape)))

    ValueError: Input 0 of layer lstm_29 is incompatible with the layer: expected ndim=3,found ndim=4. Full shape received: (None,9,3)

似乎原始特征的数量(9)和嵌入的主成分的数量都包含在图层的暗淡中,这不是我想要的。请给我一些建议。非常感谢。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?