티스토리 뷰
[논문 리뷰]Few-shot Text Classification with Distributional Signatures
ProWiseman 2023. 5. 12. 14:50들어가며
본 글은 논문 Few-shot Text Classification with Distributional Signatures을 리뷰한 글입니다.
텍스트 도메인에서 메타러닝을 적용하기 어려운 이유
컴퓨터 비전에서는 에지와 같은 저수준의 패턴이나 그에 부합하는 표현들이 task끼리 서로 공유될 수 있지만 언어 데이터에선 대부분의 task들이 어휘(lexical) 수준에서 이루어진다. 단어 어느 task에선 유용하지만 어느 task에선 유용하지 않는다.
※ 여기서 lexicon은 단어의 발음, 의미, 다른 단어와의 관계 등 단어에 관한 정보가 저장되어 있는 곳을 의미하는 것 같다. 참고자료
논문에서의 키 아이디어
단어를 바로 고려하지 않고 단어 분포(i.e. tf-idf weighting)에 있는 특성인 distributional signature를 활용한다.
(tf-idf weighting은 단어의 중요도를 문서 집합의 빈도(frequency)와 특정 문서에 대한 치우침(skewness)을 통해 명시적으로 지정한다.)
Distributional Signature의 효용
본 논문에선 distributional signature를 활용하기 위해 단어의 빈도와 특정 클래스에 대한 단어의 중요도를 이용한다. 특정 클래스에 대한 단어의 중요도는 few-shot learning 이기 때문에 라벨링된 데이터가 부족해 신용할만한 추정치를 얻지는 못하지만 noisy한 추정치는 얻을 수 있고 이는 메타러닝 프레임워크로 정제할 수 있다.
이 부분에서 느낄 수 있듯 distibutional signature는 문장에서 lexical counterparts보다 표현력이 약하지만 메타 지식(meta knowledge)를 만들 때 일반화에 도움이 된다.
※ lexical counterparts란 nouns with a lexical structure and meaning similar to corresponding phrasal verbs라고 한다. 참고자료
기존 메타러닝 프레임워크에서의 확장
본 논문에서는 few-shot learning에서의 기존 표준 메타러닝 프레임워크에서 더 나아간 확장 버전을 도입하였다.
메터 러닝은 클래스 집합 \(y^{train}\)에서 라벨링 된 예제가 주어질 때 이 훈련 데이터로부터 정보를 얻도록 하는 모델을 개발하여 적은 어노테이션만으로 새로운 클래스 집합 \(y^{test}\)에서 예측하는 것이다.
기존의 메타 훈련과 메타 테스트
메타 훈련은 \(y^{train}\)에서 매 에피소드마다 \(N\)개의 클래스와 그 안에서 \(K\)개의 예제를 훈련 데이터로(support set), {L}개의 데이터를 테스트용으로(query set) 샘플링한다. 모델은 테스트 데이터의 loss를 기반으로 업데이트된다.(그림 4a 참고) 여셔기서 support set이 주어졌을 때 query set에 \(\textit{N-way K-shot classification}\)으로 예측한다고 말한다.
메타 테스트는 메타 훈련이 끝난 뒤에 우리 모델이 정말로 새로운 클래스에 빠르게 적을할 수 있는지 테스트하기 위해 동일한 에피소드 기반 메커니즘을 적용한다.\(Y^{test}\)에서 {N}개의 새로운 클래스를 샘플하고 거기서 각 클래스별로 support set과 query set을 샘플링한다. 그리고 모든 테스트 에피소드의 query set의 성능을 평균내 평가한다.
확장
기존의 메터 훈련에서 모든 예제에 접근할 수 있긴 해도 에피소드별로 작은 서브셋으로만 학습하기 때문에 distributional statistics를 충분히 이용할 수 없다. 따라서 더 강건한(robust) 추론을 위해 모든 훈련 예제에서 distributional statistics를 활용한다. 이를 위해 각 에피소드에서 source pool로 데이터를 증강한다. 메타 훈련동안 source pool은 에피소드로 선택된 클래스를 제외한 모든 학습 클래스의 예제를 포함한다. 메타 테스트동안에는 모든 훈련 예제를 포함한다.
이에 대한 개념은 아래에 그림으로도 나와있다.
논문에서 제시된 방법
논문의 목표는 입력의 distributional signature로 고품질의 attention을 배워 few-shot classification의 성능을 개선하는 것이다.
이를 위해 우선 특정 에피소드가 주어졌을 때, source pool과 support set에서 relevant statistics를 추출한다. 이런 통계값은 분류를 위한 단어 중요도를 rough하게 근사하기 때문에 attention generator라는 모듈을 이용해 relevant statistics를 고품질의 attention으로 만들어준다. 이 attention은 ridge regressor라는 downstream을 위한 예측 모듈에 가이던스를 제공한다.
전제적인 모델은 앞서 소개한 다음 두 모듈로 구성되어 있다.
attention generator : source pool과 support set의 distributional signature를 클래스별 단어 중요도를 반영하는 attention score로 변환하는 역할을 한다.
ridge regressor : attention generator가 제공한 attnetion으로 어휘 표현을 구성하고 downstream을 통해 약간의 학습 예제만으로 예측하는 법을 배운다. 제공된 attention은 단어 중요도의 inductive bias이다.
여기서 알고 넘어갈 점은 attention generator는 모든 학습 에피소드에 걸쳐 최적화를 하는 반면 ridge regressor는 각 에피소드마다 새로 학습한다.
Attention Generator
attention generator의 목표는 각 입력 예제의 distributional signature으로부터 단어 중요도를 얻는 것이다.
distributional signature를 얻기 위해 unigram statistics를 사용했다. 이를 선택한 이유는 명사의 섭동(단어가 미세하게 달라지는 것으로 이해했음)에 대해 강건할 것으로 생각되었기 때문이다. 또한 거대한 source pool로 모델에 general한 단어 중요도를 알리고 클래스별 단어 중요도를 추정하기 위해 작은 support set을 활용하였다.
자주 등장하는 단어는 덜 중요할 수도 있기 때문에 자주 등장하는 단어는 downweigh하고 드물게 등장하는 단어는 upweight한다. general한 단어 중요도를 구하는 방법은 다음과 같다.
\[s(x_{i}):=\frac{\epsilon}{\epsilon + \textbf{P}(x_{i})}\]
\(\epsilon = 10^{-3}\), \(x_{i}\)는 입력 예제 \(x\)의 \(i\) 번째 단어, \(\textbf{P}(x_{i})\)는 source pool에서 \(x_{i}\)의 unigram likelihood이다.
반면 support set에서 분별있는 단어는 query set에서도 분별있을 것이다. 따라서 다음의 통계식으로 클래스별 단어 중요도를 반영하도록 정의하였다.
\[t(x_{i}):=\mathfrak{H}(\textbf{P}(y|x_{i}))^{-1}\]
조건부 likelihood \(\textbf{P}(y|x_{i})\)는 support set에서 regularized linear classifier로 추정하였다. \(\mathfrak{H}\)는 엔트로피 연산자이다.
\(t(\cdot )\)은 주어진 단어에서 클래스 라벨이 \(y\)일 때 불확실성을 측정한다. 따라서 치우쳐진 분포를 가진 분포는 높은 가중치를 갖게 될 것이다.
이런 통계를 바로 적용하면 다음 두 가지 이유로 결과가 좋지 못할 것이다.
- 앞서 구한 두 통계는 상호 보완적인 정보를 지니지만 통합하는 방법이 불분명하다.
- 이러한 통계는 분류에서의 단어 중요도엔 다소 noisy하다.
이러한 간극을 잇기 위해 biLSTM으로 두 정보를 합친다. 그리고 마지막으로 단어 \(x_{i}\)의 attention score \(\alpha_{i}\)를 예측하기 위해 dot-product attention을 이용한다.
\[\alpha_{i} := \frac{exp(v^{T}h_{i})}{\sum_{j}exp(v^{T}h_{j})}\]
\(h_{i}\)는 \(i\) 번째 위치에서의 biLSTM의 출력값이고 \(v\)는 학습 가능한 벡터이다.
Ridge Regressor
ridge regressor의 목표는 attention generator가 제공한 정보로 몇 가지 예제를 보고 빠르게 예측하는 방법을 학습하는 것이다.
우선 앞서 생성된 attention score로 단어 중요도에 집중한 어휘 표현을 구성한다. 이 어휘 표현으로 support set에서 ridge regressor를 처음부터(from scratch) 학습시킨다. 마지막으로 query set에 대한 예측을 만들고 여기서 발생한 loss로 attention generator를 학습시킨다.
Constructing representations
서로 다른 단어가 분류에서 중요도에 다양한 수준을 보일 때 단련된 단어에 호의를 보이는 어휘 표현을 생성한다. (we construct lexical representations that favor pertinent words 원문을 잘 이해하지 못하겠음) 예제 \(x\)에서 표현은 다음과 같이 정의한다.
\[\phi(x) := \sum_{i}\alpha_{i} \cdot f_{ebd}(x_{i})\]
여기서 \(f_{ebd}(\cdot )\)는 사전 학습된 임베딩 함수로 단어를 \(\mathbb{R}^{E}\)로 매핑한다.
Training from the support set
N-way K-shot 분류 task에서 \(\Phi_{S} \in \mathbb{R}^{NK \times E}\)를 \(\phi (\cdot )\)으로부터 얻은 support set의 표현이라 하고 \(Y_{S} \in \mathbb{R}^{NK \times N}\)을 원-핫 레이블이라 한다.
우리는 다음의 이유로 ridge regression을 라벨링된 support set에 학습에 이용한다.
- ridge regression은 closed-form 솔루션으로 모델 전반에 end-to-end 미분이 가능하다.
- 적절한 규제를 함께 쓰면 ridge regression은 작은 support set에서도 오버피팅을 줄일 수 있다.
논문에서는 regularized squared loss를 사용하였다.
\[\mathfrak{L}^{RR}(W):=\left\|\Phi_{S}W-Y_{S} \right\|^{2}_{F}+\lambda \left\|W \right\|^{2}_{F}\]
가중치 행렬 \(W \in \mathbb{R}^{E \times N}\)에서 \(\left\|\cdot \right\|_{F}\)는 Frobenius norm을 의미한다. \(\lambda > 0\)은 \(W\)의 조건을 제어한다. closed-form 솔루션은 다음의 식으로 얻어질 수 있다.
\[W=\Phi^{T}_{S}(\Phi_{S}\Phi^{T}_{S}+\lambda I)^{-1}Y_{S}\]
\(I\)는 단위행렬이다.
Inference on the query set
\(\Phi_{Q}\)는 query set의 표현을 나타낸다고 한다.위 regularized squared loss를 회귀 목적함수로 사용하여 최적화를 해도 학습된 변환(\(W\))는 calibration step 이후에 잘 작동했다.
\[\widehat{Y}_{Q}=a\Phi_{Q}W+b\]
a와 b는 메타 파라미터로 메타 훈련을 통해 학습된다. 마지막으로 \(\widehat{Y}_{Q}\)에 소프트맥스를 적용하여 예측 확률 \(\widehat{P}_{Q}\)를 얻는다. 여기서 calibration은 소프트맥스의 스케일과 temperature만 조절한다.
\(\Phi_{S}\)와 \(\Phi_{Q}\)는 모두 \(\phi (\cdot )\)에 의존하기 때문에 크로스 엔트로피를 \(\widehat{P}_{Q}\)와 query set의 레이블에서만 계산하여 attention generator에 supervision을 제공한다.
실험
아래 표에서 Rep.은 Representations의 약자로 세 가지로 평가했다.
AVG는 각 예제의 표현을 임베딩의 평균값으로 한 것이다.
IDF는 각 예제의 표현을 단어 임베딩의 학습 셋의 문서 빈도의 역순의 가중 평균으로 한 것이다.
CNN은 입력입력 단어에 1D convolution을 적용하고 max-over-time pooling으로 한 것이다.
Alg.는 Algorithm의 약자이다.
RR은 Ridge Regressor이다.
NN은 유클리드 거리에서의 1-nearest-neighbor 분류기이다.
FT는 모든 학습 예제를 사전학습 한 후 support set으로 네트워크를 파인튜닝 한 것이다.
MAML과 PROTO(Prototypical network)는 다른 메타러닝 기법이다.
s()와 t()는 각각 general한 단어 중요도와 클래스별 단어 중요도이다.
이 외의 분석 결과들