微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

TensorFlow 的 QAT 似乎没有使用 AllValuesQuantizer

如何解决TensorFlow 的 QAT 似乎没有使用 AllValuesQuantizer

我使用 AllValuesQuantizer 创建了两个 QAT 模型,一个是按张量量化,另一个是按通道量化。在检查它们各自的 QuantizeWrapper 层时,我注意到两者都有变量 kernel_min 和 kernel_max 的标量值。

Here is an example of a per-tensor quantized model

Here is an example of a per-channel quantized model

正如我从 this paper 了解到的,内核的最小值/最大值定义了比例和零点量化参数。对于每张量量化,模型只有一个最小值和最大值是合理的,因为整个张量具有相同的标度和零点。但是,对于每通道量化(其中每个通道都有自己的比例和零点),我认为 kernel_min 和 kernel_max 应该是向量?为什么不是?

this github issue 中,有人提到 QAT 会自动使用每张量量化(截至 2020 年 3 月),但这可能会发生变化。在我看来 QAT 仍然只使用每张量量化?如果是这样,为什么我可以设置一个参数来启用每张量量化 (See AllValuesQuantizer's per-axis boolean)?

为了进一步展示我的观点,我还在 source code for the AllValuesQuantizer 中指出 self.per_axis 永远不会传递给下一个函数,那么那个 even 变量是做什么用的?请注意,其他量化器 LastValue 和 MovingAverage 确实会传递此变量。

所以; TF 的 QAT 甚至执行每通道量化吗?对我来说似乎不是。如何将每通道量化与 AllValuesQuantizer 结合使用?

GitHub 问题:https://github.com/tensorflow/tensorflow/issues/47858

复制我的两个模型的代码

import tensorflow as tf
from tensorflow import keras
import tensorflow_model_optimization as tfmot

from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

# Possible quantization aware quantizers:
QAT_ALL_VALUES = tfmot.quantization.keras.quantizers.AllValuesQuantizer
QAT_LAST_VALUE = tfmot.quantization.keras.quantizers.LastValueQuantizer
QAT_MA = tfmot.quantization.keras.quantizers.MovingAverageQuantizer


def quantization_aware_training(model,save,w_bits,a_bits,symmetric,per_axis,narrow_range,quantizer,batch_size=64,epochs=2):

    # Create quantized model's name string
    name = model.name + '_'
    name = name + str(w_bits) + 'wbits_' + str(a_bits) + 'abits_'

    if symmetric:
        name = name + 'sym_'
    else:
        name = name + 'asym_'

    if narrow_range:
        name = name + 'narr_'
    else:
        name = name + 'full_'

    if per_axis:
        name = name + 'perch_'
    else:
        name = name + 'perten_'

    if quantizer == QAT_ALL_VALUES:
        name = name + 'AV'
    elif quantizer == QAT_LAST_VALUE:
        name = name + 'LV'
    elif quantizer == QAT_MA:
        name = name + 'MA'

    # Quantization
    # *****
    quantize_apply = tfmot.quantization.keras.quantize_apply
    quantize_model = tfmot.quantization.keras.quantize_model
    quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
    clone_model = tf.keras.models.clone_model
    quantize_scope = tfmot.quantization.keras.quantize_scope

    supported_layers = [
        tf.keras.layers.Conv2D,]

    class Quantizer(tfmot.quantization.keras.QuantizeConfig):
        # Configure how to quantize weights.
        def get_weights_and_quantizers(self,layer):
            return [(layer.kernel,tfmot.quantization.keras.quantizers.LastValueQuantizer(num_bits=8,symmetric=True,narrow_range=False,per_axis=False))]

        # Configure how to quantize activations.
        def get_activations_and_quantizers(self,layer):
            return [(layer.activation,tfmot.quantization.keras.quantizers.MovingAverageQuantizer(num_bits=8,symmetric=False,per_axis=False))]

        def set_quantize_weights(self,layer,quantize_weights):
            # Add this line for each item returned in `get_weights_and_quantizers`
            #,in the same order
            layer.kernel = quantize_weights[0]

        def set_quantize_activations(self,quantize_activations):
            # Add this line for each item returned in `get_activations_and_quantizers`
            #,in the same order.
            layer.activation = quantize_activations[0]

        # Configure how to quantize outputs (may be equivalent to activations).
        def get_output_quantizers(self,layer):
            return []

        def get_config(self):
            return {}

    class ConvQuantizer(Quantizer):
        # Configure weights to quantize with 4-bit instead of 8-bits.
        def get_weights_and_quantizers(self,quantizer(num_bits=w_bits,symmetric=symmetric,narrow_range=narrow_range,per_axis=per_axis))]

        # Configure how to quantize activations.
        def get_activations_and_quantizers(self,tfmot.quantization.keras.quantizers.MovingAverageQuantizer(num_bits=a_bits,per_axis=False))]

    class DepthwiseQuantizer(Quantizer):
        # Configure weights to quantize with 4-bit instead of 8-bits.
        def get_weights_and_quantizers(self,layer):
            return [(layer.depthwise_kernel,per_axis=False))]

    # Instead of simply using quantize_annotate_model or quantize_model we must use
    # quantize_annotate_layer since it's the only one with a quantize_config argument
    def quantize_all_layers(layer):
        if isinstance(layer,tf.keras.layers.DepthwiseConv2D):
            return quantize_annotate_layer(layer,quantize_config=DepthwiseQuantizer())
        elif isinstance(layer,tf.keras.layers.Conv2D):
            return quantize_annotate_layer(layer,quantize_config=ConvQuantizer())
        return layer

    annotated_model = clone_model(
        model,clone_function=quantize_all_layers
    )

    with quantize_scope(
        {'Quantizer': Quantizer},{'ConvQuantizer': ConvQuantizer},{'DepthwiseQuantizer': DepthwiseQuantizer}):
        q_aware_model = quantize_apply(annotated_model)

    # *****

    # Compile and train model
    optimizer = keras.optimizers.Adam(
        learning_rate=0.001)
    q_aware_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True),optimizer=optimizer,metrics=['sparse_categorical_accuracy'])

    (train_images,train_labels),_ = keras.datasets.cifar10.load_data()

    q_aware_model.fit(train_images,train_labels,batch_size=batch_size,epochs=epochs,verbose=1,validation_split=0.1)

    if save:
        save_path = 'models/temp/' + name
        q_aware_model.save(save_path + '.h5')

    return q_aware_model


def temp_net():
    dropout = 0.1

    model = keras.Sequential()
    model.add(keras.layers.Conv2D(32,(3,3),padding='same',input_shape=(32,32,3)))
    model.add(keras.layers.Batchnormalization())
    model.add(keras.layers.Activation('relu'))

    model.add(keras.layers.Flatten())
    model.add(keras.layers.Dense(10,activation='softmax'))

    model._name = "temp_net"

    return model


if __name__ == "__main__":
    q_model = quantization_aware_training(model=temp_net(),save=True,w_bits=8,a_bits=8,per_axis=False,quantizer=QAT_ALL_VALUES,epochs=1)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?