블로그 이미지
평범하게 살고 싶은 월급쟁이 기술적인 토론 환영합니다.같이 이야기 하고 싶으시면 부담 말고 연락주세요:이메일-bwcho75골뱅이지메일 닷컴. 조대협


Archive»


 
 

딥러닝을 이용한 숫자 이미지 인식 #2/2


앞서 MNIST 데이타를 이용한 필기체 숫자를 인식하는 모델을 컨볼루셔널 네트워크 (CNN)을 이용하여 만들었다. 이번에는 이 모델을 이용해서 필기체 숫자 이미지를 인식하는 코드를 만들어 보자


조금 더 테스트를 쉽게 하기 위해서, 파이썬 주피터 노트북내에서 HTML 을 이용하여 마우스로 숫자를 그릴 수 있도록 하고, 그려진 이미지를 어떤 숫자인지 인식하도록 만들어 보겠다.



모델 로딩

먼저 앞의 예제에서 학습을한 모델을 로딩해보도록 하자.

이 코드는 주피터 노트북에서 작성할때, 모델을 학습 시키는 코드 (http://bcho.tistory.com/1156) 와 별도의 새노트북에서 구현을 하도록 한다.


코드

import tensorflow as tf

import numpy as np

import matplotlib.pyplot as plt

from tensorflow.examples.tutorials.mnist import input_data


#이미 그래프가 있을 경우 중복이 될 수 있기 때문에, 기존 그래프를 모두 리셋한다.

tf.reset_default_graph()


num_filters1 = 32


x = tf.placeholder(tf.float32, [None, 784])

x_image = tf.reshape(x, [-1,28,28,1])


#  layer 1

W_conv1 = tf.Variable(tf.truncated_normal([5,5,1,num_filters1],

                                         stddev=0.1))

h_conv1 = tf.nn.conv2d(x_image, W_conv1,

                      strides=[1,1,1,1], padding='SAME')


b_conv1 = tf.Variable(tf.constant(0.1, shape=[num_filters1]))

h_conv1_cutoff = tf.nn.relu(h_conv1 + b_conv1)


h_pool1 =tf.nn.max_pool(h_conv1_cutoff, ksize=[1,2,2,1],

                       strides=[1,2,2,1], padding='SAME')


num_filters2 = 64


# layer 2

W_conv2 = tf.Variable(

           tf.truncated_normal([5,5,num_filters1,num_filters2],

                               stddev=0.1))

h_conv2 = tf.nn.conv2d(h_pool1, W_conv2,

                      strides=[1,1,1,1], padding='SAME')


b_conv2 = tf.Variable(tf.constant(0.1, shape=[num_filters2]))

h_conv2_cutoff = tf.nn.relu(h_conv2 + b_conv2)


h_pool2 =tf.nn.max_pool(h_conv2_cutoff, ksize=[1,2,2,1],

                       strides=[1,2,2,1], padding='SAME')


# fully connected layer

h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*num_filters2])


num_units1 = 7*7*num_filters2

num_units2 = 1024


w2 = tf.Variable(tf.truncated_normal([num_units1, num_units2]))

b2 = tf.Variable(tf.constant(0.1, shape=[num_units2]))

hidden2 = tf.nn.relu(tf.matmul(h_pool2_flat, w2) + b2)


keep_prob = tf.placeholder(tf.float32)

hidden2_drop = tf.nn.dropout(hidden2, keep_prob)


w0 = tf.Variable(tf.zeros([num_units2, 10]))

b0 = tf.Variable(tf.zeros([10]))

k = tf.matmul(hidden2_drop, w0) + b0

p = tf.nn.softmax(k)


# prepare session

sess = tf.InteractiveSession()

sess.run(tf.global_variables_initializer())

saver = tf.train.Saver()

saver.restore(sess, '/Users/terrycho/anaconda/work/cnn_session')


print 'reload has been done'


그래프 구현

코드를 살펴보면, #prepare session 부분 전까지는 이전 코드에서의 그래프를 정의하는 부분과 동일하다. 이 코드는 우리가 만든 컨볼루셔널 네트워크를 복원하는 부분이다.


변수 데이타 로딩

그래프의 복원이 끝나면, 저장한 세션의 값을 다시 로딩해서 학습된 W와 b값들을 다시 로딩한다.


# prepare session

sess = tf.InteractiveSession()

sess.run(tf.global_variables_initializer())

saver = tf.train.Saver()

saver.restore(sess, '/Users/terrycho/anaconda/work/cnn_session')


이때 saver.restore 부분에서 앞의 예제에서 저장한 세션의 이름을 지정해준다.

HTML을 이용한 숫자 입력

그래프와 모델 복원이 끝났으면 이 모델을 이용하여, 숫자를 인식해본다.

테스트하기 편리하게 HTML로 마우스로 숫자를 그릴 수 있는 화면을 만들어보겠다.

주피터 노트북에서 새로운 Cell에 아래와 같은 내용을 입력한다.


코드

input_form = """

<table>

<td style="border-style: none;">

<div style="border: solid 2px #666; width: 143px; height: 144px;">

<canvas width="140" height="140"></canvas>

</div></td>

<td style="border-style: none;">

<button onclick="clear_value()">Clear</button>

</td>

</table>

"""


javascript = """

<script type="text/Javascript">

   var pixels = [];

   for (var i = 0; i < 28*28; i++) pixels[i] = 0

   var click = 0;


   var canvas = document.querySelector("canvas");

   canvas.addEventListener("mousemove", function(e){

       if (e.buttons == 1) {

           click = 1;

           canvas.getContext("2d").fillStyle = "rgb(0,0,0)";

           canvas.getContext("2d").fillRect(e.offsetX, e.offsetY, 8, 8);

           x = Math.floor(e.offsetY * 0.2)

           y = Math.floor(e.offsetX * 0.2) + 1

           for (var dy = 0; dy < 2; dy++){

               for (var dx = 0; dx < 2; dx++){

                   if ((x + dx < 28) && (y + dy < 28)){

                       pixels[(y+dy)+(x+dx)*28] = 1

                   }

               }

           }

       } else {

           if (click == 1) set_value()

           click = 0;

       }

   });

   

   function set_value(){

       var result = ""

       for (var i = 0; i < 28*28; i++) result += pixels[i] + ","

       var kernel = IPython.notebook.kernel;

       kernel.execute("image = [" + result + "]");

   }

   

   function clear_value(){

       canvas.getContext("2d").fillStyle = "rgb(255,255,255)";

       canvas.getContext("2d").fillRect(0, 0, 140, 140);

       for (var i = 0; i < 28*28; i++) pixels[i] = 0

   }

</script>

"""


다음 새로운 셀에서, 다음 코드를 입력하여, 앞서 코딩한 HTML 파일을 실행할 수 있도록 한다.


from IPython.display import HTML

HTML(input_form + javascript)


이제 앞에서 만든 두 셀을 실행시켜 보면 다음과 같이 HTML 기반으로 마우스를 이용하여 숫자를 입력할 수 있는 박스가 나오는것을 확인할 수 있다.



입력값 판정

앞의 HTML에서 그린 이미지는 앞의 코드의 set_value라는 함수에 의해서, image 라는 변수로 784 크기의 벡터에 저장된다. 이 값을 이용하여, 이 그림이 어떤 숫자인지를 앞서 만든 모델을 이용해서 예측을 해본다.


코드


p_val = sess.run(p, feed_dict={x:[image], keep_prob:1.0})


fig = plt.figure(figsize=(4,2))

pred = p_val[0]

subplot = fig.add_subplot(1,1,1)

subplot.set_xticks(range(10))

subplot.set_xlim(-0.5,9.5)

subplot.set_ylim(0,1)

subplot.bar(range(10), pred, align='center')

plt.show()

예측

예측을 하는 방법은 쉽다. 이미지 데이타가 image 라는 변수에 들어가 있기 때문에, 어떤 숫자인지에 대한 확률을 나타내는 p 의 값을 구하면 된다.


p_val = sess.run(p, feed_dict={x:[image], keep_prob:1.0})


를 이용하여 x에 image를 넣고, 그리고 dropout 비율을 0%로 하기 위해서 keep_prob를 1.0 (100%)로 한다. (예측이기 때문에 당연히 dropout은 필요하지 않다.)

이렇게 하면 이 이미지가 어떤 숫자인지에 대한 확률이 p에 저장된다.

그래프로 표현

그러면 이 p의 값을 찍어 보자


fig = plt.figure(figsize=(4,2))

pred = p_val[0]

subplot = fig.add_subplot(1,1,1)

subplot.set_xticks(range(10))

subplot.set_xlim(-0.5,9.5)

subplot.set_ylim(0,1)

subplot.bar(range(10), pred, align='center')

plt.show()


그래프를 이용하여 0~9 까지의 숫자 (가로축)일 확률을 0.0~1.0 까지 (세로축)으로 출력하게 된다.

다음은 위에서 입력한 숫자 “4”를 인식한 결과이다.



(보너스) 첫번째 컨볼루셔널 계층 결과 출력

컨볼루셔널 네트워크를 학습시키다 보면 종종 컨볼루셔널 계층을 통과하여 추출된 특징 이미지들이 어떤 모양을 가지고 있는지를 확인하고 싶을때가 있다. 그래서 각 필터를 통과한 값을 이미지로 출력하여 확인하고는 하는데, 여기서는 이렇게 각 필터를 통과하여 인식된 특징이 어떤 모양인지를 출력하는 방법을 소개한다.


아래는 우리가 만든 네트워크 중에서 첫번째 컨볼루셔널 필터를 통과한 결과 h_conv1과, 그리고 이 결과에 bias 값을 더하고 활성화 함수인 Relu를 적용한 결과를 출력하는 예제이다.


코드


conv1_vals, cutoff1_vals = sess.run(

   [h_conv1, h_conv1_cutoff], feed_dict={x:[image], keep_prob:1.0})


fig = plt.figure(figsize=(16,4))


for f in range(num_filters1):

   subplot = fig.add_subplot(4, 16, f+1)

   subplot.set_xticks([])

   subplot.set_yticks([])

   subplot.imshow(conv1_vals[0,:,:,f],

                  cmap=plt.cm.gray_r, interpolation='nearest')

plt.show()


x에 image를 입력하고, dropout을 없이 모든 네트워크를 통과하도록 keep_prob:1.0으로 주고, 첫번째 컨볼루셔널 필터를 통과한 값 h_conv1 과, 이 값에 bias와 Relu를 적용한 값 h_conv1_cutoff를 계산하였다.

conv1_vals, cutoff1_vals = sess.run(

   [h_conv1, h_conv1_cutoff], feed_dict={x:[image], keep_prob:1.0})


첫번째 필터는 총 32개로 구성되어 있기 때문에, 32개의 결과값을 imshow 함수를 이용하여 흑백으로 출력하였다.




다음은 bias와 Relu를 통과한 값인 h_conv_cutoff를 출력하는 예제이다. 위의 코드와 동일하며 subplot.imgshow에서 전달해주는 인자만 conv1_vals → cutoff1_vals로 변경되었다.


코드


fig = plt.figure(figsize=(16,4))


for f in range(num_filters1):

   subplot = fig.add_subplot(4, 16, f+1)

   subplot.set_xticks([])

   subplot.set_yticks([])

   subplot.imshow(cutoff1_vals[0,:,:,f],

                  cmap=plt.cm.gray_r, interpolation='nearest')

   

plt.show()


출력 결과는 다음과 같다



이제까지 컨볼루셔널 네트워크를 이용한 이미지 인식을 텐서플로우로 구현하는 방법을 MNIST(필기체 숫자 데이타)를 이용하여 구현하였다.


실제로 이미지를 인식하려면 전체적인 흐름은 같지만, 이미지를 전/후처리 해내야 하고 또한 한대의 머신이 아닌 여러대의 머신과 GPU와 같은 하드웨어 장비를 사용한다. 다음 글에서는 MNIST가 아니라 실제 칼라 이미지를 인식하는 방법에 대해서 데이타 전처리에서 부터 서비스까지 전체 과정에 대해서 설명하도록 하겠다.


예제 코드 : https://github.com/bwcho75/tensorflowML/blob/master/MNIST_CNN_Prediction.ipynb


저작자 표시 비영리
신고
크리에이티브 커먼즈 라이선스
Creative Commons License

딥러닝을 이용한 숫자 이미지 인식 #1/2


조대협 (http://bcho.tistory.com)


지난 글(http://bcho.tistory.com/1154 ) 을 통해서 소프트맥스 회귀를 통해서, 숫자를 인식하는 모델을 만들어서 학습 시켜 봤다.

이번글에서는 소프트맥스보다 정확성이 높은 컨볼루셔널 네트워크를 이용해서 숫자 이미지를 인식하는 모델을 만들어 보겠다.


이 글의 목적은 CNN 자체의 설명이나, 수학적 이론에 대한 이해가 목적이 아니다. 최소한의 수학적 지식만 가지고, CNN 네트워크 모델을 텐서플로우로 구현하는데에 그 목적을 둔다. CNN을 이해하기 위해서는 Softmax 등의 함수를 이해하는게 좋기 때문에 가급적이면 http://bcho.tistory.com/1154 예제를 먼저 보고 이 문서를 보는게 좋다. 그 다음에 CNN 모델에 대한 개념적인 이해를 위해서 http://bcho.tistory.com/1149  문서를 참고하고 이 문서를 보는 것이 좋다.


이번 글은 CNN을 적용하는 것 이외에, 다음과 같은 몇가지 팁을 추가로 소개한다.

  • 학습이 된 모델을 저장하고 다시 로딩 하는 방법

  • 학습된 모델을 이용하여 실제로 주피터 노트북에서 글씨를 써보고 인식하는 방법

MNIST CNN 모델


우리가 만들고자 하는 모델은 두개의 컨볼루셔널 레이어(Convolutional layer)과, 마지막에 풀리 커넥티드 레이어 (fully connected layer)을 가지고 있는 컨볼루셔널 네트워크 모델(CNN) 이다.

모델의 모양을 그려보면 다음과 같다.


입력 데이타

입력으로 사용되는 데이타는 앞의 소프트맥스 예제에서 사용한 데이타와 동일한 손으로 쓴 숫자들이다. 각 숫자 이미지는 28x28 픽셀로 되어 있고, 흑백이미지이기 때문에 데이타는 28x28x1 행렬이 된다. (만약에 칼라 RGB라면 28x28x3이 된다.)

컨볼루셔널 계층

총 두 개의 컨볼루셔널 계층을 사용했으며, 각 계층에서 컨볼루셔널 필터를 사용해서, 특징을 추출한다음에, 액티베이션 함수 (Activation function)으로, ReLu를 적용한 후, 맥스풀링 (Max Pooling)을 이용하여, 주요 특징을 정리해낸다.

이와 같은 컨볼루셔널 필터를 두개를 중첩하여 적용하였다.

마지막 풀리 커넥티드 계층

컨볼루셔널 필터를 통해서 추출된 특징은 풀리 커넥티드 레이어(Fully connected layer)에 의해서 분류 되는데, 풀리 커넥티드 레이어는 하나의 뉴럴 네트워크를 사용하고, 그 뒤에 드롭아웃 (Dropout) 계층을 넣어서, 오버피팅(Overfitting)이 발생하는 것을 방지한다.  마지막으로 소프트맥스 (Softmax) 함수를 이용하여 0~9 열개의 숫자로 분류를 한다.


학습(트레이닝) 코드

이를 구현하기 위한 코드는 다음과 같다.


코드

import tensorflow as tf

import numpy as np

import matplotlib.pyplot as plt

from tensorflow.examples.tutorials.mnist import input_data



tf.reset_default_graph()


np.random.seed(20160704)

tf.set_random_seed(20160704)


# load data

mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)


# define first layer

num_filters1 = 32


x = tf.placeholder(tf.float32, [None, 784])

x_image = tf.reshape(x, [-1,28,28,1])


W_conv1 = tf.Variable(tf.truncated_normal([5,5,1,num_filters1],

                                         stddev=0.1))

h_conv1 = tf.nn.conv2d(x_image, W_conv1,

                      strides=[1,1,1,1], padding='SAME')


b_conv1 = tf.Variable(tf.constant(0.1, shape=[num_filters1]))

h_conv1_cutoff = tf.nn.relu(h_conv1 + b_conv1)


h_pool1 = tf.nn.max_pool(h_conv1_cutoff, ksize=[1,2,2,1],

                        strides=[1,2,2,1], padding='SAME')


# define second layer

num_filters2 = 64


W_conv2 = tf.Variable(

           tf.truncated_normal([5,5,num_filters1,num_filters2],

                               stddev=0.1))

h_conv2 = tf.nn.conv2d(h_pool1, W_conv2,

                      strides=[1,1,1,1], padding='SAME')


b_conv2 = tf.Variable(tf.constant(0.1, shape=[num_filters2]))

h_conv2_cutoff = tf.nn.relu(h_conv2 + b_conv2)


h_pool2 = tf.nn.max_pool(h_conv2_cutoff, ksize=[1,2,2,1],

                        strides=[1,2,2,1], padding='SAME')


# define fully connected layer

h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*num_filters2])


