블로그 이미지
평범하게 살고 싶은 월급쟁이 기술적인 토론 환영합니다.같이 이야기 하고 싶으시면 부담 말고 연락주세요:이메일-bwcho75골뱅이지메일 닷컴. 조대협


Archive»


 
 

연예인 얼굴 인식 모델을 만들어보자

#2 CNN 모델을 만들고 학습 시켜보기

조대협 (http://bcho.tistroy.com)

선행 학습 자료

이 글은 딥러닝 컨볼루셔널 네트워크 (이하 CNN)을 이용하여 사람의 얼굴을 인식하는 모델을 만드는 튜토리얼이다. 이 글을 이해하기 위해서는 머신러닝과 컨볼루셔널 네트워크등에 대한 사전 지식이 필요한데, 사전 지식이 부족한 사람은 아래 글을 먼저 읽어보기를 추천한다.

 

머신러닝의 개요 http://bcho.tistory.com/1140

머신러닝의 기본 원리는 http://bcho.tistory.com/1139

이산 분류의 원리에 대해서는 http://bcho.tistory.com/1142

인공 신경망에 대한 개념은 http://bcho.tistory.com/1147

컨볼루셔널 네트워크에 대한 개념 http://bcho.tistory.com/1149

학습용 데이타 전처리 http://bcho.tistory.com/1176

학습용 데이타 전처리를 스케일링 하기 http://bcho.tistory.com/1177

손글씨를 CNN을 이용하여 인식하는 모델 만들기 http://bcho.tistory.com/1156

손글씨 인식 CNN 모델을 이용하여 숫자 인식 하기 http://bcho.tistory.com/1157

환경

본 예제는 텐서플로우 1.1과 파이썬 2.7 그리고 Jupyter 노트북 환경 및 구글 클라우드를 사용하여 개발되었다.

준비된 데이타

학습에 사용한 데이타는 96x96 사이즈의 얼굴 이미지로, 총 5명의 사진(안젤리나 졸리, 니콜키드만, 제시카 알바, 빅토리아 베컴,설현)을 이용하였으며, 인당 학습 데이타 40장 테스트 데이타 10장으로 총 250장의 얼굴 이미지를 사용하였다.

사전 데이타를 준비할때, 정면 얼굴을 사용하였으며, 얼굴 각도 변화 폭이 최대한 적은 이미지를 사용하였다. (참고 : https://www.slideshare.net/Byungwook/ss-76098082 ) 만약에 이 모델로 학습이 제대로 되지 않는다면 학습에 사용된 데이타가 적절하지 않은것이기 때문에 데이타를 정재해서 학습하기를 권장한다.

데이타 수집 및 정재 과정에 대한 내용은 http://bcho.tistory.com/1177 를 참고하기 바란다.

 

컨볼루셔널 네트워크 모델

얼굴 인식을 위해서, 머신러닝 모델 중 이미지 인식에 탁월한 성능을 보이는 CNN 모델을 사용하였다. 테스트용 모델이기 때문에 모델은 복잡하지 않게 설계하였다.

 

학습과 예측에 사용되는 이미지는 96x96픽셀의 RGB 컬러 이미지를 사용하였다.

아래 그림과 같은 모델을 사용했는데, 총 4개의 Convolutional 계층과, 2개의 Fully connected 계층, 하나의 Dropout 계층을 사용하였다.


Convolutional 계층의 크기는 각각 16,32,64,128개를 사용하였고, 사용된 Convolutional 필터의 사이즈는 3x3 이다.

Fully connected 계층은 각각 512, 1024를 사용하였고 Dropout 계층에서는 Keep_prob값을 0.7로 둬서 30%의 뉴론이 drop out 되도록 하여 학습을 진행하였다.

 

학습 결과 5개의 카테고리에 대해서 총 200장의 이미지로 맥북 프로 i7 CPU 기준 7000 스텝정도의 학습을 진행한 결과 테스트 정확도 기준 90% 정도의 정확도를 얻을 수 있었다.

코드 설명

텐서플로우로 구현된 코드를 살펴보자

파일에서 데이타 읽기

먼저 학습 데이타를 읽어오는 부분이다.

학습과 테스트에서 읽어드리는 데이타의 포맷은 다음과 같다

 

/Users/terrycho/training_data_class5_40/validate/s1.jpg,Sulhyun,3

이미지 파일 경로, 사람 이름 , 숫자 라벨

 

파일에서 데이타를 읽어서 처리 하는 함수는 read_data_batch(), read_data(), get_input_queue()  세가지 함수가 사용된다.

  • get_input_queue() 함수는 CSV 파일을 한줄씩 읽어서, 파일 경로 및 숫자 라벨 두가지를 리턴할 수 있는 큐를 만들어서 리턴한다.

  • read_data() 함수는 get_input_queue()에서 리턴한 큐로 부터 데이타를 하나씩 읽어서 리턴한다.

  • read_batch_data()함수는 read_data() 함수를 이용하여, 데이타를 읽어서 일정 단위(배치)로 묶어서 리턴을 하고, 그 과정에서 이미지 데이타를 뻥튀기 하는 작업을 한다.

즉 호출 구조는 다음과 같다.

 

read_batch_data():

 → Queue = get_input_queue()

 → image,label = read_data(Queue)

 → image_data = 이미지 데이타 뻥튀기

Return image_data,label

 

실제 코드를 보자

get_input_queue

get_input_queue() 함수는 CSV 파일을 읽어서 image와 labels을 리턴하는 input queue를 만들어서 리턴하는 함수이다.

 

def get_input_queue(csv_file_name,num_epochs = None):

   train_images = []

   train_labels = []

   for line in open(csv_file_name,'r'):

       cols = re.split(',|\n',line)

       train_images.append(cols[0])

       # 3rd column is label and needs to be converted to int type

       train_labels.append(int(cols[2]) )

                           

   input_queue = tf.train.slice_input_producer([train_images,train_labels],

                                              num_epochs = num_epochs,shuffle = True)

   

   return input_queue

 

CSV 파일을 순차적으로 읽은 후에, train_images와 train_labels라는 배열에 넣은 다음 tf.train.slice_input_producer를 이용하여 큐를 만들어냈다. 이때 중요한 점은 shuffle=True라는 옵션을 준것인데, 만약에 이 옵션을 주지 않으면, 학습 데이타를 큐에서 읽을때 CSV에서 읽은 순차적으로 데이타를 리턴한다. 즉 현재 데이타 포맷은 Jessica Alba가 40개, Jolie 가 40개, Nicole Kidman이 40개 .. 식으로 순서대로 들어가 있기 때문에, Jessica Alba를 40개 리턴한 후 Jolie를 40개 리턴하는 식이 된다.  이럴 경우 Convolutional 네트워크가 Jessica Alba에 치우쳐지기 때문에 제대로 학습이 되지 않는다. Shuffle은 필수이다.

read_data()

input_queue에서 데이타를 읽는 부분인데 특이한 점은 input_queue에서 읽어드린 이미지 파일명의 파일을 읽어서 데이타 객체로 저장해야 한다. 텐서플로우에서는 tf.image.decode_jpeg, tf.image.decode_png 등을 이용하여 이러한 기능을 제공한다.

def read_data(input_queue):

   image_file = input_queue[0]

   label = input_queue[1]

   

   image =  tf.image.decode_jpeg(tf.read_file(image_file),channels=FLAGS.image_color)

   

   return image,label,image_file

read_data_batch()

마지막으로 read_data_batch() 함수 부분이다.get_input_queue에서 읽은 큐를 가지고 read_data함수에 넣어서 이미지 데이타와 라벨을 읽어서 리턴하는 값을 받아서 일정 단위로 (배치) 묶어서 리턴하는 함수이다. 중요한 부분이 데이타를 뻥튀기 하는 부분이 있다.

이 모델에서 학습 데이타가 클래스당 40개 밖에 되지 않기 때문에 학습데이타가 부족하다. 그래서 여기서 사용한 방법은 read_data에서 리턴된 이미지 데이타에 대해서 tf.image.random_xx 함수를 이용하여 좌우를 바꾸거나, brightness,contrast,hue,saturation 함수를 이용하여 매번 색을 바꿔서 리턴하도록 하였다.

 

def read_data_batch(csv_file_name,batch_size=FLAGS.batch_size):

   input_queue = get_input_queue(csv_file_name)

   image,label,file_name= read_data(input_queue)

   image = tf.reshape(image,[FLAGS.image_size,FLAGS.image_size,FLAGS.image_color])

   

   # random image

   image = tf.image.random_flip_left_right(image)

   image = tf.image.random_brightness(image,max_delta=0.5)

   image = tf.image.random_contrast(image,lower=0.2,upper=2.0)

   image = tf.image.random_hue(image,max_delta=0.08)

   image = tf.image.random_saturation(image,lower=0.2,upper=2.0)

   

   batch_image,batch_label,batch_file = tf.train.batch([image,label,file_name],batch_size=batch_size)

   #,enqueue_many=True)

   batch_file = tf.reshape(batch_file,[batch_size,1])

 

   batch_label_on_hot=tf.one_hot(tf.to_int64(batch_label),

       FLAGS.num_classes, on_value=1.0, off_value=0.0)

   return batch_image,batch_label_on_hot,batch_file

 

그리고 마지막 부분에 label을 tf.one_hot을 이용해서 변환한것을 볼 수 있는데, 입력된 label은 0,1,2,3,4 과 같은 단일 정수이다. 그런데, CNN에서 나오는 결과는 정수가 아니라 클래스가 5개인 (분류하는 사람이 5명이기 때문에) 행렬이다. 즉 Jessica Alba일 가능성이 90%이고, Jolie일 가능성이 10%이면 결과는 [0.9,0.1,0,0,0] 식으로 리턴이 되기 때문에, 입력된 라벨 0은 [1,0,0,0,0], 라벨 1은 [0,1,0,0,0] 라벨 2는 [0,0,1,0,0] 식으로 변환되어야 한다. tf.one_hot 이라는 함수가 이 기능을 수행해준다.

 

모델 코드

모델은 앞서 설명했듯이 4개의 Convolutional 계층과, 2개의 Fully connected 계층 그리고 Dropout 계층을 사용한다. 각각의 계층별로는 코드가 다르지 않고 인지만 다르니 하나씩 만 설명하도록 한다.

 

Convolutional 계층

아래 코드는 두번째 Convolutional 계층의 코드이다.

  • FLAGS.conv2_layer_size 는 이 Convolutional 계층의 뉴런의 수로 32개를 사용한다.

  • FLAGS.conv2_filter_size 는 필터 사이즈를 지정하는데, 3x3 을 사용한다.

  • FLAGS.stride2 = 1 는 필터의 이동 속도로 한칸씩 이동하도록 정의했다.

 

# convolutional network layer 2

def conv2(input_data):

   FLAGS.conv2_filter_size = 3

   FLAGS.conv2_layer_size = 32

   FLAGS.stride2 = 1

   

   with tf.name_scope('conv_2'):

       W_conv2 = tf.Variable(tf.truncated_normal(

                       [FLAGS.conv2_filter_size,FLAGS.conv2_filter_size,FLAGS.conv1_layer_size,FLAGS.conv2_layer_size],

                                             stddev=0.1))

       b2 = tf.Variable(tf.truncated_normal(

                       [FLAGS.conv2_layer_size],stddev=0.1))

       h_conv2 = tf.nn.conv2d(input_data,W_conv2,strides=[1,1,1,1],padding='SAME')

       h_conv2_relu = tf.nn.relu(tf.add(h_conv2,b2))

       h_conv2_maxpool = tf.nn.max_pool(h_conv2_relu

                                       ,ksize=[1,2,2,1]

                                       ,strides=[1,2,2,1],padding='SAME')

       

       

   return h_conv2_maxpool

 

다음 Weight 값 W_conv2 와 Bias 값 b2를 지정한후에, 간단하게 tf.nn.conv2d 함수를 이용하면 2차원의 Convolutional 네트워크를 정의해준다. 다음 결과가 나오면 이 결과를 액티베이션 함수인 relu 함수에 넣은 후에, 마지막으로 max pooling 을 이용하여 결과를 뽑아낸다.

 

각 값의 의미에 대해서는 http://bcho.tistory.com/1149 의 컨볼루셔널 네트워크 개념 글을 참고하기 바란다.

같은 방법으로 총 4개의 Convolutional 계층을 중첩한다.

 

Fully Connected 계층

앞서 정의한 4개의 Convolutional 계층을 통과하면 다음 두개의 Fully Connected 계층을 통과하게 되는데 모양은 다음과 같다.

  • FLAGS.fc1_layer_size = 512 를 통하여 Fully connected 계층의 뉴런 수를 512개로 지정하였다.

 

# fully connected layer 1

def fc1(input_data):

   input_layer_size = 6*6*FLAGS.conv4_layer_size

   FLAGS.fc1_layer_size = 512

   

   with tf.name_scope('fc_1'):

       # 앞에서 입력받은 다차원 텐서를 fcc에 넣기 위해서 1차원으로 피는 작업

       input_data_reshape = tf.reshape(input_data, [-1, input_layer_size])

       W_fc1 = tf.Variable(tf.truncated_normal([input_layer_size,FLAGS.fc1_layer_size],stddev=0.1))

       b_fc1 = tf.Variable(tf.truncated_normal(

                       [FLAGS.fc1_layer_size],stddev=0.1))

       h_fc1 = tf.add(tf.matmul(input_data_reshape,W_fc1) , b_fc1) # h_fc1 = input_data*W_fc1 + b_fc1

       h_fc1_relu = tf.nn.relu(h_fc1)

   

   return h_fc1_relu

 

Fully connected 계층은 단순하게 relu(W*x + b) 함수이기 때문에 이 함수를 위와 같이 그대로 적용하였다.

마지막 계층

Fully connected 계층을 거쳐 나온 데이타는 Dropout 계층을 거친후에, 5개의 카테고리에 대한 확률로 결과를 내기 위해서 final_out 계층을 거치게 되는데, 이 과정에서 softmax 함수를 사용해야 하나, 학습 과정에서는 별도로 softmax 함수를 사용하지 않는다. softmax는 나온 결과의 합이 1.0이 되도록 값을 변환해주는 것인데, 학습 과정에서는 5개의 결과 값이 어떤 값이 나오던 가장 큰 값에 해당하는 것이 예측된 값이기 때문에, 그 값과 입력된 라벨을 비교하면 되기 때문이다.

즉 예를 들어 Jessica Alba일 확률이 100%면 실제 예측에서는 [1,0,0,0,0] 식으로 결과가 나와야 되지만, 학습 중는 Jessica Alaba 로 예측이 되었다고만 알면 되기 때문에 결과가 [1292,-0.221,-0.221,-0.221] 식으로 나오더라도 최대값만 찾으면 되기 때문에 별도로 softmax 함수를 적용할 필요가 없다. Softmax 함수는 연산 비용이 큰 함수이기 때문에 일반적으로 학습 단계에서는 적용하지 않는다.

 

마지막 계층의 코드는 다음과 같다.

# final layer

def final_out(input_data):

 

   with tf.name_scope('final_out'):

       W_fo = tf.Variable(tf.truncated_normal([FLAGS.fc2_layer_size,FLAGS.num_classes],stddev=0.1))

       b_fo = tf.Variable(tf.truncated_normal(

                       [FLAGS.num_classes],stddev=0.1))

       h_fo = tf.add(tf.matmul(input_data,W_fo) , b_fo) # h_fc1 = input_data*W_fc1 + b_fc1

       

   # 최종 레이어에 softmax 함수는 적용하지 않았다.

       

   return h_fo

전체 네트워크 모델 정의

이제 각 CNN의 각 계층을 함수로 정의 하였으면 각 계층을 묶어 보도록 하자. 묶는 법은 간단하다 앞 계층에서 나온 계층을 순서대로 배열하고 앞에서 나온 결과를 뒤의 계층에 넣는 식으로 묶으면 된다.

 

# build cnn_graph

def build_model(images,keep_prob):

   # define CNN network graph

   # output shape will be (*,48,48,16)

   r_cnn1 = conv1(images) # convolutional layer 1

   print ("shape after cnn1 ",r_cnn1.get_shape())

   

   # output shape will be (*,24,24,32)

   r_cnn2 = conv2(r_cnn1) # convolutional layer 2

   print ("shape after cnn2 :",r_cnn2.get_shape() )

   

   # output shape will be (*,12,12,64)

   r_cnn3 = conv3(r_cnn2) # convolutional layer 3

   print ("shape after cnn3 :",r_cnn3.get_shape() )

 

   # output shape will be (*,6,6,128)

   r_cnn4 = conv4(r_cnn3) # convolutional layer 4

   print ("shape after cnn4 :",r_cnn4.get_shape() )

   

   # fully connected layer 1

   r_fc1 = fc1(r_cnn4)

   print ("shape after fc1 :",r_fc1.get_shape() )

 

   # fully connected layer2

   r_fc2 = fc2(r_fc1)

   print ("shape after fc2 :",r_fc2.get_shape() )

   

   ## drop out

   # 참고 http://stackoverflow.com/questions/34597316/why-input-is-scaled-in-tf-nn-dropout-in-tensorflow

   # 트레이닝시에는 keep_prob < 1.0 , Test 시에는 1.0으로 한다.

   r_dropout = tf.nn.dropout(r_fc2,keep_prob)

   print ("shape after dropout :",r_dropout.get_shape() )

   

   # final layer

   r_out = final_out(r_dropout)

   print ("shape after final layer :",r_out.get_shape() )

 

   return r_out

 

이 build_model 함수는 image 를 입력 값으로 받아서 어떤 카테고리에 속할지를 리턴하는 컨볼루셔널 네트워크이다.  중간에 Dropout 계층이 추가되어 있는데, tf.nn.dropout함수를 이용하면 간단하게 dropout 계층을 구현할 수 있다. r_fc2는 Dropout 계층 앞의 Fully Connected 계층에서 나온 값이고,  두번째 인자로 남긴 keep_prob는 Dropout 비율이다.

 

   r_dropout = tf.nn.dropout(r_fc2,keep_prob)

   print ("shape after dropout :",r_dropout.get_shape() )

 

모델 학습

데이타를 읽는 부분과 학습용 모델 정의가 끝났으면 실제로 학습을 시켜보자

 

def main(argv=None):

   

   # define placeholders for image data & label for traning dataset

   

   images = tf.placeholder(tf.float32,[None,FLAGS.image_size,FLAGS.image_size,FLAGS.image_color])

   labels = tf.placeholder(tf.int32,[None,FLAGS.num_classes])

   image_batch,label_batch,file_batch = read_data_batch(TRAINING_FILE)

 

먼저 학습용 모델에 넣기 위한 image 데이타를 읽어드릴 placeholder를 images로 정의하고, 다음으로 모델에 의해 계산된 결과와 비교하기 위해서 학습데이타에서 읽어드린 label 데이타를 저장하기 위한 placeholder를 labels로 정의한다. 다음 image_batch,label_batch,fle_batch 변수에 배치로 학습용 데이타를 읽어드린다. 그리고 dropout 계층에서 dropout 비율을 지정할 keep_prob를 place holder로 정의한다.

각 변수가 지정되었으면, build_model 함수를 호출하여, images 값과 keep_prob 값을 넘겨서 Convolutional 네트워크에 값을 넣도록 그래프를 정의하고 그 결과 값을 prediction으로 정의한다.

 

   keep_prob = tf.placeholder(tf.float32) # dropout ratio

   prediction = build_model(images,keep_prob)

   # define loss function

   loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=labels))

   tf.summary.scalar('loss',loss)

 

   #define optimizer

   optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)

   train = optimizer.minimize(loss)

 

