微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

在 Tensorflow 中实现依赖于实例的/条件损失

如何解决在 Tensorflow 中实现依赖于实例的/条件损失

我希望能够根据给定批次中发生的实例计算条件损失(两种不同的损失)。我正在从头开始编写自定义 train_step,因为我相信这提供了实现我所想的灵活性。但是,我在弄清楚如何实现这一点上有点困难。

在每个训练步骤中,我计算批次中每个实例的真实标签和预测标签间的分类分类交叉熵)损失,这是标准的。此外,我包括一个正则化损失,它不是为批处理中的每个实例计算的,而是仅为实例的一个子集计算的。这就是为什么我提到一个条件损失或两个损失目标。

在训练之前,我已经指定了一个训练实例 ID 列表(每个训练实例都有一个唯一的 ID)。每当这些实例中的任何一个碰巧出现在当前批次中时,我都会仅使用这些特定实例来计算正则化项。如果这些情况都没有发生,我只计算标准分类损失。正则化项的目标是鼓励特定训练实例(由实例 id 指定)和一组附加实例(现在我们可以只假设单个实例)之间的特征相似性,以平方距离衡量。

这是我目前所拥有的。这不是一个有效的实现,但希望展示我所描述的以及我希望实现的目标。模型接受图像张量并输出特征表示(用于正则化项)和预测向量(用于分类损失)。随意忽略我正在使用的方法并建议替代方法。例如,创建自定义损失函数或使用 tf.cond 可能会有所帮助。注意:我使用的是 tensorflow 2/tf.GradientTape()

class MNIST_Classifier(tf.keras.Model):
    def __init__(self,model,train_sub_ids,reg_example,lmbda,**kwargs):
        super(MNIST_Classifier,self).__init__(**kwargs)
        self.model = model
        self.train_sub_ids = train_sub_ids
        self.reg_example = reg_examples
        self.lmbda = lmbda
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.classification_loss_tracker = keras.metrics.Mean(
            name="classification_loss"
        )
        self.reg_loss_tracker = keras.metrics.Mean(name="reg_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,self.classification_loss_tracker,self.reg_loss_tracker,]

    def train_step(self,data):
        with tf.GradientTape() as tape:
            ids,x,y = data # batch includes id,image,label
            _,y_pred = self.model(x) # get predictions,features don't matter for classification loss
            
            # compute classification loss for all instances 
            classification_loss = tf.reduce_mean(
              tf.keras.losses.categorical_crossentropy(y,y_pred)
            )

            # compute reg loss for subset of instances (Could be none)
            # step 1: obtain instances from batch where id is in self.train_sub_id
            # Todo: this won't work because it's not using tensor operations...need to replace 
            x_sub = [img for id,img in zip(ids,x) if any(id==i for i in self.train_sub_id)]
            if x_sub:
              features_sub,_  = self.model(x_sub)
              # step 2: compute features and predictions for reg example
              features_reg,_  = self.model(x_reg)
              # should still work if features_sub and features_reg are different shapes in batch (left most) dim
              reg_loss = tf.reduce_mean(tf.math.squared_difference(features_sub,features_reg))
            else:
              reg_loss = 0
            
            total_loss = classification_loss + self.lmbda*reg_loss

        variables = self.trainable_weights
        grads = tape.gradient(total_loss,variables)
        self.optimizer.apply_gradients(zip(grads,self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.classification_loss_tracker.update_state(classification_loss)
        self.reg_loss_tracker.update_state(reg_loss)

        return {
            "total_loss": self.total_loss_tracker.result(),"classification_loss": self.classification_loss_tracker.result(),"reg_loss": self.reg_loss_tracker.result(),}

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?