num_units1 = 7*7*num_filters2

num_units2 = 1024


w2 = tf.Variable(tf.truncated_normal([num_units1, num_units2]))

b2 = tf.Variable(tf.constant(0.1, shape=[num_units2]))

hidden2 = tf.nn.relu(tf.matmul(h_pool2_flat, w2) + b2)


keep_prob = tf.placeholder(tf.float32)

hidden2_drop = tf.nn.dropout(hidden2, keep_prob)


w0 = tf.Variable(tf.zeros([num_units2, 10]))

b0 = tf.Variable(tf.zeros([10]))

k = tf.matmul(hidden2_drop, w0) + b0

p = tf.nn.softmax(k)


#define loss (cost) function

t = tf.placeholder(tf.float32, [None, 10])

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(k,t))

train_step = tf.train.AdamOptimizer(0.0001).minimize(loss)

correct_prediction = tf.equal(tf.argmax(p, 1), tf.argmax(t, 1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


# prepare session

sess = tf.InteractiveSession()

sess.run(tf.global_variables_initializer())

saver = tf.train.Saver()


# start training

i = 0

for _ in range(1000):

   i += 1

   batch_xs, batch_ts = mnist.train.next_batch(50)

   sess.run(train_step,

            feed_dict={x:batch_xs, t:batch_ts, keep_prob:0.5})

   if i % 500 == 0:

       loss_vals, acc_vals = [], []

       for c in range(4):

           start = len(mnist.test.labels) / 4 * c

           end = len(mnist.test.labels) / 4 * (c+1)

           loss_val, acc_val = sess.run([loss, accuracy],

               feed_dict={x:mnist.test.images[start:end],

                          t:mnist.test.labels[start:end],

                          keep_prob:1.0})

           loss_vals.append(loss_val)

           acc_vals.append(acc_val)

       loss_val = np.sum(loss_vals)

       acc_val = np.mean(acc_vals)

       print ('Step: %d, Loss: %f, Accuracy: %f'

              % (i, loss_val, acc_val))


saver.save(sess, 'cnn_session')

sess.close()



데이타 로딩 파트

그러면 코드를 하나씩 살펴보도록 하자.

맨 처음 블럭은 데이타를 로딩하고 각종 변수를 초기화 하는 부분이다.

import tensorflow as tf

import numpy as np

import matplotlib.pyplot as plt

from tensorflow.examples.tutorials.mnist import input_data


#Call tf.reset_default_graph() before you build your model (and the Saver). This will ensure that the variables get the names you intended, but it will invalidate previously-created graphs.


tf.reset_default_graph()


np.random.seed(20160704)

tf.set_random_seed(20160704)


# load data

mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)


Input_data 는 텐서플로우에 내장되어 있는 MNIST (손으로 쓴 숫자 데이타)셋으로, read_data_sets 메서드를 이요하여 데이타를 읽었다. 데이타 로딩 부분은 앞의 소프트맥스 MNIST와 같으니 참고하기 바란다.


여기서 특히 주목해야 할 부분은 tf.reset_default_graph()  인데, 주피터 노트북과 같은 환경에서 실행을 하게 되면, 주피터 커널을 리스타트하지 않는 이상 변수들의 컨택스트가 그대로 유지 되기 때문에, 위의 코드를 같은 커널에서 tf.reset_default_graph() 없이, 두 번 이상 실행하게 되면 에러가 난다. 그 이유는 텐서플로우 그래프를 만들어놓고, 그 그래프가 지워지지 않은 상태에서 다시 같은 그래프를 생성하면서 나오는 에러인데, tf.reset_default_graph() 메서드는 기존에 생성된 디폴트 그래프를 모두 삭제해서 그래프가 중복되는 것을 막아준다. 일반적인 파이썬 코드에서는 크게 문제가 없지만, 컨택스트가 계속 유지되는 주피터 노트북 같은 경우에는 발생할 수 있는 문제이니, 반드시 디폴트 그래프를 리셋해주도록 하자

첫번째 컨볼루셔널 계층

필터의 정의

다음은 첫번째 컨볼루셔널 계층을 정의 한다. 컨볼루셔널 계층을 이해하려면 컨볼루셔널 필터에 대한 개념을 이해해야 하는데, 다시 한번 되짚어 보자.

컨볼루셔널 계층에서 하는 일은 입력 데이타에 필터를 적용하여, 특징을 추출해 낸다.


이 예제에서 입력 받는 이미지 데이타는  28x28x1 행렬로 표현된 흑백 숫자 이미지이고, 예제 코드에서는 5x5x1 사이즈의 필터를 적용한다.

5x5x1 사이즈의 필터 32개를 적용하여, 총 32개의 특징을 추출할것이다.


코드

필터 정의 부분까지 코드로 살펴보면 다음과 같다.

# define first layer

num_filters1 = 32


x = tf.placeholder(tf.float32, [None, 784])

x_image = tf.reshape(x, [-1,28,28,1])


W_conv1 = tf.Variable(tf.truncated_normal([5,5,1,num_filters1],


x는 입력되는 이미지 데이타로, 2차원 행렬(28x28)이 아니라, 1차원 벡터(784)로 되어 있고, 데이타의 수는 무제한으로 정의하지 않았다. 그래서 placeholder정의에서 shape이 [None,784] 로 정의 되어 있다.  

예제에서는 연산을 편하게 하기 위해서 2차원 행렬을 사용할것이기 때문에, 784 1차원 벡터를 28x28x1 행렬로 변환을 해준다.

x_image는 784x무한개인 이미지 데이타 x를 , (28x28x1)이미지의 무한개 행렬로  reshape를 이용하여 변경하였다. [-1,28,28,1]은 28x28x1 행렬을 무한개(-1)로 정의하였다.


필터를 정의하는데, 필터는 앞서 설명한것과 같이 5x5x1 필터를 사용할것이고, 필터의 수는 32개이기 때문에, 필터 W_conv1의 차원(shape)은 [5,5,1,32] 가된다. (코드에서 32는 num_filters1 이라는 변수에 저장하여 사용하였다.) 그리고 W_conv1의 초기값은 [5,5,1,32] 차원을 가지는 난수를 생성하도록 tf.truncated_normal을 사용해서 임의의 수가 지정되도록 하였다.

필터 적용

필터를 정의했으면 필터를 입력 데이타(이미지)에 적용한다.


h_conv1 = tf.nn.conv2d(x_image, W_conv1,

                      strides=[1,1,1,1], padding='SAME')


필터를 적용하는 방법은 tf.nn.conv2d를 이용하면 되는데, 28x28x1 사이즈의 입력 데이타인 x_image에 앞에서 정의한 필터 W_conv1을 적용하였다.

스트라이드 (Strides)

필터는 이미지의 좌측 상단 부터 아래 그림과 같이 일정한 간격으로 이동하면서 적용된다.


이를 개념적으로 표현하면 다음과 같은 모양이 된다.


이렇게 필터를 움직이는 간격을 스트라이드 (Stride)라고 한다.

예제에서는 우측으로 한칸 그리고 끝까지 이동하면 아래로 한칸을 이동하도록 각각 가로와 세로의 스트라이드 값을 1로 세팅하였다.

코드에서 보면

h_conv1 = tf.nn.conv2d(x_image, W_conv1,

                      strides=[1,1,1,1], padding='SAME')

에서 strides=[1,1,1,1] 로 정의한것을 볼 수 있다. 맨앞과 맨뒤는 통상적으로 1을 쓰고, 두번째 1은 가로 스트라이드 값, 그리고 세번째 1은 세로 스트라이드 값이 된다.

패딩 (Padding)

위의 그림과 같이 필터를 적용하여 추출된 특징 행렬은 원래 입력된 이미지 보다 작게 된다.

연속해서 필터를 이런 방식으로 적용하다 보면 필터링 된 특징들이  작아지게되는데, 만약에 특징을  다 추출하기 전에 특징들이 의도하지 않게 유실되는 것을 막기 위해서 패딩이라는 것을 사용한다.


패딩이란, 입력된 데이타 행렬 주위로, 무의미한 값을 감싸서 원본 데이타의 크기를 크게 해서, 필터를 거치고 나온 특징 행렬의 크기가 작아지는 것을 방지한다.

또한 무의미한 값을 넣음으로써, 오버피팅이 발생하는 것을 방지할 수 있다. 코드상에서 padding 변수를 이용하여 패딩 방법을 정의하였다.


h_conv1 = tf.nn.conv2d(x_image, W_conv1,

                      strides=[1,1,1,1], padding='SAME')



padding=’SAME’을 주게 되면, 텐서플로우가 자동으로 패딩을 삽입하여 입력값과 출력값 (특징 행렬)의 크기가 같도록 한다. padding=’VALID’를 주게 되면, 패딩을 적용하지 않고 필터를 적용하여 출력값 (특징 행렬)의 크기가 작아진다.

활성함수 (Activation function)의 적용

필터 적용이 끝났으면, 이 필터링된 값에 활성함수를 적용한다. 컨볼루셔널 네트워크에서 일반적으로 사용하는 활성함수는 ReLu 함수이다.


코드

b_conv1 = tf.Variable(tf.constant(0.1, shape=[num_filters1]))

h_conv1_cutoff = tf.nn.relu(h_conv1 + b_conv1)


먼저 bias 값( y=WX+b 에서 b)인 b_conv1을 정의하고, tf.nn.relu를 이용하여, 필터된 결과(h_conv1)에 bias 값을 더한 값을 ReLu 함수로 적용하였다.

Max Pooling

추출된 특징 모두를 가지고 특징을 판단할 필요가 없이, 일부 특징만을 가지고도 특징을 판단할 수 있다. 즉 예를 들어서 고해상도의 큰 사진을 가지고도 어떤 물체를 식별할 수 있지만, 작은 사진을 가지고도 물체를 식별할 수 있다. 이렇게 특징의 수를 줄이는 방법을 서브샘플링 (sub sampling)이라고 하는데, 서브샘플링을 해서 전체 특징의 수를 의도적으로 줄이는 이유는 데이타의 크기를 줄이기 때문에, 컴퓨팅 파워를 절약할 수 있고, 데이타가 줄어드는 과정에서 데이타가 유실이 되기 때문에, 오버 피팅을 방지할 수 있다.


이러한 서브 샘플링에는 여러가지 방법이 있지만 예제에서는 맥스 풀링 (max pooling)이라는 방법을 사용했는데, 맥스 풀링은 풀링 사이즈 (mxn)로 입력데이타를 나눈후 그 중에서 가장 큰 값만을 대표값으로 추출하는 것이다.


아래 그림을 보면 원본 데이타에서 2x2 사이즈로 맥스 풀링을 해서 결과를 각 셀별로 최대값을 뽑아내었고, 이 셀을 가로 2칸씩 그리고 그다음에는 세로로 2칸씩 이동하는 stride 값을 적용하였다.


코드

h_pool1 = tf.nn.max_pool(h_conv1_cutoff, ksize=[1,2,2,1],

                        strides=[1,2,2,1], padding='SAME')


Max pooling은 tf.nn.max_pool이라는 함수를 이용해서 적용할 수 있는데, 첫번째 인자는 활성화 함수 ReLu를 적용하고 나온 결과 값인 h_conv1_cutoff 이고, 두 번째 인자인 ksize는 풀링 필터의 사이즈로 [1,2,2,1]은 2x2 크기로 묶어서 풀링을 한다는 의미이다.


다음 stride는 컨볼루셔널 필터 적용과 마찬가지로 풀링 필터를 가로와 세로로 얼마만큼씩 움직일 것인데, strides=[1,2,2,1]로, 가로로 2칸, 세로로 2칸씩 움직이도록 정의하였다.


행렬의 차원 변환

텐서플로우를 이용해서 CNN을 만들때 각각 개별의 알고리즘을 이해할 필요는 없지만 각 계층을 추가하거나 연결하기 위해서는 행렬의 차원이 어떻게 바뀌는지는 이해해야 한다.

다음 그림을 보자


첫번째 컨볼루셔널 계층은 위의 그림과 같이, 처음에 28x28x1 의 이미지가 들어가면 32개의 컨볼루셔널 필터 W를 적용하게 되고, 각각은 28x28x1의 결과 행렬을 만들어낸다. 컨볼루셔널 필터를 거치게 되면 결과 행렬의 크기는 작아져야 정상이지만, 결과 행렬의 크기를 입력 행렬의 크기와 동일하게 유지하도록 padding=’SAME’으로 설정하였다.

다음으로 bias 값 b를 더한후 (위의 그림에는 생략하였다) 에 이 값에 액티베이션 함수 ReLu를 적용하고 나면 행렬 크기에 변화 없이 28x28x1 행렬 32개가 나온다. 이 각각의 행렬에 size가 2x2이고, stride가 2인 맥스풀링 필터를 적용하게 되면 각각의 행렬의 크기가 반으로 줄어들어 14x14x1 행렬 32개가 리턴된다.


두번째 컨볼루셔널 계층


이제 두번째 컨볼루셔널 계층을 살펴보자. 첫번째 컨볼루셔널 계층과 다를 것이 없다.


코드

# define second layer

num_filters2 = 64


W_conv2 = tf.Variable(

           tf.truncated_normal([5,5,num_filters1,num_filters2],

                               stddev=0.1))

h_conv2 = tf.nn.conv2d(h_pool1, W_conv2,

                      strides=[1,1,1,1], padding='SAME')


b_conv2 = tf.Variable(tf.constant(0.1, shape=[num_filters2]))

h_conv2_cutoff = tf.nn.relu(h_conv2 + b_conv2)


h_pool2 = tf.nn.max_pool(h_conv2_cutoff, ksize=[1,2,2,1],

                        strides=[1,2,2,1], padding='SAME')


단 필터값인 W_conv2의 차원이 [5,5,32,64] ([5,5,num_filters1,num_filters2] 부분 )로 변경되었다.


W_conv2 = tf.Variable(

           tf.truncated_normal([5,5,num_filters1,num_filters2],

                               stddev=0.1))


필터의 사이즈가 5x5이고, 입력되는 값이 32개이기 때문에, 32가 들어가고, 총 64개의 필터를 적용하기 때문에 마지막 부분이 64가 된다.

첫번째 필터와 똑같이 stride를 1,1을 줘서 가로,세로로 각각 1씩 움직이고, padding=’SAME’으로 입력과 출력 사이즈를 같게 하였다.


h_pool2 = tf.nn.max_pool(h_conv2_cutoff, ksize=[1,2,2,1],

                        strides=[1,2,2,1], padding='SAME')


맥스풀링 역시 첫번째 필터와 마찬가지로 2,2 사이즈의 필터(ksize=[1,2,2,1]) 를 적용하고 stride값을 2,2로 줘서 (strides=[1,2,2,1]) 가로 세로로 두칸씩 움직이게 하여 결과의 크기가 반으로 줄어들게 하였다.


14x14 크기의 입력값 32개가 들어가서, 7x7 크기의 행렬 64개가 리턴된다.

풀리 커넥티드 계층

두개의 컨볼루셔널 계층을 통해서 특징을 뽑아냈으면, 이 특징을 가지고 입력된 이미지가 0~9 중 어느 숫자인지를 풀리 커넥티드 계층 (Fully connected layer)를 통해서 판단한다.


코드

# define fully connected layer

h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*num_filters2])


num_units1 = 7*7*num_filters2

num_units2 = 1024


w2 = tf.Variable(tf.truncated_normal([num_units1, num_units2]))

b2 = tf.Variable(tf.constant(0.1, shape=[num_units2]))

hidden2 = tf.nn.relu(tf.matmul(h_pool2_flat, w2) + b2)


keep_prob = tf.placeholder(tf.float32)

hidden2_drop = tf.nn.dropout(hidden2, keep_prob)


w0 = tf.Variable(tf.zeros([num_units2, 10]))

b0 = tf.Variable(tf.zeros([10]))

k = tf.matmul(hidden2_drop, w0) + b0

p = tf.nn.softmax(k)


입력된 64개의 7x7 행렬을 1차원 행렬로 변환한다.


h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*num_filters2])


