Recent Posts
Link
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
Tags
- nlp
- dl
- FGVC
- web
- math
- REACT
- computervision
- Depth estimation
- Python
- 머신러닝
- PRML
- CV
- Front
- 알고리즘
- algorithm
- 3d
- cs
- pytorch
- GAN
- classification
- Torch
- Meta Learning
- nerf
- 딥러닝
- Vision
- FineGrained
- 자료구조
- ML
- clean code
- SSL
- Today
- Total
KalelPark's LAB
[CODE] Masked AutoEncoder CODE로 살펴보기 본문
Load Library
import torch
import timm
import numpy as np
from einops import repeat, rearrange
from einops.layers.torch import Rearrange
from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import Block
Patch shuffle을 하기 위한 Class 및 function 구축
* forward_indexes : (16, 2)
* backward_indexes : (16, 2)
이후, take_indexes를 거칠 때, gather는 각 차원의 vector를 0을 axis로 사용하므로, depth, channel?
기준으로 cutting되어 출력되게 됩니다.
def random_indexes(size : int):
forward_indexes = np.arange(size)
np.random.shuffle(forward_indexes)
backward_indexes = np.argsort(forward_indexes)
return forward_indexes, backward_indexes
def take_indexes(sequence, indexes):
return torch.gather(sequence, 0, repeat(indexes, "t b -> t b c", c = sequence.shape[-1])
class PatchShuffle(torch.nn.Module):
def __init__(self, ratio) -> None:
super().__init__()
self.ratio = ratio
def forward(self, patches : torch.Tensor):
T, B, C = patches.shape # 16 2 10
remain_T = int(T * (1 - self.ratio)) # 4
indexes = [random_indexes(T) for _ in range(B)] # 2 2 16
forward_indexes = torch.as_tensor(np.stack([i[0] for i in indexes], axis = -1), dtype = torch.long).to(patches.device)
backward_indexes = torch.as_tensor(np.stack([i[1] for i in indexes], axis = -1), dtype = torch.long).to(patches.device)
patches = take_indexes(patches, forward_indexes)
patches = patches[:remain_T]
return patches, forward_indexes, backward_indexes
Masked Auto Encoder 구축하기
* 여기서는 이전에 불러온 라이브러리를 사용합니다.
from.timm.models.vision_transformer import Block
from.timm.model .layrers import trunc_normal_
- trunc_normal : layer 내 parameter를 초기화할 때 사용합니다.
- Block : transformer Architecture 내 Block을 의미합니다.
* 주의해야 할 점은 Masked AutoEncoder는 asymmetric 구조임을 반드시 인지하시기 바랍니다.
class MAE_Encoder(torch.nn.Module):
def __init__(self,
image_size = 32,
patch_size = 2,
emb_dim = 192,
num_layer = 12,
num_head = 3,
mask_ratio = 0.75) -> None:
super().__init__()
self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2, 1, emb_dim))
self.shuffle = PatchShuffle(mask_ratio)
self.patchify = torch.nn.Conv2d(3, emb_dim, patch_size, patch_size)
self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])
self.layer_norm = torch.nn.LayerNorm(emb_size)
self.init_weight()
def init_weight():
trunc_normal_(self.cls_token, std = .02)
trunc_normal_(self.pos_embedding, std = .02)
def forward(self, img):
patches = self.patchfiy(img)
patches = rearrange(patches, "b c h w -> (h w) b c")
patches = patches + self.pos_embedding
patches, forward_indexes, backward_indexes = self.shuffle(patches)
patches = torch.cat([self.cls_token.expand(-1, patchesd.shape[1], -1), patches], dim = 0)
patches = rearrange(patches, "t b c -> b t c")
features = self.layer_norm(self.transformer(patches))
return features, backward_indexes
Masked Auto Decoer 구축하기
* transformer 구조를 다시 복습할 필요성을 느낍니다..,!
class MAE_Decoder(torch.nn.Module):
def __init__(self, image_size = 32,
patch_size = 2,
emb_dim = 192,
num_layer = 4,
num_head = 3) -> None:
super().__init__()
self.mask_toekn = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2 + 1, 1, emb_dim)
self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])
self.head = torch.nn.Linear(emb_dim, 3 * patch_size ** 2)
elf.patch2img = Rearrange("(h w) b (c p1 p2) -> b c (h p1) (w p2)", p1 = patch_size, p2 = patch_size, h = image_size // patch_size)
self.init_weight()
def init__weight(self):
trunc_normal_(self.mask_token, std = .02)
trunc_normal_(self.pos_embedding, std = .02)
def forward(self, features, backward_indexes):
T = features.shape[0]
backward_indexes = torch.cat([torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1], dim = 0)
features = torch.cat([features, self.mask_token.expand(backward_indexes.shape[0] - features.shape[0], features.shape[1], -1)], dim = 0)
features = take_indexes(features, backward_indexes)
features = features + self.pos_embedding
features = rearrange(features, "t b c -> b t c")
features = self.transformer(features)
features = rearrange(features, "b t c -> t b c")
features = features[1:]
patches = self.head(features)
mask = torch.zeros_like(patches)
mask[T:] = 1
mask = take_indexes(mask, backward_indexes[1:] - 1)
img = self.patch2img(patches)
mask = self.patch2img(mask)
return img, mask
'Data Science > CODE' 카테고리의 다른 글
[CODE] Gradient Accumulate이란? (0) | 2023.03.20 |
---|---|
[CODE] MixUp 분할해서 구현하기 (3) | 2023.03.20 |
[CODE] Multi-GPU 활용하기 (0) | 2023.03.19 |
[CODE] How to making useful logger? (0) | 2023.03.15 |
[CODE] NLP's Study for BERT tutorial (1) | 2023.03.13 |
Comments