如何解决如何完成WGAN训练网络?
我使用 DCGAN 合成医学图像(512*512)。但是,目前 DCGAN 太不稳定了。因此,我正在尝试将我的 DCGAN 网络更改为 WGAN。
How to increase image_size in DCGAN
数据和参数
# Root directory for dataset
daTaroot = f"./processed/{grade}/{grade}/"
# Number of workers for DataLoader
workers = 4
# Batch size during training
batch_size = 128
# Spatial size of training images. All images will be resized to this
# size using a transformer.
image_size = 512
# Number of channels in the training images. For color images this is 3
nc = 3
# Size of z latent vector (i.e. size of generator input)
nz = 100
# Size of feature maps in generator
ngf = 16
# Size of feature maps in discriminator
ndf = 16
# Number of training epochs
num_epochs = 500
# Learning rate for optimizers
lr = 0.0002
# Beta1 hyperparam for Adam optimizers
beta1 = 0.5
# Number of GPUs available. Use 0 for cpu mode.
ngpu = 2
# WGAN clip gradient
clamp_num=0.01
我改变了 weight_init()
def weight_init(m):
# weight_initialization: important for wgan
class_name=m.__class__.__name__
if class_name.find('Conv')!=-1:
m.weight.data.normal_(0,0.02)
elif class_name.find('norm')!=-1:
m.weight.data.normal_(1.0,0.02)
变更生成器
class Generator(nn.Module):
def __init__(self,ngpu):
super(Generator,self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is Z,going into a convolution
nn.ConvTranspose2d(nz,ngf * 64,4,1,bias=False),nn.Batchnorm2d(ngf * 64),nn.ReLU(True),# state size. (ngf*64) x 4 x 4
nn.ConvTranspose2d(ngf * 64,ngf * 32,2,nn.Batchnorm2d(ngf * 32),# state size. (ngf*32) x 8 x 8
nn.ConvTranspose2d(ngf * 32,ngf * 16,nn.Batchnorm2d(ngf * 16),# state size. (ngf*16) x 16 x 16
nn.ConvTranspose2d(ngf * 16,ngf * 8,nn.Batchnorm2d(ngf * 8),# state size. (ngf*8) x 32 x 32
nn.ConvTranspose2d(ngf * 8,ngf * 4,nn.Batchnorm2d(ngf * 4),# state size. (ngf*4) x 64 x 64
nn.ConvTranspose2d(ngf * 4,ngf * 2,nn.Batchnorm2d(ngf * 2),# state size. (ngf*2) x 128 x 128
nn.ConvTranspose2d(ngf * 2,ngf,nn.Batchnorm2d(ngf),# state size. (ngf) x 256 x 256
nn.ConvTranspose2d( ngf,nc,nn.Tanh()
# state size. (nc) x 512 x 512
)
def forward(self,x):
return self.main(x)
和鉴别器
class discriminator(nn.Module):
def __init__(self,ngpu):
super(discriminator,self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is (nc) x 512 x 512
nn.Conv2d(nc,ndf,nn.LeakyReLU(0.2,inplace=True),# state size. (ndf) x 256 x 256
nn.Conv2d(ndf,ndf * 2,nn.Batchnorm2d(ndf * 2),# state size. (ndf*2) x 128 x 128
nn.Conv2d(ndf * 2,ndf * 4,nn.Batchnorm2d(ndf * 4),# state size. (ndf*4) x 64 x 64
nn.Conv2d(ndf * 4,ndf * 8,nn.Batchnorm2d(ndf * 8),# state size. (ndf*8) x 32 x 32
nn.Conv2d(ndf * 8,ndf * 16,nn.Batchnorm2d(ndf * 16),# state size. (ndf*16) x 16 x 16
nn.Conv2d(ndf * 16,ndf * 32,nn.Batchnorm2d(ndf * 32),# state size. (ndf*32) x 8 x 8
nn.Conv2d(ndf * 32,ndf * 64,nn.Batchnorm2d(ndf * 64),# state size. (ndf*64) x 4 x 4
nn.Conv2d(ndf * 64,# Modification 1: remove sigmoid
# nn.Sigmoid()
)
def forward(self,x):
return self.main(x)
另外,更改优化器
from torch.optim import RMSprop
# modification 2: Use RMSprop instead of Adam
optimizerD = RMSprop(netD.parameters(),lr=lr )
optimizerG = RMSprop(netG.parameters(),lr=lr )
# modification3: No Log in loss
# criterion = nn.bceloss()
# Create batch of latent vectors that we will use to visualize
# the progression of the generator
fixed_noise = torch.randn(64,nz,device=device)
# Establish convention for real and fake labels during training
real_label = 1
fake_label = 0
最后,训练代码如下。我猜训练代码有问题。 (另外,我没有改变打印部分的任何内容)我希望有人可以帮助我如何更改训练代码以在这个意义上运行 WGAN。
# Training Loop
one=torch.FloatTensor([1]).cuda()
mone=-1*one.cuda()
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
num_epochs = 1000
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
# For each batch in the DataLoader
for i,data in enumerate(DataLoader,0):
for parm in netD.parameters():
parm.data.clamp_(clamp_num,clamp_num)
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
## Train with all-real batch
#print(epoch)
netD.zero_grad()
# Format batch
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,),real_label,device=device).float()
#print(real_cpu.shape)
output = netD(real_cpu).view(-1).float()
# Calculate loss on all-real batc
output.backward(one)
# Calculate gradients for D in backward pass
D_x = output.mean().item()
## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size,device=device)
# Generate fake image batch with G
fake = netG(noise)
label.fill_(fake_label)
#print(fake.detach())
# Classify all fake batch with D
output2 = netD(fake.detach()).view(-1).float()
# Calculate D's loss on the all-fake batch
output2.backward(mone)
# Calculate the gradients for this batch
D_G_z1 = output.mean().item()
# Add the gradients from the all-real and all-fake batches
# Update D
optimizerD.step()
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
# Since we just updated D,perform another forward pass of all-fake batch through D
output2 = netD(fake.detach()).view(-1).float()
#output = netD(fake).view(-1)
# Calculate G's loss based on this output
#errG = criterion(output,label)
# Calculate gradients for G
output2.backward()
D_G_z2 = output2.mean().item()
# Update G
optimizerG.step()
# Output training stats
if i % 1000 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch,num_epochs,i,len(DataLoader),errD.item(),errG.item(),D_x,D_G_z1,D_G_z2))
# Save Losses for plotting later
G_losses.append(errG.item())
D_losses.append(errD.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(DataLoader)-1)):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(utils.make_grid(fake,padding=0,normalize=True))
iters += 1
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。