다음으로 풀리 커넥티드 레이어에 넣는데, 이때 입력값은 64x7x7 개의 벡터 값을 1024개의 뉴런을 이용하여 학습한다.


w2 = tf.Variable(tf.truncated_normal([num_units1, num_units2]))

b2 = tf.Variable(tf.constant(0.1, shape=[num_units2]))


그래서 w2의 값은 [num_units1,num_units2]로 num_units1은 64x7x7 로 입력값의 수를, num_unit2는 뉴런의 수를 나타낸다. 다음 아래와 같이 이 뉴런으로 계산을 한 후 액티베이션 함수 ReLu를 적용한다.


hidden2 = tf.nn.relu(tf.matmul(h_pool2_flat, w2) + b2)


다음 레이어에서는 드롭 아웃을 정의하는데, 드롭 아웃은 오버피팅(과적합)을 막기 위한 계층으로, 원리는 다음 그림과 같이 몇몇 노드간의 연결을 끊어서 학습된 데이타가 도달하지 않도록 하여서 오버피팅이 발생하는 것을 방지하는 기법이다.


출처 : http://cs231n.github.io/neural-networks-2/


텐서 플로우에서 드롭 아웃을 적용하는 것은 매우 간단하다. 아래 코드와 같이 tf.nn.dropout 이라는 함수를 이용하여, 앞의 네트워크에서 전달된 값 (hidden2)를 넣고 keep_prob에, 연결 비율을 넣으면 된다.

keep_prob = tf.placeholder(tf.float32)

hidden2_drop = tf.nn.dropout(hidden2, keep_prob)


연결 비율이란 네트워크가 전체가 다 연결되어 있으면 1.0, 만약에 50%를 드롭아웃 시키면 0.5 식으로 입력한다.

드롭 아웃이 끝난후에는 결과를 가지고 소프트맥스 함수를 이용하여 10개의 카테고리로 분류한다.


w0 = tf.Variable(tf.zeros([num_units2, 10]))

b0 = tf.Variable(tf.zeros([10]))

k = tf.matmul(hidden2_drop, w0) + b0

p = tf.nn.softmax(k)

비용 함수 정의

여기까지 모델 정의가 끝났다. 이제 이 모델을 학습 시키기 위해서 비용함수(코스트 함수)를 정의해보자.

코스트 함수는 크로스엔트로피 함수를 이용한다.

#define loss (cost) function

t = tf.placeholder(tf.float32, [None, 10])

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(k,t))

train_step = tf.train.AdamOptimizer(0.0001).minimize(loss)


k는 앞의 모델에 의해서 앞의 모델에서

k = tf.matmul(hidden2_drop, w0) + b0

p = tf.nn.softmax(k)


으로 softmax를 적용하기 전의 값이다.  Tf.nn.softmax_cross_entropy_with_logits 는 softmax가 포함되어 있는 함수이기 때문에, p를 적용하게 되면 softmax 함수가 중첩 적용되기 때문에, softmax 적용전의 값인 k 를 넣었다.


WARNING: This op expects unscaled logits, since it performs a softmax on logits internally for efficiency. Do not call this op with the output of softmax, as it will produce incorrect results

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.nn.softmax_cross_entropy_with_logits.md


t는 플레이스 홀더로 정의하였는데, 나중에 학습 데이타 셋에서 읽을 라벨 (그 그림이 0..9 중 어느 숫자인지)이다.


그리고 이 비용 함수를 최적화 하기 위해서 최적화 함수 AdamOptimizer를 사용하였다.

(앞의 소프트맥스 예제에서는 GradientOptimizer를 사용하였는데, 일반적으로 AdamOptimizer가 좀 더 무난하다.)

학습

이제 모델 정의와, 모델의 비용함수와 최적화 함수까지 다 정의하였다. 그러면 이 그래프들을 데이타를 넣어서 학습 시켜보자.  학습은 배치 트레이닝을 이용할것이다.


학습 도중 학습의 진행상황을 보기 위해서 학습된 모델을 중간중간 테스트할것이다. 테스트할때마다 학습의 정확도를 측정하여 출력하는데, 이를 위해서 정확도를 계산하는 함수를 아래와 같이 정의한다.


#define validation function

