Notice
Recent Posts
Recent Comments
Link
«   2025/02   »
1
2 3 4 5 6 7 8
9 10 11 12 13 14 15
16 17 18 19 20 21 22
23 24 25 26 27 28
Tags
more
Archives
Today
Total
관리 메뉴

Zorba blog

LSTM and GRU(Long Short Term Memory & Gated Recurrent Unit) 본문

Machine Learning

LSTM and GRU(Long Short Term Memory & Gated Recurrent Unit)

Zorba blog 2022. 8. 6. 11:42

LSTM의 등장배경

게이트가 추가된 RNN

- RNN은 순환 경로를 포함하여 과거의 정보를 기억할 수 있음.

- 구조가 단순하여 구현은 쉽지만 안타깝게도 성능이 좋지 않음.

- 그 원인은 시간적으로 많이 떨어진 장기 의존 관계를 잘 학습할 수 없다는 데 있음.

   -> 따라서, 요즘에는 단순한 RNN 대신 LSTM이나 GRU라는 계층이 주로 쓰임.

- LSTM이나 GRU에는 게이트라는 구조가 더해져 있는데,

  이 게이트 덕분에 시계열 데이터의 장기 의존 관계를 학습할 수 있음.

 

Tom was watching TV in his room. Mary came into the room. Mary said hi to ?

 

"?" 에 들어가는 단어는 "Tom" 이다. RNN에서 이 문제에 올바르게 답하려면, 현재 맥락에서 "Tom was watching TV in his room" 과 "Mary cate into the room"이란 정보를 기억해둬야 함. 다시 말해 이런 정보를 RNN 계층의 은닉 상태에 인코딩해 보관해둬야 함.

 

RNN의 문제점

 

Tom was watching TV in his room. Mary came into the room. Mary said hi to ?

 

- 정답 레이블로 "Tom"이라는 단어가 주어졌을 때, RNNLM에서 기울기가 어떻게 전파되는지를 확인.

- 정답 레이블 "Tom"이라고 주어진 시점으로부터 과거 방향으로 기울기를 전달.

- 위 그림처럼 RNN 계층이 과거 방향으로 의미 있는 기울기를 전달함으로써 시간 방향의 의존 관계를 학습.

- 하지만 현재의 단순한 RNN 계층에서는 시간을 거슬러 올라갈수록 기울기가 작아지거나(기울기 소실) 혹은 커질 수 있으며(기울기 폭발), 대부분 둘 중 하나의 운명을 걷게 됨.

 

기울기 소실과 기울기 폭발의 원인

- RNN 계층에서 기울기 소실(혹은 기울기 폭발)이 일어나는 원인을 살펴보자.

- 길이가 T인 시계열 데이터를 가정하면 T번째 정답 레이블로부터 전해지는 기울기가 어떻게 변하는지 살펴보자.

- 시간 방향 기울기에 주목하면 역전파로 전해지는 기울기는 차례로 'tahn', '+', 'Matmul(행렬곱)' 연산을 통과.

- '+' 역전파는 상류에서 전해지는 기울기를 그대로 하류로 흘려보내니, 'tahn', 'Matmul(행렬곱)' 에만 주목.

 

1) 기울기 소실과 기울기 폭발의 원인 - tahn

- y=tahn(x) 일 때의 미분은 1-y^2 이며 각각 그래프로 그리면 아래 그림과 같음.

 

- dtanh에 주목하면 그 값은 1.0 이하이고, x가 0으로부터 멀어질수록 작아짐.

- 이는 역전파에서 기울기가 tanh 노드를 지날 때마다 값은 계속 작아진다는 뜻.

- 그래서 tanh 함수를 T번 통과하면 기울기도 T번 반복해서 작아지게 됨.

* RNN 계층의 활성화 함수로는 주로 tanh 함수를 사용하는데, 이를 ReLU로 바꾸면 기울기 소실을 줄일 수 있음.

 

2) 기울기 소실과 기울기 폭발의 원인 - Matmul

- 다음처럼 상류로부터 dh라는 기울기가 흘러들어왔을 때, 이때 Matmul 노드에서의 역전파는 dhWh^t 라는 행렬 곱으로 기울기를 계산.

- 그리고 같은 계산을 시계열 데이터의 시간 크기만큼 반복.

- 여기에서 주목할 점은 이 행렬 곱셈에서는 매번 똑같은 가중치인 Wh가 사용된다는 점.

- 역전파의 Matmul 노드 수(T) 만큼 dh를 갱신했을 때, 기울기의 크기는 시간에 비례해 지수적으로 증가.

