김형 BLOG

Tiled MLP 대규모 언어 모델의 긴 시퀀스 학습

대규모 언어 모델(LLM) 연구에서 긴 시퀀스(Long Sequence) 처리 능력은 모델의 성능을 결정짓는 중요한 요소입니다. 하지만 시퀀스 길이가 늘어날수록 급격히 증가하는 메모리 사용량은 학습과 추론 과정에서 큰 병목으로 작용합니다. 본 글에서는 이러한 문제를 해결하기 위해 제안된 ‘Tiled MLP’ 기술의 개념과 원리, 그리고 최신 프레임워크인 Unsloth에 적용된 실제 사례를 통해 그 효과를 살펴봅니다.

1. 연구 배경: 메모리 병목의 원인

LLM 학습 시 메모리 부족 현상(OOM)은 흔히 어텐션(Attention) 메커니즘의 2차원적 복잡도($O(N^2)$) 때문이라고 알려져 있습니다. 그러나 시퀀스 길이가 수백만 토큰 단위로 길어지면, 어텐션뿐만 아니라 피드포워드 네트워크(MLP)나 로짓(Logits) 계산과 같은 선형 연산에서도 막대한 양의 활성화(Activation) 메모리가 필요하게 됩니다.

예를 들어 Llama-3.1-8B 모델 기준으로 시퀀스 길이가 125K에 달할 경우, 각 레이어의 은닉 상태(hidden_states) 텐서를 저장하는 데만 약 30.5GiB의 메모리가 필요합니다. 기존의 학습 방식은 전체 시퀀스에 대해 연산을 한 번에 수행하려다 보니, 이러한 중간 텐서를 모두 GPU 메모리에 올려야 해 물리적인 한계에 부딪히게 됩니다.

2. Tiled MLP의 핵심 개념

Tiled MLP는 시퀀스 차원을 따라 연산을 작은 단위(Tile)로 나누어 순차적으로 처리하는 기술입니다. 핵심은 전체 시퀀스에 대한 중간 활성화 값을 한 번에 저장하지 않고, 타일 단위로 계산을 수행한 뒤 필요한 값만 남기고 메모리를 해제하는 것입니다

이 방식을 적용하면 이론적으로 MLP 연산에 필요한 피크 메모리 사용량을 시퀀스 길이 $N$에 비례하는 $O(N)$에서 상수 시간인 $O(1)$로 줄일 수 있습니다. 즉, 시퀀스 길이가 아무리 길어져도 MLP 연산 시 순간적으로 점유하는 메모리 양은 일정 수준으로 유지됩니다.

3. 주요 특징 및 이점

메모리 효율성의 극대화 Llama-3.1-8B 모델의 16K 시퀀스 학습 시, 로짓 계산에만 약 8GiB의 메모리가 필요합니다. 이를 Tiled MLP 방식을 적용하여 분할 처리하면, 실제 피크 메모리 사용량을 획기적으로 낮출 수 있습니다 이는 제한된 GPU 자원 내에서 더 큰 배치 크기나 더 긴 시퀀스를 처리할 수 있는 여유 공간을 확보해 줍니다.

image

기존 최적화 기법과의 시너지 Tiled MLP는 단독으로 사용될 때보다 활성화 체크포인트를 CPU로 오프로딩(Activation Checkpoint Offload)하는 기술과 결합될 때 그 효과가 배가됩니다.Arctic Long Sequence Training(ALST) 연구에 따르면, 이 두 기술을 함께 사용할 경우 시퀀스 길이를 기존 대비 약 3.5배까지 확장할 수 있음이 확인되었습니다.

4. 실제 적용 사례: Unsloth 500K 컨텍스트 파인튜닝

Tiled MLP 기술은 단순한 연구 단계에 머무르지 않고, 파인튜닝 프레임워크인 Unsloth에 통합되어 실제 성과를 내고 있습니다. 최근 Unsloth는 Snowflake의 연구진과 협업하여 단일 GPU 환경에서의 한계를 획기적으로 극복했습니다.

단일 GPU에서의 컨텍스트 확장 Unsloth는 Tiled MLP 도입을 통해 단일 80GB H100 GPU에서 gpt-oss-20b 모델의 컨텍스트 길이를 기존 80K에서 500K 이상으로 확장했습니다. 더 나아가 작은 모델의 경우 단일 GPU에서 최대 100만(1M) 컨텍스트 길이까지 훈련이 가능한 환경을 구현했습니다.

통합된 최적화 기술 (Loss Refactoring & Gradient Checkpointing) Unsloth 구현에서는 Tiled MLP와 함께 ‘Loss Refactoring(Chunk & Fuse)’ 기술이 적용되었습니다. 이는 전체 시퀀스를 한 번에 처리하는 대신 VRAM 용량에 따라 동적으로 조절되는 청크 단위로 손실(Loss)을 계산하여 피크 메모리를 줄입니다. 또한, 활성화 데이터를 CPU로 오프로드할 때 CUDA 스트림을 활용하여 훈련 오버헤드를 기존 1~3%에서 0.1% 이하로 낮춘 개선된 Gradient Checkpointing 기술이 사용되었습니다.

트레이드오프와 사용 편의성 Tiled MLP를 적용하면 메모리 사용량을 약 40% 절감하고 컨텍스트 용량을 2배 늘릴 수 있지만, 추가적인 연산 과정으로 인해 스텝 당 소요 시간이 약 1.3배 증가하는 트레이드오프가 존재합니다. 그러나 사용자는 FastLanguageModel.from_pretrained 함수에서 unsloth_tiled_mlp = True로 설정하는 것만으로 복잡한 튜닝 없이 이 기능을 즉시 활용할 수 있습니다.

Tiled MLP는 대규모 언어 모델의 긴 시퀀스 학습에 있어 필수적인 메모리 최적화 기술입니다. 연산을 타일링하여 피크 메모리를 $O(1)$ 수준으로 억제하는 이 접근 방식은 ALST 논문을 통해 이론적으로 입증되었으며, Unsloth와 같은 프레임워크를 통해 실제 개발 환경에서도 단일 GPU로 수십만 토큰을 처리할 수 있음을 증명했습니다. 연산 시간의 소폭 증가는 존재하지만, 하드웨어 증설 없이 문맥 처리 능력을 비약적으로 향상시킬 수 있다는 점에서 매우 실용적인 기술이라 할 수 있습니다.

참조문서 https://docs.unsloth.ai/new/500k-context-length-fine-tuning https://docs.unsloth.ai/new/500k-context-length-fine-tuning#tiled-mlp-unlocking-500k https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/gpt_oss_(20B)_500K_Context_Fine_tuning.ipynb