微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

TensorFlow Checkpoint变量未保存

如何解决TensorFlow Checkpoint变量未保存

我正在尝试将Checkpoint用于我的模型。在此之前,尝试了一个玩具示例。这没有错误运行。但是每次我跑步时,训练参数看起来都是从初始值开始的。不知道我是否在这里缺少什么?以下是即时通讯使用的代码

import numpy as np
import tensorflow as tf

X = tf.range(10.)
Y = 50.*X
    
class CGMM(object):
    def __init__(self):
        self.beta =  tf.Variable(1.,dtype=np.float32)

    @tf.function
    def objfun(self):
        beta = self.beta
        obj = tf.reduce_mean(tf.square(beta*self.X - self.Y))
        return obj

    def build_model(self,X,Y):
        self.X,self.Y=X,Y
        optimizer = tf.keras.optimizers.RMSprop(0.5)
        ckpt = tf.train.Checkpoint(step=tf.Variable(1),model =self.objfun,optimizer=optimizer)
        manager = tf.train.CheckpointManager(ckpt,'./tf_ckpts_cg',max_to_keep=3)

        ckpt.restore(manager.latest_checkpoint)
        if manager.latest_checkpoint:
            print("Restored from {}".format(manager.latest_checkpoint))
        else:
            print("Initializing from scratch.")

        for i in range(20):
            optimizer.minimize(self.objfun,var_list =  self.beta)
            loss,beta = self.objfun(),self.beta
            # print(self.beta.numpy())
            ckpt.step.assign_add(1)
            if int(ckpt.step) % 5 == 0:
              save_path = manager.save()
              print("Saved checkpoint for step {}: {}".format(int(ckpt.step),save_path))
              print("loss {:1.2f}".format(loss.numpy()))
              print("beta {:1.2f}".format(beta.numpy()))

        return beta


model =CGMM()
opt_beta = model.build_model(X,Y)

结果第一次运行:

Initializing from scratch.
Saved checkpoint for step 5: ./tf_ckpts_cg/ckpt-1
loss 56509.74
beta 5.47
Saved checkpoint for step 10: ./tf_ckpts_cg/ckpt-2
loss 48354.54
beta 8.81
Saved checkpoint for step 15: ./tf_ckpts_cg/ckpt-3
loss 42085.54
beta 11.57
Saved checkpoint for step 20: ./tf_ckpts_cg/ckpt-4
loss 36750.57
beta 14.09

第二次结果:

Restored from ./tf_ckpts_cg/ckpt-4
Saved checkpoint for step 25: ./tf_ckpts_cg/ckpt-5
loss 54619.16
beta 6.22
Saved checkpoint for step 30: ./tf_ckpts_cg/ckpt-6
loss 46997.79
beta 9.39
Saved checkpoint for step 35: ./tf_ckpts_cg/ckpt-7
loss 40958.30
beta 12.09
Saved checkpoint for step 40: ./tf_ckpts_cg/ckpt-8
loss 35763.21
beta 14.58

解决方法

出了什么问题:

主要问题是您的beta变量不可跟踪:这意味着检查点对象将不会保存它。通过使用以下功能检查检查点的内容,我们可以看到:

>>> tf.train.list_variables(tf.train.latest_checkpoint('./tf_ckpts_cg/')

[('_CHECKPOINTABLE_OBJECT_GRAPH',[]),('optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',('optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',('optimizer/momentum/.ATTRIBUTES/VARIABLE_VALUE',('optimizer/rho/.ATTRIBUTES/VARIABLE_VALUE',('save_counter/.ATTRIBUTES/VARIABLE_VALUE',('step/.ATTRIBUTES/VARIABLE_VALUE',[])]

由检查点跟踪的唯一tf.Variable是优化程序中的那个,以及tf.train.Checkpoint对象本身使用的那些。


可能的解决方案:

要更改此设置,您需要跟踪变量。关于该主题的TensorFlow文档并不出色,但是经过一番搜索之后,您可以在tf.Variable文档中阅读以下内容:

将变量分配给的属性时会自动对其进行跟踪 从tf.Module继承的类型。

[...]

此跟踪然后允许 将变量值保存到训练检查点或SavedModels 其中包括序列化的TensorFlow图。

因此,通过使CGMM类继承自tf.Module,您可以跟踪beta变量并将其恢复!这是对您的代码的非常直接的更改:

class CGMM(tf.Module):
    def __init__(self):
        super(CGMM,self).__init__(name='CGMM')
        self.beta =  tf.Variable(1.,dtype=np.float32)

我们还需要告诉Checkpoint对象该模型现在是CGMM对象:

ckpt = tf.train.Checkpoint(step=tf.Variable(1),model=self,optimizer=optimizer)

现在,如果我们训练一些步骤并查看检查点文件的内容,我们将获得一些希望。 Beta变量现已保存:

>>> tf.train.list_variables(tf.train.latest_checkpoint('./tf_ckpts_cg/'))

[('_CHECKPOINTABLE_OBJECT_GRAPH',('model/beta/.ATTRIBUTES/VARIABLE_VALUE',('model/beta/.OPTIMIZER_SLOT/optimizer/rms/.ATTRIBUTES/VARIABLE_VALUE',[])]

如果我们多次运行该程序,则会得到:

>>> run tf-ckpt.py

Restored from ./tf_ckpts_cg/ckpt-28
Saved checkpoint for step 145: ./tf_ckpts_cg/ckpt-29
loss 0.00
beta 49.99
Saved checkpoint for step 150: ./tf_ckpts_cg/ckpt-30
loss 0.00
beta 50.00

万岁!


注意:为了跟踪变量,您还可以使用任何类型的keras.layers.Layer和任何keras.Model。这可能是最简单的方法。

Training Checkpoint guide的摘录:

tf.train.Checkpoint,tf.keras.layers.Layer和 tf.keras.Model自动跟踪分配给它们的变量 属性。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。