티스토리 뷰

들어가며

본 글은 Hierarchical Attention Prototypical Networks for Few-Shot Text Classification을 리뷰한 글입니다.

 

기존 Prototypical Network의 문제점

데이터에서 prototype의 분별력과 표현력을 약화시키는 다양한 노이즈의 부정적 영향을 고려하지 않았다.

기존 Attention 기반 Prototypical Network의 문제점

기존에서 PN에 attention을 이용한 접근도 있었지만 이러한 방법들은 의미 정보를 이용하지 않거나 노이즈의 영향을 정교하게 고려하지 않았다.

 

논문에서의 키 아이디어

세 가지 레벨에서 attention 메커니즘을 사용하여 few-shot text classification을 한다.

 

feature level attention: 클래스마다 다른 feature score를 얻기 위해 CNN을 사용한다.

word level attention: attention 메커니즘으로 인스턴스에서 각 단어 은닉 상태의 중요도를 배우도록 한다.

instance level multi cross attention: support set과 query set 사이의 multi cross attention으로 각은 클래스 안에 서로 다른 인스턴스의 중요도를 결정하고 각 클래스에서 모델이 더욱 분별있는 prototype을 얻도록 한다.

 

이러한 방법은 의미 공간(semantic space)에서의 표현력을 강화한다.

 

문제 정의

few-shot text classification에서 본 논문에서 배우고자 하는 함수는 \(D\)(\(D_{train}, D_{validation}, D_{test}\)로 나뉨)가 라벨링된 데이터, \(S\)가 support set, \(x\)는 라벨링 되지 않은 입력 텍스트 인스턴스, \(y\)가 출력 레이블일 때 \(G(D, S, x) \to y\)이다.

 

에피소드 학습 전략에 따라 처음엔 \(D_{train}\)에서 label set \(\mathfrak{L}\)을 샘플링한다. 그 다음 \(\mathfrak{L}\)에서 support set \(S\)와 query set \(Q\)를 샘플링하고 이 \(S\)와 \(Q\)를 모델에 전달하여 loss를 최소화한다. 여기서 \(\mathfrak{L}\)이 \(N\)개의 클래스, \(S\)가 \(K\)개의 인스턴스를 가지고 있으면 이를 \(N\)-way \(K\)-shot 학습이라고 한다. 본 논문에선 \(N\)이 5나 10, \(K\)가 5나 10인 경우를 고려한다.

 

모델

전반적인 모델의 아키텍처는 다음 그림과 같다.

모델은 크게 세 파트로 나뉜다. Instance Encoder, Prototypical Networks, Hierachical Attention.

 

Instance Encoder

이 층은 support set이나 query set의 각 인스턴스는 임베딩으로 각 단어를 입력 벡터로 표현한다.

 

instance encoder는 embedding layer와 instance encoding layer 두 개의 레이어로 구성되어있다.

Embedding Layer

\(T\)개의 단어로 구성된 인스턴스 \(x=\left\{w_{t},w_{2},...,w_{T}\right\}\)가 주어질 때 임베딩 행렬 \(W_{E}\)로 \(\boldsymbol{w}_{t}=W_{E}w_{t}\) 식을 통해 각 단어를 \(d\) 차원 벡터로 임베딩한다.

\[\left\{\boldsymbol{w}_{1}, \boldsymbol{w}_{2},...,\boldsymbol{w}_{T}\right\}, \boldsymbol{w}_{t} \in \mathbb{R}^{d}\]

Encoding Layer

encoding layer로 CNN을 쓰는데, 모델의 경량화와 속도의 측면을 고려해서 1개층의 CNN을 사용한다. CNN은 컨볼루션 커널로 각 단어의 숨은 정보(hidden annotations)를 얻는다. 윈도우 사이즈를 \(m\)이라 할 때 식은 다음과 같다.

\[\boldsymbol{h}_{t} = \textrm{CNN}(\boldsymbol{w}_{t-\frac{m-1}{2}},...,\boldsymbol{w}_{t-\frac{m+1}{2}})\]

