微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

tensorflow.keras搭建gan神经网络,可直接运行

tensorflow.keras搭建gan神经网络,可直接运行

文章目录


前言

keras是tensorflow的一个高级API库之一,代码简洁,可读性强。本文采用tensorflow.keras来实现gan网络。具体的原理在本文不作过多阐述,只作为一个案例交流

#####keras中文参考文档


正文

一、tf.keras搭建gan网络大致步骤

1.首先我们需要将所有的图像数据装换为tensorflow提供的tfrecords的格式,利用creat_tfrecords.py文件生成即可(这个文件是我原来用作图像分类标签生成的脚本文件,如果做gan网络不需要将标签也保存)
2.利用生成的tfrecords文件来建立数据集,利用tf.data.TFRecordDataset来进行设置,本文还提供了另一种方法来对tfrecords数据进行获取,但是殊途同归,方法都差不多
3.搭建generator网络
4.搭建discriminator网络,整合为gan网络(需要在gan网络compile之前将discriminator网络设置为不可训练)
5.建立循环体分别训练generator网络和discriminator网络
6.保存网络gan.model

二、使用步骤

1.制作tfrecords数据集

creat_tfrecords.py
生成tfrecords位置为 filename_train="./data/train.tfrecords"
终端输入:python creat_tfrecords.py --data [数据集位置]
生成train.tfrecords,也可以自己动手添加验证集和测试集的数据

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt 
import os
from PIL import Image
import random

objects = ['cat','dog']#'cat'0,'dog'1

filename_train="./data/train.tfrecords"
writer_train= tf.python_io.TFRecordWriter(filename_train)

tf.app.flags.DEFINE_string(
    'data', 'None', 'where the datas?.')
FLAGS = tf.app.flags.FLAGS

if(FLAGS.data == None):
    os._exit(0)

dim = (224,224)
object_path = FLAGS.data
total = os.listdir(object_path)
for index in total:
    img_path=os.path.join(object_path,index)
    img=Image.open(img_path)
    img=img.resize(dim)
    img_raw=img.tobytes()
    for i in range(len(objects)):
        if objects[i] in index:
            value = i
        else:
            continue
    example = tf.train.Example(features=tf.train.Features(feature={
                'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[value])),
                'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
                }))
    print([index,value])
    writer_train.write(example.SerializetoString())  #序列化为字符串
writer_train.close()

2.读入数据

利用tf.data.TFRecordDataset建立
代码如下:(load_image函数用来作为map的输入,对数据集进行解码),在main函数调用
train_datas,iter = dataset_tfrecords(tfrecords_path,use_keras_fit=False)

def load_image(serialized_example):   
    features={
        'label': tf.io.FixedLenFeature([], tf.int64),
        'img_raw' : tf.io.FixedLenFeature([], tf.string)}
    parsed_example = tf.io.parse_example(serialized_example,features)
    image = tf.decode_raw(parsed_example['img_raw'],tf.uint8)
    image = tf.reshape(image,[-1,224,224,3])
    image = tf.cast(image,tf.float32)*(1./255)
    label = tf.cast(parsed_example['label'], tf.int32)
    label = tf.reshape(label,[-1,1])
    return image,label
 
def dataset_tfrecords(tfrecords_path,use_keras_fit=True): 
    #是否使用tf.keras
    if use_keras_fit:
        epochs_data = 1
    else:
        epochs_data = epochs
    dataset = tf.data.TFRecordDataset([tfrecords_path])#这个可以有多个组成[tfrecords_name1,tfrecords_name2,...],可以用os.listdir(tfrecords_path):
    dataset = dataset\
                .repeat(epochs_data)\
                .batch(batch_size)\
                .map(load_image,num_parallel_calls = 2)\
                .shuffle(1000)

    iter = dataset.make_one_shot_iterator()#make_initialization_iterator
    train_datas = iter.get_next() #用train_datas[0],[1]的方式得到值
    return train_datas,iter

3.搭建gan网络

