使用池索引keras进行向上采样不池

如何解决使用池索引keras进行向上采样不池

我首先说我是深度学习的新手

我正在尝试在keras中编写一个使用池索引来进行上采样的segnet。

我将此功能与Lambda图层一起使用以执行最大池化并保存池索引:

def pool_argmax2D(x,pool_size=(2,2),strides=(2,2)):
    padding = 'SAME'
    pool_size = [1,pool_size[0],pool_size[1],1]
    strides = [1,strides[0],strides[1],1]
    ksize = [1,1]
    output,argmax = tf.nn.max_pool_with_argmax(
        x,ksize = ksize,strides = strides,padding = padding
    )

    return [output,argmax]

[...]
pool_4,mask_4 = Lambda(pool_argmax2D,arguments={'pool_size': pool_size,'strides': pool_size})(conv_10)
[...]

似乎可行。在我的模型摘要中,它返回形状为(无,h / 2,w / 2,通道)的张量。 但是,我在查找或编写一个有效的分池功能时遇到了一些问题。 我无法返回形状为(无,2h,2w,通道)的张量 (批量大小无)

我已经尝试过这些(但不仅限于)我在stackoverflow上发现的释放功能Function1 Function2

没有结果

有人可以帮助我吗?谢谢

编辑: 这是我要使用的模型

