微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

具有约束权重的 keras 层在模型初始化期间引发错误

如何解决具有约束权重的 keras 层在模型初始化期间引发错误

我正在训练一个 actor-critic 模型,该模型在 actor 网络中具有一个受限层。该约束强制权重为对角线。下面是一个可重现的最小示例

import tensorflow.compat.v1 as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense,Input,Conv1D,Concatenate,Batchnormalization,Reshape
from tensorflow.keras.constraints import Constraint
from tensorflow.python.keras.utils.vis_utils import plot_model
from tensorflow.keras.optimizers import Adam
from tensorflow.python.keras.backend import set_session
from tensorflow.python.keras import backend as K


import numpy as np
import random
from collections import deque

tf.disable_v2_behavior() 

# For more repetitive results
np.random.seed(1)
random.seed(1)

class DiagonalWeight(Constraint):
    """Constrains the weights to be diagonal.
    """
    def __call__(self,w):
        N = K.int_shape(w)[-1]
        m = K.eye(N)
        return w*m

state_shape = (10,3)
class AC():
    def __init__(self,sess,LRA,LRC):
        self.sess = sess     #session
        
        self.LRA = LRA                                   #learning rate for actor
        self.LRC = LRC                                   #learning rate for critic
        
        self.graph = tf.get_default_graph()
        set_session(self.sess)

        # training actor and target actor 
        self.actor,self.input_actor = self.create_actor()
        plot_model(self.actor,to_file='Ac_architecture.png',show_shapes=True,show_layer_names=True)
        self.target_actor,_ = self.create_actor()
        ###initialize the weights of the target with the weights of training actor
        self.target_actor.set_weights(self.actor.get_weights())

        #training critic and target critic
        self.critic,self.critic_state_input,self.critic_action_input = self.create_critic()
        self.target_critic,_,_  = self.create_critic()
        plot_model(self.critic,to_file='Cr_architecture.png',show_layer_names=True)
        ###initialize the weights of the target with the weights of training critic
        self.target_critic.set_weights(self.critic.get_weights())


        ######################## Actor/Critic Grads ########################################################
        self.actor_critic_grad = tf.placeholder(tf.float32,[None,state_shape[0],state_shape[0]+2])   
        actor_weights = self.actor.trainable_weights
        self.actor_grads = tf.gradients(self.actor.output,actor_weights,-self.actor_critic_grad) 
        grads = zip(self.actor_grads,actor_weights)
        self.optimize = tf.train.AdamOptimizer(self.LRA).apply_gradients(grads) 
        self.critic_grads = tf.gradients(self.critic.output,self.critic_action_input)

        # Initialize for later gradient calculations
        self.sess.run(tf.global_variables_initializer())
        #######################################################################################################################

    def create_actor(self):
        actor_input = Input(shape=state_shape,name='state_input')     
        h3 = Conv1D(128,3,padding='same',activation='relu',name='h3')(actor_input)
        h3 = Batchnormalization(name='h3_BN')(h3)

        matrix = Conv1D(state_shape[0],name='matrix')(h3)
        vect0= Conv1D(1,name='vect0')(h3)
        vect0_resh = Reshape((1,state_shape[0]))(vect0)
        vect1 = Dense(state_shape[0],name='vect1',\
                              use_bias=False,kernel_constraint=DiagonalWeight())(vect0_resh) #,kernel_constraint=DiagonalWeight()
        vect1 = Reshape((state_shape[0],1))(vect1)
        actor_output = Concatenate(axis=-1)([vect0,vect1,matrix]) 

        model = Model(actor_input,actor_output)  
        adam  = Adam(lr=self.LRA)
        model.compile(loss="mse",optimizer=adam)
        return model,actor_input

    def create_critic(self):
        state_input = Input(shape=state_shape,name='state_input')     
        action_input = Input(shape=(state_shape[0],state_shape[0]+2),name='action_input')                                           
        critic_input = Concatenate(axis=-1)([state_input,action_input]) 
        h3 = Conv1D(128,name='h3')(critic_input)
        h3 = Batchnormalization(name='h3_BN')(h3)

        Q = Conv1D(state_shape[0]+2,name='Q')(h3)

        model = Model([state_input,action_input],Q)
        adam  = Adam(lr=self.LRC)
        model.compile(loss="mse",state_input,action_input


sess = tf.compat.v1.Session()
K.set_session(sess)
agent = AC(sess,0.01,0.001)

当我从密集层移除内核约束时,一切正常,但是当我添加它时,我得到以下三个错误之一:

tensorflow.python.framework.errors_impl.InvalidArgumentError: You must Feed a value for placeholder tensor 'Placeholder_19' with dtype float and shape [?,10,12]
         [[node Placeholder_19 (defined at AC.py:59) ]]

tensorflow.python.framework.errors_impl.InvalidArgumentError: You must Feed a value for placeholder tensor 'state_input' with dtype float and shape [?,3]
         [[node state_input (defined at AC.py:73) ]]

tensorflow.python.framework.errors_impl.FailedPreconditionError: Error while reading resource variable beta1_power from Container: localhost. This Could mean that the variable was uninitialized. Not found: Resource localhost/beta1_power/class tensorflow::Var does not exist.
         [[node Adam/update_vect1/kernel/ResourceApplyAdam/ReadVariableOp (defined at AC.py:64) ]]

密集层中的核是一个方阵,所以应用对角线约束应该没有问题。

解决方法

对于任何对解决方案感兴趣的人: 在 class DiagonalWeight 中,我将 m = K.eye(N) 行更改为 m = tf.eye(N)

我不知道这是如何工作的,但我认为这与会话有关。 Tensorflow 需要存储操作以备将来在训练中使用。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?