ML \ DL/PyTorch Zero To All

Pytorch Lecture 04: Back-Propagation and Autograd

lfgwy 2022. 7. 13. 17:03

방학동안 학회에서 김성훈 교수님의 PyTorch Zero To All 강의로 공부를 하게 된 김에 스스로 정리해보려고 합니다.

좋은 강의 공유해주신 김성훈 교수님께 감사드립니다.

 

강의링크:

https://www.youtube.com/playlist?list=PLlMkM4tgfjnJ3I-dbhO9JTw7gNty6o_2m

 

PyTorchZeroToAll (in English)

Basic ML/DL lectures using PyTorch in English.

www.youtube.com

 

설명에 들어가기에 앞서 backpropagation에 관해서는 Michigan University의 Justion Johnson교수가 강의한 Deep Learning for Computer Vision에서 설명이 정말 잘 되었다고 생각하기에 해당 강의 링크도 공유합니다.

 

강의링크(Deep Learning for Computer Vision - Backpropagation):

https://www.youtube.com/watch?v=dB-u77Y5a6A&list=PL5-TkQAfAZFbzxjBHtzdVCWE0Zbhomg7r&index=6 

 

Back-propagation and Autograd

2강, 3강에서 다룬 모델들은 굉장히 단순한 모델이기 때문에, gradient를 직접 계산하는 것이 가능했습니다. 하지만 우리가 그림과 같이 복잡한 네트워크를 다루는 경우, 또 각 노드들 사이의 관계가 비선형인 경우, gradient를 직접 계산하는 것이 매우 어렵고 심지어 불가능한 경우도 많습니다.

PyTorch Zero to All 김성훈 교수님 강의자료

 

보다 나은 방법은 computational graph와 chain rule을 이용하는 방법입니다. 많은 노드로 구성되어있는 네트워크의 한 개 노드를 보고 있다고 가정해봅시다. f는 x와 y를 input으로 받아 output z를 출력해냅니다. 우리에게 z에 대해 loss의 gradient가 주어졌다면, chain rule을 통해 x에 대한 loss의 gradient와 y에 대한 loss의 gradient를 구할 수 있습니다. 

Pytorch Zero to All 김성훈 교수님 강의자료

예제) f가 *이고, x = 2, y = 3인 경우

forward pass를 진행합니다. z = 2 * 3 = 6입니다. z = x * y이기 때문에 local gradient는 각각 $\frac{\partial z}{\partial x} = \frac{\partial x * y}{\partial x} = y = 3$, $\frac{\partial z}{\partial y} = \frac{\partial x * y}{\partial y} = x = 2$가 됩니다.  그리고 $\frac{\partial L}{\partial z} = 5$라고 주어졌다면 chain rule에 의해 $\frac{\partial L}{\partial x}$ = 15, $\frac{\partial L}{\partial y}$ = 10이 됩니다. 이 때 연산의 결과를 역으로 추적하기 때문에 위와 같은 과정을 역전파(Backpropagation)이라고 합니다. 

 

이제 우리가 2강, 3강 때 사용했던 선형 모델에 대해 backpropagation을 실시해봅시다.

 

1. x = 1, y = 2, w = 1인 경우

Pytorch Zero to All 김성훈 교수님 강의자료

우선, forward pass를 진행합니다. $\hat{y} = x * w = 1$입니다. s = $\hat{y} - y = 1 - 2 = -1$입니다. loss = $s^2$ = 1입니다. loss가 $s^2$이므로, 초록색 게이트에서의 local gradient는 $\frac{\partial loss}{\partial s} = \frac{\partial s^2}{\partial s} = 2s = -2$입니다. 

빨강색 게이트에서의 local gradient는 $\frac{\partial s}{\partial \hat{y}} = \frac{\partial \hat{y} - y}{\partial \hat{y}} = 1$입니다. chain rule을 이용하여 $\frac{\partial loss}{\partial \hat{y}} = \frac{\partial loss}{\partial s} \frac{\partial s}{\partial \hat{y}} = -2 * 1 = -2$라는 것을 구할 수 있습니다.

파랑색 게이트에서의 local gradient는 $\frac{\partial \hat{y}}{\partial w} = \frac{\partial xw}{\partial w} = x = 1$입니다. 마찬가지로 chain rule을 이용하여 $\frac{\partial loss}{\partial w} = \frac{\partial loss}{\partial \hat{y}} \frac{\partial \hat{y}}{\partial w} = -2 * 1 = -2$라는 것을 연산 가능합니다. 우리가 구하고자 했던 $\frac{\partial loss}{\partial w}$를 backpropagation을 사용하여 구한 것입니다.  

 

2. x = 2, y = 4, w = 1인 경우

Pytorch Zero to All 김성훈 교수님 강의자료

1번과 마찬가지로 forward pass부터 진행합니다. $\hat{y} = x * w = 2$입니다. s = $\hat{y} - y = 2 - 4 = -2$입니다. loss = $s^2$ = 4입니다. loss가 $s^2$이므로, 초록색 게이트에서의 local gradient는 $\frac{\partial loss}{\partial s} = \frac{\partial s^2}{\partial s} = 2s = -4$입니다. 

