微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Fit_generator 在纪元结束时产生两个详细信息

如何解决Fit_generator 在纪元结束时产生两个详细信息

我正在使用类似的混合生成

import numpy as np
from tensorflow.keras.utils import Sequence

class MixupGenerator(Sequence):
    def __init__(self,x_train,y_train,batch_size=32,alpha=0.2,shuffle=True):
        self.X_train = x_train
        self.y_train = y_train
        self.batch_size = batch_size
        self.alpha = alpha
        self.shuffle = shuffle
        self.sample_num = len(x_train)
        self.lock = threading.Lock()

    def __iter__(self):
        return self

    #@threadsafe_generator
    def __call__(self):
        with self.lock:
            while True:
                indexes = self.__get_exploration_order()
                itr_num = int(len(indexes) // (self.batch_size * 2))

                for i in range(itr_num):
                    batch_ids = indexes[i * self.batch_size * 2:(i + 1) * self.batch_size * 2]
                    X,y = self.__data_generation(batch_ids)

                    yield X,y

    def __get_exploration_order(self):
        indexes = np.arange(self.sample_num)

        if self.shuffle:
            np.random.shuffle(indexes)

        return indexes

    def __data_generation(self,batch_ids):
        _,h,w,c = self.X_train.shape
        l = np.random.beta(self.alpha,self.alpha,self.batch_size)
        X_l = l.reshape(self.batch_size,1,1)
        y_l = l.reshape(self.batch_size,1)

        X1 = self.X_train[batch_ids[:self.batch_size]]
        X2 = self.X_train[batch_ids[self.batch_size:]]

        X = X1 * X_l + X2 * (1.0 - X_l)

        if isinstance(self.y_train,list):
            y = []

            for y_train_ in self.y_train:
                y1 = y_train_[batch_ids[:self.batch_size]]
                y2 = y_train_[batch_ids[self.batch_size:]]
                y.append(y1 * y_l + y2 * (1.0 - y_l))
        else:
            y1 = self.y_train[batch_ids[:self.batch_size]]
            y2 = self.y_train[batch_ids[self.batch_size:]]
            y = y1 * y_l + y2 * (1.0 - y_l)

        return X,y

我在训练期间有 13965 个样本,在测试期间有 2970 个样本。我称之为适合:

history = model.fit_generator(train_datagen,validation_data=(val_x,val_y),epochs=epochs,steps_per_epoch=np.ceil((x.shape[0] - 1) / config.batch_size),callbacks=callbacks,verbose=tr_verbose)

batch_size = 32

verbose 比较少,是不是因为 epochs 和 batch size 是十进制的?

时代 49/500 436/437 [============================>.] - ETA:0s - 损失:0.1408 - categorical_accuracy:0.8295Epoch 1/ 500 2968/437 [============================================== ================================================== ================================================== ================================================== ========] - 6s 2ms/sample - 损失:0.2304 - categorical_accuracy:0.5162 437/437 [==============================] - 131 秒 299 毫秒/步 - 损失:0.1409 - categorical_accuracy:0.8294 - val_loss :0.2510 - val_categorical_accuracy:0.5162

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?