Sign In

Dispatch-Aware Ragged Attention for Pruned Vision Transformers

์ž‘์„ฑ์ž
  • Haebom
์นดํ…Œ๊ณ ๋ฆฌ
Empty

์ €์ž

Seifeldin Abdellatif, Ahmad Almasri

๐Ÿ’ก ๊ฐœ์š”

๋ณธ ๋…ผ๋ฌธ์€ Vision Transformer(ViT)์—์„œ ํ† ํฐ ๊ฐ€์ง€์น˜๊ธฐ(pruning)๋ฅผ ํ†ตํ•ด ์—ฐ์‚ฐ๋Ÿ‰์„ ์ค„์ด๋”๋ผ๋„, ์งง์€ ์‹œํ€€์Šค ๊ธธ์ด์—์„œ ๋ฐœ์ƒํ•˜๋Š” ์ปค๋„ ๋””์ŠคํŒจ์น˜ ์˜ค๋ฒ„ํ—ค๋“œ๋กœ ์ธํ•ด ์‹ค์ œ ์†๋„ ํ–ฅ์ƒ์ด ๋‚˜ํƒ€๋‚˜์ง€ ์•Š๋Š” ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฅผ ์œ„ํ•ด ๊ธฐ์กด FlashAttention-2๋ณด๋‹ค ๋‚ฎ์€ ๋””์ŠคํŒจ์น˜ ์˜ค๋ฒ„ํ—ค๋“œ๋ฅผ ๊ฐ€์ง€๋Š” ๊ฒฝ๋Ÿ‰ํ™”๋œ ์–‘๋ฐฉํ–ฅ Triton ์–ดํ…์…˜ ์ปค๋„์„ ์ œ์•ˆํ•ฉ๋‹ˆ๋‹ค. ์ œ์•ˆ๋œ ๋ฐฉ๋ฒ•์€ ํŒจ๋”ฉ ๊ธฐ๋ฐ˜ PyTorch SDPA ๋Œ€๋น„ ์ตœ๋Œ€ 2.51๋ฐฐ์˜ ์ฒ˜๋ฆฌ๋Ÿ‰ ํ–ฅ์ƒ์„ ๋‹ฌ์„ฑํ•˜๋ฉฐ, ๊ฐ€์ง€์น˜๊ธฐ๋œ ViT์˜ ํšจ์œจ์„ฑ์„ ๋†’์ž…๋‹ˆ๋‹ค.

๐Ÿ”‘ ์‹œ์‚ฌ์  ๋ฐ ํ•œ๊ณ„

โ€ข
ViT์˜ ํ† ํฐ ๊ฐ€์ง€์น˜๊ธฐ์—์„œ ๋ฐœ์ƒํ•˜๋Š” ์‹ค์ œ ์†๋„ ํ–ฅ์ƒ์€ ์ปค๋„ ๋””์ŠคํŒจ์น˜ ์˜ค๋ฒ„ํ—ค๋“œ์— ์˜ํ•ด ์ œํ•œ๋  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ์ด๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•œ ํšจ์œจ์ ์ธ ์ปค๋„ ๊ตฌํ˜„์ด ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค.
โ€ข
์ œ์•ˆ๋œ Dispatch-Aware Ragged Attention ์ปค๋„์€ ์งง์€ ์‹œํ€€์Šค ๊ธธ์ด์—์„œ๋„ ๊ฐ€์ง€์น˜๊ธฐ์˜ ์ด์ ์„ ์‚ด๋ ค ํšจ์œจ์„ฑ์„ ํฌ๊ฒŒ ํ–ฅ์ƒ์‹œํ‚ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
โ€ข
์‹คํ—˜ ๊ฒฐ๊ณผ๋Š” ์ œ์•ˆ๋œ ์‹œ์Šคํ…œ์ด ๋‹ค์–‘ํ•œ ์ž…๋ ฅ ํฌ๊ธฐ ๋ฐ ๊ฐ€์ง€์น˜๊ธฐ์œจ์—์„œ ์šฐ์ˆ˜ํ•œ ์„ฑ๋Šฅ์„ ๋ณด์ด๋ฉฐ, ์ˆ˜์น˜์  ์ •ํ™•์„ฑ ๋˜ํ•œ ๊ฒ€์ฆ๋˜์—ˆ์Œ์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.
โ€ข
์ œ์•ˆ๋œ ์ปค๋„์ด ํŠน์ • ํ•˜๋“œ์›จ์–ด(NVIDIA RTX 4000 Ada Generation GPU)์— ์ตœ์ ํ™”๋˜์–ด ์žˆ์–ด, ๋‹ค๋ฅธ ์•„ํ‚คํ…์ฒ˜์—์„œ์˜ ์„ฑ๋Šฅ ๊ฒ€์ฆ ๋ฐ ์ผ๋ฐ˜ํ™” ๊ฐ€๋Šฅ์„ฑ์— ๋Œ€ํ•œ ์ถ”๊ฐ€ ์—ฐ๊ตฌ๊ฐ€ ํ•„์š”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
๐Ÿ‘