微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何使用自定义 CTC 层正确保存和加载模型Keras 示例

如何解决如何使用自定义 CTC 层正确保存和加载模型Keras 示例

我正在 Keras 上关注本教程,但我不知道如何在训练后使用自定义层正确保存此模型并加载它。 herehere 中已经提到了这个问题,但显然这些解决方案都不适用于这个 Keras 示例。有人能指出我正确的方向吗?

P.S:这里是代码的主要部分:

class CTCLayer(layers.Layer):
    def __init__(self,name=None):
        super().__init__(name=name)
        self.loss_fn = keras.backend.ctc_batch_cost

    def call(self,y_true,y_pred):
        # Compute the training-time loss value and add it
        # to the layer using `self.add_loss()`.
        batch_len = tf.cast(tf.shape(y_true)[0],dtype="int64")
        input_length = tf.cast(tf.shape(y_pred)[1],dtype="int64")
        label_length = tf.cast(tf.shape(y_true)[1],dtype="int64")

        input_length = input_length * tf.ones(shape=(batch_len,1),dtype="int64")
        label_length = label_length * tf.ones(shape=(batch_len,dtype="int64")

        loss = self.loss_fn(y_true,y_pred,input_length,label_length)
        self.add_loss(loss)

        # At test time,just return the computed predictions
        return y_pred


def build_model():
    # Inputs to the model
    input_img = layers.Input(
        shape=(img_width,img_height,name="image",dtype="float32"
    )
    labels = layers.Input(name="label",shape=(None,),dtype="float32")

    # First conv block
    x = layers.Conv2D(
        32,(3,3),activation="relu",kernel_initializer="he_normal",padding="same",name="Conv1",)(input_img)
    x = layers.MaxPooling2D((2,2),name="pool1")(x)

    # Second conv block
    x = layers.Conv2D(
        64,name="Conv2",)(x)
    x = layers.MaxPooling2D((2,name="pool2")(x)

    # We have used two max pool with pool size and strides 2.
    # Hence,downsampled feature maps are 4x smaller. The number of
    # filters in the last layer is 64. Reshape accordingly before
    # passing the output to the RNN part of the model
    new_shape = ((img_width // 4),(img_height // 4) * 64)
    x = layers.Reshape(target_shape=new_shape,name="reshape")(x)
    x = layers.Dense(64,name="dense1")(x)
    x = layers.Dropout(0.2)(x)

    # RNNs
    x = layers.Bidirectional(layers.LSTM(128,return_sequences=True,dropout=0.25))(x)
    x = layers.Bidirectional(layers.LSTM(64,dropout=0.25))(x)

    # Output layer
    x = layers.Dense(len(characters) + 1,activation="softmax",name="dense2")(x)

    # Add CTC layer for calculating CTC loss at each step
    output = CTCLayer(name="ctc_loss")(labels,x)

    # Define the model
    model = keras.models.Model(
        inputs=[input_img,labels],outputs=output,name="ocr_model_v1"
    )
    # Optimizer
    opt = keras.optimizers.Adam()
    # Compile the model and return
    model.compile(optimizer=opt)
    return model


# Get the model
model = build_model()
model.summary()class CTCLayer(layers.Layer):
    def __init__(self,name="ocr_model_v1"
    )
    # Optimizer
    opt = keras.optimizers.Adam()
    # Compile the model and return
    model.compile(optimizer=opt)
    return model


# Get the model
model = build_model()
model.summary()

epochs = 100
early_stopping_patience = 10
# Add early stopping
early_stopping = keras.callbacks.EarlyStopping(
    monitor="val_loss",patience=early_stopping_patience,restore_best_weights=True
)

# Train the model
history = model.fit(
    train_dataset,validation_data=validation_dataset,epochs=epochs,callbacks=[early_stopping],)

# Get the prediction model by extracting layers till the output layer
prediction_model = keras.models.Model(
    model.get_layer(name="image").input,model.get_layer(name="dense2").output
)
prediction_model.summary()

解决方法

@Amirhosein,在 Horovod 存储库中查看此函数:

序列化: https://github.com/horovod/horovod/blob/6f0bb9fae826167559501701d4a5a0380284b5f0/horovod/spark/keras/util.py#L115

反序列化: https://github.com/horovod/horovod/blob/6f0bb9fae826167559501701d4a5a0380284b5f0/horovod/spark/keras/remote.py#L267

反序列化的使用示例: https://github.com/horovod/horovod/blob/6f0bb9fae826167559501701d4a5a0380284b5f0/horovod/spark/keras/remote.py#L118

如果您使用自定义指标或自定义损失函数等自定义对象,则需要使用示例中的 custom_object_scope

它在底层使用了一个名为 cloudpickle (https://pypi.org/project/cloudpickle/) 的包来将 KerasModel 转换为字符串,反之亦然。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。