Shen Nie, Fengqi Zhu, Chao Du, Tianyu Pang, Qian Liu, Guangtao Zeng, Min Lin, Chongxuan Li
개요
본 논문은 Masked Diffusion Models (MDMs)의 확장성과 주요 언어 과제(텍스트 생성 및 이해)에서의 효과를 최초로 규명합니다. 자기회귀 모델(ARMs)과 비교 가능한 확장 속도를 보이며, 상대적으로 작은 컴퓨팅 격차를 갖는 MDM의 확장 법칙을 제시합니다. 최대 11억 개의 파라미터를 가진 MDM 계열을 훈련하여 동일하거나 더 큰 크기의 ARM과 성능을 체계적으로 비교 평가합니다. MDM의 확률적 공식을 완전히 활용하여, 대규모 비짝 데이터를 효과적으로 활용하는 간단하면서도 효과적인 비지도 분류기 없는 안내(classifier-free guidance)를 제안합니다. 언어 이해 측면에서, 11억 파라미터의 MDM은 동일한 데이터로 훈련된 11억 파라미터의 TinyLlama 모델을 8개의 제로샷 벤치마크 중 4개에서 능가합니다. 특히, GSM8K 데이터셋에서 70억 파라미터의 Llama-2 모델과 경쟁력 있는 수학 추론 능력을 달성합니다. 텍스트 생성에서, 16배 더 많은 사전 훈련 시간을 가진 MDM은 KV-Cache 가속 샘플링 기법을 통해 ARM과 유연한 절충안을 제공합니다. MDM은 ARM과 성능이 동일하면서 샘플링 속도는 1.4배 빠릅니다. 또한, MDM은 양방향 추론을 효과적으로 처리하고 데이터의 시간적 변화에 적응함으로써 ARM에 어려운 과제를 해결합니다. 특히, 11억 파라미터의 MDM은 130억 파라미터의 Llama-2 및 1750억 파라미터의 GPT-3와 같이 훨씬 더 많은 데이터와 계산을 사용하는 훨씬 큰 ARM에서 발생하는 역설적 저주(reverse curse)를 극복합니다. 코드는 https://github.com/ML-GSAI/SMDM 에서 확인할 수 있습니다.