微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

当 run_eagerly = False

如何解决当 run_eagerly = False

当我对 Model 类进行子类化时,我在尝试评估在自定义 train_step() 中计算的张量时遇到了问题。当我在 Tensor.numpy() 内传递 run_eagerly = True 时,我可以使用 model.compile(...) 但据我所知,这效率不高。我尝试了其他建议,例如使用 Tensor.eval()(无认会话)、backend.get_values(Tensor)(Tensor 没有属性 numpy())等,但没有成功。我有一个我想在下面实现的简化示例:

class CustomModel(keras.Model):

def __init__(self,**kwargs):
    super().__init__(**kwargs)
    self.saved_pred = []

def train_step(self,data):
    x,y = data
    
    with tf.GradientTape() as tape:
        y_pred = self(x,training=True)  # Forward pass
        loss = self.compiled_loss(y,y_pred,regularization_losses=self.losses)

    # Compute gradients
    trainable_vars = self.trainable_variables
    gradients = tape.gradient(loss,trainable_vars)
    # Update weights
    self.optimizer.apply_gradients(zip(gradients,trainable_vars))
    # Update metrics (includes the metric that tracks the loss)
    self.compiled_metrics.update_state(y,y_pred)
    # Return a dict mapping metric names to current value

    self.saved_pred.append(y_pred)

    return {m.name: m.result() for m in self.metrics}

import numpy as np

inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs = inputs,outputs = outputs)
model.compile(optimizer=“adam”,loss=“mse”,metrics=[“mae”])

x = np.random.random((1000,32))
y = np.random.random((1000,1))
model.fit(x,y,epochs=3)

当我打印 model.saved_pred 时,我得到以下信息:

ListWrapper([<tf.Tensor ‘custom_model_1/dense_5/BiasAdd:0’ shape=(None,1) dtype=float32>,<tf.Tensor ‘custom_model_1/dense_5/BiasAdd:0’ shape=(None,1) dtype=float32>])

是否有某种方法可以在 train_step() 内部或在 model.fit() 之后提取这些张量的值(作为 numpy 数组)?

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?