微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

“函数调用堆栈:train_function”在图像识别 AI 中是什么意思?

如何解决“函数调用堆栈:train_function”在图像识别 AI 中是什么意思?

所以我正在尝试将图像识别 AI 用于一个项目。我使用的模块是本网站 https://towardsdatascience.com/train-image-recognition-ai-with-5-lines-of-code-8ed0bdd8d9ba 上的模块。我已经安装了所有的 pip 及​​其给定的变体。

from imageai.Classification.Custom import ClassificationModelTrainer

directory = "C:\\Users\\DELL\\Desktop\\AI Project\\idenprof-jpg\\idenprof\\"
model_trainer = ClassificationModelTrainer()
model_trainer.setModelTypeAsresnet50()
model_trainer.setDataDirectory(directory)
model_trainer.trainModel(num_objects=10,num_experiments=200,enhance_data=True,batch_size=32,show_network_summary=True)

然后我收到此错误

回溯(最近一次调用最后一次): 文件“C:\Users\DELL\Desktop\AI Project\Test 2.py”,第 7 行,在 model_trainer.trainModel(num_objects=10,num_experiments=200,enhance_data=True,batch_size=32,show_network_summary=True) 文件“C:\Users\DELL\AppData\Local\Programs\Python\python38\lib\site-packages\imageai\Classification\Custom_init_.py”,第 393 行,在 trainModel 中 model.fit_generator(train_generator,steps_per_epoch=int(num_train/batch_size),epochs=self.__num_epochs,文件“C:\Users\DELL\AppData\Local\Programs\Python\python38\lib\site-packages\tensorflow\python\keras\engine\training.py”,第 1847 行,在 fit_generator 中 返回 self.fit( 文件“C:\Users\DELL\AppData\Local\Programs\Python\python38\lib\site-packages\tensorflow\python\keras\engine\training.py”,第 1100 行,适合 tmp_logs = self.train_function(iterator) 文件“C:\Users\DELL\AppData\Local\Programs\Python\python38\lib\site-packages\tensorflow\python\eager\def_function.py”,第 828 行,调用 结果 = self._call(*args,**kwds) 文件“C:\Users\DELL\AppData\Local\Programs\Python\python38\lib\site-packages\tensorflow\python\eager\def_function.py”,第 888 行,在 _call 返回 self._stateless_fn(*args,**kwds) 文件“C:\Users\DELL\AppData\Local\Programs\Python\python38\lib\site-packages\tensorflow\python\eager\function.py”,第 2942 行,调用 返回 graph_function._call_flat( 文件“C:\Users\DELL\AppData\Local\Programs\Python\python38\lib\site-packages\tensorflow\python\eager\function.py”,第 1918 行,在 _call_flat 返回 self._build_call_outputs(self._inference_function.call( 文件“C:\Users\DELL\AppData\Local\Programs\Python\python38\lib\site-packages\tensorflow\python\eager\function.py”,第 555 行,调用输出 = 执行.执行( 文件“C:\Users\DELL\AppData\Local\Programs\Python\python38\lib\site-packages\tensorflow\python\eager\execute.py”,第 59 行,在 quick_execute 张量 = pywrap_tfe.TFE_Py_Execute(ctx.handle,device_name,op_name,tensorflow.python.framework.errors_impl.InvalidArgumentError:logits 和标签必须是可广播的:logits_size=[32,10] labels_size=[32,13] [[node categorical_crossentropy/softmax_cross_entropy_with_logits(定义在 C:\Users\DELL\AppData\Local\Programs\Python\python38\lib\site-packages\imageai\Classification\Custom_init。 py:393)]] [操作:__inference_train_function_11908] 函数调用栈: train_function

我该怎么做才能解决这个问题?另外,如果格式不正确,这是第一次提问。

解决方法

train_function 是调用 model.fit() 时隐藏在幕后的函数。大致是这样的:

def train_step(self,data):
    x,y = data

    with tf.GradientTape() as tape:
        y_pred = self(x,training=True)  # Forward pass
        # Compute our own loss
        loss = keras.losses.mean_squared_error(y,y_pred)

    # Compute gradients
    trainable_vars = self.trainable_variables
    gradients = tape.gradient(loss,trainable_vars)

    # Update weights
    self.optimizer.apply_gradients(zip(gradients,trainable_vars))

    # Compute our own metrics
    loss_tracker.update_state(loss)
    mae_metric.update_state(y,y_pred)
    return {"loss": loss_tracker.result(),"mae": mae_metric.result()}

当您在回溯中看到这一点时,它仅表明错误在此函数中的某处。如您所见,它可以是很多东西,因为 train_function 中有很多操作。但是,如果您查找,则会有更多信息。在你的情况下:

InvalidArgumentError:logits 和标签必须是可广播的:logits_size=[32,10] labels_size=[32,13] [[node categorical_crossentropy/softmax_cross_entropy_with_logits

您的神经网络有 10 个输出神经元,但您的标签有 13 个类别。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?