最近,CV研究人员对transformer产生了浓厚的兴趣,并取得了很多突破。这表明Transformer有潜力成为计算机视觉任务(如分类、检测和分割)的强大通用模型。我们都很好奇:Transformer在计算机视觉领域能走多远?Transformer如何执行更困难的视觉任务,例如生成对抗网络(GAN)?在这种好奇心的驱使下,得克萨斯大学奥斯汀分校的姜一帆和王章扬以及IBM研究院的ShiyuChang等研究人员进行了首次实验研究,构建了一个完全没有卷积的纯transformer架构。GAN并将其命名为TransGAN。与其他基于transformer的视觉模型相比,仅使用transformer构建GAN似乎更具挑战性,因为与分类等任务相比,真实图像生成的门槛更高,而且GAN训练本身的高度不稳定性。论文链接:https://arxiv.org/pdf/2102.07074.pdf代码链接:https://github.com/VITA-Group/TransGAN从结构上看,TransGAN由两部分组成:一是记忆-friendlytransformer-based一个可以逐渐提高特征分辨率同时降低嵌入维度的生成器;另一个是基于变压器的补丁级鉴别器。研究人员还发现,TransGAN显着受益于数据增强(超过标准GAN)、生成器的多任务协同训练策略,以及强调自然图像邻域平滑度的局部初始化自注意力。这些发现表明TransGAN可以有效地扩展到更大的模型和具有更高分辨率的图像数据集。实验结果表明,与当前基于卷积主干的SOTAGAN相比,最先进的TransGAN实现了极具竞争力的性能。具体来说,TransGAN在STL-10上实现了IS得分为10.10和FID为25.32的新SOTA。这项研究表明,GAN可能不需要依赖卷积主干和许多专用模块,并且纯变换器足以生成图像。在论文的相关讨论中,有读者开玩笑说,“注意力真的正在变成‘你所需要的’”。不过,也有研究人员表达了他们的担忧:在Transformer席卷整个社区的大背景下,仅凭力量,Bo的小实验室将如何生存?如果transformer真的成为“刚需”社区,如何提高这类架构的计算效率将成为一个难点的研究问题。TransformerEncoderBasedonPureTransformer-basedGANsastheBasicBlock研究人员选择使用Transformer编码器(Vaswanietal.,2017)作为基本块,变化最小。编码器由两个组件组成,第一个组件由多头自注意力模块构成,第二个组件是具有GELU非线性的前馈MLP(多层感知器)。此外,研究人员在两个组件之前应用了层归一化(Baetal.,2016)。这两个组件也使用剩余连接。记忆友好型生成器NLP中的Transformer将每个单词作为输入(Devlin等人,2018年)。然而,如果图像是通过以类似方式堆叠Transformer编码器逐像素生成的,低分辨率图像(例如32×32)也可能导致长序列(1024)和更昂贵的自注意力开销。因此,为了避免过多的开销,研究人员受到基于CNN的GAN中常见设计思想的启发,在多个阶段迭代地提高分辨率(Denton等人,2015年;Karras等人,2017年)。他们的策略是逐渐增加输入序列并降低嵌入维数。如下图1左侧所示,研究人员提出了一种内存友好的、基于Transformer的由多个阶段组成的生成器:每个阶段堆叠了多个编码器块(默认情况下为5、2和2)。通过分段设计,研究人员逐渐提高特征图分辨率,直到达到目标分辨率H_T×W_T。具体来说,生成器将随机噪声作为输入,并通过MLP将随机噪声传递给长度为H×W×C的向量。这个向量被转化为一个分辨率为H×W(默认H=W=8)的featuremap,每个点都是一个C维的embedding。然后将该特征图视为长度为64的C维标记序列,并与可学习的位置编码相结合。与BERT(Devlin等人,2018)类似,提出的Transformer编码器将嵌入标记作为输入并递归计算每个标记之间的匹配。为了合成更高分辨率的图像,研究人员在每个阶段后插入了一个由重塑和像素洗牌模块组成的上采样模块。具体操作上,upsampling模块首先将1D序列的tokenembedding转化为2Dfeaturemap,然后使用pixelshuffle模块对2Dfeaturemap的分辨率进行上采样,并对embedding维度进行下采样,最终得到输出。然后,二维特征映射X'_0再次转换为嵌入标记的一维序列,其中标记数为4HW,嵌入维度为C/4。因此,在每个阶段,分辨率(H,W)都会增加一倍,而嵌入维数C会减少到输入的四分之一。这种权衡策略缓和了内存和计算需求的激增。研究人员在多个阶段重复上述过程,直到达到分辨率(H_T,W_T)。然后,他们将嵌入维度投影到3并获得RGB图像。判别器的标记化输入与那些需要准确合成每个像素的生成器不同,本研究中提出的判别器只需要区分真假图像。这允许研究人员将输入图像语义标记化到更粗糙的补丁级别(Dosovitskiy等人,2020年)。如上图1右侧所示,鉴别器将图像的一个补丁作为输入。研究人员将输入图像分解为8×8的小块,其中每个小块都可以看作是一个“词”。然后,8×8patches通过线性展平层转化为一维的tokenembeddings序列,其中token的数量N=8×8=64,嵌入维度为C。之后,研究人员添加了一个可学习的位置代码和一维序列开头的[cls]标记。通过Transformer编码器后,分类头仅使用[cls]标记输出真假预测。CIFAR-10上的实验结果研究人员在CIFAR-10数据集上对比了TransGAN和最近基于卷积的GAN研究,结果如下表5所示:如上表5所示,TransGAN优于AutoGAN(Gong等人,2019),在IS分数方面也优于许多竞争对手,如SN-GAN(Miyatoetal.,2018),改进MMDGAN(Wangetal.,2018a),MGAN(Hoangetal.,2018)。TransGAN仅次于ProgressiveGAN和StyleGANv2。比较FID结果,发现TransGAN甚至优于ProgressiveGAN,略低于StyleGANv2(Karrasetal.,2020b)。在CIFAR-10上生成的可视化示例如下图4所示:STL-10上的结果研究人员将TransGAN应用于另一个流行的48×48分辨率基准,STL-10。为了适应目标分辨率,本研究将第一阶段的输入特征图从(8×8)=64增加到(12×12)=144,然后将提出的TransGAN-XL与自动搜索的ConvNets和比较了手工制作的ConvNet,结果如下表6所示:与CIFAR-10上的结果不同,本研究发现TransGAN优于所有当前模型,并在IS和FID分数方面实现了新的SOTA性能。高分辨率生成由于TransGAN在标准基准CIFAR-10和STL-10上取得了良好的性能,研究人员将TransGAN用于更具挑战性的数据集CelebA64×64,结果如下表10所示:TransGAN-XLFID12.23的分数表明TransGAN-XL适用于高分辨率任务。可视化结果如图4所示。局限性尽管TransGAN取得了不错的成绩,但与最好的手工设计的GAN相比,它仍有很大的改进空间。在论文的最后,作者指出了以下具体的改进方向:对G和D进行更复杂的分词操作,例如使用一些语义分组(Wuetal.,2020)。使用前置任务对Transformer进行预训练可能会改进本研究中现有的MT-CT。更强大的注意力形式,例如(Zhuetal.,2020)。一种更高效的自注意力形式(Wangetal.,2020;Choromanskietal.,2020),它不仅有助于提高模型效率,还可以节省内存开销,从而有助于生成更高分辨率的图像。作者简介本文第一作者江一凡是德克萨斯大学奥斯汀分校电气与计算机工程系博士一年级学生(此前曾在德克萨斯农工大学学习一年)。毕业于华中科技大学,获学士学位。深度学习等。目前主要从事神经结构搜索、视频理解、高级表示学习等领域的研究,师从大学电气与计算机工程系助理教授王章阳得克萨斯州在奥斯汀。本科期间,姜一凡在字节跳动人工智能实验室实习。今年夏天,他将作为实习生加入谷歌研究院。一帆首页:https://yifanjiang.net/