def getSegNet3(n_ch,height,width,n_labels,output_mode="sigmoid"):
    # encoder
    inputs = Input(shape=(n_ch,width))

    conv_1 = Conv2D(16,(3,3),kernel_initializer='he_normal',padding='same',data_format='channels_first')(inputs)
    conv_1 = Batchnormalization(axis=1)(conv_1)
    conv_1 = Activation("relu")(conv_1)
    conv_2 = Conv2D(16,data_format='channels_first')(conv_1)

    conv_2 = Batchnormalization(axis=1)(conv_2)
    conv_2 = Activation("relu")(conv_2)

    conv_2 = core.Permute((2,3,1))(conv_2)
    pool_1,mask_1 = Lambda(pool_argmax2D,'strides': pool_size})(conv_2)
    pool_1 = core.Permute((3,1,2))(pool_1)

    conv_3 = Conv2D(32,data_format='channels_first')(pool_1)
    conv_3 = Batchnormalization(axis=1)(conv_3)
    conv_3 = Activation("relu")(conv_3)
    conv_4 = Conv2D(32,data_format='channels_first')(conv_3)
    conv_4 = Batchnormalization(axis=1)(conv_4)
    conv_4 = Activation("relu")(conv_4)

    conv_4 = core.Permute((2,1))(conv_4)
    pool_2,mask_2 = Lambda(pool_argmax2D,'strides': pool_size})(conv_4)
    pool_2 = core.Permute((3,2))(pool_2)

    conv_5 = Conv2D(64,data_format='channels_first')(pool_2)
    conv_5 = Batchnormalization(axis=1)(conv_5)
    conv_5 = Activation("relu")(conv_5)
    conv_6 = Conv2D(64,data_format='channels_first')(conv_5)
    conv_6 = Batchnormalization(axis=1)(conv_6)
    conv_6 = Activation("relu")(conv_6)
    conv_7 = Conv2D(64,data_format='channels_first')(conv_6)
    conv_7 = Batchnormalization(axis=1)(conv_7)
    conv_7 = Activation("relu")(conv_7)

    conv_7 = core.Permute((2,1))(conv_7)
    pool_3,mask_3 = Lambda(pool_argmax2D,'strides': pool_size})(conv_7)
    pool_3 = core.Permute((3,2))(pool_3)

    conv_8 = Conv2D(128,data_format='channels_first')(pool_3)
    conv_8 = Batchnormalization(axis=1)(conv_8)
    conv_8 = Activation("relu")(conv_8)
    conv_9 = Conv2D(128,data_format='channels_first')(conv_8)
    conv_9 = Batchnormalization(axis=1)(conv_9)
    conv_9 = Activation("relu")(conv_9)
    conv_10 = Conv2D(128,data_format='channels_first')(conv_9)
    conv_10 = Batchnormalization(axis=1)(conv_10)
    conv_10 = Activation("relu")(conv_10)

    conv_10 = core.Permute((2,1))(conv_10)
    pool_4,'strides': pool_size})(conv_10)
    pool_4 = core.Permute((3,2))(pool_4)

    conv_11 = Conv2D(256,data_format='channels_first')(pool_4)
    conv_11 = Batchnormalization(axis=1)(conv_11)
    conv_11 = Activation("relu")(conv_11)
    conv_12 = Conv2D(256,data_format='channels_first')(conv_11)
    conv_12 = Batchnormalization(axis=1)(conv_12)
    conv_12 = Activation("relu")(conv_12)
    conv_13 = Conv2D(256,data_format='channels_first')(conv_12)
    conv_13 = Batchnormalization(axis=1)(conv_13)
    conv_13 = Activation("relu")(conv_13)

    conv_13 = core.Permute((2,1))(conv_13)
    pool_5,mask_5 = Lambda(pool_argmax2D,'strides': pool_size})(conv_13)

    print("Build encoder done..")

    # decoder


    #unpool_1 = MaxUnpooling2D(pool_5,mask_5,(None,4,256))
    unpool_1 = Lambda(unpool2D,output_shape=(4,256),arguments={'ind':mask_5})(pool_5)
    unpool_1 = core.Permute((3,2))(unpool_1)

    conv_14 = Conv2D(256,data_format='channels_first')(unpool_1)
    conv_14 = Batchnormalization(axis=1)(conv_14)
    conv_14 = Activation("relu")(conv_14)
    conv_15 = Conv2D(256,data_format='channels_first')(conv_14)
    conv_15 = Batchnormalization(axis=1)(conv_15)
    conv_15 = Activation("relu")(conv_15)
    conv_16 = Conv2D(256,data_format='channels_first')(conv_15)
    conv_16 = Batchnormalization(axis=1)(conv_16)
    conv_16 = Activation("relu")(conv_16)

    conv_16 = core.Permute((2,1))(conv_16)
    unpool_2 = Lambda(unpool2D,output_shape=(8,8,arguments={'ind':mask_4})(conv_16)
    unpool_2 = core.Permute((3,2))(unpool_2)

    conv_17 = Conv2D(256,data_format='channels_first')(unpool_2)
    conv_17 = Batchnormalization(axis=1)(conv_17)
    conv_17 = Activation("relu")(conv_17)
    conv_18 = Conv2D(256,data_format='channels_first')(conv_17)
    conv_18 = Batchnormalization(axis=1)(conv_18)
    conv_18 = Activation("relu")(conv_18)
    conv_19 = Conv2D(128,data_format='channels_first')(conv_18)
    conv_19 = Batchnormalization(axis=1)(conv_19)
    conv_19 = Activation("relu")(conv_19)

    conv_19 = core.Permute((2,1))(conv_19)
    unpool_3 = Lambda(unpool2D,output_shape=(16,16,128),arguments={'ind':mask_3})(conv_19)
    unpool_3 = core.Permute((3,2))(unpool_3)


    conv_20 = Conv2D(128,data_format='channels_first')(unpool_3)
    conv_20 = Batchnormalization(axis=1)(conv_20)
    conv_20 = Activation("relu")(conv_20)
    conv_21 = Conv2D(128,data_format='channels_first')(conv_20)
    conv_21 = Batchnormalization(axis=1)(conv_21)
    conv_21 = Activation("relu")(conv_21)
    conv_22 = Conv2D(64,data_format='channels_first')(conv_21)
    conv_22 = Batchnormalization(axis=1)(conv_22)
    conv_22 = Activation("relu")(conv_22)

    conv_22 = core.Permute((2,1))(conv_22)
    unpool_4 = Lambda(unpool2D,output_shape=(32,32,64),arguments={'ind':mask_2})(conv_22)
    unpool_4 = core.Permute((3,2))(unpool_4)

    conv_23 = Conv2D(64,data_format='channels_first')(unpool_4)
    conv_23 = Batchnormalization(axis=1)(conv_23)
    conv_23 = Activation("relu")(conv_23)
    conv_24 = Conv2D(32,data_format='channels_first')(conv_23)
    conv_24 = Batchnormalization(axis=1)(conv_24)
    conv_24 = Activation("relu")(conv_24)

    conv_24 = core.Permute((2,1))(conv_24)
    unpool_5 = Lambda(unpool2D,output_shape=(64,64,32),arguments{'ind':mask_1})(conv_24)
    unpool_5 = core.Permute((3,2))(unpool_5)

    conv_25 = Conv2D(32,data_format='channels_first')(unpool_5)
    conv_25 = Batchnormalization(axis=1)(conv_25)
    conv_25 = Activation("relu")(conv_25)
    conv_26 = Convolution2D(n_labels,(1,1),padding="valid",data_format="channels_first")(conv_25)
    conv_26 = Batchnormalization(axis=1)(conv_26)


    outputs = Activation(output_mode)(conv_26)
    print("Build decoder done..")

    model = Model(inputs=inputs,outputs=outputs,name="SegNet")

    return model

我要使用的功能

def unpool2D(pool,ind,ksize=(2,2)):
    with tf.compat.v1.variable_scope("unpool"):
        input_shape =  tf.shape(pool)
        output_shape = [input_shape[0],input_shape[1] * ksize[0],input_shape[2] * ksize[1],input_shape[3]]

        flat_input_size = tf.math.cumprod(input_shape)[-1]
        flat_output_shape = tf.cast([output_shape[0],output_shape[1] * output_shape[2] * output_shape[3]],tf.int64)

        pool_ = tf.reshape(pool,[flat_input_size])
        batch_range = tf.reshape(tf.range(tf.cast(output_shape[0],tf.int64),dtype=tf.int64),shape=[input_shape[0],1])

        b = tf.ones_like(ind) * batch_range
        b = tf.reshape(b,[flat_input_size,1])

        ind_ = tf.reshape(ind,1]) % flat_output_shape[1]
        ind_ = tf.concat([b,ind_],1)
        ret = tf.scatter_nd(ind_,pool_,shape=flat_output_shape)
        ret = tf.reshape(ret,output_shape)
        return ret

