Data Analysis

Python/머신러닝 & 딥러닝

[딥러닝] 합성곱 신경망 모델 만들기 (MNIST 데이터)

Holy_Water 2023. 1. 19. 15:36

합성곱신경망은 완전 연결 신경망 보다 훨씬 가중치가 작으면서도 이미지 분류 문제를 더 잘 해결한다. 

 

합성곱 신경망 구조

데이터 세트 불러오기

import tensorflow as tf
(x_train_all, y_train_all), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

 

훈련 데이터 세트를 훈련세트와 검증세트로 나누기

from sklearn.model_selection import train_test_split
x_train, x_val, y_train, y_val = train_test_split(x_train_all,y_train_all,
                        stratify=y_train_all, test_size=0.2, random_state=42)

 

타깃을 원-핫 인코딩으로 변환하기

y_train_encoded = tf.keras.utils.to_categorical(y_train)
y_val_encoded = tf.keras.utils.to_categorical(y_val)

 

입력데이터 준비하기 (마지막에 컬러채널 추가후 확인)

x_train = x_train.reshape(-1, 28, 28, 1)
x_val = x_val.reshape(-1, 28, 28, 1)
print(x_train.shape)
# (48000, 28, 28, 1)

 

입력데이터 표준화전처리 하기

x_train = x_train / 255
x_val = x_val / 255

 

케라스로 합성곱 신경망 만들기

from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense

 

합성곱층 쌓기
Conv2D: 1번 합성곱 커널 개수(높이,너비 포함)

conv1 = tf.keras.Sequential()
conv1.add(Conv2D(10,(3,3), activation = 'relu', padding = 'same', input_shape=(28, 28, 1)))

 

풀링층 쌓기
MaxPooling (높이,너비)

conv1.add(MaxPooling2D((2, 2)))

 

완전 연결층에 중비할 수 있도록 특성 맵 펼치기

conv1.add(Dense(100, activation='relu'))
conv1.add(Dense(10, activation='softmax'))

 

모델 구조 살펴보기

conv1.summary()

 

합성곱 신경망 모델 훈련하기

conv1.compile(optimizer='adam',
    loss='categorical_crossentropy', metrics=['accuracy'])

history = conv1.fit(x_train, y_train_encoded, epochs=20,
                    validation_data=(x_val,y_val_encoded))

 

손실그래프와 정확도 그래프 확인하기

손실 그래프

import matplotlib.pyplot as plt
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train_loss','val_loss'])

정확도 그래프

plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train_accuracy','val_accuracy'])

 

검증세트 정확도 계산하기

loss, accuracy = conv1.evaluate(x_val, y_val_encoded, verbose=0)
print(accuracy)
# 0.9106666445732117