일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
- Depth estimation
- pytorch
- Vision
- Python
- ML
- Meta Learning
- math
- 3d
- 알고리즘
- classification
- computervision
- Front
- 자료구조
- PRML
- GAN
- Torch
- 딥러닝
- CV
- REACT
- dl
- nerf
- clean code
- cs
- algorithm
- 머신러닝
- FineGrained
- nlp
- FGVC
- SSL
- web
- Today
- Total
KalelPark's LAB
[Meta Learning] MetaDataLoader 구현 본문
* 해당 포스팅을 보기 전에, 이전 논문 리뷰를 우선적으로 참고하시는데 많은 도움이 될 것입니다. 감사합니다.
https://kalelpark.tistory.com/28
Torchmeta란?
Pytorch에서의 few-shot learning & meta-learning을 위한 DataLoader라고 볼 수 있습니다.
해당 라이브러리에는 meta learning에 자주 등장하는 데이터셋들이 존재합니다.
import warnings
from typing import Any, Dict, Tuple
import matplotlib.pyplot as plt
from torchmeta.datasets.helpers import omniglot
from torchmeta.utils.data import BatchMetaDataLoader
- torchmeta.datasets.helpers에는 일부 벤치마크 데이터셋에 대한 데이터 세팅을 포함하고 있습니다.
- 해당 Dataset의 매개변수를 활용하여, N-way, K-shot 세팅 등 기본적인 설정이 가능합니다.
Torchmeta DataLoader 정의
- 여러 매개변수를 통하여, dataset 및 dataloader 생성하여, 가져오는 것이 가능합니다.
def get_dataloader(config : Dict[str, Any])-> Tuple[BatchMetaDataLoader, BatchMetaDataLoader, BatchMetaDataLoader]:
train_dataset = omniglot(
folder = config["folder_name"],
shots = config["num_shots"],
ways = config["num_ways"],
shuffle = True,
meta_train = True,
download = config["download"],
)
train_dataloader = BatchMetaDataLoader(
train_dataset, batch_size = config["batch_size"], shuffle = True
)
val_dataset = omniglot(
folder = config["folder_name"],
shots = config["num_shots"],
ways = config["num_ways"],
shuffle = True,
meta_val = True,
download = config["download"],
)
val_dataloader = BatchMetaDataLoader(
val_dataset, batch_size = config["batch_size"], shuffle = True
)
test_dataset = omniglot(
folder = config["folder_name"],
shots = config["num_shots"],
ways = config["num_ways"],
shuffle = True,
meta_train = True,
download = config["download"],
)
test_dataloader = BatchMetaDataLoader(
test_dataset, batch_size = config["batch_size"], shuffle = True
)
return train_dataloader, val_dataloader, test_dataloader
config를 보면 알 수 있다시피, 5-ways 2-shots이 되었으며,
각각의 5개의 class에 데이터가 2개씩 존재합니다.
config = {
"folder_name" : "dataset",
"download" : True,
"num_shots" : 2,
"num_ways" : 5,
"batch_size" : 3,
"num_batches_train" : 6000,
"num_batches_test" : 2000,
"num_batches_val" : 100,
"device" : "cpu"
}
train_dataloader, val_dataloader, test_dataloader = get_dataloader(config)
for batch_idx, batch in enumerate(train_dataloader):
if batch_idx >= config["num_batches_train"]:
break
support_xs = batch["train"][0].to(device = config["device"])
support_ys = batch["train"][1].to(device = config["device"])
query_xs = batch["test"][0].to(device = config["device"])
query_ys = batch["test"][1].to(device = config["device"])
print(
f"support_x shape : {support_xs.shape}\n",
f"support_y shape : {support_ys.shape}\n",
f"query_x shape : {query_xs.shape}\n",
f"query_y shape : {query_ys.shape}\n",
)
break
* 시각화를 해보자면,
for b in range(config["batch_size"]):
fig = plt.figure(constrained_layout = True, figsize = (18, 4))
subfigs = fig.subfigures(1, 2, wspace = 0.07)
subfigs[0].set_facecolor("0.75")
subfigs[0].suptitle("Support set", fontsize = "x-large")
support_axs = subfigs.flat[0].subplots(nrows = 2, ncols = 5)
for i, ax in enumerate(supprt_axs.T.flatten()):
ax.imshow(support_xs[b][i].permute(1, 2, 0).squeeze(), aspect = "auto")
subfigs[1].set_facecolor("0.75")
subfigs[1].suptitle("Query set", fontsize = "x-large")
query_axs = subfigs.flat[1].subplots(nrows = 2, ncols = 5)
for i, ax in enumerate(query_axs.T.flatten()):
ax.imshow(query_xs[b][i].permute(1, 2, 0).squeeze(), aspect="auto")
fig.suptitle("Batch " + str(b), fontsize = "xx-large")
plt.show()
Sinusoid란?
- 일정한 패턴이 반복되는 매끈한 곡선 함수를 말합니다. Sinusoid 함수의 예측은 대표적인 메타 회귀 문제이다.
K-shot 메타 회귀문제에서는 주어진 sinusoid 곡선 위 K개의 정점에 대한 좌표 값을 데이터로 받아,
sinusoid 곡선의 진폭과 위상을 학습 및 예측을 하게 됩니다.
Sinusoid DataLoader 구현
- torchmeta.toy.Sinsoid를 활용하여, 하나의 태스크에서 제공하는 set의 크기와
데이터셋을 구성하는 전체 태스크의 개수를 설정 가능합니다.
import warnings
from typing import Any, Dict, Tuple
import torch
from torchmeta.toy import Sinusoid
from torchmeta.utils.data import BatchMetaDataLoader
warnings.filterwarnings(action = "ignore")
def get_dataloader(config : Dict[str, Any]) -> Tuple[BatchMetaDataLoader, BatchMetaDataLoader, BatchMetaDataLoader]:
train_dataset = Sinusoid(
num_samples_per_task = config["num_shots"] * 2,
num_tasks = config["num_batches_train"] * config["batch_size"],
noise_std = None
)
train_dataloader = BatchMetaDataLoader(train_dataset, batch_size = config["batch_size"])
val_dataset = Sinusoid(
num_samples_per_task = config["num_shots"] * 2,
num_tasks = config["num_batches_train"] * config["batch_size"],
noise_std = None
)
val_dataloader = BatchMetaDataLoader(val_dataset, batch_size = config["batch_size"])
test_dataset = Sinusoid(
num_samples_per_task = config["num_shots"] * 2,
num_tasks = config["num_batches_train"] * config["batch_size"],
noise_std = None
)
test_dataloader = BatchMetaDataLoader(test_dataset, batch_size = config["batch_size"])
return train_dataloader, val_dataloader, test_dataloader
train_dataloader, val_dataloader, test_dataloader = get_dataloader(config)
for batch_idx, batch in enumerate(val_dataloader):
xs, ys = batch
support_xs = xs[:, : config["num_shots"], :].to(device = config["device"]).type(torch.float)
query_xs = xs[:, : config["num_shots"], :].to(device = config["device"]).type(torch.float)
support_ys = xs[:, : config["num_shots"], :].to(device = config["device"]).type(torch.float)
query_ys = xs[:, : config["num_shots"], :].to(device = config["device"]).type(torch.float)
print(
f"support_x shape : {support_xs.shape}\n",
f"support_x shape : {support_ys.shape}\n",
f"query_x shape : {query_xs.shape}\n",
f"query_x shape : {query_ys.shape}\n",
)
break
'Data Science > Meta Learning' 카테고리의 다른 글
[ 논문 리뷰 ] One-shot Learning with Memory-Augmented Neural Networks? (0) | 2022.12.29 |
---|---|
[Meta Learning] Neural Turing Machines이란? (2) | 2022.12.28 |
[ Meta Learning ] 모델 기반 메타 러닝 이해하기 (0) | 2022.12.28 |
[논문 리뷰] Optimization as a model for few-shot learning? (0) | 2022.12.27 |
[Meta Learning] What's Meta Learning? (0) | 2022.12.25 |