Sign In

DeFT: Decoding with Flash Tree-attention for Efficient Tree-structured LLM Inference

Created by
  • Haebom
Category
Empty

저자

Jinwei Yao, Kaiqi Chen, Kexun Zhang, Jiaxuan You, Binhang Yuan, Zeke Wang, Tao Lin

개요

본 논문은 트리 구조를 갖는 여러 생성 호출을 처리하는 복잡한 작업(예: 퓨샷 프롬프팅, 다단계 추론, 예측적 디코딩 등)에 점점 더 많이 사용되는 대규모 언어 모델(LLM)의 추론 시스템의 비효율성 문제를 해결하기 위해 DeFT(Decoding with Flash Tree-Attention) 알고리즘을 제안합니다. 기존 시스템은 쿼리와 KV 캐시의 부적절한 분할로 인해 공유 접두사의 KV 캐시에 대한 메모리 접근(IO) 재사용 부족 및 부족한 부하 분산이라는 두 가지 주요 문제를 갖습니다. DeFT는 접두사 인식 및 부하 분산 KV 캐시 분할을 통해 이러한 문제를 해결합니다. 구체적으로, KV-Guided Grouping을 통해 공유 접두사의 KV 캐시를 반복적으로 로드하는 것을 피하고, Flattened Tree KV Splitting을 통해 KV 캐시를 균등하게 분산시켜 GPU 활용도를 높입니다. 실험 결과, DeFT는 기존 최첨단 알고리즘에 비해 최대 2.23/3.59배의 종단 간/어텐션 지연 시간 단축을 달성합니다.

시사점, 한계점

시사점:
트리 기반 LLM 응용 프로그램의 추론 속도를 크게 향상시킬 수 있는 효율적인 어텐션 알고리즘을 제시합니다.
KV-Guided Grouping과 Flattened Tree KV Splitting 기법을 통해 KV 캐시 IO를 획기적으로 줄이고 GPU 활용도를 높였습니다.
실제 트리 기반 작업에서 상당한 성능 향상을 보여줍니다.
공개된 코드를 통해 재현성을 확보합니다.
한계점:
제안된 알고리즘의 효율성은 특정 하드웨어 환경에 의존할 수 있습니다.
다양한 트리 구조 및 LLM 아키텍처에 대한 일반화 가능성에 대한 추가 연구가 필요합니다.
특정 작업에 대한 최적화가 이루어졌으므로 다른 유형의 작업에 대한 성능은 추가 검증이 필요합니다.
👍