当前位置: 首页 > 科技观察

涵盖了18+个SOTAGAN的实现,这个图像生成领域的库已经成为流行的

时间:2023-03-20 19:58:19 科技观察

GAN。自提出以来,迅速受到广泛关注。我们可以把GAN分为两类,一类是无条件生成;另一种是基于条件信息的生成。近日,韩国浦项科技大学的一名硕士生在GitHub上开源了一个项目,该项目提供了一个具有代表性的生成对抗网络(GAN)的实现,用于条件/无条件图像生成。最近机器之心在GitHub上看到一个很有意义的项目PyTorch-StudioGAN,它是一个PyTorch库,提供了代表生成对抗网络(GAN)的条件/无条件图像生成的实现。根据主页介绍,该项目旨在提供一个统一的现代GAN平台,以便机器学习领域的研究人员能够快速比较和分析新的想法和方法。本项目作者为韩国浦项科技大学硕士研究生。他的研究兴趣主要包括深度学习、机器学习和计算机视觉。项目地址:https://github.com/POSTECH-CVLab/PyTorch-StudioGAN具体而言,本项目具有以下显着特点:提供大量PyTorch框架的GAN实现;基于CIFAR10、TinyImageNet和ImageNet数据集的GAN基准;与原始实现相比,性能更好,内存消耗更低;使用完全最新的PyTorch环境预训练模型;支持多GPU(DP、DDP和多节点DDP)、混合精度、simultaneousbatchnormalizationVisualization、LARS、Tensorboardvisualization等分析方法。对于这个PyTorchGAN库,有网友表示:“看起来很不错!如果能提供top-k和各种增强方法等现代训练实践就更好了。”对此,项目作者表示将提交NeurIPS论文。在该日期之后,添加了一些改进的方法,例如Sinha等人的Tok-KtrainingwithLangevinsampling和SimCLRaugmentation。此外,还有网友询问该项目是否可以用于图像以外的其他领域。作者说是的,即使有些稳定器(如diffaug、ada等)不能使用,你仍然可以通过调整dataLoader来训练自己的模型。18+SOTAGAN实现如下图所示。项目作者提供了18+个SOTAGAN实现,包括DCGAN、LSGAN、GGAN、WGAN-WC、WGAN-GP、WGAN-DRA、ACGAN、ProjGAN、SNGAN、SAGAN、BigGAN、BigGAN-Deep、CRGAN、ICRGAN、LOGAN、DiffAugGAN、ADAGAN、ContraGAN和FreezeD。cBN:条件批量归一化;AC:辅助分类器;PD:投影鉴别器;CL:对比学习。其中,需要注意以下几点:G/D_type表示将标签信息注入生成器或判别器的方式;EMA表示更新的指数移动平均线应用于生成器;TinyImageNet数据集上的实验使用ResNet架构而不是CNN。下图中,StyleGAN2就是即将到来的GAN网络,其中AdaIN代表AdaptiveInstanceNormalization。环境要求AnacondaPython>=3.66.0.0<=Pillow<=7.0.0scipy==1.1.0sklearnseabornh5pytqdmtorch>=1.6.0torchvision>=0.7.0tensorboard5.4.0<=gcc<=7.4.0torchlars用户可以使用以下方法安装推荐环境:condaenvcreate-fenvironment.yml-nstudiogan也可以在docker中使用以下方法:dockerpullmgkang/studiogan:latest下面是创建名为“studioGAN”的容器的命令,也可以使用端口号为6006来连接tensoreboard。dockerrun-it--gpusall--shm-size128g-p6006:6006--namestudioGAN-v/home/USER:/root/code--workdir/root/codemgkang/studiogan:latest/bin/bashusageusingGPU0在此case,模型的训练“-t”和评估“-e”在CONFIG_PATH中定义:CUDA_VISIBLE_DEVICES=0python3src/main.py-t-e-cCONFIG_PATHisusingGPU(0,1,2,3)andDataParallel接下来,训练CONFIG_PATH中定义模型的“-t”和求值“-e”:CUDA_VISIBLE_DEVICES=0,1,2,3python3src/main.py-t-e-cCONFIG_PATH在python3src/main.py程序中查看可用选项,你可以通过Tensorboard监控IS、FID、F_beta、AuthenticityAccuracies和最大奇异值:~PyTorch-StudioGAN/logs/RUN_NAME>>>tensorboard--logdir=./--portPORT可视化和分析生成图像StudioGAN支持图像可视化,k最近邻分析、线性差分和频率分析。所有结果都保存在“./figures/RUN_NAME/*.png”中。图像可视化的代码和例子如下:CUDA_VISIBLE_DEVICES=0,...,Npython3src/main.py-iv-std_stat--standing_stepSTANDING_STEP-cCONFIG_PATH--checkpoint_folderCHECKPOINT_FOLDER--log_output_pathLOG_OUTPUT_PATHk最近邻分析,这里固定K=7,在第一列是生成的图像:CUDA_VISIBLE_DEVICES=0,...,Npython3src/main.py-knn-std_stat--standing_stepSTANDING_STEP-cCONFIG_PATH--checkpoint_folderCHECKPOINT_FOLDER--log_output_pathLOG_OUTPUT_PATH线性插值的代码和示例(仅限有条件的BigResNet模型)如下:CUDA_VISIBLE_DEVICES=0,...,Npython3src/main.py-itp-std_stat--standing_stepSTANDING_STEP-cCONFIG_PATH--checkpoint_folderCHECKPOINT_FOLDER--log_output_pathLOG_OUTPUT_PATH