correct_prediction = tf.equal(tf.argmax(p, 1), tf.argmax(t, 1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


correct_prediction은 학습 결과와 입력된 라벨(정답)을 비교하여 맞았는지 틀렸는지를 리턴한다.

argmax는 인자에서 가장 큰 값의 인덱스를 리턴하는데, 0~9 배열이 들어가 있기 때문에 가장 큰 값이 학습에 의해 예측된 숫자이다. p는 예측에 의한 결과 값이고, t는 라벨 값이다 이 두 값을 비교하여 가장 큰 값이 있는 인덱스가 일치하면 예측이 성공한것이다.

correct_pediction은 bool 값이기 때문에, 이 값을 숫자로 바꾸기 위해서 tf.reduce_mean을 사용하여, accuracy에 저장하였다.


이제 학습을 세션을 시작하고, 변수들을 초기화 한다.

# prepare session

sess = tf.InteractiveSession()

sess.run(tf.global_variables_initializer())

saver = tf.train.Saver()


다음 배치 학습을 시작한다.

# start training

i = 0

for _ in range(10000):

   i += 1

   batch_xs, batch_ts = mnist.train.next_batch(50)

   sess.run(train_step,

            feed_dict={x:batch_xs, t:batch_ts, keep_prob:0.5})

   if i % 500 == 0:

       loss_vals, acc_vals = [], []

       for c in range(4):

           start = len(mnist.test.labels) / 4 * c

           end = len(mnist.test.labels) / 4 * (c+1)

           loss_val, acc_val = sess.run([loss, accuracy],

               feed_dict={x:mnist.test.images[start:end],

                          t:mnist.test.labels[start:end],

                          keep_prob:1.0})

           loss_vals.append(loss_val)

           acc_vals.append(acc_val)

       loss_val = np.sum(loss_vals)

       acc_val = np.mean(acc_vals)

       print ('Step: %d, Loss: %f, Accuracy: %f'

              % (i, loss_val, acc_val))


학습은 10,000번 루프를 돌면서 한번에 50개씩 배치로 데이타를 읽어서 학습을 진행하고, 500 번째 마다 중각 학습 결과를 출력한다. 중간 학습 결과에서는 10,000 중 몇번째 학습인지와, 비용값 그리고 정확도를 출력해준다.


코드를 보자


   batch_xs, batch_ts = mnist.train.next_batch(50)


MNIST 학습용 데이타 셋에서 50개 단위로 데이타를 읽는다. batch_xs에는 학습에 사용할 28x28x1 사이즈의 이미지와, batch_ts에는 그 이미지에 대한 라벨 (0..9중 어떤 수인지) 가 들어 있다.

읽은 데이타를 feed_dict를 통해서 피딩(입력)하고 트레이닝 세션을 시작한다.


  sess.run(train_step,

            feed_dict={x:batch_xs, t:batch_ts, keep_prob:0.5})


이때 마지막 인자에 keep_prob를 0.5로 피딩하는 것을 볼 수 있는데, keep_prob는 앞의 드롭아웃 계층에서 정의한 변수로 드롭아웃을 거치지 않을 비율을 정의한다. 여기서는 0.5 즉 50%의 네트워크를 인위적으로 끊도록 하였다.


배치로 학습을 진행하다가 500번 마다 중간중간 정확도와 학습 비용을 계산하여 출력한다.

   if i % 500 == 0:

       loss_vals, acc_vals = [], []


여기서 주목할 점은 아래 코드 처럼 한번에 검증을 하지 않고 테스트 데이타를 4등분 한후, 1/4씩 테스트 데이타를 로딩해서 학습비용(loss)와 학습 정확도(accuracy)를 계산하는 것을 볼 수 있다.


       for c in range(4):

           start = len(mnist.test.labels) / 4 * c

           end = len(mnist.test.labels) / 4 * (c+1)

           loss_val, acc_val = sess.run([loss, accuracy],

               feed_dict={x:mnist.test.images[start:end],

                          t:mnist.test.labels[start:end],

                          keep_prob:1.0})

           loss_vals.append(loss_val)

           acc_vals.append(acc_val)


이유는 한꺼번에 많은 데이타를 로딩해서 검증을 할 경우 메모리 문제가 생길 수 있기 때문에, 4번에 나눠 걸쳐서 읽고 검증한 다음에 아래와 같이 학습 비용은 4번의 학습 비용을 합하고, 정확도는 4번의 학습 정확도를 평균으로 내어 출력하였다.


       loss_val = np.sum(loss_vals)

       acc_val = np.mean(acc_vals)

       print ('Step: %d, Loss: %f, Accuracy: %f'

              % (i, loss_val, acc_val))

학습 결과 저장

학습을 통해서 최적의 W와 b값을 구했으면 이 값을 예측에 이용해야 하는데, W 값들이 많고, 이를 일일이 출력해서 파일로 저장하는 것도 번거롭고 해서, 텐서플로우에서는 학습된 모델을 저장할 수 있는 기능을 제공한다. 학습을 통해서 계산된 모든 변수 값을 저장할 수 있는데,  앞에서 세션을 생성할때 생성한 Saver (saver = tf.train.Saver())를 이용하면 현재 학습 세션을  저장할 수 있다.


코드

saver.save(sess, 'cnn_session')

sess.close()


이렇게 하면 현재 디렉토리에 cnn_session* 형태의 파일로 학습된 세션 값들이 저장된다.

그래서 추후 예측을 할때 다시 학습할 필요 없이 이 파일을 로딩해서, 모델의 값들을 복귀한 후에, 예측을 할 수 있다. 이 파일을 읽어서 예측을 하는 것은 다음글에서 다루기로 한다.


예제 코드 : https://github.com/bwcho75/tensorflowML/blob/master/MNIST_CNN_Training.ipynb


저작자 표시 비영리
신고
크리에이티브 커먼즈 라이선스
Creative Commons License

텐서플로우 #2 - 행렬과 텐서플로우


조대협 (http://bcho.tistory.com)


머신러닝은 거의 모든 연산을 행렬을 활용한다. 텐서플로우도 이 행렬을 기반으로 하고, 이 행렬의 차원을 shape 라는 개념으로 표현하는데, 행렬에 대한 기본적이 개념이 없으면 헷갈리기 좋다. 그래서 이 글에서는 간략하게 행렬의 기본 개념과 텐서플로우내에서 표현 방법에 대해서 알아보도록 한다.


행렬의 기본 개념 훝어보기

행과 열

행렬의 가장 기본 개념은 행렬이다. mxn 행렬이 있을때, m은 행, n은 열을 나타내며, 행은 세로의 줄수, 열은 가로줄 수 를 나타낸다. 아래는 3x4 (3행4열) 행렬이다.


곱셈


곱셈은 앞의 행렬에서 행과, 뒤의 행렬의 열을 순차적으로 곱해준다.

아래 그림을 보면 쉽게 이해가 될것이다.



이렇게 앞 행렬의 행과 열을 곱해나가면 결과적으로 아래와 같은 결과가 나온다.


이때 앞의 행렬의 열과, 뒤의 행렬의 행이 같아야 곱할 수 있다.

즉 axb 행렬과 mxn 행렬이 있을때, 이 두 행렬을 곱하려면 b와 m이 같아야 한다.

그리고 이 두 행렬을 곱하면 axn 사이즈의 행렬이 나온다.

행렬의 덧셈과 뺄셈

행렬의 덧셈과 뺄셈은 단순하다. 같은 행과 열에 있는 값을 더하거나 빼주면 되는데, 단지 주의할점은 덧셈과 뺄샘을 하는 두개의 행렬의 차원이 동일해야 한다.


텐서 플로우에서 행렬의 표현

행렬에 대해서 간단하게 되짚어 봤으면, 그러면 텐서 플로우에서는 어떻게 행렬을 표현하는지 알아보자


을 하는 코드를 살펴보자


예제코드

import tensorflow as tf


x = tf.constant([ [1.0,2.0,3.0] ])

w = tf.constant([ [2.0],[2.0],[2.0] ])

y = tf.matmul(x,w)

print x.get_shape()


sess = tf.Session()

init = tf.global_variables_initializer()

sess.run(init)

result = sess.run(y)


print result


실행 결과

(1, 3)
[[ 12.]]



텐서플로우에서 행렬의 곱셈은 일반 * 를 사용하지 않고, 텐서플로우 함수  “tf.matmul” 을 사용한다.

중간에, x.get_shape()를 통해서, 행렬 x의 shape를 출력했는데, shape는 행렬의 차원이라고 생각하면 된다. x는 1행3열인 1x3 행렬이기 때문에, 위의 결과와 같이 (1,3)이 출력된다.


앞의 예제에서는 contant 에 저장된 행렬에 대한 곱셈을 했는데, 당연히 Variable 형에서도 가능하다.


예제 코드

import tensorflow as tf


x = tf.Variable([ [1.,2.,3.] ], dtype=tf.float32)

w = tf.constant([ [2.],[2.],[2.]], dtype=tf.float32)

y = tf.matmul(x,w)


sess = tf.Session()

init = tf.global_variables_initializer()

sess.run(init)

result = sess.run(y)


print result


Constant 및 Variable 뿐 아니라,  PlaceHolder에도 행렬로 저장이 가능하다 다음은 PlaceHolder에 행렬 데이타를 feeding 해주는 예제이다.

입력 데이타 행렬 x는 PlaceHolder 타입으로 3x3 행렬이고, 여기에 곱하는 값 w는 1x3 행렬이다.


예제 코드는 다음과 같다.


예제코드

import tensorflow as tf


input_data = [ [1.,2.,3.],[1.,2.,3.],[2.,3.,4.] ] #3x3 matrix

x = tf.placeholder(dtype=tf.float32,shape=[None,3])

w = tf.Variable([ [2.],[2.],[2.] ], dtype = tf.float32) #3x1 matrix

y = tf.matmul(x,w)


sess = tf.Session()

init = tf.global_variables_initializer()

sess.run(init)

result = sess.run(y,feed_dict={x:input_data})


print result


실행결과

[[ 12.]
[ 12.]
[ 18.]]


이 예제에서 주의 깊게 봐야 할부분은 placeholder x 를 정의하는 부분인데, shape=[None,3] 으로 정의했다 3x3 행렬이기 때문에, shape=[3,3]으로 지정해도 되지만 None 이란, 갯수를 알수 없음을 의미하는 것으로, 텐서플로우 머신러닝 학습에서 학습 데이타가 계속해서 들어오고  학습 때마다 데이타의 양이 다를 수 있기 때문에, 이를 지정하지 않고 None으로 해놓으면 들어오는 숫자 만큼에 맞춰서 저장을 한다.

브로드 캐스팅

텐서플로우 그리고 파이썬으로 행렬 프로그래밍을 하다보면 헷갈리는 개념이 브로드 캐스팅이라는 개념이 있다. 먼저 다음 코드를 보자


예제코드

import tensorflow as tf


input_data = [

    [1,1,1],[2,2,2]

   ]

x = tf.placeholder(dtype=tf.float32,shape=[2,3])

w  =tf.Variable([[2],[2],[2]],dtype=tf.float32)

b  =tf.Variable([4],dtype=tf.float32)

y = tf.matmul(x,w)+b


print x.get_shape()

sess = tf.Session()

init = tf.global_variables_initializer()

sess.run(init)

result = sess.run(y,feed_dict={x:input_data})


print result


실행결과

(2, 3)
[[ 24.]
[ 48.]]


행렬 x는 2x3 행렬이고 w는 3x1 행렬이다. x*w를 하면 2*1 행렬이 나온다.

문제는 +b 인데, b는 1*1 행렬이다. 행렬의 덧셈과 뺄셈은 차원이 맞아야 하는데, 이 경우 더하고자 하는 대상은 2*1, 더하려는 b는 1*1로 행렬의 차원이 다르다. 그런데 어떻게 덧셈이 될까?

이 개념이 브로드 캐스팅이라는 개념인데, 위에서는 1*1인 b행렬을 더하는 대상에 맞게 2*1 행렬로 자동으로 늘려서 (stretch) 계산한다.


브로드 캐스팅은 행렬 연산 (덧셈,뺄셈,곱셈)에서 차원이 맞지 않을때, 행렬을 자동으로 늘려줘서(Stretch) 차원을 맞춰주는 개념으로 늘리는 것은 가능하지만 줄이는 것은 불가능하다.


브로드 캐스팅 개념은 http://scipy.github.io/old-wiki/pages/EricsBroadcastingDoc 에 잘 설명되어 있으니 참고하기 바란다. (아래 그림은 앞의 링크를 참조하였다.)


아래는 4x3 행렬 a와 1x3 행렬 b를 더하는 연산인데, 차원이 맞지 않기 때문에, 행렬 b의 열을 늘려서 1x3 → 4x3 으로 맞춰서 연산한 예이다.


만약에 행렬 b가 아래 그림과 같이 1x4 일 경우에는 열을 4 → 3으로 줄이고, 세로 행을 1→ 4 로 늘려야 하는데, 앞에서 언급한바와 같이, 브로드 캐스팅은 행이나 열을 줄이는 것은 불가능하다.


다음은 양쪽 행렬을 둘다 늘린 케이스 이다.

4x1 행렬 a와 1x3 행렬 b를 더하면 양쪽을 다 수용할 수 있는 큰 차원인 4x3 행렬로 변환하여 덧셈을 수행한다.



텐서플로우 행렬 차원 용어


텐서플로우에서는 행렬을 차원에 따라서 다음과 같이 호칭한다.

행렬이 아닌 숫자나 상수는 Scalar, 1차원 행렬을 Vector, 2차원 행렬을 Matrix, 3차원 행렬을 3-Tensor 또는 cube, 그리고 이 이상의 다차원 행렬을 N-Tensor라고 한다.


그리고 행렬의 차원을 Rank라고 부른다. scalar는 Rank가 0, Vector는 Rank 가 1, Matrix는 Rank가 2가 된다.


저작자 표시 비영리
신고
크리에이티브 커먼즈 라이선스
Creative Commons License

텐서플로우-#1 자료형의 이해

빅데이타/머신러닝 | 2016.12.09 22:42 | Posted by 조대협

텐서플로우-#1 자료형의 이해


조대협 (http://bcho.tistory.com)


딥러닝에 대한 대략적인 개념을 익히고 실제로 코딩을 해보려고 하니, 모 하나를 할때 마다 탁탁 막힌다. 파이썬이니 괜찮겠지 했는데, (사실 파이썬도 다 까먹어서 헷갈린다.) 이건 라이브러리로 도배가 되어 있다.

당연히 텐서플로우 프레임웍은 이해를 해야 하고, 데이타를 정재하고 시각화 하는데, numpy,pandas와 같은 추가적인 프레임웍에 대한 이해가 필요하다.


node.js 시작했을때도 자바스크립트 때문에 많이 헤매고 몇달이 지난후에야 어느정도 이해하게 되었는데, 역시나 차근차근 기초 부터 살펴봐야 하지 않나 싶다.


텐서 플로우에 대해 공부한 내용들을 하나씩 정리할 예정인데, 이 컨텐츠들은 유투브의 이찬우님의 강의를 기반으로 정리하였다. 무엇보다 한글이고 개념을 쉽게 풀어서 정리해주시기 때문에, 왠만한 교재 보다 났다.

https://www.youtube.com/watch?v=a74pFg8paVc


텐서플로우 환경 설정

텐서 플로우 환경을 설정 하는 방법은 쉽지 않다. 텐서플로우 뿐 아니라, 여러 파이썬 버전과 그에 맞는 라이브러리도 함께 설정해야 하기 때문에 여간 까다로운게 아닌데, 텐서플로우 환경은 크게 대략 두 가지 환경으로 쉽게 설정이 가능하다.

구글 데이타랩

첫번째 방법은 구글에서 주피터 노트북을 도커로 패키징해놓은 패키지를 이용하는 방법이다. 도커 패키지안에, numpy,pandas,matplotlib,tensorflow,python 등 텐서플로우 개발에 필요한 모든 환경이 패키징 되어 있다. 데이타 랩 설치 방법은 http://bcho.tistory.com/1134 링크를 참고하면 된다.

도커 런타임이 설치되어 있다면, 데이타랩 환경 설정은 10분이면 충분하다.

아나콘다

다음 방법은 일반적으로 가장 많이 사용하는 방법인데, 파이썬 수학관련 라이브러리를 패키징해놓은 아나콘다를 이용하는 방법이 있다. 자세한 환경 설정 방법은 https://www.tensorflow.org/versions/r0.12/get_started/os_setup.html#anaconda-installation 를 참고하기 바란다. 아나콘다를 설치해놓고, tensorflow 환경(environment)를 정의한 후에, 주피터 노트북을 설치하면 된다. http://stackoverflow.com/questions/37061089/trouble-with-tensorflow-in-jupyter-notebook 참고


Tensorflow 환경을 만든 후에,

$ source activate tensorflow

를 실행해서 텐서 플로우 환경으로 전환한후, 아래와 같이 ipython 을 설치한후에, 주피터 (jupyter) 노트북을 설치하면 된다.

(tensorflow) username$ conda install ipython
(tensorflow) username$ pip install jupyter #(use pip3 for python3)


아나콘다 기반의 텐서플로우 환경 설정은 나중에 시간이 될때 다른 글을 통해서 다시 설명하도록 하겠다.

텐서플로우의 자료형

텐서플로우는 뉴럴네트워크에 최적화되어 있는 개발 프레임웍이기 때문에, 그 자료형과, 실행 방식이 약간 일반적인 프로그래밍 방식과 상의하다. 그래서 삽질을 많이 했다.


상수형 (Constant)

상수형은 말 그대로 상수를 저장하는 데이타 형이다.

  • tf.constant(value, dtype=None, shape=None, name='Const', verify_shape=False)

와 같은 형태로 정의 된다. 각 정의되는 내용을 보면

  • value : 상수의 값이다.

  • dtype : 상수의 데이타형이다. tf.float32와 같이 실수,정수등의 데이타 타입을 정의한다.

  • shape : 행렬의 차원을 정의한다. shape=[3,3]으로 정의해주면, 이 상수는 3x3 행렬을 저장하게 된다.

  • name : name은 이 상수의 이름을 정의한다. name에 대해서는 나중에 좀 더 자세하게 설명하도록 하겠다.

간단한 예제를 하나 보자.

a,b,c 상수에, 각각 5,10,2 의 값을 넣은 후에, d=a*b+c 를 계산해서 계산 결과 d를 출력하려고 한다.

import tensorflow as tf


a = tf.constant([5],dtype=tf.float32)

b = tf.constant([10],dtype=tf.float32)

c = tf.constant([2],dtype=tf.float32)


d = a*b+c


print d

그런데, 막상 실행해보면, a*b+c의 값이 아니라 다음과 같이 Tensor… 라는 문자열이 출력된다.


Tensor("add_8:0", shape=(1,), dtype=float32)

그래프와 세션의 개념

먼저 그래프와 세션이라는 개념을 이해해야 텐서플로우의 프로그래밍 모델을 이해할 수 있다.

위의 d=a*b+c 에서 d 역시 계산을 수행하는 것이 아니라 다음과 같이 a*b+c 그래프를 정의하는 것이다.


실제로 값을 뽑아내려면, 이 정의된 그래프에 a,b,c 값을 넣어서 실행해야 하는데, 세션 (Session)을 생성하여,  그래프를 실행해야 한다. 세션은 그래프를 인자로 받아서 실행을 해주는 일종의 러너(Runner)라고 생각하면 된다.


자 그러면 위의 코드를 수정해보자


import tensorflow as tf


a = tf.constant([5],dtype=tf.float32)

b = tf.constant([10],dtype=tf.float32)

c = tf.constant([2],dtype=tf.float32)


d = a*b+c


sess = tf.Session()

result = sess.run(d)

print result



tf.Session()을 통하여 세션을 생성하고, 이 세션에 그래프 d를 실행하도록 sess.run(d)를 실행한다

이 그래프의 실행결과는 리턴값으로 result에 저장이 되고, 출력을 해보면 다음과 같이 정상적으로 52라는 값이 나오는 것을 볼 수 있다.


플레이스 홀더 (Placeholder)

자아 이제 상수의 개념을 알았으면, 이제는 플레이스 홀더에 대해서 알아보자.

y = x * 2 를 그래프를 통해서 실행한다고 하자. 입력값으로는 1,2,3,4,5를 넣고, 출력은 2,4,6,8,10을 기대한다고 하자. 이렇게 여러 입력값을 그래프에서 넣는 경우는 머신러닝에서 y=W*x + b 와 같은 그래프가 있다고 할 때, x는 학습을 위한 데이타가 된다.

즉 지금 살펴보고자 하는 데이타 타입은 학습을 위한 학습용 데이타를 위한 데이타 타입이다.


y=x*2를 정의하면 내부적으로 다음과 같은 그래프가 된다.


그러면, x에는 값을 1,2,3,4,5를 넣어서 결과값을 그래프를 통해서 계산해 내야한다. 개념적으로 보면 다음과 같다.



이렇게 학습용 데이타를 담는 그릇을 플레이스홀더(placeholder)라고 한다.

플레이스홀더에 대해서 알아보면, 플레이스 홀더의 위의 그래프에서 x 즉 입력값을 저장하는 일종의 통(버킷)이다.

tf.placeholder(dtype,shape,name)

으로 정의된다.

플레이스 홀더 정의에 사용되는 변수들을 보면

  • dtype : 플레이스홀더에 저장되는 데이타형이다. tf.float32와 같이 실수,정수등의 데이타 타입을 정의한다.

  • shape : 행렬의 차원을 정의한다. shapre=[3,3]으로 정의해주면, 이 플레이스홀더는 3x3 행렬을 저장하게 된다.

  • name : name은 이 플레이스 홀더의 이름을 정의한다. name에 대해서는 나중에 좀 더 자세하게 설명하도록 하겠다.


그러면 이 x에 학습용 데이타를 어떻게 넣을 것인가? 이를 피딩(feeding)이라고 한다.

다음 예제를 보자


import tensorflow as tf


input_data = [1,2,3,4,5]

x = tf.placeholder(dtype=tf.float32)

y = x * 2


sess = tf.Session()

result = sess.run(y,feed_dict={x:input_data})


print result


처음 input_data=[1,2,3,4,5]으로 정의하고

다음으로 x=tf.placeholder(dtype=tf.float32) 를 이용하여, x를 float32 데이타형을 가지는 플레이스 홀더로 정의하다. shape은 편의상 생략하였다.

그리고 y=x * 2 로 그래프를 정의하였다.


세션이 실행될때, x라는 통에 값을 하나씩 집어 넣는데, (앞에서도 말했듯이 이를 피딩이라고 한다.)

sess.run(y,feed_dict={x:input_data}) 와 같이 세션을 통해서 그래프를 실행할 때, feed_dict 변수를 이용해서 플레이스홀더 x에, input_data를 피드하면, 세션에 의해서 그래프가 실행되면서 x는 feed_dict에 의해서 정해진 피드 데이타 [1,2,3,4,5]를 하나씩 읽어서 실행한다.


변수형 (Variable)

마지막 데이타형은 변수형으로,

y=W*x+b 라는 학습용 가설이 있을때, x가 입력데이타 였다면, W와 b는 학습을 통해서 구해야 하는 값이 된다.  이를 변수(Variable)이라고 하는데, 변수형은 Variable 형의 객체로 생성이 된다.


  • tf.Variable.__init__(initial_value=None, trainable=True, collections=None, validate_shape=True, caching_device=None, name=None, variable_def=None, dtype=None, expected_shape=None, import_scope=None)


변수형에 값을 넣는 것은 다음과 같이 한다.


var = tf.Variable([1,2,3,4,5], dtype=tf.float32)


자 그러면 값을 넣어보고 코드를 실행해보자


import tensorflow as tf


input_data = [1,2,3,4,5]

x = tf.placeholder(dtype=tf.float32)

W = tf.Variable([2],dtype=tf.float32)

y = W*x


sess = tf.Session()

result = sess.run(y,feed_dict={x:input_data})


print result


우리가 기대하는 결과는 다음과 같다. y=W*x와 같은 그래프를 가지고,


x는 [1,2,3,4,5] 값을 피딩하면서, 변수 W에 지정된 2를 곱해서 결과를 내기를 바란다.

그렇지만 코드를 실행해보면 다음과 같이 에러가 출력되는 것을 확인할 수 있다.



이유는 텐서플로우에서 변수형은 그래프를 실행하기 전에 초기화를 해줘야 그 값이 변수에 지정이 된다.


세션을 초기화 하는 순간 변수 W에 그 값이 지정되는데, 초기화를 하는 방법은 다음과 같이 변수들을 global_variables_initializer() 를 이용해서 초기화 한후, 초기화된 결과를 세션에 전달해 줘야 한다.


init = tf.global_variables_initializer()

sess.run(init)


그러면 초기화를 추가한 코드를 보자


import tensorflow as tf


input_data = [1,2,3,4,5]

x = tf.placeholder(dtype=tf.float32)

W = tf.Variable([2],dtype=tf.float32)

y = W*x


sess = tf.Session()

init = tf.global_variables_initializer()

sess.run(init)

result = sess.run(y,feed_dict={x:input_data})


print result


초기화를 수행한 후, 코드를 수행해보면 다음과 같이 우리가 기대했던 결과가 출력됨을 확인할 수 있다.



텐서플로우를 처음 시작할때, Optimizer나 모델등에 대해 이해하는 것도 중요하지만, “데이타를 가지고 학습을 시켜서 적정한 값을 찾는다" 라는 머신러닝 학습 모델의 특성상, 모델을 그래프로 정의하고, 세션을 만들어서 그래프를 실행하고, 세션이 실행될때 그래프에 동적으로 값을 넣어가면서 (피딩) 실행한 다는 기본 개념을 잘 이해해야, 텐서플로우 프로그래밍을 제대로 시작할 수 있다.


저작자 표시 비영리
신고
크리에이티브 커먼즈 라이선스
Creative Commons License

머신러닝의 과학습 / 오버피팅의 개념


조대협 (http://bcho.tistory.com)


머신 러닝을 공부하다보면 자주 나오는 용어 중에 하나가 오버피팅 (Overfitting)이다.

과학습이라고도 하는데, 그렇다면 오버 피팅은 무엇일까?


머신 러닝을 보면 결과적으로 입력 받은 데이타를 놓고, 데이타를 분류 (Classification) 하거나 또는 데이타에 인접한 그래프를 그리는 (Regression) , “선을 그리는 작업이다.”

그러면 선을 얼마나 잘 그리느냐가 머신 러닝 모델의 정확도와 연관이 되는데, 다음과 같이 붉은 선의 샘플 데이타를 받아서, 파란선을 만들어내는 모델을 만들었다면 잘 만들어진 모델이다. (기대하는)


언더 피팅


만약에 학습 데이타가 모자라거나 학습이 제대로 되지 않아서, 트레이닝 데이타에 가깝게 가지 못한 경우에는 다음과 같이 그래프가 트레이닝 데이타에서 많이 떨어진것을 볼 수 있는데, 이를 언더 피팅 (under fitting)이라고 한다.



오버 피팅

오버 피팅은 반대의 경우로, 다음 그림과 같이 트레이닝 데이타에 그래프가 너무 정확히 맞아 들어갈때 발생한다.


샘플 데이타에 너무 정확하게 학습이 되었기 때문에, 샘플데이타를 가지고 판단을 하면 100%에 가까운 정확도를 보이지만 다른 데이타를 넣게 되면, 정확도가 급격하게 떨어지는 문제이ㅏㄷ.

오버피팅의 해결

이런 오버피팅 문제를 해결하는 방법으로는 여러가지가 있는데 대표적인 방법으로는

  • 충분히 많은 학습 데이타를 넣거나

  • 피쳐의 수를 줄이거나

  • Regularization (정규화)를 이용하는 방법이 있다.



그림 출처 : 출처 : https://kousikk.wordpress.com/2014/11/20/problem-of-overfitting-in-machine-learning/




저작자 표시 비영리
신고
크리에이티브 커먼즈 라이선스
Creative Commons License

Docker Kubernetes의 UI

클라우드 컴퓨팅 & NoSQL/google cloud | 2016.11.26 23:38 | Posted by 조대협

Docker Kubernetes UI


조대협 (http://bcho.tistory.com)


오늘 도커 밋업에서 Kubernetes 발표가 있어서, 발표전에 데모를 준비하다 보니, 구글 클라우드의 Kubernetes 서비스인 GKE (Google Container Engine)에서 Kubernetes UI를 지원하는 것을 확인했다.


Google Container Service (GKE)


GKE는 구글 클라우드의 도커 클라우드 서비스이다. 도커 컨테이너를 관리해주는 서비스로는 Apache mesos, Docker Swarm 그리고 구글의 Kuberenetes 가 있는데, GKE는 이 Kuberentes 기반의 클라우드 컨테이너 서비스이다.


대부분의 이런 컨테이너 관리 서비스는 아직 개발중으로 운영에 적용하기에는 많은 부가적인 기능이 필요한데, 사용자 계정 인증이나, 로깅등이 필요하기 때문에, 운영환경에 적용하기는 아직 쉽지 않은데, GKE 서비스는 운영 환경에서 도커 서비스를 할 수 있도록 충분한 완성도를 제공한다. 이미 Pocketmon go 서비스도 이미 GKE를 사용하고 있다.


Kubernetes UI


예전에 Kubernetes를 테스트할 때 단점은 아직 모든 관리와 모니터링을 대부분 CLI로 해야 하기 때문에 사용성이 떨어지는데, 이번 GKE에서는 웹 UI 콘솔을 제공한다.


구글 GKE 콘솔에서 Kuberentes 클러스터를 선택하며 우측에 Connect 버튼이 나오는데, 


이 버튼을 누르면, Kubernetes 웹 UI를 띄울 수 있는 명령어가 출력된다.

아래와 같이 나온 명령어를 커맨드 창에서 실행시키고 htt://localhost:8001/ui 에 접속하면 Kubernetes 웹 콘솔을 볼 수 있다. 


Kubernetes 의 웹콘솔은 다음과 같은 모양이다.



Kubernetes의 주요 컴포넌트인 Pods, Service, Replication Controller , Nodes 등의 상태 모니터링은 물론이고, 배포 역시 이 웹 콘솔에서 가능하다.


예를 들어  gcr.io/terrycho-sandbox/hello-node:v1 컨테이너 이미지를 가지고, Pod 를 생성하고, Service를 정의해서 배포를 하려면 다음과 같은 명령을 이용해야 한다.


1. hello-node 라는 pod를 생성한다. 

% kubectl run hello-node --image=gcr.io/terrycho-sandbox/hello-node:v1 --port=8080


2. 생성된 pod를 service를 정의해서 expose 한다.

kubectl expose deployment hello-node --type="LoadBalancer"


이런 설정들을 CLI로 하면 익숙해지면 쉽지만 익숙해지기전까지는 번거로운데,

아래 그림과 같이, 간단하게 웹 UI에서 Pod와 서비스들을 한번에 정의할 수 있다.





배포가 완료된 후에는 각 Pod의 상황이나, Pod를 호스팅하고 있는 Nodes 들의 상황등 다양한 정보를 매우 쉽게 모니터링이 가능하다. (cf. CLI를 이용할 경우 CLI 명령어를 잘 알아야 가능하다.)


GKE에 대한 튜토리얼은 https://cloud.google.com/container-engine/docs/tutorials  에 있는데,

추천하는 튜토리얼은

가장 간단한 튜토리얼 node.js 웹앱을 배포하는  http://kubernetes.io/docs/hellonode/

와 WordPress와 MySQL을 배포하는 https://cloud.google.com/container-engine/docs/tutorials/persistent-disk/

을 추천한다.


도커가 아직까지 운영 환경에 사례가 국내에 많지 않고, GKE도 GUI 가 없어서 그다지 지켜보지 않았는데, 다시 파볼만한 정도의 완성도가 된듯.


참고로 테스트를 해보니 VM을 3개 만들어놓고 컨테이너를 7개인가 배포했는데, VM은 3개로 유지된다. 즉 하나의 VM에 여러개의 컨테이너가 배포되는 형태인데, 작은 서비스들이 많은 경우에는 자원 사용 효율이 좋을듯. 이런 관점에서 봤는때는 VM 기반의 서비스보다 컨테이너 서비스를 쓰는 장점이 확실히 보이는듯 하다




저작자 표시 비영리
신고
크리에이티브 커먼즈 라이선스
Creative Commons License

파이어베이스 애널러틱스를 이용한 모바일 데이타 분석

#4 주피터 노트북을 이용한 파이어베이스 데이타 분석 및 시각화

조대협 (http://bcho.tistory.com)

노트북의 개념

빅데이타 분석에서 리포팅 도구중 많이 사용되는 제품군 중의 하나가 노트북이라는 제품군이다. 대표적인 제품으로는 오픈소스 제품중 주피터(https://ipython.org/notebook.html) 와 제플린(https://zeppelin.apache.org/) 이 있다.

노트북은 비지니스에 전달하기 위한 멋진 액셀이나 대쉬보드와 같은 리포트 보다는 데이타를 다루는 데이타 과학자와 같은 사람들이 사용하는 분석도구인데, 제품의 이름 처럼 노트북의 개념을 가지고 있다.

예를 들어서 설명해보자 우리가 수학문제를 풀려면 연습장을 펴놓고 공식을 사용해가면서 하나하나 문제를 풀어나간다. 이처럼, 빅데이타 분석을 하려면, 여러데이타를 분석해가면서 그 과정을 노트하고 노트한 결과를 기반으로 다음 단계의 문제를 풀어나가는 것이 통상적인데, 노트북 소프트웨어는 문제 풀이에 있어서 기존의 연습장 노트와 같은 사용자 경험을 제공한다.

이러한 노트북 소프트웨어의 특징은 메모를 위한 글과, 계산을 위한 소스 코드를 한페이지에 같이 적을 수 있고, 이 소스 코드는 노트북 내에서 실행이 가능하고 결과도 같은 페이지에 출력해준다.


다음 화면은 본인이 작성했던 노트북의 일부로 딥러닝 프레임웍인 텐서플로우에 대해서 공부하면서 간단하게 문법과 샘플 코드를 노트북에 정리한 예이다.



데이타랩

구글의 데이타랩(https://cloud.google.com/datalab/) 은 오픈소스 주피터 노트북을 구글 클라우드 플랫폼에 맞게 기능을 추가한 노트북이다. 기본이 되는 주피터 노트북이 오픈소스이기 때문에, 데이타랩 역시 오프소스로 코드가 공개되어 있다.


데이타랩은 기본으로 파이썬 언어를 지원하며, 빅쿼리 연동등을 위해서 SQL과, 자바 스크립트를 지원한다.

또한 머신러닝의 딥러닝 프레임웍인 텐서플로우도 지원하고 있다.

데이타랩에서 연동할 수 있는 데이타는 구글 클라우드상의 VM이나, 빅쿼리, Google Cloud Storage

데이타랩은 오픈소스로 별도의 사용료가 부가되지 않으며, 사용 목적에 따라서 VM에 설치해서 실행할 수 도 있고, 로컬 데스크탑에 설치해서 사용할 수 도 있다. 도커로 패키징이 되어 있기 때문에 도커 환경만 있다면 손쉽게 설치 및 실행이 가능하다.

데이타 랩 설치

이 글에서는 로컬 맥북 환경에 데이타랩을 설치해서 데이타를 분석 해보도록 하자.

데이타 랩은 앞에서 언급한것과 같이 구글 클라우드 플랫폼 상의 VM에 설치할 수 도 있고, 맥,윈도우 기반의 로컬 데스크탑에도 설치할 수 있다. 각 플랫폼별 설치 가이드는  https://cloud.google.com/datalab/docs/quickstarts/quickstart-local 를 참고하기 바란다. 이 문서에서는 맥 OS를 기반으로 설치하는 방법을 설명한다.


데이타 랩은 컨테이너 솔루션인 도커로 패키징이 되어 있다. 그래서 도커 런타임을 설치해야 한다.

https://www.docker.com/products/docker 에서 도커 런타임을 다운 받아서 설치한다.

도커 런타임을 설치하면 애플리케이션 목록에 다음과 같이 고래 모양의 도커 런타임 아이콘이 나오는 것을 확인할 수 있다.



하나 주의할점이라면 맥에서 예전의 도커 런타임은 오라클의 버추얼 박스를 이용했었으나, 제반 설정등이 복잡하기 때문에, 이미 오라클 버추얼 박스 기반의 도커 런타임을 설치했다면 이 기회에, 도커 런타임을 새로 설치하기를 권장한다.

다음으로 도커 사용을 도와주는 툴로 Kitematic 이라는 툴을 설치한다. (https://kitematic.com/) 이 툴은 도커 컨테이너에 관련한 명령을 내리거나 이미지를 손쉽게 관리할 수 있는 GUI 환경을 제공한다.


Kitematic의 설치가 끝났으면 데이타랩 컨테이너 이미지를 받아서 실행해보자, Kitematic 좌측 하단의 “Dokcer CLI” 버튼을 누르면, 도커 호스트 VM의 쉘 스크립트를 수행할 수 있는 터미널이 구동된다.


터미널에서 다음 명령어를 실행하자


docker run -it -p 8081:8080 -v "${HOME}:/content" \

  -e "PROJECT_ID=terrycho-firebase" \

  gcr.io/cloud-datalab/datalab:local


데이타랩은 8080 포트로 실행이 되고 있는데, 위에서 8081:8080은  도커 컨테이너안에서 8080으로 실행되고 있는 데이타 랩을 외부에서 8081로 접속을 하겠다고 정의하였고, PROJECT_ID는 데이타랩이 접속할 구글 클라우드 프로젝트의 ID를 적어주면 된다.

명령을 실행하면, 데이타랩 이미지가 다운로드 되고 실행이 될것이다.

실행이 된 다음에는 브라우져에서 http://localhost:8081로 접속하면 다음과 같이 데이타랩이 수행된 것을 볼 수 있다.


데이타랩을 이용한 파이어베이스 애널러틱스 데이타 분석 (책에서는 위치 이동 할것 파이어 베이스로)

데이타랩이 설치되었으면, 파이어베이스 애널러틱스를 이용하여 빅쿼리에 수집한 로그를 분석해보자

데이타 랩에서 “+Notebook” 버튼을 눌러서 새로운 노트북을 생성하자

생성된 노트북으로 들어가서 “Add Code” 버튼을 누르고, 생성된 코드 블록 박스에 아래와 같은 SQL을 추가하자


%%sql

SELECT user_dim.app_info.app_instance_id, user_dim.device_info.device_category, user_dim.device_info.user_default_language, user_dim.device_info.platform_version, user_dim.device_info.device_model, user_dim.geo_info.country, user_dim.geo_info.city, user_dim.app_info.app_version, user_dim.app_info.app_store, user_dim.app_info.app_platform

FROM [terrycho-firebase:my_ios.app_events_20160830]


%%sql은 빅쿼리 SQL을 수행하겠다는 선언이다.

다음에 SQL 문장을 기술했는데, 테이블은 terrycho-firebase 프로젝트의 my_ios 데이타셋의 app_events_20160830 테이블에서 쿼리를 하였다.

2016년 8월 30일의 iOS 앱에서 올라온 사용자 관련 정보를 쿼리하는 내용이다. (디바이스 정보, 국가등)

다음은 쿼리 결과 이다.



다음 쿼리는 2016년 6월 1일의 안드로이드와 iOS 접속자에 대해서 국가별 사용자 수 통계를 내는 쿼리이다.


%%sql

SELECT

 user_dim.geo_info.country as country,

 EXACT_COUNT_DISTINCT( user_dim.app_info.app_instance_id ) as users

FROM

[firebase-analytics-sample-data:android_dataset.app_events_20160601],

 [firebase-analytics-sample-data:ios_dataset.app_events_20160601]

GROUP BY

 country

ORDER BY

 users DESC




다음은 2016년 6월 1일 사용자중, 안드로이드와 iOS 모두에서 사용자가 사용하는 언어별로 쿼리를 하는 내용이다.


%%sql

SELECT

 user_dim.user_properties.value.value.string_value as language_code,

 EXACT_COUNT_DISTINCT(user_dim.app_info.app_instance_id) as users,

FROM [firebase-analytics-sample-data:android_dataset.app_events_20160601],

 [firebase-analytics-sample-data:ios_dataset.app_events_20160601]

WHERE

user_dim.user_properties.key = "language"

GROUP BY

language_code

ORDER BY

users DESC


쿼리 결과



이번에는 차트를 사용하는 방법을 알아보자, 안드로이드 로그에서 이벤트 로그중에, 많이 나오는 로그 20개에 대한 분포도를 파이 차트로 그려내는 예제이다.

%%sql --module events

SELECT event_dim.name as event_name, COUNT(event_dim.name) as event_count  

FROM [firebase-analytics-sample-data:android_dataset.app_events_20160601]

GROUP BY event_name

ORDER BY event_count DESC

LIMIT 20


쿼리 결과를 --module 명령을 이용하여 events라는 모듈에 저장한후


%%chart pie --fields event_name,event_count --data events

title: Event count

height: 400

width: 800

pieStartAngle: 20

slices:

 0:

   offset: .2


구글 차트 명령을 이용하여 pie 차트를 그린다. 필드는 앞의 모듈에서 쿼리한 event_name과 event_count 필드를 이용하고, 데이타는 앞에서 정의한 “events” 모듈에서 읽어온다.

차트 실행 결과는 다음과 같다.



이외에도 Tensorflow 연동이나 GCS를 연동하는 방법, 그리고 구글 차트 이외에 일반 plot 함수를 이용하여 그래프를 그리는 등 다양한 기능을 제공하는데, 이에 대한 자세한 설명은 데이타랩을 설치하면 /docs/README.md 파일을 참조하면 다양한 가이드를 찾을 수 있다.



저작자 표시 비영리
신고
크리에이티브 커먼즈 라이선스
Creative Commons License

node.js에서 Redis 사용하기


조대협 (http://bcho.tistory.com)


Redis NoSQL 데이타 베이스의 종류로, mongoDB 처럼 전체 데이타를 영구히 저장하기 보다는 캐쉬처럼 휘발성이나 임시성 데이타를 저장하는데 많이 사용된다.

디스크에 데이타를 주기적으로 저장하기는 하지만, 기능은 백업이나 복구용으로 주로 사용할뿐 데이타는 모두 메모리에 저장되기 때문에, 빠른 접근 속도를 자랑한다.

 

이유 때문에 근래에는 memcached 다음의 캐쉬 솔루션으로 널리 사용되고 있는데, 간단하게 -밸류 (Key-Value)형태의 데이타 저장뿐만 아니라, 다양한 데이타 타입을 지원하기 때문에 응용도가 높고, node.js 호환 모듈이 지원되서 node.js 궁합이 좋다. 여러 node.js 클러스터링 하여 사용할때, node.js 인스턴스간 상태정보를 공유하거나, 세션과 같은 휘발성 정보를 저장하거나 또는 캐쉬등으로 다양하게 사용할 있다.

 

Redis 제공하는 기능으로는 키로 데이타를 저장하고 조회하는 Set/Get 기능이 있으며, 메세지를 전달하기 위한 큐로도 사용할 있다.

 

큐로써의 기능은 하나의 클라이언트가 다른 클라이언트로 메세지를 보내는 1:1 기능뿐 아니라, 하나의 클라이언트가 다수의 클라이언트에게 메세지를 발송하는 발행/배포 (Publish/Subscribe) 기능을 제공한다.




그림 1 RedisPublish/Subscribe의 개념 구조

 

재미있는 것중에 하나는 일반적인 Pub/Sub 시스템의 경우 Subscribe 하는 하나의 Topic에서만 Subscribe하는데 반해서, redis에서는 pattern matching 통해서 다수의 Topic에서 message subscribe 있다.

예를 들어 topic 이름이 music.pop music,classic 이라는 두개의 Topic 있을때, "PSUBSCRIBE music.*"라고 하면 두개의 Topic에서 동시에 message subscribe 있다.

 

자료 구조

 

Redis 가장 기본이 되는 자료 구조를 살펴보자. Redis 다양한 자료 구조를 지원하는데, 지원하는 자료 구조형은 다음과 같다.

1)       String

Key 대해서 문자열을 저장한다. 텍스트 문자열뿐만 아니라 숫자나 최대 512mbyte 까지의 바이너리도 저장할 있다.

 

2)       List

Key 대해서 List 타입을 저장한다. List에는 값들이 들어갈 있으며, INDEX 값을 이용해서 지정된 위치의 값을 넣거나 있고, 또는 push/pop 함수를 이용하여 리스트 앞뒤에 데이타를 넣거나 있다. 일반적인 자료 구조에서 Linked List 같은 자료 구조라고 생각하면 된다.

 

3)       Sets

Set 자료 구조는 집합이라고 생각하면 된다. Key 대해서 Set 저장할 있는데, List 구조와는 다르게 주의할점은 집합이기 때문에 같은 값이 들어갈 없다. 대신 집합의 특성을 이용한 집합 연산, 교집합, 합집합등의 연산이 가능하다.

 

4)       Sorted Set

SortedSet Set 동일하지만, 데이타를 저장할때, value 이외에, score 라는 값을 같이 저장한다. 그리고 score 라는 값에 따라서 데이타를 정렬(소팅)해서 저장한다. 순차성이나 순서가 중요한 데이타를 저장할때 유용하게 저장할 있다.

 

5)       Hashes

마지막 자료형으로는 Hashes 있는데, 해쉬 자료 구조를 생각하면 된다.Key 해쉬 테이블을 저장하는데, 해쉬 테이블에 저장되는 데이타는 (field, value) 형태로 field 해쉬의 키로 저장한다.

키가 있는 데이타를 군집하여 저장하는데 유용하며 데이타의 접근이 매우 빠르다. 순차적이지 않고 비순차적인 랜덤 액세스 데이타에 적절하다.

 

설명한 자료 구조를 Redis 저장되는 형태로 표현하면 다음과 같다.

 



Figure 36 redis의 자료 구조

 

기본적으로 /밸류 (Key/Value) 형태로 데이타가 저장되며, 밸류에 해당하는 데이타 타입은 앞서 언급하 String, List, Sets, SortedSets, Hashes 있다.

 

Redis 대한 설명은 여기서는 자세하게 하지 않는다. 독립적인 제품인 만큼 가지고 있는 기능과 운영에 신경써야할 부분이 많다. Redis 대한 자세한 설명은 http://redis.io 홈페이지를 참고하거나 정경석씨가 이것이 레디스다http://www.yes24.com/24/Goods/11265881?Acode=101 라는 책을 추천한다. 단순히 redis 대한 사용법뿐만 아니라, 레디스의 데이타 모델 설계에 대한 자세한 가이드를 제공하고 있다.

 

Redis 설치하기

개발환경 구성을 위해서 redis 설치해보자.

 

맥의 경우 애플리케이션 설치 유틸리티인 brew 이용하면 간단하게 설치할 있다.

%brew install redis

 

윈도우즈

안타깝게도 redis 공식적으로는 윈도우즈 인스톨을 지원하지 않는다. http://redis.io에서 소스 코드를 다운 받아서 컴파일을 해서 설치를 해야 하는데, 만약에 이것이 번거롭다면, https://github.com/rgl/redis/downloads 에서 다운로드 받아서 설치할 있다. 그렇지만 이경우에는 최신 버전을 지원하지 않는다.

그래서 vagrant 이용하여 우분투 리눅스로 개발환경을 꾸미고 위에 redis 설치하거나 https://redislabs.com/pricing https://www.compose.io  같은 클라우드 redis 환경을 사용하기를 권장한다. ( 클라우드 서비스의 경우 일정 용량까지 무료 또는 일정 기간까지 무료로 서비스를 제공한다.)

 

리눅스

리눅스의 경우 설치가 매우 간단하다. 우분투의 경우 패키지 메니저인 apt-get 이용해서 다음과 같이 설치하면 된다.

%sudo apt-get install redis-server

 

설치가 끝났으면 편하게 redis 사용하기 위해서 redis 클라이언트를 설치해보자.

여러 GUI 클라이언트들이 많지만, 편하게 사용할 있는 redis desktop 설치한다. http://redisdesktop.com/ 에서 다운 받은 후에 간단하게 설치할 있다.

 

이제 환경 구성이 끝났으니, redis 구동하고 제대로 동작하는지 테스트해보자

%redis-server

명령을 이용해서 redis 서버를 구동한다.

 



Figure 37 redis 기동 화면

 

redis desktop 이용해서 localhost 호스트에 Host 주소는 localhost TCP 포트는 6379 새로운 Connection 추가하여 연결한다.

 

 



Figure 38 redis desktop에서 연결을 설정하는 화면

 

연결이 되었으면 redis desktop에서 Console 연다.

 



Figure 39 redis desktop에서 콘솔을 여는 화면

 

Console에서 다음과 같이 명령어를 입력해보자

 

localhost:0>set key1 myvalue

OK

 

localhost:0>set key2 myvalue2

OK

 

localhost:0>get key2

myvalue2

 

localhost:0>

Figure 40 redis desktop에서 간단한 명령을 통해서 redis를 테스트 하는 화면


위의 명령은 key1 myvalue라는 값을 입력하고, key2 myvalue2라는 값을 입력한 후에, key2 입력된 값을 조회하는 명령이다.

 

Redis desktop에서, 디비를 조회해보면, 앞서 입력한 /밸류 값이 저장되어 있는 것을 다음과 같이 확인할 있다.

\


Figure 41 redis에 저장된 데이타를 redis desktop을 이용해서 조회하기

 

node.js에서 redis 접근하기

 

이제 node.js에서 redis 사용하기 위한 준비가 끝났다. 간단한 express API 만들어서 redis 캐쉬로 사용하여 데이타를 저장하고 조회하는 예제를 작성해보자

 

node.js redis 클라이언트는 여러 종류가 있다. http://redis.io/clients#nodejs

가장 널리 쓰는 클라이언트 모듈로는 node-redis https://github.com/NodeRedis/node_redis 있는데, 예제는 node-redis 클라이언트를 기준으로 설명한다.

 

예제는 profile URL에서 사용자 데이타를 JSON/POST 받아서 redis 저장하고, TTL(Time to Leave) 방식의 캐쉬 처럼 10 후에 삭제되도록 하였다.

그리고 GET /profile/{사용자 이름} 으로 redis 저장된 데이타를 조회하도록 하였다.

 

먼저 node-redis 모듈과, json 문서를 처리하기 위해서 JSON 모듈을 사용하기 때문에, 모듈을 설치하자

% npm install redis

% npm install JSON

 

package.json 모듈의 의존성을 다음과 같이 정의한다.

 

 

{

  "name": "RedisCache",

  "version": "0.0.0",

  "private": true,

  "scripts": {

    "start": "node ./bin/www"

  },

  "dependencies": {

    "body-parser": "~1.13.2",

    "cookie-parser": "~1.3.5",

    "debug": "~2.2.0",

    "express": "~4.13.1",

    "jade": "~1.11.0",

    "morgan": "~1.6.1",

    "serve-favicon": "~2.3.0",

    "redis":"~2.6.0",

    "JSON":"~1.0.0"

  }

}

 

Figure 42 redisJSON 모듈의 의존성이 추가된 package.json

 

다음으로 express 간단한 프로젝트를 만든 후에, app.js 다음과 같은 코드를 추가한다.

 

 

// redis example

var redis = require('redis');

var JSON = require('JSON');

client = redis.createClient(6379,'127.0.0.1');

 

app.use(function(req,res,next){

      req.cache = client;

      next();

})

app.post('/profile',function(req,res,next){

      req.accepts('application/json');

     

      var key = req.body.name;

      var value = JSON.stringify(req.body);

     

      req.cache.set(key,value,function(err,data){

           if(err){

                 console.log(err);

                 res.send("error "+err);

                 return;

           }

           req.cache.expire(key,10);

           res.json(value);

           //console.log(value);

      });

})

app.get('/profile/:name',function(req,res,next){

      var key = req.params.name;

     

      req.cache.get(key,function(err,data){

           if(err){

                 console.log(err);

                 res.send("error "+err);

                 return;

           }

 

           var value = JSON.parse(data);

           res.json(value);

      });

});

 

Figure 43 app.jsredis에 데이타를 쓰고 읽는 부분

 

redis 클라이언트와, JSON 모듈을 로딩한후, createClient 메서드를 이용해서, redis 대한 연결 클라이언트를 생성하자.

 

client = redis.createClient(6379,'127.0.0.1');

 

app.use(function(req,res,next){

      req.cache = client;

      next();

})

 

다음 연결 객체를 express router에서 쉽게 가져다 있도록, 미들웨어를 이용하여 req.cache 객체에 저장하도록 하자.

 

HTTP POST /profile 의해서 사용자 프로파일 데이타를 저장하는 부분을 보면

req.accepts('application/json'); 이용하여 JSON 요청을 받아드리도록 한다.

JSON내의 name 필드를 키로, 하고, JSON 전체를 밸류로 한다. JSON 객체 형태로 redis 저장할 있겠지만 경우 redis에서 조회를 하면 객체형으로 나오기 때문에 운영이 불편하다. 그래서 JSON.stringfy 이용하여 JSON 객체를 문자열로 변환하여 value 객체에 저장하였다.

다음 req.cache.set(key,value,function(err,data) 코드에서 redis 저장하기 위해서 redis 클라이언트를 req 객체에서 조회해온후, set 명령을 이용해서 /밸류 값을 저장한다. 저장이 끝나면 뒤에 인자로 전달된 콜백함수 호출 되는데, 콜백함수에서, req.cache.expire(key,10); 호출하여, 키에 대한 데이타 저장 시간을 10초로 설정한다. (10 후에는 데이타가 삭제된다.) 마지막으로 res.json(value); 이용하여 HTTP 응답에 JSON 문자열을 리턴한다.

 

HTTP GET으로 /profile/{사용자 이름} 요청을 받아서 키가 사용자 이름은 JSON 데이타를 조회하여 리턴하는 코드이다.

app.get('/profile/:name',function(req,res,next) 으로 요청을 받은 , URL에서 name 부분을 읽어서 키값으로 하고,

req.cache.get(key,function(err,data){ 이용하여, 키를 가지고 데이타를 조회한다. 콜백 함수 부분에서, 데이타가 문자열 형태로 리턴되는데, 이를 var value = JSON.parse(data); 이용하여, JSON 객체로 변환한 후에, res.json(value); 통해서 JSON 문자열로 리턴한다.

 

코드 작성이 끝났으면 테스트를 해보자 HTTP JSON/POST REST 호출을 보내야 하기 때문에, 별도의 클라이언트가 필요한데, 클라이언트는 구글 크롬 브라우져의 플러그인인 포스트맨(POSTMAN) 사용하겠다. https://chrome.google.com/webstore/detail/postman/fhbjgbiflinjbdggehcddcbncdddomop

 

포스트맨 설치가 끝났으면, 포스트맨에서 HTTP POST/JSON 방식으로 http://localhost:3000/profile 아래와 같이 요청을 보낸다.

 



Figure 44 포스트맨에서 HTTP POSTprofile 데이타를 삽입하는 화면

 

요청을 보낸후 바로 HTTP GET으로 http://localhost:3000/profile/terry 조회를 하면 아래와 같이 앞에서 입력한 데이타가 조회됨을 확인할 있다. 이때 위의 POST 요청을 보낸지 10 내에 조회를 해야 한다. 10초가 지나면 앞서 지정한 expire 의해서 자동으로 삭제가된다.



Figure 45 포스트맨에서 사용자 이름이 terry인 데이타를 조회하는 화면

 

Redisdesktop에서 확인을 해보면 아래와 같이 문자열로 terry 사용자에 대한 데이타가 저장되어 있는 것을 확인할 있다.



Figure 46 redis desktop 에서 입력된 데이타를 확인하는 화면

 

10초후에, 다시 조회를 해보면, terry 키로 가지는 데이타가 삭제된 것을 확인할 있다.

 

지금까지 가장 기본적인 redis 대한 소개와 사용법에 대해서 알아보았다. redis 뒤에 나올 node.js 클러스터링의 HTTP 세션을 저장하는 기능이나, Socket.IO 등에서도 계속해서 사용되는 중요한 솔루션이다. Redis 자체를 다루는 것이 아니라서 자세하게 파고 들어가지는 않았지만, 다소 운영이 까다롭고 특성을 파악해서 설계해야 하는 만큼 반드시 시간을 내서 redis 자체에 대해서 조금 자세하게 살펴보기를 권장한다.

저작자 표시 비영리
신고
크리에이티브 커먼즈 라이선스
Creative Commons License

monk 모듈을 이용한 mongoDB 연결


조대협 (http://bcho.tistory.com)


mongoDB 기반의 개발을 하기 위해서 mongoDB를 설치한다. https://www.mongodb.org/ 에서 OS에 맞는 설치 파일을 다운로드 받아서 설치한다.

설치가 된 디렉토리에 들어가서 설치디렉토리 아래 ‘./data’ 라는 디렉토리를 만든다. 이 디렉토리는 mongoDB의 데이타가 저장될 디렉토리이다.

 

mongoDB를 구동해보자.

% ./bin/mongod --dbpath ./data



Figure 1 mongoDB 구동화면


구동이 끝났으면 mongoDB에 접속할 클라이언트가 필요하다. DB에 접속해서 데이타를 보고 쿼리를 수행할 수 있는 클라이언트가 필요한데, 여러 도구가 있지만 많이 사용되는 도구로는 roboMongo라는 클라이언트가 있다.

https://robomongo.org/download 에서 다운로드 받을 수 있다. OS에 맞는 설치 파일을 다운로드 받아서 설치 후 실행한다.

 

설치 후에, Create Connection에서, 로컬호스트에 설치된 mongoDB를 연결하기 위해서 연결 정보를 기술하고, 연결을 만든다





Figure 2 robomongo에서 localhost에 있는 mongodb 연결 추가

 

주소는 localhost, 포트는 디폴트 포트로 27017를 넣으면 된다.

 

환경이 준비가 되었으면 간단한 테스트를 해보자. 테스트 전에 기본적인 개념을 숙지할 필요가 있는데, mongoDBNoSQL 계열중에서도 도큐먼트DB (Document DB)에 속한다. 기존 RDBMS에서 하나의 행이 데이타를 표현했다면, mogoDB는 하나의 JSON 파일이 하나의 데이타를 표현한다. JSON을 도큐먼트라고 하기 때문에, 도큐먼트 DB라고 한다.

 

제일 상위 개념은 DB의 개념에 대해서 알아보자, DB는 여러개의 테이블(컬렉션)을 저장하는 단위이다.

Robomongo에서 mydb라는 이름으로 DB 를 생성해보자



Figure 3 robomongo에서 새로운  DB를 추가 하는 화면

 

다음으로 생성된 DB안에, 컬렉션을 생성한다. 컬렉션은 RDBMS의 단일 테이블과 같은 개념이다.

Robomongo에서 다음과 같이 ‘users’라는 이름의 컬렉션을 생성한다



Figure 4 robomongo에서 컬렉션(Collection) 생성

 

users 컬렉션에는 userid를 키로 해서, sex(성별), city(도시) 명을 입력할 예정인데, userid가 키이기 때문에, userid를 통한 검색이나 소팅등이 발생한다. 그래서 userid를 인덱스로 지정한다.

인덱스 지정 방법은 createIndex 명령을 이용한다. 다음과 같이 robomongo에서 createIndex 명령을 이용하여 인덱스를 생성한다.



Figure 5 users 컬렉션에서 userid를 인덱스로 지정

 

mongoDB는 디폴트로, 각 컬렉션마다 “_id”라는 필드를 가지고 있다. 이 필드는 컬렉션 안의 데이타에 대한 키 값인데, 12 바이트의 문자열로 이루어져 있고 ObjectId라는 포맷으로 시간-머신이름,프로세스ID,증가값형태로 이루어지는 것이 일반적이다.

_id 필드에 userid를 저장하지 않고 별도로 인덱스를 만들어가면서 까지 userid 필드를 별도로 사용하는 것은 mongoDBNoSQL의 특성상 여러개의 머신에 데이타를 나눠서 저장한다. 그래서 데이타가 여러 머신에 골고루 분산되는 것이 중요한데, 애플리케이션상의 특정 의미를 가지고 있는 필드를 사용하게 되면 데이타가 특정 머신에 쏠리는 현상이 발생할 수 있다.

예를 들어서, 주민번호를 _id로 사용했다면, 데이타가 골고루 분산될것 같지만, 해당 서비스가 10~20대에만 인기있는 서비스라면, 10~20대 데이타를 저장하는 머신에만 데이타가 몰리게 되고, 10세이하나, 20세 이상의 데이타를 저장하는 노드에는 데이타가 적게 저장된다.

이런 이유등으로 mongoDB를 지원하는 node.js 드라이버에서는 _id 값을 사용할때, 앞에서 언급한 ObjectId 포맷을 따르지 않으면 에러를 내도록 설계되어 있다. 우리가 앞으로 살펴볼 mongoosemonk의 경우에도 마찬가지이다.

 

이제 데이타를 집어넣기 위한 테이블(컬렉션) 생성이 완료되었다.

다음 컬렉션 에 대한 CRUD (Create, Read, Update, Delete) 를 알아보자

SQL 문장과 비교하여, mongoDB에서 CRUD 에 대해서 알아보면 다음과 같다.

CRUD

SQL

MongoDB

Create

insert into users ("name","city") values("terry","seoul")

db.users.insert({userid:"terry",city:"seoul"})

Read

select * from users where id="terry"

db.users.find({userid:"terry"})

Update

update users set city="busan" where _id="terry"

db.users.update( {userid:"terry"}, {$set :{ city:"Busan" } } )

Delete

delete from users where _id="terry"

db.users.remove({userid:"terry"})

Figure 6 SQL문장과 mongoDB 쿼리 문장 비교


mongoDB에서 쿼리는 위와 같이 db.{Collection }.{명령어} 형태로 정의된다.

roboMongo에서 insert 쿼리를 수행하여 데이타를 삽입해보자



Figure 7 mongoDB에서 users 컬렉션에 데이타 추가

 

다음으로 삽입한 데이타를 find 명령을 이용해 조회해보자



Figure 8 mongoDB에서 추가된 데이타에 대한 확인

 

mongoDB에 대한 구조나 자세한 사용 방법에 대해서는 여기서는 설명하지 않는다.

http://www.tutorialspoint.com/mongodb/ mongoDB에 대한 전체적인 개념과 주요 쿼리들이 간략하게 설명되어 있으니 이 문서를 참고하거나, 자세한 내용은 https://docs.mongodb.org/manual/ 를 참고하기 바란다.

https://university.mongodb.com/ 에 가면 mongodb.com에서 운영하는 온라인 강의를 들을 수 있다. (무료인 과정도 있으니 필요하면 참고하기 바란다.)

 

mongoDBnode.js에서 호출하는 방법은 여러가지가 있으나 대표적인 두가지를 소개한다.

첫번째 방식은 mongoDB 드라이버를 이용하여 직접 mongoDB 쿼리를 사용하는 방식이고, 두번째 방식은 ODM (Object Document Mapper)를 이용하는 방식이다. ODM 방식은 자바나 다른 프로그래밍 언어의 ORM (Object Relational Mapping)과 유사하게 직접 쿼리를 사용하는 것이 아니라 맵퍼를 이용하여 프로그램상의 객체를 데이타와 맵핑 시키는 방식이다. 뒷부분에서 직접 코드를 보면 이해가 빠를 것이다.

 

Monk를 이용한 연결

첫번째로 mongoDB 네이티브 쿼리를 수행하는 방법에 대해서 소개한다. monk라는 node.jsmongoDB 클라이언트를 이용할 것이다.

monk 모듈을 이용하기 위해서 아래와 같이 package.jsonmonk에 대한 의존성을 추가한다.


{

  "name": "mongoDBexpress",

  "version": "0.0.0",

  "private": true,

  "scripts": {

    "start": "node ./bin/www"

  },

  "dependencies": {

    "body-parser": "~1.13.2",

    "cookie-parser": "~1.3.5",

    "debug": "~2.2.0",

    "express": "~4.13.1",

    "jade": "~1.11.0",

    "morgan": "~1.6.1",

    "serve-favicon": "~2.3.0",

    "monk":"~1.0.1"

  }

}

 

Figure 9 monk 모듈에 대한 의존성이 추가된 package.json

 

app.js에서 express가 기동할때, monk를 이용해서 mongoDB에 연결하도록 한다.

var monk = require('monk');

var db = monk('mongodb://localhost:27017/mydb');

 

var mongo = require('./routes/mongo.js');

app.use(function(req,res,next){

    req.db = db;

    next();

});

app.use('/', mongo);

Figure 10 monk를 이용하여 app.js에서 mongoDB 연결하기

 

mongoDB에 연결하기 위한 연결 문자열은 'mongodb://localhost:27017/mydb' mongo://{mongoDB 주소}:{mongoDB 포트}/{연결하고자 하는 DB} 으로 이 예제에서는 mongoDB 연결을 간단하게 IP,포트,DB명만 사용했지만, 여러개의 인스턴스가 클러스터링 되어 있을 경우, 여러 mongoDB로 연결을 할 수 있는 설정이나, Connection Pool과 같은 설정, SSL과 같은 보안 설정등 부가적인 설정이 많으니, 반드시 운영환경에 맞는 설정으로 변경하기를 바란다. 설정 방법은 http://mongodb.github.io/node-mongodb-native/2.1/reference/connecting/connection-settings/ 문서를 참고하자.

 

이때 주의깊게 살펴봐야 하는 부분이 app.use를 이용해서 미들웨어를 추가하였는데, req.dbmongodb 연결을 넘기는 것을 볼 수 있다. 미들웨어로 추가가 되었기 때문에 매번 HTTP 요청이 올때 마다 req 객체에는 db라는 변수로 mongodb 연결을 저장해서 넘기게 되는데, 이는 HTTP 요청을 처리하는 것이 router에서 처리하는 것이 일반적이기 때문에, routerdb 연결을 넘기기 위함이다. 아래 데이타를 삽입하는 라우터 코드를 보자

 

router.post('/insert', function(req, res, next) {

      var userid = req.body.userid;

      var sex = req.body.sex;

      var city = req.body.city;

     

      db = req.db;

      db.get('users').insert({'userid':userid,'sex':sex,'city':city},function(err,doc){

             if(err){

                console.log(err);

                res.status(500).send('update error');

                return;

             }

             res.status(200).send("Inserted");

            

         });

});

Figure 11 /routes/mongo.js 에서 데이타를 삽입하는 코드


req 객체에서 폼 필드를 읽어서 userid,sex,city등을 읽어내고, 앞의 app.js 에서 추가한 미들웨어에서 넘겨준 db 객체를 받아서 db.get('users').insert({'userid':userid,'sex':sex,'city':city},function(err,doc) 수행하여 데이타를 insert 하였다.

 

다음은 userid필드가 HTTP 폼에서 넘어오는 userid 일치하는 레코드를 지우는 코드 예제이다. Insert 부분과 크게 다르지 않고 remove 함수를 이용하여 삭제 하였다.


router.post('/delete', function(req, res, next) {

      var userid = req.body.userid;

     

      db = req.db;

      db.get('users').remove({'userid':userid},function(err,doc){

             if(err){

                console.log(err);

                res.status(500).send('update error');

                return;

             }

             res.status(200).send("Removed");

            

         });

});

Figure 12 /routes/mongo.js 에서 데이타를 삭제하는 코드

 

다음은 데이타를 수정하는 부분이다. Update 함수를 이용하여 데이타를 수정하는데,

db.get('users').update({userid:userid},{'userid':userid,'sex':sex,'city':city},function(err,doc){

와 같이 ‘userid’userid 인 필드의 데이타를 },{'userid':userid,'sex':sex,'city':city} 대치한다.

 

router.post('/update', function(req, res, next) {

      var userid = req.body.userid;

      var sex = req.body.sex;

      var city = req.body.city;

      db = req.db;

      db.get('users').update({userid:userid},{'userid':userid,'sex':sex,'city':city},function(err,doc){

      //db.get('users').update({'userid':userid},{$set:{'sex':'BUSAN'}},function(err,doc){

             if(err){

                console.log(err);

                res.status(500).send('update error');

                return;

             }

             res.status(200).send("Updated");

            

         });

});

Figure 13 /routes/mongo.js 에서 데이타를 수정하는 코드


전체 레코드를 대치하는게 아니라 특정 필드만 수정하고자 하면, $set: 쿼리를 이용하여, 수정하고자하는 필드만 아래와 같이 수정할 수 있다.

db.collection('users').updateOne({_id:userid},{$set:{'sex':'BUSAN'}},function(err,doc){

 

마지막으로 데이타를 조회하는 부분이다. /list URL은 전체 리스트를 리턴하는 코드이고, /get ?userid= 쿼리 스트링으로 정의되는 사용자 ID에 대한 레코드만을 조회해서 리턴한다.

router.get('/list', function(req, res, next) {

      db = req.db;

      db.get('users').find({},function(err,doc){

           if(err) console.log('err');

           res.send(doc);

      });

});

router.get('/get', function(req, res, next) {

      db = req.db;

      var userid = req.query.userid

      db.get('users').findOne({'userid':userid},function(err,doc){

           if(err) console.log('err');

           res.send(doc);

      });

});

Figure 14 /routes/mongo.js 에서 데이타를 조회하는 코드

 

이제 /routes/mongo.js 의 모든 코드 작업이 완료되었다. 이 코드를 호출하기 위한 HTML 폼을 작성하자.

 

<!DOCTYPE html>

<html>

<head>

<meta charset="UTF-8">

<title>Insert title here</title>

</head>

<body>

 

<h1> Native MongoDB Test Example</h1>

<form method='post' action='/insert' name='mongoform' >

      user <input type='text' size='10' name='userid'>

      <input type='submit' value='delete' onclick='this.form.action="/delete"' >

      <input type='button' value='get' onclick='location.href="/get?userid="+document.mongoform.userid.value' >

      <p>

      city <input type='text' size='10' name='city' >

      sex <input type='radio' name='sex' value='male'>male

      <input type='radio' name='sex' value='female'>female

      <p>

      <input type='submit' value='insert' onclick='this.form.action="/insert"' >

      <input type='submit' value='update' onclick='this.form.action="/update"' >

      <input type='button' value='list'  onclick='location.href="/list"' >

     

</form>

</body>

</html>

Figure 15 /public/monksample.html

 

node.js를 실행하고 http://localhost:3000/monksample.html 을 실행해보자



Figure 16 http://localhost:3000/monksample.html 실행 결과

 

아래 insert 버튼을 누르면, 채워진 필드로 새로운 레코드를 생성하고, update 버튼은 user 필드에 있는 사용자 이름으로된 데이타를 업데이트 한다. list 버튼은 컬렉션에서 전체 데이타를 조회해서 출력하고, delete 버튼은 user 필드에 있는 사용자 이름으로된 레코드를 삭제한다. get 버튼은 user 필드에 있는 사용자 이름으로 데이타를 조회하여 리턴한다.

다음은 list로 전체 데이타를 조회하는 화면이다.

 


Figure 17 /list를 수행하여 mongoDB에 저장된 전체 데이타를 조회하는 화면


이 코드의 전체 소스코드는 https://github.com/bwcho75/nodejs_tutorial/tree/master/mongoDBexpress 에 있으니 필요하면 참고하기 바란다


다음 글에서는  node.js의 mongoDB ODM 프레임웍인 mongoose 이용한 접근 방법에 대해서 알아보기로 한다.


저작자 표시 비영리
신고
크리에이티브 커먼즈 라이선스
Creative Commons License

Google Cloud Vision API 사용하기


구글 클라우드 비젼 API 사용하기

조대협 (http://bcho.tistory.com)






빅데이타와 머신러닝과 같은 기술이 요즘 인터넷을 매우고 있는 시대에, 구글이 얼마전 이미지 디텍션 (Image detection)이 가능한, Cloud Vision API라는 오픈 API를 발표하였다. 현재는 베타버전 상태인데, 호기심에 빠르게 한번 테스트를 해봤다.


node.js를 이용하여, 간단한 테스트 프로그램을 만들어서 테스트를 해봤는데, 구현에 걸리는 시간은 불과 10분이 안된듯... (node.js는 역시 프로토타이핑용으로는 정말 좋은듯)


Cloud Vision API 억세스 권한 얻기


Cloud Vision API는 현재 베타 상태이다. 접근을 하려면 별도로 요청을 해야 접근 권한을 받을 수 있다. https://cloud.google.com/vision/ 에서 권한을 신청하면 심사를 하고 권한을 준다. 



Cloud Vision API 활성화 하기

권한을 얻은 후에는 Google Cloud Platform API 관리자 콘솔에 접속해서 Cloud Vision API 사용을 활성화해야 한다. 



다음으로 API를 외부에서 호출하기 위한 API 키를 발급받아야 하는데, Google Cloud Platform API에서는 API 키에 대한 다양한 접근 방식을 제공한다. OAuth 방식의 접근, 서버를 위한 API 키 발급 방식등 여러 방식을 지원하는데, 이 예제에서는 서비스 계정키를 JSON 형태로 다운 받아서 사용하도록 한다.


계정키를 생성 하는 방법은 Google Cloud Platform 콘솔에서 “사용자 인증 정보” 메뉴에서 “서비스 계정 키 만들기” 메뉴를 통해서 키를 생성하고 JSON 파일을 다운로드 받을 수 있다.





Cloud Vision API 호출 하기

https://cloud.google.com/vision/docs/getting-started?utm_source=product-announcement&utm_medium=email&utm_campaign=2016-02-Vision-API&utm_content=NoFT 


를 보면 Cloud Vision API를 호출하는 방법이 자세하게 설명되어 있다. 제공하는 기능에 비해서 API 사용법은 무지 간단한데, JSON으로 REST 방식으로 API를 호출하면 분석 결과를 JSON으로 리턴해준다. 이때 이미지 파일은 이 JSON에 base64 인코딩으로 첨부를 하면 된다. 어렵지 않으니 문서를 한번 쭈욱 보면서 흐름을 따라가보기를 권장한다. 


node.js 모듈 준비하기

base64 인코딩 모듈을 넣어서 샘플 요청을 만드는 것 조차 귀찮아서 node.js의 npm모듈이 이미 있는지 찾아보기로 하였다. 역시나 있다. 테스트를 위해서 사용한 모듈은 google-vision-api-client 라는 모듈이다.https://www.npmjs.com/package/google-vision-api-client


%npm install google-vision-api-client


명령을 사용해서 모듈을 설치한다. 설치가 끝난후에, 모듈이 설치된 디렉토리 (내 경우에는 “/Users/terry/node_modules/google-visionapi-client” - Mac OS)에 들어가서 index.js 파일을 열어보자. 이 모듈은 Cloud Vision API가 예전 알파버전이었을때 개발된 후 업데이트가 되지 않아서 현재 베타 버전의 Cloud Vision API호출이 안된다. Cloud Vision API의 End point가 변경되었기 때문인데.

index.js 파일에서 baseurl의 값을 다음과 같이 바꿔주자. (Cloud vision API의 베타 버전 URL로 변경)


var baseurl = ‘https://vision.googleapis.com/v1/images:annotate';


이제 Cloud Vision API를 호출하기 위한 준비가 끝났다.


node.js로 Cloud Vision API 호출 하기

이제 google-vision-api-client를 이용하여 API를 호출해보자.다음은 google-vision-api-client에서 제공하는 예제이다.


var vision = require('google-vision-api-client');

var requtil = vision.requtil;

 

//Prepare your service account from trust preview certificated project

var jsonfile = '/Users/terry/dev/ws/nodejs/GoogleVisionAPISample/My Project-eee0a2d4532a.json';

 

//Initialize the api

vision.init(jsonfile);

 

//Build the request payloads

var d = requtil.createRequests().addRequest(

requtil.createRequest('/Users/terry/images/dale2.jpg')

.withFeature('FACE_DETECTION', 3)

.withFeature('LABEL_DETECTION', 2)

.build());

 

//Do query to the api server

vision.query(d, function(e, r, d){

if(e) console.log('ERROR:', e);

  console.log(JSON.stringify(d));

});



<예제. visionAPI.js >

코드를 작성하고, jsonfile 경로에 앞에서 다운받은 서비스 계정키 JSON 파일의 경로를 적어주면 된다.


그리고, createRequest 부분에, Google Cloud Vision API로 분석하고자 하는 이미지 파일명을 적고, withFeature라는 메서드를 이용해서 어떤 분석을 할것인지를 명시한다. (이 부분은 뒤에서 다시 설명한다.)


그러면 다음 명령을 통해서 실행을 해보자


%node visionAPI.js


실행을 하면 실행 결과를 json 형태로 리턴을 해주는데, 

테스트에서 사용한 이미지와 결과는 다음과 같다.



<그림. 테스트에서 사용한 이미지 >



{  

   "responses":[  

      {  

         "faceAnnotations":[  

            {  

               "boundingPoly":{  

                  "vertices":[  

                     {  

                        "x":122,

                        "y":52

                     },

                     {  

                        "x":674,

                        "y":52

                     },

                     {  

                        "x":674,

                        "y":693

                     },

                     {  

                        "x":122,

                        "y":693

                     }

                  ]

               },

               "fdBoundingPoly":{  

                  "vertices":[  

                     {  

                        "x":176,

                        "y":208

                     },

                             중략...

                     }

                  ]

               },

               "landmarks":[  

                  {  

                     "type":"LEFT_EYE",

                     "position":{  

                        "x":282.99844,

                        "y":351.67017,

                        "z":-0.0033840234

                     }

                  },

                  {  

                     "type":"RIGHT_EYE",

                     "position":{  

                        "x":443.8624,

                        "y":336.31445,

                        "z":-35.029751

                     }

                  },

               중략...

                  }

               ],

               "rollAngle":-3.8402841,

               "panAngle":-12.196975,

               "tiltAngle":-0.68598062,

               "detectionConfidence":0.8096019,

               "landmarkingConfidence":0.64295566,

               "joyLikelihood":"LIKELY",

               "sorrowLikelihood":"VERY_UNLIKELY",

               "angerLikelihood":"VERY_UNLIKELY",

               "surpriseLikelihood":"VERY_UNLIKELY",

               "underExposedLikelihood":"VERY_UNLIKELY",

               "blurredLikelihood":"VERY_UNLIKELY",

               "headwearLikelihood":"VERY_UNLIKELY"

            }

         ],

         "labelAnnotations":[  

            {  

               "mid":"/m/068jd",

               "description":"photograph",

               "score":0.92346138

            },

            {  

               "mid":"/m/09jwl",

               "description":"musician",

               "score":0.86925673

            }

         ]

      }

   ]

}

<그림. 결과 JSON>


결과를 살펴보면, 눈코입의 위치와, 감정상태등을 상세하게 리턴해주는 것을 볼 수 있다.

joyLikeihood 는 기쁨 감정, sorrowLikeihood는 슬픈 감정들을 나타내는데, 

Cloud Vision API는 여러개의 Feature를 동시에 분석이 가능하다.


예를 들어 얼굴 인식과 로고 인식을 같이 활용하여, 특정 브랜드를 보고 있을때 사람들의 얼굴 표정이 어떤지를 분석함으로써 대략적인 브랜드에 대한 반응을 인식한다던지. 특정 랜드마크와 표정 분석을 통해서 장소에 대한 분석 (재미있는 곳인지? 슬픈 곳인지.) 등으로 활용이 가능하다.


Cloud Vision API의 이미지 분석 기능


Cloud Vision API는 여러가지 형태의 이미지 분석 기능을 제공하는데, 대략적인 내용을 훝어보면 다음과 같다.


  • Label Detection - 사진속에 있는 사물을 찾아준다. 가구, 동물 , 음식등을 인지해서 리턴해준다.
  • Logo Detection - 사진속에서 회사 로고와 같은 로고를 찾아준다.
  • Landmark Detection - 사진속에서 유명한 랜드 마크 (남산 타워, 경복궁등과 같은 건축물이나 자연 경관 이름)를 찾아준다.
  • Face Detection - 사진속에서 사람 얼굴을 찾아준다. 이게 좀 재미있는데 눈코입의 위치등을 리턴하는 것을 물론 표정을 분석하여 감정 상태를 분석하여 리턴해준다. 화가 났는지 기쁜 상태인지 슬픈 상태인지
  • Safe Search Detection - 사진 컨텐츠의 위험도(? 또는 건전성)을 검출해주는데, 성인 컨텐츠, 의학 컨텐츠, 폭력 컨텐츠등의 정도를 검출해준다.
  • Optical Character Recognition - 문자 인식


그외에도 몇가지 추가적인 Feature를 제공하고 있으니 https://cloud.google.com/vision/ 문서를 참고하기 바란다.


Cloud Vision API는 무슨 의미를 제공하는가?


사실 Cloud Vision API 자체로만으로도 대단히 흥미로운 기능을 가지고 있지만, Cloud Vision API는 또 다른 의미를 가지고 있다고 본다.

머신러닝이나 빅데이타, 그리고 인공 지능은 데이타 과학자와 대규모 하드웨어 자원을 가지고 있는 업체의 전유물이었지만, 근래에는 이러한 빅데이타 관련 기술들이 클라우드를 기반으로 하여 API로 제공됨으로써, 누구나 쉽게 빅데이타 기반의 분석 기술을 쉽게 활용할 수 있는 시대가 되어가고 있다.

이미 구글이나 Microsoft의 경우 머신러닝 알고리즘을 클라우드 API로 제공하고 있고, 대규모 데이타 분석의 경우에도 Google Analytics나 Yahoo Flurry등을 통해서 거의 무료로 제공이 되고 있다. (코드 몇줄이면 앱에 추가도 가능하다.)

이러한 접근성을 통해서 많은 서비스와 앱들이 고급 데이타 분석 알고리즘과 인공지능 기능들을 사용할 수 있는 보편화의 시대에 들어 선것이 아닐까?







저작자 표시 비영리
신고
크리에이티브 커먼즈 라이선스
Creative Commons License

안드로이드에서 ListView를 이용한 채팅 UI 만들기


조대협 (http://bcho.tistory.com)


안드로이드 프로그래밍 기본 개념이 어느정도 잡혀가기 시작하니, 몬가 만들어봐야겠다는 생각이 들어서 생각하던중에 결론 낸것이, 간단한 채팅 서비스, 기존에 node.js 하면서 웹용 채팅을 만들어보기도 했고, 찾아보니, 안드로이드용 SocketIO 라이브러리도 잘되어 있어서 서버 연계도 어려울것이 없을것 같고, 또한 메세지가 왔을때 푸쉬 알림을 써야 하는 등 이것저것 실습이 될것 같아서, 결국은 채팅으로 정했다.


서버나 연계 코드 구현보다, 가장 어려운게 역시나 UI 디자인과 프로그래밍인데, 가장 쉬운 방법으로는 ListView를 사용하는 방법이 무난하다. (결국 코딩을 하고 나니 여러가지 한계를 느껴서 다른  UI를 찾으려고 하고는 있지만)


궁극적으로 만들고자 하는 UI는 카카오톡 처럼 말풍선이 나오는 UI이다. 





말풍선은 Ninepatch (나인패치) 이미지 라는 것으로 만들 수 가 있는데, 나인 패치 이미지에 대해서는 나중에 알아보도록 하고, 여기서는 말풍선 없이, 화면에 스크롤되서 내려가는 텍스트 기반의 대화창을 만드는 것을 먼저 진행하도록 한다.  스크롤이 되는 채팅창은 ListView 컴포넌트를 사용해서 구현하는데, 아래 그림과 같이 지난 메세지는 화면에 나오지 않고 현재 대화되는 상황만 보이도록 한다. 





리스트뷰(ListView) 에 대해서


그렇다면 리스트뷰는 무엇인가? 리스트뷰는 안드로이드 뷰 그룹 (View Group)의 일종으로, 스크롤이 가능한 아이템들의 리스트를 출력해주는 그룹이다.



아답터(Adaptor) 의 개념


채팅용 리스트뷰를 만들기 위해서는 아답터(Adaptor)의 개념을 이해해야 하는데, 아답터는 크게 두가지 기능을 한다. 리스트뷰에서 보여질 아이템들을 저장하는 데이타 저장소의 역할과, 리스트뷰안에 아이템이 그려질때 이를 렌더링하는 역할을 한다.


add 메서드를 이용하여, 아이템을 추가하고

getItem(int index)를 이용하여, index  번째의 아이템을 리턴하며 (자바의 일반적인 List형과 유사하다)

View getView(int position, xxx )이 중요한데, position 번째의 아이템을 화면에 출력할때 렌더링하는 역할을 한다.


그러면 실제로 작동하는 코드를 만들어보자. 이 예제에서는 텍스트를 입력하면 리스트 뷰에 추가되고, 텍스트가 입력됨에 따라 쭈욱 아래로 리스트뷰가 자동으로 스크롤되는 예제이다. 





아답터 클래스 구현


제일 먼저 채팅 메세지 리스트를 저장할 아답터 클래스를 구현해보자


package com.example.terry.simplelistview;


import android.content.Context;

import android.graphics.Color;

import android.view.LayoutInflater;

import android.view.View;

import android.view.ViewGroup;

import android.widget.ArrayAdapter;

import android.widget.TextView;


import java.util.ArrayList;

import java.util.List;


/**

 * Created by terry on 2015. 10. 7..

 */

public class ChatMessageAdapter extends ArrayAdapter {


    List msgs = new ArrayList();


    public ChatMessageAdapter(Context context, int textViewResourceId) {

        super(context, textViewResourceId);

    }


    //@Override

    public void add(ChatMessage object){

        msgs.add(object);

        super.add(object);

    }


    @Override

    public int getCount() {

        return msgs.size();

    }


    @Override

    public ChatMessage getItem(int index) {

        return (ChatMessage) msgs.get(index);

    }


    @Override

    public View getView(int position, View convertView, ViewGroup parent) {

        View row = convertView;

        if (row == null) {

            // inflator를 생성하여, chatting_message.xml을 읽어서 View객체로 생성한다.

            LayoutInflater inflater = (LayoutInflater) this.getContext().getSystemService(Context.LAYOUT_INFLATER_SERVICE);

            row = inflater.inflate(R.layout.chatting_message, parent, false);

        }


        // Array List에 들어 있는 채팅 문자열을 읽어

        ChatMessage msg = (ChatMessage) msgs.get(position);


        // Inflater를 이용해서 생성한 View에, ChatMessage를 삽입한다.

        TextView msgText = (TextView) row.findViewById(R.id.chatmessage);

        msgText.setText(msg.getMessage());

        msgText.setTextColor(Color.parseColor("#000000"));


        return row;


    }

}


add나 getItem,getCount등은 메세지를 Java List에 저장하고, n 번째 메세지를 리턴하거나 전체 크기를 저장하는 방식으로 구현한다.


가장 중요한 부분은 getView 메서드인데,리스트의 n 번째 아이템에 대한 내용을 화면에 출력하는 역할을 하며 여기서는 두 단계를 거쳐서 렌더링을 진행한다.


첫번째로는 인플레이터 (inflator)를 사용하여, 아이템을 렌더링할 View 컴포넌트를 ~/layout/XX.XML 에서 읽어오는 역할을 한다.

인플레이터에 대해서 간략하게 짚고 넘어가면, View등 화면등을 디자인할 때 안드로이드에서는 XML 을 사용하여 쉽게 View를 정의할 수 있다. 이렇게 정의된 XML을 실제 View 자바 Object로 생성을 해주는 것이 인플레이터(inflator)이다. 


    LayoutInflater inflater = (LayoutInflater) this.getContext().getSystemService(Context.LAYOUT_INFLATER_SERVICE); 


를 통해서 인플레이터를 생성하고, 이 인플레이터를 통해서 각 아이템을 출력해줄 View를 생성하여 row라는 변수에 저장한다.


 row = inflater.inflate(R.layout.chatting_message, parent, false);


이때, 레이아웃을 정의한 XML 파일은 chatting_message.xml 로 안드로이드 프로젝트의 ~/layout 디렉토리 아래에 있으며, 인플레이트 할때는 R.layout.chatting_message라는 이름으로 지칭한다.


chatting_message.xml 의 내용은 다음과 같다.


<?xml version="1.0" encoding="utf-8"?>

<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"

    android:orientation="vertical" android:layout_width="match_parent"

    android:layout_height="match_parent">


    <TextView

        android:layout_width="match_parent"

        android:layout_height="match_parent"

        android:textAppearance="?android:attr/textAppearanceMedium"

        android:text="Chat message"

        android:id="@+id/chatmessage"

        android:gravity="left|center_vertical|center_horizontal"

        android:layout_marginLeft="20dp" />

</LinearLayout>



이렇게 아이템을 표시할 View가 생성되었으면, 그 안에 알맹이를 채워넣어야 하는데, 채팅 메세지를 저장하는 List 객체에서, position 번째의 채팅 메세지를 읽어온 후에, row 뷰 안에 있는 TextView에 그 채팅 메세지를 채워 넣는다. 


getView 코드에서 주의해서 봐야할 부분이


  View row = convertView;

        if (row == null) { …


인데, 가만히 보면 row가 null 일 경우에만 인플레이터를 이용해서 row를 생성하는 것을 볼 수 있다.


ListView의 특징중 하나는, 아이템을 랜더링 하는 View 객체가 매번 생성되는 것이 아니라, 해당 아이템에 대해서 이미 생성되어 있는 View가 있다면, getView( .. ,View convertView, ..) 를 통해서 인자로 전달된다.


그래서, convertView가 null 인지를 체크하고, 만약에 null이 아닌 경우에만 View를 생성한다. 


여기까지가 채팅 메세지를 저장하고, 각 메세지를 렌더링 해주는 ListView 용 아답터의 구현이었다. 

그러면 아답터를 이용한 ListView를 사용하여 채팅 메세지를 출력하는 부분을 구현해보자


아래는 ListView를 안고 있는 MainActivity이다.


package com.example.terry.simplelistview;


import android.database.DataSetObserver;

import android.support.v7.app.ActionBarActivity;

import android.os.Bundle;

import android.view.Menu;

import android.view.MenuItem;

import android.view.View;

import android.widget.AbsListView;

import android.widget.EditText;

import android.widget.ListView;



public class MainActivity extends ActionBarActivity {

    ChatMessageAdapter chatMessageAdapter;


    @Override

    protected void onCreate(Bundle savedInstanceState) {

        super.onCreate(savedInstanceState);

        setContentView(R.layout.activity_main);

    }


    @Override

    public boolean onCreateOptionsMenu(Menu menu) {


        // Inflate the menu; this adds items to the action bar if it is present.

        getMenuInflater().inflate(R.menu.menu_main, menu);


        chatMessageAdapter = new ChatMessageAdapter(this.getApplicationContext(),R.layout.chatting_message);

        final ListView listView = (ListView)findViewById(R.id.listView);

        listView.setAdapter(chatMessageAdapter);

        listView.setTranscriptMode(ListView.TRANSCRIPT_MODE_ALWAYS_SCROLL); // 이게 필수


        // When message is added, it makes listview to scroll last message

        chatMessageAdapter.registerDataSetObserver(new DataSetObserver() {

            @Override

            public void onChanged() {

                super.onChanged();

                listView.setSelection(chatMessageAdapter.getCount()-1);

            }

        });

        return true;

    }


    @Override

    public boolean onOptionsItemSelected(MenuItem item) {

        // Handle action bar item clicks here. The action bar will

        // automatically handle clicks on the Home/Up button, so long

        // as you specify a parent activity in AndroidManifest.xml.

        int id = item.getItemId();


        //noinspection SimplifiableIfStatement

        if (id == R.id.action_settings) {

            return true;

        }


        return super.onOptionsItemSelected(item);

    }


    public void send(View view){

        EditText etMsg = (EditText)findViewById(R.id.etMessage);

        String strMsg = (String)etMsg.getText().toString();

        chatMessageAdapter.add(new ChatMessage(strMsg));

    }

}


먼저 send는 “SEND” 버튼을 눌렀을때, 화면상에서 채팅 메세지를 읽어드려서 Adapter에 저장하는 역할을 한다.

가장 중요한 메서드는  onCreateOptionsMenu인데, 이 메서드의 주요 내용은, Adapter를 생성하여, Adapter를 listView에 바인딩한다.


  chatMessageAdapter = new ChatMessageAdapter(this.getApplicationContext(),R.layout.chatting_message);

        final ListView listView = (ListView)findViewById(R.id.listView);

        listView.setAdapter(chatMessageAdapter);


다음 기능으로는, 새로운 아이템이 listView에 추가되었을때, 맨 아래로 화면을 스크롤 하는 기능인데, 이는 Adapter에 DataObserver를 바인딩해서, 데이타가 바뀌었을때 즉 채팅 메세지가 추가되었을때, listView에서 리스트업 되는 아이템 목록을 맨 아래로 이동 시킨다.


        // When message is added, it makes listview to scroll last message