티스토리 뷰
[논문 리뷰]Few-shot Text Classification with Distributional Signatures
ProWiseman 2023. 5. 12. 14:50들어가며
본 글은 논문 Few-shot Text Classification with Distributional Signatures을 리뷰한 글입니다.
텍스트 도메인에서 메타러닝을 적용하기 어려운 이유
컴퓨터 비전에서는 에지와 같은 저수준의 패턴이나 그에 부합하는 표현들이 task끼리 서로 공유될 수 있지만 언어 데이터에선 대부분의 task들이 어휘(lexical) 수준에서 이루어진다. 단어 어느 task에선 유용하지만 어느 task에선 유용하지 않는다.
※ 여기서 lexicon은 단어의 발음, 의미, 다른 단어와의 관계 등 단어에 관한 정보가 저장되어 있는 곳을 의미하는 것 같다. 참고자료
논문에서의 키 아이디어
단어를 바로 고려하지 않고 단어 분포(i.e. tf-idf weighting)에 있는 특성인 distributional signature를 활용한다.
(tf-idf weighting은 단어의 중요도를 문서 집합의 빈도(frequency)와 특정 문서에 대한 치우침(skewness)을 통해 명시적으로 지정한다.)
Distributional Signature의 효용
본 논문에선 distributional signature를 활용하기 위해 단어의 빈도와 특정 클래스에 대한 단어의 중요도를 이용한다. 특정 클래스에 대한 단어의 중요도는 few-shot learning 이기 때문에 라벨링된 데이터가 부족해 신용할만한 추정치를 얻지는 못하지만 noisy한 추정치는 얻을 수 있고 이는 메타러닝 프레임워크로 정제할 수 있다.
이 부분에서 느낄 수 있듯 distibutional signature는 문장에서 lexical counterparts보다 표현력이 약하지만 메타 지식(meta knowledge)를 만들 때 일반화에 도움이 된다.
※ lexical counterparts란 nouns with a lexical structure and meaning similar to corresponding phrasal verbs라고 한다. 참고자료
기존 메타러닝 프레임워크에서의 확장
본 논문에서는 few-shot learning에서의 기존 표준 메타러닝 프레임워크에서 더 나아간 확장 버전을 도입하였다.
메터 러닝은 클래스 집합
기존의 메타 훈련과 메타 테스트
메타 훈련은
메타 테스트는 메타 훈련이 끝난 뒤에 우리 모델이 정말로 새로운 클래스에 빠르게 적을할 수 있는지 테스트하기 위해 동일한 에피소드 기반 메커니즘을 적용한다.
확장
기존의 메터 훈련에서 모든 예제에 접근할 수 있긴 해도 에피소드별로 작은 서브셋으로만 학습하기 때문에 distributional statistics를 충분히 이용할 수 없다. 따라서 더 강건한(robust) 추론을 위해 모든 훈련 예제에서 distributional statistics를 활용한다. 이를 위해 각 에피소드에서 source pool로 데이터를 증강한다. 메타 훈련동안 source pool은 에피소드로 선택된 클래스를 제외한 모든 학습 클래스의 예제를 포함한다. 메타 테스트동안에는 모든 훈련 예제를 포함한다.
이에 대한 개념은 아래에 그림으로도 나와있다.


