如何解决解释使用带有 RepeatDataset 和 BatchDataset 类型对象的 SHAP 用 BERT 构建的模型
我使用预训练的 BERT 权重构建了一个有点复杂的模型。模型结构如下:
Model: "model_1"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_ids (InputLayer) [(None,32)] 0
__________________________________________________________________________________________________
attention_mask (InputLayer) [(None,32)] 0
__________________________________________________________________________________________________
token_type_ids (InputLayer) [(None,32)] 0
__________________________________________________________________________________________________
tf_bert_model_1 (TFBertModel) ((None,32,768),(N 109482240 input_ids[0][0]
attention_mask[0][0]
token_type_ids[0][0]
__________________________________________________________________________________________________
dropout_76 (Dropout) (None,768) 0 tf_bert_model_1[0][0]
__________________________________________________________________________________________________
lstm_1 (LSTM) (None,256) 1049600 dropout_76[0][0]
__________________________________________________________________________________________________
dense_5 (Dense) (None,128) 32896 lstm_1[0][0]
__________________________________________________________________________________________________
dropout_77 (Dropout) (None,128) 0 dense_5[0][0]
__________________________________________________________________________________________________
dense_6 (Dense) (None,64) 8256 dropout_77[0][0]
__________________________________________________________________________________________________
dense_7 (Dense) (None,32) 2080 dense_6[0][0]
__________________________________________________________________________________________________
dense_8 (Dense) (None,16) 528 dense_7[0][0]
__________________________________________________________________________________________________
dense_9 (Dense) (None,7) 119 dense_8[0][0]
==================================================================================================
Total params: 110,575,719
Trainable params: 110,719
Non-trainable params: 0
__________________________________________________________________________________________________
我使用了 RepeatDataset 类型的对象将数据提供给模型。它是使用以下代码创建的:
train_ds = tf.data.Dataset.from_tensor_slices((train_inp,train_mask,train_type_ids,train_out)).map(convert_to_features).shuffle(100).batch(BATCH_SIZE).repeat(5)
type(test_ds)
所以类型是:tensorflow.python.data.ops.dataset_ops.RepeatDataset
。现在我想使用 SHAP 添加模型说明。
我已经尝试使用 DeepExplainer
进行实现。它试图获得数据的形状,在我的例子中是 train_ds
。但是作为 RepeateDataset
类型的对象,它没有 shape 属性。我怎样才能克服模型?或者还有其他方法可以将 SHAP 与 RepeatDataset
类型的对象一起使用吗?
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。