当前位置: 首页 > 科技观察

轻松构建PyTorch生成对抗网络(GAN)

时间:2023-03-19 12:47:12 科技观察

您眼前这张图片中的人不是真实的,而是由机器学习模型创建的虚拟人。图片取自维基百科上的GAN词条,画面细节丰富,色彩鲜艳,令人印象深刻。生成对抗网络(GAN)是一种生成式机器学习模型,广泛应用于广告、游戏、娱乐、媒体、医药等行业。它可以用来创建虚构的人物和场景,模拟人脸老化,变换图像风格。,以及生成化学式等。下面两张图分别展示了图片到图片转换的效果和基于语义布局合成场景的效果。本文将从工程实践的角度,带领读者借助AWS机器学习相关的云计算服务,搭建第一个基于PyTorch机器学习框架的生成对抗网络,开启全新有趣的机器学习和人工智能体验.还等什么,马上开始吧!主要内容题目及解决方案概述模型开发环境生成对抗网络模型模型训练与验证结语与总结题目及解决方案概述下图两组手写数字图片,你能认出电脑生成的“手写”字体吗?是哪一组?这篇文章的题目是用机器学习的方法“模仿手写字体”。为了完成本主题,您将体验生成对抗网络的设计和实现。“仿手写字体”和人像生成的基本原理和工程过程基本相同。虽然它们在复杂度和精度要求上存在一定差距,但通过解决“模仿手写字体”的问题,可以提高生成对抗网络的原理和工程性。实践打好基础,之后可以逐步尝试和探索更复杂、更高级的网络架构和应用场景。《生成对抗网络》(GAN),由IanGoodfellow等人提出。2014年,是由生成网络和判别网络组成的深度神经网络架构。生成网络生成“虚假”数据并试图欺骗判别网络;判别网络验证生成数据的真实性,并尝试正确识别所有“虚假”数据。在训练迭代的过程中,两个网络不断进化对抗,直到达到均衡状态(参考:纳什均衡),判别网络无法再识别“虚假”数据,训练结束。2016年,AlecRadford等人发表的论文《深度卷积生成对抗网络》(DCGAN)。率先将卷积神经网络应用于生成对抗网络的模型算法设计,替代全连接层,提高图像场景训练的稳定性。性别。AmazonSageMaker是AWS的一项完全托管的机器学习服务。通过AmazonSageMaker可以快速轻松地完成数据处理和机器学习训练。经过训练的模型可以直接部署到完全托管的生产环境中。AmazonSageMaker提供了一个托管的JupyterNotebook实例,它通过SageMakerSDK与AWS的各种云服务集成,允许您访问数据源以进行探索和分析。SageMakerSDK是一个开源的AmazonSageMaker开发包,可以帮助您使用好AmazonSageMaker提供的托管容器镜像,以及AWS的其他云服务,如计算和存储资源。如上图所示,训练数据将来自AmazonS3存储桶;训练框架和托管算法以容器镜像的形式提供,在训练时与代码结合;模型代码运行在AmazonSageMaker管理的计算实例中,结合训练时的数据;培训输出进入专用的AmazonS3存储桶。在接下来的讲解中,我们将学习如何通过SageMakerSDK使用这些资源。我们会使用AmazonSageMaker、AmazonS3、AmazonEC2等AWS服务,会产生一定的云资源使用费。在模型开发环境中创建Notebook实例,请打开AmazonSageMaker控制面板(点击打开北京大区|宁夏大区),点击Notebook实例按钮进入notebook实例列表。如果您是第一次使用AmazonSageMaker,您的Notebook实例列表将显示为空列表,您需要单击创建笔记本实例按钮来创建一个新的JupyterNotebook实例。进入创建笔记本实例页面后,请在笔记本实例名称字段中输入实例名称。本文将使用MySageMakerInstance作为实例名称,您可以选择任何您认为合适的名称。本文将使用默认实例类型,因此Notebook实例类型选项将保留*ml.t2.medium*。如果您是第一次使用AmazonSageMaker,您需要创建一个IAM角色,以便笔记本实例可以访问AmazonS3服务。请在IAM角色选项中点击创建新角色。AmazonSageMaker将创建一个具有必要权限的角色,并将该角色分配给正在创建的实例。另外,根据你的实际情况,你也可以选择已有的角色。在CreateanIAMrole弹出窗口中,您可以选择*AnyS3bucket*以便笔记本实例可以访问您账户中的所有存储桶。另外,根据您的需要,您还可以选择SpecificS3buckets并输入bucket名称。单击“创建角色”按钮,将创建这个新角色。此时,您可以看到AmazonSageMaker已经为您创建了一个角色,名称类似于*AmazonSageMaker-ExecutionRole-****。其他字段可以使用默认值,请点击Createnotebookinstance按钮创建一个实例。回到Notebook实例页面,会看到MySageMakerInstance笔记本实例处于Pending状态,这个状态会持续2分钟左右,直到变为InService状态。编写第一行代码并单击打开JupyterLab链接。在新的页面上,你会看到熟悉的JupyterNotebook加载界面。本文默认使用JupyterLabnotebook作为项目环境,大家可以根据需要选择使用传统的Jupyternotebook。您将通过单击conda_pytorch_p36笔记本图标创建一个名为Untitled.ipynb的笔记本,稍后您可以更改其名称。或者,您也可以通过“文件”>“新建”>“笔记本”菜单路径创建此笔记本,然后选择conda_pytorch_p36作为内核。在新建的Untitled.ipynbnotebook中,我们输入第一行命令如下,importtorchprint(f"HelloPyTorch{torch.__version__}")源码下载请在notebook中输入如下命令,将代码下载到本地实例的文件系统。下载完成后,您可以通过文件浏览器浏览源代码结构。本文中涵盖的代码和笔记本已使用AmazonSageMaker托管的Python3.6、PyTorch1.4和JupyterLab进行了验证。本文涉及的代码和笔记本都可以在这里找到。生成对抗网络模型算法原理DCGAN模型的生成网络包含10层。它使用跨步转置卷积层来提高张量的分辨率。输入形状为(batchsize,100),输出形状为(batchsize,64,64,3)。换句话说,生成网络获取噪声向量并对其进行转换,直到生成最终图像。判别网络也包含10层。它接收一张(64,64,3)格式的图片,使用2D卷积层进行下采样,最后传给全连接层进行分类。分类结果为1或0,即真假。DCGAN模型的训练过程大致可以分为三个子过程。首先,Generator网络使用一个随机数作为输入,生成一张“假”图片;接下来,分别用“真”图片和“假”图片训练鉴别器网络,并更新参数;最后,更新生成器网络参数。代码分析项目目录byos-pytorch-gan的文件结构如下。文件model.py包含3个类,即Generator和Discriminator。classGenerator(nn.Module):...classDiscriminator(nn.Module):...classDCGAN(object):"""AwrapperclassforGeneratorandDiscriminator,'train_step'methodisforsinglebatchtraining."""...filetrain.pyforGeneratorandDiscriminator的两个神经网络的训练主要包括以下方法,defparse_args():...defget_datasets(dataset_name,...):...deftrain(dataloader,hps,...):...modeldebugging开发调试时,train.py脚本可以直接从Linux命令行运行。可以通过命令行参数指定超参数、输入数据通道、模型等训练输出存储目录。pythondcgan/train.py--datasetqmnist\--model-dir'/home/myhome/byom-pytorch-gan/model'\--output-dir'/home/myhome/byom-pytorch-gan/tmp'\--data-dir'/home/myhome/byom-pytorch-gan/data'\--hps'{"beta1":0.5,"dataset":"qmnist","epochs":15,"learning-rate":0.0002,"log-interval":64,"nc":1,"nz":100,"sample-interval":100}'这样的训练脚本参数设计,不仅提供了很好的调试方法,而且兼容SageMakerContainer集成的规范和必要条件考虑了模型开发的自由度和训练环境的可移植性。对于模型训练和验证,请找到并打开名为dcgan.ipynb的笔记本文件。本笔记本将介绍和执行培训过程。本节代码部分省略,请参考notebook代码。互联网环境下有很多公开的数据集,对机器学习的工程和科学研究,如算法学习、效果评估等有很大的帮助。我们将使用QMNIST手写字体数据集来训练模型,最终生成逼真的“手写”字体效果图案。数据准备PyTorch框架的torchvision.datasets包中提供了QMNIST数据集,您可以按照说明将QMNIST数据集下载到本地备份。fromtorchvisionimportdatasetsdataroot='./data'trainset=datasets.QMNIST(root=dataroot,train=True,download=True)testset=datasets.QMNIST(root=dataroot,train=False,download=True)AmazonSageMaker为您创建一个默认AmazonS3存储桶,用于存储机器学习工作流程中可能需要的各种文件和数据。我们可以通过SageMakerSDK中sagemaker.session.Session类的default_bucket方法获取这个bucket的名称。fromsagemaker.sessionimportSessionsess=Session()#S3bucketforsavingcodeandmodelartifacts.#Feelfreetospecifyadifferentbuckethereifyouwish.bucket=sess.default_bucket()SageMakerSDK提供用于操作AmazonS3服务的包和类。S3Downloader类用于访问或下载S3中的对象,而S3Uploader类用于将本地文件上传到S3。您将下载的数据上传到AmazonS3进行模型训练。在模型训练过程中,不从互联网下载数据,避免了通过互联网获取训练数据带来的网络延迟,也避免了直接访问互联网可能带来的模型训练安全风险。fromsagemaker.s3importS3Uploaderass3ups3_data_location=s3up.upload(f"{dataroot}/QMNIST",f"s3://{bucket}/data/qmnist")通过sagemaker.getexecutionrole()方法训练执行,可以预先分配当前notebook给notebook实例的角色,这个角色会用来获取训练资源,比如下载训练框架镜像,分配AmazonEC2计算资源等等。可以在notebook中定义用于训练模型的超参数,实现与算法代码的分离,在创建训练任务时传入超参数,与训练任务动态结合。hps={“学习率”:0.0002,“epochs”:15,“dataset”:“qmnist”,“beta1”:0.5,“sample-interval”:200,“log-interval”:64}sagemaker.pytorch包中的PyTorch类是基于PyTorch框架的模型拟合器,可以用来创建和执行训练任务,也可以部署训练好的模型。参数列表中train_instance_type用于指定CPU或GPU实例类型,训练脚本和模型代码所在的目录由source_dir指定,训练脚本的文件名必须由entry_point明确定义。这些参数会和其他参数一起传递给训练任务,它们决定了训练任务的运行环境和模型训练时的参数。fromsagemaker.pytorchimportPyTorchestimator=PyTorch(role=role,entry_point='train.py',source_dir='dcgan',output_path=s3_model_artifacts_location,code_location=s3_custom_code_upload_location,train_instance_count=1,train_instance_type='ml.train_max_wait=86400,framework.version='1.40',py_version='py3',hyperparameters=hps)请特别注意train_use_spot_instances参数,True值代表你想优先使用SPOT实例。由于机器学习训练通常需要大量的计算资源来长时间运行,利用好SPOT可以帮助您实现有效的成本控制。SPOT实例的价格可能是按需实例价格的20%到60%,具体取决于选择的实例类型、区域和时间。实际价格会有所不同。现在您已经创建了一个PyTorch对象,是时候将其与AmazonS3上预先存在的数据相匹配了。以下命令将执行训练任务,训练数据将作为名为QMNIST的输入通道导入到训练环境中。训练开始时,会将AmazonS3上的训练数据下载到模型训练环境的本地文件系统中,训练脚本train.py会从本地磁盘加载数据进行训练。#Starttrainingestimator.fit({'QMNIST':s3_data_location},wait=False)根据您选择的训练实例,训练过程可能持续数十分钟到数小时不等。建议将wait参数设置为False。此选项会将笔记本与训练任务分开。在训练时间长、训练日志多的场景下,可以避免因网络中断或session超时导致notebook上下文丢失。训练任务离开笔记本后,输出会暂时不可见。可以执行下面的代码,notebook会获取并加载之前的训练session。模型设计考虑了GPU加速训练的能力,因此使用GPU实例进行训练会比CPU实例更快。例如,一个p3.2xlarge实例大约需要15分钟,而一个c5.xlarge实例可能需要6个多小时。目前的模型不支持分布式和并行训练,所以多实例和多CPU/GPU不会带来更多的训练速度提升。训练完成后,模型将上传到AmazonS3,位于创建PyTorch对象时提供的output_path参数指定的位置。模型验证您将从AmazonS3下载训练好的模型到notebook所在实例的本地文件系统。下面的代码会加载模型,然后输入一个随机数得到推理结果,并以图片的形式显示出来。执行以下命令加载训练好的模型,通过该模型生成一组“手写”数字字体。fromhelperimport*importmatplotlib.pyplotaspltimportnumpyasnpimporttorchfromdcgan.modelimportGeneratordevice=torch.device("cuda:0"iftorch.cuda.is_available()else"cpu")params={'nz':nz,'nc':nc,'ngf':ngf}model=load_model(Generator,params,"./model/generator_state.pth",device=device)img=generate_fake_handwriting(model,batch_size=batch_size,nz=nz,device=device)plt.imshow(np.asarray(img))结论和总结PyTorch框架,近年来发展迅速,正在被广泛认可和应用。越来越多的新模型采用PyTorch框架,部分模型迁移到PyTorch,或完全基于PyTorch重新实现。生态环境不断丰富,应用领域不断拓展。PyTorch已经成为事实上的主流框架之一。AmazonSageMaker与多种AWS服务紧密集成,例如各种类型和大小的AmazonEC2计算实例、AmazonS3、AmazonECR等,为机器学习工程实践提供端到端、一致的体验。AmazonSageMaker继续支持主要的机器学习框架,PyTorch就是其中之一。使用PyTorch开发的机器学习算法和模型可以轻松移植到AmazonSageMaker的工程和服务环境中,然后使用AmazonSageMaker全托管的JupyterNotebook、训练容器镜像、服务容器镜像、训练任务管理、部署环境托管等功能,简化机器学习工程复杂度,提高生产效率,降低运维成本。DCGAN是生成对抗网络领域的一个里程碑,是当今许多复杂生成对抗网络的基石。文章开头提到的StyleGAN,用文本合成图像的StackGAN,用草图生成图像的Pix2pix,还有网上颇受争议的DeepFakes等等,都有DCGAN的影子。相信通过本文的介绍和工程实践,将有助于大家理解生成对抗网络的原理和工程方法。