大家好。最近大家都在玩AI画画。在GitHub上找到一个开源项目,分享给大家。今天分享的项目是通过使用GAN生成对抗网络来实现的。我们之前分享过很多关于GAN原理和实战的文章。想了解更多的朋友可以去看历史文章。文末获取源码和数据集,下面分享如何训练和运行项目。1.准备环境,安装tensorflow-gpu1.15.0,GPU显卡使用2080Ti,cuda版本10.0。git下载项目AnimeGANv2源代码。环境搭建完成后,需要准备dataset和vgg19。下载dataset.zip压缩文件,里面包含6k真人图片和2k漫画图片,用于GAN训练。Vgg19是用来计算loss的,下面会详细介绍。2.网络模型生成对抗网络需要定义两个模型,一个是生成器,一个是判别器。生成器网络定义如下:withtf.variable_scope('A'):inputs=Conv2DNormLReLU(inputs,32,7)inputs=Conv2DNormLReLU(inputs,64,strides=2)inputs=Conv2DNormLReLU(inputs,64)withtf.variable_scope('B'):inputs=Conv2DNormLReLU(inputs,128,strides=2)inputs=Conv2DNormLReLU(inputs,128)withtf.variable_scope('C'):inputs=Conv2DNormLReLU(inputs,128)inputs=self.InvertedRes_block(输入,2,256,1,'r1')inputs=self.InvertedRes_block(inputs,2,256,1,'r2')inputs=self.InvertedRes_block(inputs,2,256,1,'r3')inputs=self.InvertedRes_block(inputs,2,256,1,'r4')inputs=Conv2DNormLReLU(inputs,128)withtf.variable_scope('D'):inputs=Unsample(inputs,128)inputs=Conv2DNormLReLU(inputs,128)withtf.variable_scope('E'):inputs=Unsample(inputs,64)inputs=Conv2DNormLReLU(inputs,64)inputs=Conv2DNormLReLU(inputs,32,7)withtf.variable_scope('out_layer'):out=Conv2D(输入,过滤器=3,kernel_size=1,strides=1)self.fake=tf.tanh(out)generator中的主要模块是reverseresidualblock残差结构(a)和reverseresidualblock(b)判别器网络结构如下:defD_net(x_init,ch,n_dis,sn,scope,reuse):channel=ch//2withtf.variable_scope(scope,reuse=reuse):x=conv(x_init,channel,kernel=3,stride=1,pad=1,use_bias=False,sn=sn,scope='conv_0')x=lrelu(x,0.2)foriinrange(1,n_dis):x=conv(x,channel*2,kernel=3,stride=2,pad=1,use_bias=False,sn=sn,scope='conv_s2_'+str(i))x=lrelu(x,0.2)x=conv(x,channel*4,kernel=3,stride=1,pad=1,use_bias=False,sn=sn,scope='conv_s1_'+str(i))x=layer_norm(x,scope='1_norm_'+str(i))x=lrelu(x,0.2)channel=channel*2x=conv(x,channel*2,kernel=3,stride=1,pad=1,use_bias=False,sn=sn,scope='last_conv')x=layer_norm(x,scope='2_ins_norm')x=lrelu(x,0.2)x=conv(x,channels=1,kernel=3,stride=1,pad=1,use_bias=False,sn=sn,scope='D_logit')返回x3。Loss在计算loss之前,使用VGG19网络对图像进行矢量化处理。这个过程有点像NLP中的Embedding操作。Eembedding是将词转为向量,VGG19是将图片转为向量。VGG19定义损失部分的计算逻辑如下:defcon_sty_loss(vgg,real,anime,fake):#真实图像矢量化vgg.build(real)real_feature_map=vgg.conv4_4_no_activation#生成图像矢量化vgg.build(fake)fake_feature_map=vgg.conv4_4_no_activation#漫画风格矢量化vgg.build(anime[:fake_feature_map.shape[0]])anime_feature_map=vgg.conv4_4_no_activation#真实图片和生成图片的损失c_loss=L1_loss(real_feature_map,fake_feature_map)#漫画风格和生成图片Losss_loss=style_loss(anime_feature_map,fake_feature_map)returnc_loss,s_loss这里使用vgg19计算真实图片(参数real)和生成图片(参数fake),生成图片(参数fake)和漫画风格的loss(参数动漫)损失。c_loss,s_loss=con_sty_loss(self.vgg,self.real,self.anime_gray,self.generated)t_loss=self.con_weight*c_loss+self.sty_weight*s_loss+color_loss(self.real,self.generated)*self.color_weight+tv_loss最后赋予这两个loss不同的权重,让generator生成的图片既保留真实图片的外观,又迁移到漫画风格4.训练在项目目录下执行如下命令开始训练python火车。py--datasetHayao--epoch101--init_epoch10运行成功后就可以看到数据了。同时也可以看出损失在不断的减少。源码和数据集已经打包,需要的朋友可以在评论区留言。如果您觉得本文对您有用,请点个赞鼓励一下。以后会继续分享优秀的Python+AI项目。