a.搭建generator网络

    generator = keras.models.Sequential([
            #fullyconnected nets
            keras.layers.Dense(256,activation='selu',input_shape=[coding_size]),
            keras.layers.Dense(64,activation='selu'),
            keras.layers.Dense(256,activation='selu'),
            keras.layers.Dense(1024,activation='selu'),
            keras.layers.Dense(7*7*64,activation='selu'),
            keras.layers.Reshape([7,7,64]),
            #7*7*64
            #反卷积
            keras.layers.Conv2DTranspose(64,kernel_size=3,strides=2,padding='same',activation='selu'),
            #14*14*64
            keras.layers.Conv2DTranspose(64,kernel_size=3,strides=2,padding='same',activation='selu'),
            #28*28*64
            keras.layers.Conv2DTranspose(32,kernel_size=3,strides=2,padding='same',activation='selu'),
            #56*56*32
            keras.layers.Conv2DTranspose(16,kernel_size=3,strides=2,padding='same',activation='selu'),
            #112*112*16
            keras.layers.Conv2DTranspose(3,kernel_size=3,strides=2,padding='same',activation='tanh'),#使用tanh代替sigmoid
            #224*224*3
            keras.layers.Reshape([224,224,3])
            ])

b.搭建discriminator网络

    discriminator = keras.models.Sequential([
            keras.layers.Conv2D(128,kernel_size=3,padding='same',strides=2,activation='selu',input_shape=[224,224,3]),
            keras.layers.MaxPool2D(pool_size=2),
            #56*56*128
            keras.layers.Conv2D(64,kernel_size=3,padding='same',strides=2,activation='selu'),
            keras.layers.MaxPool2D(pool_size=2),
            #14*14*64
            keras.layers.Conv2D(32,kernel_size=3,padding='same',strides=2,activation='selu'),
            #7*7*32
            keras.layers.Flatten(),
            #dropout 0.4
            keras.layers.Dropout(0.4),
            keras.layers.Dense(512,activation='selu'),
            keras.layers.Dropout(0.4),
            keras.layers.Dense(64,activation='selu'),
            keras.layers.Dropout(0.4),
            #the last net
            keras.layers.Dense(1,activation='sigmoid')
            ])

c.整合generator,discriminator网络为gan网络

gan = keras.models.Sequential([generator,discriminator])

4.complie编译(建立loss和optimizer优化器)

    #compile the net
    discriminator.compile(loss="binary_crossentropy",optimizer='rmsprop')# metrics=['accuracy'])
    discriminator.trainable=False
    gan.compile(loss="binary_crossentropy",optimizer='rmsprop')# metrics=['accuracy'])

5.训练网络(建立循环)

获取数据集:

train_datas,iter = dataset_tfrecords(tfrecords_path,use_keras_fit=False)

