微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

keras.models.load_model() 给出错误“ValueError: Got 0 输入方程“baik,baij->bakj”,期待 2”

如何解决keras.models.load_model() 给出错误“ValueError: Got 0 输入方程“baik,baij->bakj”,期待 2”

我的代码将批处理矩阵乘法“tf.einsum('baik,baij->bakj',q,k)/np.sqrt(dv)”作为其中的一部分。在训练模型后,我使用“model.save('./model')”保存它,现在我想加载该保存的模型。我这样试过 "model = keras.models.load_model('./model',compile=False,custom_objects={'f1': f1})" 。但它给出了下面的错误。为什么会发生这种情况。

Traceback (most recent call last):
  File "test_pro.py",line 37,in <module>
    model = keras.models.load_model('./model',custom_objects={'f1': f1})
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/keras/saving/save.py",line 212,in load_model
    return saved_model_load.load(filepath,compile,options)
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py",line 147,in load
    keras_loader.finalize_objects()
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py",line 612,in finalize_objects
    self._reconstruct_all_models()
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py",line 631,in _reconstruct_all_models
    self._reconstruct_model(model_id,model,layers)
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py",line 677,in _reconstruct_model
    created_layers) = functional_lib.reconstruct_from_config(
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py",line 1285,in reconstruct_from_config
    process_node(layer,node_data)
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py",line 1233,in process_node
    output_tensors = layer(input_tensors,**kwargs)
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py",line 1012,in __call__
    outputs = call_fn(inputs,*args,**kwargs)
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/keras/layers/core.py",line 1327,in _call_wrapper
    return self._call_wrapper(*args,line 1359,in _call_wrapper
    result = self.function(*args,**kwargs)
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py",line 201,in wrapper
    return target(*args,**kwargs)
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/ops/special_math_ops.py",line 751,in einsum
    return _einsum_v2(equation,*inputs,line 1174,in _einsum_v2
    _einsum_v2_parse_and_resolve_equation(equation,input_shapes))
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/ops/special_math_ops.py",line 1254,in _einsum_v2_parse_and_resolve_equation
    raise ValueError('Got {} inputs for equation "{}",expecting {}'.format(
ValueError: Got 0 inputs for equation "baik,baij->bakj",expecting 2

这是我创建模型的方式:

def MultiHeadsAttModel(l=7*7,d=1024,dv=64,dout=1024,nv = 16 ):

v1 = Input(shape = (l,d))
q1 = Input(shape = (l,d))
k1 = Input(shape = (l,d))

v2 = Dense(dv*nv,activation = "relu")(v1)
q2 = Dense(dv*nv,activation = "relu")(q1)
k2 = Dense(dv*nv,activation = "relu")(k1)

v = Reshape([l,nv,dv])(v2)
q = Reshape([l,dv])(q2)
k = Reshape([l,dv])(k2)
att = tf.einsum('baik,k)/np.sqrt(dv) #batch matrix multiplication
att = Lambda(lambda x:  K.softmax(x),output_shape=(l,nv))(att)
out = tf.einsum('bajk,baik->baji',att,v)
out = Reshape([l,d])(out)
out = Add()([out,q1])

out = Dense(dout,activation = "relu")(out)

return  Model(inputs=[q1,k1,v1],outputs=out) 



   
def create_model(input_shape,output_classes):
     mobile = tf.keras.applications.mobilenet.MobileNet(weights='imagenet')
     x = mobile.layers[-6].input
    
     if True:
        x = Reshape([7*7,1024])(x)
        att = MultiHeadsAttModel(l=7*7,nv = 16 )
        x = att([x,x,x])
        x = Reshape([7,7,1024])(x)   
        x = Batchnormalization()(x)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。