Recent Posts
Link
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
Tags
- clean code
- nlp
- math
- dl
- GAN
- Depth estimation
- Vision
- Meta Learning
- Front
- computervision
- 머신러닝
- 3d
- classification
- 딥러닝
- CV
- cs
- ML
- web
- 알고리즘
- FGVC
- 자료구조
- Torch
- SSL
- FineGrained
- REACT
- nerf
- PRML
- Python
- pytorch
- algorithm
- Today
- Total
KalelPark's LAB
[ 논문 구현 ] SimCLR DataLoader, info_loss 구현 본문
Data Science/Self Supervised Learning
[ 논문 구현 ] SimCLR DataLoader, info_loss 구현
kalelpark 2023. 1. 5. 15:52
GitHub를 참고하시면, CODE 및 다양한 논문 리뷰가 있습니다! 하단 링크를 참고하시기 바랍니다.
(+ Star 및 Follow는 사랑입니다..!)
https://github.com/kalelpark/Awesome-ComputerVision
ContrastiveTransformation을 위한 transform을 새로 생성
class ContrastriveTransformations(object):
def __init__(self, base_transforms, n_view = 2):
self.base_transforms = base_transforms
self.n_view = n_view
def __call__(self, x):
return [self.base_transforms(x) for i in range(self.n_view)]
contrast_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop(size = 96),
transforms.RandomApply([
transforms.ColorJitter(
brightness = 0.5,
contrast = 0.5,
saturation = 0.5,
hue = 0.1)
], p = 0.8),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=9),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
비교할 Dataset을 위하여, n_views 설정
unlabeled_data = STL10( root = "./data", split = "unlabeled", download = True,
transform = ContrastriveTransformations(contrast_transforms, n_view = 2))
Loss 함수 설정
def info_nce_loss(features):
labels = torch.cat([torch.arange(32) for i in range(2)], dim=0)
labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
features = F.normalize(features, dim=1)
similarity_matrix = torch.matmul(features, features.T)
mask = torch.eye(labels.shape[0], dtype=torch.bool)
labels = labels[~mask].view(labels.shape[0], -1)
similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
logits = torch.cat([positives, negatives], dim=1)
labels = torch.zeros(logits.shape[0], dtype=torch.long)
logits = logits / 0.2
return logits, labels
logits, labels = info_nce_loss(output_feats)
criterion = torch.nn.CrossEntropyLoss()
criterion(logits, labels)
// 혹은
def info_nce_loss(self, batch, mode='train'):
imgs, _ = batch
imgs = torch.cat(imgs, dim=0)
# Encode all images
feats = self.convnet(imgs)
# Calculate cosine similarity
cos_sim = F.cosine_similarity(feats[:,None,:], feats[None,:,:], dim=-1)
# Mask out cosine similarity to itself
self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
cos_sim.masked_fill_(self_mask, -9e15)
# Find positive example -> batch_size//2 away from the original example
pos_mask = self_mask.roll(shifts=cos_sim.shape[0]//2, dims=0)
# InfoNCE loss
cos_sim = cos_sim / self.hparams.temperature
nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1)
nll = nll.mean()
참고
https://github.com/sthalles/SimCLR
https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial17/SimCLR.html
'Data Science > Self Supervised Learning' 카테고리의 다른 글
Comments