- 이것이 바로 기울기 폭발(Exploding Gradient). 이러한 기울기 폭발이 일어나면 결국 Overflow를 일으켜 NaN 같은 값 발생.


LSTM

기울기 소실과 LSTM

- 이제 기울기 소실을 일으키지 않는다는(혹은 일으키기 어렵게 한다는) LSTM 구조에 대해 살펴보고 이 구조를 개량한 GRU의 구조까지 살펴보자.

 

- 그림에서 보듯 LSTM 계층의 인터페이스에는 c라는 경로가 있다는 차이가 있음.

- C를 기억 셀 이라 하며, LSTM 전용의 기억 메커니즘.

- 기억 셀의 특징은 데이터를 자기 자신으로만 (LSTM 계층 내에서만) 주고받는다는 것.

- 즉, LSTM 계층 내에서만 완결되고, 다른 계층으로는 출력하지 않음.

- 반면, LSTM의 은닉 상태 h는 RNN 계층과 마찬가지로 다른 계층으로 출력.

(LSTM의 출력은 은닉 상태 벡터 h뿐. 그러므로 C의 존재 자체를 생각할 필요가 없음.)

 

LSTM 계층 조립하기

- 기억 셀 Ct에는 시각 t에서의 LSTM의 기억이 저장되어 있는데, 과거로부터 시각 t까지에 필요한 모든 정보가 저장되어 있다고 가정.

- 필요한 정보를 모두 간직한 이 기억을 바탕으로 외부 계층에 은닉 상태 ht를 출력.

- 은닉상태 ht는 기억 셀 ct에 단순히 tanh 함수를 적용했을 뿐.

- LSTM 구조에서의 핵심은 ht는 단기상태(short term state), ct는 장기상태(long term state) 라고 볼 수 있음.

 

LSTM의 게이트

- 게이트는 데이터의 흐름을 제어. 마치 아래 그림처럼 물의 흐름을 멈추거나 배출하는 것이 게이트의 역할.

- 여기서 중요한 것은  '게이트를 얼마나 열까' 라는 것도 데이터로부터 자동으로 학습한다는 점.

- 게이트의 열림 상태를 구할 때는 시그모이드 함수를 사용. 그 이유는 시그모이드 함수의 출력이 0.0~1.0 사이의 실수이기 때문.

 

Output 게이트 - (1)

- 앞에서 은닉상태 ht는 기억 셀 ct에 단순히 tanh 함수를 적용했을 뿐이라고 설명.

- 이번에는 tanh(ct)에 게이트를 적용.

- 즉, tanh(ct)의 각 원소에 대해 '그것이 다음 시각의 은닉 상태에 얼마나 중요한가'를 조정

- output 게이트의 열림 상태(다음 몇%만 흘려보낼까)는 입력 xt와 이전 상태 ht-1로부터 구함.

식 삽입

- RNN 계층의 계산에서 tanh가 아닌 Sigmoid를 사용했다는 점만이 다르다는 것을 확인.

 

Output 게이트 - (2)

- ht는 o(output 게이트 수행 식 시그마의 출력)와 tanh(ct)의 곱으로 계산.(원소별 곱, 아다마르 곱 이라고도 함)

식 삽입

- tanh의 출력은 -1~1의 실수이고, 인코딩된 정보의 강략 정도를 표시한다고 해석할 수 있음.

- 시그모이드 함수의 출력은 0~1의 실수이며, 데이터를 얼마나 통과시킬지를 정하는 비율.

- 주로 게이트에서는 시그모이드 함수가, 실질적인 '정보'를 지니는 데이터에는 tanh 함수가 활성화 함수로 사용.

 

forget 게이트

- 다음은 기억셀에 '무엇을 잊을까'를 지시하는 것. 이것도 게이트를 사용.

- ct-1의 기억 중에서 불필요한 기억을 잊게 해주는 게이트를 forget 게이트라고 함.

 

새로운 기억 셀

- forget 게이트를 거치면서 이전 시각의 기억 셀로부터 잊어야 할 기억이 삭제.

- 새로 기억해야 할 정보를 추가해야 하므로 tanh 노드를 추가.

- tanh 노드가 계산한 결과가 이전 시각의 기억 셀 ct-1에 추가.

- 이 tanh 노드는 '게이트'가 아니며, 새로운 정보를 기억 셀에 추가하는 것이 목적.

- 따라서 활성화 함수로는 시그모이드 함수가 아닌 tanh 함수가 사용.

 

Input 게이트

- 마지막으로 새로운 정보가 들어있는 g에 게이트를 하나 추가.

