如何解决用于 LSTM 的 Kerastuner
我正在处理一个文本分类问题,并尝试使用 Kerastuner 来确定我的 LSTM 网络的最佳配置。下面是相同的代码:
keras 调谐器
def build_model(hp):
num_hidden_layers =1
num_units = 8
dropout_rate = 0.1
learning_rate=0.01
if hp:
num_hidden_layers = hp.Int('num_hidden_layers',min_value=2,max_value=100,step=5)
num_units = hp.Int('num_units',min_value=50,max_value=2000,step=50)
dropout_rate = hp.Float('dropout_rate',min_value=0.1,max_value=0.5)
learning_rate = hp.Float('learning_rate',min_value=0.0001,max_value=0.01)
momentum_rate = hp.Float('momentum_rate',min_value=0.5,max_value=0.9)
vocab_size = len(tokenizer.word_index)+1
max_sequence_length = 500
embedding_size = 300
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Embedding(input_dim=vocab_size,output_dim=embedding_size,input_length=max_sequence_length,weights=[embedding_matrix],trainable=False))
for _ in range(0,num_hidden_layers):
model.add(tf.keras.layers.LSTM(num_units))
model.add(tf.keras.layers.Dropout(dropout_rate))
model.add(tf.keras.layers.Dense(1,activation='sigmoid'))
model.compile(
loss = 'mse',optimizer =tf.keras.optimizers.SGD(learning_rate=learning_rate,momentum=momentum_rate),metrics = [tf.keras.metrics.BinaryCrossentropy(name='binary_crossentropy')]
)
return model
class CustomTuner(kerastuner.tuners.Bayesianoptimization):
def run_trial(self,trial,*args,**kwargs):
kwargs['batch_size'] = trial.hyperparameters.Int('batch_size',128,1024,step=32)
super(CustomTuner,self).run_trial(trial,**kwargs)
tuner = CustomTuner(
build_model,objective=kerastuner.Objective('val_loss','min'),max_trials=2,executions_per_trial=1,directory='/dbfs/FileStore/GDPR_Dev/Data/',project_name = 'nn_logs_lstm_30062021',overwrite=True
ValueError: 层 lstm_1 的输入 0 与层不兼容:预期 ndim=3,发现 ndim=2。收到完整形状:(无,50)
有人可以帮我解决这个问题吗?
解决方法
目前,您的数据是二维 (N x M),但是,输入数据应该是三维的。要解决这个问题,您应该将输入重塑为 N x M x 1 矩阵,如下所示:
x = np.reshape(x,(x.shape[0],x.shape[1],1))
如果您的输入是多元的,那么所需的输入形状将是 N x M x K,其中 k 是维数
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。