중간 중간에 학습 과정을 시각화 하기 위해서 tf.summary.scalar 함수를 이용하여 loss 값을 저장하였다.

 

그래프 생성이 완료 되었으면, 학습에서 계산할 비용 함수를 정의한다. 비용함수는 sofrmax cross entopy 함수를 이용하여, 모델에 의해서 예측된 값 prediction 과, 학습 파일에서 읽어드린 label 값을 비교하여 loss 값에 저장한다.

그리고 이 비용 최적화 함수를 위해서 옵티마이져를 AdamOptimizer를 정의하여, loss 값을 최적화 하도록 하였다.

 

학습용 모델 정의와, 비용 함수, 옵티마이저 정의가 끝났으면 학습 중간 중간 학습된 모델을 테스트하기 위한 Validation 관련 항목등을 정의한다.

 

   # for validation

   #with tf.name_scope("prediction"):

   validate_image_batch,validate_label_batch,validate_file_batch = read_data_batch(VALIDATION_FILE)

   label_max = tf.argmax(labels,1)

   pre_max = tf.argmax(prediction,1)

   correct_pred = tf.equal(tf.argmax(prediction,1),tf.argmax(labels,1))

   accuracy = tf.reduce_mean(tf.cast(correct_pred,tf.float32))

           

   tf.summary.scalar('accuracy',accuracy)

      

   startTime = datetime.now()

 

학습용 데이타가 아니라 검증용 데이타를 VALIDATION_FILE에서 읽어서 데이타를 validate_image_batch,validate_label_batch,validate_file_batch에 저장한다. 다음, 정확도 체크를 위해서 학습에서 예측된 라벨값과, 학습 데이타용 라벨값을 비교하여 같은지 틀린지를 비교하고, 이를 가지고 평균을 내서 정확도 (accuracy)로 사용한다.

 

학습용 모델과, 테스트용 데이타 등이 준비되었으면 이제 학습을 시작한다.

학습을 시직하기 전에, 학습된 모델을 저장하기 위해서 tf.train.Saver()를 지정한다. 그리고, 그래프로 loss와 accuracy등을 저장하기 위해서 Summary write를 저장한다.

다음 tf.global_variable_initializer()를 수행하여 변수를 초기화 하고, queue에서 데이타를 읽기 위해서 tf.train.Corrdinator를 선언하고 tf.start_queue_runners를 지정하여, queue 러너를 실행한다.

 

   #build the summary tensor based on the tF collection of Summaries

   summary = tf.summary.merge_all()

   

   with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) as sess:

       saver = tf.train.Saver() # create saver to store training model into file

       summary_writer = tf.summary.FileWriter(FLAGS.log_dir,sess.graph)

       

       init_op = tf.global_variables_initializer() # use this for tensorflow 0.12rc0

       coord = tf.train.Coordinator()

       threads = tf.train.start_queue_runners(sess=sess, coord=coord)

       sess.run(init_op)

 

변수 초기화와 세션이 준비되었기 때문에 이제 학습을 시작해보자. for 루프를 이용하여 총 10,000 스텝의 학습을 하도록 하였다.

 

       for i in range(10000):

           images_,labels_ = sess.run([image_batch,label_batch])

 

다음 image_batch와 label_batch에서 값을 읽어서 앞에서 정의한 모델에 넣고 train 그래프 (AdamOptimizer를 정의한)를 실행한다.

 

           sess.run(train,feed_dict={images:images_,labels:labels_,keep_prob:0.7})

 

이때 앞에서 읽은 images_와, labels_ 데이타를 피딩하고 keep_prob 값을 0.7로 하여 30% 정도의 값을 Dropout 시킨다.

 

다음 10 스텝 마다 학습 상태를 체크하도록 하였다.

           

           if i % 10 == 0:

               now = datetime.now()-startTime

               print('## time:',now,' steps:',i)         

               

               # print out training status

               rt = sess.run([label_max,pre_max,loss,accuracy],feed_dict={images:images_

                                                         , labels:labels_

                                                         , keep_prob:1.0})

               print ('Prediction loss:',rt[2],' accuracy:',rt[3])

위와 같이 loss 값과 accuracy 값을 받아서 출력하여 현재 모델의 비용 함수 값과 정확도를 측정하고

 

               # validation steps

               validate_images_,validate_labels_ = sess.run([validate_image_batch,validate_label_batch])

               rv = sess.run([label_max,pre_max,loss,accuracy],feed_dict={images:validate_images_

                                                         , labels:validate_labels_

                                                         , keep_prob:1.0})

               print ('Validation loss:',rv[2],' accuracy:',rv[3])

학습용 데이타가 아니라 위와 같이 테스트용 데이타를 피딩하여, 테스트용 데이타로 정확도를 검증한다. 이때 keep_prob를 1.0으로 해서 Dropout 없이 100% 네트워크를 활용한다.

 

               if(rv[3] > 0.9):

                   Break

 

만약에 테스트 정확도가 90% 이상이면 학습을 멈춘다. 그리고 아래와 같이 Summary

 

               # validation accuracy

               summary_str = sess.run(summary,feed_dict={images:validate_images_

                                                         , labels:validate_labels_

                                                         , keep_prob:1.0})

 

               summary_writer.add_summary(summary_str,i)

               summary_writer.flush()

 

마지막으로 다음과 같이 학습이 다된 모델을 saver.save를 이용하여 저장하고, 사용된 리소스들을 정리한다.

       saver.save(sess, 'face_recog') # save session

       coord.request_stop()

       coord.join(threads)

       print('finish')

   

main()

 

이렇게 학습을 끝내면 본인의 경우 약 7000 스텝에서 테스트 정확도 91%로 끝난것을 확인할 수 있다.

 

아래는 텐서보드를 이용하여 학습 과정을 시각화한 내용이다.

 


 

코드는 공개가 가능하지만 학습에 사용한 데이타는 저작권 문제로 공유가 불가능하다. 약 200장의 사진만 제대로 수집을 하면 되기 때문에 각자 수집을 해서 학습을 도전해보는 것을 권장한다. (더 많은 인물에 대한 시도를 해보는것도 좋겠다.)

정리 하며

혹시나 이 튜토리얼을 따라하면서 학습 데이타를 공개할 수 있는 분들이 있다면 다른 분들에게도 많은 도움이 될것이라고 생각한다. 가능하면 데이타가 공개되었으면 좋겠다.

전체 코드는 https://github.com/bwcho75/facerecognition/blob/master/1.%2BFace%2BRecognition%2BTraining.ipynb 에 있다.