논문에서 제시된 방법
논문의 목표는 입력의 distributional signature로 고품질의 attention을 배워 few-shot classification의 성능을 개선하는 것이다.
이를 위해 우선 특정 에피소드가 주어졌을 때, source pool과 support set에서 relevant statistics를 추출한다. 이런 통계값은 분류를 위한 단어 중요도를 rough하게 근사하기 때문에 attention generator라는 모듈을 이용해 relevant statistics를 고품질의 attention으로 만들어준다. 이 attention은 ridge regressor라는 downstream을 위한 예측 모듈에 가이던스를 제공한다.
전제적인 모델은 앞서 소개한 다음 두 모듈로 구성되어 있다.
attention generator : source pool과 support set의 distributional signature를 클래스별 단어 중요도를 반영하는 attention score로 변환하는 역할을 한다.
ridge regressor : attention generator가 제공한 attnetion으로 어휘 표현을 구성하고 downstream을 통해 약간의 학습 예제만으로 예측하는 법을 배운다. 제공된 attention은 단어 중요도의 inductive bias이다.
여기서 알고 넘어갈 점은 attention generator는 모든 학습 에피소드에 걸쳐 최적화를 하는 반면 ridge regressor는 각 에피소드마다 새로 학습한다.
Attention Generator
attention generator의 목표는 각 입력 예제의 distributional signature으로부터 단어 중요도를 얻는 것이다.
distributional signature를 얻기 위해 unigram statistics를 사용했다. 이를 선택한 이유는 명사의 섭동(단어가 미세하게 달라지는 것으로 이해했음)에 대해 강건할 것으로 생각되었기 때문이다. 또한 거대한 source pool로 모델에 general한 단어 중요도를 알리고 클래스별 단어 중요도를 추정하기 위해 작은 support set을 활용하였다.
자주 등장하는 단어는 덜 중요할 수도 있기 때문에 자주 등장하는 단어는 downweigh하고 드물게 등장하는 단어는 upweight한다. general한 단어 중요도를 구하는 방법은 다음과 같다.
반면 support set에서 분별있는 단어는 query set에서도 분별있을 것이다. 따라서 다음의 통계식으로 클래스별 단어 중요도를 반영하도록 정의하였다.
조건부 likelihood
이런 통계를 바로 적용하면 다음 두 가지 이유로 결과가 좋지 못할 것이다.
- 앞서 구한 두 통계는 상호 보완적인 정보를 지니지만 통합하는 방법이 불분명하다.
- 이러한 통계는 분류에서의 단어 중요도엔 다소 noisy하다.
이러한 간극을 잇기 위해 biLSTM으로 두 정보를 합친다. 그리고 마지막으로 단어
Ridge Regressor
ridge regressor의 목표는 attention generator가 제공한 정보로 몇 가지 예제를 보고 빠르게 예측하는 방법을 학습하는 것이다.
우선 앞서 생성된 attention score로 단어 중요도에 집중한 어휘 표현을 구성한다. 이 어휘 표현으로 support set에서 ridge regressor를 처음부터(from scratch) 학습시킨다. 마지막으로 query set에 대한 예측을 만들고 여기서 발생한 loss로 attention generator를 학습시킨다.
Constructing representations
서로 다른 단어가 분류에서 중요도에 다양한 수준을 보일 때 단련된 단어에 호의를 보이는 어휘 표현을 생성한다. (we construct lexical representations that favor pertinent words 원문을 잘 이해하지 못하겠음) 예제
여기서
Training from the support set
N-way K-shot 분류 task에서
우리는 다음의 이유로 ridge regression을 라벨링된 support set에 학습에 이용한다.
- ridge regression은 closed-form 솔루션으로 모델 전반에 end-to-end 미분이 가능하다.
- 적절한 규제를 함께 쓰면 ridge regression은 작은 support set에서도 오버피팅을 줄일 수 있다.
논문에서는 regularized squared loss를 사용하였다.
가중치 행렬
Inference on the query set
a와 b는 메타 파라미터로 메타 훈련을 통해 학습된다. 마지막으로

실험
아래 표에서 Rep.은 Representations의 약자로 세 가지로 평가했다.
AVG는 각 예제의 표현을 임베딩의 평균값으로 한 것이다.
IDF는 각 예제의 표현을 단어 임베딩의 학습 셋의 문서 빈도의 역순의 가중 평균으로 한 것이다.
CNN은 입력입력 단어에 1D convolution을 적용하고 max-over-time pooling으로 한 것이다.
Alg.는 Algorithm의 약자이다.
RR은 Ridge Regressor이다.
NN은 유클리드 거리에서의 1-nearest-neighbor 분류기이다.
FT는 모든 학습 예제를 사전학습 한 후 support set으로 네트워크를 파인튜닝 한 것이다.
MAML과 PROTO(Prototypical network)는 다른 메타러닝 기법이다.
s()와 t()는 각각 general한 단어 중요도와 클래스별 단어 중요도이다.

이 외의 분석 결과들

