안녕하세요. 오늘 리뷰할 논문은 Meta AI에서 발표한 TOKEN MERGING: YOUR VIT BUT FASTER라는 논문입니다. ICLR Oral(noticable top 5%)로 올라간 논문입니다. 주관적인 의견이지만 Transformer 기반의 모델은 딥러닝 전반에서 좋은 모습을 보여주고 있습니다. 그와 동시에 파라미터를 많이 사용하는 딥러닝의 경향성을 가속화 시켰다고 생각합니다.
그 동안 Model pruning이나 quantization 같이 computing cost를 절약하는 류의 연구들이 동시에 진행이 되었었는 오늘 리뷰할 논문 ToMe는 이러한 연구들과 방향성이 같다고 보시면 될 것 같습니다. Transformer는 GPU 메모리를 차지할 때, 입력으로 들어오는 token의 갯수에 영향을 굉장히 많이 받는 구조입니다. 그렇기 때문에 transformer 구조를 사용하는 모델에서 효율성을 올리는 방법 중에 하나가 token의 갯수를 줄이는 것입니다. ToMe에서는 제목 그대로 input token들을 merging하는 방식으로 token의 갯수를 줄여나가고 결과적으로 모델의 효율성을 올렸다고 합니다.
https://openreview.net/forum?id=JroZRaRw7Eu
0. Abstract
저자들은 기존에 학습된 ViT 모델을 추가적인 학습이 없이도 throughput을 높이는 방법인 ToMe를 제안합니다. ToMe는 가벼운 matching 알고리즘을 사용해서 비슷한 token들을 합치는 방법이라고 합니다. 토큰들은 합치는 (merging) 방식은 토큰을 버리는 (pruning)하는 방법 못지않게 빠른 동시에 성능은 더 좋을 편이라고 합니다. 물론 학습을 할 때에도 ToMe를 넣어서 학습할 수 있습니다.
사전에 학습된 모델인 ViT-L, ViT-H 같은 모델은 ToMe를 적용했을 때 throughput이 대략 2배 정도 상승시키면서 성능(정확도)는 0.2-0.3% 정도만 떨어진다고 합니다. 그리고 MAE를 video domain에서 finetuning할 때 함께 사용하면 연산량이 적기 때문인지 학습속도도 2배 정도 빨라지는 것을 확인할 수 있었다고 합니다. Audio 분야에서도 마찬가지로 ViT-B 모델이 throughput이 2배 늘어나는 반면 성능은 0.4%(mAP)만 떨어졌다고 합니다.
ToMe에서는 matching algorithm으로 bipartite알고리즘을 사용하고 있는데, 해당 matching 알고리즘은 DETR에서도 비슷하게 사용된 알고리즘입니다.
1. Introduction
최근 거대모델들이 좋은 성능을 보여주고 있지만 비용 등의 문제들이 존재합니다. 그리고 거대모델의 연산 속도를 높이면서 성능 하락을 최대한 막는 것 또한 쉽지 않다고 합니다. 제가 이 분야를 많이 아는게 아니라서 얼마나 속도를 높일 때 성능이 얼마만큼 유지되는지에 대한 감이 없긴합니다. 최근에 많이 시도 되고 있는 방법이 runtime 상에서 token을 없애는 방식이 많이 시도되고 있다고 합니다. 하지만 token을 버리는 방식은 말그대로 token들을 없애고 거기서 representation을 뽑아내는 방식은 구조적으로 단점이 있습니다. 최대한 버려도 되는 token을 버린다고 해도 결국에는 정보가 소실되는 것이며 그렇기 때문에 token을 무작정 많이 버리기도 쉽지 않다고 합니다. 최근 관련 연구들은 모델을 처음부터 다시 학습하여야하며 학습 속도가 올라가지도 않는다고 합니다. 또한 버리는 token의 갯수가 가변적이면 batch단위로 처리할 수 없게된다고 합니다.
기존의 문제들을 해결한 ToMe는 아래와 같은 contribution이 있다고 합니다.
- 학습에 적용 여부와 상관 없이 throughput을 높인 방법 제안 + 학습에 적용 시 학습 속도 향상 효과
- ViT 외에 변형 구조들에 대해서도 다양한 실험을 진행
- SOTA Token pruning 기법과 비교한 실험
- 다양한 modality에서도 실험을 진행
- Token merging된 결과를 시각화하여 보임
2. Related Work
Efficient Transformer, Token Reduction, Combining Tokens와 관련된 내용들입니다. 관심이 있으신 분은 논문을 찾아보시면 좋을 것 같습니다.
3. Token Merging
ToMe의 목적은 ViT에 token merging module을 추가하여, 해당 모듈이 token들을 합치면서 학습/추론 시 throughput을 높이도록하는 것입니다. 그리고 굳이 추가 학습이 없이도 바로 사용이 가능하다는 것이 특징입니다.
Strategy: transformer의 각 block 마다 ToMe 모듈을 넣어서 매 block을 지날 때, 고정된 갯수(r) 만큼 token을 합친다고 합니다. L개의 transformer layer가 쌓인 모델이라고 치면 총 rL번 token merging이 일어났게 됩니다. 항상 고정된 갯수의 token을 합치는 이유는 batched inference를 하기 위함이라고 합니다. 다이나믹하게 token pruning을 하는 방법도 기존에 제시가 되었고, 성능도 좋은 편이지만 이런 구조는 batched inference를 할 때 제약이 많다고 합니다.
아래 그림은 ToMe의 작동 방식에 대한 figure입니다. 기존 방법들은 token pruning을 block 맨 앞에서 하는 경우가 대부분이었는데 ToMe는 attention과 MLP 사이에서 token merging을 진행한다고 합니다. 이러한 구조는 token merging을 할 때 attention의 결과 값을 참고할 수 있다는 장점이 있다고 합니다.
Token Similarity: Transformer 안에서 연산이 이루어지는 feature space는 기본적으로 overparmeterized 되어있기 때문에 token을 쉽게 encoding 할 수 있습니다. 하지만 feature 안에 noise가 껴있을 수도 있기 때문에 feature space 상에서 token 간의 거리는 token 간의 유사도를 완벽하게는 아니지만 어느정도 유사하게 대변할 수 있다고 합니다.
여기서는 attention 안에 있는 K(keys)를 사용해서 token 간의 유사도(cosine similarity)를 구한다고 합니다: key와 각 token을 dot product 하는 식으로 구함. 이것과 관련된 실험은 뒤에 나와있습니다.
Bipartite Soft Matching: Token 간의 유사도가 위와 같은 방식으로 정해져있다는 가정 하에 token을 r개 만큼 줄일 수 있는 효율적인 matching 알고리즘이 필요합니다. 후보로 제안되었던 방법은 k-means, graph cuts 방법 등이 있었습니다. 하지만 모델마다 다르겠지만 대략 1000개 안팎의 token을 L(층 수) 번 계산을 해야하기에 굉장히 빠르고 효율적인 알고리즘이 필요했고 iterative clustering 쪽 방법은 모두 배제되었습니다.
그리고 저자들은 병렬처리가 불가능한 iterative한 방법이 아니면서, merging을 점진적인 형식으로 할 수 있는 방법을 찾았다고 합니다.
그리하여 Bipartite soft matching 알고리즘을 고안하는데 아래와 같은 방식으로 작동합니다.
1. Token을 대략적으로 비슷한 크기의 두 개의 set A, set B로 나눈다.
2. A에 있는 token들을 B에서 가장 비슷한 token과 연결한다.
3. 연결된 선 중에서 유사도가 가장 높은 r개만 남긴다.
4. 연결되어있는 Token들끼리 합친다. (feature를 avereaging 하는 등의 방식으로)
5. Set A와 set B를 다시 하나로 만든다.
이 절차는 아래 그림에도 나와있습니다.
Tracking Token Size: 복수 개의 token이 동일한 key로 합쳐지면 softmax 연산을 할 때 비교적으로 영향력이 적어지는 방향으로 연산이 된다고 합니다. 그리하여 아래와 같은 proportional attention이라는 변형된 계산식을 사용합니다. 아래 식에서 s는 각 token이 몇 개의 token들이 합쳐져서 만들어진 것인지를 나타낸다고 합니다.
Training With Merging: 사전에 학습이 된 ViT에 바로 붙여서 사용할 수 있도록 고안이 되었지만 ToMe를 같이 끼워서 학습을 할 때 정확성 저하 방지나 빠른 학습 같은 효과가 있기도 합니다. ToMe와 모델을 함께 학습할 때는 별도의 training trick 없이 pooling layer처럼 취급하면서 학습을 하면 된다고 합니다.
4. Image Experiments
실험으로는 ViT를 ImageNet-1k 데이터셋으로 AugReg, MAE, SWAG, DeiT 방법을 사용해서 실험하여 비교하였습니다. MAE와 DeiT 방법은 ToMe를 ViT에 달아서 별도로 학습을 진행하였고, 나머지는 사전에 학습된 weight를 가져와서 사용했다고 합니다.
MAE는 제가 전에 리뷰했던 포스트가 있어서 궁금하신 분은 아래 링크를 구경해보셔도 좋을 것 같습니다. MAE를 포함한 다른 학습 방법들에 대해서 구글링을 하셔도 좋을 것 같습니다.
Throughput을 측정할 때는 V100 GPU에 각자 최적의 batch size대로 설정하였다고 합니다.
4.1. Design Choices
일종의 ablation 처럼 최적의 옵션을 찾기 위해 실험을 진행하였습니다.
Token Similarity: 위의 (a) 표를 보면 Token 간의 유사도를 측정할 때 token을 그대로 사용(X)하는 것 보다 위에서 말씀드렸던 것처럼 Key(K)를 곱해서 사용하는 것이 가장 성능이 좋았다고 합니다. Feature distance를 측정할 때는 cosine similarity를 통해 측정하는 것이 성능이 좋았다고 합니다. 또한 연산상의 효율 문제도 있고 attention haed들을 concat하기보단 평균을 구하는 쪽으로 선택했다고 합니다.
Algorithmic Choices: Token들을 합칠 때에는 token 갯수에 비례한 (위의 식 참고) weighted avg 방식이 성능이 좋았다고 합니다. Bipartite matching을 할 때는 token을 2개의 set(A, B)로 나눕니다. Merging을 한 후에는 남은 token들을 concat합니다. 다음에 token들을 A, B로 나눌 때 전체 token들을 A, B에 번갈아 할당하면서 나눌 때 성능이 가장 좋았다고 합니다.
Proportional Attention: Merge된 token은 하나 이상의 token을 위의 식을 통해서 합치면서 모든 token들을 대변하게 됩니다. 이와 같은 proportional attention은 MAE를 제외한 방식에만 더 효과가 좋았다고 합니다. MAE에서 이런 양상이 나오는 이유는 MAE에서는 encoder에서 token을 받을 때 이미 일정 비율을 버린 상태로 받기 때문입니다. 이런 성능 격차는 pretraining 때 ToMe를 넣어서 학습하면 줄어들긴 한다고 합니다. 그래서 pretrain에서 바로 ToMe를 적용한 MAE 버전 ViT가 아닐 때에만 proportional attention 방식을 사용하지 않았다고 합니다.
Comparing Matching Algorithms: Token 수를 줄이는 알고리즘을 비교한 실험은 아래와 같다고 합니다.
Seleting a Token Merging Schedule: 매 layer 마다 정해진 수 만큼의 token을 합치는 constant schedule 방식과 처음에는 많이 merging을 하다가 점차 merge하는 token의 갯수를 점차 줄여나가는 decreasing schedule 방식을 비교하였습니다. Constant schedule 방식이 optimal한 경향을 보여준다고 합니다. 저 점들은 다양하게 15,000개의 schuler를 가지고 샘플링하여 그린 것이라고 합니다. AugReg ViT-B/16을 사용하였다고 합니다.
4.2. Model Sweep
다양한 모델을 가지고 실험하는 부분입니다.
4.3. Comparison to Other Works
다른 방법들과 비교를 한 부분입니다.
아래 table로 대체하도록 하겠습니다.
4.4. Visualizations
마지막 Layer의 결과 값을 가지고 시각화를 한 모습입니다. MAE를 사용한 ViT-H/r7/trained from scratch 버전이라고 합니다.
Video/Audio 부분에도 적용을 한 Section이 있는데 이 부분은 넘어가도록 하겠습니다.