Tensorflow NN实现不收敛

如何解决Tensorflow NN实现不收敛

我正在尝试仅使用Tensorflow来实现一个简单的前馈神经网络,并且它没有收敛。我不确定问题出在网络的体系结构还是培训过程的实现。使用Keras构建的简单2层NN似乎收敛了:

from keras.layers import LSTM,Dense,Flatten,Conv1D
from keras import Sequential
model = Sequential()
model.add(Dense(32,activation='relu'))
model.add(Dense(32,activation='relu'))
model.add(Dense(21,activation='softmax'))
model.compile(loss='sparse_categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
history = model.fit(np.array(train_in),np.array(train_target),epochs=10,validation_split=0.1,batch_size=16)
    Epoch 2/10 59717/59717 [==============================] - 4s 71us/sample - loss: 1.4021 - accuracy: 0.6812 - val_loss: 1.1049 - val_accuracy: 0.7066
Epoch 3/10
59717/59717 [==============================] - 4s 70us/sample - loss: 1.0942 - accuracy: 0.7321 - val_loss: 1.2269 - val_accuracy: 0.7015
Epoch 4/10
59717/59717 [==============================] - 4s 70us/sample - loss: 0.9096 - accuracy: 0.7654 - val_loss: 0.8207 - val_accuracy: 0.7905
Epoch 5/10
59717/59717 [==============================] - 4s 70us/sample - loss: 0.8373 - accuracy: 0.7790 - val_loss: 0.6863 - val_accuracy: 0.8267
Epoch 6/10
59717/59717 [==============================] - 4s 72us/sample - loss: 0.7925 - accuracy: 0.7918 - val_loss: 0.8132 - val_accuracy: 0.7929
Epoch 7/10
59717/59717 [==============================] - 4s 73us/sample - loss: 0.7916 - accuracy: 0.7925 - val_loss: 0.6749 - val_accuracy: 0.8210
Epoch 8/10
19600/59717 [========>.....................] - ETA: 2s - loss: 0.7475 - accuracy: 0.8011

这是我在Tensorflow中对同一网络的实现:

tf.compat.v1.disable_eager_execution()
batch_size = 10
hid_dim = 32
output_dim = 21
features = train_x.shape[1]

x = tf.compat.v1.placeholder(tf.float32,(batch_size,features),name='x')
y = tf.compat.v1.placeholder(tf.int32,),name='y')

w1 = tf.Variable(tf.compat.v1.random_normal([features,hid_dim]),dtype=tf.float32)
b1 = tf.Variable(tf.compat.v1.random_normal([hid_dim]),dtype=tf.float32)

w2 = tf.Variable(tf.compat.v1.random_normal([hid_dim,output_dim]),dtype=tf.float32)
b2 = tf.Variable(tf.compat.v1.random_normal([output_dim]),dtype=tf.float32)


h1 = tf.nn.relu(tf.matmul(x,w1) + b1)
h2 = tf.matmul(h1,w2) + b2


loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h2,labels=y))
optimizer = tf.compat.v1.train.AdamOptimizer(0.001).minimize(loss)
pred = tf.nn.softmax(h2)

这是我的培训程序实施。在我的情况下,batch_size是固定的,因此在每个时期,我都会将整个数据集逐批馈送到网络。我计算每个批次后的损失,并将其添加到数组中。在每个时期之后,我取该时期的批次损失数组的平均值,以获得我的总体时期损失:

train_in = np.array(train_x)

train_target = np.array(train_y)
train_target = np.squeeze(train_target)

y_t = train_target

num_of_train_batches = len(train_in)/batch_size
init=tf.compat.v1.global_variables_initializer()
print('TRAIN BATCHES: ',num_of_train_batches) 
epoch_list = []
epoch_losses = [] 
epochs = 50
with tf.compat.v1.Session() as sess:
    sess.run(init)
    print('TRAINING')
    for epoch in range(epochs):
      lt = []
      ft = 0
      tt = 1

      train_losses = []
      print('EPOCH: ',epoch)
      epoch_list.append(epoch)
      # RUN WHOLE SET
      for it in range(int(num_of_train_batches)): #len(x_train)/batch_size
          # OPTIMIZE
          _,batch_loss = sess.run([optimizer,loss],Feed_dict={x:train_in[ft*batch_size:tt*batch_size],y:train_target[ft*batch_size:tt*batch_size]})
          train_losses.append(batch_loss)
          
          ft+=1
          tt+=1

      epoch_losses.append(np.array(train_losses).mean())

      print('EPOCH: ',epoch)
      print('LOSS: ',np.array(train_losses).mean())

TRAIN BATCHES:  2200.0
TRAINING
EPOCH:  0
EPOCH:  0
LOSS:  1370.9271
EPOCH:  1
EPOCH:  1
LOSS:  64.23466
EPOCH:  2
EPOCH:  2
LOSS:  36.015495
EPOCH:  3
EPOCH:  3
LOSS:  30.292429
EPOCH:  4
EPOCH:  4
LOSS:  26.436918
EPOCH:  5
EPOCH:  5
LOSS:  25.689302
EPOCH:  6
EPOCH:  6
LOSS:  23.730627
EPOCH:  7
EPOCH:  7
LOSS:  22.356762
EPOCH:  8
EPOCH:  8
LOSS:  21.81124

我的Keras实现仅在使用相同数量的隐藏层和隐藏层大小的8个纪元后才达到0.75损失,但是我的TF实现即使在15个纪元后仍显示大于10的损失。

有人可以指出为什么会发生这种情况吗?我想这个问题与训练程序有关,而不是实际的NN。

欢迎所有建议!

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?