그리고 직접 사진을 수집해보면, 데이타 수집 및 가공이 얼마나 어려운지 알 수 있기 때문에 직접 한번 시도해보는 것도 권장한다. 아래는 크롬브라우져 플러그인으로 구글 검색에서 나온 이미지를 싹 긁을 수 있는 플러그인이다. Bulk Download Images (ZIG)

https://www.youtube.com/watch?v=k5ioaelzEBM

 



이 플러그인을 이용하면 손쉽게 특정 인물의 데이타를 수집할 수 있다.

다음 글에서는 학습이 끝난 데이타를 이용해서 실제로 예측을 해보는 부분에 대해서 소개하도록 하겠다.

 

 

 

텐서플로우 - 파일에서 학습데이타를 읽어보자#1


조대협 (http://bcho.tistory.com)


텐서플로우를 학습하면서 실제 모델을 만들어보려고 하니 생각보다 데이타 처리에 대한 부분에서 많은 노하우가 필요하다는 것을 알게되었다. MNIST와 같은 예제는 데이타가 다 이쁘게 정리되어서 학습 하기 좋은 형태로 되어 있지만, 실제로 내 모델을 만들고 학습을 하기 위해서는 데이타에 대한 정재와 분류 작업등이 많이 필요하다.


이번글에서는 학습에 필요한 데이타를 파일에서 읽을때 필요한 큐에 대한 개념에 대해서 알아보도록 한다.


피딩 (Feeding) 개념 복습


텐서플로우에서 모델을 학습 시킬때, 학습 데이타를 모델에 적용하는 방법은 일반적으로 피딩 (feeding)이라는 방법을 사용한다. 메모리상의 어떤 변수 리스트 형태로 값을 저장한 후에, 모델을 세션에서 실행할 때, 리스트에서 값을 하나씩 읽어서 모델에 집어 넣는 방식이다.



위의 그림을 보면, y=W*x라는 모델에서 학습 데이타 x는 [1,2,3,4,5]로, 첫번째 학습에는 1, 두번째 학습에는 2를 적용하는 식으로 피딩이 된다.

그런데, 이렇게 피딩을 하려면, 학습 데이타 [1,2,3,4,5]가 메모리에 모두 적재되어야 하는데, 실제로 모델을 만들어서 학습을할때는 데이타의 양이 많기 때문에 메모리에 모두 적재하고 학습을 할 수 가 없고, 파일에서 읽어드리면서 학습을 해야 한다.


텐서플로우 큐에 대해서

이러한 문제를 해결하기 위해서는 파일에서 데이타를 읽어가면서, 읽은 데이타를 순차적으로 모델에 피딩하면 되는데, 이때 큐를 사용한다.


파일에서 데이타를 읽는 방법에 앞서서 큐를 설명하면, 큐에 데이타를 넣는 것(Enqueue) 은 Queue Runner 라는 것이 한다.

이 Queue Runner가 큐에 어떤 데이타를 어떻게 넣을지를 정의 하는 것이 Enqueue_operation인데, 데이타를 읽어서 실제로 어떻게 Queue에 Enqueue 하는지를 정의한다.


이 Queue Runner는 멀티 쓰레드로 작동하는데, Queue Runner 안의 쓰레드들을 관리해주기 위해서 별도로 Coordinator라는 것을 사용한다.


이 개념을 정리해서 도식화 해주면 다음과 같다.


=


Queue Runner 는 여러개의 쓰레드 (T)를 가지고 있고, 이 쓰레드들은 Coordinator들에 의해서 관리된다. Queue Runner 가 Queue에 데이타를 넣을때는 Enqueue_op이라는 operation에 의해 정의된 데로 데이타를 Queue에 집어 넣는다.


위의 개념을 코드로 구현해보자


import tensorflow as tf


QUEUE_LENGTH = 20

q = tf.FIFOQueue(QUEUE_LENGTH,"float")

enq_ops = q.enqueue_many(([1.0,2.0,3.0,4.0],) )

qr = tf.train.QueueRunner(q,[enq_ops,enq_ops,enq_ops])


sess = tf.Session()

# Create a coordinator, launch the queue runner threads.

coord = tf.train.Coordinator()

threads = qr.create_threads(sess, coord=coord, start=True)


for step in xrange(20):

   print(sess.run(q.dequeue()))


coord.request_stop()

coord.join(threads)


sess.close()


Queue 생성

tf.FIFOQUEUE를 이용해서 큐를 생성한다.

q = tf.FIFOQueue(QUEUE_LENGTH,"float")

첫번째 인자는 큐의 길이를 정하고, 두번째는 dtype으로 큐에 들어갈 데이타형을 지정한다.

Queue Runner 생성

다음은 Queue Runner를 만들기 위해서 enqueue_operation 과, QueueRunner를 생성한다.

enq_ops = q.enqueue_many(([1.0,2.0,3.0,4.0],) )

qr = tf.train.QueueRunner(q,[enq_ops,enq_ops,enq_ops])

enqueue operation인 enq_ops는 위와 같이 한번에 [1.0,2.0,3.0,4.0] 을 큐에 넣는 operation으로 지정한다.

그리고 Queue Runner를 정의하는데, 앞에 만든 큐에 데이타를 넣을것이기 때문에 인자로 큐 ‘q’를 넘기고 list 형태로 enq_ops를 3개를 넘긴다. 3개를 넘기는 이유는 Queue Runner가 멀티쓰레드 기반이기 때문에 각 쓰레드에서 Enqueue시 사용할 Operation을 넘기는 것으로, 3개를 넘긴것은 3개의 쓰레드에 Enqueue 함수를 각각 지정한 것이다.

만약 동일한 enqueue operation을 여러개의 쓰레드로 넘길 경우 위 코드처럼 일일이 enqueue operation을 쓸 필요 없이

qr = tf.train.QueueRunner(q,[enq_ops]*NUM_OF_THREAD)

[enq_ops] 에 쓰레드 수 (NUM_OF_THREAD)를 곱해주면 된다.

Coordinator 생성

이제 Queue Runner에서 사용할 쓰레드들을 관리할 Coordinator를 생성하자

coord = tf.train.Coordinator()

Queue Runner용 쓰레드 생성

Queue Runner와 쓰레드를 관리할 Coordinator 가 생성되었으면, Queue Runner에서 사용할 쓰레드들을 생성하자

threads = qr.create_threads(sess, coord=coord, start=True)

생성시에는 세션과, Coordinator를 지정하고, start=True로 해준다.

start=True로 설정하지 않으면, 쓰레드가 생성은 되었지만, 동작을 하지 않기 때문에, 큐에 메세지를 넣지 않는다.

큐 사용

이제 큐에서 데이타를 꺼내와 보자. 아래코드는 큐에서 20번 데이타를 꺼내와서 출력하는 코드이다.

for step in xrange(20):

   print(sess.run(q.dequeue()))


큐가 비워지면, QueueRunner를 이용하여 계속해서 데이타를 채워 넣는다. 즉 큐가 비기전에 계속해서 [1.0,2.0,3.0,4.0] 데이타가 큐에 계속 쌓인다.

쓰레드 정지

큐 사용이 끝났으면 Queue Runner의 쓰레드들을 모두 정지 시켜야 한다.

coord.request_stop()

을 이용하면 모든 쓰레드들을 정지 시킨다.

coord.join(threads)

는 다음 코드를 진행하기전에, Queue Runner의 모든 쓰레드들이 정지될때 까지 기다리는 코드이다.

멀티 쓰레드

Queue Runner가 멀티 쓰레드라고 하는데, 그렇다면 쓰레드들이 어떻게 데이타를 큐에 넣고 enqueue 연산은 어떻게 동작할까?

그래서, 간단한 테스트를 해봤다. 3개의 쓰레드를 만든 후에, 각 쓰레드에 따른 enqueue operation을 다르게 지정해봤다.

import tensorflow as tf


QUEUE_LENGTH = 20

q = tf.FIFOQueue(QUEUE_LENGTH,"float")

enq_ops1 = q.enqueue_many(([1.0,2.0,3.0],) )

enq_ops2 = q.enqueue_many(([4.0,5.0,6.0],) )

enq_ops3 = q.enqueue_many(([6.0,7.0,8.0],) )

qr = tf.train.QueueRunner(q,[enq_ops1,enq_ops2,enq_ops3])


sess = tf.Session()

# Create a coordinator, launch the queue runner threads.

coord = tf.train.Coordinator()

threads = qr.create_threads(sess, coord=coord, start=True)


for step in xrange(20):

   print(sess.run(q.dequeue()))


coord.request_stop()

coord.join(threads)


sess.close()


실행을 했더니, 다음과 같은 결과를 얻었다.


첫번째 실행 결과

1.0

2.0

3.0

4.0

5.0

6.0

6.0

7.0

8.0



두번째 실행결과

1.0

2.0

3.0

1.0

2.0

3.0

4.0

5.0

6.0


결과에서 보는것과 같이 Queue Runner의 3개의 쓰레드중 하나가 무작위로 (순서에 상관없이) 실행되서 데이타가 들어가는 것을 볼 수 있었다.


파일에서 데이타 읽기


자 그러면 이 큐를 이용해서, 파일 목록을 읽고, 파일을 열어서 학습 데이타를 추출해서 학습 파이프라인에 데이타를 넣어주면 된다.

텐서 플로우에서는 파일에서 데이타를 읽는 처리를 위해서 앞에서 설명한 큐 뿐만 아니라 Reader와 Decoder와 같은 부가적인 기능을 제공한다.


  1. 파일 목록을 읽는다.

  2. 읽은 파일목록을 filename queue에 저장한다.

  3. Reader 가 finename queue 에서 파일명을 하나씩 읽어온다.

  4. Decoder에서 해당 파일을 열어서 데이타를 읽어들인다.

  5. 필요하면 읽어드린 데이타를 텐서플로우 모델에 맞게 정재한다. (이미지를 리사이즈 하거나, 칼라 사진을 흑백으로 바꾸거나 하는 등의 작업)

  6. 텐서 플로우에 맞게 정재된 학습 데이타를 학습 데이타 큐인 Example Queue에 저장한다.

  7. 모델에서 Example Queue로 부터 학습 데이타를 읽어서 학습을 한다.


먼저 파일 목록을 읽는 부분은 파일 목록을 읽어서 각 파일명을  큐에 넣은 부분을 살펴보자.

다음 예제코드는 파일명 목록을 받은 후에, filename queue에 파일명을 넣은후에, 파일명을 하나씩 꺼내는 예제이다.

import tensorflow as tf


filename_queue = tf.train.string_input_producer(["1","2","3"],shuffle=False)


with tf.Session() as sess:

   

   coord = tf.train.Coordinator()

   threads = tf.train.start_queue_runners(coord=coord,sess=sess)

   

   for step in xrange(10):

       print(sess.run(filename_queue.dequeue()) )


   coord.request_stop()

   coord.join(threads)


코드를 보면 큐 생성이나, enqueue operation 처리들이 다소 다른것을 볼 수 있는데, 이는 텐서플로우에서는  학습용 파일 목록을 편리하게 처리 하기 위해서 조금 더 추상화된 함수들을 제공하기 때문이다.


filename_queue = tf.train.string_input_producer(["1","2","3"],shuffle=False)


train.xx_input_producer() 함수는 입력 받은 큐를 만드는 역할을 한다.

위의 명령을 수행하면, filename queue 가 FIFO (First In First Out)형태로 생긴다.


큐가 생기기는 하지만, 실제로 큐에 파일명이 들어가지는 않는다. (아직 Queue Runner와 쓰레드들을 생성하지 않았기 때문에)

다음으로 쓰레드를 관리하기 위한 Coordinator 를 생성한다.

   coord = tf.train.Coordinator()

Coordinator 가 생성이 되었으면 Queue Runner와 Queue Runner에서 사용할 Thread들을 생성해주는데,  start_queue_runner 라는 함수로, 이 기능들을 모두 구현해놨다.

   threads = tf.train.start_queue_runners(coord=coord,sess=sess)

이 함수는 Queue Runner와, 쓰레드 생성 및 시작 뿐 만 아니라 Queue Runner 쓰레드가 사용하는 enqueue operation 까지 파일형태에 맞춰서 자동으로 생성 및 지정해준다.






Queue, Queue Runner, Coordinator와 Queue Runner가 사용할 쓰레드들이 생성되고 시작되었기 때문에,Queue Runner는 filename queue에 파일명을 enqueue 하기 시작한다.

파일명 Shuffling

위의 예제를 실행하면 파일명이 다음과 같이 1,2,3 이 순차적으로 반복되서 나오는 것을 볼 수 있다.

실행 결과

1

2

3

1

2

3

1

2

3

1


만약에 파일명을 랜덤하게 섞어서 나오게 하려면 어떻게해야 할까? (매번 학습시 학습데이타가 일정 패턴으로 몰려서 편향되지 않고, 랜덤하게 나와서 학습 효과를 높이고자 할때)

filename_queue = tf.train.string_input_producer(["1","2","3"],shuffle=False)

큐를 만들때, 다음과 같이 셔플 옵션을 True로 주면 된다.

filename_queue = tf.train.string_input_producer(["1","2","3"],shuffle=True)

실행 결과

2

1

3

2

3

1

2

3

1

1

지금까지 파일명을 지정해서 이 파일명들을 filename queue에 넣는 방법에 대해서 알아보았다.

다음은 이 file name queue에서 파일을 순차적으로 꺼내서

  • 파일을 읽어드리고

  • 각 파일을 파싱해서 학습 데이타를 만들고

  • 학습 데이타용 큐 (example queue)에 넣는 방법

에 대해서 설명하도록 한다.



딥러닝을 이용한 숫자 이미지 인식 #2/2


앞서 MNIST 데이타를 이용한 필기체 숫자를 인식하는 모델을 컨볼루셔널 네트워크 (CNN)을 이용하여 만들었다. 이번에는 이 모델을 이용해서 필기체 숫자 이미지를 인식하는 코드를 만들어 보자


조금 더 테스트를 쉽게 하기 위해서, 파이썬 주피터 노트북내에서 HTML 을 이용하여 마우스로 숫자를 그릴 수 있도록 하고, 그려진 이미지를 어떤 숫자인지 인식하도록 만들어 보겠다.



모델 로딩

먼저 앞의 예제에서 학습을한 모델을 로딩해보도록 하자.

이 코드는 주피터 노트북에서 작성할때, 모델을 학습 시키는 코드 (http://bcho.tistory.com/1156) 와 별도의 새노트북에서 구현을 하도록 한다.


코드

import tensorflow as tf

import numpy as np

import matplotlib.pyplot as plt

from tensorflow.examples.tutorials.mnist import input_data


#이미 그래프가 있을 경우 중복이 될 수 있기 때문에, 기존 그래프를 모두 리셋한다.

tf.reset_default_graph()


num_filters1 = 32


x = tf.placeholder(tf.float32, [None, 784])

x_image = tf.reshape(x, [-1,28,28,1])


#  layer 1

W_conv1 = tf.Variable(tf.truncated_normal([5,5,1,num_filters1],

                                         stddev=0.1))

h_conv1 = tf.nn.conv2d(x_image, W_conv1,

                      strides=[1,1,1,1], padding='SAME')


b_conv1 = tf.Variable(tf.constant(0.1, shape=[num_filters1]))

h_conv1_cutoff = tf.nn.relu(h_conv1 + b_conv1)


h_pool1 =tf.nn.max_pool(h_conv1_cutoff, ksize=[1,2,2,1],

                       strides=[1,2,2,1], padding='SAME')


num_filters2 = 64


# layer 2

W_conv2 = tf.Variable(

           tf.truncated_normal([5,5,num_filters1,num_filters2],

                               stddev=0.1))

h_conv2 = tf.nn.conv2d(h_pool1, W_conv2,

                      strides=[1,1,1,1], padding='SAME')


b_conv2 = tf.Variable(tf.constant(0.1, shape=[num_filters2]))

h_conv2_cutoff = tf.nn.relu(h_conv2 + b_conv2)


h_pool2 =tf.nn.max_pool(h_conv2_cutoff, ksize=[1,2,2,1],

                       strides=[1,2,2,1], padding='SAME')


# fully connected layer

h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*num_filters2])


