如何解决将 feed_dict 与 Tensorflow Estimator API 一起使用,“您必须为占位符张量‘Placeholder’提供一个值为 dtype float 和 shape [?,784]”
我正在尝试将自定义 Estimator
与 Feed_dict
一起使用。根据几个相关问题,例如this one,我得出了以下代码。请注意,我从 dataset
返回 input_fn
,而不是 next_example,next_label
。错误是
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must Feed a value for placeholder tensor 'Placeholder' with dtype float and shape [?,784]
[[{{node Placeholder}}]]
完整的堆栈跟踪进一步向下。
我遗漏了一些关于数据 X
和 y
究竟是如何输入到图表中的基本概念。任何人都可以阐明我做错了什么吗?谢谢!
代码:
import numpy as np
import tensorflow as tf
class IteratorInitializerHook(tf.compat.v1.train.SessionRunHook):
def __init__(self):
super(IteratorInitializerHook,self).__init__()
self.iterator_initializer_func = None # Will be set in the input_fn
def after_create_session(self,session,coord):
# Initialize the iterator with the data Feed_dict
self.iterator_initializer_func(session)
def get_inputs(X,y):
iterator_initializer_hook = IteratorInitializerHook()
def input_fn():
X_pl = tf.compat.v1.placeholder(X.dtype,[None,X.shape[1]])
y_pl = tf.compat.v1.placeholder(y.dtype,y.shape[1]])
dataset = tf.compat.v1.data.Dataset.from_tensor_slices((X_pl,y_pl))
iterator = dataset.make_initializable_iterator()
dataset = dataset.batch(32)
iterator_initializer_hook.iterator_initializer_func = \
lambda sess: sess.run(iterator.initializer,Feed_dict={X_pl: X,y_pl: y})
return dataset
return input_fn,iterator_initializer_hook
class MyMnist:
def __init__(self,params,**kwargs):
self.loss = 0
self.optimizer = tf.compat.v1.train.AdamOptimizer()
self.W = tf.compat.v1.Variable(tf.zeros([784,10]),trainable=True,name="W")
self.b = tf.compat.v1.Variable(tf.zeros([10]),name="b")
def build_model(self,features,labels,mode):
"""
Build model and return output
"""
is_training = mode == tf.estimator.ModeKeys.TRAIN
output = tf.compat.v1.nn.softmax(
tf.matmul(features,self.W) + self.b
)
return output
def build_total_loss(self,model_outputs,mode):
"""
Return computed loss
"""
loss = tf.compat.v1.losses.softmax_cross_entropy(
labels,model_outputs
)
return loss
def build_optimizer(self):
"""
Setup the optimizer.
:returns: The optimizer
"""
print("build_optimizer")
lr = 0.01
optimizer = tf.compat.v1.train.AdamOptimizer(
learning_rate=lr,name="Adam"
)
return optimizer
def build_train_ops(self,loss):
"""
Setup optimizer and build train ops.
:param Tensor loss: The loss tensor
:return: Train ops
"""
print("build_train_ops")
self.optimizer = self.build_optimizer()
return self.optimizer.minimize(
loss,global_step=tf.compat.v1.train.get_global_step()
)
def model_fn(features,mode,params):
print('model_fn')
model = MyMnist(params)
output = model.build_model(features,mode)
loss = model.build_total_loss(output,mode)
if mode == tf.estimator.ModeKeys.TRAIN:
train_op = model.build_train_ops(loss)
log_hook = \
tf.compat.v1.train.LoggingTensorHook(
{"W is": model.W,"b is": model.b},every_n_iter=1)
return tf.estimator.EstimatorSpec(
mode=mode,loss=loss,train_op=train_op,training_hooks=[log_hook]
)
def train(argv=None):
print("train")
params = { 'mode': 'train','model_dir': './model_dir','training': {'steps': 10 },'size': 100
}
est = tf.estimator.Estimator(
model_fn,model_dir=params['model_dir'],params=params,)
(X_train,l_train),(X_test,l_test) = tf.keras.datasets.mnist.load_data()
y_train = np.zeros((l_train.shape[0],l_train.max()+1),dtype=np.float32)
y_train[np.arange(l_train.shape[0]),l_train] = 1
y_test = np.zeros((l_test.shape[0],l_test.max()+1),dtype=np.float32)
y_test[np.arange(l_test.shape[0]),l_test] = 1
X_train = X_train.reshape((X_train.shape[0],-1)).astype(np.float32)
X_test = X_test.reshape((X_test.shape[0],-1))
train_input_fn,train_iterator_initializer_hook = \
get_inputs(X_train,y_train)
test_input_fn,test_iterator_initializer_hook = get_inputs(X_test,y_test)
if params['mode'] == 'train':
est.train(
input_fn=train_input_fn,hooks=[train_iterator_initializer_hook],steps=params['training']['steps']
)
if __name__ == "__main__":
tf.compat.v1.disable_eager_execution()
tf.compat.v1.app.run(main=train)
堆栈跟踪:
$ python main.py
train
INFO:tensorflow:Using default config.
I0128 18:51:12.061203 139970842568512 estimator.py:1822] Using default config.
INFO:tensorflow:Using config: {'_model_dir': './model_dir','_tf_random_seed': None,'_save_summary_steps': 100,'_save_checkpoints_steps': None,'_save_checkpoints_secs': 600,'_session_config': allow_soft_placement: true
graph_options {
rewrite_options {
Meta_optimizer_iterations: ONE
}
},'_keep_checkpoint_max': 5,'_keep_checkpoint_every_n_hours': 10000,'_log_step_count_steps': 100,'_train_distribute': None,'_device_fn': None,'_protocol': None,'_eval_distribute': None,'_experimental_distribute': None,'_experimental_max_worker_delay_secs': None,'_session_creation_timeout_secs': 7200,'_service': None,'_cluster_spec': ClusterSpec({}),'_task_type': 'worker','_task_id': 0,'_global_id_in_cluster': 0,'_master': '','_evaluation_master': '','_is_chief': True,'_num_ps_replicas': 0,'_num_worker_replicas': 1}
I0128 18:51:12.061676 139970842568512 estimator.py:191] Using config: {'_model_dir': './model_dir','_num_worker_replicas': 1}
WARNING:tensorflow:From /srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/training/training_util.py:235: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
W0128 18:51:12.383132 139970842568512 deprecation.py:317] From /srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/training/training_util.py:235: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
WARNING:tensorflow:From main.py:21: DatasetV1.make_initializable_iterator (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.
Instructions for updating:
This is a deprecated API that should only be used in TF 1 graph mode and legacy TF 2 graph mode available through `tf.compat.v1`. In all other situations -- namely,eager mode and inside `tf.function` -- you can consume dataset elements using `for elem in dataset: ...` or by explicitly creating iterator via `iterator = iter(dataset)` and fetching its elements via `values = next(iterator)`. Furthermore,this API is not available in TF 2. During the transition from TF 1 to TF 2 you can use `tf.compat.v1.data.make_initializable_iterator(dataset)` to create a TF 1 graph mode style iterator for a dataset created through TF 2 APIs. Note that this should be a transient state of your code base as there are in general no guarantees about the interoperability of TF 1 and TF 2 code.
W0128 18:51:12.397508 139970842568512 deprecation.py:317] From main.py:21: DatasetV1.make_initializable_iterator (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.
Instructions for updating:
This is a deprecated API that should only be used in TF 1 graph mode and legacy TF 2 graph mode available through `tf.compat.v1`. In all other situations -- namely,this API is not available in TF 2. During the transition from TF 1 to TF 2 you can use `tf.compat.v1.data.make_initializable_iterator(dataset)` to create a TF 1 graph mode style iterator for a dataset created through TF 2 APIs. Note that this should be a transient state of your code base as there are in general no guarantees about the interoperability of TF 1 and TF 2 code.
INFO:tensorflow:Calling model_fn.
I0128 18:51:12.404662 139970842568512 estimator.py:1162] Calling model_fn.
model_fn
build_train_ops
build_optimizer
INFO:tensorflow:Done calling model_fn.
I0128 18:51:12.476285 139970842568512 estimator.py:1164] Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
I0128 18:51:12.477094 139970842568512 basic_session_run_hooks.py:546] Create CheckpointSaverHook.
data size:0.033848 MB
data size:0.033848 MB
INFO:tensorflow:Graph was finalized.
I0128 18:51:12.530334 139970842568512 monitored_session.py:246] Graph was finalized.
2021-01-28 18:51:12.530598: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (onednN)to use the following cpu instructions in performance-critical operations: SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA
To enable them in other operations,rebuild TensorFlow with the appropriate compiler flags.
2021-01-28 18:51:12.539575: I tensorflow/core/platform/profile_utils/cpu_utils.cc:104] cpu Frequency: 2400000000 Hz
2021-01-28 18:51:12.543713: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x559677c5ee30 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2021-01-28 18:51:12.543747: I tensorflow/compiler/xla/service/service.cc:176] StreamExecutor device (0): Host,Default Version
INFO:tensorflow:Running local_init_op.
I0128 18:51:12.571843 139970842568512 session_manager.py:505] Running local_init_op.
INFO:tensorflow:Done running local_init_op.
I0128 18:51:12.574033 139970842568512 session_manager.py:508] Done running local_init_op.
data size:0.052762 MB
data size:0.052762 MB
data size:0.052762 MB
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
I0128 18:51:12.677021 139970842568512 basic_session_run_hooks.py:613] Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into ./model_dir/model.ckpt.
I0128 18:51:12.677253 139970842568512 basic_session_run_hooks.py:618] Saving checkpoints for 0 into ./model_dir/model.ckpt.
data size:0.052762 MB
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
I0128 18:51:12.704515 139970842568512 basic_session_run_hooks.py:625] Calling checkpoint listeners after saving checkpoint 0...
Traceback (most recent call last):
File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/client/session.py",line 1365,in _do_call
return fn(*args)
File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/client/session.py",line 1349,in _run_fn
return self._call_tf_sessionrun(options,Feed_dict,fetch_list,File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/client/session.py",line 1441,in _call_tf_sessionrun
return tf_session.TF_SessionRun_wrapper(self._session,options,tensorflow.python.framework.errors_impl.InvalidArgumentError: You must Feed a value for placeholder tensor 'Placeholder' with dtype float and shape [?,784]
[[{{node Placeholder}}]]
During handling of the above exception,another exception occurred:
Traceback (most recent call last):
File "main.py",line 144,in <module>
tf.compat.v1.app.run(main=train)
File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/platform/app.py",line 40,in run
_run(main=main,argv=argv,flags_parser=_parse_flags_tolerate_undef)
File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/absl/app.py",line 303,in run
_run_main(main,args)
File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/absl/app.py",line 251,in _run_main
sys.exit(main(argv))
File "main.py",line 136,in train
est.train(
File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow_estimator/python/estimator/estimator.py",line 349,in train
loss = self._train_model(input_fn,hooks,saving_listeners)
File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow_estimator/python/estimator/estimator.py",line 1175,in _train_model
return self._train_model_default(input_fn,line 1206,in _train_model_default
return self._train_with_estimator_spec(estimator_spec,worker_hooks,File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow_estimator/python/estimator/estimator.py",line 1495,in _train_with_estimator_spec
with training.MonitoredTrainingSession(
File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/training/monitored_session.py",line 601,in MonitoredTrainingSession
return MonitoredSession(
File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/training/monitored_session.py",line 1034,in __init__
super(MonitoredSession,self).__init__(
File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/training/monitored_session.py",line 749,in __init__
self._sess = _RecoverableSession(self._coordinated_creator)
File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/training/monitored_session.py",line 1231,in __init__
_WrappedSession.__init__(self,self._create_session())
File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/training/monitored_session.py",line 1236,in _create_session
return self._sess_creator.create_session()
File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/training/monitored_session.py",line 909,in create_session
hook.after_create_session(self.tf_sess,self.coord)
File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow_estimator/python/estimator/util.py",line 86,in after_create_session
session.run(self._initializer)
File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/client/session.py",line 957,in run
result = self._run(None,fetches,options_ptr,line 1180,in _run
results = self._do_run(handle,final_targets,final_fetches,line 1358,in _do_run
return self._do_call(_run_fn,Feeds,targets,line 1384,in _do_call
raise type(e)(node_def,op,message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must Feed a value for placeholder tensor 'Placeholder' with dtype float and shape [?,784]
[[node Placeholder (defined at main.py:17) ]]
Original stack trace for 'Placeholder':
File "main.py",line 1201,in _train_model_default
self._get_features_and_labels_from_input_fn(input_fn,ModeKeys.TRAIN))
File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow_estimator/python/estimator/estimator.py",line 1037,in _get_features_and_labels_from_input_fn
self._call_input_fn(input_fn,mode))
File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow_estimator/python/estimator/estimator.py",line 1130,in _call_input_fn
return input_fn(**kwargs)
File "main.py",line 17,in input_fn
X_pl = tf.compat.v1.placeholder(X.dtype,X.shape[1]])
File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py",line 3100,in placeholder
return gen_array_ops.placeholder(dtype=dtype,shape=shape,name=name)
File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/ops/gen_array_ops.py",line 6808,in placeholder
_,_,_op,_outputs = _op_def_library._apply_op_helper(
File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/framework/op_def_library.py",line 742,in _apply_op_helper
op = g._create_op_internal(op_type_name,inputs,dtypes=None,File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/framework/ops.py",line 3478,in _create_op_internal
ret = Operation(
File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/framework/ops.py",line 1949,in __init__
self._traceback = tf_stack.extract_stack()
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。