Sign In
AI

Llama 3

E
Eunyoung Lee
이번 글에서는 메타의 LLM인 Llama3에 대해 살펴보도록 하겠습니다.

LLaMA(Large Language Model Meta AI) 모델

메타의 오픈소스 LLM 모델로 2023년 2월에 Llama1이 처음 발표되었다. 트랜스포머 아키텍처를 기반으로 하며, Llama2의 경우 Llama1에서 학습 데이터를 40% 더 많이 사용하였다. 7B, 13B, 70B 등 여러 파라미터 사이즈의 모델이 존재하며, Alpaca와 Vinuca 같이 수많은 파생 모델이 존재한다.

GPT-3와의 차이점

GPT-3 아키텍처
GeLU 대신 SwiGLU 활성화 함수 사용
Swish와 GLU의 조합인 활성화 함수로 실험적으로 성능이 뛰어나서 사용
Swish는 학습 가능한 파라미터인 $β$ 값에 따라 다른 특성을 가지는 활성화 함수
GLU(Gated Linear Units)는 모델이 신경망에서 정보가 흐르는 것을 선택적으로 조절할 수 있도록 해주는 방법
SwiGLU에서 β, W, b는 모두 학습 가능한 파라미터
SwiGLU(x) = x * sigmoid(β * x) + (**1** - sigmoid(β * x)) * (Wx + b)
절대 위치 임베딩 대신 Rotary Positional Embedding(RoPE) 사용
트랜스포머는 위치를 고려 못하기 때문에 위치 임베딩 추가 필요
기존 트랜스포머는 절대 위치 임베딩을 사용하여 위치에 따른 사인함수와 코사인 함수 값을 더하는 방식을 사용
RoPE는 시퀀스의 각 위치마다 고유한 값의 회전을 통하여 절대 위치 임베딩과 상대 위치 임베딩을 통합하는 방법
Root-Mean-Square(제곱평균제곱근) 레이어 정규화 사용
모델이 재스케일링에 영향을 받지 않고, 적절한 학습률을 자동으로 찾아가는 암시적 학습률 적응 능력을 갖게 됨
기본 레이어 정규화에 비하여 계산이 효율적이고 안정적
Grouped Query Attention 사용
기존 Multihead Attention에서는 계산된 key와 value 벡터를 디코딩 단계에 쓸 수 있도록 저장하는 Key-Value caching 때문에 연산에 비용이 많이 소요되고, 디코딩 단계마다 캐시를 로드하고 업데이트 해야 하기 때문에 메모리 오버헤드 발생
Grouped Query Attention은 Query를 그룹으로 나누고 각 그룹이 Key와 Value를 공유하도록 하여 긴 컨텍스트에서 어텐션 계산을 더 빨리 할 수 있도록 함
Llama2에서는 70B에만 GQA를 적용했지만, GQA는 모델 학습 이후에도 적용하여 추론 속도 개선 가능

Llama2 프롬프트 템플릿

Instruction 튜닝을 할 때 사용하는 포맷으로, LLM 모델이 다양한 태스크(질의응답, 번역, 요약 등)를 실행할 수 있도록 지시를 사용하여 튜닝하는 방법이다. Llama2는 시작 토큰으로 <s>, 끝 토큰으로 </s>를 사용한다.
<s>[INST] <<SYS>>\\n{system}\\n<</SYS>>\\n\\n{user}[/INST]</s>

Llama 3

메타에서 2024년 4월 18일에 발표한 LLaMA 시리즈 모델로, 8B와 70B 두 개의 모델을 발표했으며 현재 학습 중인 400B 모델은 차후 발표 예정이다. 현재 Llama3는 영어 위주 모델이지만, 차후 멀티링구얼과 멀티모달, 그리고 긴 컨텍스트를 다룰 수 있도록 학습할 예정이다. Instruction 튜닝된 모델로 LLaMA3 8B는 Gemma 7B나 Mistral 7B 보다 성능이 뛰어나고, LLaMA3 70B는 Gemini 1.5 및 Claude 3 Sonnet보다 성능이 월등히 뛰어나다.

모델 아키텍처

