KalelPark's LAB

[Meta Learning] MetaDataLoader 구현 본문

Data Science/Meta Learning

[Meta Learning] MetaDataLoader 구현

kalelpark 2022. 12. 28. 20:55

* 해당 포스팅을 보기 전에, 이전 논문 리뷰를 우선적으로 참고하시는데 많은 도움이 될 것입니다. 감사합니다.

https://kalelpark.tistory.com/28

 

[논문 리뷰] Optimization as a model for few-shot learning?

GitHub를 참고하시면, CODE 및 다양한 논문 리뷰가 있습니다! 하단 링크를 참고하시기 바랍니다. (+ Star 및 Follow는 사랑입니다..!) https://github.com/kalelpark/Awesome-ComputerVision GitHub - kalelpark/Awesome-ComputerVis

kalelpark.tistory.com

Torchmeta란?

       Pytorch에서의 few-shot learning & meta-learning을 위한 DataLoader라고 볼 수 있습니다.
       해당 라이브러리에는 meta learning에 자주 등장하는 데이터셋들이 존재합니다.

import warnings
from typing import Any, Dict, Tuple

import matplotlib.pyplot as plt
from torchmeta.datasets.helpers import omniglot
from torchmeta.utils.data import BatchMetaDataLoader

     - torchmeta.datasets.helpers에는 일부 벤치마크 데이터셋에 대한 데이터 세팅을 포함하고 있습니다.     

     - 해당 Dataset의 매개변수를 활용하여, N-way, K-shot 세팅 등 기본적인 설정이 가능합니다.

Torchmeta DataLoader  정의

     - 여러 매개변수를 통하여, dataset 및 dataloader 생성하여, 가져오는 것이 가능합니다.

def get_dataloader(config : Dict[str, Any])-> Tuple[BatchMetaDataLoader, BatchMetaDataLoader, BatchMetaDataLoader]:
    train_dataset = omniglot(
        folder = config["folder_name"],
        shots = config["num_shots"],
        ways = config["num_ways"],
        shuffle = True,
        meta_train = True,
        download = config["download"],
    )
    
    train_dataloader = BatchMetaDataLoader(
        train_dataset, batch_size = config["batch_size"], shuffle = True
    )
    
    val_dataset = omniglot(
        folder = config["folder_name"],
        shots = config["num_shots"],
        ways = config["num_ways"],
        shuffle = True,
        meta_val = True,
        download = config["download"],
    )

    val_dataloader = BatchMetaDataLoader(
        val_dataset, batch_size = config["batch_size"], shuffle = True
    )
    
    test_dataset = omniglot(
        folder = config["folder_name"],
        shots = config["num_shots"],
        ways = config["num_ways"],
        shuffle = True,
        meta_train = True,
        download = config["download"],
    )

    test_dataloader = BatchMetaDataLoader(
        test_dataset, batch_size = config["batch_size"], shuffle = True
    )
    
    return train_dataloader, val_dataloader, test_dataloader

config를 보면 알 수 있다시피, 5-ways 2-shots이 되었으며,

각각의 5개의 class에 데이터가 2개씩 존재합니다.

config = {
    "folder_name" : "dataset",
    "download" : True,
    "num_shots" : 2,
    "num_ways" : 5,
    "batch_size" : 3,
    "num_batches_train" : 6000,
    "num_batches_test" : 2000,
    "num_batches_val" : 100,
    "device" : "cpu"
}

train_dataloader, val_dataloader, test_dataloader = get_dataloader(config)

for batch_idx, batch in enumerate(train_dataloader):
    if batch_idx >= config["num_batches_train"]:
        break
    
    support_xs = batch["train"][0].to(device = config["device"])
    support_ys = batch["train"][1].to(device = config["device"])
    query_xs = batch["test"][0].to(device = config["device"])
    query_ys = batch["test"][1].to(device = config["device"])
    
    print(
        f"support_x shape : {support_xs.shape}\n",
        f"support_y shape : {support_ys.shape}\n",
        f"query_x shape : {query_xs.shape}\n",
        f"query_y shape : {query_ys.shape}\n",
    )
    
    break

 

