대외활동/DSC CV Study

week2 - Backpropagation 오차역전파법

여니여니_ 2020. 1. 13. 08:46

오차역전파법

신경망 학습에서, 신경망의 가중치(weight)에 대한 손실 함수의 기울기를 구하기 위해 미분을 사용했습니다. 미분 계산은 단순하고 구현하기 쉽지만, 시간이 오래걸리는 단점이 있습니다. 초기 가중치에서의 손실 함수를 계산하고 그것을 미분하여 다음 가중치 계산에 사용하는데, 뉴런이 많아져 각 노드의 매개변수(가중치, 편향)가 많아지면 그 모든 매개변수를 임의로 대입하며 학습시키면 시간도 매우 오래걸리고 계산량도 많아 비효율적입니다. 이것을 효율적으로 하기 위한 것이 오차역전파 알고리즘입니다.

 

학습이란 각 뉴런의 파라미터에 의해 계산된 최종 출력을 토대로 손실함수를 구하고, 그 손실함수를 최소화하는 방향으로 각 뉴런의 매개변수(가중치, 편향)를 변화시켜나가는 것입니다. 학습이 완료되었다는 것은, 학습 모델이 손실함수가 최소화되도록 모든 뉴런에서 각각의 매개변수 값을 찾았다는 것입니다.  

 

출력에 대한 영향력을 알기 위하여 오차역전파법(Back Propagation)을 사용합니다. 이 알고리즘은 학습 시에 입력부터 출발하여 각 노드의 매개변수를 통한 계산 후 출력으로 나오는 Forward Propagation과는 반대로 출력으로부터 입력까지 연쇄법칙(Chain Rule)을 통해 input쪽으로 거슬러 올라갑니다.

 

Chain Rule을 이용하면 출력에 대한 각 노드 매개변수의 미분 값이 계산되는데, 이 미분 값이 바로 특정 매개변수가 전체 출력에 주는 영향력이 되는 것입니다.  

 

코드 구현

 

순전파와 역전파를 가진 간단한 클래스를 만든다.

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class MulLayer:
    def __init__(self):
        self.x = None
        self.y = None
        
    def forward(self, x, y):
        self.x = x
        self.y = y
        out = x * y
        return out
    
    def backward(self, dout):
        dx = dout * self.y
        dy = dout * self.x
        return dx, dy

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
apple = 100
apple_num = 2
tax = 1.1
 
mul_apple_layer = MulLayer()
mul_tax_layer = MulLayer()
 
# forward propagation
apple_price = mul_apple_layer.forward(apple, apple_num)
price = mul_tax_layer.forward(apple_price, tax)
 
print(price)
 
dprice = 1
dapple_price, dtax = mul_tax_layer.backward(dprice)
dapple, dapple_num = mul_apple_layer.backward(dapple_price)
 
print(dapple, dapple_num, dtax)
 

 

 

 

 

 

 

 

https://sacko.tistory.com/39

 

문과생도 이해하는 딥러닝 (6) - 오차역전파법 실습 1

2017/09/27 - 문과생도 이해하는 딥러닝 (1) - 퍼셉트론 Perceptron 2017/10/18 - 문과생도 이해하는 딥러닝 (2) - 신경망 Neural Network 2017/10/25 - 문과생도 이해하는 딥러닝 (3) - 오차 역전파, 경사하강법..

sacko.tistory.com

 

https://excelsior-cjh.tistory.com/171

 

03. 오차역전파 - BackPropagation

이번 포스팅은 '밑바닥부터 시작하는 딥러닝' 교재로 공부한 것을 정리한 것입니다. 아래의 이미지들은 해당 교재의 GitHub에서 가져왔으며, 혹시 문제가 된다면 이 포스팅은 삭제하도록 하겠습니다.. ㅜㅜ 오차역..

excelsior-cjh.tistory.com

 

https://ratsgo.github.io/deep%20learning/2017/05/14/backprop/

 

 

오차 역전파 (backpropagation) · ratsgo's blog

이번 글에서는 오차 역전파법(backpropagation)에 대해 살펴보도록 하겠습니다. 이번 글은 미국 스탠포드대학의 CS231n 강의를 기본으로 하되, 고려대학교 데이터사이언스 연구실의 김해동 석사과정이 쉽게 설명한 자료를 정리했음을 먼저 밝힙니다. 그럼 시작하겠습니다. 계산그래프와 chain rule 계산그래프(computational graph)는 계산과정을 그래프로 나타낸 것입니다. 노드(node, 꼭지점)은 함수(연산), 엣지(edge, 간선)

ratsgo.github.io

 

 

https://www.youtube.com/watch?v=Ilg3gGewQ5U