일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | ||
6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 28 | 29 | 30 | 31 |
- Torch
- web
- Front
- dl
- Depth estimation
- FineGrained
- 3d
- GAN
- 딥러닝
- classification
- 자료구조
- computervision
- SSL
- Meta Learning
- 알고리즘
- REACT
- ML
- cs
- pytorch
- nerf
- clean code
- 머신러닝
- nlp
- Vision
- algorithm
- FGVC
- math
- CV
- Python
- PRML
- Today
- Total
KalelPark's LAB
[ 논문 구현 ] SimCLR DataLoader, info_loss 구현 본문
[ 논문 구현 ] SimCLR DataLoader, info_loss 구현
kalelpark 2023. 1. 5. 15:52GitHub를 참고하시면, CODE 및 다양한 논문 리뷰가 있습니다! 하단 링크를 참고하시기 바랍니다.
(+ Star 및 Follow는 사랑입니다..!)
https://github.com/kalelpark/Awesome-ComputerVision
GitHub - kalelpark/Awesome-ComputerVision: Awesome-ComputerVision
Awesome-ComputerVision. Contribute to kalelpark/Awesome-ComputerVision development by creating an account on GitHub.
github.com
ContrastiveTransformation을 위한 transform을 새로 생성
class ContrastriveTransformations(object):
def __init__(self, base_transforms, n_view = 2):
self.base_transforms = base_transforms
self.n_view = n_view
def __call__(self, x):
return [self.base_transforms(x) for i in range(self.n_view)]
contrast_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop(size = 96),
transforms.RandomApply([
transforms.ColorJitter(
brightness = 0.5,
contrast = 0.5,
saturation = 0.5,
hue = 0.1)
], p = 0.8),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=9),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
비교할 Dataset을 위하여, n_views 설정
unlabeled_data = STL10( root = "./data", split = "unlabeled", download = True,
transform = ContrastriveTransformations(contrast_transforms, n_view = 2))
Loss 함수 설정
def info_nce_loss(features):
labels = torch.cat([torch.arange(32) for i in range(2)], dim=0)
labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
features = F.normalize(features, dim=1)
similarity_matrix = torch.matmul(features, features.T)
mask = torch.eye(labels.shape[0], dtype=torch.bool)
labels = labels[~mask].view(labels.shape[0], -1)
similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
logits = torch.cat([positives, negatives], dim=1)
labels = torch.zeros(logits.shape[0], dtype=torch.long)
logits = logits / 0.2
return logits, labels
logits, labels = info_nce_loss(output_feats)
criterion = torch.nn.CrossEntropyLoss()
criterion(logits, labels)
// 혹은
def info_nce_loss(self, batch, mode='train'):
imgs, _ = batch
imgs = torch.cat(imgs, dim=0)
# Encode all images
feats = self.convnet(imgs)
# Calculate cosine similarity
cos_sim = F.cosine_similarity(feats[:,None,:], feats[None,:,:], dim=-1)
# Mask out cosine similarity to itself
self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
cos_sim.masked_fill_(self_mask, -9e15)
# Find positive example -> batch_size//2 away from the original example
pos_mask = self_mask.roll(shifts=cos_sim.shape[0]//2, dims=0)
# InfoNCE loss
cos_sim = cos_sim / self.hparams.temperature
nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1)
nll = nll.mean()
참고
https://github.com/sthalles/SimCLR
GitHub - sthalles/SimCLR: PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representation
PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations - GitHub - sthalles/SimCLR: PyTorch implementation of SimCLR: A Simple Framework for Contrast...
github.com
https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial17/SimCLR.html
Tutorial 17: Self-Supervised Contrastive Learning with SimCLR — UvA DL Notebooks v1.2 documentation
We will start our exploration of contrastive learning by discussing the effect of different data augmentation techniques, and how we can implement an efficient data loader for such. Next, we implement SimCLR with PyTorch Lightning, and finally train it on
uvadlc-notebooks.readthedocs.io