DEEP.I - Lab

오프라인 공간의 지능화를 꿈꾸는 딥아이 연구실입니다.

Python/Tensorflow

[Tensorflow] GAN (생산적 적대 신경망) 구현하기

Jongwon Kim 2021. 2. 17. 11:15
반응형

Concept

그림 1. GAN 2.0: NVIDIA’s Hyperrealistic Face Generator 생성된 가짜 이미지

요즘 가장 흥미롭게 연구 중인 GAN (Generative Adversarial Network: 생산적 적대 신경망)입니다. GAN은 Neural Network에 뿌리를 두고 있으나 비지도 학습으로 정의되며, 두 개의 신경망이 서로 경쟁하며 학습하게 됩니다. 2014년 처음 아이디어가 제안된 이후, 급격한 연구적 성장을 거듭하며 현재는 놀라울 정도로 진보된 기술로 성장하고 있습니다.

 

이번 포스팅에서는 GAN의 오리지널 버전의 알고리즘을 간단하게 살펴본 뒤, MNIST 손글씨 인식 데이터를 이용해 텐서플로우로 구현해보도록 하겠습니다. 

 

Algorithm

그림 2. GAN 신경망의 기본 구조

기본적인 구조는 간단합니다. 가짜 이미지 생성을 위한 생성자(Generator) 신경망과 진짜와 가짜 이미지 판별을 위한 판별자(Discriminator) 신경망으로 구성됩니다. 두 신경망 모두 FC(Fully Connected Layers)로 구성된 다층 신경망입니다.

 

1. 생성자 네트워크

비지도 학습으로 정의되는 GAN의 생성자는 임의로 생성된 잡음을 입력으로 영상을 생성합니다. 보통 1 x 100 ~ 200의 잡음을 입력하며, 기존 MLP와 동일하게 2 ~ 3개의 층으로 구성된 네트워크로 설계하면 됩니다.

 

MNIST  데이터를 목표할 경우, 28 x 28의 이미지 크기에 맞게 1 x 768의 출력층을 가지게 됩니다. CNN 구조가 없기 때문에 2차원 데이터를 모두 1차원으로 변경해주어야 합니다.

2. 판별자 네트워크

생성자 네트워크와 역순으로 설계됩니다. 1 x 768이 입력되며 신경망 층을 통해 최종적으로 참 (1) 과 거짓 (0)을 판단하게 됩니다. 손글씨 인식을 위한 분류 네트워크와 같은 구조이며 출력은 1개입니다.

3. 손실 함수

GAN은 일반적으로 판별자 네트워크 기반 크로스 앤트로피 손실 함수를 통해 학습하게 됩니다.

 

생성자 네트워크 (G)는 거짓 이미지가 판별자 네트워크에 입력되었을 때 값이 최소 (MIN)가 되도록 하며 (거짓 이미지가 거짓인지 모르도록)판별자 네트워크 (D) 는 거짓 이미지와 참 이미지가 판별자 네트워크에 입력되었을 때 값이 최대 (MAX) (거짓을 거짓으로 참을 참으로 판별하도록)가 되도록 학습하게 됩니다.

 

이처럼 서로 반대되는 방향으로 학습되는 구조이기때문에, 경쟁학습 이라고도 하며 Zero-Sum 게임과 유사합니다.

 

 

SourceCode

github.com/DEEPI-LAB/python-tensorflow-MNIST-GANs.git

 

DEEPI-LAB/python-tensorflow-MNIST-GANs

Contribute to DEEPI-LAB/python-tensorflow-MNIST-GANs development by creating an account on GitHub.

github.com

git clone https://github.com/DEEPI-LAB/python-tensorflow-MNIST-GANs.git

1. MNIST 데이터 불러오기

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from scipy import io

mnist_x = io.loadmat('train_input.mat')['images']
minst_y = io.loadmat('train_output.mat')['y']
mnist_x = mnist_x.astype('float32')

2. 생성자 네트워크 구조 설계

 

