问题复制 AutoKeras StructuredDataClassifier

如何解决问题复制 AutoKeras StructuredDataClassifier

我有一个使用 AutoKeras 生成的模型,我想复制该模型,以便我可以使用 keras 调谐器构建它以进行进一步的超参数调整。但是我在复制模型时遇到了问题。 autokeras模型的模型总结是:

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None,11)]              0         
_________________________________________________________________
multi_category_encoding (Mul (None,11)                0         
_________________________________________________________________
normalization (normalization (None,11)                23        
_________________________________________________________________
dense (Dense)                (None,16)                192       
_________________________________________________________________
re_lu (ReLU)                 (None,16)                0         
_________________________________________________________________
dense_1 (Dense)              (None,32)                544       
_________________________________________________________________
re_lu_1 (ReLU)               (None,32)                0         
_________________________________________________________________
dense_2 (Dense)              (None,3)                 99        
_________________________________________________________________
classification_head_1 (Softm (None,3)                 0         
=================================================================
Total params: 858
Trainable params: 835
Non-trainable params: 23

图层配置

{'batch_input_shape': (None,11),'dtype': 'string','sparse': False,'ragged': False,'name': 'input_1'}
{'name': 'multi_category_encoding','trainable': True,'dtype': 'float32','encoding': ListWrapper(['int','int','int'])}
{'name': 'normalization','axis': (-1,)}
{'name': 'dense','units': 16,'activation': 'linear','use_bias': True,'kernel_initializer': {'class_name': 'GlorotUniform','config': {'seed': None}},'bias_initializer': {'class_name': 'Zeros','config': {}},'kernel_regularizer': None,'bias_regularizer': None,'activity_regularizer': None,'kernel_constraint': None,'bias_constraint': None}
{'name': 're_lu','max_value': None,'negative_slope': array(0.,dtype=float32),'threshold': array(0.,dtype=float32)}
{'name': 'dense_1','units': 32,'bias_constraint': None}
{'name': 're_lu_1',dtype=float32)}
{'name': 'dense_2','units': 3,'bias_constraint': None}
{'name': 'classification_head_1','axis': -1}

我的训练数据是一个数据框,它被转换为包含数字和分类数据的字符串类型。由于输出softmax,我使用 LabelBinarizer 来转换目标类。

为了确保模型被正确复制,我使用 keras.clone_model 创建模型的副本并尝试自己训练。但是当我尝试自己训练它时,尽管达到了 500 个 epoch,但准确率并没有提高。

在从头开始训练模型时,我是否遗漏了什么?

解决方法

AutoKeras 不支持任何直接转换 - 它的依赖项过于内置,无法与包本身隔离。上面的答案表明缺少 softmax 激活是错误的,因为确实存在:

classification_head_1 (Softm --> 可能文本被截断了

接下来 - 您是否注意到缺少参数? 858 是一个非常小的数字 - 那是因为大多数层都有 0 参数 - Autokeras 使用构成其自定义块的自定义层(更多来自 their docs 的块)

您可以看到,要重新创建这些自定义层,您需要它们的确切代码 - 在撰写本文时无法将其隔离(尽管 @haifeng-jin 正在讨论它),因为它们有特定的包用于处理输入数据以及是什么为他们的 NAS (Neural Architecture search) 和他们执行的优化例程提供动力。

除非你可以研究他们的代码和自定义层的实现并重新创建它(这本身会是相当多的工作,但因为代码已经可用,所以不会太多),如果你使用 { {1}} 适用于预定义的 keras 层。这显然会导致模型损坏(例如您目前使用的模型)。

更重要的是,keras.clone_model 会自行调整超参数 - 如果您想进一步调整模型,只需将 AutoKeras 运行更长的时间即可获得更好的结果。

tl;dr 您不能直接克隆具有包内依赖项的自定义层和块。但是如果你想进行超参数调优,你可以运行更长时间的搜索以获得更好的模型。

,

我终于能够解决我的问题。这很奇怪,但即使自定义多类别图层没有参数,它也包含自己的数据映射。为了扩展模型并检查层深度的影响,我通过从现有模型添加多类别层创建了一个新模型。一旦我做了这个训练准确率匹配 AutoKeras。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?