KalelPark's LAB

[CODE] Masked AutoEncoder CODE로 살펴보기 본문

Data Science/CODE

[CODE] Masked AutoEncoder CODE로 살펴보기

kalelpark 2023. 3. 16. 19:44
 

Load Library

import torch
import timm
import numpy as np

from einops import repeat, rearrange
from einops.layers.torch import Rearrange

from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import Block

Patch shuffle을 하기 위한 Class 및 function 구축

    * forward_indexes  : (16, 2)

    * backward_indexes  : (16, 2)

       이후, take_indexes를 거칠 때, gather는 각 차원의 vector를 0을 axis로 사용하므로, depth, channel?

       기준으로 cutting되어 출력되게 됩니다.

def random_indexes(size : int):
    forward_indexes = np.arange(size)
    np.random.shuffle(forward_indexes)
    backward_indexes = np.argsort(forward_indexes)
    return forward_indexes, backward_indexes

def take_indexes(sequence, indexes):
	return torch.gather(sequence, 0, repeat(indexes, "t b -> t b c", c = sequence.shape[-1])

class PatchShuffle(torch.nn.Module):
    def __init__(self, ratio) -> None:
        super().__init__()
        self.ratio = ratio
    
    def forward(self, patches : torch.Tensor):
        T, B, C = patches.shape     # 16 2 10
        remain_T = int(T * (1 - self.ratio)) # 4
        indexes = [random_indexes(T) for _ in range(B)]     # 2 2 16
        forward_indexes = torch.as_tensor(np.stack([i[0] for i in indexes], axis = -1), dtype = torch.long).to(patches.device)
        backward_indexes = torch.as_tensor(np.stack([i[1] for i in indexes], axis = -1), dtype = torch.long).to(patches.device)
        patches = take_indexes(patches, forward_indexes)
        patches = patches[:remain_T]

        return patches, forward_indexes, backward_indexes

Masked Auto Encoder 구축하기

  * 여기서는 이전에 불러온 라이브러리를 사용합니다.

     from.timm.models.vision_transformer import Block

     from.timm.model .layrers import trunc_normal_

           - trunc_normal : layer 내 parameter를 초기화할 때 사용합니다.

           - Block : transformer Architecture 내 Block을 의미합니다.

 

* 주의해야 할 점은 Masked AutoEncoder는 asymmetric 구조임을 반드시 인지하시기 바랍니다.

class MAE_Encoder(torch.nn.Module):
    def __init__(self,
    			 image_size = 32,
                 patch_size = 2,
                 emb_dim = 192,
                 num_layer = 12,
                 num_head = 3,
                 mask_ratio = 0.75) -> None:
		super().__init__()
        
        self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2, 1, emb_dim))
        self.shuffle = PatchShuffle(mask_ratio)
        self.patchify = torch.nn.Conv2d(3, emb_dim, patch_size, patch_size)
        self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])
        self.layer_norm = torch.nn.LayerNorm(emb_size)
        self.init_weight()
    
    def init_weight():
    	trunc_normal_(self.cls_token, std = .02)
        trunc_normal_(self.pos_embedding, std = .02)
    
    def forward(self, img):
    	patches = self.patchfiy(img)
        patches = rearrange(patches, "b c h w -> (h w) b c")
        patches = patches + self.pos_embedding
        
        patches, forward_indexes, backward_indexes = self.shuffle(patches)
        patches = torch.cat([self.cls_token.expand(-1, patchesd.shape[1], -1), patches], dim = 0)
        patches = rearrange(patches, "t b c -> b t c")
        features = self.layer_norm(self.transformer(patches))
        
        return features, backward_indexes

Masked Auto Decoer 구축하기

    * transformer 구조를 다시 복습할 필요성을 느낍니다..,! 

class MAE_Decoder(torch.nn.Module):
	def __init__(self, image_size = 32,
    			 patch_size = 2,
                 emb_dim = 192,
                 num_layer = 4,
                 num_head = 3) -> None:
		super().__init__()
        
        self.mask_toekn = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2 + 1, 1, emb_dim)
        self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])
        self.head = torch.nn.Linear(emb_dim, 3 * patch_size ** 2)
        elf.patch2img = Rearrange("(h w) b (c p1 p2) -> b c (h p1) (w p2)", p1 = patch_size, p2 = patch_size, h = image_size // patch_size)
        self.init_weight()
    
    def init__weight(self):
    	trunc_normal_(self.mask_token, std = .02)
        trunc_normal_(self.pos_embedding, std = .02)
    
    def forward(self, features, backward_indexes):
        T = features.shape[0]
        backward_indexes = torch.cat([torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1], dim = 0)
        features = torch.cat([features, self.mask_token.expand(backward_indexes.shape[0] - features.shape[0], features.shape[1], -1)], dim = 0)
        features = take_indexes(features, backward_indexes)
        features = features + self.pos_embedding

        features = rearrange(features, "t b c -> b t c")
        features = self.transformer(features)
        features = rearrange(features, "b t c -> t b c")
        features = features[1:]

        patches = self.head(features)
        mask = torch.zeros_like(patches)
        mask[T:] = 1
        mask = take_indexes(mask, backward_indexes[1:] - 1)
        img = self.patch2img(patches)
        mask = self.patch2img(mask)

        return img, mask

 

'Data Science > CODE' 카테고리의 다른 글

[CODE] Gradient Accumulate이란?  (0) 2023.03.20
[CODE] MixUp 분할해서 구현하기  (3) 2023.03.20
[CODE] Multi-GPU 활용하기  (0) 2023.03.19
[CODE] How to making useful logger?  (0) 2023.03.15
[CODE] NLP's Study for BERT tutorial  (1) 2023.03.13
Comments