微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何在 Mobile Net 中对量化的 TensorFlow lite 模型使用批量归一化?

如何解决如何在 Mobile Net 中对量化的 TensorFlow lite 模型使用批量归一化?

我想在嵌入式系统(Coral USB Accelerator)上运行移动网络来分析我转换为频谱图的录音。 我已经根据 Tensorflow 中的 Mobile Net paper 实现了模型。

def SeparableConv(x,num_filters,strides,alpha=1.0):
    x = tf.keras.layers.DepthwiseConv2D(kernel_size=3,padding='same')(x)
    x = tf.keras.layers.Batchnormalization()(x) 
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.Conv2D(np.floor( num_filters * alpha),kernel_size=(1,1),strides=strides,use_bias=False,padding='same')(x) 
    x = tf.keras.layers.Batchnormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    return x 


def Conv(x,kernel_size,strides=1,alpha=1.0):
    x = tf.keras.layers.Conv2D((np.floor( num_filters * alpha)),kernel_size=kernel_size,padding='same')(x)
    x = tf.keras.layers.Batchnormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.ReLU()(x)
    return x

inputs = tf.keras.layers.Input(shape=(64,64,1))

x = Conv(inputs,num_filters=16,kernel_size=3,strides=2)
x = SeparableConv(x,num_filters=32,strides=1)
x = SeparableConv(x,num_filters=64,num_filters=128,num_filters=256,num_filters=512,strides=1)

x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(512)(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dropout(0.001)(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(2)(x)

outputs = tf.keras.layers.ReLU()(x)
model = tf.keras.models.Model(inputs,outputs)

要在 Coral USB Accelerator 上运行它,我需要对其进行量化并将其转换为 TensorFlow Lite 模型(然后使用 Edge TPU 编译器再次编译它)。但是在从 32 位浮点数量化到 8 位整数后,预测会变得更糟。问题似乎是批量标准化keras.layers.BatchNormalization()。这不是 TensorFlow lite 中允许的指令。但由于 Mobile Net 是专门为嵌入式系统设计的,而批量归一化是其中的一个基本部分,我无法想象这在 Edge TPU 上是不可能的。所以我想问一下是否有人知道获取 Mobile Net 的解决方法,特别是在 Coral USB Accelerator 上工作的 Batch normalization?

非常感谢您提前提供帮助和建议!

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?