微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何从 ImageDataGenerator 获取 x_train?

如何解决如何从 ImageDataGenerator 获取 x_train?

我正在研究图像分类 CNN,我想知道如何使用 ImageGenerator 获得 x_train 和 y_train 形式。这样做的原因是我想将我的模型拟合为 fit note fit.generator

trainDataGen = ImageDataGenerator(rescale= 1./255,rotation_range =30,width_shift_range=0.1,height_shift_range=0.1,shear_range=0.2,zoom_range=0.2,horizontal_flip=False,vertical_flip=False,fill_mode='nearest',)

trainGenSet = trainDataGen.flow_from_directory(
    path +'Train',target_size=(28,28),batch_size=32,class_mode='categorical',color_mode='grayscale',)


x_train,y_train = trainGenSet.next()

打印(x_train)是(32,28,1) 因为批量大小为 32。 我总共有 5000 个火车数据集 我想得到 (5000,1) for print(x_train)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。