新智元报道
编辑:英智 好困
【新智元导读】SANA 1.5是一种高效可扩展的线性扩散Transformer,针对文本生成图像任务进行了三项创新:高效的模型增长策略、深度剪枝和推理时扩展策略。这些创新不仅大幅降低了训练和推理成本,还在生成质量上达到了最先进的水平。
近年来,文本生成图像的技术不断突破,但随着模型规模的扩大,计算成本也随之急剧上升。
为此,英伟达联合MIT、清华、北大等机构的研究人员提出了一种高效可扩展的线性扩散Transformer——SANA,在大幅降低计算需求的情况下,还能保持有竞争力的性能。
SANA1.5在此基础上,聚焦了两个关键问题:
线性扩散Transformer的可扩展性如何?
在扩展大规模线性DiT时,怎样降低训练成本?
论文链接:https://arxiv.org/pdf/2501.18427
SANA 1.5:高效模型扩展三大创新
SANA 1.5在SANA 1.0(已被ICLR 2025接收)的基础上,有三项关键创新。
首先,研究者提出了一种高效的模型增长策略,使得SANA可以从1.6B(20层)扩展到4.8B(60层)参数,同时显著减少计算资源消耗,并结合了一种节省内存的8位优化器。
与传统的从头开始训练大模型不同,通过有策略地初始化额外模块,可以让大模型保留小模型的先验知识。与从头训练相比,这种方法能减少60%的训练时间。
其二,引入了模型深度剪枝技术,实现了高效的模型压缩。通过识别并保留关键的块,实现高效的模型压缩,然后通过微调快速恢复模型质量,实现灵活的模型配置。
其三,研究者提出了一种推理期间扩展策略,引入了重复采样策略,使得SANA在推理时通过计算而非参数扩展,使小模型也能达到大模型的生成质量。
通过生成多个样本,并利用基于视觉语言模型(VLM)的选择机制,将GenEval分数从0.72提升至0.80。
与从头开始训练大模型不同,研究者首先将一个包含N个Transformer层的基础模型扩展到N+M层(在实验中,N=20,M=40),同时保留其学到的知识。
在推理阶段,采用两种互补的方法,实现高效部署:
模型深度剪枝机制:识别并保留关键的Transformer块,从而在小的微调成本下,实现灵活的模型配置。
推理时扩展策略:借助重复采样和VLM引导选择,在计算资源和模型容量之间权衡。
同时,内存高效CAME-8bit优化器让单个消费级GPU上微调十亿级别的模型成为可能。
下图展示了这些组件如何在不同的计算资源预算下协同工作,实现高效扩展。
模型增长
研究者提出一种高效的模型增长策略,目的是对预训练的DiT模型进行扩展,把它从层增加到+层,同时保留模型已经学到的知识。
研究过程中,探索了三种初始化策略,最终选定部分保留初始化方法。这是因为该方法既简单又稳定。
在这个策略里,预训练的N层继续发挥特征提取的作用,而新增加的M层一开始是随机初始化,从恒等映射起步,慢慢学习优化特征表示。
实验结果显示,与循环扩展和块扩展策略相比,这种部分保留初始化方法在训练时的动态表现最为稳定。
模型剪枝
本文提出了一种模型深度剪枝方法,能高效地将大模型压缩成各种较小的配置,同时保持模型质量。
受Minitron启发,通过输入输出相似性模式分析块的重要性:
这里的表示第i个transformer的第t个token。
模型的头部和尾部块的重要性较高,而中间层的输入和输出特征相似性较高,表明这些层主要用于逐步优化生成的结果。根据排序后的块重要性,对transformer块进行剪枝。
剪枝会逐步削弱高频细节,因为,在剪枝后进一步微调模型,以弥补信息损失。
使用与大模型相同的训练损失来监督剪枝后的模型。剪枝模型的适配过程非常简单,仅需100步微调,剪枝后的1.6B参数模型就能达到与完整的4.8B参数模型相近的质量,并且优于SANA 1.0的1.6B模型。
推理时扩展
SANA 1.5经过充分训练,在高效扩展的基础上,生成能力有了显著提升。受LLM推理时扩展的启发,研究者也想通过这种方式,让SANA 1.5表现得更好。
对SANA和很多扩散模型来说,增加去噪步数是一种常见的推理时扩展方法。但实际上,这个方法不太理想。一方面,新增的去噪步骤没办法修正之前出现的错误;另一方面,生成质量很快就会达到瓶颈。
相较而言,增加采样次数是更有潜力的方向。
研究者用视觉语言模型(VLM)来判断生成图像和文本提示是否匹配。他们以NVILA-2B为基础模型,专门制作了一个数据集对其进行微调。
微调后的VLM能自动比较并评价生成的图像,经过多轮筛选,选出排名top-N的候选图像。这不仅确保了评选结果的可靠性,还能有效过滤与文本提示不匹配的图像。
模型增长、模型深度剪枝和推理扩展,构成了一个高效的模型扩展框架。三种方法协同配合,证明了精心设计的优化策略,远比单纯增加参数更有效。
模型增长策略探索了更大的优化空间,挖掘出更优质的特征表示。
模型深度剪枝精准识别并保留了关键特征,从而实现高效部署。
推理时间扩展表明,当模型容量有限时,借助额外的推理时间和计算资源,能让模型达到与大模型相似甚至更好的效果。
为了实现大模型的高效训练与微调,研究者对CAME进行扩展,引入按块8位量化,从而实现CAME-8bit优化器。
CAME-8bit相比AdamW-32bit减少了约8倍的内存使用,同时保持训练的稳定性。
该优化器不仅在预训练阶段效果显著,在单GPU微调场景中更是意义非凡。用RTX 4090这样的消费级GPU,就能轻松微调SANA 4.8B。
研究揭示了高效扩展不仅仅依赖于增加模型容量。通过充分利用小模型的知识,并设计模型的增长-剪枝,更高的生成质量并不一定需要更大的模型。
SANA 1.5 评估结果
实验表明,SANA 1.5的训练收敛速度比传统方法(扩大规模并从头开始训练)快2.5倍。
训练扩展策略将GenEval分数从0.66提升至0.72,并通过推理扩展将其进一步提高至0.80,在GenEval基准测试中达到了最先进的性能。
模型增长
将SANA-4.8B与当前最先进的文本生成图像方法进行了比较,结果如表所示。
从SANA-1.6B到4.8B的扩展带来了显著的改进:GenEval得分提升0.06(从0.66增加到0.72),FID降低0.34(从5.76降至5.42),DPG得分提升0.2(从84.8增加到85.0)。
和当前最先进的方法相比,SANA-4.8B模型的参数数量少很多,却能达到和大模型一样甚至更好的效果。
SANA-4.8B的GenEval得分为0.72,接近Playground v3的0.76。
在运行速度上,SANA-4.8B的延迟比FLUX-dev(23.0秒)低5.5倍;吞吐量为0.26样本/秒,是FLUX-dev(0.04样本/秒)的6.5倍,这使得SANA-4.8B在实际应用中更具优势。
模型剪枝
为了和SANA 1.0(1.6B)公平比较,此次训练的SANA 1.5(4.8B)模型,没有用高质量数据做监督微调。
所有结果都是针对512×512尺寸的图像评估得出的。经过修剪和微调的模型,仅用较低的计算成本,得分就达到了0.672,超过了从头训练模型的0.664。
推理时扩展
将推理扩展应用于SANA 1.5(4.8B)模型,并在GenEval基准上与其他大型图像生成模型进行了比较。
通过从2048张生成的图像中选择样本,经过推理扩展的模型在整体准确率上比单张图像生成提高了8%,在「颜色」「位置」和「归属」子任务上提升明显。
不仅如此,借助推理时扩展,SANA 1.5(4.8B)模型的整体准确率比Playground v3 (24B)高4%。
结果表明,即使模型容量有限,提高推理效率,也能提升模型生成图像的质量和准确性。
SANA:超高效文生图
在这里介绍一下之前的SANA工作。
SANA是一个超高效的文本生成图像框架,能生成高达4096×4096分辨率的图像,不仅画质清晰,还能让图像和输入文本精准匹配,而且生成速度超快,在笔记本电脑的GPU上就能运行。
SANA为何如此强大?这得益于它的创新设计:
深度压缩自动编码器:传统自动编码器压缩图像的能力有限,一般只能压缩8倍。而SANA的自动编码器能达到32倍压缩,大大减少了潜在tokens数量,计算效率也就更高了。
线性DiT:SANA用线性注意力替换了DiT中的标准注意力。在处理高分辨率图像时,速度更快,还不会降低图像质量。
仅解码文本编码器:SANA不用T5做文本编码器了,而是采用现代化的小型仅解码大模型。同时,通过上下文学习,设计出更贴合实际需求的指令,让生成的图像和输入文本对应得更好。
高效训练与采样:SANA提出了Flow-DPM-Solver方法,减少了采样步骤。再配合高效的字幕标注与选取,让模型更快收敛。
经过这些优化,SANA-0.6B表现十分出色。
它生成图像的质量和像Flux-12B这样的现代大型扩散模型差不多,但模型体积缩小了20倍,数据处理能力却提升了100倍以上。
SANA-0.6B运行要求不高,在只有16GB显存的笔记本GPU上就能运行,生成一张1024×1024分辨率的图像,用时不到1秒。
这意味着,创作者们用普通的笔记本电脑,就能轻松制作高质量图像,大大降低了内容创作的成本。
研究者提出新的深度压缩自动编码器,将压缩比例提升到32倍,和压缩比例为8倍的自动编码器相比,F32自动编码器生成的潜在tokens减少了16倍。
这一改进对于高效训练和超高分辨率图像生成,至关重要。
研究者提出一种全新的线性DiT,用线性注意力替代传统的二次复杂度注意力,将计算复杂度从原本的O(N²) 降低至O(N)。另一方面,在MLP层引入3×3深度可分卷积,增强潜在tokens的局部信息。
在生成效果上,线性注意力与传统注意力相当,在生成4K图像时,推理延迟降低了1.7倍。Mix-FFN结构让模型无需位置编码,也能生成高质量图像,这让它成为首个无需位置嵌入的DiT变体。
在文本编码器的选择上,研究者选用了仅解码的小型大语言模型Gemma,以此提升对提示词的理解与推理能力。相较于CLIP和T5,Gemma在文本理解和指令执行方面表现更为出色。
为充分发挥Gemma的优势,研究者优化训练稳定性,设计复杂人类指令,借助Gemma的上下文学习能力,进一步提高了图像与文本的匹配质量。
研究者提出一种自动标注与训练策略,借助多个视觉语言模型(VLM)生成多样化的重新描述文本。然后,运用基于CLIPScore的策略,筛选出CLIPScore较高的描述,以此增强模型的收敛性和对齐效果。
在推理环节,相较于Flow-Euler-Solver,Flow-DPM-Solver将推理步骤从28-50步缩减至14-20步,不仅提升了速度,生成效果也更为出色。
参考资料:
https://huggingface.co/papers/2501.18427
https://x.com/xieenze_jr/status/1885510823767875799
https://nvlabs.github.io/SANA/