Recent Posts
Link
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | |||||
3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | 20 | 21 | 22 | 23 |
24 | 25 | 26 | 27 | 28 | 29 | 30 |
Tags
- algorithm
- pytorch
- ML
- REACT
- web
- FGVC
- SSL
- PRML
- nerf
- 딥러닝
- Python
- GAN
- cs
- clean code
- 자료구조
- 머신러닝
- 3d
- Front
- 알고리즘
- math
- Vision
- Meta Learning
- Depth estimation
- dl
- nlp
- FineGrained
- CV
- Torch
- computervision
- classification
- Today
- Total
KalelPark's LAB
[ Computer Vision ] MaskedAutoencoder overall CODE 본문
MaskedAutoencoder
일반적인 MAE에서의 Overall CODE
class MaskedAutoencoder(nn.Module):
def __init__(self):
nn.Module.__init__(self)
self.norm_pix_loss = True
def patchify(self, imgs): # imgs to Patch
p = self.decoder_patch_size
h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
return x
def unpatchify(self, x): # Patch to imgs
p = self.decoder_patch_size
h = w = int(x.shape[1] ** . 5)
assert h * w == x.shape[1]
x = x.reshape(shape = (x.shape[0], h, w, p, p, 3))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
return imgs
def masking_id(self, batch_size, mask_ratio):
N, L = batch_size, 196 # self.patch_embed.num_patches
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L)
ids_shuffle = torch.argsort(noise, dim = 1)
ids_restore = torch.argsort(ids_shuffle, dim = 1)
ids_keep = ids_shuffle[:, :len_keep]
mask = torch.ones([N, L])
mask[:, : ids_keep.size(1)] = 0
mask = torch.gather(mask, dim = 1, index = ids_restore)
return ids_keep, ids_restore, mask
def random_masking(self, x, ids_keep):
N, L, D = x.shape
x_masked = torch.gather(x, dim = 1, index = ids_keep.unsqueeze(1).repeat(1, 1, D))
return x_masked
def forward_encoder(self, x, mask_ratio):
raise NotImplementedError
def forward_decoder(sefl, x, ids_restore):
raise NotImplementedError
def forward_loss(self, imgs, cls_pred, pred, mask):
num_preds = mask.sum()
target = self.patchify(imgs) # 1, 256, 588
if self.norm_pix_loss:
mean = target.mean(dim = -1, keepdim = True)
var = target.var(dim = -1, keepdim = True)
target = (target - mean) / (var + 1.e-6) ** .5
loss = (pred - target) ** 2
loss = loss.mean(dim = -1)
loss = (loss * mask).sum() / num_preds
return loss
def forward(self, imgs, mask_ratio = 0.75):
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
cls_pred, pred = self.forward_decoder(latent, ids_restore)
loss = self.forward_loss(imgs, cls_pred, pred, mask)
return loss, pred, mask
'Data Science > CODE' 카테고리의 다른 글
[CODE] ZipFile 편리하게 다루는 방법 (0) | 2023.05.01 |
---|---|
[ Computer Vision ] Real-Time Depth Estimation CODE (2) | 2023.05.01 |
[CODE] Attention Convd 구현 (1) | 2023.04.13 |
[CODE] Masking imaging 코드 구현 (0) | 2023.03.29 |
[CODE] Profile 팁 및 라이브러리 소개 및 logging 추천 (0) | 2023.03.26 |
Comments