如何解决无法计算ConcatV2,因为输入#1从零开始应该是浮点张量,但是是双张量[Op:ConcatV2]名称:concat
导入库
%matplotlib inline
import tensorflow as tf
from tensorflow import keras
import numpy as np
import plot_utils
import matplotlib.pyplot as plt
from tqdm import tqdm
print('Tensorflow version:',tf.__version__)
任务3:创建一批培训数据
batch_size = 32
# This dataset fills a buffer with buffer_size elements,#then randomly samples elements from this buffer,replacing the selected elements with new elements.
dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(1000)
#Combines consecutive elements of this dataset into batches.
dataset = dataset.batch(batch_size,drop_remainder=True).prefetch(1)
#Creates a Dataset that prefetches elements from this dataset
print(dataset)
output:<PrefetchDataset shapes: (32,32,3),types: tf.float64>
任务4:为DCGAN建立发电机网络
num_features = 100
generator = keras.models.Sequential([
keras.layers.Dense(256*4*4,input_shape=[num_features]),keras.layers.Reshape([4,4,256]),keras.layers.Batchnormalization(),keras.layers.Conv2DTranspose(128,(4,4),(2,2),padding="same",activation="selu"),keras.layers.Conv2DTranspose(3,(3,activation="tanh"),])
import numpy as np
import matplotlib.pyplot as plt
def show(images,n_cols=None):
n_cols = n_cols or len(images)
n_rows = (len(images) - 1) // n_cols + 1
if images.shape[-1] == 1:
images = np.squeeze(images,axis=-1)
plt.figure(figsize=(n_cols,n_rows))
for index,image in enumerate(images):
plt.subplot(n_rows,n_cols,index + 1)
plt.imshow(image,cmap="binary")
plt.axis("off")
noise = tf.random.normal(shape=[1,num_features])
generated_images = generator(noise,training=False)
show(generated_images,1)
任务5:为DCGAN建立鉴别器网络
discriminator = keras.models.Sequential([
keras.layers.Conv2D(64,input_shape=[32,3]),keras.layers.LeakyReLU(0.2),keras.layers.Dropout(0.3),keras.layers.Conv2D(128,padding="same"),keras.layers.Conv2D(256,keras.layers.Flatten(),keras.layers.Dense(1,activation='sigmoid')
])
decision = discriminator(generated_images)
print(decision)
output:tf.Tensor([[0.5006197]],shape=(1,1),dtype=float32)
discriminator.compile(loss="binary_crossentropy",optimizer="rmsprop")
discriminator.trainable = False
gan = keras.models.Sequential([generator,discriminator])
gan.compile(loss="binary_crossentropy",optimizer="rmsprop")
from IPython import display
from tqdm import tqdm
seed = tf.random.normal(shape=[batch_size,100])
任务7:定义培训程序
from tqdm import tqdm
def train_DCGAN(gan,dataset,batch_size,num_features,epochs=5):
generator,discriminator = gan.layers
for epoch in tqdm(range(epochs)):
print("Epoch {}/{}".format(epoch + 1,epochs))
for X_batch in dataset:
noise = tf.random.normal(shape=[batch_size,num_features])
generated_images = generator(noise)
X_fake_and_real = tf.concat([generated_images,X_batch],axis=0)
y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
discriminator.trainable = True
discriminator.train_on_batch(X_fake_and_real,y1)
noise = tf.random.normal(shape=[batch_size,num_features])
y2 = tf.constant([[1.]] * batch_size)
discriminator.trainable = False
gan.train_on_batch(noise,y2)
# Produce images for the GIF as we go
display.clear_output(wait=True)
generate_and_save_images(generator,epoch + 1,seed)
display.clear_output(wait=True)
generate_and_save_images(generator,epochs,seed)
## Source https://www.tensorflow.org/tutorials/generative/DCGAN#create_a_gif
def generate_and_save_images(model,epoch,test_input):
# Notice `training` is set to False.
# This is so all layers run in inference mode (batchnorm).
predictions = model(test_input,training=False)
fig = plt.figure(figsize=(10,10))
for i in range(25):
plt.subplot(5,5,i+1)
plt.imshow(predictions[i,:,0] * 127.5 + 127.5,cmap='binary')
plt.axis('off')
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
plt.show()
任务8:训练DCGAN
x_train_DCGAN = x_train.reshape(-1,3) * 2. - 1.
batch_size = 32
dataset = tf.data.Dataset.from_tensor_slices(x_train_DCGAN)
dataset = dataset.shuffle(1000)
dataset = dataset.batch(batch_size,drop_remainder=True).prefetch(1)
这是主要问题
%%time
train_DCGAN(gan,epochs=10)**
output:
7 noise = tf.random.normal(shape=[batch_size,num_features])
8 generated_images = generator(noise)
----> 9 X_fake_and_real = tf.concat([generated_images,axis=0)
10 y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
11 discriminator.trainable = True
cannot compute ConcatV2 as input #1(zero-based) was expected to be a float tensor but is a double tensor [Op:ConcatV2] name: concat
是Cifar10 DCGAN,我真的不了解此错误以及如何解决。
解决方法
默认情况下,Tensorflow使用float32。您必须将数据转换为tf.float32。
X = tf.cast(yourDATA,tf.float32)
,
在执行 tf.concat 操作之前,以下代码段在受相同 tensorflow 示例启发的代码中对我有用:
X_batch = tf.cast(X_batch,tf.float32)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。