빨강색 게이트에서의 local gradient는 $\frac{\partial s}{\partial \hat{y}} = \frac{\partial \hat{y} - y}{\partial \hat{y}} = 1$입니다. chain rule을 이용하여 $\frac{\partial loss}{\partial \hat{y}} = \frac{\partial loss}{\partial s} \frac{\partial s}{\partial \hat{y}} = -4 * 1 = -4$라는 것을 구할 수 있습니다.

파랑색 게이트에서의 local gradient는 $\frac{\partial \hat{y}}{\partial w} = \frac{\partial xw}{\partial w} = x = 2$입니다. 마찬가지로 chain rule을 이용하여 $\frac{\partial loss}{\partial w} = \frac{\partial loss}{\partial \hat{y}} \frac{\partial \hat{y}}{\partial w} = -4 * 2 = -8$이라는 것을 연산 가능합니다. 우리가 구하고자 했던 $\frac{\partial loss}{\partial w}$를 backpropagation을 사용하여 구한 것입니다.  

 

3. x = 1, y = 2, w = 1, b = 2인 경우

PyTorch Zero to All 김성훈 교수님 강의자료

b가 추가되었지만 과정은 동일합니다. foward pass부터 진행합니다. $\hat{y} = x * w + b = 3$입니다. s = $\hat{y} - y = 3 - 2 = 1$입니다. loss = $s^2$ = 1입니다. loss가 $s^2$이므로, 초록색 게이트에서의 local gradient는 $\frac{\partial loss}{\partial s} = \frac{\partial s^2}{\partial s} = 2s = 2$입니다.

빨강색 게이트에서의 local gradient는 $\frac{\partial s}{\partial \hat{y}} = \frac{\partial \hat{y} - y}{\partial \hat{y}} = 1$입니다. chain rule을 이용하여 $\frac{\partial loss}{\partial \hat{y}} = \frac{\partial loss}{\partial s} \frac{\partial s}{\partial \hat{y}} = 2 * 1 = 2$라는 것을 구할 수 있습니다.

추가된 b는 결국엔 상수이기 때문에 backpropagation의 미분 과정에서 사라집니다. 따라서 이전의 파랑색 게이트에서와 같이 backpropagation을 진행합니다. 하늘색 게이트에서의 local gradient는 $\frac{\partial \hat{y}}{\partial w} = \frac{\partial xw + b}{\partial w} = x = 1$입니다. 마찬가지로 chain rule을 이용하여 $\frac{\partial loss}{\partial w} = \frac{\partial loss}{\partial \hat{y}} \frac{\partial \hat{y}}{\partial w} = 2 * 1 = 2$이라는 것을 연산 가능합니다. 우리가 구하고자 했던 $\frac{\partial loss}{\partial w}$를 backpropagation을 사용하여 구한 것입니다.  

 

import torch
import pdb

x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = torch.tensor([1.0], requires_grad=True)
#2강에서와 마찬가지로 실습에 필요한 데이터를 생성해줍니다.
#requires_grad = True로 설정하면 그 tensor에서 이뤄진 모든 연산들을 추적합니다.
#후에 backward()를 호출하여 모든 gradient들을 자동으로 계산할 수 있습니다. 
#이 변화도는 .grad에 누적됩니다. 추적을 중단하려면 detach()를 호출하면 됩니다.
#메모리 사용량을 줄이기 위해 코드블럭을 with torch.no_grad():로 감싸기도 합니다.

# our model forward pass
def forward(x):
    return x * w
#역시나 2강에서와 마찬가지로 y_hat값인 x * w, x_data와 weight을 곱해주는 함수 forward를 정의합니다.

# Loss function
def loss(y_pred, y_val):
    return (y_pred - y_val) ** 2
#error인 (y_hat - y)^2를 계산해주는 함수 loss를 정의합니다.


# Before training
print("Prediction (before training)",  4, forward(4).item())
#학습 이전에 w의 initial value는 1이기 때문에 forward(4)의 결과는 4.0입니다.

# Training loop
# 0부터 9까지 10번 반복
for epoch in range(10):
    for x_val, y_val in zip(x_data, y_data):
        y_pred = forward(x_val) # 1) Forward pass
        l = loss(y_pred, y_val) # 2) Compute loss, forward propagation
        l.backward() # 3) Back propagation to update weights
        #autograd를 이용하여 backpropagation 실시합니다. require_grad = True를 갖는 모든
        #텐서들에 대해 loss의 gradient를 계산합니다. 이후 w.grad는 w에 대한 loss의 gradient를 
        #갖는 텐서가 됩니다.
        print("\tgrad: ", x_val, y_val, w.grad.item()) #\t를 하면 들여쓰기 해줍니다
        w.data = w.data - 0.01 * w.grad.item()

        # Manually zero the gradients after updating weights
        # 0으로 초기화해주지 않으면 이전 gradient가 다음 루프에 간섭을 하게 되어 원하는 방향으로 학습이 
        # 진행되지 않습니다.
        w.grad.data.zero_()

    print(f"Epoch: {epoch} | Loss: {l.item()}")

# After training
print("Prediction (after training)",  4, forward(4).item())
#학습을 진행한 후에는 실제값인 8과 유사한 값을 보이는 것을 확인할 수 있습니다.