- input 게이트는 g의 각 원소가 새로 추가되는 정보로써의 가치가 얼마나 큰지를 판단.

- 즉, 새로운 정보를 무비판적으로 수용하는 것이 아니라, 적절히 취사선택하는 것이 이 게이트의 역할.

 

LSTM의 기울기 흐름

- LSTM의 기울기 소실을 없애주는 원리. 기억 셀 C의 역전파에 주목.

- 기억셀의 역전파에는 '+'와 'X' 노드만을 지나게 됨.

- '+' 노드는 상류 기울기를 그대로 흘릴 뿐이므로 남는 것은 'X' 노드, 이 노드는 '행렬 곱'이 아닌 '아다마르 곱'으로 계산

- RNN의 역전파에서는 똑같은 가중치 행렬을 사용해서 '행렬 곱'을 반복했고, 그래서 기울기 소실 혹은 폭발이 일어남.

- 반면, 이번 LSTM의 역전파에서는 '행렬 곱'이 아닌  '원소별 곱'이 이뤄지고, 매 시각 다른 게이트 값을 이용해 원소별 곱을 계산.

- 이처럼 매번 새로운 게이트 값을 이용하므로 곱셈의 효과가 누적되지 않아 기울기 소실이 일어나기 어려운 것.

- 위 그림의 'X' 노드의 계산은 forget 게이트가 제어.

- 역전파 계산시 forget 게이트의 출력과 상류 기울기의 곱이 계산되므로 forget 게이트가 '잊어야 한다'고 판단한 기억 셀의 원소에 대해서는 그 기울기가 작아지고, '잊어서는 안 된다'고 판단한 원소에 대해서는 그 기울기가 약화되지 않은 채로 과거 방향으로 전해짐. 따라서 중요한 정보의 기울기는 소실 없이 전파.


LSTM 구현

Affine 변환

수식 입력

- 주목한 부분은 위에서부터 4개의 f,g,i,o의 수식에 포함된 아핀 변환(xWx+hWh+b).

- 그림에서 보듯 4개의 가중치와 편향을 하나로 모아서 처리. 총 4번을 수행하던 아핀 변환을 단 1회의 계산으로 끝마침.

 

LSTM의 역전파

- slice 노드의 역전파에서는 4개의 행렬을 연결.

- 그림에서는 df, dg, di, do를 연결하여 dA를 만듬.


GRU

GRU(Gate Recurrent Unit)

- LSTM은 아주 좋은 계층이지만 매개변수가 많아서 계산이 오래 걸리는 것이 단점.

- 그래서 최근에는 LSTM을 대신할 '게이트가 추가된 RNN'이 많이 제안되고 있음.

- 그 중 유명하고 검증된 GRU(게이트가 추가된 RNN)

- LSTM의 게이트를 사용한다는 개념은 유지한 채, 매개변수를 줄여 계산시간을 줄임.

 

GRU의 계산 그래프

 

- GRU에서 수행하는 계산은 4개의 식으로 표현. 6개였던 LSTM에 비해 간단해진 것을 확인.

- GRU는 이처럼 LSTM을 더 단순하게 만든 아키텍처.

- GRU에는 기억 셀은 없고, 시간 방향으로 전파하는 것은 은닉 상태 h뿐.

- r과 z라는 2개의 게이트를 사용.(r은 reset게이트, z는 update 게이트)

- Reset 게이트 r은 과거의 은닉 상태를 얼마나 무시할지를 결정.\

- 만약 r이 0이면, h=tanh(xtWx+(r0ht-1)Wh+b) 으로부터, 새로운 은닉상태 h는 입력 xt 만으로 결정.

  (즉, 과거의 은닉 상태를 완전히 무시)

- Update 게이트는 은닉 상태를 갱신하는 게이트.

- LSTM의 forget 게이트와 input 게이트의 2가지 역할을 혼자 담당.

- forget 게이트로써의 기능은 (1-z)0ht-1 부분. 과거의 은닉 상태에서 잊어야 할 정보를 삭제.

- input 게이트로써의 기능은 z0h 부분. 이에 따라 새로 추가된 정보에 input 게이트의 가중치를 부여.


LSTM vs GRU

- LSTM과 GRU중 어느 쪽을 사용해야 하는지를 묻는다면, 주어진 문제와 하이퍼파라미터에 따라 승자가 달라짐.

 


아다마르 곱(Hardamard Product)

- 일반 행렬 곱은 mxn 행렬과 nxp 꼴의 행렬을 곱하지만, Hardamard Product는 mxn과 mxn의 같은 꼴을 가지는 행렬끼리 같은 위치의 원소끼리 각각 곱함.

Comments