num_units1 = 7*7*num_filters2

num_units2 = 1024


w2 = tf.Variable(tf.truncated_normal([num_units1, num_units2]))

b2 = tf.Variable(tf.constant(0.1, shape=[num_units2]))

hidden2 = tf.nn.relu(tf.matmul(h_pool2_flat, w2) + b2)


keep_prob = tf.placeholder(tf.float32)

hidden2_drop = tf.nn.dropout(hidden2, keep_prob)


w0 = tf.Variable(tf.zeros([num_units2, 10]))

b0 = tf.Variable(tf.zeros([10]))

k = tf.matmul(hidden2_drop, w0) + b0

p = tf.nn.softmax(k)


# prepare session

sess = tf.InteractiveSession()

sess.run(tf.global_variables_initializer())

saver = tf.train.Saver()

saver.restore(sess, '/Users/terrycho/anaconda/work/cnn_session')


print 'reload has been done'


그래프 구현

코드를 살펴보면, #prepare session 부분 전까지는 이전 코드에서의 그래프를 정의하는 부분과 동일하다. 이 코드는 우리가 만든 컨볼루셔널 네트워크를 복원하는 부분이다.


변수 데이타 로딩

그래프의 복원이 끝나면, 저장한 세션의 값을 다시 로딩해서 학습된 W와 b값들을 다시 로딩한다.


# prepare session

sess = tf.InteractiveSession()

sess.run(tf.global_variables_initializer())

saver = tf.train.Saver()

saver.restore(sess, '/Users/terrycho/anaconda/work/cnn_session')


이때 saver.restore 부분에서 앞의 예제에서 저장한 세션의 이름을 지정해준다.

HTML을 이용한 숫자 입력

그래프와 모델 복원이 끝났으면 이 모델을 이용하여, 숫자를 인식해본다.

테스트하기 편리하게 HTML로 마우스로 숫자를 그릴 수 있는 화면을 만들어보겠다.

주피터 노트북에서 새로운 Cell에 아래와 같은 내용을 입력한다.


코드

input_form = """

<table>

<td style="border-style: none;">

<div style="border: solid 2px #666; width: 143px; height: 144px;">

<canvas width="140" height="140"></canvas>

</div></td>

<td style="border-style: none;">

<button onclick="clear_value()">Clear</button>

</td>

</table>

"""


javascript = """

<script type="text/Javascript">

   var pixels = [];

   for (var i = 0; i < 28*28; i++) pixels[i] = 0

   var click = 0;


   var canvas = document.querySelector("canvas");

   canvas.addEventListener("mousemove", function(e){

       if (e.buttons == 1) {

           click = 1;

           canvas.getContext("2d").fillStyle = "rgb(0,0,0)";

           canvas.getContext("2d").fillRect(e.offsetX, e.offsetY, 8, 8);

           x = Math.floor(e.offsetY * 0.2)

           y = Math.floor(e.offsetX * 0.2) + 1

           for (var dy = 0; dy < 2; dy++){

               for (var dx = 0; dx < 2; dx++){

                   if ((x + dx < 28) && (y + dy < 28)){

                       pixels[(y+dy)+(x+dx)*28] = 1

                   }

               }

           }

       } else {

           if (click == 1) set_value()

           click = 0;

       }

   });

   

   function set_value(){

       var result = ""

       for (var i = 0; i < 28*28; i++) result += pixels[i] + ","

       var kernel = IPython.notebook.kernel;

       kernel.execute("image = [" + result + "]");

   }

   

   function clear_value(){

       canvas.getContext("2d").fillStyle = "rgb(255,255,255)";

       canvas.getContext("2d").fillRect(0, 0, 140, 140);

       for (var i = 0; i < 28*28; i++) pixels[i] = 0

   }

</script>

"""


다음 새로운 셀에서, 다음 코드를 입력하여, 앞서 코딩한 HTML 파일을 실행할 수 있도록 한다.


from IPython.display import HTML

HTML(input_form + javascript)


이제 앞에서 만든 두 셀을 실행시켜 보면 다음과 같이 HTML 기반으로 마우스를 이용하여 숫자를 입력할 수 있는 박스가 나오는것을 확인할 수 있다.



입력값 판정

앞의 HTML에서 그린 이미지는 앞의 코드의 set_value라는 함수에 의해서, image 라는 변수로 784 크기의 벡터에 저장된다. 이 값을 이용하여, 이 그림이 어떤 숫자인지를 앞서 만든 모델을 이용해서 예측을 해본다.


코드


p_val = sess.run(p, feed_dict={x:[image], keep_prob:1.0})


fig = plt.figure(figsize=(4,2))

pred = p_val[0]

subplot = fig.add_subplot(1,1,1)

subplot.set_xticks(range(10))

subplot.set_xlim(-0.5,9.5)

subplot.set_ylim(0,1)

subplot.bar(range(10), pred, align='center')

plt.show()

예측

예측을 하는 방법은 쉽다. 이미지 데이타가 image 라는 변수에 들어가 있기 때문에, 어떤 숫자인지에 대한 확률을 나타내는 p 의 값을 구하면 된다.


p_val = sess.run(p, feed_dict={x:[image], keep_prob:1.0})


를 이용하여 x에 image를 넣고, 그리고 dropout 비율을 0%로 하기 위해서 keep_prob를 1.0 (100%)로 한다. (예측이기 때문에 당연히 dropout은 필요하지 않다.)

이렇게 하면 이 이미지가 어떤 숫자인지에 대한 확률이 p에 저장된다.

그래프로 표현

그러면 이 p의 값을 찍어 보자


fig = plt.figure(figsize=(4,2))

pred = p_val[0]

subplot = fig.add_subplot(1,1,1)

subplot.set_xticks(range(10))

subplot.set_xlim(-0.5,9.5)

subplot.set_ylim(0,1)

subplot.bar(range(10), pred, align='center')

plt.show()


그래프를 이용하여 0~9 까지의 숫자 (가로축)일 확률을 0.0~1.0 까지 (세로축)으로 출력하게 된다.

다음은 위에서 입력한 숫자 “4”를 인식한 결과이다.



(보너스) 첫번째 컨볼루셔널 계층 결과 출력

컨볼루셔널 네트워크를 학습시키다 보면 종종 컨볼루셔널 계층을 통과하여 추출된 특징 이미지들이 어떤 모양을 가지고 있는지를 확인하고 싶을때가 있다. 그래서 각 필터를 통과한 값을 이미지로 출력하여 확인하고는 하는데, 여기서는 이렇게 각 필터를 통과하여 인식된 특징이 어떤 모양인지를 출력하는 방법을 소개한다.


아래는 우리가 만든 네트워크 중에서 첫번째 컨볼루셔널 필터를 통과한 결과 h_conv1과, 그리고 이 결과에 bias 값을 더하고 활성화 함수인 Relu를 적용한 결과를 출력하는 예제이다.


코드


conv1_vals, cutoff1_vals = sess.run(

   [h_conv1, h_conv1_cutoff], feed_dict={x:[image], keep_prob:1.0})


fig = plt.figure(figsize=(16,4))


for f in range(num_filters1):

   subplot = fig.add_subplot(4, 16, f+1)

   subplot.set_xticks([])

   subplot.set_yticks([])

   subplot.imshow(conv1_vals[0,:,:,f],

                  cmap=plt.cm.gray_r, interpolation='nearest')

plt.show()


x에 image를 입력하고, dropout을 없이 모든 네트워크를 통과하도록 keep_prob:1.0으로 주고, 첫번째 컨볼루셔널 필터를 통과한 값 h_conv1 과, 이 값에 bias와 Relu를 적용한 값 h_conv1_cutoff를 계산하였다.

conv1_vals, cutoff1_vals = sess.run(

   [h_conv1, h_conv1_cutoff], feed_dict={x:[image], keep_prob:1.0})


첫번째 필터는 총 32개로 구성되어 있기 때문에, 32개의 결과값을 imshow 함수를 이용하여 흑백으로 출력하였다.




다음은 bias와 Relu를 통과한 값인 h_conv_cutoff를 출력하는 예제이다. 위의 코드와 동일하며 subplot.imgshow에서 전달해주는 인자만 conv1_vals → cutoff1_vals로 변경되었다.


코드


fig = plt.figure(figsize=(16,4))


for f in range(num_filters1):

   subplot = fig.add_subplot(4, 16, f+1)

   subplot.set_xticks([])

   subplot.set_yticks([])

   subplot.imshow(cutoff1_vals[0,:,:,f],

                  cmap=plt.cm.gray_r, interpolation='nearest')

   

plt.show()


출력 결과는 다음과 같다



이제까지 컨볼루셔널 네트워크를 이용한 이미지 인식을 텐서플로우로 구현하는 방법을 MNIST(필기체 숫자 데이타)를 이용하여 구현하였다.


실제로 이미지를 인식하려면 전체적인 흐름은 같지만, 이미지를 전/후처리 해내야 하고 또한 한대의 머신이 아닌 여러대의 머신과 GPU와 같은 하드웨어 장비를 사용한다. 다음 글에서는 MNIST가 아니라 실제 칼라 이미지를 인식하는 방법에 대해서 데이타 전처리에서 부터 서비스까지 전체 과정에 대해서 설명하도록 하겠다.


예제 코드 : https://github.com/bwcho75/tensorflowML/blob/master/MNIST_CNN_Prediction.ipynb


파이어베이스 애널러틱스를 이용한 모바일 데이타 분석 #1-Hello Firebase

