티스토리 뷰
들어가며
Few-shot learning에 대표적인 matching network, prototypical, Relation network은 모두 이미지 도메인에서 작성된 논문이다. 추후 산학협력 프로젝트로 진행하는 추천시스템 프로젝트에선 텍스트 데이터를 사용하므로 이에 적용하기 위해서 텍스트 분야에 적용된 17년과 23년 사이에 작성된 Metric 기반 논문 혹은 코드를 찾고자 한다.
Relation Network의 Embedding Network에 BERT를 적용한 접근
해당 접근은 이 블로그에 정리돼 있으며 코드는 작성자의 깃허브 레포에서 확인할 수 있다.
해당 접근에선 워드 임베딩으로 당시 SOTA 모델인 BERT를 사용하였다. 해당 블로그에선 BERT를 few-shot setting과 dataset에 파인튜닝하여 사용했다.
BERT엔 다음의 두 가지 주요 문제점이 존재한다.
- BERT의 높은 연산 복잡도
- 코어 레벨에서 단어 표현을 수정하는 능력
이 문제를 피하기 위해서 필요에 따라 단어 임베딩을 수정할 레이어를 최상위 레이어에 쌓는다. 해당 블로그에선 BiGRU를 사용했다. (성능, 연산량 측면 고려)
실험은 캐글에 있는 news category dataset을 사용했다. 카테고리의 50%는 학습, 20%는 검증, 30%는 테스트에 사용했고 5 way 2 shot으로 실험했다.
GRU를 이용한 Embedding Network는 다음과 같다.
class TextGRUEncoder(nn.Module):
def __init__(self, bert_path):
super(TextGRUEncoder, self).__init__()
self.bert_path = bert_path
self.bert = transformers.BertModel.from_pretrained('bert-base-uncased').eval()
self.GRU = nn.GRU(768, 1024, batch_first=True, bidirectional=True)
self.fc1 = nn.Linear(1024*2, 1024)
self.fc2 = nn.Linear(1024, 768)
def forward(self, ids, mask, token_type_ids):
with torch.no_grad():
o1, o2 = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids)
out, _ = self.GRU(o1)
h = F.relu(self.fc1(torch.sum(out, 1).squeeze(1)))
h = self.fc2(h)
return h
Relation Network론 별 다른 설명은 없었으나 네트워크론 다음과 같은 네트워크를 사용했다.
class TextLSTMRelationNetwork(nn.Module):
def __init__(self):
super(TextLSTMRelationNetwork, self).__init__()
self.fc1 = nn.Linear(768*2, 768)
self.fc2 = nn.Linear(768, 128)
self.fc3 = nn.Linear(128, 1)
def forward(self, x):
h = F.relu(self.fc1(x))
h = F.relu(self.fc2(h))
h = torch.sigmoid(self.fc3(h))
return h
Knowledge guided metric learning
요약하자면 기존 relation network에 사전 지식을 활용하도록 만든 네트워크이다.
Knowledge Guided Metric Learning for Few-Shot Text Classification에선 외부 지식을 few-shot learning에 도입하여 사람을 모방한다. 외부 지식으로 서로 다른 metric을 만드는 parameter generator network를 사용하며 비슷한 task는 비슷한 metric을 쓰고 다른 task는 다른 metric을 사용한다.
Metric-based approach는 샘플을 적절한 측징 공간에 표현하는 것을 배우고 거리 metric을 사용해 레이블을 예측한다.
단어가 특정 task에선 매우 유용하지만 다른 task에선 아닐 수도 있다는 점 때문에 Metric-based approach를 곧바로 text classification에 적용하기는 힘들다. 단일 metric이 few-shot text classification에서 모든 task를 커버하기엔 불충분하기에 knowledge guided metric learning을 사용한다.
이 방법은 knowledge base(KB)의 외부 지식을 사용한다. symbolic fact(토큰의 의미를 하드코딩한 것을 의미하는 것 같음)는 일반화가 좋지 못한 점과 데이터 희소성의 문제로 KB의 distributed representation을 사용한다.(참고) 이 KB에 기반하여 parameter generator networkfh task와 관련된 relation network의 파라미터를 생성한다. 이 파라미터로 task-relevat relation network가 다양한 task에 대한 다양한 metric을 적용하고 비슷한 task는 비슷한 metric을 사용한다는 것을 보장한다.
Encoder
인코더론 pre-trained BERT를 사용한다.
메타러닝에선 각 클래스의 표현은 각 클래스에 속한 임베드된 문장의 평균 벡터이다. 따라서 원 relation network와 마찬가지로 각 서포트 셋의 평균 벡터와 쿼리 표현을 concatenate한다.
아키텍처의 이 부분을 의미한다.
Knowledge Guided Relation Network
이 모듈은 위 과정에서 서포트 표현과 쿼리 표현을 합친 표현과 서포트셋의 지식을 입력으로 받고 0~1 사이의 스칼라 값을 생성한다. 이는 쿼리 문장과 클래스 표현의 유사도를 나타내며 relation score라 불린다.
아키텍처의 이 부분을 의미한다.
기존 relation network와 다른 점은 relation network를 task-agnostic relation network와 task-relevant relation network로 나눈다. task agnostic relation network는 기존의 metric 함수이고 task relevant relation network는 여러 task에 적응한다.
Task-Agnostic Relation Network
이 네트워크는 원래 relation network처럼 모든 task에 통합된 metric을 배운다.
Task-Relevant Relation Network
이 네트워크는 외부 지식으로 다양한 task에 다양한 metric을 적용한다.
상세히 살펴보면 이 네트워크는 각 서포트 셋에서 잠재적으로 KB 콘셉과 연관된 집합을 검색한다. 여기서 콘셉은 후술할 개념으로 KB 임베딩과 연관된 콘셉이다.
KB 임베딩을 요소별로 평균을 내서 서포트 셋의 지식 표현을 형성한다. 이 지식 표현으로 task-relevant relation network의 파라미터를 생성한다. 이렇게 생성된 파라미터로 서포트셋의 클래스와 쿼리 입력에 대한 task-relevant score를 생성한다.
이렇게 두 개의 네트워크로 생성된 두 개의 relation score는 다음의 식으로 최종 relation socre로 변환한다.
이 relation score를 두개의 fc 레이어와 MSE로 학습한다.
Knowledge Embedding and Retrieval
KB로 NELL(Never-Ending Language Learning)을 사용했으며 주어와 목적어 사이의 특정한 관계를 나타내는 (subject, relation, object)의 형태로 저장된다. e.g., (Intel, competes with, Nvidia)
Knowledge Embedding
symbolic fact는 일반화 성능이 좋지 않고 데이터 희소성 문제도 있기 때문에 distributed representation을 사용한다. 여기선 위에서 말했듯 (s, r, o)를 사용하고 각 요소는 벡터로 임베딩돼 있다. 이 벡터는 실제 세계에서의 유효성이 학습돼고 이를 위해 BILINEAR 모델을 사용한다.
이 벡터의 임베딩을 학습할 때 margin-based ranking loss를 사용해 KB 안에 있는 s, r, o 세 요소를 긍정, 부정으로 구성한다.
Knowledge Retrieval
exact string matching으로 주어진 구절에서 엔티티 언급을 인식하고 인식된 엔티티 언급을 KB의 주제(s)에 연결하는 데 사용된다. 이렇게 해서 이에 부합한 목적어(concept) 후보를 모은다. 이 과정을 통해 가능한 KB 임베딩과 관련된 relevant KB concept의 집합을 얻을 수 있다.
해당 논문의 아키텍처에서 BERT 인코더를 위에서 나온대 파인튜닝한 형식으로 바꾸면 성능이 더 높아질 수도 있을 것 같다.
Hierarchical Attention Prototypical Networks for Few-Shot Text Classification
https://aclanthology.org/D19-1045/
기존 prototypical network는 노이즈에 대해 고려를 하지 않기 때문에 이 논문은 세 수준에서의 attention 메커니즘을 이용하여 semantic 공간에서의 표현력을 강화한다.
Feature level : CNN으로 다양한 클래스에 대한 feature 스코어를 얻는다.
Word level : attention 메커니즘으로 각 단어의 instance 내의 은닉층의 중요도를 학습한다.
Instance level : 서포트 셋과 목표 쿼리 사이의 다중 교차 attention으로 같은 클래스에서의 서로 다른 인스턴스의 중요도를 결정짓고 모델이 각 클래스에 대해 더욱 분별있는 프로토타입을 얻을 수 있도록 한다.
Knowledge-Aware Meta-learning for Low-Resource Text Classification
논문 : https://aclanthology.org/2021.emnlp-main.136.pdf
소스코드 : https://github.com/huaxiuyao/KGML
외부 지식 베이스(KB)로 삭습과 훈련 task 간의 간극을 이어주고 적은 자원으로 text classification에서 메타 러닝을 더욱 효율적으로 할 수 있도록 한다.
해당 논문의 키 아이디어는 아래 이미지처럼 모든 task에서 공유되는 KB(knowledge base)에서 추출된 엔티티의 문장별 지식 그래프(KG)를 구성하고 포함하여 각 문장에 대한 추가 표현을 GNN으로 연산하는 것이다.
KGML이란 네트워크로 문장에 특화된 지식 그래프(KG)의 표현을 추출하고 지식을 융합하는 프레임워크이다.
Learning to Few-Shot Learn Across Diverse Natural Language Classification Tasks
https://arxiv.org/abs/1911.03863
STraTA: Self-Training with Task Augmentation for Better Few-shot Learning
https://arxiv.org/abs/2109.06270
https://github.com/google-research/google-research/tree/master/STraTA
'공부한 내용 정리 > 인공지능' 카테고리의 다른 글
[논문 리뷰]Hierarchical Attention Prototypical Networks for Few-Shot Text Classification (0) | 2023.05.19 |
---|---|
추천 시스템 (0) | 2023.05.15 |
[RecSys] Modern Recommendation Systems with Neural Networks (0) | 2023.04.12 |
[논문 리뷰]Few-shot learning for short text classification (0) | 2023.04.06 |
인공지능 기초 4 (1) | 2023.03.25 |