微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

PyTorch BatchNorm2d 计算

如何解决PyTorch BatchNorm2d 计算

我试图通过计算来了解 PyTorch Batchnorm2d 的机制。我的示例代码

import torch
from torch import nn

torch.manual_seed(123)

a = torch.rand(3,2,3,3)
print(a)

print(nn.Batchnorm2d(2)(a))
#print(a[:,:,:])
mean_by_plane_feature = torch.mean(a,dim=0)
std_by_plane_feature = torch.std(a,dim=0)
print(mean_by_plane_feature)
print(std_by_plane_feature)

输出

tensor([[[[0.2961,0.5166,0.2517],[0.6886,0.0740,0.8665],[0.1366,0.1025,0.1841]],[[0.7264,0.3153,0.6871],[0.0756,0.1966,0.3164],[0.4017,0.1186,0.8274]]],[[[0.3821,0.6605,0.8536],[0.5932,0.6367,0.9826],[0.2745,0.6584,0.2775]],[[0.8573,0.8993,0.0390],[0.9268,0.7388,0.7179],[0.7058,0.9156,0.4340]]],[[[0.0772,0.3565,0.1479],[0.5331,0.4066,0.2318],[0.4545,0.9737,0.4606]],[[0.5159,0.4220,0.5786],[0.9455,0.8057,0.6775],[0.6087,0.6179,0.6932]]]])
tensor([[[[-0.5621,0.2574,-0.7273],[ 0.8968,-1.3879,1.5584],[-1.1552,-1.2819,-0.9787]],[[ 0.5369,-1.0117,0.3888],[-1.9141,-1.4584,-1.0073],[-0.6859,-1.7524,0.9171]]],[[[-0.2425,0.7925,1.5103],[ 0.5422,0.7042,1.9901],[-0.6425,0.7846,-0.6311]],[[ 1.0298,1.1880,-2.0520],[ 1.2915,0.5833,0.5047],[ 0.4593,1.2495,-0.5645]]],[[[-1.3761,-0.3375,-1.1132],[ 0.3187,-0.1512,-0.8011],[ 0.0269,1.9569,0.0493]],[[-0.2561,-0.6096,-0.0199],[ 1.3619,0.8356,0.3525],[ 0.0933,0.1281,0.4116]]]],grad_fn=<NativeBatchnormBackward>)
tensor([[[0.2518,0.5112,0.4177],[0.6049,0.3724,0.6937],[0.2885,0.5782,0.3074]],[[0.6999,0.5455,0.4349],[0.6493,0.5804,0.5706],[0.5721,0.5507,0.6515]]])
tensor([[[0.1572,0.1521,0.3810],[0.0784,0.2829,0.4042],[0.1594,0.4411,0.1406]],[[0.1723,0.3110,0.3471],[0.4969,0.3340,0.2211],[0.1553,0.4028,0.2000]]])

我发现 Batchnorm 的输出不是我所期望的。例如,第一个平面的跨批次平均值,第一个特征 = 0.2518,std 为 0.1572。第一个值的归一化值 = (0.2961-0.2518)/0.1572 = 0.2818 != -0.5621。

我的问题:

  1. 我以这种方式计算均值是否正确(跨批次、每个平面和特征)?据我了解,batchnorm 用于处理不同特征具有不同尺度的问题,因此它至少应该是每个特征维度,但是我不确定是否也对“平面维度”求和。

  2. 我需要做任何其他修改才能从 Batchnorm2d 获得相同的输出吗?

解决方法

这是pytorch中BatchNorm2d的实现(source1source2)。使用它,您可以验证您执行的操作。

class MyBatchNorm2d(nn.BatchNorm2d):
    def __init__(self,num_features,eps=1e-5,momentum=0.1,affine=True,track_running_stats=True):
        super(MyBatchNorm2d,self).__init__(
            num_features,eps,momentum,affine,track_running_stats)

    def forward(self,input):
        self._check_input_dim(input)

        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        # calculate running estimates
        if self.training:
            mean = input.mean([0,2,3])
            # use biased var in train
            var = input.var([0,3],unbiased=False)
            n = input.numel() / input.size(1)
            with torch.no_grad():
                self.running_mean = exponential_average_factor * mean\
                    + (1 - exponential_average_factor) * self.running_mean
                # update running_var with unbiased var
                self.running_var = exponential_average_factor * var * n / (n - 1)\
                    + (1 - exponential_average_factor) * self.running_var
        else:
            mean = self.running_mean
            var = self.running_var

        input = (input - mean[None,:,None,None]) / (torch.sqrt(var[None,None] + self.eps))
        if self.affine:
            input = input * self.weight[None,None] + self.bias[None,None]

        return input

nn.BatchNorm2d(2)(a)MyBatchNorm2d(2)(a) 的输出相同。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。