新智元报道
编辑:LRS 好困
【新智元导读】SANA-Sprint是一个高效的蒸馏扩散模型,专为超快速文本到图像生成而设计。通过结合连续时间一致性蒸馏(sCM)和潜空间对抗蒸馏(LADD)的混合蒸馏策略,SANA-Sprint在一步内实现了7.59 FID和0.74 GenEval的最先进性能。SANA-Sprint仅需0.1秒即可在H100上生成高质量的1024x1024图像,在速度和质量的权衡方面树立了新的标杆。
扩散生成模型通常需要50-100次迭代去噪步骤,效率很低,时间步蒸馏技术可以极大提高推理效率,「基于分布的蒸馏」方法,如生成对抗网络GAN及其变分分数蒸馏VSD变体,以及「基于轨迹的蒸馏方法」(如直接蒸馏、渐进蒸馏、一致性模型)可以实现10-100倍的图像生成加速效果。
但仍然存在一些关键难点,比如基于GAN的方法由于对抗动态的振荡特性和模式坍塌问题,训练过程不稳定;基于VSD的方法需要联合训练一个额外的扩散模型,增加了计算开销;一致性模型虽然稳定,但在极少数步骤(例如少于4步)的情况下,生成质量会下降。
如何开发一个能够兼顾效率、灵活性和质量的蒸馏框架成了模型部署的关键。
论文地址: https://arxiv.org/pdf/2503.09641
项目主页:https://nvlabs.github.io/Sana/Sprint/
基于OpenAI提出的连续时间一致性模型(sCM)的方法,研究人员提出SANA-Sprint,进一步结合了LADD的对抗蒸馏技术,帮助模型在蒸馏过程中更好地保留细节信息,从而实现超快速且高质量的文本到图像生成,同时避免了离散化带来的误差,保留了传统一致性模型的优势。
SANA-Sprint的核心在于其创新的混合蒸馏框架和对ControlNet的集成,主要贡献包括:
混合蒸馏框架:设计了一种新颖的混合蒸馏框架,将预训练的流匹配模型无缝转换为TrigFlow模型,集成了连续时间一致性模型(sCM)和潜在对抗扩散蒸馏(LADD)。
sCM确保了模型与教师模型的一致性和多样性保留,而LADD则增强了单步生成的保真度,从而实现了统一的步长自适应采样。
卓越的速度/质量权衡:SANA-Sprint仅需1-4步即可实现卓越的性能。在H100上,SANA-Sprint仅需0.10-0.18秒即可生成1024x1024的图像,在MJHQ-30K数据集上实现了7.59的FID和0.74的GenEval分数,超越了FLUX-schnell(7.94FID/0.71GenEval),速度提升了10倍。
实时交互式生成:通过将ControlNet与SANA-Sprint集成,实现了在H100上仅需0.25秒的实时交互式图像生成。这为需要即时视觉反馈的应用(如ControlNet引导的图像生成/编辑)提供了可能,实现了更好的人机交互。
SANA-Sprint不仅在速度和性能上表现出色,生成的图像质量也非常高。
SANA-Sprint
SANA-Sprint方法主要包括以下四个关键步骤:
1. 无训练转换到TrigFlow
研究人员提出了一种简单的方法,通过直接的数学输入和输出转换,将预训练的流匹配模型转换为TrigFlow模型。这使得可以直接使用已有的预训练模型,无需额外的TrigFlow模型的训练。
动机是,虽然sCM使用TrigFlow公式简化了连续时间一致性模型的训练,但大多数基于分数的生成模型(如扩散模型和流匹配模型)并不直接支持TrigFlow。
为了克服这一挑战,SANA-Sprint提出了一种无需重新训练的转换方法,通过数学变换将流匹配模型转换TrigFlow模型,从而避免了复杂的额外算法设计和额外的计算成本。
2. 混合蒸馏策略
混合蒸馏策略结合了sCM和LADD两种蒸馏方法。sCM利用TrigFlow的公式简化了连续时间一致性模型的训练,而LADD则通过对抗训练在潜在空间中直接进行判别,进一步提升了生成质量。
3. 稳定训练的关键技术
密集时间嵌入(Dense Time-Embedding):为了稳定连续时间一致性模型的训练,SANA-Sprint采用了密集时间嵌入设计。通过将噪声系数 调整为
Query-Key归一化(QK-Normalization):在Transformer模型的自注意力和交叉注意力机制中引入了RMS归一化,进一步稳定了训练过程,尤其是在大模型和高分辨率场景下。
4. 集成ControlNet
将SANA-Sprint的训练流程应用于ControlNet任务,利用图像和文本提示作为条件,实现了SANA-ControlNet模型,并通过蒸馏得到SANA-Sprint-ControlNet,支持实时的图像编辑和生成。
实验结果
研究人员采用了两阶段的训练策略,详细的设置和评估协议在论文附录中进行了概述。
教师模型通过剪枝和微调SANA-1.5 4.8B模型得到,然后使用文中提出的训练范式进行蒸馏,使用包括FID、MJHQ-30K上的CLIP Score和GenEval在内的指标评估性能。
实验结果表明,SANA-Sprint在速度和质量方面均达到了最先进的水平。
效率与性能对比:在4步推理下,SANA-Sprint 0.6B实现了5.34个样本/秒的吞吐量和0.32秒的延迟,FID为6.48,GenEval为0.76;SANA-Sprint 1.6B 的吞吐量略低(5.20个样本/秒),但GenEval提升至0.77,优于更大的模型如FLUX-schnell 12B,其吞吐量仅为0.5个样本/秒,延迟为2.10秒。
单步生成性能:SANA-Sprint在单步生成方面也表现出色,实现了7.59的FID和0.74的GenEval分数,超越了其他单步生成方法。
实时交互式生成:集成ControlNet的SANA-Sprint模型在H100上实现了约200毫秒的推理速度,支持近乎实时的交互。
结论与展望
SANA-Sprint是一款高效的扩散模型,用于超快速的单步文本到图像生成,同时保留了多步采样的灵活性。通过采用结合了连续时间一致性蒸馏(sCM)和潜在对抗蒸馏(LADD)的混合蒸馏策略,SANA-Sprint在一步内实现了7.59的FID和0.74的GenEval分数,无需针对特定步骤进行训练。
该统一的步长自适应模型仅需0.1秒即可在H100上生成高质量的1024x1024图像,在速度和质量的权衡方面树立了新的标杆。
展望未来,SANA-Sprint的即时反馈特性将为实时交互应用(如响应迅速的创意工具和AIPC)开启新的可能性。
参考资料:
https://nvlabs.github.io/Sana/Sprint/