본 논문은 장문 시퀀스 추론 작업(텍스트 또는 이미지/비디오 생성 등)에서 대규모 모델의 어텐션 계산 속도를 높이기 위해 Flash Attention 기반의 저정밀 수치적으로 동등한 알고리즘인 PASA를 개발했습니다. PASA는 온라인 의사 평균 이동 및 전역 복구라는 두 가지 새로운 기법을 도입하여 오버플로 불안정성이나 용인할 수 없는 수치 정확도 손실 없이 Flash Attention 프로세스 전반에 걸쳐 반정밀 계산을 가능하게 합니다. 이 알고리즘은 데이터 이동을 줄이고 계산 FLOPs를 증가시킴으로써 Ascend NPU와 같은 메모리 제한 AI 하드웨어 아키텍처에서 성능을 향상시킵니다. 설계된 랜덤 벤치마크와 실제 대규모 모델을 사용하여 알고리즘을 검증했습니다. 대규모 모델(Qwen2-7B 언어 모델 및 Stable-Video-Diffusion 다중 모드 모델)의 두 가지 범주에서 어텐션 입력 데이터의 큰 편향과 진폭이 반정밀도에서 수치 오버플로($>65504$)에 기여하는 중요한 요소임을 발견했습니다. 특히, 오버플로는 시퀀스 차원의 큰 편향과 Stable-Video-Diffusion 모델의 헤드 차원에서 쿼리와 키 사이의 공진 메커니즘으로 인해 발생합니다. 공진 메커니즘은 쿼리와 키 행렬 간의 위상 일치 또는 180도 위상 이동으로 정의되며, 어텐션 점수 행렬의 요소 값을 현저하게 증폭시킵니다. 이 문제는 Qwen 모델에도 적용됩니다. 또한, root mean square error (RMSE)와 고정밀 어텐션을 사용하여 생성된 최종 텍스트 및 비디오를 비교하여 수치 정확도를 평가했습니다.
시사점, 한계점
•
시사점:
◦
메모리 제약이 있는 AI 하드웨어(예: Ascend NPU)에서 장문 시퀀스 추론 작업의 성능을 향상시키는 저정밀 어텐션 계산 알고리즘 PASA 제시.
◦
온라인 의사 평균 이동 및 전역 복구 기법을 통해 반정밀 계산에서 오버플로 및 정확도 손실 문제 해결.
◦
대규모 언어 모델 및 다중 모드 모델에서의 어텐션 계산 오버플로 원인 분석 및 해결 방안 제시.
•
한계점:
◦
PASA 알고리즘의 성능은 특정 하드웨어 아키텍처(Ascend NPU)에 최적화되어 다른 아키텍처에서의 일반화 가능성에 대한 추가 연구 필요.
◦
현재 검증된 모델 외 다른 유형의 대규모 모델에 대한 적용 가능성 및 성능 평가 추가 연구 필요.