KalelPark's LAB

[PYTORCH] nn.Unfold란? 본문

Python/Pytorch

[PYTORCH] nn.Unfold란?

kalelpark 2023. 5. 10. 15:36

 

Pytorch unfold란?

Batched tensor에 대하여, 마치 Convolution처럼, Slidingg하면서, locl block을 구하는 것이다.
예를들면, 해당 MNIST의 경우 1, 1, 28, 28 사이즈의 경우 -> 1, 49, 16 으로 변환한 것이다.

EX)

import torch.nn as nn
from matplotlib import pyplot as plt
from torchvision import datasets, transforms


if __name__ == '__main__':
    train_data = datasets.MNIST(root='./data/',
                                train=True,
                                download=True,
                                transform=transforms.ToTensor())

    test_data = datasets.MNIST(root='./data/',
                               train=False,
                               download=True,
                               transform=transforms.ToTensor())

    image, label = train_data[0]
    image = image.unsqueeze(0)  # 1, 1, 28, 28
    print(image.size())
    unfold = nn.Unfold(kernel_size=2, stride=2, padding=0)
    x = unfold(image)

    fig, axs = plt.subplot_mosaic([['left', 'right']], layout='constrained')
    axs['left'].imshow(image.squeeze().numpy(), cmap='gray')

    x = x[:, 0, ...].view(1, 1, 14, 14)
    axs['right'].imshow(x.squeeze().numpy(), cmap='gray')
    plt.show()

 

주요 사용 용도 :

1) https://runebook.dev/ko/docs/pytorch/generated/torch.nn.unfold

2) https://www.facebook.com/groups/PyTorchKR/posts/1685133764959631/

Reference

https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html

 

Unfold — PyTorch 2.0 documentation

Shortcuts

pytorch.org

 

Comments