worldforest 2024. 12. 12. 17:09

LSTM은 여러 게이트와 셀 상태를 포함하고 있다.

 

(1) 입력 게이트 it

입력을 받아 새 정보를 셀 상태에 얼마나 반영할지 결정

학습되는 가중치 : Wi(입력에 대한 가중치), Ui(이전 상태에 대한 가중치), bi(바이어스)

 

 (2) 망각 게이트 ft

이전 셀 상태의 정보를 얼마나 유지할지 결정

학습되는 가중치 : Wf(입력에 대한 가중치), Uf(이전 상태에 대한 가중치), bf(바이어스)

 

(3) 출력 게이트 ot

어떤 정보를 출력할지 결정

학습되는 가중치 : Wo(입력에 대한 가중치), Uo(이전 상태에 대한 가중치), bo(바이어스)

 

(4) 셀 상태 후보 Ct

새로운 정보를 기반으로 업데이트될 셀 상태 후보 값

학습되는 가중치 : Wc(입력에 대한 가중치), Uc(이전 상태에 대한 가중치), bc(바이어스)

 

 

학습 가능한 파라미터 수 = 4 * (ninput * nhidden + n^2hiden + nhidden)

ninput 입력 피처의 차원

nhidden LSTM 유닛의 수

LSTM은 게이트와 셀 상태 후보 값에 대해 각각 독립적인 가중치를 갖기 때문에 4*를 곱합니다.

 

 

 

LSTM의 가중치들은 시간 축을 따라 역전파를 수행하는 BPRR 알고리즘을 통해 학습됩니다.

Backpropagation Through Time, BPTT

 

  • 순전파(forward pass): 현재 입력과 이전 상태를 사용해 각 게이트와 셀 상태를 계산.
  • 손실 계산(loss calculation): 예측값과 실제값의 차이를 바탕으로 손실을 계산.
  • 역전파(backpropagation): 손실에 따라 가중치와 바이어스를 업데이트.

학습의 핵심은 역전파와 옵티마이저를 통해 가중치를 업데이트 하는 것입니다.

 

 

초기 장기 기억은 tanh를 통과하지 않고 모델의 초기값으로 설정됩니다.

이후 임시 장기 기억은 tanh를 통과해 새 정보를 계산하며 이 값이 게이트를 통해 장기기억을 업데이트 합니다.

출력단계에서 장기기억의 일부가 tanh를 거쳐 은닉 상태로 변환됩니다.

 

 


 

 

1. 출력 게이트 계산 Ot

현재 입력과 이전 단기 데이터가 출력 게이트를 통해 출력 게이트 값을 생성합니다.

ot=σ(Woxt+Uoht1+bo)

여기서 σ 시그모이드 함수(출력 범위 [0, 1])

Wo, Uo, bo : 출력 게이트의 가중치 및 바이어스

 

2. 셀 상태 갱신

망각 게이트와 입력 게이트를 통해 장기 기억 업데이트​

3. 셀 상태의 비선형 변환(tanh(Ct))

업데이트된 장기 기억은 tanh를 통해 출력용 값으로 변환

4. 현재 단기 데이터 계산 ht

변환된 셀 상태 값과 출력 게이트 값을 곱해 현재 단기 데이터 생성

 

장기 기억Ct의 변환 값에 출력 게이트를 곱하는 과정은

출력 게이트가 현재 시점에서 어떤 정보가 중요한지 결정/장기 기억의 정보를 선택적으로 은닉 상태로 전달

ht는 출력 게이트와 장기 기억의 상호작용으로 생성되고 현재 시점에서 중요한 단기 정보를 포함

반응형