微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何计算cifar10数据的均值和标准差

如何解决如何计算cifar10数据的均值和标准差

Pytorch 使用以下值作为 cifar10 数据的均值和标准值: transforms.normalize((0.5,0.5,0.5),(0.5,0.5))

我需要理解计算背后的概念,因为这个数据是 3 通道图像,我不明白什么是相加和除以什么等等。 另外,如果有人可以分享计算平均值和标准的代码,那就太感谢了。

解决方法

0.5 值只是三个通道 (r,g,b) 上 cifar10 平均值和标准值的近似值。 cifar10 训练集的精确值是

  • 意思是:0.49139968,0.48215827,0.44653124
  • 标准:0.24703233 0.24348505 0.26158768

您可以使用以下脚本计算这些:

import torch
import numpy
import torchvision.datasets as datasets
from torchvision import transforms

cifar_trainset = datasets.CIFAR10(root='./data',train=True,download=True,transform=transforms.ToTensor())

imgs = [item[0] for item in cifar_trainset] # item[0] and item[1] are image and its label
imgs = torch.stack(imgs,dim=0).numpy()

# calculate mean over each channel (r,b)
mean_r = imgs[:,:,:].mean()
mean_g = imgs[:,1,:].mean()
mean_b = imgs[:,2,:].mean()
print(mean_r,mean_g,mean_b)

# calculate std over each channel (r,b)
std_r = imgs[:,:].std()
std_g = imgs[:,:].std()
std_b = imgs[:,:].std()
print(std_r,std_g,std_b)

此外,您可能会发现相同的均值和标准值 herehere

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。