KalelPark's LAB

[CODE] Attention Convd 구현 본문

Data Science/CODE

[CODE] Attention Convd 구현

kalelpark 2023. 4. 13. 10:57

해당 객체에 Attention을 줘보도록 하겠습니다. :)

CODE

import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision.transforms import transforms

trans_main = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean = 0.5, std = 0.5),    
])

class ImageAttentionMap(nn.Module):
	def __init__(self):
    	super(ImageAttentionMap, self).__init__()
        self.conv_block = nn.Sequential(
        	nn.Conv2d(in_channels = 3, out_channels = 64, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(inplace = True),
            nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(inplace = True),
            nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(inplace = True),
            nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(inplace = True)
        )
        
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, 1)
    
    def forward(self, x):
    	x = self.conv_block(x)		# Conv_Block :  torch.Size([1, 64, 224, 224])
        attention_map = x.clone()   
        
        x = self.global_avg_pool(x) # Conv_Block :  torch.Size([1, 64, 1, 1])
        
        x = x.view(x.size(0), -1)		# x.view :  torch.Size([1, 64])
        x = self.fc(x)
        x = F.sigmoid(x)
        
        attention_weights = self.fc.weight.detach().squeeze()
        # attention_weights :  torch.Size([1, 1])
        
        attention_weights = attention_weights.view(1, -1, 1, 1)
        # attention_map :  torch.Size([1, 1])
        
        attention_map = attention_map * attention_weights
        attention_map = torch.sum(attention_map, dim=1, keepdim=True)
        attention_map = F.relu(attention_map)
        attention_map /= torch.max(attention_map)
        return x, attention_map

 

Comments