微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

将 Tensorflow Estimators 与 Dataset API 结合使用会导致奇怪的步骤行为

如何解决将 Tensorflow Estimators 与 Dataset API 结合使用会导致奇怪的步骤行为

我在 Tensorflow 的 Estimator 和数据集 API 的训练循环行为方面遇到了一些问题。 代码如下(tf2.3):


NUM_EXAMPLES = X_train.shape[0] # dataset has 8000 elements
BATCH_SIZE = NUM_EXAMPLES
STEPS = NONE
N_EPOCHS = 100

def make_input_fn(X,y,n_epochs=N_EPOCHS,shuffle=True):
    dataset = tf.data.Dataset.from_tensor_slices((X.to_dict(orient='list'),y))
    if shuffle:
      dataset = dataset.shuffle(NUM_EXAMPLES)
    return dataset.repeat(n_epochs).batch(BATCH_SIZE)


estimator = tf.estimator.BoostedTreesClassifier(feature_cols,{
    'config': tf.estimator.runconfig(
        model_dir=model_dir,save_checkpoints_steps=100
    ),'n_trees': 50,'max_depth': 6,'n_batches_per_layer': 1,'l2_regularization': 0.1
})

estimator.train(input_fn=lambda: make_input_fn(X_train,y_train),steps=STEPS)

我只是不明白我看到的行为。 TF 估计器训练的步数似乎以 300 步为上限,无论我为 batch_size、训练步数或 epoch 数设置了什么。

我的数据集有 8K 个训练元素,当我选择带有 n_epochs=100batch_size=1000steps=None 时,我期望 tensorflow 将运行 100 (n_epochs) * 8 (steps required for 1 epoch) 步,但没有,它运行了 300。

下面实际上是对不同 N_EPOCHSBATCH_SIZESTEPS 的多次实验的总结,前 3 个对我来说很好,但其余的不是。

- 步骤 N_EPOCHS BATCH_SIZE TF 训练步骤 (est.train) 我预期的 # 个步骤
1 100 8000 100 100
2 200 8000 200 200
3 300 8000 300 300
4 400 8000 300 400
5 100 1000 300 800
6 600 10 300 600 * 800
7 400 400 8000 300 400

可以看出,从第 4 行开始,我的期望不等于 tensorflow 训练运行的实际步骤。这基本上意味着当我降低 batch_size10 时,它仅在大小为 10 的数据批次上运行 300 次,这是错误的,但我无法理解我的实现有什么不正确,看在文档中,非常感谢任何帮助!

此外,我将 train_and_evaluate 与 Specs 一起使用还是直接使用 train 都没有关系,为了简单起见,此处使用 train。 train 函数的日志如下(对于第 7 个实验):

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported,or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported,or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/estimator-run-1609770778/model.ckpt.
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported,or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.69314593,step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 100...
INFO:tensorflow:Saving checkpoints for 100 into /tmp/estimator-run-1609770778/model.ckpt.
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported,or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 100...
INFO:tensorflow:global_step/sec: 2.88621
INFO:tensorflow:loss = 0.64561623,step = 99 (34.648 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 200...
INFO:tensorflow:Saving checkpoints for 200 into /tmp/estimator-run-1609770778/model.ckpt.
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported,or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 200...
INFO:tensorflow:global_step/sec: 2.9199
INFO:tensorflow:loss = 0.6292723,step = 199 (34.248 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 300...
INFO:tensorflow:Saving checkpoints for 300 into /tmp/estimator-run-1609770778/model.ckpt.
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported,or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 300...
INFO:tensorflow:global_step/sec: 2.83013
INFO:tensorflow:loss = 0.6164282,step = 299 (35.334 sec)
INFO:tensorflow:Loss for final step: 0.6164282.

解决方法

我认为这里的线索是n_trees * max_depth = 300

    'n_trees': 50,'max_depth': 6,

另外,看一下行this test

      # It will stop after 5 steps because of the max depth and num trees.
      num_steps = 100

我不知道确切的逻辑,我想这只是增加了一个层到一个树与每个批次。您可以设置的树木和最大深度的数量,它不能坚持训练,一旦所有的树木都建立。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?