如何将 KERAS `MLP` 模型转换为 `Pythorch` 模型

如何解决如何将 KERAS `MLP` 模型转换为 `Pythorch` 模型

我有一个使用某个项目的 MLP 模型

from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Flatten
from keras.layers import Embedding

model = Sequential()
embedding_layer = Embedding(vocab_size,50,input_length=len(X[0]))
model.add(embedding_layer)
model.add(Flatten())
model.add(Dense(100,activation='relu'))
model.add(Dense(3,activation='softmax'))
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
model.summary()
model.fit(X,Y,epochs=10,verbose=1)

但我想将此 MLP 模型转换为 Pythorch 模型。

我该怎么做?

非常感谢。

我添加此代码以生成 data

def generate_batch_data(x,y,batch_size):
    i,batch = 0,0
    for batch,i in enumerate(range(0,len(x) - batch_size,batch_size),1):
        x_batch = x[i : i + batch_size]
        y_batch = y[i : i + batch_size]
        yield x_batch,y_batch
    if i + batch_size < len(x):
        yield x[i + batch_size :],y[i + batch_size :]
    if batch == 0:
        yield x,y

和这段代码:

epochs = 10
batch_size = 10
for epoch in range(epochs):
    print(f'Epoch {epoch+1}/{epochs}')
   
    for x_train,y_train in generate_batch_data(x_train,y_train,batch_size):
        y_hat = model(x_train)
        loss = criterion(y_hat,y_train)
        acc = (y_hat.argmax(1) == y).float().mean()

        print(f'loss: {loss},accuracy: {acc}')

这次我犯了这个错误:

RuntimeError                              Traceback (most recent call last)
<ipython-input-50-7ffa61cb7f34> in <module>()
      5 
      6     for x_train,batch_size):
----> 7         y_hat = model(x_train)
      8         loss = criterion(y_hat,y_train)
      9         acc = (y_hat.argmax(1) == y).float().mean()

4 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in embedding(input,weight,padding_idx,max_norm,norm_type,scale_grad_by_freq,sparse)
   1850         # remove once script supports set_grad_enabled
   1851         _no_grad_embedding_renorm_(weight,input,norm_type)
-> 1852     return torch.embedding(weight,sparse)
   1853 
   1854 

RuntimeError: Expected tensor for argument #1 'indices' to have scalar type Long; but got torch.FloatTensor instead (while checking arguments for embedding)  

我在这里添加了torch.LongTensor

def tokenize_and_pad_text(df,max_seq):
  tokenized_text = tokenize_text(df,max_seq)
  padded_text = pad_text(tokenized_text,max_seq)
  return torch.LongTensor(padded_text)

train_indices = tokenize_and_pad_text(df_train,max_seq)

x_train = bert_model(train_indices)[0]

并在此处更改:

y_hat = model(x_train.long())

但是这次它给出了这个错误:

IndexError                                Traceback (most recent call last)
<ipython-input-67-9ad38a8c062a> in <module>()
      5 
      6     for x_train,batch_size):
----> 7         y_hat = model(x_train.long())
      8         loss = criterion(y_hat,sparse)
   1853 
   1854 

IndexError: index out of range in self

模型是这样的:

import torch.nn as nn

model = nn.Sequential(
    nn.Embedding(num_embeddings=148,embedding_dim=768),nn.Flatten(),nn.Linear((768*148),148),nn.ReLU(),nn.Linear(148,3))

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters())

通常像这样使用 dataframe context column

def tokenize_text(df,max_seq):
  return[
         tokenizer.encode(text,add_special_tokens=True)[:max_seq] for text in df.context.values
  ]

def pad_text(tokenized_text,max_seq):
  return np.array([el + [0] * (max_seq - len(el)) for el in tokenized_text])

和评论后:我认为len(df.context)148并将num_embeddings更新为148。我仍然不断得到the same error

非常感谢。

解决方法

这是带有损失函数和优化器的模型定义:

import torch.nn as nn

model = nn.Sequential(
    nn.Embedding(num_embeddings=vocab_size,embedding_dim=50),nn.Flatten(),nn.Linear(50*len(X[0]),100),nn.ReLU(),nn.Linear(100,3))

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters())

这是一个训练循环的粗略轮廓:

epochs = 10
for epoch in range(epochs):
    print(f'Epoch {epoch+1}/{epochs}')

    for x,y in data:
        y_hat = model(x)
        loss = criterion(y_hat,y)
        acc = (y_hat.argmax(1) == y).float().mean()

        print(f'loss: {loss},accuracy: {acc}')

假设 data 包含 (x,y) 个训练点,并且 y 有一个包含真实类索引的维度。


其他评论:

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams[&#39;font.sans-serif&#39;] = [&#39;SimHei&#39;] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -&gt; systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping(&quot;/hires&quot;) public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate&lt;String
使用vite构建项目报错 C:\Users\ychen\work&gt;npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-
参考1 参考2 解决方案 # 点击安装源 协议选择 http:// 路径填写 mirrors.aliyun.com/centos/8.3.2011/BaseOS/x86_64/os URL类型 软件库URL 其他路径 # 版本 7 mirrors.aliyun.com/centos/7/os/x86
报错1 [root@slave1 data_mocker]# kafka-console-consumer.sh --bootstrap-server slave1:9092 --topic topic_db [2023-12-19 18:31:12,770] WARN [Consumer clie
错误1 # 重写数据 hive (edu)&gt; insert overwrite table dwd_trade_cart_add_inc &gt; select data.id, &gt; data.user_id, &gt; data.course_id, &gt; date_format(
错误1 hive (edu)&gt; insert into huanhuan values(1,&#39;haoge&#39;); Query ID = root_20240110071417_fe1517ad-3607-41f4-bdcf-d00b98ac443e Total jobs = 1
报错1:执行到如下就不执行了,没有显示Successfully registered new MBean. [root@slave1 bin]# /usr/local/software/flume-1.9.0/bin/flume-ng agent -n a1 -c /usr/local/softwa
虚拟及没有启动任何服务器查看jps会显示jps,如果没有显示任何东西 [root@slave2 ~]# jps 9647 Jps 解决方案 # 进入/tmp查看 [root@slave1 dfs]# cd /tmp [root@slave1 tmp]# ll 总用量 48 drwxr-xr-x. 2
报错1 hive&gt; show databases; OK Failed with exception java.io.IOException:java.lang.RuntimeException: Error in configuring object Time taken: 0.474 se
报错1 [root@localhost ~]# vim -bash: vim: 未找到命令 安装vim yum -y install vim* # 查看是否安装成功 [root@hadoop01 hadoop]# rpm -qa |grep vim vim-X11-7.4.629-8.el7_9.x
修改hadoop配置 vi /usr/local/software/hadoop-2.9.2/etc/hadoop/yarn-site.xml # 添加如下 &lt;configuration&gt; &lt;property&gt; &lt;name&gt;yarn.nodemanager.res