如何解决迁移到Tensorflow v2:从tensorflow.python.keras.engine.data_adapter实例化DataHandler
我问了我一个问题,因为在从以下存储库https://github.com/cyprienruffino/CTCModel迁移脚本到Tensorflow 2时遇到了一些麻烦,从而在keras模型上实现了CTC丢失。在CTCModel类的方法predict
中(在keras_ctcmodel目录中),创建了DataHandler类的对象,如data_handler = data_adapter.DataHandler(...),其中data_adapter的导入方式如下:from tensorflow.python.keras.engine import data_adapter
。我得到的错误是未定义 DataHandler ,我还通过检查脚本 data_adapter.py 对此进行了验证。
下面我附上了我所指的代码块:
# Creates a `tf.data.Dataset` and handles batch and epoch iteration.
data_handler = data_adapter.DataHandler(
x=x,batch_size=batch_size,steps_per_epoch=steps,initial_epoch=0,epochs=1,max_queue_size=10,workers=1,use_multiprocessing=False,model=self,steps_per_execution=tf.constant(1,dtype=tf.int32))
if not isinstance(callbacks,callbacks_module.CallbackList):
callbacks = callbacks_module.CallbackList(
callbacks,add_history=True,add_progbar=verbose != 0,verbose=verbose,steps=data_handler.inferred_steps)
if self.model_pred.stateful:
if x[0].shape[0] > batch_size and x[0].shape[0] % batch_size != 0:
raise ValueError('In a stateful network,'
'you should only pass inputs with '
'a number of samples that can be '
'divided by the batch size. Found: ' +
str(x[0].shape[0]) + ' samples. '
'Batch size: ' + str(batch_size) + '.')
# Prepare inputs,delegate logic to `_predict_loop`.
ins = x
outputs = None
predict_function = self.model_pred.make_predict_function()
self.model_pred._predict_counter.assign(0)
callbacks.on_predict_begin()
for _,iterator in data_handler.enumerate_epochs(): # Single epoch.
with data_handler.catch_stop_iteration():
for step in data_handler.steps():
callbacks.on_predict_batch_begin(step)
tmp_batch_outputs = predict_function(iterator)
if data_handler.should_sync:
context.async_wait()
batch_outputs = tmp_batch_outputs # No error,Now safe to assign.
if outputs is None:
outputs = nest.map_structure(lambda batch_output: [batch_output],batch_outputs)
else:
nest.map_structure_up_to(
batch_outputs,lambda output,batch_output: output.append(batch_output),outputs,batch_outputs)
end_step = step + data_handler.step_increment
callbacks.on_predict_batch_end(end_step,{'outputs': batch_outputs})
callbacks.on_predict_end()
all_outputs = nest.map_structure_up_to(batch_outputs,concat,outputs)
return tf_utils.to_numpy_or_python_type(all_outputs)
因此,我试图实例化 data_adapter.py 中提供的其他类,这些类似乎在对data_adapter.DataHandler
的原始调用中定义了一些属性,但由于使用的类无法提供一些必需的方法。
是否可以在Tensorflow 2上进行迁移或使此行代码正常工作?
非常感谢!
最诚挚的问候
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。