조대협 (http://bcho.tistory.com)


얼마전에 구글은 모바일 백앤드 플랫폼인 파이어베이스를 인수하고 이를 서비스로 공개하였다.

파이어 베이스는 모바일 백앤드의 종합 솔루션으로, 크래쉬 리포팅, 리모트 컨피그를 이용한 A/B 테스팅 플랫폼, 클라우드와 자동 동기화가 가능한 리얼타임 데이타 베이스, 사용자 인증 기능, 강력한 푸쉬 플랫폼 다양한 모바일 기기에 대해서 테스트를 해볼 수 있는 테스트랩 등, 모바일 앱 개발에 필요한 모든 서비스를 제공해주는 종합 패키지와 같은 플랫폼이라고 보면 된다. 안드로이드 뿐만 아니라 iOS까지 지원하여 모든 모바일 앱 개발에 공통적으로 사용할 수 있다.



그중에서 파이어베이스 애널러틱스 (Firebase analytics)는 모바일 부분은 모바일 앱에 대한 모든 이벤트를 수집 및 분석하여 자동으로 대쉬 보드를 통하여 분석을 가능하게 해준다.


이 글에서는 파이어베이스 전체 제품군중에서 파이어베이스 애널러틱스에 대해서 수회에 걸쳐서 설명을 하고자 한다.


파이어베이스 애널러틱스

이미 시장에는 모바일 앱에 대한 데이타 분석이 가능한 유료 또는 무료 제품이 많다.

대표적으로 야후의 flurry, 트위터 fabric, 구글 애널러틱스등이 대표적인 제품군인데, 그렇다면 파이어베이스가 애널러틱스가 가지고 있는 장단점은 무엇인가?


퍼널 분석 및 코호트 분석 지원

파이어베이스 애널러틱스는 데이타 분석 방법중에 퍼넬 분석과 코호트 분석을 지원한다.

퍼널 분석은 한글로 깔데기 분석이라고 하는데, 예를 들어 사용자가 가입한 후에, 쇼핑몰의 상품 정보를 보고  주문 및 결재를 하는 단계 까지 각 단계별로 사용자가 이탈하게 된다. 이 구조를 그려보면 깔데기 모양이 되는데,사용자 가입에서 부터 최종 목표인 주문 결재까지 이루도록 단계별로 이탈율을 분석하여 서비스를 개선하고, 이탈율을 줄이는데 사용할 수 있다.

코호트 분석은 데이타를 집단으로 나누어서 분석하는 방법으로 일일 사용자 데이타 (DAU:Daily Active User)그래프가 있을때, 일일 사용자가 연령별로 어떻게 분포가 되는지등을 나눠서 분석하여 데이타를 조금 더 세밀하게 분석할 수 있는 방법이다.


이러한 코호트 분석과 퍼넬 분석은 모바일 데이타 분석 플랫폼 중에서 일부만 지원하는데, 파이어베이스 애널러틱스는 퍼넬과 코호트 분석을 기본적으로 제공하고 있으며, 특히 코호트 분석으로 많이 사용되는 사용자 잔존율 (Retention 분석)의 경우 별다른 설정 없이도 기본으로 제공하고 있다.


<그림. 구글 파이어베이스의 사용자 잔존율 코호트 분석 차트>

출처 : https://support.google.com/firebase/answer/6317510?hl=en

무제한 앱 및 무제한 사용자 무료 지원

이러한 모바일 서비스 분석 서비스의 경우 사용자 수나 수집할 수 있는 이벤트 수나 사용할 수 있는 앱수에 제약이 있는데, 파이어베이스 애널러틱스의 경우에는 제약이 없다.

빅쿼리 연계 지원

가장 강력한 기능중의 하나이자, 이 글에서 주로 다루고자 하는 내용이 빅쿼리 연동 지원이다.

모바일 데이타 분석 서비스 플랫폼의 경우 대 부분 플랫폼 서비스의 형태를 띄기 때문에, 분석 플랫폼에서 제공해주는 일부 데이타만 볼 수 가 있고, 원본 데이타에 접근하는 것이 대부분 불가능 하다.

그래서 모바일 애플리케이션 서버에서 생성된 데이타나, 또는 광고 플랫폼등 외부 연동 플랫폼에서 온 데이타에 대한 연관 분석이 불가능하고, 원본 데이타를 통하여 여러가지 지표를 분석하는 것이 불가능하다.


파이어베이스 애널러틱스의 경우에는 구글의 데이타 분석 플랫폼이 빅쿼리 연동을 통하여 모든 데이타를 빅쿼리에 저장하여 간단하게 분석이 가능하다.

구글 빅쿼리에 대한 소개는 http://bcho.tistory.com/1116 를 참고하기 바란다.

구글의 빅쿼리는 아마존 S3나, 구글의 스토리지 서비스인 GCS 보다 저렴한 비용으로 데이타를 저장하면서도, 수천억 레코드에 대한 연산을 수십초만에 8~9000개의 CPU와 3~4000개의 디스크를 사용해서 끝낼만큼 어마어마한 성능을 제공하면서도, 사용료 매우 저렴하며 기존 SQL 문법을 사용하기 때문에, 매우 쉽게 접근이 가능하다.

모바일 데이타 분석을 쉽게 구현이 가능

보통 모바일 서비스에 대한 데이타 분석을 할때는 무료 서비스를 통해서 DAU나 세션과 같은 기본적인 정보 수집은 가능하지만, 추가적인 이벤트를 수집하여 저장 및 분석을 하거나 서버나 다른 시스템의 지표를 통합 분석 하는 것은 별도의 로그 수집 시스템을 모바일 앱과 서버에 만들어야 하였고, 이를 분석 및 저장하고 리포팅 하기 위해서 하둡이나 스파크와 같은 복잡한 빅데이타 기술을 사용하고 리포팅에도 많은 시간이 소요 되었다.


파이어베이스 애널러틱스를 이용하면, 손 쉽게, 추가 이벤트나 로그 정보를 기존의 로깅 프레임웍을 통하여 빅쿼리에 저장할 수 있고, 복잡한 하둡이나 스파크의 설치나 프로그래밍 없이 빅쿼리에서 간략하게 SQL만을 사용하여 분석을 하고 오픈소스 시각화 도구인 Jupyter 노트북이나 구글의 데이타스튜디오 (http://datastudio.google.com)을 통하여 시작화가 간단하기 때문에, 이제는 누구나 쉽게 빅데이타 로그를 수집하고 분석할 수 있게 된다.

실시간 데이타 분석은 지원하지 않음

파이어베이스 애널러틱스가 그러면 만능 도구이고 좋은 기능만 있는가? 그건 아니다. 파이어베이스 애널러틱스는 아직까지는 실시간 데이타 분석을 지원하고 있지 않다. 수집된 데이타는 보통 수시간이 지나야 대쉬 보드에 반영이 되기 때문에 현재 접속자나, 실시간 모니터링에는 적절하지 않다.

그래서 보완을 위해서 다른 모니터링 도구와 혼용해서 사용하는 게 좋다. 실시간 분석이 강한 서비스로는 트위터 fabric이나 Google analytics 등이 있다.

이러한 도구를 이용하여 데이타에 대한 실시간 분석을 하고, 정밀 지표에 대한 분석을 파이어베이스 애널러틱스를 사용 하는 것이 좋다.


파이어베이스 애널러틱스 적용해보기

백문이 불여일견이라고, 파이어베이스 애널러틱스를 직접 적용해보자.

https://firebase.google.com/ 사이트로 가서, 가입을 한 후에, “콘솔로 이동하기"를 통해서 파이어 베이스 콘솔로 들어가자.

프로젝트 생성하기

다음으로 파이어베이스 프로젝트를 생성한다. 상단 메뉴에서 “CREATE NEW PROJECT”를 선택하면 새로운 파이어 베이스 프로젝트를 생성할 수 있다. 만약에 기존에 사용하던 구글 클라우드 프로젝트등이 있으면 별도의 프로젝트를 생성하지 않고 “IMPORT GOOGLE PROJECT”를 이용하여 기존의 프로젝트를 불러와서 연결할 수 있다.



프로젝트가 생성되었으면 파이어베이스를 사용하고자 하는 앱을 등록해야 한다.

파이어베이스 화면에서 “ADD APP” 이라는 버튼을 누르면 앱을 추가할 수 있다.

아래는 앱을 추가하는 화면중 첫번째 화면으로 앱에 대한 기본 정보를 넣는 화면이다.

“Package name” 에, 파이어베이스와 연동하고자 하는 안드로이드 앱의 패키지 명을 넣는다.


ADD APP 버튼을 누르고 다음 단계로 넘어가면 google-services.json 이라는 파일이 자동으로 다운된다. 이 파일은 나중에 안드로이드 앱의 소스에 추가해야 하기 때문에 잘 보관한다.


Continue 버튼을 누르면 아래와 같이 다음 단계로 넘어간다. 다음 단계에서는 안드로이드 앱을 개발할때 파이어베이스를 연동하려면 어떻게 해야 하는지에 대한 가이드가 나오는데, 이 부분은 나중에 코딩 부분에서 설명할 예정이니 넘어가도록 하자.


자 이제 파이어베이스 콘솔에서, 프로젝트를 생성하고 앱을 추가하였다.

이제 연동을 할 안드로이드 애플리케이션을 만들어보자.

안드로이드 빌드 환경 설정

콘솔에서 앱이 추가되었으니, 이제 코드를 작성해보자, 아래 예제는 안드로이드 스튜디오 2.1.2 버전 (맥 OS 기준) 으로 작성되었다.


먼저 안드로이드 프로젝트를 생성하였다. 이때 반드시 안드로이드 프로젝트에서 앱 패키지 명은 앞에 파이어베이스 콘솔에서 지정한 com.terry.hellofirebase가 되어야 한다.

안드로이드 프로젝트에는 프로젝트 레벨의 build.gradle 파일과, 앱 레벨의 build.gradle 파일이 있는데



프로젝트 레벨의 build.gradle 파일에 classpath 'com.google.gms:google-services:3.0.0' 를 추가하여  다음과 같이 수정한다.


// Top-level build file where you can add configuration options common to all sub-projects/modules.


buildscript {

  repositories {

      jcenter()

  }

  dependencies {

      classpath 'com.android.tools.build:gradle:2.1.2'

      classpath 'com.google.gms:google-services:3.0.0'

      // NOTE: Do not place your application dependencies here; they belong

      // in the individual module build.gradle files

  }

}


allprojects {

  repositories {

      jcenter()

  }

}


task clean(type: Delete) {

  delete rootProject.buildDir

}



다음으로, 앱레벨의 build.gradle 파일도 dependencies 부분에    compile 'com.google.firebase:firebase-core:9.4.0' 를 추가하고, 파일 맨 아래 apply plugin: 'com.google.gms.google-services' 를 추가 하여 아래와 같이 수정한다.

apply plugin: 'com.android.application'


android {

  compileSdkVersion 24

  buildToolsVersion "24.0.2"


  defaultConfig {

      applicationId "com.terry.hellofirebase"

      minSdkVersion 16

      targetSdkVersion 24

      versionCode 1

      versionName "1.0"

  }

  buildTypes {

      release {

          minifyEnabled false

          proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'

      }

  }

}


dependencies {

  compile fileTree(dir: 'libs', include: ['*.jar'])

  testCompile 'junit:junit:4.12'

  compile 'com.android.support:appcompat-v7:24.2.0'

  compile 'com.google.firebase:firebase-core:9.4.0'

}

apply plugin: 'com.google.gms.google-services'



그리고 파이어베이스 콘솔에서 앱을 추가할때 다운된 google-services.json 파일을 app디렉토리에 복사한다.




이 예제의 경우에는 /Users/terrycho/AndroidStudioProjects/HelloFireBase에 프로젝트를 만들었기 때문에,  /Users/terrycho/AndroidStudioProjects/HelloFireBase/app 디렉토리에 복사하였다.


Gradle 파일 수정이 끝나고, google-services.json 파일을 복사하였으면 안드로이드 스튜디오는 gradle 파일이 변경이 되었음을 인지하고 sync를 하도록 아래 그림과 같이 “Sync now”라는 버튼이 상단에 표시된다.


“Sync now”를 눌러서 프로젝트를 동기화 한다.

예제 코드 만들기

이제 안드로이드 스튜디오의 프로젝트 환경 설정이 완료되었다. 이제, 예제 코드를 만들어 보자.

이 예제 코드는 단순하게, 텍스트 박스를 통해서 아이템 ID,이름, 그리고 종류를 입력 받아서, 파이어베이스 애널러틱스에 이벤트를 로깅하는 예제이다.

파이어베이스 애널러틱스 서버로 로그를 보낼 것이기 때문에, AndroidManifest 파일에 아래와 같이  수정하여 INTERNET과 ACCESS_NETWORK_STATE 권한을 추가한다.

<?xml version="1.0" encoding="utf-8"?>

<manifest xmlns:android="http://schemas.android.com/apk/res/android"

  package="com.terry.hellofirebase">

  <uses-permission android:name="android.permission.INTERNET" />

  <uses-permission android:name="android.permission.ACCESS_NETWORK_STATE" />

 

  <application

      android:allowBackup="true"

      android:icon="@mipmap/ic_launcher"

      android:label="@string/app_name"

      android:supportsRtl="true"

      android:theme="@style/AppTheme">

      <activity android:name=".MainActivity">

          <intent-filter>

              <action android:name="android.intent.action.MAIN" />


              <category android:name="android.intent.category.LAUNCHER" />

          </intent-filter>

      </activity>

  </application>


</manifest>


다음으로 화면을 구성해야 하는데, 우리가 구성하려는 화면 레이아웃은 대략 다음과 같다.



각각의 EditText 컴포넌트는 tv_contentsId, tv_contentsName,tv_contentsCategory로 지정하였다.

위의 레이아웃을 정의한 activity_main.xml은 다음과 같다.


<?xml version="1.0" encoding="utf-8"?>

<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"

  xmlns:tools="http://schemas.android.com/tools"

  android:layout_width="match_parent"

  android:layout_height="match_parent"

  android:paddingBottom="@dimen/activity_vertical_margin"

  android:paddingLeft="@dimen/activity_horizontal_margin"

  android:paddingRight="@dimen/activity_horizontal_margin"

  android:paddingTop="@dimen/activity_vertical_margin"

  tools:context="com.terry.hellofirebase.MainActivity">


  <LinearLayout

      android:orientation="vertical"

      android:layout_width="match_parent"

      android:layout_height="match_parent"

      android:layout_alignParentLeft="true"

      android:layout_alignParentStart="true">


      <TextView

          android:layout_width="wrap_content"

          android:layout_height="wrap_content"

          android:textAppearance="?android:attr/textAppearanceMedium"

          android:text="Contents ID"

          android:id="@+id/tv_contetnsId" />


      <EditText

          android:layout_width="match_parent"

          android:layout_height="wrap_content"

          android:id="@+id/txt_contentsId"

          android:layout_gravity="center_horizontal" />


      <TextView

          android:layout_width="wrap_content"

          android:layout_height="wrap_content"

          android:textAppearance="?android:attr/textAppearanceMedium"

          android:text="Contents Name"

          android:id="@+id/tv_contentsName" />


      <EditText

          android:layout_width="match_parent"

          android:layout_height="wrap_content"

          android:id="@+id/txt_contentsName" />


      <TextView

          android:layout_width="wrap_content"

          android:layout_height="wrap_content"

          android:textAppearance="?android:attr/textAppearanceMedium"

          android:text="Contents Category"

          android:id="@+id/tv_contentsCategory" />


      <EditText

          android:layout_width="match_parent"

          android:layout_height="wrap_content"

          android:id="@+id/txt_contentsCategory" />


      <Button

          android:layout_width="wrap_content"

          android:layout_height="wrap_content"

          android:text="Send Event"

          android:id="@+id/btn_sendEvent"

          android:layout_gravity="center_horizontal"

          android:onClick="onSendEvent" />

  </LinearLayout>

</RelativeLayout>


레이아웃 설계가 끝났으면, SEND EVENT 버튼을 눌렀을때, 이벤트를 파이어베이스 애널러틱스 서버로 보내는 코드를 만들어 보자.

MainActivity인 com.terry.hellofirebase.MainActivity 클래스의 코드는 다음과 같다.


package com.terry.hellofirebase;


import android.support.v7.app.AppCompatActivity;

import android.os.Bundle;

import android.view.View;

import android.widget.EditText;

import android.widget.Toast;


import com.google.firebase.analytics.FirebaseAnalytics;


public class MainActivity extends AppCompatActivity {


  // add firebase analytics object

  private FirebaseAnalytics mFirebaseAnalytics;


  @Override

  protected void onCreate(Bundle savedInstanceState) {

      super.onCreate(savedInstanceState);

      mFirebaseAnalytics = FirebaseAnalytics.getInstance(this);

      setContentView(R.layout.activity_main);

  }


  public void onSendEvent(View view){

      String contentsId;

      String contentsName;

      String contentsCategory;


      EditText txtContentsId = (EditText)findViewById(R.id.txt_contentsId);

      EditText txtContentsName = (EditText)findViewById(R.id.txt_contentsName);

      EditText txtContentsCategory = (EditText)findViewById(R.id.txt_contentsCategory);


      contentsId = txtContentsId.getText().toString();

      contentsName = txtContentsName.getText().toString();

      contentsCategory = txtContentsCategory.getText().toString();


      Bundle bundle = new Bundle();

      bundle.putString(FirebaseAnalytics.Param.ITEM_ID, contentsId);

      bundle.putString(FirebaseAnalytics.Param.ITEM_NAME, contentsName);

      bundle.putString(FirebaseAnalytics.Param.CONTENT_TYPE, contentsCategory);

      mFirebaseAnalytics.logEvent(FirebaseAnalytics.Event.SELECT_CONTENT, bundle);


      Toast.makeText(getApplicationContext(), "Sent event", Toast.LENGTH_LONG).show();

  }

}


MainActivity 클래스에 FirebaseAnalytics 객체를 mFirebaseAnalytics라는 이름으로 정의하고 onCreate메서드에서 FirebaseAnalytics.getInstance(this) 메서드를 이용하여 파이어베이스 애널러틱스 객체를 생성한다.


다음 onSendEvent라는 메서드를 구현한다. 이 메서드는 화면에서 “SEND EVENT”라는 버튼을 누르면 EditText 박스에서 입력된 값으로 SELECT_CONTENT라는 이벤트를 만들어서 파이어베이스 애널러틱스 서버로 보내는 기능을 한다.

컨텐츠 ID,NAME,CATEGORY를 EditText 박스에서 읽어온 후에, Bundle 이라는 객체를 만들어서 넣는다.

파이어베이스 애널러틱스 로그는 이벤트와 번들이라는 개념으로 구성이 된다.

이벤트는 로그인, 컨텐츠 보기, 물품 구매와 같은 이벤트이고, Bundle은 이벤트에 구체적인 인자를 묶어서 저장하는 객체이다. 위의 예제인 경우 SELECT_CONTENTS 라는 이벤트가 발생할때 컨텐츠 ID, 이름(Name), 종류(Category)를 인자로 하여, Bundle에 묶어서 전달하도록 하였다.

Bundle 클래스를 생성한후, bundle.putString(“인자명",”인자값") 형태로 Bundle 객체를 설정한 후에, mFirebaseAnalytics.logEvent(“이벤트명",”Bundle 객체") 메서드를 이용하여 SELECT_CONTENTS 이벤트에 앞서 작성한 Bundle을 통하여 인자를 전달하였다.


앱 개발이 모두 완료되었다. 이제 테스트를 해보자

실행하기

앱을 실행하고 아래와 같이 데이타를 넣어보자


컨텐츠 ID는 200, 컨텐츠 이름은 W, 그리고 컨텐츠 종류는 webtoon으로 입력하였다.

SEND EVENT 눌러서 이벤트를 보내서 파이어베이스 웹콘솔에 들어가서 Analytics 메뉴에 상단 메뉴인 “Events”를 선택하면 처음에는 아무런 값이 나오지 않는다.

앞에서 설명했듯이 파이어베이스 애널러틱스는 아직까지 실시간 분석을 지원하지 않기 때문에 수시간이 지난 후에야 그 값이 반영 된다.


본인의 경우 밤 12시에 테스트를 진행하고 아침 9시경에 확인을 하였더니 아래와 같은 결과를 얻을 수 있었다.



실제로 테스트 시에 select contents 이벤트를 3번을 보냈더니, Count가 3개로 나온다.

그러나 이벤트에 보낸 컨텐츠 ID, 이름 , 분류등은 나타나지 않는다. 기본 설정에서는 이벤트에 대한 디테일 정보를 얻기가 어렵다. 그래서 빅쿼리 연동이 필요한데 이는 후에 다시 다루도록 하겠다.


Dashboard 메뉴를 들어가면 다음과 같이 지역 분포나 단말명등 기본적인 정보를 얻을 수 있다.



이벤트와 이벤트 인자

앞서처럼 이벤트와 인자등을 정해줬음에도 불구하고 대쉬보드나 기타 화면에 수치들이 상세하지 않은 것을 인지할 수 있다. 정확한 데이타를 분석하려면 마찬가지로 정확한 데이타를 보내줘야 하는데, 화면 로그인이나 구매등과 같은 앱에서의 이벤트를 앱 코드내에 삽입해줘야 상세한 분석이 가능하다.

이벤트는 https://firebase.google.com/docs/reference/android/com/google/firebase/analytics/FirebaseAnalytics.Event 에 정의가 되어 있고, 각 이벤트별 인자에 대한 설명은 https://firebase.google.com/docs/reference/android/com/google/firebase/analytics/FirebaseAnalytics.Param 에 있는데, 이미 파이어베이스에서는 게임이나 미디어 컨텐츠, 쇼핑과 같은 주요 모바일 앱 시나리오에 대해서 이벤트와 인자들은 미리 정의해놓았다.

https://support.google.com/firebase/topic/6317484?hl=en&ref_topic=6386699

를 보면 모바일 앱의 종류에 따라서 어떠한 이벤트를 사용해야 하는지가 정의되어 있다.


또한 미리 정의되어 있는 이벤트 이외에도 사용자가 직접 이벤트를 정의해서 사용할 수 있다.  이러한 이벤트를 커스텀 이벤트라고 하는데 https://firebase.google.com/docs/analytics/android/events 를 참고하면 된다.


지금까지 간략하게 나마 파이어베이스 애널러틱스의 소개와 예제 코드를 통한 사용 방법을 알아보았다.

모바일 데이타 분석이나 빅데이타 분석에서 가장 중요한 것은 데이타를 모으는 것도 중요하지만, 모아진 데이타에 대한 지표 정의와 그 의미를 파악하는 것이 중요하다. 그래서 다음 글에서는 파이어베이스 애널러틱스에 정의된 이벤트의 종류와 그 의미 그리고, 대쉬 보드를 해석하는 방법에 대해서 설명하고, 그 후에 빅쿼리 연동을 통해서 상세 지표 분석을 하는 방법에 대해서 소개하고자 한다.


node.js에서 Redis 사용하기


조대협 (http://bcho.tistory.com)


Redis NoSQL 데이타 베이스의 종류로, mongoDB 처럼 전체 데이타를 영구히 저장하기 보다는 캐쉬처럼 휘발성이나 임시성 데이타를 저장하는데 많이 사용된다.

디스크에 데이타를 주기적으로 저장하기는 하지만, 기능은 백업이나 복구용으로 주로 사용할뿐 데이타는 모두 메모리에 저장되기 때문에, 빠른 접근 속도를 자랑한다.

 

이유 때문에 근래에는 memcached 다음의 캐쉬 솔루션으로 널리 사용되고 있는데, 간단하게 -밸류 (Key-Value)형태의 데이타 저장뿐만 아니라, 다양한 데이타 타입을 지원하기 때문에 응용도가 높고, node.js 호환 모듈이 지원되서 node.js 궁합이 좋다. 여러 node.js 클러스터링 하여 사용할때, node.js 인스턴스간 상태정보를 공유하거나, 세션과 같은 휘발성 정보를 저장하거나 또는 캐쉬등으로 다양하게 사용할 있다.

 

Redis 제공하는 기능으로는 키로 데이타를 저장하고 조회하는 Set/Get 기능이 있으며, 메세지를 전달하기 위한 큐로도 사용할 있다.

 

큐로써의 기능은 하나의 클라이언트가 다른 클라이언트로 메세지를 보내는 1:1 기능뿐 아니라, 하나의 클라이언트가 다수의 클라이언트에게 메세지를 발송하는 발행/배포 (Publish/Subscribe) 기능을 제공한다.




그림 1 RedisPublish/Subscribe의 개념 구조

 

재미있는 것중에 하나는 일반적인 Pub/Sub 시스템의 경우 Subscribe 하는 하나의 Topic에서만 Subscribe하는데 반해서, redis에서는 pattern matching 통해서 다수의 Topic에서 message subscribe 있다.

예를 들어 topic 이름이 music.pop music,classic 이라는 두개의 Topic 있을때, "PSUBSCRIBE music.*"라고 하면 두개의 Topic에서 동시에 message subscribe 있다.

 

자료 구조

 

Redis 가장 기본이 되는 자료 구조를 살펴보자. Redis 다양한 자료 구조를 지원하는데, 지원하는 자료 구조형은 다음과 같다.

1)       String

Key 대해서 문자열을 저장한다. 텍스트 문자열뿐만 아니라 숫자나 최대 512mbyte 까지의 바이너리도 저장할 있다.

 

2)       List

Key 대해서 List 타입을 저장한다. List에는 값들이 들어갈 있으며, INDEX 값을 이용해서 지정된 위치의 값을 넣거나 있고, 또는 push/pop 함수를 이용하여 리스트 앞뒤에 데이타를 넣거나 있다. 일반적인 자료 구조에서 Linked List 같은 자료 구조라고 생각하면 된다.

 

3)       Sets

