오답노트
[Keras] Flatton Layer 본문
Flatton Layer
모델링을 할 때 학습시킬 데이터를 전처리하는 경우 중, 데이터를 1차원 데이터로 만드는 경우가 존재한다.
이때 numpy를 활용하여 reshape 하는 과정을 거쳐서 차원을 변환 시킬수도 있다.
하지만 numpy를 활용하지 않고 Keras에 존재하는 Flatton Layer를 통해서 다차원 데이터를 1차원으로 바꿀 수 있다.
실습을 통해 reshape으로 차원을 변환하는 과정과 Flatton Layer를 통해 차원을 변환하는 과정을 알아보자.
라이브러리 및 데이터 불러오기
import tensorflow as tf
from tensorflow import keras
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random as rd
from sklearn.metrics import accuracy_score
#데이터 불러오기
(train_x, train_y), (test_x, test_y) = tf.keras.datasets.mnist.load_data()
reshape
#데이터 reshape
train_x_reshape = train_x.reshape([train_x.shape[0],-1])
test_x_reshape = test_x.reshape([test_x.shape[0],-1])
#데이터 scaling
max_num, min_num = train_x.max(),train_x.min()
train_x = (train_x - min_num) / (max_num - min_num)
test_x = (test_x - min_num) / (max_num - min_num)
max_num, min_num = train_x_reshape.max(),train_x_reshape.min()
train_x_reshape = (train_x_reshape - min_num) / (max_num - min_num)
test_x_reshape = (test_x_reshape - min_num) / (max_num - min_num)
#one hot encoding
from tensorflow.keras.utils import to_categorical
n_class = len(np.unique(train_y))
train_y = to_categorical(train_y,n_class)
test_y = to_categorical(test_y,n_class)
#모델링
# 레이어 선언
il = keras.layers.Input(shape=(train_x_reshape.shape[1]))
h1 = keras.layers.Dense(512,'relu')(il)
h2 = keras.layers.Dense(512,'relu')(h1)
ol = keras.layers.Dense(n_class,'softmax')(h2)
# 모델 선언
model = keras.models.Model(il,ol)
#모델 컴파일
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adam(),
metrics=['accuracy'])
#모델 확인
model.summary()
#모델 학습
model.fit(train_x_reshape,train_y,
verbose=1,epochs=10)
#모델 예측
pred_train = model.predict(train_x_reshape)
pred_test = model.predict(test_x_reshape)
pred_train_argmax = pred_train.argmax(axis=1)
pred_test_argmax = pred_test.argmax(axis=1)
train_accuracy = accuracy_score(train_y.argmax(axis=1),pred_train_argmax)
test_accuracy = accuracy_score(test_y.argmax(axis=1),pred_test_argmax)
print('train data 정확도 : ',train_accuracy*100,'%') # 99.87%
print('test data 정확도 : ',test_accuracy*100,'%') # 98.59%
Flatton Layer
#레이어 선언
shape_size = tuple(train_x.shape[i] for i in range(1,len(train_x.shape)))
il = keras.layers.Input(shape=shape_size)
fl = keras.layers.Flatten()(il)
h1 = keras.layers.Dense(512,'relu')(fl)
ol = keras.layers.Dense(n_class,'softmax')(h1)
#모델 선언
model = keras.models.Model(il,ol)
#모델 컴파일
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adam(),
metrics=['accuracy'])
#모델 확인
model.summary()
#모델 학습
model.fit(train_x,train_y,
verbose=1,
epochs=10)
#모델 평가
pred_train = model.predict(train_x)
pred_test = model.predict(test_x)
pred_train_argmax = pred_train.argmax(axis=1)
pred_test_argmax = pred_test.argmax(axis=1)
train_accuracy = accuracy_score(train_y.argmax(axis=1),pred_train_argmax)
test_accuracy = accuracy_score(test_y.argmax(axis=1),pred_test_argmax)
print('train data 정확도 : ',train_accuracy*100,'%') # 99.75500000000001 %
print('test data 정확도 : ',test_accuracy*100,'%') # 98.13 %
Flatton Layer의 특성
Flatton Layer는 다차원 데이터를 1차원 데이터로 만들어준다.
아래 코드를 통해 알아보자
il = keras.layers.Input(shape=(10,10,10)) # 3차원 데이터
fl = keras.layers.Flatten()(il)
ol = keras.layers.Dense(n_class,'softmax')(fl)
#모델 선언
model = keras.models.Model(il,ol)
#모델 컴파일
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adam(),
metrics=['accuracy'])
#모델 확인
model.summary()
'''
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_7 (InputLayer) [(None, 10, 10, 10)] 0
flatten_2 (Flatten) (None, 1000)<<변환 후 0
dense_15 (Dense) (None, 10) 10010
=================================================================
Total params: 10,010
Trainable params: 10,010
Non-trainable params: 0
'''
'Python > DL' 카테고리의 다른 글
[Keras] 연산 Layer (0) | 2022.09.18 |
---|---|
[Keras] EarlyStopping (0) | 2022.09.15 |
[Keras] ANN (0) | 2022.09.14 |
[Keras] MultiClass - SoftMax (0) | 2022.09.13 |
[Keras] Functional API (0) | 2022.09.13 |