当前位置: 首页 > 科技观察

Soft Diffusion:谷歌新框架从通用扩散过程中正确调度、学习和采样

时间:2023-03-14 12:19:53 科技观察

SoftDiffusion:Google的新框架正确地调度、学习和从通用扩散过程中采样以生成样本。在YangSong等人的论文《Score-based generative modeling through stochastic differential equations》中,这两类模型被统一在一个框架下,被广泛称为扩散模型。目前,扩散模型在图像、音频和视频生成以及求解反问题等一系列应用中取得了巨大成功。在论文《Elucidating the design space of diffusionbased generative models》中,TeroKarras等研究人员分析了扩散模型的设计空间并确定了3个阶段,i)选择噪声级别的时间表,ii)选择网络参数化(每个参数化生成不同的损失函数),iii)设计采样算法。近日,在谷歌研究院与UT-Austin合作的一篇arXiv论文中,几位研究者认为扩散模型还有一个重要的步骤:腐败。总的来说,corruption是一个添加不同量级噪声的过程,对于DDMP也需要rescaling。尽管已经尝试使用不同的分布进行扩散,但仍然缺乏一个通用的框架。因此,我们提出了一个框架,用于为更一般的损伤过程设计扩散模型。具体来说,他们提出了一个名为SoftScoreMatching的新训练目标和一种新颖的采样方法MomentumSampler。理论结果表明,对于满足软分数匹配正则化条件的损坏过程能够学习它们的分数(即似然梯度),扩散必须将任何图像转换为具有非零似然的任何图像。在实验部分,研究人员在CelebA和CIFAR-10上训练模型,其中在CelebA上训练的模型线性扩散模型的SOTAFID得分为1.85。同时,研究人员训练的模型明显快于使用原始高斯去噪扩散训练的模型。论文地址:https://arxiv.org/pdf/2209.05442.pdf方法概述一般来说,扩散模型通过逆转逐渐增加噪声的腐败过程来生成图像。我们展示了如何学习反转涉及线性确定性退化和随机加性噪声的扩散。具体来说,研究人员展示了使用更通用的损伤模型来训练扩散模型的框架,该框架由三部分组成,即新的训练目标SoftScoreMatching、新颖的采样方法MomentumSampler和损伤机制的调度。我们先来看训练目标SoftScoreMatching,这个名字的灵感来自软过滤,一个摄影术语,指的是去除细节的滤镜。它以可证明的方式学习常规线性损坏过程的分数,还在网络中加入过滤过程,并训练模型预测与漫射观察相匹配的损坏图像。只要扩散将非零概率分配给任何一对干净、损坏的图像,这个训练目标就可以证明学习分数。此外,当损坏中存在附加噪声时,始终可以满足此条件。具体来说,研究人员以以下形式探索了损伤过程。在此过程中,研究人员发现噪声在经验上(即为了获得更好的结果)和理论上(即对于学习分数)都很重要。这也成为它与ColdDiffusion的关键区别,ColdDiffusion逆转确定性损坏的并发工作。第二种是采样方法MomentumSampling。研究人员证明,采样器的选择对生成样本的质量有重大影响。他们提出了MomentumSampler来反转一般的线性损坏过程。该采样器使用不同扩散级别的损坏凸组合,并受到优化中的动量方法的启发。这种采样方法的灵感来自YangSong等人提出的扩散模型的连续公式。上面的纸。MomentumSampler的算法如下图所示。下图直观地展示了不同采样方式对生成样本质量的影响。左侧使用NaiveSampler采样的图像看起来重复且缺乏细节,而右侧的MomentumSampler显着提高了采样质量和FID分数。最后是调度。即使退化的类型是预定义的(如模糊),决定在每个扩散步骤中损坏多少也不是微不足道的。研究人员提出了一种原则性工具来指导损伤过程的设计。为了找到时间表,他们最小化了沿路径分布之间的Wasserstein距离。直觉上,研究人员希望从完全损坏的发行版平稳过渡到干净的发行版。实验结果研究人员在CelebA-64和CIFAR-10上评??估了所提出的方法,这两者都是图像生成的标准基线。实验的主要目的是了解损坏类型的影响。研究人员首先试验了模糊和低振幅的噪音来检测腐败。结果表明,他们提出的模型在CelebA上取得了SOTA结果,FID得分为1.85,优于所有其他仅添加噪声并可能重新缩放图像的方法。另外,在CIFAR-10上得到的FID分数是4.64,虽然不是SOTA,但还是有竞争力的。此外,在CIFAR-10和CelebA数据集上,研究人员的方法在另一个指标采样时间上也表现更好。另一个额外的好处是显着的计算优势。去模糊(几乎没有噪声)似乎是比图像生成去噪方法更有效的操作。下图显示了FID分数如何随功能评估次数(NFE)变化。从结果中,我们可以看出,与CIFAR-10和CelebA数据集上的标准高斯去噪扩散模型相比,我们的模型可以使用更少的步骤实现相同或更好的质量。