Set 자료 구조는 집합이라고 생각하면 된다. Key 대해서 Set 저장할 있는데, List 구조와는 다르게 주의할점은 집합이기 때문에 같은 값이 들어갈 없다. 대신 집합의 특성을 이용한 집합 연산, 교집합, 합집합등의 연산이 가능하다.

 

4)       Sorted Set

SortedSet Set 동일하지만, 데이타를 저장할때, value 이외에, score 라는 값을 같이 저장한다. 그리고 score 라는 값에 따라서 데이타를 정렬(소팅)해서 저장한다. 순차성이나 순서가 중요한 데이타를 저장할때 유용하게 저장할 있다.

 

5)       Hashes

마지막 자료형으로는 Hashes 있는데, 해쉬 자료 구조를 생각하면 된다.Key 해쉬 테이블을 저장하는데, 해쉬 테이블에 저장되는 데이타는 (field, value) 형태로 field 해쉬의 키로 저장한다.

키가 있는 데이타를 군집하여 저장하는데 유용하며 데이타의 접근이 매우 빠르다. 순차적이지 않고 비순차적인 랜덤 액세스 데이타에 적절하다.

 

설명한 자료 구조를 Redis 저장되는 형태로 표현하면 다음과 같다.

 



Figure 36 redis의 자료 구조

 