이 식에서 \(\boldsymbol{w}_{t}\)가 position embedding을 갖고 있다면 \(\boldsymbol{w}_{t}\)와 \(\boldsymbol{p}_{t}\)을 concat한다.

※ position embedding : positional encoding은 transformer에서 어순 정보를 나타낸다고 한다. 여기서도 비슷한 맥락으로 사용되지 않았을까 생각한다. 참고자료

\[\boldsymbol{wp}_{t}=\left [\boldsymbol{w}_{t} \oplus \boldsymbol{p}_{t}\right ]\]

\(\oplus \)를 concatation이라 할 때 \(\boldsymbol{h}_{t}\)은 다음과 같이 다시 쓰일 수 있다.

\[\boldsymbol{h}_{t} =\textrm{CNN}(\boldsymbol{wp}_{t-\frac{m-1}{2}},...,\boldsymbol{wp}_{t-\frac{m+1}{2}})\]

그 다음 모든 \(\boldsymbol{h}_{t}\) 값을 인스턴스 \(x\)의 전체 표현으로 만든다.

\[x=\left\{ \boldsymbol{h}_{1},\boldsymbol{h}_{2},...,\boldsymbol{h}_{t} \right\}\]

 

이 두 레이어를 하나의 포괄적인 함수로 정의한다.

\[\boldsymbol{x}=g_{\theta }(x)\]

이 함수에서 \(\theta \)는 학습될 네트워크의 파라미터를 의미한다.

 

Prototypical Networks

Prototypical network는 각 클래스의 표현으로 prototype 벡터를 계산한다. 이 prototype은 각 클래스의 embedded support 인스턴스의 평균값이고 prototype 벡터와 query 벡터의 거리(squared Euclidean distance)를 비교해서 가장 가까운 prototype의 클래스가 query의 클래스가 된다.

 

클래스 \(l_{i}\)에서의 prototype 벡터 \(c_{i}\)는 다음과 같이 서포트 셋의 임베딩된 인스턴스를 모두 평균하여 구한다.

\[\boldsymbol{c}_{i}=\frac{1}{n_{i}}\sum_{j=1}^{n_{i}}g_{\theta }(x_{j}^{i})\]

\(x_{i}^{j}\)는 \(i\)번째 클래스의 \(j\)번째 문장을 의미한다.

 

본 논문에선 벡터의 거리를 구할 때 기존 squared Euclidean distance 대신 squared Euclidean distance와 class feature score를 결합하여 더 좋은 성능을 이끌었다.

 

Hierarchical Attention

희귀한 데이터에서 중요한 정보를 더 얻기 위해서 hierachical attention 메커니즘을 이용한다. few-shot text classification을 위해 여기서 모델은 feature score 벡터를 얻고 각 클래스의 support set을 벡터 표현으로 변환한다.

 

Feature Level Attention

Feature level attention은 각 클래스에서의 서로 다른 피처의 중요도를 늘리거나 줄인다.

feature level 공간에선 어떤 특징 차원이 특정 클래스에서 더욱 분별있는 동시에 다른 피처는 헷갈리고 쓸모없을 수 있다. 따라서 Hybrid attention-based prototypical networks for noisy few-shot relation classification에서 제안된 것과 유사한 CNN 기반의 feature attention 메커니즘을 사용한다. 이는 각 클래스에서 support set의 모든 인스턴스에 의존하고 클래스에 따라 다이나믹하게 달라진다.

 

