Python keras.backend 模块,batch_set_value() 实例源码
我们从Python开源项目中,提取了以下3个代码示例,用于说明如何使用keras.backend.batch_set_value()。
def get_weights_from_h5_group(model, model_weights, verbose=1):
layers = model.layers
weight_value_tuples = []
for layer in layers:
name = layer.name
if name in model_weights and len(model_weights[name]) > 0:
layer_weights = model_weights[name]
weight_names = [n.decode('utf8') for n in layer_weights.attrs['weight_names']]
weight_values = [layer_weights[weight_name] for weight_name in weight_names]
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
if len(weight_values) != len(symbolic_weights):
raise Exception('Layer #' + str(k) +
' (named "' + layer.name +
'" in the current model) was found to '
'correspond to layer ' + name +
' in the save file. '
'However the new layer ' + layer.name +
' expects ' + str(len(symbolic_weights)) +
' weights,but the saved weights have ' +
str(len(weight_values)) +
' elements.')
if verbose:
print('Setting_weights for layer:', name)
weight_value_tuples += zip(symbolic_weights, weight_values)
K.batch_set_value(weight_value_tuples)
def load_weights(model, weights_path):
"""Load weights from Caffe models."""
print("Loading weights...")
if h5py is None:
raise ImportError('`load_weights` requires h5py.')
f = h5py.File(weights_path, mode='r')
# New file format.
layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]
# Reverse index of layer name to list of layers with name.
index = {}
for layer in model.layers:
if layer.name:
index.setdefault(layer.name, []).append(layer)
# We batch weight value assignments in a single backend call
# which provides a speedup in TensorFlow.
weight_value_tuples = []
for k, name in enumerate(layer_names):
g = f[name]
weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
weight_values = [g[weight_name] for weight_name in weight_names]
for layer in index.get(name, []):
symbolic_weights = layer.weights
# Set values.
for i in range(len(weight_values)):
weight_value_tuples.append((symbolic_weights[i],
weight_values[i]))
K.batch_set_value(weight_value_tuples)
return layer_names
def load_weights(model, weights_path):
from keras import backend as K
if not os.path.isfile(weights_path):
raise Exception("File does not exist.")
import h5py
f = h5py.File(weights_path, mode='r')
# new file format
layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]
if len(layer_names) != len(model.layers):
print("Warning: Layer count different")
# we batch weight value assignments in a single backend call
# which provides a speedup in TensorFlow.
weight_value_tuples = []
for k, name in enumerate(layer_names):
g = f[name]
weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
layer = model.get_layer(name=name)
if layer and len(weight_names):
weight_values = [g[weight_name] for weight_name in weight_names]
if not hasattr(layer, 'trainable_weights'):
print("Layer %s (%s) has no trainable weights,but we tried to load it." % (
name, type(layer).__name__))
else:
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
if len(weight_values) != len(symbolic_weights):
raise Exception('Layer #' + str(k) +
' (named "' + layer.name +
'" in the current model) was found to '
'correspond to layer ' + name +
' in the save file. '
'However the new layer ' + layer.name +
' expects ' + str(len(symbolic_weights)) +
' weights,but the saved weights have ' +
str(len(weight_values)) +
' elements.')
weight_value_tuples += list(zip(symbolic_weights, weight_values))
K.batch_set_value(weight_value_tuples)
f.close()
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 [email protected] 举报,一经查实,本站将立刻删除。