微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

为什么我的无监督域适应 PyTorch 模型预测的结果范围较小?

如何解决为什么我的无监督域适应 PyTorch 模型预测的结果范围较小?

我为 WiFi RSS 数据中的无监督域适应创建了一个 PyTorch 模型。它需要使用这些数据来预测设备位置。该模型用于预测连续值。

在这里,预测的范围只有 4 个单位。但是,实际值在 100 以上的范围内。像 11、14 等这样的小值是正确预测的。但是,较大的值预测范围高达 20 到 25,即使该值接近 100。

算法正在尝试预测它。但没有前进。我认为它陷入了一些局部最优或其他什么。如何纠正这个问题? 我需要在任何 y 值范围内进行预测。当我将 y 值除以 10 时,损失会下降(但是,我不能使用这些预测,因为它们是不正确的!)。所以,我认为这与 y 值的范围有关。

看看我的模型,

class DANN(nn.Module):
  def __init__(self):
    super().__init__()
 
    self.num_features = num_features
 
    self.feature_extractor = nn.Sequential(
        nn.Linear(self.num_features,100),nn.ReLU(True),nn.Linear(100,nn.Dropout(),)
 
    self.label_predictor = nn.Sequential(
        nn.Linear(100,2),)
 
    self.domain_classifier = nn.Sequential(
        nn.Linear(100,nn.Batchnorm1d(100),nn.Logsoftmax(dim=1)
    )
  
  def forward(self,x,grl_lambda = 1.0):
    x = x.expand(x.data.shape[0],num_features)
 
    features = self.feature_extractor(x)
    features_grl = GradientReversalFn.apply(features,grl_lambda)
    label_pred = self.label_predictor(features)
    domain_pred = self.domain_classifier(features_grl)
 
    return label_pred,domain_pred

我已经训练过了,

#initializing parameters before training the model
lr = 1e-3
optimizer = optim.Adam(model.parameters(),lr)

loss_fn_label = torch.nn.L1Loss()
loss_fn_domain = torch.nn.NLLLoss()

max_batches = min(len(dx_source),len(dx_target))

#training loop
for epoch_idx in range(num_epochs):
  print(f'Epoch {epoch_idx+1:03d}/{num_epochs:03d}',end = '\t')
 
  dx_source_iter = iter(dx_source)
  dx_target_iter = iter(dx_target)
 
  dy_source_iter = iter(dy_source)
 
  for batch_idx in range(max_batches-1):
    optimizer.zero_grad() # clear all the gradients before calculating them
    p = float(batch_idx + epoch_idx*max_batches)/(num_epochs*max_batches)
    grl_lambda = 2./(1.+np.exp(-10*p))-1
    
    #source training
    x_s = next(dx_source_iter)
    y_s = next(dy_source_iter)
    y_s_domain = torch.zeros(num_batch,dtype = torch.long)
 
    label_pred,domain_pred = model(x_s.float(),grl_lambda)
 
    loss_s_label = loss_fn_label(label_pred.float(),y_s.float())
    loss_s_domain = loss_fn_domain(domain_pred.float(),y_s_domain)
 
    #target training
    x_t = next(dx_target_iter)
    y_t_domain = torch.ones(num_batch,dtype = torch.long)
 
    _,domain_pred = model(x_t.float(),grl_lambda)
 
    loss_t_domain = loss_fn_domain(domain_pred.float(),y_t_domain)
 
    #optimization
    loss = loss_t_domain + loss_s_domain + loss_s_label
    loss.backward()
    optimizer.step()
 
  print(f's_label_loss: {loss_s_label.item():.4f}''\t'
     f's_domain_loss: {loss_s_domain.item():.4f}''\t'
     f't_domain_loss: {loss_t_domain.item():.4f}''\t'
     f'grl_lambda: {grl_lambda:.3f}')

这是模型的输出。 (s 代表源,t 代表目标):

Epoch 001/020   s_label_loss: 37.9219   s_domain_loss: 0.7102   t_domain_loss: 0.7024   grl_lambda: 0.243
Epoch 002/020   s_label_loss: 31.6202   s_domain_loss: 0.6835   t_domain_loss: 0.7082   grl_lambda: 0.460
Epoch 003/020   s_label_loss: 33.5835   s_domain_loss: 0.6897   t_domain_loss: 0.6989   grl_lambda: 0.634
Epoch 004/020   s_label_loss: 28.1080   s_domain_loss: 0.6915   t_domain_loss: 0.7021   grl_lambda: 0.761
Epoch 005/020   s_label_loss: 30.8963   s_domain_loss: 0.7108   t_domain_loss: 0.6758   grl_lambda: 0.848
Epoch 006/020   s_label_loss: 40.5831   s_domain_loss: 0.6694   t_domain_loss: 0.7219   grl_lambda: 0.905
Epoch 007/020   s_label_loss: 36.4664   s_domain_loss: 0.6686   t_domain_loss: 0.7212   grl_lambda: 0.941
Epoch 008/020   s_label_loss: 27.1264   s_domain_loss: 0.6808   t_domain_loss: 0.7044   grl_lambda: 0.964
Epoch 009/020   s_label_loss: 31.7887   s_domain_loss: 0.7417   t_domain_loss: 0.6478   grl_lambda: 0.978
Epoch 010/020   s_label_loss: 34.0161   s_domain_loss: 0.6921   t_domain_loss: 0.6969   grl_lambda: 0.987
Epoch 011/020   s_label_loss: 31.1291   s_domain_loss: 0.6969   t_domain_loss: 0.6956   grl_lambda: 0.992
Epoch 012/020   s_label_loss: 40.8664   s_domain_loss: 0.6695   t_domain_loss: 0.7174   grl_lambda: 0.995
Epoch 013/020   s_label_loss: 35.8289   s_domain_loss: 0.6661   t_domain_loss: 0.7128   grl_lambda: 0.997
Epoch 014/020   s_label_loss: 39.4238   s_domain_loss: 0.6907   t_domain_loss: 0.6946   grl_lambda: 0.998
Epoch 015/020   s_label_loss: 35.6348   s_domain_loss: 0.6934   t_domain_loss: 0.6925   grl_lambda: 0.999
Epoch 016/020   s_label_loss: 35.0068   s_domain_loss: 0.6925   t_domain_loss: 0.6802   grl_lambda: 0.999
Epoch 017/020   s_label_loss: 36.3469   s_domain_loss: 0.7362   t_domain_loss: 0.6506   grl_lambda: 1.000
Epoch 018/020   s_label_loss: 37.8324   s_domain_loss: 0.6865   t_domain_loss: 0.6997   grl_lambda: 1.000
Epoch 019/020   s_label_loss: 34.9296   s_domain_loss: 0.6867   t_domain_loss: 0.6951   grl_lambda: 1.000
Epoch 020/020   s_label_loss: 35.0239   s_domain_loss: 0.6814   t_domain_loss: 0.7164   grl_lambda: 1.000

提前感谢您的时间和帮助! :)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。