循环体:(在里面使用cv2来对generator网络查看)

    generator,discriminator = gan.layers
    sess = tf.Session()
    for step in range(num_steps):
        #get the time
        start_time = time.time()
        #phase 1 - training the discriminator
        noise = np.random.normal(size=batch_size*coding_size).reshape([batch_size,coding_size])
        noise = np.cast[np.float32](noise)
        generated_images = generator.predict(noise)
        train_datas_ = sess.run(train_datas)
        x_fake_and_real = np.concatenate([generated_images,train_datas_[0]],axis = 0)#np.concatenate
        #千万不能再循环体内用tf.concat,不能用tf相关的函数在循环体内定义
        #否则内存会被耗尽,而且训练速度越来越慢
        y1 = np.array([[0.]]*batch_size+[[1.]]*batch_size)
        discriminator.trainable = True
        dis_loss = discriminator.train_on_batch(x_fake_and_real,y1)
        #将keras 的train_on_batch函数放在gan网络中是明智之举
        #phase 2 - training the generator
        noise = np.random.normal(size=batch_size*coding_size).reshape([batch_size,coding_size])
        noise = np.cast[np.float32](noise)
        y2 = np.array([[1.]]*batch_size)
        discriminator.trainable = False
        ad_loss = gan.train_on_batch(noise,y2)
        duration = time.time()-start_time
        if step % 5 == 0:
            #gan.save_weights('gan.h5')
            print("The step is %d,discriminator loss:%.3f,adversarial loss:%.3f"%(step,dis_loss,ad_loss),end=' ')
            print('%.2f s/step'%(duration))
        if step % 30 == 0 and step != 0:
            noise = np.random.normal(size=[1,coding_size])
            noise = np.cast[np.float32](noise)
            fake_image = generator.predict(noise,steps=1)
            #复原图像
            #1.乘以255后需要映射成uint8的类型
            #2.也可以保持[0,1]的float32类型,依然可以直接输出
            arr_img = np.array([fake_image],np.float32).reshape([224,224,3])*255
            arr_img = np.cast[np.uint8](arr_img)
            #保存为tfrecords用的是PIL.Image,即打开为RGB,所以在用cv显示时需要转换为BGR
            arr_img = cv2.cvtColor(arr_img,cv2.COLOR_RGB2BGR)
            cv2.imshow('fake image',arr_img)
            cv2.waitKey(1500)#show the fake image 1.5s
            cv2.destroyAllWindows()

6.保存网络

    #save the models 
    model_vision = '0001'
    model_name = 'gans'
    model_path = os.path.join(model_name,model_name)
    tf.saved_model.save(gan,model_path)

7.完整的gans.py(可运行)

# -*- coding: utf-8 -*-
'''
    @author:zyl
    author is zouyuelin
    a Master of Tianjin University(TJU)
'''

import tensorflow as tf
from tensorflow import keras
#tf.enable_eager_execution()
import numpy as np
from PIL import Image
import os
import cv2
import time

batch_size = 32
epochs = 120
num_steps = 2000
coding_size = 30
tfrecords_path = 'data/train.tfrecords'

#--------------------------------------datasetTfrecord----------------   
def load_image(serialized_example):   
    features={
        'label': tf.io.FixedLenFeature([], tf.int64),
        'img_raw' : tf.io.FixedLenFeature([], tf.string)}
    parsed_example = tf.io.parse_example(serialized_example,features)
    image = tf.decode_raw(parsed_example['img_raw'],tf.uint8)
    image = tf.reshape(image,[-1,224,224,3])
    image = tf.cast(image,tf.float32)*(1./255)
    label = tf.cast(parsed_example['label'], tf.int32)
    label = tf.reshape(label,[-1,1])
    return image,label
 
def dataset_tfrecords(tfrecords_path,use_keras_fit=True): 
    #是否使用tf.keras
    if use_keras_fit:
        epochs_data = 1
    else:
        epochs_data = epochs
    dataset = tf.data.TFRecordDataset([tfrecords_path])#这个可以有多个组成[tfrecords_name1,tfrecords_name2,...],可以用os.listdir(tfrecords_path):
    dataset = dataset\
                .repeat(epochs_data)\
                .batch(batch_size)\
                .map(load_image,num_parallel_calls = 2)\
                .shuffle(1000)

    iter = dataset.make_one_shot_iterator()#make_initialization_iterator
    train_datas = iter.get_next() #用train_datas[0],[1]的方式得到值
    return train_datas,iter

#------------------------------------tf.TFRecordReader-----------------
def read_and_decode(tfrecords_path):
    #根据文件生成一个队列
    filename_queue = tf.train.string_input_producer([tfrecords_path],shuffle=True) 
    reader = tf.TFRecordReader()
    _,  serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,features={
        'label': tf.FixedLenFeature([], tf.int64),
        'img_raw' : tf.FixedLenFeature([], tf.string)})

    image = tf.decode_raw(features['img_raw'], tf.uint8)
    image = tf.reshape(image,[224,224,3])#reshape 200*200*3
    image = tf.cast(image,tf.float32)*(1./255)#image张量可以除以255,*(1./255)
    label = tf.cast(features['label'], tf.int32)
    img_batch, label_batch = tf.train.shuffle_batch([image,label],
                    batch_size=batch_size,
                    num_threads=4,
                    capacity= 640,
                    min_after_dequeue=5)
    return [img_batch,label_batch]

