如何解决将 Tensorflow Estimators 与 Dataset API 结合使用会导致奇怪的步骤行为
我在 Tensorflow 的 Estimator 和数据集 API 的训练循环行为方面遇到了一些问题。 代码如下(tf2.3):
NUM_EXAMPLES = X_train.shape[0] # dataset has 8000 elements
BATCH_SIZE = NUM_EXAMPLES
STEPS = NONE
N_EPOCHS = 100
def make_input_fn(X,y,n_epochs=N_EPOCHS,shuffle=True):
dataset = tf.data.Dataset.from_tensor_slices((X.to_dict(orient='list'),y))
if shuffle:
dataset = dataset.shuffle(NUM_EXAMPLES)
return dataset.repeat(n_epochs).batch(BATCH_SIZE)
estimator = tf.estimator.BoostedTreesClassifier(feature_cols,{
'config': tf.estimator.runconfig(
model_dir=model_dir,save_checkpoints_steps=100
),'n_trees': 50,'max_depth': 6,'n_batches_per_layer': 1,'l2_regularization': 0.1
})
estimator.train(input_fn=lambda: make_input_fn(X_train,y_train),steps=STEPS)
我只是不明白我看到的行为。 TF 估计器训练的步数似乎以 300 步为上限,无论我为 batch_size、训练步数或 epoch 数设置了什么。
我的数据集有 8K 个训练元素,当我选择带有 n_epochs=100
和 batch_size=1000
的 steps=None
时,我期望 tensorflow 将运行 100 (n_epochs) * 8 (steps required for 1 epoch)
步,但没有,它运行了 300。
下面实际上是对不同 N_EPOCHS
、BATCH_SIZE
和 STEPS
的多次实验的总结,前 3 个对我来说很好,但其余的不是。
- | 步骤 | N_EPOCHS | BATCH_SIZE | TF 训练步骤 (est.train ) |
我预期的 # 个步骤 |
---|---|---|---|---|---|
1 | 无 | 100 | 8000 | 100 | 100 |
2 | 无 | 200 | 8000 | 200 | 200 |
3 | 无 | 300 | 8000 | 300 | 300 |
4 | 无 | 400 | 8000 | 300 | 400 |
5 | 无 | 100 | 1000 | 300 | 800 |
6 | 无 | 600 | 10 | 300 | 600 * 800 |
7 | 400 | 400 | 8000 | 300 | 400 |
可以看出,从第 4 行开始,我的期望不等于 tensorflow 训练运行的实际步骤。这基本上意味着当我降低 batch_size
说 10
时,它仅在大小为 10 的数据批次上运行 300 次,这是错误的,但我无法理解我的实现有什么不正确,看在文档中,非常感谢任何帮助!
此外,我将 train_and_evaluate
与 Specs 一起使用还是直接使用 train
都没有关系,为了简单起见,此处使用 train
。 train 函数的日志如下(对于第 7 个实验):
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported,or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported,or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/estimator-run-1609770778/model.ckpt.
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported,or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.69314593,step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 100...
INFO:tensorflow:Saving checkpoints for 100 into /tmp/estimator-run-1609770778/model.ckpt.
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported,or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 100...
INFO:tensorflow:global_step/sec: 2.88621
INFO:tensorflow:loss = 0.64561623,step = 99 (34.648 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 200...
INFO:tensorflow:Saving checkpoints for 200 into /tmp/estimator-run-1609770778/model.ckpt.
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported,or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 200...
INFO:tensorflow:global_step/sec: 2.9199
INFO:tensorflow:loss = 0.6292723,step = 199 (34.248 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 300...
INFO:tensorflow:Saving checkpoints for 300 into /tmp/estimator-run-1609770778/model.ckpt.
WARNING:tensorflow:Issue encountered when serializing resources.
Type is unsupported,or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.
'_Resource' object has no attribute 'name'
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 300...
INFO:tensorflow:global_step/sec: 2.83013
INFO:tensorflow:loss = 0.6164282,step = 299 (35.334 sec)
INFO:tensorflow:Loss for final step: 0.6164282.
解决方法
我认为这里的线索是n_trees * max_depth = 300
:
'n_trees': 50,'max_depth': 6,
另外,看一下行this test:
# It will stop after 5 steps because of the max depth and num trees.
num_steps = 100
我不知道确切的逻辑,我想这只是增加了一个层到一个树与每个批次。您可以设置的树木和最大深度的数量,它不能坚持训练,一旦所有的树木都建立。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。