Llama2와 비교하여 달라진 점
기존 32K에서 더 커진 128K 토큰을 가진 효율적으로 인코딩이 가능한 토크나이저 사용
토큰 사이즈가 커지며 다국어 임베딩 품질이 더 좋아짐
기존 4096에서 더 커진 8192 시퀀스 길이를 갖고 있어 더 긴 텍스트를 입력으로 받을 수 있음
8B 모델에서도 Group Query Attention 사용
토큰 효율성이 높아져 같은 텍스트에서 15% 더 적은 토큰을 사용

Llama3 학습

Llama2보다 7배 더 큰 공개적으로 사용이 가능한 150억 개가 넘는 토큰으로 학습
150억 개의 토큰을 학습했을 때까지 모델 성능이 log-linear하게 향상됨
5%는 높은 퀄리티의 영어가 아닌 30개가 넘는 언어로 이루어진 다국어 데이터
8B 모델은 2023년 3월까지, 70B 모델은 2023년 12월까지의 데이터를 학습
Llama2를 통해 좋은 품질의 데이터를 걸러낼 수 있어 여러가지의 필터를 만들어 사용
24,000 개의 H100 GPU로 구성된 두 개의 클러스터에서 학습
Llama3 8B의 경우 130만 GPU 시간 소요되었으며 70B는 640만 GPU 시간 소요
데이터 병렬 학습, 모델 병렬 학습, 파이프라인 병렬 학습을 모두 사용
Supervised Fine-tuning, Rejection Sampling, 강화학습(PPO, DPO) 사용하여 추가 학습

책임성과 안전성

안전성을 위해 SFT와 RLHF를 사용하여 인간의 선호도를 반영
위험한 답변을 끌어내려할 때 적대적인 프롬프트를 생성하도록 하는 레드 팀 접근 사용
Llama3 8B 기반의 LLM 안전보호 모델인 Llama Guard 2를 통해 위험 카테고리 11가지로 구별 가능
코드 인터프리터 남용 방지, 안전한 커멘드 실행, 안전한 코드를 위한 추론 시 필터링인 Code Shield 2 추가

오픈소스

Llama3를 학습하는 데 1000억 달러가 들었지만 모델을 오픈소스로 공개함
강력한 AI가 소수의 손에 집중되는 것은 위험할 수 있기 때문에 모델을 공개하고 균형 잡힌 경쟁을 보장
월간 활성 사용자 수가 7억 이상일 경우 메타의 라이센스 허가 하에 상업적 사용 가능
페이스북, 인스타그램, 왓츠앱, 페이스북 메신저에서 Llama3의 기반 Meta AI 어시스턴트 무료로 사용 가능

Llama3 Instruction 모델 프롬프트 템플릿

<|begin_of_text|><|start_header_id|>system<|end_header_id|> {{ system_prompt }}<|eot_id|><|start_header_id|>user<|end_header_id|> {{ user_message_1 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|> {{ model_answer_1 }}<|eot_id|><|start_header_id|>user<|end_header_id|> {{ user_message_2 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
<|begin_of_text|>: 시작 토큰
<|end_of_text|>: 끝 토큰
<|eot_id|>: 멀티 턴 대화 형식에서 메시지가 끝날 때마다 사용됨
<|start_header_id|>{role}<|end_header_id|>: system, user, assistant와 같이 역할 부여에 사용

huggingface에서 사용 방법

import transformers import torch model_id = "meta-llama/Meta-Llama-3-8B-Instruct" pipeline = transformers.pipeline( "text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device="cuda", ) messages = [ {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"}, {"role": "user", "content": "Who are you?"}, ] prompt = pipeline.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) terminators = [ pipeline.tokenizer.eos_token_id, pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>") ] outputs = pipeline( prompt, max_new_tokens=256, eos_token_id=terminators, do_sample=True, temperature=0.6, top_p=0.9, ) print(outputs[0]["generated_text"][len(prompt):])
참조
Kp
Subscribe to 'KPMG Lighthouse'
Subscribe to my site to be the first to receive notifications and emails about the latest updates, including new posts.
Join Slashpage and subscribe to 'KPMG Lighthouse'!
Subscribe
👍🏻
1
😍
1