KalelPark's LAB

[ 논문 리뷰 ] An Attention Free Transformer 본문

Data Science/Classification

[ 논문 리뷰 ] An Attention Free Transformer

kalelpark 2023. 5. 19. 09:42

Abstract

본 논문에서는 dot product self attention을 사용하지 않는 Attention Free Trnasformer (AFT)을 제시합니다. AFT layer는 key와 value를 position biases와 함께 결합됩니다. 이후, query와 함께 element-wise 형태로 계산됩니다. 이러한 작업은 context size와 dimension of features에서 선형복잡성을 가지고, large input과 model size에서도, 모두 호환하게 합니다. 또한 global connectivity를 유지하면서, locality와 spatial weight를 공유하는 장점을 가지는 2가지 모델의 변형은 AFT-local과 AFT-conv를 소개합니다.

Introduction

본 논문에서는 일반적으로 사용되는 computational module을 사용하지 않거나, 근사하지 않는 계산 Module을 사용하는 방식을 제안합니다. 우리는 이러한 방식을 Attention free transformer라고 부르며, Dot product 와 유사하게, AFT는 query, key, value 3가지 요소의 상호작용으로 구성됩니다. Attention Free Transformer(AFT) Dot product의 장점과 유사한 방식으로 Context내 2지점간의 직접적인 상호작용을 유지합니다. AFT는 attention map이 명시적으로 계산될 필요가 없으며, head의 개수가 Model의 feature dimension과 동일한 형태로 수행되는 것으로 해석될 수 있습니다. 

일반적인 Attention 방법론들은 Key Query의 개수가 증가하게 되면, model의  feature dimension이 상당히 늘어나, 계산 시간이 늘어날 뿐만 아니라, large model에서 부적합하다.

일반적으로,

이러한 부분에서 영감을 받아, AFT-local과 AFT-conv를 제시합니다. AFT-Local의 경우, learned position biases는 local region에 제약되어, global connectivity를 유지합니다. AFT-Conv는 spatial weight을 공유함으로써 design을 확장하고, CNN의 global receptive field를 확장합니다. localilty constraint는 better parameter와 계산비용의 효율성을 보여줍니다. 뿐만 아니라, 모든 task에서 높은 성능을 보여줍니다.

* 기존 Transformer에서 만연되게 사용되는 Multi-Head Attention에 대해서 소개하고 있습니다.

Method

3.1 Attention Free Transformer

본 논문에서는 일반적인 Transformer의 Architecture를 변경하지 않고, MHA의 대체제인 Attention Free Transformer(AFT)를 제시합니다. 한 마디로 설명하자면, AFT는 weight average of value를 query와 함께 element-wise multiplication으로 결합합니다. 특히, 가중치를 만드는 것은 key와 learned pairwise position biases로 구성되어 있습니다. 이러한 방식은 상당한 attention matrix를 계산할 필요도 없으며, MHA에 존재하는 value 와 query사이 상호작용을 유지합니다. 

3.2 AFT variants : locality, weight sharing and parameterization

위의 수식은, global connectivity, non-negative convolutional weights, sophisticated divisive/multiplicative로 해석될 수 있습니다. 또한 여러 실험을 진행하면서, AFT-full 과 AFT-lcoal에 대해서는, factorized 형태를 생성하는 것이 중요하다고 말합니다. 이러한 인수분해의 형태는 매개변수를 크게 줄일 수 있을 뿐만 아니라, 테스트 모델에서 높은 성능을 경험적으로 보여줍니다. 

import numpy as np
import torch
import torch.nn as nn
from torch.nn import init

class AFT_FULL(nn.Module):

    def __init__(self, d_model,n=49,simple=False):

        super(AFT_FULL, self).__init__()
        self.fc_q = nn.Linear(d_model, d_model)
        self.fc_k = nn.Linear(d_model, d_model)
        self.fc_v = nn.Linear(d_model,d_model)
        if(simple):
            self.position_biases=torch.zeros((n,n))
        else:
            self.position_biases=nn.Parameter(torch.ones((n,n)))
        self.d_model = d_model
        self.n=n
        self.sigmoid=nn.Sigmoid()

        self.init_weights()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, input):

        bs, n,dim = input.shape

        q = self.fc_q(input) #bs,n,dim
        k = self.fc_k(input).view(1,bs,n,dim) #1,bs,n,dim
        v = self.fc_v(input).view(1,bs,n,dim) #1,bs,n,dim
        
        numerator=torch.sum(torch.exp(k+self.position_biases.view(n,1,-1,1))*v,dim=2) #n,bs,dim
        denominator=torch.sum(torch.exp(k+self.position_biases.view(n,1,-1,1)),dim=2) #n,bs,dim

        out=(numerator/denominator) #n,bs,dim
        out=self.sigmoid(q)*(out.permute(1,0,2)) #bs,n,dim

        return out


if __name__ == '__main__':
    input=torch.randn(3,196,768)
    aft_full = AFT_FULL(d_model=768, n=196)
    output=aft_full(input)
    print(output.shape)

Experiments

Conclusion

본 논문에서는 효율적인 계산을 하기 위하여, dot producc attention을 도입하였습니다. 또한 훌륭한 성능을 논문에서 보여줍니다.

Reference

https://arxiv.org/abs/2105.14103

 

Comments