기본적으로 /밸류 (Key/Value) 형태로 데이타가 저장되며, 밸류에 해당하는 데이타 타입은 앞서 언급하 String, List, Sets, SortedSets, Hashes 있다.

 

Redis 대한 설명은 여기서는 자세하게 하지 않는다. 독립적인 제품인 만큼 가지고 있는 기능과 운영에 신경써야할 부분이 많다. Redis 대한 자세한 설명은 http://redis.io 홈페이지를 참고하거나 정경석씨가 이것이 레디스다http://www.yes24.com/24/Goods/11265881?Acode=101 라는 책을 추천한다. 단순히 redis 대한 사용법뿐만 아니라, 레디스의 데이타 모델 설계에 대한 자세한 가이드를 제공하고 있다.

 

Redis 설치하기

개발환경 구성을 위해서 redis 설치해보자.

 

맥의 경우 애플리케이션 설치 유틸리티인 brew 이용하면 간단하게 설치할 있다.

%brew install redis

 

윈도우즈

안타깝게도 redis 공식적으로는 윈도우즈 인스톨을 지원하지 않는다. http://redis.io에서 소스 코드를 다운 받아서 컴파일을 해서 설치를 해야 하는데, 만약에 이것이 번거롭다면, https://github.com/rgl/redis/downloads 에서 다운로드 받아서 설치할 있다. 그렇지만 이경우에는 최신 버전을 지원하지 않는다.

그래서 vagrant 이용하여 우분투 리눅스로 개발환경을 꾸미고 위에 redis 설치하거나 https://redislabs.com/pricing https://www.compose.io  같은 클라우드 redis 환경을 사용하기를 권장한다. ( 클라우드 서비스의 경우 일정 용량까지 무료 또는 일정 기간까지 무료로 서비스를 제공한다.)

 

리눅스

리눅스의 경우 설치가 매우 간단하다. 우분투의 경우 패키지 메니저인 apt-get 이용해서 다음과 같이 설치하면 된다.

%sudo apt-get install redis-server

 

설치가 끝났으면 편하게 redis 사용하기 위해서 redis 클라이언트를 설치해보자.

여러 GUI 클라이언트들이 많지만, 편하게 사용할 있는 redis desktop 설치한다. http://redisdesktop.com/ 에서 다운 받은 후에 간단하게 설치할 있다.

 

이제 환경 구성이 끝났으니, redis 구동하고 제대로 동작하는지 테스트해보자

%redis-server

명령을 이용해서 redis 서버를 구동한다.

 



Figure 37 redis 기동 화면

 

redis desktop 이용해서 localhost 호스트에 Host 주소는 localhost TCP 포트는 6379 새로운 Connection 추가하여 연결한다.

 

 



Figure 38 redis desktop에서 연결을 설정하는 화면

 

연결이 되었으면 redis desktop에서 Console 연다.

 



Figure 39 redis desktop에서 콘솔을 여는 화면

 

Console에서 다음과 같이 명령어를 입력해보자

 

localhost:0>set key1 myvalue

OK

 

localhost:0>set key2 myvalue2

OK

 

localhost:0>get key2

myvalue2

 

localhost:0>

Figure 40 redis desktop에서 간단한 명령을 통해서 redis를 테스트 하는 화면


위의 명령은 key1 myvalue라는 값을 입력하고, key2 myvalue2라는 값을 입력한 후에, key2 입력된 값을 조회하는 명령이다.

 

Redis desktop에서, 디비를 조회해보면, 앞서 입력한 /밸류 값이 저장되어 있는 것을 다음과 같이 확인할 있다.

\


Figure 41 redis에 저장된 데이타를 redis desktop을 이용해서 조회하기

 

node.js에서 redis 접근하기

 

이제 node.js에서 redis 사용하기 위한 준비가 끝났다. 간단한 express API 만들어서 redis 캐쉬로 사용하여 데이타를 저장하고 조회하는 예제를 작성해보자

 

node.js redis 클라이언트는 여러 종류가 있다. http://redis.io/clients#nodejs

가장 널리 쓰는 클라이언트 모듈로는 node-redis https://github.com/NodeRedis/node_redis 있는데, 예제는 node-redis 클라이언트를 기준으로 설명한다.

 

예제는 profile URL에서 사용자 데이타를 JSON/POST 받아서 redis 저장하고, TTL(Time to Leave) 방식의 캐쉬 처럼 10 후에 삭제되도록 하였다.

그리고 GET /profile/{사용자 이름} 으로 redis 저장된 데이타를 조회하도록 하였다.

 

먼저 node-redis 모듈과, json 문서를 처리하기 위해서 JSON 모듈을 사용하기 때문에, 모듈을 설치하자

% npm install redis

% npm install JSON

 

package.json 모듈의 의존성을 다음과 같이 정의한다.

 

 

{

  "name": "RedisCache",

  "version": "0.0.0",

  "private": true,

  "scripts": {

    "start": "node ./bin/www"

  },

  "dependencies": {

    "body-parser": "~1.13.2",

    "cookie-parser": "~1.3.5",

    "debug": "~2.2.0",

    "express": "~4.13.1",

    "jade": "~1.11.0",

    "morgan": "~1.6.1",

    "serve-favicon": "~2.3.0",

    "redis":"~2.6.0",

    "JSON":"~1.0.0"

  }

}

 

Figure 42 redisJSON 모듈의 의존성이 추가된 package.json

 

다음으로 express 간단한 프로젝트를 만든 후에, app.js 다음과 같은 코드를 추가한다.

 

 

// redis example

var redis = require('redis');

var JSON = require('JSON');

client = redis.createClient(6379,'127.0.0.1');

 

app.use(function(req,res,next){

      req.cache = client;

      next();

})

app.post('/profile',function(req,res,next){

      req.accepts('application/json');

     

      var key = req.body.name;

      var value = JSON.stringify(req.body);

     

      req.cache.set(key,value,function(err,data){

           if(err){

                 console.log(err);

                 res.send("error "+err);

                 return;

           }

           req.cache.expire(key,10);

           res.json(value);

           //console.log(value);

      });

})

app.get('/profile/:name',function(req,res,next){

      var key = req.params.name;

     

      req.cache.get(key,function(err,data){

           if(err){

                 console.log(err);

                 res.send("error "+err);

                 return;

           }

 

           var value = JSON.parse(data);

           res.json(value);

      });

});

 

Figure 43 app.jsredis에 데이타를 쓰고 읽는 부분

 

redis 클라이언트와, JSON 모듈을 로딩한후, createClient 메서드를 이용해서, redis 대한 연결 클라이언트를 생성하자.

 

client = redis.createClient(6379,'127.0.0.1');

 

app.use(function(req,res,next){

      req.cache = client;

      next();

})

 

다음 연결 객체를 express router에서 쉽게 가져다 있도록, 미들웨어를 이용하여 req.cache 객체에 저장하도록 하자.

 

HTTP POST /profile 의해서 사용자 프로파일 데이타를 저장하는 부분을 보면

req.accepts('application/json'); 이용하여 JSON 요청을 받아드리도록 한다.

JSON내의 name 필드를 키로, 하고, JSON 전체를 밸류로 한다. JSON 객체 형태로 redis 저장할 있겠지만 경우 redis에서 조회를 하면 객체형으로 나오기 때문에 운영이 불편하다. 그래서 JSON.stringfy 이용하여 JSON 객체를 문자열로 변환하여 value 객체에 저장하였다.

