KalelPark's LAB

[ 논문 구현 ] SimCLR DataLoader, info_loss 구현 본문

Data Science/Self Supervised Learning

[ 논문 구현 ] SimCLR DataLoader, info_loss 구현

kalelpark 2023. 1. 5. 15:52

GitHub를 참고하시면, CODE 및 다양한 논문 리뷰가 있습니다! 하단 링크를 참고하시기 바랍니다.
(+ Star 및 Follow는 사랑입니다..!)

https://github.com/kalelpark/Awesome-ComputerVision

 

GitHub - kalelpark/Awesome-ComputerVision: Awesome-ComputerVision

Awesome-ComputerVision. Contribute to kalelpark/Awesome-ComputerVision development by creating an account on GitHub.

github.com

ContrastiveTransformation을 위한 transform을 새로 생성

class ContrastriveTransformations(object):
    def __init__(self, base_transforms, n_view = 2):
        self.base_transforms = base_transforms
        self.n_view = n_view
    
    def __call__(self, x):
        return [self.base_transforms(x) for i in range(self.n_view)]

contrast_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop(size = 96),
    transforms.RandomApply([
        transforms.ColorJitter(
            brightness = 0.5,
            contrast = 0.5,
            saturation = 0.5,
            hue = 0.1)
    ], p = 0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.GaussianBlur(kernel_size=9),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
    ])

비교할 Dataset을 위하여, n_views 설정

unlabeled_data = STL10( root = "./data", split = "unlabeled", download = True,
                        transform = ContrastriveTransformations(contrast_transforms, n_view = 2))

Loss 함수 설정

def info_nce_loss(features):
    labels = torch.cat([torch.arange(32) for i in range(2)], dim=0)
    labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
    features = F.normalize(features, dim=1)

    similarity_matrix = torch.matmul(features, features.T)
    mask = torch.eye(labels.shape[0], dtype=torch.bool)
    labels = labels[~mask].view(labels.shape[0], -1)
    similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
    positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)

    negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
    logits = torch.cat([positives, negatives], dim=1)
    labels = torch.zeros(logits.shape[0], dtype=torch.long)

    logits = logits / 0.2
    return logits, labels

logits, labels = info_nce_loss(output_feats)
criterion = torch.nn.CrossEntropyLoss()
criterion(logits, labels)
        
// 혹은

def info_nce_loss(self, batch, mode='train'):
        imgs, _ = batch
        imgs = torch.cat(imgs, dim=0)

        # Encode all images
        feats = self.convnet(imgs)
        # Calculate cosine similarity
        cos_sim = F.cosine_similarity(feats[:,None,:], feats[None,:,:], dim=-1)
        # Mask out cosine similarity to itself
        self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
        cos_sim.masked_fill_(self_mask, -9e15)
        # Find positive example -> batch_size//2 away from the original example
        pos_mask = self_mask.roll(shifts=cos_sim.shape[0]//2, dims=0)
        # InfoNCE loss
        cos_sim = cos_sim / self.hparams.temperature
        nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1)
        nll = nll.mean()

 

참고

https://github.com/sthalles/SimCLR

 

GitHub - sthalles/SimCLR: PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representation

PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations - GitHub - sthalles/SimCLR: PyTorch implementation of SimCLR: A Simple Framework for Contrast...

github.com

https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial17/SimCLR.html

 

Tutorial 17: Self-Supervised Contrastive Learning with SimCLR — UvA DL Notebooks v1.2 documentation

We will start our exploration of contrastive learning by discussing the effect of different data augmentation techniques, and how we can implement an efficient data loader for such. Next, we implement SimCLR with PyTorch Lightning, and finally train it on

uvadlc-notebooks.readthedocs.io

 

Comments