这就是我得到的:

~/bones-adamo/models.py in getSegNet3(n_ch,pool_size,output_mode)
   1013     unpool_1 = core.Permute((3,2))(unpool_1)
   1014 
-> 1015     conv_14 = Conv2D(256,data_format='channels_first')(unpool_1)
   1016     conv_14 = Batchnormalization(axis=1)(conv_14)
   1017     conv_14 = Activation("relu")(conv_14)

~/venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self,*args,**kwargs)
    923     # >> model = tf.keras.Model(inputs,outputs)
    924     if _in_functional_construction_mode(self,inputs,args,kwargs,input_list):
--> 925       return self._functional_construction_call(inputs,926                                                 input_list)
    927 

~/venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _functional_construction_call(self,input_list)
   1096         # Build layer if applicable (if the `build` method has been
   1097         # overridden).
-> 1098         self._maybe_build(inputs)
   1099         cast_inputs = self._maybe_cast_inputs(inputs,input_list)
   1100 

~/venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _maybe_build(self,inputs)
   2641         # operations.
   2642         with tf_utils.maybe_init_scope(self):
-> 2643           self.build(input_shapes)  # pylint:disable=not-callable
   2644       # We must set also ensure that the layer is marked as built,and the build
   2645       # shape is stored since user defined build functions may not be calling

~/venv/lib/python3.8/site-packages/tensorflow/python/keras/layers/convolutional.py in build(self,input_shape)
    185   def build(self,input_shape):
    186     input_shape = tensor_shape.TensorShape(input_shape)
--> 187     input_channel = self._get_input_channel(input_shape)
    188     if input_channel % self.groups != 0:
    189       raise ValueError(

~/venv/lib/python3.8/site-packages/tensorflow/python/keras/layers/convolutional.py in _get_input_channel(self,input_shape)
    357     channel_axis = self._get_channel_axis()
    358     if input_shape.dims[channel_axis].value is None:
--> 359       raise ValueError('The channel dimension of the inputs '
    360                        'should be defined. Found `None`.')
    361     return int(input_shape[channel_axis])

ValueError: The channel dimension of the inputs should be defined. Found `None`.

解决方法

好的,我解决了我的问题。我第一次没有发现模型架构问题。 如果要使用池索引上采样,建议您使用这些自定义图层here

class MaxUnpooling2D(Layer):
    def __init__(self,size=(2,2),**kwargs):
        super(MaxUnpooling2D,self).__init__(**kwargs)
        self.size = size

    def call(self,inputs,output_shape=None):
        updates,mask = inputs[0],inputs[1]
        with tf.compat.v1.variable_scope(self.name):
            mask = K.cast(mask,'int32')
            input_shape = tf.shape(updates,out_type='int32')
            #print(updates.shape)
            #print(mask.shape)
            if output_shape is None:
                output_shape = (
                    input_shape[0],input_shape[1] * self.size[0],input_shape[2] * self.size[1],input_shape[3])

            ret = tf.scatter_nd(K.expand_dims(K.flatten(mask)),K.flatten(updates),[K.prod(output_shape)])

            input_shape = updates.shape
            out_shape = [-1,input_shape[3]]
        return K.reshape(ret,out_shape)

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'size': self.size
        })
        return config

    def compute_output_shape(self,input_shape):
        mask_shape = input_shape[1]
        return (
                mask_shape[0],mask_shape[1]*self.size[0],mask_shape[2]*self.size[1],mask_shape[3]
                )

用法示例:

unpool_3 = MaxUnpooling2D()([conv_19,mask_3])

我添加了get_config以避免此错误:

NotImplementedError: Layer MaxPoolingWithArgmax2D has arguments in `__init__` and therefore must override `get_config`.

希望这个答案可以帮助其他用户

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?