DEEP.I - Lab

오프라인 공간의 지능화를 꿈꾸는 딥아이 연구실입니다.

Matlab

[Matlab] 다층 퍼셉트론(MLP)을 이용한 MNIST 손글씨 인식 알고리즘 구현

Jongwon Kim 2020. 12. 5. 16:00
반응형

MNIST DATASET

MNIST 데이터셋은 머신러닝을 입문하는 분들이 처음 접하게 되는 데이터 중 하나입니다. 28 x 28 해상도를 가지는 흑백 이미지로 구성되어있지만, 영상 처리 알고리즘 이외 K-Measn, PCA, RNN 등 다양항 기법이 적용 가능하여 초기 데이터 분석 단계에서 연습에 활용되고 있습니다.

 

 

저 역시 처음 머신러닝에 입문했을 때 XOR 게이트 문제 이후, 머리를 쓰며 가장 많이 다뤄본 데이터입니다. 이제 막 입문하시는 분들이라면 Tesnorflow 나 Pytorch가 제공하는 함수 사용 이전에 직접 수식을 코딩하고 데이터 전처리하는 연습은 꼭 가지시길 바랍니다. 그런 의미에서 이번 포스팅에서는 모두를 위한 인공지능 교육에 활용했었던 MATLAB 기반 MINIST 손글씨 인식 알고리즘 코드를 구현하겠습니다. 모든 코드는 함수없이 직접 구현되었습니다.

MULTI-LAYER PERCEPTRON

 

2개의 은닉층을 가지는 MLP 구조로 설계하도록 하겠습니다. MLP는 CNN처럼 2차원 영상이 입력될 수 없기 때문에 1차원으로 데이터를 변환시켜줘야 합니다. 28 X 28 이미지는 784 개의 1차원 벡터로 변환되어 N개의 노드를 가지는 은닉층에 입력됩니다. 첨부된 데이터는 총 60,000개이며 784개의 차원을 가지므로 60,000 x 784 의 형태를 가지게 됩니다.

 

 

Algorithm

 

0. 샘플코드 다운로드

github.com/DEEPI-LAB/matlab-mnist-multi-layer-perceptron.git

 

DEEPI-LAB/matlab-mnist-multi-layer-perceptron

The repository implements the a simple Multi-Layers Neural Network from scratch for MNIST classification. - DEEPI-LAB/matlab-mnist-multi-layer-perceptron

github.com

git clone https://github.com/DEEPI-LAB/matlab-mnist-multi-layer-perceptron.git

또는 위 링크에서 좌측 상단 초록색 CODE를 클릭 후, 하단 Download ZIP 으로 직접 다운로드 가능합니다.

 

 

1. 데이터 확인

%% MNIST 데이터 확인                                            
mnist = images(:,1:200);                        
mnist = reshape(mnist,28,28,200); 
montage(mnist)        

 

2. 신경망 구조 설계 및 미니 배치 크기 설정

% 1st Layer
node_1w = fc_node('weight', imRe^2, pNum_1);
node_1b = fc_node('bias', pNum_1,1);
% 2nd Layer
node_2w = fc_node('weight', pNum_1, pNum_2);
node_2b = fc_node('bias', pNum_2,1);
% 3rd Layer
node_3w = fc_node('weight', pNum_2, 10);
node_3b = fc_node('bias', 10,1);
% batch size
batch =64;

 

입력 벡터의 크기와 일치하도록 은닉층 노드 수를 설정하고 가중치값과 편향 값을 randn 함수로 초기화 해줍니다. 학습은 미니 배치 방식으로 64개씩 학습을 진행했습니다.

 

3. 학습 데이터 셔플링 및 미니 배치 메모리 지정

% data shuffle
p = randperm(cols);                                           
X = x(:,p(1:batch));
Y = y(p(1:batch),:);

% batch memory init (weight)
batch_1 = 0; batch_2 = 0; batch_3 = 0; 
% batch memory init (bias)
batch_4 = 0; batch_5 = 0; batch_6 = 0;

 

학습을 반복하면서 자동으로 데이터가 섞일 수 있도록 randperm 함수를 이용하여 셔플링을 진행해줍니다. 병렬 처리 프로세스가 아닌 For문으로 모든 미니 배치 데이터 값을 연산해야 하기 때문에 배치 메모리를 사전에 지정하는 방식으로 미내 배치 학습을 구현했습니다.

 

4. 순전파 단계

%% Feed Forward propagation

f1 = relu(X(:,i)' * node_1w + node_1b');
f2 = relu(f1 * node_2w + node_2b');
f3 = exp(f2 * node_3w + node_3b') / sum(exp(f2 * node_3w + node_3b')) ;

5. 오차 함수 단계

%% Error
P(i) = find(f3==max(f3));
O(i) = find(Y(i,:)==max(Y(i,:)));
E(i,:) = - sum(Y(i,:).*log(f3));

