微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

为什么 tf.keras 对我来说比 Keras 慢得多? 24 it/s 到 2274 it/s

如何解决为什么 tf.keras 对我来说比 Keras 慢得多? 24 it/s 到 2274 it/s

我发现在我的代码库中,Keras 比 tf.keras 快得多。在我看来,tf.keras 慢得令人无法接受。 我创建了一个类似的神经网络,一次使用 tf.keras,第二次使用 Kears。 然后在 OpenAI Gym Mountain-Car-v0 环境中运行一个只有 predict 函数的简化循环。

所以我的问题是,如果我在使用框架时犯了一个巨大的错误,或者它背后的底层代码基础有什么不同吗?

结果:

Tf.Keras:10000/10000 [06:53

Keras:10000/10000 [00:04

代码库:

Keras 版本:2.3.1

import keras
from keras.models import Sequential
from keras.layers import Dense,Activation
from keras.optimizers import Adam

model = Sequential()
model.add(Dense(24,input_dim=env.observation_space.shape[0],activation="relu"))
model.add(Dense(24,activation="relu"))
model.add(Dense(env.action_space.n,activation='linear'))

model.compile(loss='mse',optimizer=Adam(lr=0.001))
print("Keras version: ",keras.__version__)

tf.keras 版本:2.2.4-tf

from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Activation
from tensorflow.keras.optimizers import Adam

model = keras.Sequential([
  Dense(24,activation = 'relu'),Dense(24,Dense(env.action_space.n,activation='linear')
])

model.compile(loss='mse',optimizer=Adam(lr=0.001))
print("tf.keras version: ",keras.__version__)

测试循环:

from tqdm import tqdm
for a in tqdm(range(10000)):
    state = env.reset()
    model.predict(state.reshape(-1,env.observation_space.shape[0]))

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。