微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Tensorflow 废话重塑值

如何解决Tensorflow 废话重塑值

在使用 tensorflow.keras.layers.Reshape 时,我遇到了奇怪的错误。它从哪里获得 47409408 值? 207936 对应正确的大小(69312*3)。

一个奇怪的方面是,如果我在重塑之前放了一个展平层。

Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None,304,228,3)       30        
_________________________________________________________________
reshape (Reshape)            (None,69312,3)          0         
=================================================================
Total params: 30
Trainable params: 30
Non-trainable params: 0
____________________________________

(0) 无效参数:reshape 的输入是一个有 207936 个值的张量,但请求的形状有 47409408

import tensorflow as tf
import numpy as np
from sklearn.model_selection import train_test_split
from PIL import Image
from tensorflow.keras import datasets,layers,models,preprocessing
import os
from natsort import natsorted
from tensorflow.keras.models import Model

BATCH_SIZE = 32
EPOCHS = 15
LEARNING_RATE = 1e-4

#jpegs with values from 0 to 255
img_dir = ".../normalized_imgs"
# .npy files of size (69312,3)
pts_dir = ".../normalized_pts"
img_files = [os.path.join(img_dir,f)
               for f in natsorted(os.listdir(img_dir))]

pts_files = [os.path.join(pts_dir,f)
            for f in natsorted(os.listdir(pts_dir))]

img = Image.open(img_files[0])
pts = np.load(pts_files[0])

def parse_img_input(img_file,pts_file):
        def _parse_input(img_file,pts_file):
                # get image
                d_filepath = img_file.numpy().decode()
                d_image_decoded = tf.image.decode_jpeg(tf.io.read_file(d_filepath),channels=1)
                d_image = tf.cast(d_image_decoded,tf.float32) / 255.0
    
                # get numpy data
                pts_filepath = pts_file.numpy().decode()
                pts = np.load(pts_filepath,allow_pickle= True)

                print("d_image ",d_image.shape )
                return d_image,pts
        return tf.py_function(_parse_input,inp=[img_file,pts_file],Tout=[tf.float32,tf.float32])

class SimpleCNN(Model):
        def __init__(self):
                super(SimpleCNN,self).__init__()
                input_shape = (img.size[0],img.size[1],1)
                self.model = model = models.Sequential()
                model.add(tf.keras.Input(shape= input_shape))
                model.add(layers.Conv2D(3,(3,3),padding='same'))
                model.add(layers.Reshape((pts.shape[0],pts.shape[1])))


# split input data into train,test sets
X_train_file,X_test_file,y_train_file,y_test_file = train_test_split(img_files,pts_files,test_size=0.2,random_state=0)

model = SimpleCNN()

dataset_train = tf.data.Dataset.from_tensor_slices((X_train_file,y_train_file))
dataset_train = dataset_train.map(parse_img_input)

dataset_test = tf.data.Dataset.from_tensor_slices((X_test_file,y_test_file))
dataset_test = dataset_test.map(parse_img_input)

model.compile(optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),loss= tf.losses.MeanSquaredError(),metrics= [tf.keras.metrics.get('accuracy')])
model.fit(dataset_train,epochs=EPOCHS,shuffle=True,validation_data= dataset_test)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。