KalelPark's LAB

[ Computer Vision ] MaskedAutoencoder overall CODE 본문

Data Science/CODE

[ Computer Vision ] MaskedAutoencoder overall CODE

kalelpark 2023. 5. 14. 09:25

MaskedAutoencoder

일반적인 MAE에서의 Overall CODE

class MaskedAutoencoder(nn.Module):
	def __init__(self):
    	nn.Module.__init__(self)
        self.norm_pix_loss = True
        
    def patchify(self, imgs):			# imgs to Patch 
    	p = self.decoder_patch_size
        
        h = w = imgs.shape[2] // p
        x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
        x = torch.einsum('nchpwq->nhwpqc', x)
        x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
        return x
    
    def unpatchify(self, x):		     # Patch to imgs
    	p = self.decoder_patch_size
        h = w = int(x.shape[1] ** . 5)
        assert h * w == x.shape[1]
        
        x = x.reshape(shape = (x.shape[0], h, w, p, p, 3))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
        return imgs

	def masking_id(self, batch_size, mask_ratio):
        N, L = batch_size, 196 # self.patch_embed.num_patches
        len_keep = int(L * (1 - mask_ratio))
        noise = torch.rand(N, L)

        ids_shuffle = torch.argsort(noise, dim = 1)
        ids_restore = torch.argsort(ids_shuffle, dim = 1)

        ids_keep = ids_shuffle[:, :len_keep]
        mask = torch.ones([N, L])
        mask[:, : ids_keep.size(1)] = 0
        mask = torch.gather(mask, dim = 1, index = ids_restore)
        return ids_keep, ids_restore, mask
    
    def random_masking(self, x, ids_keep):
        N, L, D = x.shape
        x_masked = torch.gather(x, dim = 1, index = ids_keep.unsqueeze(1).repeat(1, 1, D))
        return x_masked
    
    def forward_encoder(self, x, mask_ratio):
        raise NotImplementedError
    
    def forward_decoder(sefl, x, ids_restore):
        raise NotImplementedError
    
    def forward_loss(self, imgs, cls_pred, pred, mask):
        num_preds = mask.sum()
        target = self.patchify(imgs)   # 1, 256, 588
        if self.norm_pix_loss:
            mean = target.mean(dim = -1, keepdim = True)
            var = target.var(dim = -1, keepdim = True)
            target = (target - mean) / (var + 1.e-6) ** .5
        
        loss = (pred - target) ** 2
        loss = loss.mean(dim = -1)
        loss = (loss * mask).sum() / num_preds
        return loss
    
    def forward(self, imgs, mask_ratio = 0.75):
        latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
        cls_pred, pred = self.forward_decoder(latent, ids_restore)
        loss = self.forward_loss(imgs, cls_pred, pred, mask)
        return loss, pred, mask

 

Comments