#Autodecode 解码器
def autoencode():
        encoder = keras.models.Sequential([
            keras.layers.Conv2D(32,kernel_size=3,padding='same',strides=2,activation='selu',input_shape=[224,224,3]),
            #112*112*32
            keras.layers.MaxPool2D(pool_size=2),
            #56*56*32
            keras.layers.Conv2D(64,kernel_size=3,padding='same',strides=2,activation='selu'),
            #28*28*64
            keras.layers.MaxPool2D(pool_size=2),
            #14*14*64
            keras.layers.Conv2D(128,kernel_size=3,padding='same',strides=2,activation='selu'),
            #7*7*128
            #反卷积
            keras.layers.Conv2DTranspose(128,kernel_size=3,strides=2,padding='same',activation='selu'),
            #14*14*128
            keras.layers.Conv2DTranspose(64,kernel_size=3,strides=2,padding='same',activation='selu'),
            #28*28*64
            keras.layers.Conv2DTranspose(32,kernel_size=3,strides=2,padding='same',activation='selu'),
            #56*56*32
            keras.layers.Conv2DTranspose(16,kernel_size=3,strides=2,padding='same',activation='selu'),
            #112*112*16
            keras.layers.Conv2DTranspose(3,kernel_size=3,strides=2,padding='same',activation='tanh'),#使用tanh代替sigmoid
            #224*224*3
            keras.layers.Reshape([224,224,3])
            ])
        return encoder