다음 req.cache.set(key,value,function(err,data) 코드에서 redis 저장하기 위해서 redis 클라이언트를 req 객체에서 조회해온후, set 명령을 이용해서 /밸류 값을 저장한다. 저장이 끝나면 뒤에 인자로 전달된 콜백함수 호출 되는데, 콜백함수에서, req.cache.expire(key,10); 호출하여, 키에 대한 데이타 저장 시간을 10초로 설정한다. (10 후에는 데이타가 삭제된다.) 마지막으로 res.json(value); 이용하여 HTTP 응답에 JSON 문자열을 리턴한다.

 

HTTP GET으로 /profile/{사용자 이름} 요청을 받아서 키가 사용자 이름은 JSON 데이타를 조회하여 리턴하는 코드이다.

app.get('/profile/:name',function(req,res,next) 으로 요청을 받은 , URL에서 name 부분을 읽어서 키값으로 하고,

req.cache.get(key,function(err,data){ 이용하여, 키를 가지고 데이타를 조회한다. 콜백 함수 부분에서, 데이타가 문자열 형태로 리턴되는데, 이를 var value = JSON.parse(data); 이용하여, JSON 객체로 변환한 후에, res.json(value); 통해서 JSON 문자열로 리턴한다.

 

코드 작성이 끝났으면 테스트를 해보자 HTTP JSON/POST REST 호출을 보내야 하기 때문에, 별도의 클라이언트가 필요한데, 클라이언트는 구글 크롬 브라우져의 플러그인인 포스트맨(POSTMAN) 사용하겠다. https://chrome.google.com/webstore/detail/postman/fhbjgbiflinjbdggehcddcbncdddomop

 

포스트맨 설치가 끝났으면, 포스트맨에서 HTTP POST/JSON 방식으로 http://localhost:3000/profile 아래와 같이 요청을 보낸다.

 



Figure 44 포스트맨에서 HTTP POSTprofile 데이타를 삽입하는 화면

 

요청을 보낸후 바로 HTTP GET으로 http://localhost:3000/profile/terry 조회를 하면 아래와 같이 앞에서 입력한 데이타가 조회됨을 확인할 있다. 이때 위의 POST 요청을 보낸지 10 내에 조회를 해야 한다. 10초가 지나면 앞서 지정한 expire 의해서 자동으로 삭제가된다.



Figure 45 포스트맨에서 사용자 이름이 terry인 데이타를 조회하는 화면

 

Redisdesktop에서 확인을 해보면 아래와 같이 문자열로 terry 사용자에 대한 데이타가 저장되어 있는 것을 확인할 있다.



Figure 46 redis desktop 에서 입력된 데이타를 확인하는 화면

 

10초후에, 다시 조회를 해보면, terry 키로 가지는 데이타가 삭제된 것을 확인할 있다.

 

지금까지 가장 기본적인 redis 대한 소개와 사용법에 대해서 알아보았다. redis 뒤에 나올 node.js 클러스터링의 HTTP 세션을 저장하는 기능이나, Socket.IO 등에서도 계속해서 사용되는 중요한 솔루션이다. Redis 자체를 다루는 것이 아니라서 자세하게 파고 들어가지는 않았지만, 다소 운영이 까다롭고 특성을 파악해서 설계해야 하는 만큼 반드시 시간을 내서 redis 자체에 대해서 조금 자세하게 살펴보기를 권장한다.

monk 모듈을 이용한 mongoDB 연결


조대협 (http://bcho.tistory.com)


mongoDB 기반의 개발을 하기 위해서 mongoDB를 설치한다. https://www.mongodb.org/ 에서 OS에 맞는 설치 파일을 다운로드 받아서 설치한다.

설치가 된 디렉토리에 들어가서 설치디렉토리 아래 ‘./data’ 라는 디렉토리를 만든다. 이 디렉토리는 mongoDB의 데이타가 저장될 디렉토리이다.

 

mongoDB를 구동해보자.

% ./bin/mongod --dbpath ./data



Figure 1 mongoDB 구동화면


구동이 끝났으면 mongoDB에 접속할 클라이언트가 필요하다. DB에 접속해서 데이타를 보고 쿼리를 수행할 수 있는 클라이언트가 필요한데, 여러 도구가 있지만 많이 사용되는 도구로는 roboMongo라는 클라이언트가 있다.

https://robomongo.org/download 에서 다운로드 받을 수 있다. OS에 맞는 설치 파일을 다운로드 받아서 설치 후 실행한다.

 

설치 후에, Create Connection에서, 로컬호스트에 설치된 mongoDB를 연결하기 위해서 연결 정보를 기술하고, 연결을 만든다





Figure 2 robomongo에서 localhost에 있는 mongodb 연결 추가

 

주소는 localhost, 포트는 디폴트 포트로 27017를 넣으면 된다.

 

환경이 준비가 되었으면 간단한 테스트를 해보자. 테스트 전에 기본적인 개념을 숙지할 필요가 있는데, mongoDBNoSQL 계열중에서도 도큐먼트DB (Document DB)에 속한다. 기존 RDBMS에서 하나의 행이 데이타를 표현했다면, mogoDB는 하나의 JSON 파일이 하나의 데이타를 표현한다. JSON을 도큐먼트라고 하기 때문에, 도큐먼트 DB라고 한다.

 

제일 상위 개념은 DB의 개념에 대해서 알아보자, DB는 여러개의 테이블(컬렉션)을 저장하는 단위이다.

Robomongo에서 mydb라는 이름으로 DB 를 생성해보자



Figure 3 robomongo에서 새로운  DB를 추가 하는 화면

 

다음으로 생성된 DB안에, 컬렉션을 생성한다. 컬렉션은 RDBMS의 단일 테이블과 같은 개념이다.

Robomongo에서 다음과 같이 ‘users’라는 이름의 컬렉션을 생성한다



Figure 4 robomongo에서 컬렉션(Collection) 생성

 

users 컬렉션에는 userid를 키로 해서, sex(성별), city(도시) 명을 입력할 예정인데, userid가 키이기 때문에, userid를 통한 검색이나 소팅등이 발생한다. 그래서 userid를 인덱스로 지정한다.

인덱스 지정 방법은 createIndex 명령을 이용한다. 다음과 같이 robomongo에서 createIndex 명령을 이용하여 인덱스를 생성한다.



Figure 5 users 컬렉션에서 userid를 인덱스로 지정

 

mongoDB는 디폴트로, 각 컬렉션마다 “_id”라는 필드를 가지고 있다. 이 필드는 컬렉션 안의 데이타에 대한 키 값인데, 12 바이트의 문자열로 이루어져 있고 ObjectId라는 포맷으로 시간-머신이름,프로세스ID,증가값형태로 이루어지는 것이 일반적이다.

_id 필드에 userid를 저장하지 않고 별도로 인덱스를 만들어가면서 까지 userid 필드를 별도로 사용하는 것은 mongoDBNoSQL의 특성상 여러개의 머신에 데이타를 나눠서 저장한다. 그래서 데이타가 여러 머신에 골고루 분산되는 것이 중요한데, 애플리케이션상의 특정 의미를 가지고 있는 필드를 사용하게 되면 데이타가 특정 머신에 쏠리는 현상이 발생할 수 있다.

예를 들어서, 주민번호를 _id로 사용했다면, 데이타가 골고루 분산될것 같지만, 해당 서비스가 10~20대에만 인기있는 서비스라면, 10~20대 데이타를 저장하는 머신에만 데이타가 몰리게 되고, 10세이하나, 20세 이상의 데이타를 저장하는 노드에는 데이타가 적게 저장된다.

이런 이유등으로 mongoDB를 지원하는 node.js 드라이버에서는 _id 값을 사용할때, 앞에서 언급한 ObjectId 포맷을 따르지 않으면 에러를 내도록 설계되어 있다. 우리가 앞으로 살펴볼 mongoosemonk의 경우에도 마찬가지이다.

 

이제 데이타를 집어넣기 위한 테이블(컬렉션) 생성이 완료되었다.

다음 컬렉션 에 대한 CRUD (Create, Read, Update, Delete) 를 알아보자

SQL 문장과 비교하여, mongoDB에서 CRUD 에 대해서 알아보면 다음과 같다.

CRUD

SQL

MongoDB

Create

insert into users ("name","city") values("terry","seoul")

db.users.insert({userid:"terry",city:"seoul"})

Read

select * from users where id="terry"

db.users.find({userid:"terry"})

Update

update users set city="busan" where _id="terry"

db.users.update( {userid:"terry"}, {$set :{ city:"Busan" } } )

Delete

delete from users where _id="terry"

db.users.remove({userid:"terry"})

Figure 6 SQL문장과 mongoDB 쿼리 문장 비교


mongoDB에서 쿼리는 위와 같이 db.{Collection }.{명령어} 형태로 정의된다.

roboMongo에서 insert 쿼리를 수행하여 데이타를 삽입해보자



Figure 7 mongoDB에서 users 컬렉션에 데이타 추가

 

다음으로 삽입한 데이타를 find 명령을 이용해 조회해보자



Figure 8 mongoDB에서 추가된 데이타에 대한 확인

 

mongoDB에 대한 구조나 자세한 사용 방법에 대해서는 여기서는 설명하지 않는다.

http://www.tutorialspoint.com/mongodb/ mongoDB에 대한 전체적인 개념과 주요 쿼리들이 간략하게 설명되어 있으니 이 문서를 참고하거나, 자세한 내용은 https://docs.mongodb.org/manual/ 를 참고하기 바란다.

https://university.mongodb.com/ 에 가면 mongodb.com에서 운영하는 온라인 강의를 들을 수 있다. (무료인 과정도 있으니 필요하면 참고하기 바란다.)

 

mongoDBnode.js에서 호출하는 방법은 여러가지가 있으나 대표적인 두가지를 소개한다.

첫번째 방식은 mongoDB 드라이버를 이용하여 직접 mongoDB 쿼리를 사용하는 방식이고, 두번째 방식은 ODM (Object Document Mapper)를 이용하는 방식이다. ODM 방식은 자바나 다른 프로그래밍 언어의 ORM (Object Relational Mapping)과 유사하게 직접 쿼리를 사용하는 것이 아니라 맵퍼를 이용하여 프로그램상의 객체를 데이타와 맵핑 시키는 방식이다. 뒷부분에서 직접 코드를 보면 이해가 빠를 것이다.

 

Monk를 이용한 연결

첫번째로 mongoDB 네이티브 쿼리를 수행하는 방법에 대해서 소개한다. monk라는 node.jsmongoDB 클라이언트를 이용할 것이다.

monk 모듈을 이용하기 위해서 아래와 같이 package.jsonmonk에 대한 의존성을 추가한다.


{

  "name": "mongoDBexpress",

  "version": "0.0.0",

  "private": true,

  "scripts": {

    "start": "node ./bin/www"

  },

  "dependencies": {

    "body-parser": "~1.13.2",

    "cookie-parser": "~1.3.5",

    "debug": "~2.2.0",

    "express": "~4.13.1",

    "jade": "~1.11.0",

    "morgan": "~1.6.1",

    "serve-favicon": "~2.3.0",

    "monk":"~1.0.1"

  }

}

 

Figure 9 monk 모듈에 대한 의존성이 추가된 package.json

 

app.js에서 express가 기동할때, monk를 이용해서 mongoDB에 연결하도록 한다.

var monk = require('monk');

var db = monk('mongodb://localhost:27017/mydb');

 

var mongo = require('./routes/mongo.js');

app.use(function(req,res,next){

    req.db = db;

    next();

});

app.use('/', mongo);

Figure 10 monk를 이용하여 app.js에서 mongoDB 연결하기

 

mongoDB에 연결하기 위한 연결 문자열은 'mongodb://localhost:27017/mydb' mongo://{mongoDB 주소}:{mongoDB 포트}/{연결하고자 하는 DB} 으로 이 예제에서는 mongoDB 연결을 간단하게 IP,포트,DB명만 사용했지만, 여러개의 인스턴스가 클러스터링 되어 있을 경우, 여러 mongoDB로 연결을 할 수 있는 설정이나, Connection Pool과 같은 설정, SSL과 같은 보안 설정등 부가적인 설정이 많으니, 반드시 운영환경에 맞는 설정으로 변경하기를 바란다. 설정 방법은 http://mongodb.github.io/node-mongodb-native/2.1/reference/connecting/connection-settings/ 문서를 참고하자.

 

이때 주의깊게 살펴봐야 하는 부분이 app.use를 이용해서 미들웨어를 추가하였는데, req.dbmongodb 연결을 넘기는 것을 볼 수 있다. 미들웨어로 추가가 되었기 때문에 매번 HTTP 요청이 올때 마다 req 객체에는 db라는 변수로 mongodb 연결을 저장해서 넘기게 되는데, 이는 HTTP 요청을 처리하는 것이 router에서 처리하는 것이 일반적이기 때문에, routerdb 연결을 넘기기 위함이다. 아래 데이타를 삽입하는 라우터 코드를 보자

 

router.post('/insert', function(req, res, next) {

      var userid = req.body.userid;

      var sex = req.body.sex;

      var city = req.body.city;

     

      db = req.db;

      db.get('users').insert({'userid':userid,'sex':sex,'city':city},function(err,doc){

             if(err){

                console.log(err);

                res.status(500).send('update error');

                return;

             }

             res.status(200).send("Inserted");

            

         });

});

Figure 11 /routes/mongo.js 에서 데이타를 삽입하는 코드


req 객체에서 폼 필드를 읽어서 userid,sex,city등을 읽어내고, 앞의 app.js 에서 추가한 미들웨어에서 넘겨준 db 객체를 받아서 db.get('users').insert({'userid':userid,'sex':sex,'city':city},function(err,doc) 수행하여 데이타를 insert 하였다.

 

다음은 userid필드가 HTTP 폼에서 넘어오는 userid 일치하는 레코드를 지우는 코드 예제이다. Insert 부분과 크게 다르지 않고 remove 함수를 이용하여 삭제 하였다.


router.post('/delete', function(req, res, next) {

      var userid = req.body.userid;

     

      db = req.db;

      db.get('users').remove({'userid':userid},function(err,doc){

             if(err){

                console.log(err);

                res.status(500).send('update error');

                return;

             }

             res.status(200).send("Removed");

            

         });

});

Figure 12 /routes/mongo.js 에서 데이타를 삭제하는 코드

 

다음은 데이타를 수정하는 부분이다. Update 함수를 이용하여 데이타를 수정하는데,

db.get('users').update({userid:userid},{'userid':userid,'sex':sex,'city':city},function(err,doc){

와 같이 ‘userid’userid 인 필드의 데이타를 },{'userid':userid,'sex':sex,'city':city} 대치한다.

 

router.post('/update', function(req, res, next) {

      var userid = req.body.userid;

      var sex = req.body.sex;

      var city = req.body.city;

      db = req.db;

      db.get('users').update({userid:userid},{'userid':userid,'sex':sex,'city':city},function(err,doc){

      //db.get('users').update({'userid':userid},{$set:{'sex':'BUSAN'}},function(err,doc){

             if(err){

                console.log(err);

                res.status(500).send('update error');

                return;

             }

             res.status(200).send("Updated");

            

         });

});

Figure 13 /routes/mongo.js 에서 데이타를 수정하는 코드


전체 레코드를 대치하는게 아니라 특정 필드만 수정하고자 하면, $set: 쿼리를 이용하여, 수정하고자하는 필드만 아래와 같이 수정할 수 있다.

db.collection('users').updateOne({_id:userid},{$set:{'sex':'BUSAN'}},function(err,doc){

 

마지막으로 데이타를 조회하는 부분이다. /list URL은 전체 리스트를 리턴하는 코드이고, /get ?userid= 쿼리 스트링으로 정의되는 사용자 ID에 대한 레코드만을 조회해서 리턴한다.

router.get('/list', function(req, res, next) {

      db = req.db;

      db.get('users').find({},function(err,doc){

           if(err) console.log('err');

           res.send(doc);

      });

});

router.get('/get', function(req, res, next) {

      db = req.db;

      var userid = req.query.userid

      db.get('users').findOne({'userid':userid},function(err,doc){

           if(err) console.log('err');

           res.send(doc);

      });

});

Figure 14 /routes/mongo.js 에서 데이타를 조회하는 코드

 

이제 /routes/mongo.js 의 모든 코드 작업이 완료되었다. 이 코드를 호출하기 위한 HTML 폼을 작성하자.

 

<!DOCTYPE html>

<html>

<head>

<meta charset="UTF-8">

<title>Insert title here</title>

</head>

<body>

 

<h1> Native MongoDB Test Example</h1>

<form method='post' action='/insert' name='mongoform' >

      user <input type='text' size='10' name='userid'>

      <input type='submit' value='delete' onclick='this.form.action="/delete"' >

      <input type='button' value='get' onclick='location.href="/get?userid="+document.mongoform.userid.value' >

      <p>

      city <input type='text' size='10' name='city' >

      sex <input type='radio' name='sex' value='male'>male

      <input type='radio' name='sex' value='female'>female

      <p>

      <input type='submit' value='insert' onclick='this.form.action="/insert"' >

      <input type='submit' value='update' onclick='this.form.action="/update"' >

      <input type='button' value='list'  onclick='location.href="/list"' >

     

</form>

</body>

</html>

Figure 15 /public/monksample.html

 

node.js를 실행하고 http://localhost:3000/monksample.html 을 실행해보자



Figure 16 http://localhost:3000/monksample.html 실행 결과

 

아래 insert 버튼을 누르면, 채워진 필드로 새로운 레코드를 생성하고, update 버튼은 user 필드에 있는 사용자 이름으로된 데이타를 업데이트 한다. list 버튼은 컬렉션에서 전체 데이타를 조회해서 출력하고, delete 버튼은 user 필드에 있는 사용자 이름으로된 레코드를 삭제한다. get 버튼은 user 필드에 있는 사용자 이름으로 데이타를 조회하여 리턴한다.

다음은 list로 전체 데이타를 조회하는 화면이다.

 


Figure 17 /list를 수행하여 mongoDB에 저장된 전체 데이타를 조회하는 화면


이 코드의 전체 소스코드는 https://github.com/bwcho75/nodejs_tutorial/tree/master/mongoDBexpress 에 있으니 필요하면 참고하기 바란다


다음 글에서는  node.js의 mongoDB ODM 프레임웍인 mongoose 이용한 접근 방법에 대해서 알아보기로 한다.