CNN 기반 feature attention 메커니즘을 수식으로 설명한다. \(n_{i}\)는 클래스의 샘플 개수, \(T\)는 문장에서의 단어 개수일 때(\(d\)는 임베딩 차원인 것 같음)는 , 클래스 \(l_{i}\)의 support set \(S_{i} \in \mathbb{R}^{n_{i}\times T\times d}을 앞선 instance encoder 파트의 출력이라 할 때 다음과 같이 표현한다.

\[\boldsymbol{S}_{i}=\left\{\boldsymbol{x}^{1}, \boldsymbol{x}^{2},...,\boldsymbol{x}^{n_{i}}\right\}\]

CNN의 아키텍처는 위 표와 같으며 각 층의 수식은 다음과 같이 표현한다.

\(\boldsymbol{S}_{i}\)에서 각 인스턴스에 max pooling 레이터를 적용한 새 피쳐맵을 \(\boldsymbol{S}_{\boldsymbol{c}_{i}} \in \mathbb{R}^{n_{i} \times d}\), 그 다음 3개층의 합성곱 레이어로 클래스 \(l_{i}\)의 score 벡터인 \(\boldsymbol{\lambda}_{i} \in \mathbb{R}^{d}\)이다.

 

그 다음 새로운 거리 함수는 다음과 같다. 앞서 언급한 대로 squared euclidian distance에 class feature score를 결합한 형태이다.

\[d(\boldsymbol{c}_{i},\boldsymbol{q}')=(\boldsymbol{c}_{i},\boldsymbol{q}')^{2}\cdot \boldsymbol{\lambda }_{i}\]

여기서 \(\boldsymbol{q}'\)은 다음 섹션에서 소개할 word level attention 메커니즘을 통과한 query 벡터이다.

 

Word Level Attention

인스턴스마다 서로 다른 단어의 중요도가 다르기 때문에 Word level attention은 어떤 단어가 쓸모있는지 나타낸다. 중요한 단어를 얻고 이를 취합해 더 유용한 정보로 만들기 위해 attention 메커니즘을 적용한다. 이 식은 다음과 같다.

\[\boldsymbol{u}_{t}^{j}=\textrm{tanh}(\boldsymbol{W}_{w}\boldsymbol{h}_{t}^{j}+\boldsymbol{b}_{w})\]

\[v_{t}^{j}=\boldsymbol{u}_{t}^{j\textrm{T}}\boldsymbol{u}_{w}\]

\[\alpha_{t}^{j}=\frac{\sum_{t}\textrm{exp}(v_{t}^{j})}{\textrm{exp}(v_{t}^{j})}\]

\[\boldsymbol{s}^{j}=\sum_{t}\alpha_{t}^{j}\boldsymbol{h}_{t}^{j}\]

여기서 \(\boldsymbol{h}_{t}^{j}\)는 인스턴스 \(\boldsymbol{x}^{j}\)의 \(t\)번째 은닉 단어 임베딩을 의미한다. 이는 instance encoder를 통해 인코딩 된다.그리고 은닉 사이즈는 \(\boldsymbol{x}^{j}\)와 동일하다.

 

우선 \(\boldsymbol{W}_{w}\)와 \(\boldsymbol{b}_{w}\)는 MLP를 의미하며 tanh로 \(\boldsymbol{h}_{t}^{j}\)를 \(\boldsymbol{u}_{t}^{j}\)로 변환한다.

그 다음 이 \(\boldsymbol{u}_{t}^{j}\)와 word level의 가중치 벡터 \(\boldsymbol{u}_{w}\) 사이에 dot product를 한 후(\(v_{t}^{j}\)) softmax 함수로 정규화한다.(\(\alpha_{t}^{j}\))

마지막으로 (\(\alpha_{t}^{j}\))와 \(\boldsymbol{h}_{t}^{j}\)의 가중합으로 instance level 벡터 \(\boldsymbol{s}^{j}\)를 계산한다.

word level의 가중치 벡터 \(\boldsymbol{u}_{w}\)는 인스턴스마다 중요한 단어를 선택할 수 있도록 돕는다. 이는 초기엔 무작위로 초기화되고 학습되는 동안 파라미터 \(\theta\)와 함께 최적화된다.

 

Instance Level Multi Cross Attention

Instance level multi cross attention은 서로 다른 query 인스턴스에서 중요한 support 인스턴스를 뽑아낼 수 있다.

target query 인스턴스가 들어올 때 모든 support set의 인스턴스가 동일하게 클래스 prototype에 기여하는 건 아니기 때문에 query 인스턴스를 올바르게 분류할 단서를 제공하는 support 인스턴스의 중요도를 강조하기 위해 Instance level multi cross attention 메커니즘을 사용한다.

 

클래스 \(l_{i}\)에서 support set \(\boldsymbol{S}_{i}' \in \mathbb{R}^{n_{i} \times d}\)과 query 벡터 \(\boldsymbol{q}' \in \mathbb{R}^{d}\)이 주어지고 이들은 각각 instance encoder와 word level attention을 통과했을 때, \(\boldsymbol{S}_{i}'\) 안에 각 support 벡터 \(\boldsymbol{s}_{i}^{j}\)가 query \(\boldsymbol{q}'\)에 대해 가중치 \(\beta_{i}^{j}\)를 가진다. 이를 수식화하면 다음과 같다.

\[\boldsymbol{c}_{i}=\sum_{j=1}^{n_{i}}\beta_{i}^{j}\boldsymbol{s}_{i}^{j}\]

\(\boldsymbol{r}_{i}^{j}=\beta_{i}^{j}\boldsymbol{s}_{i}^{j}\)를 가중 prototype 벡터라 정의하고 \(\beta_{i}^{j}\)은 다음과 같이 정의된다.

\[\beta_{i}^{j}=\frac{\sum_{j=1}^{n_{i}}\textrm{exp}(\gamma_{i}^{j})}{\textrm{exp}(\gamma_{i}^{j})}\]

\[\gamma_{i}^{j}=\textrm{sum}\left\{\sigma (f_{\varphi }(mca))\right\}\]

\[mca=\left [ \boldsymbol{s}_{i\phi }^{j}\oplus \boldsymbol{q}_{phi}'\oplus \boldsymbol{\tau }_{1}\oplus \boldsymbol{\tau }_{2} \right ]\]

\[\boldsymbol{\tau}_{1} = \left| \boldsymbol{s}_{i\phi}^{j}-\boldsymbol{q}_{\phi}' \right|, \boldsymbol{\tau}_{2}=\boldsymbol{s}_{i\phi}^{j}\odot \boldsymbol{q}_{\phi}'\]

\[\boldsymbol{s}_{i\phi}^{j}=f_{\phi}(\boldsymbol{s}_{i}^{j}), \boldsymbol{q}_{\phi}'=f_{\phi}(\boldsymbol{q}')\]

여기서 \(f_{\phi}\)는 선형 레이어, \(\left|\cdot \right|\)는 element-wise 절댓값, \(\odot\)은 element-wise product이다. 이 두 element-wise 연산자는 \(\boldsymbol{s}_{i}^{j}\)와 \(\boldsymbol{q}_{\phi}'\) 사이에서 서로 다른 정보 \(\boldsymbol{\tau}_{1}\)와 \(\boldsymbol{\tau}_{2}\)를 얻기 위해 쓰인다. 그 다음 모든 값을 concat하여 multi cross attention information인 \(mca\)를 구성한다. \(f_{\varphi}(\cdot )\)는 선형 함수, \(\sigma(\cdot )\)는 tanh 활성함수, sum{\(\cdot\)}은 벡터의 모든 요소 값의 합을 의미한다. \(\gamma_{i}^{j}\)는 support set \(\boldsymbol{s}_{i}\) 안에 \(j\) 인스턴스의 가중치이고, \(\beta_{i}^{j}\)에 softmax 함수로 정규화한다.

 

이로써 multi scross attention mechanism의 protype은 query와 관련된 support 인스턴스에 더 집중하고 support set의 이해력(capacity)을 개선할 수 있다.

 

실험

댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2024/12   »
1 2 3 4 5 6 7
8 9 10 11 12 13 14
15 16 17 18 19 20 21
22 23 24 25 26 27 28
29 30 31
글 보관함