KalelPark's LAB

[Pytorch] torch.gather 코드로 간략하게 이해하기 본문

Python/Pytorch

[Pytorch] torch.gather 코드로 간략하게 이해하기

kalelpark 2023. 3. 16. 17:46

매번 까먹어서, 다시 다듬어보고자 한다.

torch.gather란?

   공식문서에 따르면 차원에 해 정해진 축에 따라 값을 모읍니다.

import torch
t = torch.tensor([[1, 2], [3, 4]])
torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))

즉, 차원에 따라 값을 재배치한다고 이해하면 됩니다.

* 코드로 한번 더 이해해보도록 하겠습니다. 

out[i][j][k] = input[index[i][j][k]][j][k]
out[i][j][k] = input[i][index[j][k]][k]
out[i][j][k] = input[i][j][index[k]]

// 위의 값처럼 indexing이 처리되는 것을 알 수 있습니다.
Example>
import torch

t = torch.tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]])
k = torch.gather(t, 1, torch.tensor([[0, 0, 0], [1, 1, 1], [2, 2, 2]]))
print(k)

tensor([[1, 1, 1],
        [2, 2, 2],
        [3, 3, 3]])

 

Reference

https://woongjun-warehouse.tistory.com/48

 

Comments