일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
- pytorch
- clean code
- ML
- cs
- Depth estimation
- CV
- algorithm
- Torch
- nerf
- Python
- Meta Learning
- classification
- 자료구조
- GAN
- 머신러닝
- computervision
- FineGrained
- dl
- SSL
- 딥러닝
- math
- Vision
- nlp
- Front
- PRML
- REACT
- web
- 3d
- FGVC
- 알고리즘
- Today
- Total
KalelPark's LAB
[논문 리뷰] Representation Learning with Contrastive Predictive Coding 본문
[논문 리뷰] Representation Learning with Contrastive Predictive Coding
kalelpark 2023. 3. 12. 10:43
GitHub를 참고하시면, CODE 및 다양한 논문 리뷰가 있습니다! 하단 링크를 참고하시기 바랍니다.
(+ Star 및 Follow는 사랑입니다..!)
https://github.com/kalelpark/Awesome-ComputerVision
Abstract
본 논문은 high-dimensional data를 추출하기 위한 Contrastive Predictive Coding을 소개합니다.
Contrastive Predictive Coding은 latent space를 예측하도록 Probabilistic contrastive loss를 사용하여 학습을 진행합니다.
논문에서는 Speech, image, text등 범용적으로 활용가능하다고 말합니다. ( Contrastive Learning의 기초 논문이다. )
Introduction
사람들의 뇌는 일반적으로, 관찰된 다양한 정보를 통하여 예측을 한다고 합니다. 단어의 경우, 주변 맥락을 통하여 어휘를 예측하거나, 이미지의 경우 주변 image들의 patch간의 관계를 통하여 색상을 예측한다고 합니다. 이러한 아이디어에서 영감을 받아, unsupervised learning시 layer를 더욱 쌓는다면, high-level information을 얻을 수 있다는 가정하에 실험을 진행하였다고 합니다.
Method
기존의 방법론들은 Mutual Information(MI)을 기반으로 정보를 추출하여 학습을 진행합니다. (자세한 내용은 생략)
Contrastive Learning은 기존 class를 예측하는 것에 초점을 두어, 다른 데이터와의 차이점을 학습하지 못한다는 단점이 있기에,
비교하는 대상과의 차이 또한 학습을 하고자, 논문에서는 CPC를 소개합니다.
InfoNCE는 context 즉, 맥락에서 negative sample과 positive sample 중 positive sample의 비율에 log를 씌워 loss를 측정합니다.
- 원하는 곳의 latent vector : poisitive sample. (택 1)
- 원하는 곳 이외의 latent vector : negative sample (특정 위치 제외 (개수 선정 가능))
수식으로 표현하면, 아래와 같다.
코드로 본다면,
import torch
import torch.nn.functional as F
def infonce_loss(embeddings, temperature = 0.07):
dot_product = torch.matmul(embeddings, embeddings.T())
batch_size = embeddings.shape[0]
num_negatives = batch_size -1
positive_mask = torch.eye(batch_size, dtype = torch.bool)
positive_logits = dot_product[positive_mask].view(batch_size, 1)
negative_mask = ~positive_mask
negative_logits = dot_product[negative_mask].view(batch_size, num_negatives)
logits = torch.cat([positive_logits, negative_logits], dim = 1) / temperature
labels = torch.zeros(batch_size, dtype = torch.long)
loss = F.cross_entropy(logits, labels)
return loss
Reference