저는 은닉층과 입력의 갯수를 맞추기 위해 256개의 노이즈를 생성했습니다. 100개나 200개나 크게 다르지는 않습니다. 활성화 함수는 relu를 사용하였으며 출력단에는 선형회기 예측과 같은 메커니즘이므로, 시그모이드 활성화 함수를 사용합니다.

# Generator
Generator = tf.keras.Sequential([
    tf.keras.layers.Input(256,30),
    tf.keras.layers.Dense(256, activation='relu'),
    tf.keras.layers.Dense(784, activation='sigmoid')])

3. 판별자 네트워크 구조 설계 및 Optimizer 정의

# Discriminator
Discriminator = tf.keras.Sequential([
    tf.keras.layers.Input(784),
    tf.keras.layers.Dense(256, activation='relu'),
    tf.keras.layers.Dense(1, activation='sigmoid')])

# define Optimizer
Doptimizer = tf.keras.optimizers.Adam(0.001)
Goptimizer = tf.keras.optimizers.Adam(0.001)

4. 학습 모델 생성

 

텐서플로우는 버전에 따라 다양한 학습 방법이 존재합니다. tessorflow 2.0 이상 버전에서 작성하였으며, 2개의 신경망 학습이 진행되야하기때문에 두개의 손실함수를 업데이트 해주는 방식으로 코드를 구현합니다. 

#%% Training Step
def get_noise(batch_size,n_noise):
    return tf.random.normal([batch_size,n_noise])

@tf.function
def train_step(inputs):

    with tf.GradientTape() as t1, tf.GradientTape() as t2:
        # 잡음으로부터 이미지 생성
        G = Generator(get_noise(30,256))
    	# 판별자 입력
        Z = Discriminator(G)
        R = Discriminator(inputs)   
        # 손실 함수 연산
        loss_D = -tf.reduce_mean(tf.math.log(R) + tf.math.log(1 - Z))
        loss_G = -tf.reduce_mean(tf.math.log(Z))
    
    # 판별자 업데이트      
    Dgradients = t1.gradient(loss_D, Discriminator.trainable_variables)
    Doptimizer.apply_gradients(zip(Dgradients, Discriminator.trainable_variables))
    # 생성자 업데이트
    Ggradients = t2.gradient(loss_G,Generator.trainable_variables)
    Goptimizer.apply_gradients(zip(Ggradients, Generator.trainable_variables)) 

5. 신경망 학습

 

학습에 따라 생성된 이미지를 확인하기 위해 plot으로 이미지를 확인하는 구문을 추가할 수 있습니다.

# 배치 사이즈
total_batch = int(60000/30) 
        
for epoch in tf.range(15):
    k = 0
    for i in tf.range(total_batch):
        batch_input = mnist_x.T[i*30:(i+1)*30]
    
        inputs = tf.Variable([batch_input],tf.float32)
        train_step(inputs)
        print(k)
        k = k + 1
		
        # 생성된 이미지
        if k%100 == 0:
            G = Generator(get_noise(10,256))
        
            fig, ax = plt.subplots(1,10 ,figsize=(10, 1))
                
            for j in range(10):
                ax[j].set_axis_off()
                ax[j].imshow(np.reshape(G[j], (28, 28)).T,cmap='gray')
            plt.pause(0.001)
            plt.show()

6. 학습 결과

 

 

Your Best AI Partner DEEP.I
AI 바우처 공급 기업
객체 추적 및 행동 분석 솔루션 | 제조 생산품 품질 검사 솔루션 | AI 엣지 컴퓨팅 시스템 개발

인공지능 프로젝트 개발 외주 및 상담
E-mail: contact@deep-i.ai
Site: www.deep-i.ai

 

딥아이 DEEP.I | AI 기반 지능형 기업 솔루션

딥아이는 AI 기술의 정상화라는 목표를 갖고, 최첨단 딥러닝 기술 기반의 기업 솔루션을 제공하고 있으며, 이를 통해 고도의 AI 기반 객체 탐지, 분석, 추적 기능을 통합하여 다양한 산업 분야에

deep-i.ai

 

반응형