KalelPark's LAB

[CODE] Gradient Clipping이란? 본문

Data Science/CODE

[CODE] Gradient Clipping이란?

kalelpark 2023. 3. 21. 00:23

Gradient Clipping이란?

    주로 RNN계열에서 gradient vanishing이나 gradient exploding이 많이 발생하는데, gradient exploding을 방지하여,
    학습의 안정화를 도모하기 위해 사용하는 방법입니다.  

Gradient Clipping과 L2_Norm

    Clipping이란, gradient가 일정 threshold를 넘어가면, clipping을 해줍니다. clipping은 gradient의 L2 norm으로 나눠주는 방식입니다. Clipping이 없으면, gradient가 너무 뛰어서, global minimum에 도달하지 않고, 너무 엉뚱한 방향으로 향하게 되지만, Clipping을 해주게 되면, gradient vector가 방향은 유지하고, 적은 값의 이동을 하여, 도달하고자 하는 곳으로 인정적으로 내려갑니다.

수식을 보면, 기울기 norm이 threshold를 넘으면, 기울기 벡터를 최댓값보다 큰 만큼의 비율로 나누어줍니다. 따라서, 기울기는 항상 역치보다 작습니다.

  이러한 방법은 학습의 발산을 방지함과 동시에 기울기의 방향 자체가 바뀌지 않고, 유지하게 하므로, 모델 파라미터가 학습해야 하는 방향을 잃지 않게 합니다. 즉 손실 함수를 최소화하기 위한 기울기의 방향을 유지한 채로 크기만 조절합니다.

CODE

import torch

max_norm = 5
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3, weight_decay = 0)

#  traininig
  optimizer.zero_grad()
  loss.backward()
  torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
  optimizer.step()
  
# 2 example>

import torch.optim as optim
import torch.nn.utils as torch_utils

learning_rate = 1.
max_grad_norm = 5.

optimizer = optim.SGD(model.parameters(), lr=learning_rate)
# In orther to avoid gradient exploding, we apply gradient clipping.
torch_utils.clip_grad_norm_(model.parameters(),
                            max_grad_norm
                            )
# Take a step of gradient descent.
optimizer.step()

 

Comments