当前位置: 首页 > 科技观察

简单使用PyTorch搭建GAN模型

时间:2023-03-17 19:01:33 科技观察

过去,人们普遍认为生成图像是不可能完成的任务,因为按照传统的机器学习思想,我们没有groundtruth来检验生成的图像是否合格。2014年,Goodfellow等人提出了生成对抗网络(GenerativeAdversarialNetwork,简称GAN),可以让我们完全依靠机器学习来生成极其逼真的图片。GAN的出现震惊了整个人工智能行业,计算机视觉和图像生成领域发生了巨大的变化。本文将带你了解GAN的工作原理,并介绍如何通过PyTorch轻松上手GAN。GAN的原理沿用了传统方法,模型的预测结果可以直接与已有的真实值进行比较。然而,我们很难定义和衡量什么被认为是“正确”生成的图像。古德费洛等人。提出了一个有趣的解决方案:我们可以先训练一个分类工具来自动区分生成的图像和真实图像。这样,我们就可以用这个分类工具来训练一个生成网络,直到它输出的图像完全是假的,连分类工具本身都没有办法判断真伪。按照这种思路,我们有一个GAN:一个生成器和一个鉴别器。生成器负责从给定的数据集生成图像,鉴别器负责区分图像是真还是假。GAN的运行过程如上图所示。Lossfunction在GAN的运行过程中,我们可以发现一个明显的矛盾:生成器和判别器很难同时优化。可以想象,这两个模型有着完全相反的目标:生成器想要尽可能地伪造真实的东西,而鉴别器必须看穿生成器生成的图像。为了说明这一点,我们令D(x)为鉴别器的输出,即x为真实图像的概率,令G(z)为生成器的输出。判别器类似于一个二元分类器,所以它的目标是最大化这个函数的结果:这个函数本质上是一个非负的二元交叉熵损失函数。另一方面,生成器的目标是最小化鉴别器做出正确决定的机会,因此它的目标是最小化上述函数的结果。因此,最终的损失函数将是两个分类器之间的极小极大博弈,表述如下:理论上,博弈的最终结果将是判别器判断成功的概率收敛到0.5。然而在实践中,极小极大博弈往往会导致网络不收敛,因此谨慎调整模型训练的参数非常重要。WhentrainingGAN,weshouldpayspecialattentiontohyperparameterssuchaslearningrate.ArelativelysmalllearningratecanallowGANtohaveamoreuniformoutputevenwhenthereisalotofinputnoise.ComputingenvironmentlibraryThisarticlewillguideyoutobuildtheentireprogram(includingtorchvision)throughPyTorch.Atthesametime,wewilluseMatplotlibtovisualizetheresultsofGANgeneration.以下代码能够导入上述所有库:"""ImportnecessarylibrariestocreateagenerativeadversarialnetworkThecodeismainlydevelopedusingthePyTorchlibrary"""importtimeimporttorchimporttorch.nnasnnimporttorch.optimasoptimfromtorch.utils.dataimportDataLoaderfromtorchvisionimportdatasetsfromtorchvision.transformsimporttransformsfrommodelimportdiscriminator,generatorimportnumpyasnpimportmatplotlib.pyplotasplt数据集数据集对于训练GAN来说非常重要,尤其考虑到我们在GAN中处理Itisusuallyunstructureddata(usuallypictures,videos,etc.),andanyclasscanhavedatadistribution.ThisdistributionofdataispreciselythebasisfortheoutputgeneratedbytheGAN.InordertobetterdemonstratetheconstructionprocessofGAN,thisarticlewilltakeyoutousethesimplestMNISTdataset,whichcontains60,000picturesofhandwrittenArabicnumerals.High-qualityunstructureddatasetslikeMNISTcanbefoundonGewuti'spublicdatasetwebsite.Infact,theGewuTitaniumOpenDatasetsplatformcoversmanyhigh-qualitypublicdatasets,andcanalsoimplementdatasethostingandone-stopsearchfunctions.ThisisaverypracticalcommunityplatformforAIdevelopers.HardwareRequirementsIngeneral,whileitispossibletouseaCPUtotrainaneuralnetwork,thebestchoiceisactuallyaGPUbecauseitcangreatlyincreasethetrainingspeed.我们可以用下面的代码来测试我们的机器是否可以用GPU训练:"""DetermineifanyGPUsareavailable"""device=torch.device('cuda'iftorch.cuda.is_available()else'cpu')来实现网络结构由于数字是非常简单的信息,我们可以将判别器和生成器都组合成全连接层。我们可以使用下面的代码在PyTorch中构建判别器和生成器:self.fc1=nn.线性(784,512)self.fc2=nn.Linear(512,1)self.activation=nn.LeakyReLU(0.1)defforward(self,x):x=x.view(-1,784)x=self.activation(self.fc1(x))x=self.fc2(x)returnnn.Sigmoid()(x)classgenerator(nn.Module):def__init__(self):super(generator,self).__init__()self.fc1=nn.Linear(128,1024)self.fc2=nn.Linear(1024,2048)self.fc3=nn.Linear(2048,784)self.activation=nn.ReLU()defforward(self,x):x=self.activation(self.fc1(x))x=self.activation(self.fc2(x))x=self.fc3(x)x=x.view(-1,1,28,28)returnnn.Tanh()(x)Training在训练GAN时,我们需要在改进生成器的同时优化判别器,所以每次迭代都需要同时优化两个相互矛盾的损失函数。对于生成器,我们将输入一些随机噪声,让生成器根据小噪声改变输出图像:"""网络训练过程每一步都更新鉴别器和生成器的损失鉴别器的目的是对真假进行分类生成器的目的生成尽可能真实的图像"""forepochinrange(epochs):foridx,(imgs(train,_)x)inumerate1#Trainingthediscriminator#RealinputsareactualimagesoftheMNISTdataset#Fakeinputsarefromthegenerator#Realinputsshouldbeclassifiedas1andfakeas0real_inputs=imgs.to(device)real_outputs=D(real_inputs)real_label=torch.ones(real_inputs.shape[0],1).inputs(devtor.shape[0],128)-0.5)/0.5noise=noise.to(device)fake_inputs=G(noise)fake_outputs=D(fake_inputs)fake_label=torch.zeros(fake_inputs.shape[0],1).to(device)outputs=torch.cat((real_outputs,fake_outputs),0)targets=torch.cat((real_label,fake_label),0)D_loss=损失(输出,目标)D_optimizer.zero_grad()D_loss.backward()D_optimizer.step()#Trainingthegenerator#Forgen发生器,目标istomakethediscriminatorbelieveeverythingis1noise=(torch.rand(real_inputs.shape[0],128)-0.5)/0.5noise=noise.to(device)fake_inputs=G(noise)fake_outputs=D(fake_inputs)fake_targets=torch.ones([fake_inputs.shape[0],1]).to(device)G_loss=loss(fake_outputs,fake_targets)G_optimizer.zero_grad()G_loss.backward()G_optimizer.step()ifidx%100==0oridx==len(train_loader):打印('Epoch{}Iteration{}:discriminator_loss{:.3f}generator_loss{:.3f}'.format(epoch,idx,D_loss.item(),G_loss.item()))if(epoch+1)%10==0:torch.save(G,'Generator_epoch_{}.pth'.format(epoch))print('Modelsaved.')经过100个训练期后,我们可以可视化数据集,直接看到由来自随机噪声的模型:我们可以看到生成的结果与真实数据非常相似。考虑到我们这里只搭建了一个非常简单的模型,实际应用效果会有非常大的提升空间。不仅与GAN和以往的机器视觉专家有别样的思路,GAN在具体场景的应用更是让很多人惊叹深度网络的无限潜力。让我们来看看最著名的两个GAN扩展应用。朱俊彦等人2017年发表的CycleGAN可以直接将一张图片从X域转换到Y域,无需图片配对,比如把马变成斑马,把炎热的夏天变成隆冬,把莫奈的画变成范的画梵高等。这些看似不可能的转换,CycleGAN可以轻松完成,而且结果非常准确。GauGANNvidia使用GAN让人们只需几笔就可以勾勒出自己的想法,然后得到一幅极其逼真的真实场景图。虽然这一应用所需的计算成本极高,但GauGAN以其转化能力探索了前所未有的研究和应用领域。结语相信看到这里,你已经知道了GAN的大致工作原理,你可以很轻松的自己搭建一个GAN。