* 시각화를 해보자면,

for b in range(config["batch_size"]):
    fig = plt.figure(constrained_layout = True, figsize = (18, 4))
    subfigs = fig.subfigures(1, 2, wspace = 0.07)
    
    subfigs[0].set_facecolor("0.75")
    subfigs[0].suptitle("Support set", fontsize = "x-large")
    support_axs = subfigs.flat[0].subplots(nrows = 2, ncols = 5)
    
    for i, ax in enumerate(supprt_axs.T.flatten()):
        ax.imshow(support_xs[b][i].permute(1, 2, 0).squeeze(), aspect = "auto")

    subfigs[1].set_facecolor("0.75")
    subfigs[1].suptitle("Query set", fontsize = "x-large")
    query_axs = subfigs.flat[1].subplots(nrows = 2, ncols = 5)
    for i, ax in enumerate(query_axs.T.flatten()):
        ax.imshow(query_xs[b][i].permute(1, 2, 0).squeeze(), aspect="auto")
    
    fig.suptitle("Batch " + str(b), fontsize = "xx-large")
    plt.show()

Sinusoid란?

      - 일정한 패턴이 반복되는 매끈한 곡선 함수를 말합니다. Sinusoid 함수의 예측은 대표적인 메타 회귀 문제이다.

 

         K-shot 메타 회귀문제에서는 주어진 sinusoid 곡선 위 K개의 정점에 대한 좌표 값을 데이터로 받아,

         sinusoid 곡선의 진폭과 위상을 학습 및 예측을 하게 됩니다.

 

Sinusoid DataLoader 구현

      - torchmeta.toy.Sinsoid를 활용하여, 하나의 태스크에서 제공하는 set의 크기와

         데이터셋을 구성하는 전체 태스크의 개수를 설정 가능합니다.

import warnings
from typing import Any, Dict, Tuple

import torch
from torchmeta.toy import Sinusoid
from torchmeta.utils.data import BatchMetaDataLoader

warnings.filterwarnings(action = "ignore")

def get_dataloader(config : Dict[str, Any]) -> Tuple[BatchMetaDataLoader, BatchMetaDataLoader, BatchMetaDataLoader]:
    train_dataset = Sinusoid(
        num_samples_per_task = config["num_shots"] * 2,
        num_tasks = config["num_batches_train"] * config["batch_size"],
        noise_std = None
    )
    train_dataloader = BatchMetaDataLoader(train_dataset, batch_size = config["batch_size"])
    
    val_dataset = Sinusoid(
        num_samples_per_task = config["num_shots"] * 2,
        num_tasks = config["num_batches_train"] * config["batch_size"],
        noise_std = None
    )
    val_dataloader = BatchMetaDataLoader(val_dataset, batch_size = config["batch_size"])
    
    test_dataset = Sinusoid(
        num_samples_per_task = config["num_shots"] * 2,
        num_tasks = config["num_batches_train"] * config["batch_size"],
        noise_std = None
    )
    test_dataloader = BatchMetaDataLoader(test_dataset, batch_size = config["batch_size"])

    return train_dataloader, val_dataloader, test_dataloader

train_dataloader, val_dataloader, test_dataloader = get_dataloader(config)

for batch_idx, batch in enumerate(val_dataloader):
    xs, ys = batch
    support_xs = xs[:, : config["num_shots"], :].to(device = config["device"]).type(torch.float)
    query_xs = xs[:, : config["num_shots"], :].to(device = config["device"]).type(torch.float)
    support_ys = xs[:, : config["num_shots"], :].to(device = config["device"]).type(torch.float)
    query_ys = xs[:, : config["num_shots"], :].to(device = config["device"]).type(torch.float)

    print(
        f"support_x shape : {support_xs.shape}\n",
        f"support_x shape : {support_ys.shape}\n",
        f"query_x shape : {query_xs.shape}\n",
        f"query_x shape : {query_ys.shape}\n",
    )

    break
Comments