微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

联合优化自动编码器和完全连接的网络以进行分类

如何解决联合优化自动编码器和完全连接的网络以进行分类

我有大量未标记的数据集和一小组较少的标记数据。因此,我想先在未标记的数据上训练变型自动编码器,然后再使用编码器对标记的数据上的三个类别(附加完全连接的层)进行分类。为了优化超参数,我想使用Optuna。

一种可能性是先优化自动编码器,然后优化全连接网络(分类),但随后自动编码器可能会学习到对于分类没有意义的编码。

是否有可能共同优化自动编码器和完全连接的网络?

我的自动编码器如下所示(params只是一个包含params的字典):

inputs = Input(shape=image_size,name='encoder_input')
x = inputs

for i in range(len(params["conv_filter_encoder"])):
    x,_ = convolutional_unit(x,params["conv_filter_encoder"][i],params["conv_kernel_size_encoder"][i],params["strides_encoder"][i],batchnorm=params["batchnorm"][i],dropout=params["dropout"][i],maxpool=params["maxpool"][i],deconv=False)

shape = K.int_shape(x)

x = Flatten()(x)
x = Dense(params["inner_dim"],activation='relu')(x)
z_mean = Dense(params["latent_dim"],name='z_mean')(x)
z_log_var = Dense(params["latent_dim"],name='z_log_var')(x)

# use reparameterization trick to push the sampling out as input
# note that "output_shape" isn't necessary with the TensorFlow backend
z = Lambda(sampling,output_shape=(params["latent_dim"],),name='z')([z_mean,z_log_var])

# instantiate encoder model
encoder = Model(inputs,[z_mean,z_log_var,z],name='encoder')

# build decoder model
latent_inputs = Input(shape=(params["latent_dim"],name='z_sampling')
x = Dense(params["inner_dim"],activation='relu')(latent_inputs)
x = Dense(shape[1] * shape[2] * shape[3],activation='relu')(x)
x = Reshape((shape[1],shape[2],shape[3]))(x)

len_batchnorm = len(params["batchnorm"])
len_dropout = len(params["dropout"])
for i in range(len(params["conv_filter_decoder"])):
    x,params["conv_filter_decoder"][i],params["conv_kernel_size_decoder"][i],params["strides_decoder"][i],batchnorm=params["batchnorm"][len_batchnorm-i-1],dropout=params["dropout"][len_dropout-i-1],maxpool=None,deconv=True,activity_regularizer=params["activity_regularizer"])

outputs = Conv2DTranspose(filters=1,kernel_size=params["conv_kernel_size_decoder"][len(params["conv_kernel_size_decoder"])-1],activation='sigmoid',padding='same')(x)

# instantiate decoder model
decoder = Model(latent_inputs,outputs,name='decoder')

# instantiate VAE model
outputs = decoder(encoder(inputs)[2])
vae = Model(inputs,name='vae')
vae.higgins_beta = K.variable(value=params["beta"])
loss = config["loss"].value

def vae_loss(x,x_decoded_mean):
    """VAE loss function"""
    # VAE loss = mse_loss or xent_loss + kl_loss
    if loss == Loss.mse.value:
        reconstruction_loss = mse(K.flatten(x),K.flatten(x_decoded_mean))
    elif loss == Loss.bce.value:
        reconstruction_loss = binary_crossentropy(K.flatten(x),K.flatten(x_decoded_mean))
    else:
        raise ValueError("Loss unkNown")

    reconstruction_loss *= image_size[0] * image_size[1]
    kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
    kl_loss = K.sum(kl_loss,axis=-1)
    # kl_loss *= -0.5
    kl_loss *= -vae.higgins_beta
    vae_loss = K.mean(reconstruction_loss + kl_loss)

    return vae_loss

batch_size = params["batch_size"]
optimizer = keras.optimizers.Adam(lr=params["learning_rate"],beta_1=0.9,beta_2=0.999,epsilon=1e-08,decay=params["learning_rate_decay"])
vae.compile(loss=vae_loss,optimizer=optimizer)

vae.fit(train_X,train_X,epochs=config.CONfig["n_epochs"],batch_size=batch_size,verbose=0,callbacks=get_callbacks(config.CONfig,autoencoder_path,encoder,decoder,vae),shuffle=shuffle,validation_data=(valid_X,valid_X))

我连接到编码器的完全连接的网络如下所示:

latent = vae.predict(images)[0]
inputs = Input(shape=(input_shape,name='fc_input')
den = inputs
for i in range(len(self.params["units"])):
    den = Dense(self.params["units"][i])(den)

    den = Activation('relu')(den)

out = Dense(self.num_classes,activation='softmax')(den)

model = Model(inputs,out,name='fcnn')

optimizer = keras.optimizers.Adam(lr=self.mc.CONfig["fcnn"]["learning_rate"],decay=self.mc.CONfig["fcnn"]["learning_rate_decay"])

model.compile(loss='categorical_crossentropy',optimizer=optimizer,metrics=['accuracy'])

    model.fit(latent,y,epochs=self.params["n_epochs"],batch_size=self.params["batch_size"],shuffle=True)

 y_prob = model.predict(latent)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。