오답노트

[NLP] Long Short-Term Memory (LSTM) 본문

Python/DL

[NLP] Long Short-Term Memory (LSTM)

권멋져 2022. 10. 6. 20:41

https://wikidocs.net/22888

 

2) 장단기 메모리(Long Short-Term Memory, LSTM)

바닐라 아이스크림이 가장 기본적인 맛을 가진 아이스크림인 것처럼, 앞서 배운 RNN을 가장 단순한 형태의 RNN이라고 하여 바닐라 RNN(Vanilla RNN)이라고 합니다 ...

wikidocs.net

Long Short-Term Memory (LSTM)

기존 RNN의 문제점을 개선한 알고리즘이다. gradient의 소멸 및 폭주 현상 해소하고, 정보의 장거리 전달이 가능하여 기본 RNN에 비해 우수한 문제 처리 능력을 가지고 있다.

 

RNN을 구성하는 기본 단위를 기존의 퍼셉트론에서 좀 더 복잡한구조로 바꾸는 방법을 사용하여, Cell State, Forget Gate, Input Gate라는 개념을 이용한다. Gate는 sigmoid 함수를 이용해서 0에 가까운 값은 무시하고, 1에 가까운 값을 입력값에 활용하는 원리로 이루어져있다.

 

길이가 길어지면 뒤로 갈 수록 앞의 값을 제대로 전달하지 못하는 점을 Cell State로 해결했다.

 

⊕ : 각 벡터의 원소를 더한다

⊗ : 각 벡터의 원소를 곱한다

 

Forget Gate

현재의 입력 xt와 이전 hidden state ht-1를 더한 값에 sigmoid 함수를 씌운 형태이다. sigmoid 내부의 값이 1에 가까울 수록 유의미한 정보로 판단한다.

 

Input Gate

우선 앞에서 설명한 Forget Gate(it)의 형태를 찾아볼 수 있다.

gt는 현재의 입력 xt와 이전 hidden state ht-1를 더한 값에 하이퍼볼릭 탄젠트 함수를 씌운 형태이다. gt는 기존 RNN에서의 형태와 같다.

 

Cell state

이전 Cell state인 Ct-1 과 Forget gate의 결과값을 곱한 값과

Input Gate에서 만들어진 결과 값들을 곱한 값과 더한 값으로 현재 Cell state Ct를 만든다.

 

Forget gate는 이전 시점의 입력을 얼마나 반영할지를 의미하고, Input Gate는 현재 시점의 입력을 얼마나 반영할지를 결정한다.

 

Hidden state

 

Forget gate 와  Ct에 하이퍼볼릭 탄젠트를 씌운 값의 곱한 결과를 Hidden state 또는 output으로 향하게 된다.

 

Bi-directional LSTM

Bi-directional LSTM는 과거 시점의 입력뿐만 아니라 미래 시점의 입력으로부터 정보를 얻어야 할 경우에 이전과 이후의 시점 모두를 고려해서 현재 시점의 예측을 정확하게 할 수 있도록 고안된 양방향 RNN이다.

양방향 RNN은 하나의 출력값을 예측하기 위해 기본적으로 두 개의 메모리 셀을 사용하며 첫번째는 앞 시점의 은닉 상태를 전달 받아 현재의 은닉 상태를 계산하고, 두번째 메모리 셀은 뒤 시점의 은닉 상태를 전달 받아 현재의 은닉상태를 계산한다.

 

이 두 개의 값 모두가 현재 시점의 출력층에서 출력 값을 예측하기 위해 사용된다.

 

 

Pytorch 실습

!pip install -U torchtext==0.10.0

from google.colab import drive
drive.mount('/content/gdrive')

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchtext.legacy import data
import torchtext.datasets as datasets

class RNN_Text(nn.Module):
  def __init__(self, embed_num, class_num):
    super(RNN_Text, self).__init__()

    #V : 단어 사전 크기
    #C : 분류하고자 하는 클래스 개수
    #H : 히든 사이즈
    #D : 단어벡터 차원

    V = embed_num
    C = class_num
    H = 256
    D = 100
    
    self.embed = nn.Embedding(V,D)
    self.rnn = nn.LSTM(D,H,bidirectional = True)
    self.out = nn.Linear(H*2,C)
  
  def forward(self, x):
    x = self.embed(x)
    x,_ = self.rnn(x,(self.h,self.c))
    logit = self.out(x[-1])

    return logit
  
  def inithidden(self,b):
    self.h = torch.randn(2,b,256)
    self.c = torch.randn(2,b,256)

class mydataset(data.Dataset):
  @staticmethod
  def sort_key(ex):
    return len(ex.text)
  def __init__(self, text_field, label_field, path=None, examples=None, **kwargs):
    fields = [('text',text_field),('label',label_field)]
    if examples is None:
      path = self.dirname if path is None else path
      examples = []
      for i,line in enumerate(open(path,'r',encoding='utf-8')):
        if i == 0:
          continue
        line = line.strip().split('\t')
        txt = line[1].split(' ')

        examples += [data.Example.fromlist([txt,line[2]],fields)]
    super(mydataset, self).__init__(examples, fields, **kwargs)

text_field = data.Field(fix_length=30)
label_field = data.Field(sequential=False, batch_first = True, unk_token = None)

train_data = mydataset(text_field, label_field,path='/content/gdrive/My Drive/Colab Notebooks/train_tok.txt')
test_data = mydataset(text_field, label_field,path='/content/gdrive/My Drive/Colab Notebooks/test_tok.txt')

text_field.build_vocab(train_data)
label_field.build_vocab(train_data)

train_iter,test_iter = data.Iterator.splits(
    (train_data,test_data),
    batch_sizes=(100,1), repeat=False
)

rnn = RNN_Text(len(text_field.vocab),2)
optimizer = torch.optim.Adam(rnn.parameters())
rnn.train()

for epoch in range(10):
  totalloss = 0
  for batch in train_iter:
    optimizer.zero_grad()

    txt=batch.text
    label=batch.label

    rnn.inithidden(txt.size(1))
    pred = rnn(txt)

    loss = F.cross_entropy(pred,label)
    totalloss += loss.data

    loss.backward()
    optimizer.step()
  
  print(epoch,'epoch')
  print('loss : {:.3f}'.format(totalloss.numpy()))



from sklearn.metrics import classification_report
correct = 0
incorrect = 0
rnn.eval()
y_test =[]
prediction =[]

for batch in test_iter:
  txt = batch.text
  label = batch.label
  y_test.append(label.data[0])
  
  rnn.inithidden(txt.size(1))

  pred = rnn(txt)

  _, ans = torch.max(pred,dim=1)
  prediction.append(ans.data[0])

  if ans.data[0] == label.data[0]:
    correct +=1
  else:
    incorrect +=1


print('correct : ',correct)
print('incorrect : ',incorrect)
print(classification_report(
    torch.tensor(y_test),
    torch.tensor(prediction),
    digits=4,
    target_names=['negative','positive']
))

'Python > DL' 카테고리의 다른 글

[NLP] Transformer  (0) 2022.10.07
[NLP] Sequence-to-Sequence (seq2seq)  (1) 2022.10.06
[NLP] RNN (Recurrent Neural Network)  (0) 2022.10.06
[NLP] CNN 기반 텍스트 분류  (0) 2022.10.06
[NLP] 워드 임베딩 (Word Embedding)과 Word2Vec  (1) 2022.10.05