如何解决为什么我的无监督域适应 PyTorch 模型预测的结果范围较小?
我为 WiFi RSS 数据中的无监督域适应创建了一个 PyTorch 模型。它需要使用这些数据来预测设备位置。该模型用于预测连续值。
但在这里,预测的范围只有 4 个单位。但是,实际值在 100 以上的范围内。像 11、14 等这样的小值是正确预测的。但是,较大的值预测范围高达 20 到 25,即使该值接近 100。
算法正在尝试预测它。但没有前进。我认为它陷入了一些局部最优或其他什么。如何纠正这个问题? 我需要在任何 y 值范围内进行预测。当我将 y 值除以 10 时,损失会下降(但是,我不能使用这些预测,因为它们是不正确的!)。所以,我认为这与 y 值的范围有关。
看看我的模型,
class DANN(nn.Module):
def __init__(self):
super().__init__()
self.num_features = num_features
self.feature_extractor = nn.Sequential(
nn.Linear(self.num_features,100),nn.ReLU(True),nn.Linear(100,nn.Dropout(),)
self.label_predictor = nn.Sequential(
nn.Linear(100,2),)
self.domain_classifier = nn.Sequential(
nn.Linear(100,nn.Batchnorm1d(100),nn.Logsoftmax(dim=1)
)
def forward(self,x,grl_lambda = 1.0):
x = x.expand(x.data.shape[0],num_features)
features = self.feature_extractor(x)
features_grl = GradientReversalFn.apply(features,grl_lambda)
label_pred = self.label_predictor(features)
domain_pred = self.domain_classifier(features_grl)
return label_pred,domain_pred
我已经训练过了,
#initializing parameters before training the model
lr = 1e-3
optimizer = optim.Adam(model.parameters(),lr)
loss_fn_label = torch.nn.L1Loss()
loss_fn_domain = torch.nn.NLLLoss()
max_batches = min(len(dx_source),len(dx_target))
#training loop
for epoch_idx in range(num_epochs):
print(f'Epoch {epoch_idx+1:03d}/{num_epochs:03d}',end = '\t')
dx_source_iter = iter(dx_source)
dx_target_iter = iter(dx_target)
dy_source_iter = iter(dy_source)
for batch_idx in range(max_batches-1):
optimizer.zero_grad() # clear all the gradients before calculating them
p = float(batch_idx + epoch_idx*max_batches)/(num_epochs*max_batches)
grl_lambda = 2./(1.+np.exp(-10*p))-1
#source training
x_s = next(dx_source_iter)
y_s = next(dy_source_iter)
y_s_domain = torch.zeros(num_batch,dtype = torch.long)
label_pred,domain_pred = model(x_s.float(),grl_lambda)
loss_s_label = loss_fn_label(label_pred.float(),y_s.float())
loss_s_domain = loss_fn_domain(domain_pred.float(),y_s_domain)
#target training
x_t = next(dx_target_iter)
y_t_domain = torch.ones(num_batch,dtype = torch.long)
_,domain_pred = model(x_t.float(),grl_lambda)
loss_t_domain = loss_fn_domain(domain_pred.float(),y_t_domain)
#optimization
loss = loss_t_domain + loss_s_domain + loss_s_label
loss.backward()
optimizer.step()
print(f's_label_loss: {loss_s_label.item():.4f}''\t'
f's_domain_loss: {loss_s_domain.item():.4f}''\t'
f't_domain_loss: {loss_t_domain.item():.4f}''\t'
f'grl_lambda: {grl_lambda:.3f}')
这是模型的输出。 (s 代表源,t 代表目标):
Epoch 001/020 s_label_loss: 37.9219 s_domain_loss: 0.7102 t_domain_loss: 0.7024 grl_lambda: 0.243
Epoch 002/020 s_label_loss: 31.6202 s_domain_loss: 0.6835 t_domain_loss: 0.7082 grl_lambda: 0.460
Epoch 003/020 s_label_loss: 33.5835 s_domain_loss: 0.6897 t_domain_loss: 0.6989 grl_lambda: 0.634
Epoch 004/020 s_label_loss: 28.1080 s_domain_loss: 0.6915 t_domain_loss: 0.7021 grl_lambda: 0.761
Epoch 005/020 s_label_loss: 30.8963 s_domain_loss: 0.7108 t_domain_loss: 0.6758 grl_lambda: 0.848
Epoch 006/020 s_label_loss: 40.5831 s_domain_loss: 0.6694 t_domain_loss: 0.7219 grl_lambda: 0.905
Epoch 007/020 s_label_loss: 36.4664 s_domain_loss: 0.6686 t_domain_loss: 0.7212 grl_lambda: 0.941
Epoch 008/020 s_label_loss: 27.1264 s_domain_loss: 0.6808 t_domain_loss: 0.7044 grl_lambda: 0.964
Epoch 009/020 s_label_loss: 31.7887 s_domain_loss: 0.7417 t_domain_loss: 0.6478 grl_lambda: 0.978
Epoch 010/020 s_label_loss: 34.0161 s_domain_loss: 0.6921 t_domain_loss: 0.6969 grl_lambda: 0.987
Epoch 011/020 s_label_loss: 31.1291 s_domain_loss: 0.6969 t_domain_loss: 0.6956 grl_lambda: 0.992
Epoch 012/020 s_label_loss: 40.8664 s_domain_loss: 0.6695 t_domain_loss: 0.7174 grl_lambda: 0.995
Epoch 013/020 s_label_loss: 35.8289 s_domain_loss: 0.6661 t_domain_loss: 0.7128 grl_lambda: 0.997
Epoch 014/020 s_label_loss: 39.4238 s_domain_loss: 0.6907 t_domain_loss: 0.6946 grl_lambda: 0.998
Epoch 015/020 s_label_loss: 35.6348 s_domain_loss: 0.6934 t_domain_loss: 0.6925 grl_lambda: 0.999
Epoch 016/020 s_label_loss: 35.0068 s_domain_loss: 0.6925 t_domain_loss: 0.6802 grl_lambda: 0.999
Epoch 017/020 s_label_loss: 36.3469 s_domain_loss: 0.7362 t_domain_loss: 0.6506 grl_lambda: 1.000
Epoch 018/020 s_label_loss: 37.8324 s_domain_loss: 0.6865 t_domain_loss: 0.6997 grl_lambda: 1.000
Epoch 019/020 s_label_loss: 34.9296 s_domain_loss: 0.6867 t_domain_loss: 0.6951 grl_lambda: 1.000
Epoch 020/020 s_label_loss: 35.0239 s_domain_loss: 0.6814 t_domain_loss: 0.7164 grl_lambda: 1.000
提前感谢您的时间和帮助! :)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。