방학동안 학회에서 김성훈 교수님의 PyTorch Zero To All 강의로 공부를 하게 된 김에 스스로 정리해보려고 합니다.
좋은 강의 공유해주신 김성훈 교수님께 감사드립니다.
강의링크:
https://www.youtube.com/playlist?list=PLlMkM4tgfjnJ3I-dbhO9JTw7gNty6o_2m
PyTorchZeroToAll (in English)
Basic ML/DL lectures using PyTorch in English.
www.youtube.com
설명에 들어가기에 앞서 backpropagation에 관해서는 Michigan University의 Justion Johnson교수가 강의한 Deep Learning for Computer Vision에서 설명이 정말 잘 되었다고 생각하기에 해당 강의 링크도 공유합니다.
강의링크(Deep Learning for Computer Vision - Backpropagation):
https://www.youtube.com/watch?v=dB-u77Y5a6A&list=PL5-TkQAfAZFbzxjBHtzdVCWE0Zbhomg7r&index=6
Back-propagation and Autograd
2강, 3강에서 다룬 모델들은 굉장히 단순한 모델이기 때문에, gradient를 직접 계산하는 것이 가능했습니다. 하지만 우리가 그림과 같이 복잡한 네트워크를 다루는 경우, 또 각 노드들 사이의 관계가 비선형인 경우, gradient를 직접 계산하는 것이 매우 어렵고 심지어 불가능한 경우도 많습니다.
보다 나은 방법은 computational graph와 chain rule을 이용하는 방법입니다. 많은 노드로 구성되어있는 네트워크의 한 개 노드를 보고 있다고 가정해봅시다. f는 x와 y를 input으로 받아 output z를 출력해냅니다. 우리에게 z에 대해 loss의 gradient가 주어졌다면, chain rule을 통해 x에 대한 loss의 gradient와 y에 대한 loss의 gradient를 구할 수 있습니다.
예제) f가 *이고, x = 2, y = 3인 경우
forward pass를 진행합니다. z = 2 * 3 = 6입니다. z = x * y이기 때문에 local gradient는 각각 $\frac{\partial z}{\partial x} = \frac{\partial x * y}{\partial x} = y = 3$, $\frac{\partial z}{\partial y} = \frac{\partial x * y}{\partial y} = x = 2$가 됩니다. 그리고 $\frac{\partial L}{\partial z} = 5$라고 주어졌다면 chain rule에 의해 $\frac{\partial L}{\partial x}$ = 15, $\frac{\partial L}{\partial y}$ = 10이 됩니다. 이 때 연산의 결과를 역으로 추적하기 때문에 위와 같은 과정을 역전파(Backpropagation)이라고 합니다.
이제 우리가 2강, 3강 때 사용했던 선형 모델에 대해 backpropagation을 실시해봅시다.
1. x = 1, y = 2, w = 1인 경우
우선, forward pass를 진행합니다. $\hat{y} = x * w = 1$입니다. s = $\hat{y} - y = 1 - 2 = -1$입니다. loss = $s^2$ = 1입니다. loss가 $s^2$이므로, 초록색 게이트에서의 local gradient는 $\frac{\partial loss}{\partial s} = \frac{\partial s^2}{\partial s} = 2s = -2$입니다.
빨강색 게이트에서의 local gradient는 $\frac{\partial s}{\partial \hat{y}} = \frac{\partial \hat{y} - y}{\partial \hat{y}} = 1$입니다. chain rule을 이용하여 $\frac{\partial loss}{\partial \hat{y}} = \frac{\partial loss}{\partial s} \frac{\partial s}{\partial \hat{y}} = -2 * 1 = -2$라는 것을 구할 수 있습니다.
파랑색 게이트에서의 local gradient는 $\frac{\partial \hat{y}}{\partial w} = \frac{\partial xw}{\partial w} = x = 1$입니다. 마찬가지로 chain rule을 이용하여 $\frac{\partial loss}{\partial w} = \frac{\partial loss}{\partial \hat{y}} \frac{\partial \hat{y}}{\partial w} = -2 * 1 = -2$라는 것을 연산 가능합니다. 우리가 구하고자 했던 $\frac{\partial loss}{\partial w}$를 backpropagation을 사용하여 구한 것입니다.
2. x = 2, y = 4, w = 1인 경우
1번과 마찬가지로 forward pass부터 진행합니다. $\hat{y} = x * w = 2$입니다. s = $\hat{y} - y = 2 - 4 = -2$입니다. loss = $s^2$ = 4입니다. loss가 $s^2$이므로, 초록색 게이트에서의 local gradient는 $\frac{\partial loss}{\partial s} = \frac{\partial s^2}{\partial s} = 2s = -4$입니다.
빨강색 게이트에서의 local gradient는 $\frac{\partial s}{\partial \hat{y}} = \frac{\partial \hat{y} - y}{\partial \hat{y}} = 1$입니다. chain rule을 이용하여 $\frac{\partial loss}{\partial \hat{y}} = \frac{\partial loss}{\partial s} \frac{\partial s}{\partial \hat{y}} = -4 * 1 = -4$라는 것을 구할 수 있습니다.
파랑색 게이트에서의 local gradient는 $\frac{\partial \hat{y}}{\partial w} = \frac{\partial xw}{\partial w} = x = 2$입니다. 마찬가지로 chain rule을 이용하여 $\frac{\partial loss}{\partial w} = \frac{\partial loss}{\partial \hat{y}} \frac{\partial \hat{y}}{\partial w} = -4 * 2 = -8$이라는 것을 연산 가능합니다. 우리가 구하고자 했던 $\frac{\partial loss}{\partial w}$를 backpropagation을 사용하여 구한 것입니다.
3. x = 1, y = 2, w = 1, b = 2인 경우
b가 추가되었지만 과정은 동일합니다. foward pass부터 진행합니다. $\hat{y} = x * w + b = 3$입니다. s = $\hat{y} - y = 3 - 2 = 1$입니다. loss = $s^2$ = 1입니다. loss가 $s^2$이므로, 초록색 게이트에서의 local gradient는 $\frac{\partial loss}{\partial s} = \frac{\partial s^2}{\partial s} = 2s = 2$입니다.
빨강색 게이트에서의 local gradient는 $\frac{\partial s}{\partial \hat{y}} = \frac{\partial \hat{y} - y}{\partial \hat{y}} = 1$입니다. chain rule을 이용하여 $\frac{\partial loss}{\partial \hat{y}} = \frac{\partial loss}{\partial s} \frac{\partial s}{\partial \hat{y}} = 2 * 1 = 2$라는 것을 구할 수 있습니다.
추가된 b는 결국엔 상수이기 때문에 backpropagation의 미분 과정에서 사라집니다. 따라서 이전의 파랑색 게이트에서와 같이 backpropagation을 진행합니다. 하늘색 게이트에서의 local gradient는 $\frac{\partial \hat{y}}{\partial w} = \frac{\partial xw + b}{\partial w} = x = 1$입니다. 마찬가지로 chain rule을 이용하여 $\frac{\partial loss}{\partial w} = \frac{\partial loss}{\partial \hat{y}} \frac{\partial \hat{y}}{\partial w} = 2 * 1 = 2$이라는 것을 연산 가능합니다. 우리가 구하고자 했던 $\frac{\partial loss}{\partial w}$를 backpropagation을 사용하여 구한 것입니다.
import torch
import pdb
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = torch.tensor([1.0], requires_grad=True)
#2강에서와 마찬가지로 실습에 필요한 데이터를 생성해줍니다.
#requires_grad = True로 설정하면 그 tensor에서 이뤄진 모든 연산들을 추적합니다.
#후에 backward()를 호출하여 모든 gradient들을 자동으로 계산할 수 있습니다.
#이 변화도는 .grad에 누적됩니다. 추적을 중단하려면 detach()를 호출하면 됩니다.
#메모리 사용량을 줄이기 위해 코드블럭을 with torch.no_grad():로 감싸기도 합니다.
# our model forward pass
def forward(x):
return x * w
#역시나 2강에서와 마찬가지로 y_hat값인 x * w, x_data와 weight을 곱해주는 함수 forward를 정의합니다.
# Loss function
def loss(y_pred, y_val):
return (y_pred - y_val) ** 2
#error인 (y_hat - y)^2를 계산해주는 함수 loss를 정의합니다.
# Before training
print("Prediction (before training)", 4, forward(4).item())
#학습 이전에 w의 initial value는 1이기 때문에 forward(4)의 결과는 4.0입니다.
# Training loop
# 0부터 9까지 10번 반복
for epoch in range(10):
for x_val, y_val in zip(x_data, y_data):
y_pred = forward(x_val) # 1) Forward pass
l = loss(y_pred, y_val) # 2) Compute loss, forward propagation
l.backward() # 3) Back propagation to update weights
#autograd를 이용하여 backpropagation 실시합니다. require_grad = True를 갖는 모든
#텐서들에 대해 loss의 gradient를 계산합니다. 이후 w.grad는 w에 대한 loss의 gradient를
#갖는 텐서가 됩니다.
print("\tgrad: ", x_val, y_val, w.grad.item()) #\t를 하면 들여쓰기 해줍니다
w.data = w.data - 0.01 * w.grad.item()
# Manually zero the gradients after updating weights
# 0으로 초기화해주지 않으면 이전 gradient가 다음 루프에 간섭을 하게 되어 원하는 방향으로 학습이
# 진행되지 않습니다.
w.grad.data.zero_()
print(f"Epoch: {epoch} | Loss: {l.item()}")
# After training
print("Prediction (after training)", 4, forward(4).item())
#학습을 진행한 후에는 실제값인 8과 유사한 값을 보이는 것을 확인할 수 있습니다.
'ML \ DL > PyTorch Zero To All' 카테고리의 다른 글
PyTorch Lecture 10: Basic CNN (0) | 2022.07.26 |
---|---|
PyTorch Lecture 09: Softmax Classifier (0) | 2022.07.14 |
PyTorch Lecture 06: Logistic Regression (0) | 2022.07.13 |
PyTorch Lecture 03: Gradient Descent (0) | 2022.07.13 |
PyTorch Lecture 02: Linear Model (0) | 2022.07.13 |