def training_keras():
    '''
        卷积和池化输出公式:
            output_size = (input_size-kernel_size+2*padding)/strides+1
            
        keras的反卷积输出计算,一般不用out_padding
        1.若padding = 'valid':
            output_size = (input_size - 1)*strides + kernel_size
        2.若padding = 'same:
            output_size = input_size * strides
    '''
    generator = keras.models.Sequential([
            #fullyconnected nets
            keras.layers.Dense(256,activation='selu',input_shape=[coding_size]),
            keras.layers.Dense(64,activation='selu'),
            keras.layers.Dense(256,activation='selu'),
            keras.layers.Dense(1024,activation='selu'),
            keras.layers.Dense(7*7*64,activation='selu'),
            keras.layers.Reshape([7,7,64]),
            #7*7*64
            #反卷积
            keras.layers.Conv2DTranspose(64,kernel_size=3,strides=2,padding='same',activation='selu'),
            #14*14*64
            keras.layers.Conv2DTranspose(64,kernel_size=3,strides=2,padding='same',activation='selu'),
            #28*28*64
            keras.layers.Conv2DTranspose(32,kernel_size=3,strides=2,padding='same',activation='selu'),
            #56*56*32
            keras.layers.Conv2DTranspose(16,kernel_size=3,strides=2,padding='same',activation='selu'),
            #112*112*16
            keras.layers.Conv2DTranspose(3,kernel_size=3,strides=2,padding='same',activation='tanh'),#使用tanh代替sigmoid
            #224*224*3
            keras.layers.Reshape([224,224,3])
            ])
            
    discriminator = keras.models.Sequential([
            keras.layers.Conv2D(128,kernel_size=3,padding='same',strides=2,activation='selu',input_shape=[224,224,3]),
            keras.layers.MaxPool2D(pool_size=2),
            #56*56*128
            keras.layers.Conv2D(64,kernel_size=3,padding='same',strides=2,activation='selu'),
            keras.layers.MaxPool2D(pool_size=2),
            #14*14*64
            keras.layers.Conv2D(32,kernel_size=3,padding='same',strides=2,activation='selu'),
            #7*7*32
            keras.layers.Flatten(),
            #dropout 0.4
            keras.layers.Dropout(0.4),
            keras.layers.Dense(512,activation='selu'),
            keras.layers.Dropout(0.4),
            keras.layers.Dense(64,activation='selu'),
            keras.layers.Dropout(0.4),
            #the last net
            keras.layers.Dense(1,activation='sigmoid')
            ])
    #gans network        
    gan = keras.models.Sequential([generator,discriminator])
    
    #compile the net
    discriminator.compile(loss="binary_crossentropy",optimizer='rmsprop')# metrics=['accuracy'])
    discriminator.trainable=False
    gan.compile(loss="binary_crossentropy",optimizer='rmsprop')# metrics=['accuracy'])
    
    #dataset
    #train_datas = read_and_decode(tfrecords_path)
    train_datas,iter = dataset_tfrecords(tfrecords_path,use_keras_fit=False)
    
    #sess = tf.Session()
    #sess.run(iter.initializer)
    
    generator,discriminator = gan.layers
    print("-----------------start---------------")
    sess = tf.Session()
    for step in range(num_steps):
        #get the time
        start_time = time.time()
        #phase 1 - training the discriminator
        noise = np.random.normal(size=batch_size*coding_size).reshape([batch_size,coding_size])
        noise = np.cast[np.float32](noise)
        generated_images = generator.predict(noise)
        train_datas_ = sess.run(train_datas)
        x_fake_and_real = np.concatenate([generated_images,train_datas_[0]],axis = 0)#np.concatenate
        #千万不能再循环体内用tf.concat,不能用tf相关的函数在循环体内定义
        #否则内存会被耗尽,而且训练速度越来越慢
        y1 = np.array([[0.]]*batch_size+[[1.]]*batch_size)
        discriminator.trainable = True
        dis_loss = discriminator.train_on_batch(x_fake_and_real,y1)
        #将keras 的train_on_batch函数放在gan网络中是明智之举
        #phase 2 - training the generator
        noise = np.random.normal(size=batch_size*coding_size).reshape([batch_size,coding_size])
        noise = np.cast[np.float32](noise)
        y2 = np.array([[1.]]*batch_size)
        discriminator.trainable = False
        ad_loss = gan.train_on_batch(noise,y2)
        duration = time.time()-start_time
        if step % 5 == 0:
            #gan.save_weights('gan.h5')
            print("The step is %d,discriminator loss:%.3f,adversarial loss:%.3f"%(step,dis_loss,ad_loss),end=' ')
            print('%.2f s/step'%(duration))
        if step % 30 == 0 and step != 0:
            noise = np.random.normal(size=[1,coding_size])
            noise = np.cast[np.float32](noise)
            fake_image = generator.predict(noise,steps=1)
            #复原图像
            #1.乘以255后需要映射成uint8的类型
            #2.也可以保持[0,1]的float32类型,依然可以直接输出
            arr_img = np.array([fake_image],np.float32).reshape([224,224,3])*255
            arr_img = np.cast[np.uint8](arr_img)
            #保存为tfrecords用的是PIL.Image,即打开为RGB,所以在用cv显示时需要转换为BGR
            arr_img = cv2.cvtColor(arr_img,cv2.COLOR_RGB2BGR)
            cv2.imshow('fake image',arr_img)
            cv2.waitKey(1500)#show the fake image 1.5s
            cv2.destroyAllWindows()
            
    #save the models 
    model_vision = '0001'
    model_name = 'gans'
    model_path = os.path.join(model_name,model_name)
    tf.saved_model.save(gan,model_path)
    
def main():
    training_keras()
main()

至此便完成了简单的gan训练


参考资料

论文:《Generative Adversarial Networks》
参考源码:
https://github.com/eriklindernoren/Keras-GAN/blob/master/gan/gan.py
参考博客:
https://blog.csdn.net/u010138055/article/details/94441812

最后的话

深度学习、机器学习的学渣小硕一枚,刚起步,不足的地方还请大家多多指教。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