tensorflow线性回归的pytorch等效项是什么?

如何解决tensorflow线性回归的pytorch等效项是什么?

我正在学习pytorch,以对此处创建的这种数据进行基本的线性回归:

from sklearn.datasets import make_regression

x,y = make_regression(n_samples=100,n_features=1,noise=15,random_state=42)
y = y.reshape(-1,1)
print(x.shape,y.shape)

plt.scatter(x,y)

我知道使用te​​nsorflow这段代码可以解决:

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(units=1,activation='linear',input_shape=(x.shape[1],)))

model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.05),loss='mse')

hist = model.fit(x,y,epochs=15,verbose=0)

但是我需要知道pytorch的等效形式是什么,我试图做的是这样:

# Model Class
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.linear = nn.Linear(1,1)
        
    def forward(self,x):
        x = self.linear(x)
        return x
    
    def predict(self,x):
        return self.forward(x)
    
model = Net()

loss_fn = F.mse_loss
opt = torch.optim.SGD(modelo.parameters(),lr=0.05)

# Funcao para treinar
def fit(num_epochs,model,loss_fn,opt,train_dl):
    
    
    # Repeat for given number of epochs
    for epoch in range(num_epochs):
        
        # Train with batches of data
        for xb,yb in train_dl:
            
            # 1. Generate predictions
            pred = model(xb)
            
            # 2. Calculate Loss
            loss = loss_fn(pred,yb)
            
            # 3. Campute gradients
            loss.backward()
            
            # 4. Update parameters using gradients
            opt.step()
            
            # 5. Reset the gradients to zero
            opt.zero_grad()
            
        # Print the progress
        if (epoch+1) % 10 == 0:
            print('Epoch [{}/{}],Loss: {:.4f}'.format(epoch+1,num_epochs,loss.item()))

# Training
fit(200,data_loader)

但是模型没有学到任何东西,我不知道该怎么办。

输入/输出尺寸为(1/1)

解决方法

数据集

首先,您应该定义torch.utils.data.Dataset

import torch
from sklearn.datasets import make_regression


class RegressionDataset(torch.utils.data.Dataset):
    def __init__(self):
        data = make_regression(n_samples=100,n_features=1,noise=0.1,random_state=42)
        self.x = torch.from_numpy(data[0]).float()
        self.y = torch.from_numpy(data[1]).float()

    def __len__(self):
        return len(self.x)

    def __getitem__(self,index):
        return self.x[index],self.y[index]

它将numpy数据转换为tensor内的PyTorch的{​​{1}}并将数据转换为__init__float默认具有numpy,而PyTorch的默认是double,以便使用更少的内存。

除此之外,它仅会返回特征的float和相应的回归目标。

健身

几乎在那里,但是您必须将模型的输出展平(如下所述)。 tuple将返回形状为torch.nn.Linear的张量,而您的目标形状为(batch,1)(batch,)将删除不必要的flatten()尺寸。

1

型号

这就是您实际需要的:

# 2. Calculate Loss
loss = criterion(pred.flatten(),yb)

任何层都可以直接调用,不需要model = torch.nn.Linear(1,1) 和简单模型的继承。

呼叫

剩下的几乎可以了,您只需要创建torch.utils.data.DataLoader并传递数据集的实例即可。 forward的作用是多次发行DataLoader的{​​{1}}并创建一批指定大小的文件(还有其他有趣的事情,但这就是想法):

__getitem__

还请注意,我使用了dataset,因为在这种情况下,我们传递对象看起来比函数要好。

整个代码

为了使其更容易:

dataset = RegressionDataset()
dataloader = torch.utils.data.DataLoader(dataset,batch_size=32)
model = torch.nn.Linear(1,1)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),lr=3e-4)

fit(5000,model,criterion,optimizer,dataloader)

您应该避免torch.nn.MSELoss()的损失,为更困难/更轻松的回归任务改变import torch from sklearn.datasets import make_regression class RegressionDataset(torch.utils.data.Dataset): def __init__(self): data = make_regression(n_samples=100,self.y[index] # Funcao para treinar def fit(num_epochs,train_dl): # Repeat for given number of epochs for epoch in range(num_epochs): # Train with batches of data for xb,yb in train_dl: # 1. Generate predictions pred = model(xb) # 2. Calculate Loss loss = criterion(pred.flatten(),yb) # 3. Compute gradients loss.backward() # 4. Update parameters using gradients optimizer.step() # 5. Reset the gradients to zero optimizer.zero_grad() # Print the progress if (epoch + 1) % 10 == 0: print( "Epoch [{}/{}],Loss: {:.4f}".format(epoch + 1,num_epochs,loss.item()) ) dataset = RegressionDataset() dataloader = torch.utils.data.DataLoader(dataset,dataloader) 或其他参数。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams['font.sans-serif'] = ['SimHei'] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -> systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping("/hires") public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate<String
使用vite构建项目报错 C:\Users\ychen\work>npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-
参考1 参考2 解决方案 # 点击安装源 协议选择 http:// 路径填写 mirrors.aliyun.com/centos/8.3.2011/BaseOS/x86_64/os URL类型 软件库URL 其他路径 # 版本 7 mirrors.aliyun.com/centos/7/os/x86
报错1 [root@slave1 data_mocker]# kafka-console-consumer.sh --bootstrap-server slave1:9092 --topic topic_db [2023-12-19 18:31:12,770] WARN [Consumer clie
错误1 # 重写数据 hive (edu)> insert overwrite table dwd_trade_cart_add_inc > select data.id, > data.user_id, > data.course_id, > date_format(
错误1 hive (edu)> insert into huanhuan values(1,'haoge'); Query ID = root_20240110071417_fe1517ad-3607-41f4-bdcf-d00b98ac443e Total jobs = 1
报错1:执行到如下就不执行了,没有显示Successfully registered new MBean. [root@slave1 bin]# /usr/local/software/flume-1.9.0/bin/flume-ng agent -n a1 -c /usr/local/softwa
虚拟及没有启动任何服务器查看jps会显示jps,如果没有显示任何东西 [root@slave2 ~]# jps 9647 Jps 解决方案 # 进入/tmp查看 [root@slave1 dfs]# cd /tmp [root@slave1 tmp]# ll 总用量 48 drwxr-xr-x. 2
报错1 hive> show databases; OK Failed with exception java.io.IOException:java.lang.RuntimeException: Error in configuring object Time taken: 0.474 se
报错1 [root@localhost ~]# vim -bash: vim: 未找到命令 安装vim yum -y install vim* # 查看是否安装成功 [root@hadoop01 hadoop]# rpm -qa |grep vim vim-X11-7.4.629-8.el7_9.x
修改hadoop配置 vi /usr/local/software/hadoop-2.9.2/etc/hadoop/yarn-site.xml # 添加如下 <configuration> <property> <name>yarn.nodemanager.res