6. 역전파 단계

%% Back propagation  

b3 = f3 - Y(i,:);    
b2 = b3 * node_3w' .* reluGradient(f2);     
b1 = b2 * node_2w' .* reluGradient(f1);

 

MNIST는 0부터 1까지 총 10개의 클래스를 가지고 있으므로 크로스 엔트로피를 오차 함수로 설정했습니다. 활성화 함수는 순 전파 역전파 모두 구현이 쉽고 성능에서도 Sigmoid 대비 우위를 가지고 있는 ReLu 함수를 사용했습니다.

 

7. 미니 배치 업데이트 단계

%% Batch 
         batch_1 = batch_1 + (alpha * f2' * b3);        
         batch_4 = batch_4 + (alpha * b3)'; 
         
         batch_2 = batch_2 + (alpha * f1' * b2);   
         batch_5 = batch_5 + (alpha * b2)';
    
         batch_3=  batch_3 + (alpha * X(:,i) * b1);
         batch_6 = batch_6 + (alpha * b1)' ;
         
    end
    
        %% Update
        node_3w = node_3w - batch_1 / batch;
        node_2w = node_2w - batch_2 / batch;
        node_1w = node_1w - batch_3 / batch;
        
        node_3b = node_3b - batch_4 / batch;
        node_2b = node_2b - batch_5 / batch;
        node_1b = node_1b - batch_6 / batch;

 

For문 연산으로 미니배치만큼 역전파 값을 구해 준 뒤, 모두 더한 다음 가중치와 편향 값을 업데이트하게 됩니다. 매트랩이 취약한 For문이 넘쳐흐르네요. 어느 정도 역전파와 배치에 대한 이해도가 높으신 분들은 병렬 처리 프로세서 방식으로 변경해 보시길 바랍니다.

 

8. 전체 코드 및 결과

 

% *********************************************
% MNIST Neural Networks
% @author: Deep.I Inc. @Jongwon Kim
% deepi.contact.us@gmail.com
% Revision date: 2020-12-01
% See here for more information :
%    https://deep-eye.tistory.com
%    https://deep-i.net
% **********************************************
% STRUCTURE : BATCH MLP
% input : 787 x 60000
% output : 10 x 60000
% MODE : Batch
% ACTIVATION FUNCTION : 'Relu'
% ERROR RATE : 2.51

%% F5를 눌러서 실행해주세요.

clear all
clc
cla
close all
input("\n\n 퍼셉트론을 활용한 손글씨 인식 프로그램 입니다. [엔터키를 눌러주세요] ")

%%
load train\train_input.mat; 
load train\train_output.mat; 

clc
input("\n\n 배포해드린 숫자 데이터를 로드할게요. [엔터키를 눌러주세요] ")                                              
                                              
%% 학습 데이터를 한번 봅시다                                              
mnist = images(:,1:200);                        
mnist = reshape(mnist,28,28,200); 
montage(mnist)                                      

title("학습에 활용할 손글씨 이미지 입니다. [엔터키를 눌러주세요] ")
clc
input("\n\n 학습에 활용할 손글씨 이미지 입니다. [엔터키를 눌러주세요] ")

x = images;
cols = size(x,2);
imRe = 28;
    
clc
fprintf("데이터의 개수 : %d \n이미지 해상도 : %d x %d\n입력 차원 : %d\n",cols,imRe,imRe,imRe^2)
alpha = input("\n\n 학습의 정도를 결정하는 Learning Rate을 입력해주세요. [0.1~ 0.0001] ");
clc
pNum_1  = input("\n\n 첫번째 층의 퍼셉트론(뉴런)의 개수를 입력해주세요. [1~inf] ");
pNum_2  = input(" 두번째 층의 퍼셉트론(뉴런)의 개수를 입력해주세요. [1~inf] ");
eh  = input(" 학습을 몇번 반복할지 반복 횟수를 입력해주세요. [1~inf] ");
fprintf("\n\n %d-%d-%d-%d 의 구조를 갖는 신경망이 완성되었습니다.",imRe^2,pNum_1,pNum_2,10);
input(" 엔터를 누루면 학습을 시작합니다!");

% 1st Layer
node_1w = fc_node('weight', imRe^2, pNum_1);
node_1b = fc_node('bias', pNum_1,1);
% 2nd Layer
node_2w = fc_node('weight', pNum_1, pNum_2);
node_2b = fc_node('bias', pNum_2,1);
% 3rd Layer
node_3w = fc_node('weight', pNum_2, 10);
node_3b = fc_node('bias', 10,1);
% batch size
batch =64;

close all
for z = 1 : eh
    
% data shuffle
p = randperm(cols);                                           
X = x(:,p(1:batch));
Y = y(p(1:batch),:);

% batch memory init (weight)
batch_1 = 0; batch_2 = 0; batch_3 = 0; 
% batch memory init (bias)
batch_4 = 0; batch_5 = 0; batch_6 = 0;
        
    for i = 1 : batch    

%% Feed Forward propagation

f1 = relu(X(:,i)' * node_1w + node_1b');
f2 = relu(f1 * node_2w + node_2b');
f3 = exp(f2 * node_3w + node_3b') / sum(exp(f2 * node_3w + node_3b')) ;
        
%% Error
P(i) = find(f3==max(f3));
O(i) = find(Y(i,:)==max(Y(i,:)));
E(i,:) = - sum(Y(i,:).*log(f3));
        
        
%% Back propagation  

b3 = f3 - Y(i,:);    
b2 = b3 * node_3w' .* reluGradient(f2);     
b1 = b2 * node_2w' .* reluGradient(f1);
       
        %% Batch 
         batch_1 = batch_1 + (alpha * f2' * b3);        
         batch_4 = batch_4 + (alpha * b3)'; 
         
         batch_2 = batch_2 + (alpha * f1' * b2);   
         batch_5 = batch_5 + (alpha * b2)';
    
         batch_3=  batch_3 + (alpha * X(:,i) * b1);
         batch_6 = batch_6 + (alpha * b1)' ;
         
    end
    
        %% Update
        node_3w = node_3w - batch_1 / batch;
        node_2w = node_2w - batch_2 / batch;
        node_1w = node_1w - batch_3 / batch;
        
        node_3b = node_3b - batch_4 / batch;
        node_2b = node_2b - batch_5 / batch;
        node_1b = node_1b - batch_6 / batch;
        
        %% 그래프 보기
        tex2 = mean(P == O);
        tex1 = mean(E);
        mse(z,1) = mean(E);
        format shortG
        clc        
        fprintf("학습 횟수 : %d번\n",z)
        fprintf("학습된 글자 수 : %d 개 (한 번 반복에 64개씩 학습을 진행합니다.)\n",z*batch)
        fprintf("학습 데이터 손글씨 인식률 : %0.2f%%\n",round(tex2*100,4))
        fprintf("전체 학습 오차(MSE) : %0.5f",round(tex1,4))
        cla
        subplot(1,2,1)
        plot(mse);
        axis([0 inf 0 5])
        title("MSE")
        drawnow;
        subplot(1,2,2)
        
        testing = reshape(X,28,28,64);
        montage(testing(:,:,1:20));
        title("학습중인 숫자")
        drawnow;
        %
end
clc
fprintf("학습 횟수 : %d번\n",z)
fprintf("학습된 글자 수 : %d 개 (한 번 반복에 64개씩 학습을 진행합니다.)\n",z*batch)
fprintf("학습 데이터 손글씨 인식률 : %0.2f%%\n",round(tex2*100,4))
fprintf("전체 학습 오차(MSE) : %0.5f",round(tex1,4))
input("    학습이 완료되었습니다. 테스트 데이터로 실험을 해봅시다!")

%%
load test\test_input.mat; 
load test\test_output.mat; 

results = [];
for i= 1 : 10000
    f1 = relu(test(:,i)' * node_1w + node_1b');
    f2 = relu(f1 * node_2w + node_2b');
    f3 = exp(f2 * node_3w + node_3b') / sum(exp(f2 * node_3w + node_3b'));
    results(i) =  min( yy(i,:) == (f3 ==max(f3)));
end

fprintf("전체 테스트 데이터 학습 결과 %0.2f %% 정확도\n",round(mean(results)*100,2))

results = [];
for i= 1 : 10
    f1 = relu(test(:,i)' * node_1w + node_1b');
    f2 = relu(f1 * node_2w + node_2b');
    f3 = exp(f2 * node_3w + node_3b') / sum(exp(f2 * node_3w + node_3b'));
    results(i) =  min( yy(i,:) == (f3 ==max(f3)));
    
    im = reshape(test(:,i)',28,28);
    imshow(im);
    title(find(f3 ==max(f3))-1)
    input("")
end

 

 

CNN 분류 모델과 비교하여 단순한 기법이지만 MNIST 학습 성능은 평균적으로 95 ~ 98% 수준으로 높습니다. 첨부된 코드에서 학습 성능 향상을 위해 개선된 부분은 다음과 같습니다.

 

  • 데이터 셔플링, ReLu 함수, 크로스 앤트로피 오차

 

현대의 신경망에는 이외에 다양한 기술적 테크닉이 존재합니다. 함수 구현이 아닌, 이론을 통해 이해가 된 알고리즘을 자신만의 코드로 변환하여 MNIST의 성능을 향상해보시길 바랍니다. 

 

 

Your Best AI Partner DEEP.I
Jetson 시리즈 기반 엣지 컴퓨팅 시스템 제작
머신러닝 프로젝트 제작 및 상담
머신러닝 접목 졸업작품 상담
E-mail : deepi.contact.us@gmail.com
Site : www.deep-i.net

 

 

반응형