微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

无法计算ConcatV2,因为输入#1从零开始应该是浮点张量,但是是双张量[Op:ConcatV2]名称:concat

如何解决无法计算ConcatV2,因为输入#1从零开始应该是浮点张量,但是是双张量[Op:ConcatV2]名称:concat

导入库

%matplotlib inline
import tensorflow as tf
from tensorflow import keras
import numpy as np
import plot_utils
import matplotlib.pyplot as plt
from tqdm import tqdm
print('Tensorflow version:',tf.__version__)

任务3:创建一批培训数据

batch_size = 32
# This dataset fills a buffer with buffer_size elements,#then randomly samples elements from this buffer,replacing the selected elements with new elements.
dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(1000)
#Combines consecutive elements of this dataset into batches.
dataset = dataset.batch(batch_size,drop_remainder=True).prefetch(1)
#Creates a Dataset that prefetches elements from this dataset


print(dataset)
output:<PrefetchDataset shapes: (32,32,3),types: tf.float64>

任务4:为DCGAN建立发电机网络

num_features = 100
generator = keras.models.Sequential([
    keras.layers.Dense(256*4*4,input_shape=[num_features]),keras.layers.Reshape([4,4,256]),keras.layers.Batchnormalization(),keras.layers.Conv2DTranspose(128,(4,4),(2,2),padding="same",activation="selu"),keras.layers.Conv2DTranspose(3,(3,activation="tanh"),])

import numpy as np
import matplotlib.pyplot as plt

def show(images,n_cols=None):
    n_cols = n_cols or len(images)
    n_rows = (len(images) - 1) // n_cols + 1
    if images.shape[-1] == 1:
        images = np.squeeze(images,axis=-1)
    plt.figure(figsize=(n_cols,n_rows))
    for index,image in enumerate(images):
        plt.subplot(n_rows,n_cols,index + 1)
        plt.imshow(image,cmap="binary")
        plt.axis("off")

noise = tf.random.normal(shape=[1,num_features])
generated_images = generator(noise,training=False)
show(generated_images,1)

任务5:为DCGAN建立鉴别器网络

discriminator = keras.models.Sequential([
    keras.layers.Conv2D(64,input_shape=[32,3]),keras.layers.LeakyReLU(0.2),keras.layers.Dropout(0.3),keras.layers.Conv2D(128,padding="same"),keras.layers.Conv2D(256,keras.layers.Flatten(),keras.layers.Dense(1,activation='sigmoid')
])

decision = discriminator(generated_images)
print(decision)
output:tf.Tensor([[0.5006197]],shape=(1,1),dtype=float32)

任务6:编译深度卷积生成对抗网络(DCGAN

discriminator.compile(loss="binary_crossentropy",optimizer="rmsprop")
discriminator.trainable = False
gan = keras.models.Sequential([generator,discriminator])
gan.compile(loss="binary_crossentropy",optimizer="rmsprop")


from IPython import display
from tqdm import tqdm
seed = tf.random.normal(shape=[batch_size,100])

任务7:定义培训程序

from tqdm import tqdm
def train_DCGAN(gan,dataset,batch_size,num_features,epochs=5):
    generator,discriminator = gan.layers
    for epoch in tqdm(range(epochs)):
        print("Epoch {}/{}".format(epoch + 1,epochs))
        for X_batch in dataset:
            noise = tf.random.normal(shape=[batch_size,num_features])
            generated_images = generator(noise)
            X_fake_and_real = tf.concat([generated_images,X_batch],axis=0)
            y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
            discriminator.trainable = True
            discriminator.train_on_batch(X_fake_and_real,y1)
            noise = tf.random.normal(shape=[batch_size,num_features])
            y2 = tf.constant([[1.]] * batch_size)
            discriminator.trainable = False
            gan.train_on_batch(noise,y2)
            # Produce images for the GIF as we go
        display.clear_output(wait=True)
        generate_and_save_images(generator,epoch + 1,seed)
        
    display.clear_output(wait=True)
    generate_and_save_images(generator,epochs,seed)


## Source https://www.tensorflow.org/tutorials/generative/DCGAN#create_a_gif
def generate_and_save_images(model,epoch,test_input):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
  predictions = model(test_input,training=False)

  fig = plt.figure(figsize=(10,10))

  for i in range(25):
      plt.subplot(5,5,i+1)
      plt.imshow(predictions[i,:,0] * 127.5 + 127.5,cmap='binary')
      plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

任务8:训练DCGAN

x_train_DCGAN = x_train.reshape(-1,3) * 2. - 1.

batch_size = 32
dataset = tf.data.Dataset.from_tensor_slices(x_train_DCGAN)
dataset = dataset.shuffle(1000)
dataset = dataset.batch(batch_size,drop_remainder=True).prefetch(1)

这是主要问题

%%time
train_DCGAN(gan,epochs=10)**
output:
    7             noise = tf.random.normal(shape=[batch_size,num_features])
      8             generated_images = generator(noise)
----> 9             X_fake_and_real = tf.concat([generated_images,axis=0)
     10             y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
     11             discriminator.trainable = True
cannot compute ConcatV2 as input #1(zero-based) was expected to be a float tensor but is a double tensor [Op:ConcatV2] name: concat

是Cifar10 DCGAN,我真的不了解此错误以及如何解决

解决方法

默认情况下,Tensorflow使用float32。您必须将数据转换为tf.float32。

X = tf.cast(yourDATA,tf.float32) 
,

在执行 tf.concat 操作之前,以下代码段在受相同 tensorflow 示例启发的代码中对我有用:

X_batch = tf.cast(X_batch,tf.float32)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?