블로그 이미지
평범하게 살고 싶은 월급쟁이 기술적인 토론 환영합니다.같이 이야기 하고 싶으시면 부담 말고 연락주세요:이메일-bwcho75골뱅이지메일 닷컴. 조대협


Archive»


 

'예제'에 해당되는 글 23

  1. 2019.05.14 SRE #4-예제로 보는 SLI/SLO 정의 방법
  2. 2019.03.22 로그 프레임워크 #2 - 기본 로깅 및 JSON 포맷으로 로깅하기
  3. 2018.11.25 Istio #4 - Istio 설치와 BookInfo 예제 (2)
  4. 2018.04.15 Circuit breaker 패턴을 이용한 장애에 강한 MSA 서비스 구현하기 #2 - Spring에서 Circuit breaker 구현
  5. 2018.04.04 Circuit breaker 패턴을 이용한 장애에 강한 MSA 서비스 구현하기 #1 - Circuit breaker와 넷플릭스 Hystrix (1)
  6. 2017.08.30 Tensorflow Object Detection API를 이용한 물체 인식 #3-얼굴은 학습시켜보자
  7. 2017.05.16 연예인 얼굴 인식 모델을 만들어보자 - #1. 학습 데이타 준비하기 (4)
  8. 2017.04.03 텐서플로우의 세션,그래프 그리고 함수의 개념 (1)
  9. 2017.01.09 딥러닝을 이용한 숫자 이미지 인식 #2/2-예측 (12)
  10. 2016.12.28 텐서플로우 #3-숫자를 인식하는 모델을 만들어보자 (11)
  11. 2016.08.29 파이어베이스 애널러틱스를 이용한 모바일 데이타 분석 #1-Hello Firebase (5)
  12. 2016.08.29 세번째 책이 나왔습니다. (4)
  13. 2016.03.29 빠르게 훝어 보는 node.js - redis 사용하기
  14. 2016.03.18 빠르게 훝어 보는 node.js - promise를 이용한 node.js에서 콜백헬의 처리 (2)
  15. 2016.03.14 빠르게 훝어 보는 node.js - monk 모듈을 이용한 mongoDB 연결
  16. 2015.10.16 안드로이드 채팅 UI 만들기 #2 - 나인패치 이미지를 이용한 채팅 버블 (1)
  17. 2015.09.16 안드로이드 웹뷰(Webview)의 이해와 성능 최적화 방안 (3)
  18. 2015.09.15 안드로이드에서 REST API 호출하기 (2)
  19. 2014.05.17 하이버네이트 쉽게 입문하기 (기초)-환경설정,입력조회 개발 (5)
  20. 2014.04.24 빠르게 훝어 보는 node.js - #12 Socket.IO 4/4 - 채팅방 기능 추가하기 (4)
 

SRE #4-예제로 살펴보는 SLI/SLO 정의 방법

조대협 (http://bcho.tistory.com)


앞에서 SRE의 주요 지표인 SLO/SLI의 개념에 대해서 설명하였는데, 그러면 실제 서비스에서는 어떻게 SLO/SLI를 정의하는지에 대해서 알아본다.

SLI는 사용자 스토리당 3~5개 정도가 적당하다. 사용자 스토리는 로그인, 검색, 상품 상세 정보와 같이 하나의 기능을 의미한다고 보면된다.


아래 그림과 같은 간단한 게임 서비스가 있다고 가정하자. 이 서비스는 웹사이트를 가지고 있고, 그리고 앱을 통해서 접근이 가능한데, 내부적으로 API 서비스를 통해서 서비스가 된다. 내부 서비스에는 사용자 랭킹(Rank ), 사용자 프로파일 (User profiles) 등의 서비스가 있다.



이 서비스에서 "사용자 프로필" 에 대한 SLI를 정의해보도록 하자.

SLI 지표 레퍼런스

앞에서 설명한 SLI 지표로 주로 사용되는 지표들을 되집어 보면 다음과 같다.

  • 응답 시간 (Request latency) : 시스템의 응답시간

  • 에러율 (Error rate%)  : 전체 요청에서 실패한 요청의 비율

  • 처리량(Throughput) : 일반적으로 초당 처리량으로 측정하고 TPS (Thoughput per second) 또는 QPS (Query per second)라는 단위를 사용한다.

  • 가용성(availability)  : 시스템의 업타임 비율로, 앞에서 예를 들어 설명하였다.

  • 내구성(Durability-스토리지 시스템만 해당) : 스토리지 시스템에만 해당하는데, 장애에도 데이타가 유실되지 않을 확률이다.

이런 지표들이 워크로드 타입에 따라 어떤 지표들이 사용되는지 정의해놓은 정보를 다시 참고해 보면 다음과 같다.

  • 사용자에게 서비스를 제공하는 서비스 시스템 (웹,모바일등) : 가용성, 응답시간, 처리량

  • 스토리지 시스템(백업,저장 시스템): 가용성, 응답시간, 내구성

  • 빅데이터 분석 시스템 : 처리량, 전체 End-to-End 처리 시간

  • 머신러닝 시스템 : 서빙 응답시간, 학습 시간, 처리량, 가용성, 서빙 정확도


이 서비스는 "사용자에게 서비스를 제공하는 서비스 시스템 패턴" 이기 때문에, 이 중에서 가용성과 응답시간을 SLI로 사용하기로 한다.

가용성 SLI

가용성은 프로파일 페이지가 성공적으로 로드된 것으로 측정한다.

그러면 성공적으로 로드 되었다는 것은 어떻게 측정할 것인가? 그리고, 성공호출 횟수와 실패 횟수는 어떻게 측정할것인가? 에 대한 질문이 생긴다.

이 서비스는 웹기반 서비스이기 때문에, HTTP GET /profile/{users}와 /profile/{users}/avatar 가 성공적으로 호출된 비율을 측정하면 된다. 성공 호출은 어떻게 정의할것인가? HTTP response code 200번만 성공으로 생각할 수 있지만 5xx는 시스템 에러이지만 3xx, 4xx는 애플리케이션에서 처리하는 에러 처리 루틴이라고 봤을때, 3xx,4xx도 성공 응답에 포함시켜야 한다. 그래서 2xx,3xx,4xx의 횟수를 성공 호출로 카운트 한다.


그러면 이 응답을 어디서 수집해야 할것인가? 앞의 아키텍쳐 다이어그램을 보면 API/웹서비스 앞에 로드밸런서가 있는 것을 볼 수 있는데, 개별 서버 (VM)에서 측정하는 것이 아니라 앞단의 로드밸런서에서 측정해도 HTTP 응답 코드를 받을 수 있기 때문에, 로드밸런서의 HTTP 응답 코드를 카운트 하기로 한다.

응답시간 SLI

그러면 같은 방식으로 응답시간에 대한 SLI를 정의해보자

응답 시간은 프로파일 페이지가 얼마나 빨리 로드 되었는지를 측정한다. 그런데 빠르다는 기준은 무엇이고, 언제부터 언제까지를 로딩 시간으로 측정해야 할것인가?

이 서비스는 HTTP GET /profile/{users} 를 호출하기 때문에, 이 서비스가 100ms 를 임의의 기준값으로 하여, 이 값 대비의 응답시간으로 정의한다.

응답 시간 역시 가용성과 마찬가지로 로드밸런서에서 측정하도록 한다.


이렇게 SLI를 정의하였으면, 여기에 측정 기간과 목표값을 정해서 SLO를 정한다.



가용성 SLO는 28일 동안 99.95%의 응답이 성공한것으로 정의한다.

응답시간 SLO는 28일 동안 90%의 응답이 500ms 안에 도착하는 것으로 정의한다. 또는 좀더 발전된 방법으로 99% 퍼센타일의 응답의 90%가 500ms 안에 도착하는 것으로 높게 잡을 수 있지만, 처음 정한 SLO이기때문에, 이정도 수준으로 시작하고 점차 높여가는 모델을 사용한다.

복잡한 서비스의 SLI 의 정의

앞의 예제를 통해서 SLI와 SLO를 정의하는 방법에 대해서 알아보았다. 사용자 스토리 단위로 SLI를 정한다하더라도, 현실에서의 서비스는 훨씬 복잡하고 많은 개수를 갖는다.

SLI가 많아지면, 관련된 사람들이 전체 SLI를 보기 어렵기 때문에 조금 더 단순화되고 직관적인 지표가 필요하다.

예를 들어보다. 구글 플레이 스토어를 예를 들어봤을때, 구글 플레이스토어는 홈 화면, 검색, 카테고리별 앱 리스트 그리고 앱 상세 정보와 같이 크게 4가지 사용자 스토리로 정의할 수 있다.


이 4가지 사용자 스토리를 aggregation (합이나 평균)으로 합쳐서 하나의 지표인 탐색(Browse)라는 지표로 재 정의할 수 있다. 아래는 4개의 SLI를 각각 측정한 값이다.



이 개별 SLI들을 합쳐서 표현하면 다음과 같이 표현할 수 있다. 전체 SLI의 값을 합친 후에, 백분률로 표현하였다.


이 하나의 지표를 사용하면 4개의 기능에 대한 SLI를 대표할 수 있다. 이렇게 개별 SLI의 합이나 평균을 사용하는 경우는 대부분의 경우에는 충분하지만 특정 서비스가 비지니스 임팩트가 더 클 경우 이를 동일하게 취급해서 합해버리면 중요한 서비스가 나도 이 대표값에는 제대로 반영이 안될 수 있기 때문에, 필요한 경우 개별 SLI에 적절한 가중치를 곱해서 값을 계산하는 것도 방법이 된다.



본인은 구글 클라우드의 직원이며, 이 블로그에 있는 모든 글은 회사와 관계 없는 개인의 의견임을 알립니다.

댓글을 달아 주세요

로그 시스템 #2- 자바 로그 & JSON 로그 포맷

조대협 (http://bcho.tistory.com)


앞 글에서 간단하게 자바 로깅 프레임워크에 대해서 알아보았다. 그러면 앞에서 추천한 slf4j와 log4j2로 실제 로깅을 구현해보자

SLF4J + log4j2

메이븐 프로젝트를 열고 dependencies 부분에 아래 의존성을 추가한다. 버전은 최신 버전을 확인하도록 한다. artifactid가 log4j-slf4j-impl 이지만, log4j가 아니라 log4j2가 사용된다.


<dependency>

<groupId>org.apache.logging.log4j</groupId>

<artifactId>log4j-slf4j-impl</artifactId>

<version>2.11.2</version>

</dependency>


다음 log4j2의 설정 정보 파일인 log4j2.properties 파일을 src/main/resources 디렉토리 아래에 다음과 같이 생성한다. Appender나, Layout등 다양한 정보 설정이 있지만 그 내용은 나중에 자세하게 설명하도록 한다.


appenders=xyz


appender.xyz.type = Console

appender.xyz.name = myOutput

appender.xyz.layout.type = PatternLayout

appender.xyz.layout.pattern = [MYLOG %d{yy-MMM-dd HH:mm:ss:SSS}] [%p] [%c{1}:%L] - %m%n


rootLogger.level = info


rootLogger.appenderRefs = abc


rootLogger.appenderRef.abc.ref = myOutput


그리고 아래와 같이 코드를 만든다.

LoggerFactory를 이용해서 Logger를 가지고 온다. 현재 클래스 명에 대한 Logger 를 가지고 오는데, 위의 설정 파일을 보면 rootLogger만 설정하였기 때문에, rootLogger가 사용된다.

package com.terry.logging.helloworld;


import org.slf4j.Logger;

import org.slf4j.LoggerFactory;



public class App

{

   private static Logger log = LoggerFactory.getLogger(App.class);

   public static void main( String[] args )

   {

       System.out.println( "Hello World!" );

       

       log.info("Hello slf4j");

   }

}



가저온 logger를 이용해서 log.info로 로그를 출력한다.

콘솔로 출력된 로그는 아래와 같다.

[MYLOG 19-Mar-18 23:07:01:373] [INFO] [App:71] - Hello slf4j


JSON 포맷으로 로그 출력

근래에는 시스템이 분산 구조를 가지고 있기 때문에 텍스트 파일로 남겨서는 여러 분산된 서비스의 로그를 모아서 보기가 어렵다. 그래서, 이런 로그를 중앙 집중화된 서버로 수집 및 분석하는 구조를 가지는데, 수집 서버에서는 이 로그들을 구조화된 포맷으로 저장하는 경우가 일반적이다. 각 로그의 내용을 저장 구조의 개별 자료 구조(예를 들어 테이블의 컬럼)에 맵핑해서 저장하는데, 이를 위해서는 로그가 JSON,XML 또는 CSV와 같은 형태로 구조화가 되어 있어야 한다.

이런 구조화된 로그를 structured logging 이라고 한다. 로그 엔트리 하나를 JSON에 포함해서 출력하는 방법에 대해서 알아본다.

slf4j + logback

SLF4 + logback을 이용하여 레이아웃을 JSON으로 출력하는 코드이다.


package com.terry.logging.logback;


import java.util.Map;

import java.util.TreeMap;


import org.slf4j.Logger;

import org.slf4j.LoggerFactory;


import com.fasterxml.jackson.core.JsonProcessingException;

import com.fasterxml.jackson.databind.ObjectMapper;


public class App

{

   private static Logger log = LoggerFactory.getLogger(App.class);

   public static void main( String[] args ) throws JsonProcessingException

   {


       log.info("hello log4j");

   }

}


pom.xml에 아래와 같이 logback과 json 관련 dependency를 추가한다.


<dependencies>

<dependency>

<groupId>ch.qos.logback</groupId>

<artifactId>logback-classic</artifactId>

<version>1.1.7</version>

</dependency>


<dependency>

<groupId>ch.qos.logback.contrib</groupId>

<artifactId>logback-json-classic</artifactId>

<version>0.1.5</version>

</dependency>


<dependency>

<groupId>ch.qos.logback.contrib</groupId>

<artifactId>logback-jackson</artifactId>

<version>0.1.5</version>

</dependency>


<dependency>

<groupId>com.fasterxml.jackson.core</groupId>

<artifactId>jackson-databind</artifactId>

<version>2.9.3</version>

</dependency>

</dependencies>



마지막으로 src/main/resources.xml 파일을 아래와 같이 작성한다.  

<?xml version="1.0" encoding="UTF-8"?>

<configuration>

   <appender name="stdout" class="ch.qos.logback.core.ConsoleAppender">

       <encoder class="ch.qos.logback.core.encoder.LayoutWrappingEncoder">

           <layout class="ch.qos.logback.contrib.json.classic.JsonLayout">

               <timestampFormat>yyyy-MM-dd'T'HH:mm:ss.SSSX</timestampFormat>

               <timestampFormatTimezoneId>Etc/UTC</timestampFormatTimezoneId>


               <jsonFormatter class="ch.qos.logback.contrib.jackson.JacksonJsonFormatter">

                   <prettyPrint>true</prettyPrint>

               </jsonFormatter>

           </layout>

       </encoder>

   </appender>


   <root level="debug">

       <appender-ref ref="stdout"/>

   </root>

</configuration>


아래는 출력 결과이다. message 필드에 로그가 출력 된것을 볼 수 있다.


{

 "timestamp" : "2019-03-19T07:24:31.906Z",

 "level" : "INFO",

 "thread" : "main",

 "logger" : "com.terry.logging.logback.App",

 "message" : "hello log4j",

 "context" : "default"

}


slf4j + log4j2

다음은 slft4+log4j2 를 이용한 예제이다.  logback과 크게 다르지는 않다.

아래와 같이 pom.xml 의 dependencies에 아래 내용을 추가하자. json layout은 jackson을 사용하기 때문에 아래와 같이 jackson에 대한 의존성도 함께 추가한다.


<dependency>

<groupId>org.apache.logging.log4j</groupId>

<artifactId>log4j-slf4j-impl</artifactId>

<version>2.11.2</version>

</dependency>

<dependency>

<groupId>com.fasterxml.jackson.core</groupId>

<artifactId>jackson-core</artifactId>

<version>2.7.4</version>

</dependency>

<dependency>

<groupId>com.fasterxml.jackson.core</groupId>

<artifactId>jackson-databind</artifactId>

<version>2.7.4</version>

</dependency>

<dependency>

<groupId>com.fasterxml.jackson.core</groupId>

<artifactId>jackson-annotations</artifactId>

<version>2.7.4</version>

</dependency>


다음 아래와 같이 log4j2.properties 파일을 src/main/resources 폴더에 저장한다.


status = info


appender.ana_whitespace.type = Console

appender.ana_whitespace.name = ana_whitespace

appender.ana_whitespace.layout.type = JsonLayout

appender.ana_whitespace.layout.propertiesAsList = false

appender.ana_whitespace.layout.compact = false

appender.ana_whitespace.layout.eventEol = true

appender.ana_whitespace.layout.objectMessageAsJsonObject = true

appender.ana_whitespace.layout.complete= true

appender.ana_whitespace.layout.properties= true


rootLogger.level = info

rootLogger.appenderRef.ana_whitespace.ref = ana_whitespace


위에 보면 layout.type을 JsonLayout으로 지정하였다. 기타 다른 필드에 대한 정보는

정보는 https://logging.apache.org/log4j/2.0/manual/layouts.html 를 참고하기 바란다.


그리고 아래와 같이 코드를 이용해서 info 레벨의 로그를 출력해보자

package com.terry.logging.jsonlog;

import org.slf4j.Logger;

import org.slf4j.LoggerFactory;



public class App

{

private static Logger log = LoggerFactory.getLogger(App.class);

   public static void main( String[] args )

   {

       

       log.info("Hello json log");

       log.error("This is error");

       log.warn("this is warn");

   }

}


코드를 컴파일 하고 실행하면 아래와 같은 형태로 로그가 출력된다. 로그 출력 형태는 logback과는 많이 차이가 있다.


[

{

 "thread" : "main",

 "level" : "INFO",

 "loggerName" : "com.terry.logging.jsonlog.App",

 "message" : "Hello json log",

 "endOfBatch" : false,

 "loggerFqcn" : "org.apache.logging.slf4j.Log4jLogger",

 "instant" : {

   "epochSecond" : 1552923302,

   "nanoOfSecond" : 38337000

 },

 "contextMap" : { },

 "threadId" : 1,

 "threadPriority" : 5

}

, {

 "thread" : "main",

 "level" : "ERROR",

 "loggerName" : "com.terry.logging.jsonlog.App",

 "message" : "This is error",

 "endOfBatch" : false,

 "loggerFqcn" : "org.apache.logging.slf4j.Log4jLogger",

 "instant" : {

   "epochSecond" : 1552923302,

   "nanoOfSecond" : 109170000

 },

 "contextMap" : { },

 "threadId" : 1,

 "threadPriority" : 5

}

, {

 "thread" : "main",

 "level" : "WARN",

 "loggerName" : "com.terry.logging.jsonlog.App",

 "message" : "this is warn",

 "endOfBatch" : false,

 "loggerFqcn" : "org.apache.logging.slf4j.Log4jLogger",

 "instant" : {

   "epochSecond" : 1552923302,

   "nanoOfSecond" : 109618000

 },

 "contextMap" : { },

 "threadId" : 1,

 "threadPriority" : 5

}


]


json을 여러가지 포맷으로 출력할 수 있다. 위의 로그를  잘보면 로그 시작과 끝에 json 포맷을 맞추기 위해서 “[“와 “]”를 추가하고, 로그 레코드 집합당 “,”로 레코드를 구별한것을 볼 수 있다. 만약에 “[“,”]”를 로그 처음과 마지막에서 제거하고, 로그 레코드 집합동 “,”를 제거하고 newline으로만 분류하고 싶다면 log4j2.properties 파일에서 appender.ana_whitespace.layout.complete = false로 하면 된다.

아래는 layout.complete를 false로 하고 출력한 결과 이다.


{ ←  이부분에 “[” 없음

 "thread" : "main",

 "level" : "INFO",

 "loggerName" : "com.terry.logging.jsonlog.App",

 "message" : "Hello json log",

 "endOfBatch" : false,

 "loggerFqcn" : "org.apache.logging.slf4j.Log4jLogger",

 "instant" : {

   "epochSecond" : 1552923722,

   "nanoOfSecond" : 98574000

 },

 "contextMap" : { },

 "threadId" : 1,

 "threadPriority" : 5

} ←  이부분에 콤마가 없음

{

 "thread" : "main",

 "level" : "ERROR",

 "loggerName" : "com.terry.logging.jsonlog.App",

 "message" : "This is error",

 "endOfBatch" : false,

 "loggerFqcn" : "org.apache.logging.slf4j.Log4jLogger",

 "instant" : {

   "epochSecond" : 1552923722,

   "nanoOfSecond" : 167047000

 },

 "contextMap" : { },

 "threadId" : 1,

 "threadPriority" : 5

}

{

 "thread" : "main",

 "level" : "WARN",

 "loggerName" : "com.terry.logging.jsonlog.App",

 "message" : "this is warn",

 "endOfBatch" : false,

 "loggerFqcn" : "org.apache.logging.slf4j.Log4jLogger",

 "instant" : {

   "epochSecond" : 1552923722,

   "nanoOfSecond" : 167351000

 },

 "contextMap" : { },

 "threadId" : 1,

 "threadPriority" : 5

} ←  이부분에 “]” 없음


그리고 로그파일을 보는데, JSON의 경우에는 위와 같이 각 element 마다 줄을 바꿔서 사람이 읽기 좋은 형태이기는 하지만, 대신 매번 줄을 바꾸기 때문에 검색이 어려운 경우가 있다. 그래서 로그 레코드 하나를 줄 바꿈 없이 한줄에 모두 출력할 수 있도록 할 수 있는데, appender.ana_whitespace.layout.compact = true로 주면 된다. 아래는 옵션을 적용한 결과 이다.

{"thread":"main","level":"INFO","loggerName":"com.terry.logging.jsonlog.App","message":"Hello json log","endOfBatch":false,"loggerFqcn":"org.apache.logging.slf4j.Log4jLogger","instant":{"epochSecond":1552923681,"nanoOfSecond":430798000},"contextMap":{},"threadId":1,"threadPriority":5}

{"thread":"main","level":"ERROR","loggerName":"com.terry.logging.jsonlog.App","message":"This is error","endOfBatch":false,"loggerFqcn":"org.apache.logging.slf4j.Log4jLogger","instant":{"epochSecond":1552923681,"nanoOfSecond":491757000},"contextMap":{},"threadId":1,"threadPriority":5}

{"thread":"main","level":"WARN","loggerName":"com.terry.logging.jsonlog.App","message":"this is warn","endOfBatch":false,"loggerFqcn":"org.apache.logging.slf4j.Log4jLogger","instant":{"epochSecond":1552923681,"nanoOfSecond":492095000},"contextMap":{},"threadId":1,"threadPriority":5}



본인은 구글 클라우드의 직원이며, 이 블로그에 있는 모든 글은 회사와 관계 없는 개인의 의견임을 알립니다.

댓글을 달아 주세요

Istio #4 - 설치 및 BookInfo 예제

조대협 (http://bcho.tistory.com)

Istio 설치

그러면 직접 Istio 를 설치해보자, 설치 환경은 구글 클라우드의 쿠버네티스 환경을 사용한다. (쿠버네티스는 오픈소스이고, 대부분의 클라우드에서 지원하기 때문에 설치 방법은 크게 다르지 않다.)

쿠버네티스 클러스터 생성

콘솔에서 아래 그림과 같이 istio 라는 이름으로 쿠버네티스 클러스터를 생성한다. 테스트용이기 때문에, 한존에 클러스터를 생성하고, 전체 노드는 3개 각 노드는 4 CPU/15G 메모리로 생성하였다.



다음 작업은 구글 클라우드 콘솔에서 Cloud Shell내에서 진행한다.

커맨드 라인에서 작업을 할것이기 때문에, gCloud SDK를 설치(https://cloud.google.com/sdk/gcloud/) 한후에,

%gcloud auth login

gcloud 명령어에 사용자 로그인을 한다.


그리고 작업을 편리하게 하기 위해서 아래와 같이 환경 변수를 설정한다. 쿠버네티스 클러스터를 생성한 리전과 존을 환경 변수에 아래와 같이 설정한다. 예제에서는 asia-southeast1 리전에 asia-southeast1-c 존에 생성하였다. 그리고 마지막으로 생성한 쿠버네티스 이름을 환경 변수로 설정한다. 예제에서 생성한 클러스터명은 istio이다.

export GCP_REGION=asia-southeast1
export GCP_ZONE=asia-southeast1-c
export GCP_PROJECT_ID=$(gcloud info --format='value(config.project)')
export K8S_CLUSTER_NAME=istio

다음 kubectl 명령어를 사용하기 위해서, 아래과 같이 gcloud 명령어를 이용하여 Credential을 세팅한다

% gcloud container clusters get-credentials $K8S_CLUSTER_NAME \
   --zone $GCP_ZONE \
   --project $GCP_PROJECT_ID

Credential 설정이 제대로 되었는지

% kubectl get pod -n kube-system

명령어를 실행하여, 쿠버네티스 시스템 관련 Pod 목록이 잘 나오는지 확인한다.

Istio 설치

쿠버네티스 클러스터가 준비되었으면, 이제 Istio를 설치한다.

Helm 설치

Istio는 Helm 패키지 매니져를 통해서 설치 한다.

% curl https://raw.githubusercontent.com/kubernetes/helm/master/scripts/get > get_helm.sh
% chmod 700 get_helm.sh
% ./get_helm.sh

Istio 다운로드

Istio 를 다운로드 받는다. 아래는 1.0.4 버전을 다운 받는 스크립트이다.

% cd ~

% curl -L https://git.io/getLatestIstio | sh -

% cd istio-1.0.4

% export PATH=$PWD/bin:$PATH

Helm 초기화

Istio를 설치하기 위해서 Helm용 서비스 어카운트를 생성하고, Helm을 초기화 한다.

% kubectl create -f install/kubernetes/helm/helm-service-account.yaml

% helm init --service-account tiller

Istio 설치

다음 명령어를 이용하여 Istio를 설치한다. 설치시 모니터링을 위해서 모니터링 도구인 kiali,servicegraph 그리고 grafana 설치 옵션을 설정하여 아래와 같이 추가 설치 한다.

% helm install install/kubernetes/helm/istio \

--name istio \

--namespace istio-system \

--set tracing.enabled=true \

--set global.mtls.enabled=true \

--set grafana.enabled=true \

--set kiali.enabled=true \

--set servicegraph.enabled=true


설치가 제대로 되었는지 kubectl get pod명령을 이용하여, istio 네임스페이스의 Pod 목록을 확인해보자

% kubectl get pod -n istio-system




BookInfo 샘플 애플리케이션 설치

Istio 설치가 끝났으면, 사용법을 알아보기 위해서 간단한 예제 애플리케이션을 설치해보자, Istio에는 BookInfo (https://istio.io/docs/examples/bookinfo/)  라는 샘플 애플리케이션이 있다.

BookInfo 애플리케이션의 구조

아래 그림과 같이 productpage 서비스 안에, 책의 상세 정보를 보여주는 details 서비스와 책에 대한 리뷰를 보여주는 reviews 서비스로 구성이 되어 있다.  


시스템의 구조는 아래와 같은데, 파이썬으로 개발된 productpage 서비스가, 자바로 개발된 review 서비스과 루비로 개발된 details 서비스를 호출하는 구조이며, review 서비스는 v1~v3 버전까지 배포가 되어 있다. Review 서비스 v2~v3는 책의 평가 (별점)를 보여주는  Rating 서비스를 호출하는 구조이다



< 그림 Book Info 마이크로 서비스 구조 >

출처 : https://istio.io/docs/examples/bookinfo/

BookInfo 서비스 설치

Istio의 sidecar injection 활성화

Bookinfo 서비스를 설치하기 전에, Istio의 sidecar injection 기능을 활성화 시켜야 한다.

앞에서도 설명하였듯이 Istio는 Pod에 envoy 를 sidecar 패턴으로 삽입하여, 트래픽을 컨트롤 하는 구조이 다. Istio는 이 sidecar를 Pod 생성시 자동으로 주입 (inject)하는 기능이 있는데, 이 기능을 활성화 하기 위해서는 쿠버네티스의 해당 네임스페이스에 istio-injection=enabled 라는 라벨을 추가해야 한다.

다음 명령어를 이용해서 default 네임 스페이스에 istio-injection=enabled 라벨을 추가 한다.

% kubectl label namespace default istio-injection=enabled


라벨이 추가되었으면

% kubectl get ns --show-labels

를 이용하여 라벨이 제대로 적용이 되었는지 확인한다.


Bookinfo 애플리케이션 배포

Bookinfo 애플리케이션의 쿠버네티스 배포 스크립트는 samples/bookinfo 디렉토리에 들어있다. 아래 명령어를 실행해서 Bookinfo 앺ㄹ리케이션을 배포하자.

% kubectl apply -f samples/bookinfo/platform/kube/bookinfo.yaml


배포를 완료한 후 kubectl get pod 명령어를 실행해보면 다음과 같이 productpage, detail,rating 서비스가 배포되고, reviews 서비스는 v1~v3까지 배포된것을 확인할 수 있다.



Kubectl get svc 를 이용해서 배포되어 있는 서비스를 확인하자



Prodcutpcage,rating,reviews,details 서비스가 배포되어 있는데, 모두 ClusterIP 타입으로 배포가 되어 있기 때문에 외부에서는 접근이 불가능하다.


Istio gateway 설정

이 서비스를 외부로 노출 시키는데, 쿠버네티스의 Ingress나 Service는 사용하지 않고, Istio의 Gateway를 이용한다.

Istio의 Gateway는 쿠버네티스의 커스텀 리소스 타입으로, Istio로 들어오는 트래픽을 받아주는 엔드포인트 역할을 한다. 여러 방법으로 구현할 수 있으나, Istio에서는 디폴트로 배포되는 Gateway는 Pod 형식으로 배포되어 Load Balancer 타입의 서비스로 서비스 된다.


먼저 Istio Gateway를 등록한후에, Gateway를 통해 서비스할 호스트를 Virtual Service로 등록한다.


아래는 bookinfo에 대한 Gateway를 등록하는 Yaml 파일이다.


apiVersion: networking.istio.io/v1alpha3

kind: Gateway

metadata:

 name: bookinfo-gateway

spec:

 selector:

   istio: ingressgateway # use istio default controller

 servers:

 - port:

     number: 80

     name: http

     protocol: HTTP

   hosts:

   - "*"


selector를 이용해서 gateway 타입을 istio에서 디폴트로 제공하는 Gateway를 사용하였다. 그리고, HTTP프로토콜을 80 포트에서 받도록 하였다.

다음에는 이 Gateway를 통해서 트래픽을 받을 서비스를 Virtual Service로 등록해야 하는데, 그 구조는 다음과 같다.


apiVersion: networking.istio.io/v1alpha3

kind: VirtualService

metadata:

 name: bookinfo

spec:

 hosts:

 - "*"

 gateways:

 - bookinfo-gateway

 http:

 - match:

   - uri:

       exact: /productpage

   - uri:

       exact: /login

   - uri:

       exact: /logout

   - uri:

       prefix: /api/v1/products

   route:

   - destination:

       host: productpage

       port:

         number: 9080


spec에서 gateways 부분에 앞에서 정의한 bookinfo-gateway를 사용하도록 한다. 이렇게 하면 앞에서 만든 Gateway로 들어오는 트래픽은 이 Virtual Servivce로 들어와서 서비스 디는데, 여기서 라우팅 룰을 정의 한다 라우팅룰은 URL에 때해서 어느 서비스로 라우팅할 지를 정하는데 /productpage,/login,/lougout,/api/v1/products URL은 productpage:9080 으로 포워딩해서 서비스를 제공한다.


Gateway와 Virtual service 배포에 앞서서, Istio에 미리 설치되어 있는 gateway를 살펴보면, Istio default gateway는 pod로 배포되어 있는데, istio=ingressgateway 라는 라벨이 적용되어 있다. 확인을 위해서 kubectl get 명령을 이용해서 확인해보면 다음과 같다.

%kubectl get pod -n istio-system -l istio=ingressgateway



이 pod들은 istio-ingressgateway라는 이름으로 istio-system 네임스페이스에 배포되어 있다. kubectl get svc로 확인해보면 다음과 같다.

%kubectl get svc istio-ingressgateway -n istio-system --show-labels



그러면 bookinfo를 istio gateway에 등록해서 외부로 서비스를 제공해보자

% istioctl create -f samples/bookinfo/networking/bookinfo-gateway.yaml


게이트 웨이 배포가 끝나면, 앞에서 조회한 Istio gateway service의 IP (여기서는 35.197.159.13)에 접속해서 확인해보자

브라우져를 열고 http://35.197.159.13/productpage 로 접속해보면 아래와 같이 정상적으로 서비스에 접속할 수 있다.



모니터링 툴

서비스 설치가 끝났으면 간단한 테스트와 함께 모니터링 툴을 이용하여 서비스를 살펴보자

Istio를 설치하면 Prometheus, Grafana, Kiali,Jaeger 등의 모니터링 도구가 기본적으로 인스톨 되어 있다. 각각의 도구를 이용해서 지표들을 모니터링 해보자

Grafana를 이용한 서비스별 지표 모니터링

Grafana를 이용해서는 각 서비스들의 지표를 상세하게 모니터링할 수 있다.

먼저 아래 스크립트를 사용해서 간단하게 부하를 주자. 아래 스크립트는 curl 명령을 반복적으로 호출하여 http://35.197.159.13/productpage 페이지를 불러서 부하를 주는 스크립이다.


for i in {1..100}; do

curl -o /dev/null -s -w "%{http_code}" http://35.197.159.13/productpage

done


다음 Grafana 웹 콘솔에 접근해야 하는데, Grafana는 외부 서비스로 노출이 안되도록 설정이 되어 있기 때문에 kubectl을 이용해서 Grafana 콘솔에 트래픽을 포워딩 하도록 하자. Grafana는 3000번 포트에서 돌고 있기 때문에, localhost:3000 → Grafana Pod의 3000 번 포트로 트래픽을 포워딩 하도록 설정하자


kubectl -n istio-system port-forward $(kubectl -n istio-system get pod -l app=grafana -o jsonpath='{.items[0].metadata.name}') 3000:3000 &


다음 localhost:3000 번으로 접속해보면 다음과 같은 화면을 볼 수 있다.

각 서비스 productpage,review,rating,detail 페이지의 응답시간과 OPS (Operation Per Sec : 초당 처리량)을 볼 수 있다.




각 서비스를 눌러보면 다음과 같이 서비스별로 상세한 내용을 볼 수 있다. 응답 시간이나 처리량에 대한 트렌드나, Request의 사이즈등 다양한 정보를 볼 수 있다.



Jaeger를 이용한 분산 트렌젝션 모니터링

다음은 Jaeger 를 이용해 개별 분산 트렌젝션에 대해서 각 구간별 응답 시간을 모니터링 할 수 있다.

Istio는 각 서비스별로 소요 시간을 수집하는데, 이를 Jaeger 오픈소스를 쓰면 손쉽게 모니터링이 가능하다.

마찬가지로 Jaeger 역시 외부 서비스로 노출이 되지 않았기 때문에, kubectl 명령을 이용해서 로컬 PC에서 jaeger pod로 포트를 포워딩하도록 한다. Jaerger는 16686 포트에서 돌고 있기 localhost:16686 → Jaeger pod:16686으로 포워딩한다.


kubectl port-forward -n istio-system $(kubectl get pod -n istio-system -l app=jaeger -o jsonpath='{.items[0].metadata.name}') 16686:16686 &


Jaeger UI에 접속해서, 아래는 productpage의 호출 기록을 보는 화면이다. 화면 상단에는 각 호출별로 응답시간 분포가 나오고 아래는 개별 트렉젝션에 대한 히스토리가 나온다.



그중 하나를 선택해보면 다음과 같은 그림을 볼 수 있다.



호출이 istio-ingressgateway로 들어와서 Productpage를 호출하였다.

productpage는 순차적으로 productpage → detail 서비스를 호출하였고, 다음 productpage→ reviews → ratings 서비스를 호출한것을 볼 수 있고, 많은 시간이 reviews 호출에 소요된것을 확인할 수 있다.


Servicegraph를 이용한 서비스 토폴로지 모니터링

마이크로 서비스는 서비스간의 호출 관계가 복잡해서, 각 서비스의 관계를 시각화 해주는 툴이 있으면 유용한데, 대표적인 도구로는 service graph라는 툴과 kiali 라는 툴이 있다. BookInfo 예제를 위한 Istio 설정에는 servicegraph가 디폴트로 설치되어 있다.


마찬가지로 외부 서비스로 노출 되서 서비스 되지 않고 클러스터 주소의 8088 포트를 통해서 서비스 되고 있기 때문에, 아래와 같이 kubectl 명령을 이용해서 localhost:8088 → service graph pod의 8088포트로 포워딩하도록 한다.


kubectl -n istio-system port-forward $(kubectl -n istio-system get pod -l app=servicegraph -o jsonpath='{.items[0].metadata.name}') 8088:8088 &


그 후에, 웹 브루우져에서 http://localhost:8088/dotviz 를 접속해보면 서비스들의 관계를 볼 수 있다.



다음 글에서는 예제를 통해서 Istio에서 네트워크 경로 설정하는 부분에 대해서 더 자세히 알아보도록 하겠다.


본인은 구글 클라우드의 직원이며, 이 블로그에 있는 모든 글은 회사와 관계 없는 개인의 의견임을 알립니다.

댓글을 달아 주세요

  1. 2018.12.05 14:07  댓글주소  수정/삭제  댓글쓰기

    비밀댓글입니다

  2. cjshms@gmail.com 2018.12.05 14:08  댓글주소  수정/삭제  댓글쓰기

    항상 큰 도움이 되십니다. 화이팅!!!!!

Circuit breaker 패턴을 이용한 장애에 강한 MSA 서비스 구현하기 #2

Spring을 이용한 Circuit breaker 구현


조대협 (http://bcho.tistory.com)


앞의 글에서는 넷플릭스 Hystrix를 이용하여 Circuit break를 구현해보았다.

실제 개발에서 Hystix로 개발도 가능하지만, 보통 자바의 경우에는 Spring framework을 많이 사용하기 때문에 이번 글에서는 Spring framework을 이용한 Circuit breaker를 구현하는 방법을 알아보도록 한다.


다행이도 근래에 Spring은 넷플릭스의 MSA 패턴들을 구현화한 오픈 소스들을 Spring 오픈 소스 프레임웍안으로 활발하게 합치는 작업을 진행하고 있어서 어렵지 않게 구현이 가능하다.


구현하고자 하는 시나리오는 앞의 글에서 예제로 사용한 User service에서 Item Service를 호출하는 구조를 구현하고, User service에 circuit breaker를 붙여보도록 하겠다.

User service 코드 전체는 https://github.com/bwcho75/msa_pattern_sample/tree/master/user-spring-hystrix 에 그리고 Item Service 코드 전체는 https://github.com/bwcho75/msa_pattern_sample/tree/master/item-spring-hystrix 에 있다


Spring Circuit breaker 구현

User service pom.xml 정의

Hystrix circuit breaker를 사용하기 위해서는 pom.xml에 다음과 같이 hystrix 관련 라이브러리에 대한 의존성을 정의해줘야 한다.

<dependency>

<groupId>org.springframework.cloud</groupId>

<artifactId>spring-cloud-starter-hystrix</artifactId>

<version>1.4.4.RELEASE</version>

</dependency>

<dependency>

<groupId>org.springframework.cloud</groupId>

<artifactId>spring-cloud-starter-hystrix-dashboard</artifactId>

<version>1.4.4.RELEASE</version>

</dependency>

<dependency>

<groupId>org.springframework.boot</groupId>

<artifactId>spring-boot-starter-actuator</artifactId>

<version>1.5.11.RELEASE</version>

</dependency>


spring-cloud-starter-hystrix 는 Hystrix circuit breaker를 이용한 의존성이고 hystrix-dashboard와 actuator 는 hystix dash 보드를 띄우기 위한 의존성이다.



User service 구현

UserApplication

Circuit breaker를 이용하기 위해서는 User Service의 메인 함수인 UserApplication 에 Annotation으로 선언을 해준다.



package com.terry.circuitbreak.User;




import org.springframework.boot.SpringApplication;


import org.springframework.boot.autoconfigure.SpringBootApplication;


import org.springframework.cloud.client.circuitbreaker.EnableCircuitBreaker;


import org.springframework.cloud.netflix.hystrix.dashboard.EnableHystrixDashboard;




@SpringBootApplication


@EnableCircuitBreaker


@EnableHystrixDashboard


public class UserApplication {





public static void main(String[] args) {


SpringApplication.run(UserApplication.class, args);


}


}


위의 코드와 같이 @EnableCircuitBreaker Annotation을 추가해주면 Circuit breaker를 사용할 수 있고, 그리고 추가적으로 Hystrix 대쉬 보드를 사용할것이기 때문에, @EnableHystrixDashboard Annotation을 추가한다.

Item Service를 호출

그러면 UserSerivce에서 ItemService를 호출하는 부분을 구현해보도록 하자. Hystrix와 마찬가지로 Spring Hystrix에서도 타 서비스 호출은 Command로 구현한다.  아래는 Item Service에서 Item 목록을 가지고 오는 GetItemCommand 코드이다.

GetItemCommand

Hystrix Command와 거의 유사하지만 Command를  상속 받아서 사용하지 않고, Circuit breaker를 적용한 메서드에 간단하게  @HystrixCommand Annotation만을 추가하면 된다.


아래 코드를 자세하게 보자. 주의할점은 Item Service 호출을 RestTemplate API를 통해서하는데, RestTemplate 객체인 resetTemplate는 Autowrire로 생성한다.



@Service


public class GetItemCommand {



@Autowired


RestTemplate restTemplate;



  @Bean


  public RestTemplate restTemplate() {


      return new RestTemplate();


  }





// GetItem command


@HystrixCommand(fallbackMethod = "getFallback")


public List<User> getItem(String name)  {


List<User> usersList = new ArrayList<User>();



List<Item> itemList = (List<Item>)restTemplate.exchange("http://localhost:8082/users/"+name+"/items"


,HttpMethod.GET,null


,new ParameterizedTypeReference<List<Item>>() {}).getBody();


usersList.add(new User(name,"myemail@mygoogle.com",itemList));



return usersList;


}



// fall back method


// it returns default result


@SuppressWarnings("unused")


public List<User> getFallback(String name){


List<User> usersList = new ArrayList<User>();


usersList.add(new User(name,"myemail@mygoogle.com"));



return usersList;


}


}


Item Service를 호출하는 코드는 getItem(String name) 메서드이다. 여기에 Circuit breaker를 적용하기 때문에, 메서드 앞에  @HystrixCommand(fallbackMethod = "getFallback") Annotation을 정의하였다. 그리고 Item Service 장애시 호출한 fallback 메서드는 getFallback 메서드로 지정하였다.

getItem안에서는 ItemService를 RestTemplate을 이용하여 호출하고 그 결과를 List<User> 타입으로 반환한다.


앞서 정의한 Fallback은 getFallback() 메서드로 Circuit breaker를 적용한 원래 함수와 입력 (String name)과 출력 (List<User>) 인자가 동일하다.

Circuit breaker 테스트


User service와 Item Service를 기동한 상태에서 user service를 호출하면 아래와 같이 itemList에 Item Service가 리턴한 내용이 같이 반환 되는 것을 확인할 수 있다.


terrycho-macbookpro:~ terrycho$ curl localhost:8081/users/terry

[  

  {  

     "name":"terry",

     "email":"myemail@mygoogle.com",

     "itemList":[  

        {

           "name":"computer",

           "quantity":1

        },

        {

           "name":"mouse",

           "quantity":2

        }

     ]

  }

]


Item Service를 내려놓고 테스트를 해보면 지연 응답 없이 User service로 부터 응답이 리턴되고, 앞서 정의한 fallback 메서드에 의해서 itemList에 아무 값이 없인할 수 있다.


terrycho-macbookpro:~ terrycho$ curl localhost:8081/users/terry

[  

  {  

     "name":"terry",

     "email":"myemail@mygoogle.com",

     "itemList":[]

  }

]


Hystrix Dashboard

User service에서 Hystrix Dash board를 사용하도록 설정하였기 때문에, User Service의 호출 상태를 실시간으로 확인할 수 있다.


User serivce 서버의 URL인 localhost:8081에서 localhost:8081/hystrix.stream을 호출 해보면

아래와 같이 Circuit Breaker가 적용된 메서드의 상태 현황 정보가 계속해서 업데이트 되면서 출력하는 것을 확인할 수 있다.




그러면 대쉬보드에 접속해보자 대쉬 보드 URL은 http://{user service}/hystrix 이다. User service url이 localhost:8081이기 때문에 http://localhost:8081/hystrix로 접속해보자


대쉬 보드에서는 모니터링 할 서비스의 스트림 URL을 넣어줘야 하는데 위에서 설명한 http://localhost:8081/hystrix.stream 을 입력한다.


URL을 입력하고 모니터링을 하면 아래와 같이 Circuit breaker가 등록된 서비스들이 모니터링 된다.

아래 그림은 부하가 없을때 상태이다.


실제로 부하를 주게 되면 아래와 같이 그래프가 커져가면서 정상적인 호출이 늘어가는 것을 확인할 수 있고, 응답 시간들도 모니터링이 가능하다.


아래는 Circuit breaker를 통해서 호출되는 Item service를 죽였을때인데, 그래프가 붉은색으로 표시되면서 붉은색 숫자가 증가하는 것을 볼 수 있고 Item service가 장애이기 때문에, Circuit 의 상태가 Close에서 Open을 변경된것을 확인할 수 있다.



운영 적용에 앞서서 고려할점

앞에서 예제로 사용한 Dashboard는 어디까지나 테스트 수준에서 사용할만한 수준이지 실제 운영환경에 적용할때는 여러가지 고려가 필요하다. 특히 /hystrix , /hystrix.stream이 외부에서 접근이 가능하기 때문에,, 이에 대해서 이 두 URL이 외부로 접근하는 것을 막아야 하며, circuit의 상태에 대한 정보를 하나의 서비스만 아니라 여러 서비스에서 대용량 서비스에 적용할시에는 중앙 집중화된 대쉬보드가 필요하고 또한 많은 로그를 동시에 수집해야 하기 때문에, 대용량 백앤드가 필요하다. 이를 지원하기 위해서 넷플릭스에서는 터빈 (Turbine)이라는 이름으로, 중앙 집중화된 Hystrix 대쉬 보드 툴을 지원하고 있다. (https://github.com/Netflix/turbine/wiki)


이번 글에서는 Spring 프레임웍을 이용하여 Circuit breaker 패턴을 Hystrix 프레임웍을 이용하여 적용하는 방법을 알아보았다.


Spring을 사용하면 편리는 하지만 자바 스택만을 지원한다는 한계점을 가지고 있다. Circuit breaker를 이처럼 소프트웨어로 지원할 수 도 있지만, 소프트웨어가 아닌 인프라 설정을 이용해서 적용이 가능한데, envoryproxy 를 이용하면 코드 변경 없이 모든 플랫폼에 적용이 가능하다. 다음 글에서는 envoy proxy를 이용하여, circuit breaker를 사용하는 방법에 대해서 알아보도록 한다.

본인은 구글 클라우드의 직원이며, 이 블로그에 있는 모든 글은 회사와 관계 없는 개인의 의견임을 알립니다.

댓글을 달아 주세요

Circuit breaker 패턴을 이용한 장애에 강한 MSA 서비스 구현하기 #1

Circuit breaker와 넷플릭스 Hystrix

조대협 (http://bcho.tistory.com)

MSA에서 서비스간 장애 전파

마이크로 서비스 아키텍쳐 패턴은 시스템을 여러개의 서비스 컴포넌트로 나눠서 서비스 컴포넌트간에 호출하는 개념을 가지고 있다. 이 아키텍쳐는 장점도 많지만 반대로 몇가지 단점을 가지고 있는데 그중에 하나는 하나의 컴포넌트가 느려지거나 장애가 나면 그 장애가난 컴포넌트를 호출하는 종속된 컴포넌트까지 장애가 전파되는 특성을 가지고 있다.


이해를 돕기 위해서 아래 그림을 보자


Service A가 Service B를 호출하는 상황에서 어떤 문제로 인하여 Service B가 응답을 못하거나 또는 응답 속도가 매우 느려진 상황이라고 가정하자. Service A가 Service B에 대한 호출 시도를 하면, Service A에서 Service B를 호출한 쓰레드는 응답을 받지 못하기 때문에, 계속 응답을 기다리는 상태로 잡혀있게 된다. 지속해서 Service A가 Service B를 호출을 하게 되면 앞과 같은 원리로 각 쓰레드들이 응답을 기다리는 상태로 변하게 되고 결과적으로는 남은 쓰레드가 없어서 다른 요청을 처리할 수 없는 상태가 된다.

이렇게 Service B의 장애가 Service A에 영향을 주는 경우를 장애가 전파 되었다고 한다. 이 상황에서 Service A를 호출하는 서비스가 또 있다면, 같은 원리로 인하여 그 서비스까지 장애가 전파되서 전체 시스템이 장애 상태로 빠질 수 있다.

Circuit breaker 패턴

이런 문제를 해결하는 디자인 패턴이 Circuit breaker 라는 패턴이 있다.

기본적인 원리는 다음과 같다. 서비스 호출 중간 즉 위의 예제에서는 Service A와 Service B에 Circuit Breaker를 설치한다. Service B로의 모든 호출은 이 Circuit Breaker를 통하게 되고 Service B가 정상적인 상황에서는 트래픽을 문제 없이 bypass 한다.

.


만약에 Service B가 문제가 생겼음을 Circuit breaker가 감지한 경우에는 Service B로의 호출을 강제적으로 끊어서 Service A에서 쓰레드들이 더 이상 요청을 기다리지 않도록 해서 장애가 전파하는 것을 방지 한다. 강제적으로 호출을 끊으면 에러 메세지가 Service A에서 발생하기 때문에 장애 전파는 막을 수 있지만, Service A에서 이에 대한 장애 처리 로직이 별도로 필요하다.

이를 조금 더 발전 시킨것이 Fall-back 메시징인데, Circuit breaker에서 Service B가 정상적인 응답을 할 수 없을 때, Circuit breaker가 룰에 따라서 다른 메세지를 리턴하게 하는 방법이다.



예를 들어 Service A가 상품 목록을 화면에 뿌려주는 서비스이고, Service B가 사용자에 대해서 머신러닝을 이용하여 상품을 추천해주는 서비스라고 했을때, Service B가 장애가 나면 상품 추천을 해줄 수 없다.

이때 상품 진열자 (MD)등이 미리 추천 상품 목록을 설정해놓고, Service B가 장애가 난 경우 Circuit breaker에서 이 목록을 리턴해주게 하면 머신러닝 알고리즘 기반의 상품 추천보다는 정확도는 낮아지지만 최소한 시스템이 장애가 나는 것을 방지 할 수 있고 다소 낮은 확률로라도 상품을 추천하여 꾸준하게 구매를 유도할 수 있다.


이 패턴은 넷플릭스에서 자바 라이브러리인 Hystrix로 구현이 되었으며, Spring 프레임웍을 통해서도 손쉽게 적용할 수 있다.

이렇게 소프트웨어 프레임웍 차원에서 적용할 수 있는 방법도 있지만 인프라 차원에서 Circuit breaker를 적용하는 방법도 있는데, envoy.io 라는 프록시 서버를 이용하면 된다.

소프트웨어를 사용하는 경우 관리 포인트가 줄어드는 장점은 있지만, 코드를 수정해야 하는 단점이 있고, 프로그래밍 언어에 따른 종속성이 있다.

반대로 인프라적인 접근의 경우에는 코드 변경은 필요 없으나, Circuit breaker용 프록시를 관리해야하는 추가적인 운영 부담이 늘어나게 된다.


이 글에서는 넷플릭스의 Hystrix, Spring circuit breaker를 이용한 소프트웨어적인 접근 방법과 envoy.io를 이용한 인프라적인 접근 방법 양쪽을 모두 살펴보기로 한다.


넷플릭스 Hystrix

넷플릭스는 MSA를 잘 적용하고 있는 기업이기도 하지만, 적용되어 있는 MSA 디자인 패턴 기술들을 오픈소스화하여 공유하는 것으로도 유명하다. Hystrix는 그중에서 Circuit breaker 패턴을 자바 기반으로 오픈소스화한 라이브러리이다.  


Circuit breaker 자체를 구현한것 뿐만 아니라, 각 서비스의 상태를 한눈에 알아볼 수 있도록 대쉬보드를 같이 제공한다.


Hystrix 라이브러리 사용방법

Hystrix를 사용하기 위해서는 pom.xml에 다음과 같이 라이브러리 의존성을 추가해야 한다.

<dependency>

<groupId>com.netflix.hystrix</groupId>

<artifactId>hystrix-core</artifactId>

<version>1.5.4</version>

</dependency>

<dependency>

<groupId>com.netflix.rxjava</groupId>

<artifactId>rxjava-core</artifactId>

<version>0.20.7</version>

</dependency>


Circuit breaker는 Hystrix 내에서 Command 디자인 패턴으로 구현된다. 먼저 아래 그림과 같이 HystrixCommand 클래스를 상속받은 Command 클래스를 정의한 후에, run() 메서드를 오버라이드하여, run 안에 실제 명령어를 넣으면 된다. HystrixCommand 클래스를 상속받을때 runI()메서드에서 리턴값으로 사용할 데이타 타입을 <>에 정의한다.


public class CommandHelloWorld extends HystrixCommand<String>{

private String name;

CommandHelloWorld(String name){

super(HystrixCommandGroupKey.Factory.asKey("ExampleGroup"));

this.name = name;

}

@Override

protected String run() {

return "Hello" + name +"!";

}


이렇게 Command가 정의되었으면 호출 방법은 아래와 같다.


CommandHelloWorld helloWorldCommand = new CommandHelloWorld("World");

assertEquals("Hello World", helloWorldCommand.execute());


먼저 Command 클래스의 객체를 생성한 다음에, 객체.execute()를 이용해서 해당 command 를 실행하면 된다. 이렇게 하면, Command 클래스가 응답을 제대로 받지 못할때는 Circuit Breaker를 이용하여 연결을 강제적으로 끊고 에러 메세지등을 리턴하도록 된다.


전체 코드 샘플은 https://github.com/bwcho75/msa_pattern_sample/tree/master/hystrix 를 참고하기 바란다.

웹서비스에 적용하는 방법

대략적인 개념을 이해하였으면 실제로 이 패턴을 REST API로 구성된 MSA 기반의 서비스에 적용해보자.

두 개의 서비스 User와 Item이 있다고 가정하자 User 서비스가 REST API 호출을 이용하여 Item 서비스를 호출하는 구조라고 할때 이 User → Item 서비스로의 호출을 HystrixCommand를 이용하여 Circuit breaker로 구현해보도록 하자.


User 서비스의 전체 코드는 https://github.com/bwcho75/msa_pattern_sample/tree/master/UserService , Item 서비스의 전체코드는 https://github.com/bwcho75/msa_pattern_sample/tree/master/ItemService 에 있다.

각 코드는 Spring Web을 이용하여 구현되었으며 User → Item으로의 호출을 resttemplate을 이용하였다.


User → Item 서비스를 호출하여 해당 사용자에 속한 Item 목록을 읽어오는 Command를 GetCommand라고 하자, 코드는 대략 아래와 같다.


public class GetItemCommand extends HystrixCommand<List<User>>{

String name;

public GetItemCommand(String name) {

super(HystrixCommandGroupKey.Factory.asKey("ItemServiceGroup"));

this.name = name;

}


@Override

protected List<User> run() throws Exception {

List<User> usersList = new ArrayList<User>();

// call REST API

                                                (생략)

return usersList;

}

@Override

protected List<User> getFallback(){

List<User> usersList = new ArrayList<User>();

usersList.add(new User(name,"myemail@mygoogle.com"));

return usersList;

}

}


리턴 값이 List<User>이기 때문에, HystrixCommand <List<User>>를 상속하여 구현하였고, Item 서비스를 호출하는 부분은 run() 메서드에 구현한다. (restTemplate을 이용하여 호출하는 내용은 생략하였다.)


여기서 주목해야할 부분은 getFallBack() 함수인데, 호출되는 서비스 Item이 장애 일때는 이를 인지하고 getFallBack의 리턴값을 fallback 메세지로 호출한다.


Item과 User 서비스를 각각 실행한다.

%java -jar ./target/User-0.0.1-SNAPSHOT.jar

%java -jar ./target/Item-0.0.1-SNAPSHOT.jar


두 서비스를 실행 한후에 아래와 같이 User 서비스를 호출하면 다음과 같이 ItemList가 채워져서 정상적으로 리턴되는 것을 볼 수 있다.


terrycho-macbookpro:~ terrycho$ curl localhost:8081/users/terry

[{"name":"terry","email":"myemail@mygoogle.com","itemList":[{"name":"computer","qtetertertertertetttt


Item 서비스 서버를 인위적으로 죽인 상태에서 호출을 하면 다음과 같이 위에서 정의한 fall back 메세지와 같이 email이 “myemail@mygoogle.com”으로 호출되고 itemList는 비어 있는채로 리턴이 된다.


terrycho-macbookpro:~ terrycho$ curl localhost:8081/users/terry

[{"name":"terry","email":"myemail@mygoogle.com","itemList":[]}]


지금까지 간단하게나마 Circuit breaker 패턴과 넷플릭스의 Hystrix 오픈소스를 이용하여 Circuit breaker를 구현하는 방법에 대해서 알아보았다.

서비스 상태에 따라서 Circuit을 차단하는 방법등도 다양하고, Command 패턴을 처리하는 방법 (멀티 쓰레드, 세마포어 방식)등이 다양하기 때문에, 자세한 내부 동작 방법 및 구현 가이드는 https://github.com/Netflix/Hystrix/wiki/How-it-Works 를 참고하기 바란다.


Circuit breaker 패턴은 개인적인 생각에서는 MSA에서는 거의 필수적으로 적용해야 하는 패턴이라고 생각을 하지만 Hystrix를 이용하면 Command를 일일이 작성해야 하고, 이로 인해서 코드 복잡도가 올라갈 수 있다. 이를 간소화 하기 위해서 Spring 오픈소스에 이 Hystrix를 잘 추상화 해놓은 기능이 있는데, 그 부분 구현에 대해서는 다음글을 통해서 살펴보도록 한다.



본인은 구글 클라우드의 직원이며, 이 블로그에 있는 모든 글은 회사와 관계 없는 개인의 의견임을 알립니다.

댓글을 달아 주세요

  1. Jodu 2019.09.19 17:46  댓글주소  수정/삭제  댓글쓰기

    관리자의 승인을 기다리고 있는 댓글입니다

Object Detection API를 이용하여 커스텀 데이타 학습하기

얼굴인식 모델 만들기


조대협 (http://bcho.tistory.com)


이번글에서는 Tensorflow Object Detection API를 이용하여 직접 이미지를 인식할 수 있는 방법에 대해서 알아보자. 이미 가지고 있는 데이타를 가지고 다양한 상품에 대한 인식이나, 사람 얼굴에 대한 인식 모델을 머신러닝에 대한 전문적인 지식 없이도 손쉽게 만들 수 있다.


Object Detection API 설치

Object Detection API 설치는 http://bcho.tistory.com/1193http://bcho.tistory.com/1192 에서 이미 다뤘기 때문에 별도로 언급하지 않는다.

학습용 데이타 데이타 생성 및 준비

Object Detection API를 학습 시키기 위해서는 http://bcho.tistory.com/1193 예제와 같이 TFRecord 형태로 학습용 파일과 테스트용 파일이 필요하다. TFRecord 파일 포맷에 대한 설명은 http://bcho.tistory.com/1190 를 참고하면 된다.


이미지 파일을 TFRecord로 컨버팅하는 전체 소스 코드는 https://github.com/bwcho75/objectdetection/blob/master/custom/create_face_data.py 를 참고하기 바란다.

구글 클라우드 VISION API를 이용하여,얼굴이 있는지 여부를 파악하고, 얼굴 각도가 너무 많이 틀어진 경우에는 필터링 해낸후에,  얼굴의 위치 좌표를 추출하여 TFRecord 파일에 쓰는 흐름이다.

VISION API를 사용하기 때문에 반드시 서비스 어카운트 (Service Account/JSON 파일)를 구글 클라우드 콘솔에서 만들어서 설치하고 실행하기 바란다.


사용 방법은

python create_face_data.py {이미지 소스 디렉토리} {이미지 아웃풋 디렉토리} {TFRECORD 파일명}


형태로 사용하면 된다.

예) python ./custom/create_face_data.py /Users/terrycho/trainingdata_source /Users/terrycho/trainingdata_out


{이미지 소스 디렉토리} 구조는 다음과 같다.

{이미지 소스 디렉토리}/{라벨1}

{이미지 소스 디렉토리}/{라벨2}

{이미지 소스 디렉토리}/{라벨3}

:

예를 들어

/Users/terrycho/trainingdata_source/Alba

/Users/terrycho/trainingdata_source/Jessica

/Users/terrycho/trainingdata_source/Victoria

:

이런식이 된다.



명령을 실행하면, {이미지 아웃풋 디렉토리} 아래

  • 학습 파일은 face_training.record

  • 테스트 파일은 face_evaluation.record

  • 라벨맵은 face_label_map.pbtxt

로 생성된다. 이 세가지 파일이 Object Detection API를 이용한 학습에 필요하고 부가적으로 생성되는  csv 파일이 있는데

  • all_files.csv : 소스 디렉토리에 있는 모든 이미지 파일 목록

  • filtered_files.csv : 각 이미지명과, 라벨, 얼굴 위치 좌표 (사각형), 이미지 전체 폭과 높이

  • converted_result_files.csv : filtered_files에 있는 이미지중, 얼굴의 각도등이 이상한 이미지를 제외하고 학습과 테스트용 데이타 파일에 들어간 이미지 목록으로, 이미지 파일명, 라벨 (텍스트), 라벨 (숫자), 얼굴 좌표 (사각형) 을 저장한다.


여기서 사용한 코드는 간단한 테스트용 코드로, 싱글 쓰레드에 싱글 프로세스 모델로 대규모의 이미지를 처리하기에는 적절하지 않기 때문에, 운영환경으로 올리려면, Apache Beam등 분산 프레임웍을 이용하여 병렬 처리를 하는 것을 권장한다. http://bcho.tistory.com/1177 를 참고하기 바란다.


여기서는 학습하고자 하는 이미지의 바운드리(사각형 경계)를 추출하는 것을 VISION API를 이용해서 자동으로 했지만, 일반적인 경우는 이미지에서 각 경계를 수동으로 추출해서 학습데이타로 생성해야 한다




이런 용도로 사용되는 툴은 https://medium.com/towards-data-science/how-to-train-your-own-object-detector-with-tensorflows-object-detector-api-bec72ecfe1d9 문서에 따르면 FastAnnotationTool이나 ImageMagick 과 같은 툴을 추천하고 있다.



이렇게 학습용 파일을 생성하였으면 다음 과정은 앞의  http://bcho.tistory.com/1193 에서 언급한 절차와 크게 다르지 않다.

체크포인트 업로드

학습 데이타가 준비 되었으면 학습을 위한 준비를 하는데, 트랜스퍼 러닝 (Transfer learning)을 위해서 기존의 학습된 체크포인트 데이타를 다운 받아서 이를 기반으로 학습을 한다.

Tensorflow Object Detection API는 경량이고 단순한 모델에서 부터 정확도가 비교적 높은 복잡한 모델까지 지원하고 있지만, 복잡도가 높다고 해서 정확도가 꼭 높지는 않을 수 있다. 복잡한 모델일 수 록 학습 데이타가 충분해야 하기 때문에, 학습하고자 하는 데이타의 양과 클래스의 종류에 따라서 적절한 모델을 선택하기를 권장한다.


여기서는 faster_rcnn_inception_resnet_v2 모델을 이용했기 때문에 아래와 같이 해당 모델의 체크포인트 데이타를 다운로드 받는다.


curl -O http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_resnet_v2_atrous_coco_11_06_2017.tar.gz


파일의 압축을 푼 다음 체크 포인트 파일을 학습 데이타용 Google Cloud Storage (GCS) 버킷으로 업로드 한다.

gsutil cp faster_rcnn_inception_resnet_v2_atrous_coco_11_06_2017/model.ckpt.* gs://${YOUR_GCS_BUCKET}/data/





설정 파일 편집 및 업로드

다음 학습에 사용할 모델의 설정을 해야 하는데,  object_detection/samples/configs/ 디렉토리에 각 모델별 설정 파일이 들어 있으며, 여기서는 faster_rcnn_inception_resnet_v2_atrous_pets.config 파일을 사용한다.


이 파일에서 수정해야 하는 부분은 다음과 같다.

클래스의 수

클래스 수를 정의한다. 이 예제에서는 총 5개의 클래스로 분류를 하기 때문에 아래와 같이 5로 변경하였다.

 8 model {

 9   faster_rcnn {

10     num_classes: 5

11     image_resizer {

학습 데이타 파일 명 및 라벨명

학습에 사용할 학습데이타 파일 (tfrecord)와 라벨 파일명을 지정한다.

126 train_input_reader: {

127   tf_record_input_reader {

128     input_path: "gs://terrycho-facedetection/data/face_training.record"

129   }

130   label_map_path: "gs://terrycho-facedetection/data/face_label_map.pbtxt"

131 }


테스트 데이타 파일명 및 라벨 파일명

학습후 테스트에 사용할 테스트 파일 (tfrecord)과 라벨 파일명을 지정한다

140 eval_input_reader: {

141   tf_record_input_reader {

142     input_path: "gs://terrycho-facedetection/data/face_evaluation.record"

143   }

144   label_map_path: "gs://terrycho-facedetection/data/face_label_map.pbtxt"

145   shuffle: false

146   num_readers: 1


만약에 학습 횟수(스탭)을 조정하고 싶으면 num_steps 값을 조정한다. 디폴트 설정은 20만회인데, 여기서는 5만회로 수정하였다.

117   # never decay). Remove the below line to train indefinitely.
118   # num_steps: 200000
119   num_steps: 50000
120   data_augmentation_options {
121     random_horizontal_flip {
122     }


설정 파일 수정이 끝났으면 gsutil cp 명령을 이용하여 해당 파일을 GCS 버킷에 다음과 같이 업로드 한다.

gsutil cp object_detection/samples/configs/faster_rcnn_inception_resnet_v2_atrous_pets.config gs://${YOUR_GCS_BUCKET}/data/faster_rcnn_inception_resnet_v2_atrous_pets.config

코드 패키징

models/ 디렉토리에서 다음 명령을 수행하여, 모델 코드를 패키징한다.

python setup.py sdist

(cd slim && python setup.py sdist)



학습


gcloud ml-engine jobs submit training `whoami`_object_detection_`date +%s` \

   --job-dir=gs://${YOUR_GCS_BUCKET}/train \

   --packages dist/object_detection-0.1.tar.gz,slim/dist/slim-0.1.tar.gz \

   --module-name object_detection.train \

   --region asia-east1 \

   --config object_detection/samples/cloud/cloud.yml \

   -- \

   --train_dir=gs://${YOUR_GCS_BUCKET}/train \

   --pipeline_config_path=gs://${YOUR_GCS_BUCKET}/data/faster_rcnn_resnet101_pets.config

모니터링

학습이 진행되면 텐서보드를 이용하여 학습 진행 상황을 모니터링할 수 있고, 또한 테스트 트레이닝을 수행하여, 모델에 대한 테스트를 동시 진행할 수 있다. http://bcho.tistory.com/1193 와 방법이 동일하니 참고하기 바란다.


학습을 시작하면 텐서보드를 통해서, Loss 값이 수렴하는 것을 확인할 수 있다.



결과

학습이 끝나면 텐서보드에서 테스트된 결과를 볼 수 있다. 이 예제의 경우 모델을 가장 복잡한 모델을 사용했는데 반하여, 총 5개의 클래스에 대해서 클래스당 약 40개정도의 학습 데이타를 사용했는데, 상대적으로 정확도가 낮았다. 실 서비스에서는 더 많은 데이타를 사용하기를 권장한다.



활용

학습된 모델을 활용하는 방법은 학습된 모델을 export 한후에, (Export 하는 방법은  http://bcho.tistory.com/1193 참고) export 된 모델을 로딩하여, 코드에서 불러서 사용하면 된다.

http://bcho.tistory.com/1192 참고



본인은 구글 클라우드의 직원이며, 이 블로그에 있는 모든 글은 회사와 관계 없는 개인의 의견임을 알립니다.

댓글을 달아 주세요

연예인 얼굴 인식 서비스를 만들어보자 #1 - 데이타 준비하기

 

CNN 에 대한 이론 공부와 텐서 플로우에 대한 기본 이해를 끝내서 실제로 모델을 만들어보기로 하였다.

CNN을 이용한 이미지 인식중 대중적인 주제로 얼굴 인식 (Face recognition)을 주제로 잡아서, 이 모델을 만들기로 하고 아직 실력이 미흡하여 호주팀에서 일하고 있는 동료인 Win woo 라는 동료에게 모델과 튜토리얼 개발을 부탁하였다.

 

이제 부터 연재하는 연예인 얼굴 인식 서비스는 Win woo 가 만든 코드를 기반으로 하여 설명한다. (코드 원본 주소 : https://github.com/wwoo/tf_face )

 

얼굴 데이타를 구할 수 있는곳

먼저 얼굴 인식 모델을 만들려면, 학습을 시킬 충분한 데이타가 있어야 한다. 사람 얼굴을 일일이 구할 수 도 없고, 구글이나 네이버에서 일일이 저장할 수 도 없기 때문에, 공개된 데이타셋을 활용하였는데, PubFig (Public Figures Face Database - http://www.cs.columbia.edu/CAVE/databases/pubfig/) 를 사용하였다.


 

이 데이타셋에는 약 200명에 대한 58,000여장의 이미지를 저장하고 있는데, 이 중의 일부만을 사용하였다.

Download 페이지로 가면, txt 파일 형태 (http://www.cs.columbia.edu/CAVE/databases/pubfig/download/dev_urls.txt) 로 아래와 같이

 

Abhishek Bachan 1 http://1.bp.blogspot.com/_Y7rzCyUABeI/SNIltEyEnjI/AAAAAAAABOg/E1keU_52aFc/s400/ash_abhishek_365x470.jpg 183,60,297,174 f533da9fbd1c770428c8961f3fa48950
Abhishek Bachan 2 http://1.bp.blogspot.com/_v9nTKD7D57Q/SQ3HUQHsp_I/AAAAAAAAQuo/DfPcHPX2t_o/s400/normal_14thbombaytimes013.jpg 49,71,143,165 e36a8b24f0761ec75bdc0489d8fd570b
Abhishek Bachan 3 http://2.bp.blogspot.com/_v9nTKD7D57Q/SL5KwcwQlRI/AAAAAAAANxM/mJPzEHPI1rU/s400/ERTYH.jpg 32,68,142,178 583608783525c2ac419b41e538a6925d

 

사람이름, 이미지 번호, 다운로드 URL, 사진 크기, MD5 체크섬을 이 필드로 저장되어 있다.

이 파일을 이용하여 다운로드 URL에서 사진을 다운받아서, 사람이름으로된 폴더에 저장한다.

물론 수동으로 할 수 없으니 HTTP Client를 이용하여, URL에서 사진을 다운로드 하게 하고, 이를 사람이름 폴더 별로 저장하도록 해야 한다.

 

HTTP Client를 이용하여 파일을 다운로드 받는 코드는 일반적인 코드이기 때문에 별도로 설명하지 않는다.

본인의 경우에는 Win이 만든 https://github.com/wwoo/tf_face/blob/master/tf/face_extract/pubfig_get.py 코드를 이용하여 데이타를 다운로드 받았다.

사용법은  https://github.com/wwoo/tf_face 에 나와 있는데,

 

$> python tf/face_extract/pubfig_get.py tf/face_extract/eval_urls.txt ./data

를 실행하면 ./data 디렉토리에 이미지를 다운로드 받아서 사람 이름별 폴더에 저장해준다.

evals_urls.txt에는 위에서 언급한 dev_urls.txt 형태의 데이타가 들어간다.


사람 종류가 너무 많으면 데이타를 정재하는 작업이 어렵고, (왜 어려운지는 뒤에 나옴) 학습 시간이 많이 걸리기 때문에, 약 47명의 데이타를 다운로드 받아서 작업하였다.

학습 데이타 준비에 있어서 경험

쓰레기 데이타 골라내기

데이타를 다운받고 나니, 아뿔사!! PubFig 데이타셋이 오래되어서 없는 이미지도 있고 학습에 적절하지 않은 이미지도 있다.


주로 학습에 적절하지 않은 데이타는 한 사진에 두사람 이상의 얼굴이 있거나, 이미지가 사라져서 위의 우측 그림처럼, 이미지가 없는 형태로 나오는 경우인데, 이러한 데이타는 어쩔 수 없이 눈으로 한장한장 다 걸러내야만 했는데, 이런 간단한 데이타 필터링 처리는 Google Cloud Vision API를 이용하여, 얼굴이 하나만 있는 사진만을 사용하도록 하여 필터링을 하였다.

학습 데이타의 분포

처음에 학습을 시작할때, 분류별로 데이타의 수를 다르게 하였다. 어렵게 모은 데이타를 버리기가 싫어서 모두 다 넣고 학습 시켰는데, 그랬더니 학습이 쏠리는 현상이 발생하였다.

예를 들어 안젤리나 졸리 300장, 브래드피트 100장, 제시카 알바 100장 이런식으로 학습을 시켰더니, 이미지 예측에서 안젤리나 졸리로 예측하는 경우가 많아졌다. 그래서 학습을 시킬때는 데이타수가 작은 쪽으로 맞춰서 각 클래스당 학습 데이타수가 같도록 하였다. 즉 위의 데이타의 경우에는 안젤리나 졸리 100장, 브래드피트 100장, 제시카 알바 100장식으로 데이타 수를 같게 해야했다.

라벨은 숫자로

라벨의 가독성을 높이기 위해서 라벨을 영문 이름으로 사용했는데, CNN 알고리즘에서 최종 분류를 하는 알고리즘은 softmax 로 그 결과 값을 0,1,2…,N식으로 라벨을 사용하기 때문에, 정수형으로 변환을 해줘야 하는데, 텐서 플로우 코드에서는 이게 그리 쉽지않았다. 그래서 차라리 처음 부터 학습 데이타를 만들때는 라벨을 정수형으로 만드는것이 더 효과적이다

얼굴 각도, 표정,메이크업, 선글라스 도 중요하다

CNN 알고리즘을 마법처럼 생각해서였을까? 데이타만 있다면 어떻게든 학습이 될 줄 알았다. 그러나 얼굴의 각도가 많이 다르거나 표정이 심하게 차이가 난 경우에는 다른 사람으로 인식이 되기 때문에 가능하면 비슷한 표정에 비슷한 각도의 사진으로 학습 시키는 것이 정확도를 높일 수 있다.


 

얼굴 각도의 경우 구글 클라우드 VISION API를 이용하면 각도를 추출할 수 있기 때문에 20도 이상 차이가 나는 사진은 필터링 하였고, 표정 부분도 VISION API를 이용하면 감정도를 분석할 수 있기 때문에 필터링이 가능하다. (아래서 설명하는 코드에서는 감정도 분석 부분은 적용하지 않았다)

또한 선글라스를 쓴 경우에도 다른 사람으로 인식할 수 있기 때문에 VISION API에서 물체 인식 기능을 이용하여 선글라스가 검출된 경우에는 학습 데이타에서 제거하였다.

이외에도 헤어스타일이나 메이크업이 심하게 차이가 나는 경우에는 다른 사람으로 인식되는 확률이 높기 때문에 이런 데이타도 가급적이면 필터링을 하는것이 좋다.

웹 크라울링의 문제점

데이타를 쉽게 수집하려고 웹 크라울러를 이용해서 구글 이미지 검색에서 이미지를 수집해봤지만, 정확도는 매우 낮게 나왔다.


 

https://www.youtube.com/watch?v=k5ioaelzEBM

<그림. 설현 얼굴을 웹 크라울러를 이용하여 수집하는 화면>

 

아래는 웹 크라울러를 이용하여 EXO 루한의 사진을 수집한 결과중 일부이다.


웹크라울러로 수집한 데이타는, 앞에서 언급한 쓰레기 데이타들이 너무 많다. 메이크업, 표정, 얼굴 각도, 두명 이상 있는 사진들이 많았고, 거기에 더해서 그 사람이 아닌 사람의 얼굴 사진까지 같이 수집이 되는 경우가 많았다.

웹 크라울링을 이용한 학습 데이타 수집은 적어도 얼굴 인식용 데이타 수집에 있어서는 좋은 방법은 아닌것 같다. 혹여나 웹크라울러를 사용하더라도 반드시 수동으로 직접 데이타를 검증하는 것이 좋다.

학습 데이타의 양도 중요하지만 질도 매우 중요하다

아이돌 그룹인 EXO와 레드벨벳의 사진을 웹 크라울러를 이용해서 수집한 후에 학습을 시켜보았다. 사람당 약 200장의 데이타로 8개 클래스 정도를 테스트해봤는데 정확도가 10%가 나오지를 않았다.

대신 데이타를 학습에 좋은 데이타를 일일이 눈으로 확인하여 클래스당 30장 정도를 수집해서 학습 시킨 결과 60% 정도의 정확도를 얻을 수 있었다.  양도 중요하지만 학습 데이타의 질적인 면도 중요하다.

중복데이타 처리 문제

데이타를 수집해본 결과, 중복되는 데이타가 생각보다 많았다. 중복 데이타를 걸러내기 위해서 파일의 MD5 해쉬 값을 추출해낸 후 이를 비교해서 중복되는 파일을 제거하였는데, 어느정도 효과를 볼 수 있었지만, 아래 이미지와 같이 같은 이미지지만, 편집이나 리사이즈가 된 이미지의 경우에는 다른 파일로 인식되서 중복 체크에서 검출되지 않았다.


연예인 얼굴 인식은 어렵다

얼굴 인식 예제를 만들면서 재미를 위해서 한국 연예인 얼굴을 수집하여 학습에 사용했는데, 제대로 된 학습 데이타를 구하기가 매우 어려웠다. 앞에서 언급한데로 메이크업이나 표정 변화가 너무 심했고, 어렸을때나 나이먹었을때의 차이등이 심했다. 간단한 공부용으로 사용하기에는 좋은 데이타는 아닌것 같다.

그러면 학습에 좋은 데이타는?

그러면 얼굴 인식 학습에 좋은 데이타는 무엇일까? 테스트를 하면서 내린 자체적인 결론은 정면 프로필 사진류가 제일 좋다. 특히 스튜디오에서 찍은 사진은 같은 조명에 같은 메이크업과 헤어스타일로 찍은 경우가 많기 때문에 학습에 적절하다. 또는 동영상의 경우에는 프레임을 잘라내면 유사한 표정과 유사한 각도, 조명등에 대한 데이타를 많이 얻을 수 있기 때문에 좋은 데이타 된다.

얼굴 추출하기

그러면 앞의 내용을 바탕으로 해서, 적절한 학습용 얼굴 이미지를 추출하는 프로그램을 만들어보자

포토샵으로 일일이 할 수 없기 때문에 얼굴 영역을 인식하는 API를 사용하기로한다. OPEN CV와 같은 오픈소스 라이브러리를 사용할 수 도 있지만 구글의 VISION API의 경우 얼굴 영역을 아주 잘 잘라내어주고,  얼굴의 각도나 표정을 인식해서 필터링 하는 기능까지 코드 수십줄만 가지고도 구현이 가능했기 때문에, VISION API를 사용하였다. https://cloud.google.com/vision/

VISION API ENABLE 하기

VISION API를 사용하기 위해서는 해당 구글 클라우드 프로젝트에서 VISION API를 사용하도록 ENABLE 해줘야 한다.

VISION API를 ENABLE하기 위해서는 아래 화면과 같이 구글 클라우드 콘솔 > API Manager 들어간후


 

+ENABLE API를 클릭하여 아래 그림과 같이 Vision API를 클릭하여 ENABLE 시켜준다.

 



 

SERVICE ACCOUNT 키 만들기

다음으로 이 VISION API를 호출하기 위해서는 API 토큰이 필요한데, SERVICE ACCOUNT 라는 JSON 파일을 다운 받아서 사용한다.

구글 클라우드 콘솔에서 API Manager로 들어간후 Credentials 메뉴에서 Create creadential 메뉴를 선택한후, Service account key 메뉴를 선택한다


 

다음 Create Service Account key를 만들도록 하고, accountname과 id와 같은 정보를 넣는다. 이때 중요한것이 이 키가 가지고 있는 사용자 권한을 설정해야 하는데, 편의상 모든 권한을 가지고 있는  Project Owner 권한으로 키를 생성한다.

 

(주의. 실제 운영환경에서 전체 권한을 가지는 키는 보안상의 위험하기 때문에 특정 서비스에 대한 접근 권한만을 가지도록 지정하여 Service account를 생성하기를 권장한다.)

 


 

Service account key가 생성이 되면, json 파일 형태로 다운로드가 된다.

여기서는 terrycho-ml-80abc460730c.json 이름으로 저장하였다.

 

예제 코드

그럼 예제를 보자 코드의 전문은 https://github.com/bwcho75/facerecognition/blob/master/com/terry/face/extract/crop_face.py 에 있다.

 

이 코드는 이미지 파일이 있는 디렉토리를 지정하고, 아웃풋 디렉토리를 지정해주면 이미지 파일을 읽어서 얼굴이 있는지 없는지를 체크하고 얼굴이 있으면, 얼굴 부분만 잘라낸 후에, 얼굴 사진을 96x96 사이즈로 리사즈 한후에,

70%의 파일들은 학습용으로 사용하기 위해서 {아웃풋 디렉토리/training/} 디렉토리에 저장하고

나머지 30%의 파일들은 검증용으로 사용하기 위해서 {아웃풋 디렉토리/validate/} 디렉토리에 저장한다.

 

그리고 학습용 파일 목록은 다음과 같이 training_file.txt에 파일 위치,사람명(라벨) 형태로 저장하고

/Users/terrycho/traning_datav2/training/wsmith.jpg,Will Smith

/Users/terrycho/traning_datav2/training/wsmith061408.jpg,Will Smith

/Users/terrycho/traning_datav2/training/wsmith1.jpg,Will Smith

 

검증용 파일들은 validate_file.txt에 마찬가지로  파일위치와, 사람명(라벨)을 저장한다.

사용 방법은 다음과 같다.

python com/terry/face/extract/crop_face.py “원본 파일이있는 디렉토리" “아웃풋 디렉토리"

(원본 파일 디렉토리안에는 {사람이름명} 디렉토리 아래에 사진들이 쭈욱 있는 구조라야 한다.)

 

자 그러면, 코드의 주요 부분을 살펴보자

 

VISION API 초기화 하기

  def __init__(self):

       # initialize library

       #credentials = GoogleCredentials.get_application_default()

       scopes = ['https://www.googleapis.com/auth/cloud-platform']

       credentials = ServiceAccountCredentials.from_json_keyfile_name(

                       './terrycho-ml-80abc460730c.json', scopes=scopes)

       self.service = discovery.build('vision', 'v1', credentials=credentials)

 

초기화 부분은 Google Vision API를 사용하기 위해서 OAuth 인증을 하는 부분이다.

scope를 googleapi로 정해주고, 인증 방식을 Service Account를 사용한다. credentials 부분에 service account key 파일인 terrycho-ml-80abc460730c.json를 지정한다.

 

얼굴 영역 찾아내기

다음은 이미지에서 얼굴을 인식하고, 얼굴 영역(사각형) 좌표를 리턴하는 함수를 보자

 

   def detect_face(self,image_file):

       try:

           with io.open(image_file,'rb') as fd:

               image = fd.read()

               batch_request = [{

                       'image':{

                           'content':base64.b64encode(image).decode('utf-8')

                           },

                       'features':[

                           {

                           'type':'FACE_DETECTION',

                           'maxResults':MAX_FACE,

                           },

                           {

                           'type':'LABEL_DETECTION',

                           'maxResults':MAX_LABEL,

                           }

                                   ]

                       }]

               fd.close()

       

           request = self.service.images().annotate(body={

                           'requests':batch_request, })

           response = request.execute()

           if 'faceAnnotations' not in response['responses'][0]:

                print('[Error] %s: Cannot find face ' % image_file)

                return None

               

           face = response['responses'][0]['faceAnnotations']

           label = response['responses'][0]['labelAnnotations']

           

           if len(face) > 1 :

               print('[Error] %s: It has more than 2 faces in a file' % image_file)

               return None

           

           roll_angle = face[0]['rollAngle']

           pan_angle = face[0]['panAngle']

           tilt_angle = face[0]['tiltAngle']

           angle = [roll_angle,pan_angle,tilt_angle]

           

           # check angle

           # if face skew angle is greater than > 20, it will skip the data

           if abs(roll_angle) > MAX_ROLL or abs(pan_angle) > MAX_PAN or abs(tilt_angle) > MAX_TILT:

               print('[Error] %s: face skew angle is big' % image_file)

               return None

           

           # check sunglasses

           for l in label:

               if 'sunglasses' in l['description']:

                 print('[Error] %s: sunglass is detected' % image_file)  

                 return None

           

           box = face[0]['fdBoundingPoly']['vertices']

           left = box[0]['x']

           top = box[1]['y']

               

           right = box[2]['x']

           bottom = box[2]['y']

               

           rect = [left,top,right,bottom]

               

           print("[Info] %s: Find face from in position %s and skew angle %s" % (image_file,rect,angle))

           return rect

       except Exception as e:

           print('[Error] %s: cannot process file : %s' %(image_file,str(e)) )

           

 

 

맨 처음에는 얼굴 영역을 추출하기전에, 같은 파일이 예전에 사용되었는지를 확인한다.

           image = Image.open(fd)  

 

           # extract hash from image to check duplicated image

           m = hashlib.md5()

           with io.BytesIO() as memf:

               image.save(memf, 'PNG')

               data = memf.getvalue()

               m.update(data)

 

           if image_hash in global_image_hash:

               print('[Error] %s: Duplicated image' %(image_file) )

               return None

           global_image_hash.append(image_hash)

 

이미지에서 md5 해쉬를 추출한후에, 이 해쉬를 이용하여 학습 데이타로 사용된 파일들의 해쉬와 비교한다. 만약에 중복되는 것이 없으면 이 해쉬를 리스트에 추가하고 다음 과정을 수행한다.

 

VISION API를 이용하여, 얼굴 영역을 추출하는데, 위의 코드에서 처럼 image_file을 읽은후에, batch_request라는 문자열을 만든다. JSON 형태의 문자열이 되는데, 이때 image라는 항목에 이미지 데이타를 base64 인코딩 방식으로 인코딩해서 전송한다. 그리고 VISION API는 얼굴인식뿐 아니라 사물 인식, 라벨인식등 여러가지 기능이 있기 때문에 그중에서 타입을 ‘FACE_DETECTION’으로 정의하여 얼굴 영역만 인식하도록 한다.

 

request를 만들었으면, VISION API로 요청을 보내면 응답이 오는데, 이중에서 response 엘리먼트의 첫번째 인자 ( [‘responses’][0] )은 첫번째 얼굴은 뜻하는데, 여기서 [‘faceAnnotation’]을 하면 얼굴에 대한 정보만을 얻을 수 있다. 이중에서  [‘fdBoundingPoly’] 값이 얼굴 영역을 나타내는 사각형이다. 이 갑ㄱㅅ을 읽어서 left,top,right,bottom 값에 세팅한 후 리턴한다.

 

그리고 얼굴의 각도 (상하좌우옆)를 추출하여, 얼국 각도가 각각 20도 이상 더 돌아간 경우에는 학습 데이타로 사용하지 않고 필터링을 해냈다.

다음은 각도를 추출하고 필터링을 하는 부분이다.

           roll_angle = face[0]['rollAngle']

           pan_angle = face[0]['panAngle']

           tilt_angle = face[0]['tiltAngle']

           angle = [roll_angle,pan_angle,tilt_angle]

           

           # check angle

           # if face skew angle is greater than > 20, it will skip the data

           if abs(roll_angle) > MAX_ROLL or abs(pan_angle) > MAX_PAN or abs(tilt_angle) > MAX_TILT:

               print('[Error] %s: face skew angle is big' % image_file)

               return None

 

 

VISION API에서 추가로 “FACE DETECTION” 뿐만 아니라 “LABEL_DETECTION” 을 같이 수행했는데 이유는 선글라스를 쓰고 있는 사진을 필터링하기 위해서 사용하였다. 아래는 선글라스 있는 사진을 검출하는  코드이다.

           # check sunglasses

           for l in label:

               if 'sunglasses' in l['description']:

                 print('[Error] %s: sunglass is detected' % image_file)  

                 return None

 

얼굴 잘라내고 리사이즈 하기

앞의 detect_face에서 필터링하고 찾아낸 얼굴 영역을 가지고 그 부분만 전체 사진에서 잘라내고, 잘라낸 얼굴을 학습에 적합하도록 같은 크기 (96x96)으로 리사이즈 한다.

이런 이미지 처리를 위해서 PIL (Python Imaging Library - http://www.pythonware.com/products/pil/)를 사용하였다.

   def crop_face(self,image_file,rect,outputfile):

       try:

           fd = io.open(image_file,'rb')

           image = Image.open(fd)  

           crop = image.crop(rect)

           im = crop.resize(IMAGE_SIZE,Image.ANTIALIAS)

           im.save(outputfile,"JPEG")

           fd.close()

           print('[Info] %s: Crop face %s and write it to file : %s' %(image_file,rect,outputfile) )

       except Exception as e:

           print('[Error] %s: Crop image writing error : %s' %(image_file,str(e)) )

image_file을 인자로 받아서 , rect 에 정의된 사각형 영역 만큼 crop를 해서 잘라내고, resize 함수를 이용하여 크기를 96x96으로 조정한후 (참고 IMAGE_SIZE = 96,96 로 정의되어 있다.) outputfile 경로에 저장하게 된다.        

 

실행을 해서 정재된 데이타는 다음과 같다.



  

생각해볼만한점들

이 코드는 간단한 토이 프로그램이기 때문에 간단하게 작성했지만 실제 운영환경에 적용하기 위해서는 몇가지 고려해야 할 사항이 있다.

먼저, 이 코드는 싱글 쓰레드로 돌기 때문에 속도가 상대적으로 느리다 그래서 멀티 쓰레드로 코드를 수정할 필요가 있으며, 만약에 수백만장의 사진을 정재하기 위해서는 한대의 서버로 되지 않기 때문에, 원본 데이타를 여러 서버로 나눠서 처리할 수 있는 분산 처리 구조가 고려되어야 한다.

또한, VISION API로 사진을 전송할때는 BASE64 인코딩된 구조로 서버에 이미지를 직접 전송하기 때문에, 자칫 이미지 사이즈들이 크면 네트워크 대역폭을 많이 잡아먹을 수 있기 때문에 가능하다면 식별이 가능한 크기에서 리사이즈를 한 후에, 서버로 전송하는 것이 좋다. 실제로 필요한 얼굴 크기는 96x96 픽셀이기 때문에 필요없이 1000만화소 고화질의 사진들을 전송해서 네트워크 비용을 낭비하지 않기를 바란다.

 

다음은 이렇게 정재한 파일들을 텐서플로우에서 읽어서 실제로 학습하는 모델을 만들어보겠다.


위의 코드를 멀티 프로세스&멀티쓰레드로 돌리는 아키텍쳐와 코드는 http://bcho.tistory.com/1177 글을 참고하기 바란다.

 

본인은 구글 클라우드의 직원이며, 이 블로그에 있는 모든 글은 회사와 관계 없는 개인의 의견임을 알립니다.

댓글을 달아 주세요

  1. Yonghan 2017.06.29 10:21  댓글주소  수정/삭제  댓글쓰기

    안녕하세요. 우선 감사드립니다. 딥러닝 관련 포스팅 잘보고있습니다.^^
    다름이아니라 질문이 있습니다.. json key값 코드에 작성 후
    위에 코드만 바로 실행하면 이미지가 정재되어 저장되나요?

  2. 조대협 2017.06.29 10:51 신고  댓글주소  수정/삭제  댓글쓰기

    네 아마도 될겁니다

  3. 2018.11.05 16:44  댓글주소  수정/삭제  댓글쓰기

    비밀댓글입니다

  4. junho 2018.11.14 19:21  댓글주소  수정/삭제  댓글쓰기

    안녕하세요. 이미지 정제 및 cnn 학습을 공부하고 있는 학생입니다!
    google vision api를 이용해 사물을 인식하는 부분을 공부하다가 블로그를 찾게 되었습니다.
    얼굴인식에서 사물인식으로 코드를 조금 수정해서 사용하고 싶은데, 생각보다 제 실력이 부족하다보니 어렵네요...
    사물인식을 할때 label detection으로 하는걸로 알고 있는데, 이 부분만 수정하고, 코드를 돌리면 되나요?

텐서플로우의 세션,그래프 그리고 함수의 개념


조대협 (http://bcho.tistory.com)


그래프와 세션에 대한 개념이 헷갈려서, 좋은 샘플이 하나 만들어져서 공유합니다.

텐서 플로우의 기본 작동 원리는 세션 시작전에 그래프를 정의해놓고, 세션을 시작하면 그 그래프가 실행되는 원리인데, 그래서 이 개념이 일반적인 프로그래밍 개념과 상의하여 헷갈리는 경우가 많다


즉, 세션을 시작해놓고 함수를 호출하는 케이스들이 대표적인데

http://bcho.tistory.com/1170 코드를 재 사용해서 이해해보도록 하자


이 코드를 보면, tt = time * 10 을 세션 시작전에 정의해놨는데, 이 코드를 함수로 바꾸면 아래와 같은 형태가 된다.


변경전 코드

def main():

   

   print 'start session'

   #coornator 위에 코드가 있어야 한다

   #데이타를 집어 넣기 전에 미리 그래프가 만들어져 있어야 함.

   batch_year,batch_flight,batch_time = read_data_batch(TRAINING_FILE)

   year = tf.placeholder(tf.int32,[None,],name='year')

   flight = tf.placeholder(tf.string,[None,],name='flight')

   time = tf.placeholder(tf.int32,[None,],name='time')

   

   tt = time * 10

   summary = tf.summary.merge_all()

   with tf.Session() as sess:

       summary_writer = tf.summary.FileWriter(LOG_DIR,sess.graph)

       try:


           coord = tf.train.Coordinator()

           threads = tf.train.start_queue_runners(sess=sess, coord=coord)


           for i in range(5):

               y_,f_,t_ = sess.run([batch_year,batch_flight,batch_time])

               print sess.run(tt,feed_dict={time:t_})

               #summary_str = sess.run(summary,feed_dict=feed_dict)

               #summary_writer.add_summary(summary_str,i)

               summary_writer.flush()         


변경후 코드

def create_graph(times):

   tt = times * 10

   return tt


def main():

   

   print 'start session'

   #coornator 위에 코드가 있어야 한다

   #데이타를 집어 넣기 전에 미리 그래프가 만들어져 있어야 함.

   batch_year,batch_flight,batch_time = read_data_batch(TRAINING_FILE)

   year = tf.placeholder(tf.int32,[None,],name='year')

   flight = tf.placeholder(tf.string,[None,],name='flight')

   time = tf.placeholder(tf.int32,[None,],name='time')

   

   r = create_graph(time)

   

   summary = tf.summary.merge_all()

   with tf.Session() as sess:

       summary_writer = tf.summary.FileWriter(LOG_DIR,sess.graph)

       try:


           coord = tf.train.Coordinator()

           threads = tf.train.start_queue_runners(sess=sess, coord=coord)


           for i in range(5):

               y_,f_,t_ = sess.run([batch_year,batch_flight,batch_time])

               print sess.run(r,feed_dict={time:t_})

               #summary_str = sess.run(summary,feed_dict=feed_dict)

               #summary_writer.add_summary(summary_str,i)

               summary_writer.flush()


변경후 코드는 tt = times * 10 을 create_graph라는 함수로 뺐는데, session 시작전에 함수를 호출한다. 언뜻 보면 개념이 헷갈릴 수 있는데, time 이라는 변수는 텐서플로우의 placeholder로 값이 읽혀지는 시점이 queue_runner를 시작해야 값을 읽을 수 있는 준비 상태가 되고, 실제로 값을 큐에서 읽으려면 session을 실행하고 feed_dict를 이용하여 feeding을 해줘야 값이 채워지기 때문에, 일반적인 프로그램상으로는 session을 시작한 후에 함수를 호출해야할것 같이 생각이 되지만, 앞에서도 언급했듯이 텐서플로우에서 프로그래밍의 개념은 그래프를 다 만들어놓은 후 (데이타가 처리되는 흐름을 모두 정의해놓고) 그 다음 session을 실행하여 그래프에 데이타를 채워놓는 개념이기 때문에, session이 정의되기 전에 함수 호출등을 이용해서 그래프를 정의해야 한다.


본인은 구글 클라우드의 직원이며, 이 블로그에 있는 모든 글은 회사와 관계 없는 개인의 의견임을 알립니다.

댓글을 달아 주세요

  1. 2017.09.13 16:44  댓글주소  수정/삭제  댓글쓰기

    비밀댓글입니다

딥러닝을 이용한 숫자 이미지 인식 #2/2


앞서 MNIST 데이타를 이용한 필기체 숫자를 인식하는 모델을 컨볼루셔널 네트워크 (CNN)을 이용하여 만들었다. 이번에는 이 모델을 이용해서 필기체 숫자 이미지를 인식하는 코드를 만들어 보자


조금 더 테스트를 쉽게 하기 위해서, 파이썬 주피터 노트북내에서 HTML 을 이용하여 마우스로 숫자를 그릴 수 있도록 하고, 그려진 이미지를 어떤 숫자인지 인식하도록 만들어 보겠다.



모델 로딩

먼저 앞의 예제에서 학습을한 모델을 로딩해보도록 하자.

이 코드는 주피터 노트북에서 작성할때, 모델을 학습 시키는 코드 (http://bcho.tistory.com/1156) 와 별도의 새노트북에서 구현을 하도록 한다.


코드

import tensorflow as tf

import numpy as np

import matplotlib.pyplot as plt

from tensorflow.examples.tutorials.mnist import input_data


#이미 그래프가 있을 경우 중복이 될 수 있기 때문에, 기존 그래프를 모두 리셋한다.

tf.reset_default_graph()


num_filters1 = 32


x = tf.placeholder(tf.float32, [None, 784])

x_image = tf.reshape(x, [-1,28,28,1])


#  layer 1

W_conv1 = tf.Variable(tf.truncated_normal([5,5,1,num_filters1],

                                         stddev=0.1))

h_conv1 = tf.nn.conv2d(x_image, W_conv1,

                      strides=[1,1,1,1], padding='SAME')


b_conv1 = tf.Variable(tf.constant(0.1, shape=[num_filters1]))

h_conv1_cutoff = tf.nn.relu(h_conv1 + b_conv1)


h_pool1 =tf.nn.max_pool(h_conv1_cutoff, ksize=[1,2,2,1],

                       strides=[1,2,2,1], padding='SAME')


num_filters2 = 64


# layer 2

W_conv2 = tf.Variable(

           tf.truncated_normal([5,5,num_filters1,num_filters2],

                               stddev=0.1))

h_conv2 = tf.nn.conv2d(h_pool1, W_conv2,

                      strides=[1,1,1,1], padding='SAME')


b_conv2 = tf.Variable(tf.constant(0.1, shape=[num_filters2]))

h_conv2_cutoff = tf.nn.relu(h_conv2 + b_conv2)


h_pool2 =tf.nn.max_pool(h_conv2_cutoff, ksize=[1,2,2,1],

                       strides=[1,2,2,1], padding='SAME')


# fully connected layer

h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*num_filters2])


num_units1 = 7*7*num_filters2

num_units2 = 1024


w2 = tf.Variable(tf.truncated_normal([num_units1, num_units2]))

b2 = tf.Variable(tf.constant(0.1, shape=[num_units2]))

hidden2 = tf.nn.relu(tf.matmul(h_pool2_flat, w2) + b2)


keep_prob = tf.placeholder(tf.float32)

hidden2_drop = tf.nn.dropout(hidden2, keep_prob)


w0 = tf.Variable(tf.zeros([num_units2, 10]))

b0 = tf.Variable(tf.zeros([10]))

k = tf.matmul(hidden2_drop, w0) + b0

p = tf.nn.softmax(k)


# prepare session

sess = tf.InteractiveSession()

sess.run(tf.global_variables_initializer())

saver = tf.train.Saver()

saver.restore(sess, '/Users/terrycho/anaconda/work/cnn_session')


print 'reload has been done'


그래프 구현

코드를 살펴보면, #prepare session 부분 전까지는 이전 코드에서의 그래프를 정의하는 부분과 동일하다. 이 코드는 우리가 만든 컨볼루셔널 네트워크를 복원하는 부분이다.


변수 데이타 로딩

그래프의 복원이 끝나면, 저장한 세션의 값을 다시 로딩해서 학습된 W와 b값들을 다시 로딩한다.


# prepare session

sess = tf.InteractiveSession()

sess.run(tf.global_variables_initializer())

saver = tf.train.Saver()

saver.restore(sess, '/Users/terrycho/anaconda/work/cnn_session')


이때 saver.restore 부분에서 앞의 예제에서 저장한 세션의 이름을 지정해준다.

HTML을 이용한 숫자 입력

그래프와 모델 복원이 끝났으면 이 모델을 이용하여, 숫자를 인식해본다.

테스트하기 편리하게 HTML로 마우스로 숫자를 그릴 수 있는 화면을 만들어보겠다.

주피터 노트북에서 새로운 Cell에 아래와 같은 내용을 입력한다.


코드

input_form = """

<table>

<td style="border-style: none;">

<div style="border: solid 2px #666; width: 143px; height: 144px;">

<canvas width="140" height="140"></canvas>

</div></td>

<td style="border-style: none;">

<button onclick="clear_value()">Clear</button>

</td>

</table>

"""


javascript = """

<script type="text/Javascript">

   var pixels = [];

   for (var i = 0; i < 28*28; i++) pixels[i] = 0

   var click = 0;


   var canvas = document.querySelector("canvas");

   canvas.addEventListener("mousemove", function(e){

       if (e.buttons == 1) {

           click = 1;

           canvas.getContext("2d").fillStyle = "rgb(0,0,0)";

           canvas.getContext("2d").fillRect(e.offsetX, e.offsetY, 8, 8);

           x = Math.floor(e.offsetY * 0.2)

           y = Math.floor(e.offsetX * 0.2) + 1

           for (var dy = 0; dy < 2; dy++){

               for (var dx = 0; dx < 2; dx++){

                   if ((x + dx < 28) && (y + dy < 28)){

                       pixels[(y+dy)+(x+dx)*28] = 1

                   }

               }

           }

       } else {

           if (click == 1) set_value()

           click = 0;

       }

   });

   

   function set_value(){

       var result = ""

       for (var i = 0; i < 28*28; i++) result += pixels[i] + ","

       var kernel = IPython.notebook.kernel;

       kernel.execute("image = [" + result + "]");

   }

   

   function clear_value(){

       canvas.getContext("2d").fillStyle = "rgb(255,255,255)";

       canvas.getContext("2d").fillRect(0, 0, 140, 140);

       for (var i = 0; i < 28*28; i++) pixels[i] = 0

   }

</script>

"""


다음 새로운 셀에서, 다음 코드를 입력하여, 앞서 코딩한 HTML 파일을 실행할 수 있도록 한다.


from IPython.display import HTML

HTML(input_form + javascript)


이제 앞에서 만든 두 셀을 실행시켜 보면 다음과 같이 HTML 기반으로 마우스를 이용하여 숫자를 입력할 수 있는 박스가 나오는것을 확인할 수 있다.



입력값 판정

앞의 HTML에서 그린 이미지는 앞의 코드의 set_value라는 함수에 의해서, image 라는 변수로 784 크기의 벡터에 저장된다. 이 값을 이용하여, 이 그림이 어떤 숫자인지를 앞서 만든 모델을 이용해서 예측을 해본다.


코드


p_val = sess.run(p, feed_dict={x:[image], keep_prob:1.0})


fig = plt.figure(figsize=(4,2))

pred = p_val[0]

subplot = fig.add_subplot(1,1,1)

subplot.set_xticks(range(10))

subplot.set_xlim(-0.5,9.5)

subplot.set_ylim(0,1)

subplot.bar(range(10), pred, align='center')

plt.show()

예측

예측을 하는 방법은 쉽다. 이미지 데이타가 image 라는 변수에 들어가 있기 때문에, 어떤 숫자인지에 대한 확률을 나타내는 p 의 값을 구하면 된다.


p_val = sess.run(p, feed_dict={x:[image], keep_prob:1.0})


를 이용하여 x에 image를 넣고, 그리고 dropout 비율을 0%로 하기 위해서 keep_prob를 1.0 (100%)로 한다. (예측이기 때문에 당연히 dropout은 필요하지 않다.)

이렇게 하면 이 이미지가 어떤 숫자인지에 대한 확률이 p에 저장된다.

그래프로 표현

그러면 이 p의 값을 찍어 보자


fig = plt.figure(figsize=(4,2))

pred = p_val[0]

subplot = fig.add_subplot(1,1,1)

subplot.set_xticks(range(10))

subplot.set_xlim(-0.5,9.5)

subplot.set_ylim(0,1)

subplot.bar(range(10), pred, align='center')

plt.show()


그래프를 이용하여 0~9 까지의 숫자 (가로축)일 확률을 0.0~1.0 까지 (세로축)으로 출력하게 된다.

다음은 위에서 입력한 숫자 “4”를 인식한 결과이다.



(보너스) 첫번째 컨볼루셔널 계층 결과 출력

컨볼루셔널 네트워크를 학습시키다 보면 종종 컨볼루셔널 계층을 통과하여 추출된 특징 이미지들이 어떤 모양을 가지고 있는지를 확인하고 싶을때가 있다. 그래서 각 필터를 통과한 값을 이미지로 출력하여 확인하고는 하는데, 여기서는 이렇게 각 필터를 통과하여 인식된 특징이 어떤 모양인지를 출력하는 방법을 소개한다.


아래는 우리가 만든 네트워크 중에서 첫번째 컨볼루셔널 필터를 통과한 결과 h_conv1과, 그리고 이 결과에 bias 값을 더하고 활성화 함수인 Relu를 적용한 결과를 출력하는 예제이다.


코드


conv1_vals, cutoff1_vals = sess.run(

   [h_conv1, h_conv1_cutoff], feed_dict={x:[image], keep_prob:1.0})


fig = plt.figure(figsize=(16,4))


for f in range(num_filters1):

   subplot = fig.add_subplot(4, 16, f+1)

   subplot.set_xticks([])

   subplot.set_yticks([])

   subplot.imshow(conv1_vals[0,:,:,f],

                  cmap=plt.cm.gray_r, interpolation='nearest')

plt.show()


x에 image를 입력하고, dropout을 없이 모든 네트워크를 통과하도록 keep_prob:1.0으로 주고, 첫번째 컨볼루셔널 필터를 통과한 값 h_conv1 과, 이 값에 bias와 Relu를 적용한 값 h_conv1_cutoff를 계산하였다.

conv1_vals, cutoff1_vals = sess.run(

   [h_conv1, h_conv1_cutoff], feed_dict={x:[image], keep_prob:1.0})


첫번째 필터는 총 32개로 구성되어 있기 때문에, 32개의 결과값을 imshow 함수를 이용하여 흑백으로 출력하였다.




다음은 bias와 Relu를 통과한 값인 h_conv_cutoff를 출력하는 예제이다. 위의 코드와 동일하며 subplot.imgshow에서 전달해주는 인자만 conv1_vals → cutoff1_vals로 변경되었다.


코드


fig = plt.figure(figsize=(16,4))


for f in range(num_filters1):

   subplot = fig.add_subplot(4, 16, f+1)

   subplot.set_xticks([])

   subplot.set_yticks([])

   subplot.imshow(cutoff1_vals[0,:,:,f],

                  cmap=plt.cm.gray_r, interpolation='nearest')

   

plt.show()


출력 결과는 다음과 같다



이제까지 컨볼루셔널 네트워크를 이용한 이미지 인식을 텐서플로우로 구현하는 방법을 MNIST(필기체 숫자 데이타)를 이용하여 구현하였다.


실제로 이미지를 인식하려면 전체적인 흐름은 같지만, 이미지를 전/후처리 해내야 하고 또한 한대의 머신이 아닌 여러대의 머신과 GPU와 같은 하드웨어 장비를 사용한다. 다음 글에서는 MNIST가 아니라 실제 칼라 이미지를 인식하는 방법에 대해서 데이타 전처리에서 부터 서비스까지 전체 과정에 대해서 설명하도록 하겠다.


예제 코드 : https://github.com/bwcho75/tensorflowML/blob/master/MNIST_CNN_Prediction.ipynb


본인은 구글 클라우드의 직원이며, 이 블로그에 있는 모든 글은 회사와 관계 없는 개인의 의견임을 알립니다.

댓글을 달아 주세요

  1. 권오성 2017.01.11 01:14  댓글주소  수정/삭제  댓글쓰기

    html에서 재입력하고, 판정까지 재실행하면, error가 발생합니다.
    아마도 이유는 변수의 초기화와 관련 있는 듯 한데... 혹시 방법을 알고 계신지요???

  2. 2017.01.22 14:08  댓글주소  수정/삭제  댓글쓰기

    비밀댓글입니다

  3. 김영재 2017.02.10 22:51  댓글주소  수정/삭제  댓글쓰기

    좋은 자료 올려주셔서 감사합니다
    도움 많이 됐습니다.

  4. 성현준 2017.05.18 06:23  댓글주소  수정/삭제  댓글쓰기

    학습하고 있는 학생입니다.
    html canvas에 아무것도 입력이 되질않는데, 오류인가요?

  5. 1234 2017.06.11 20:46  댓글주소  수정/삭제  댓글쓰기

    ImportError: No module named 'IPython'

    아이파이썬 임폴트 에러가 나오는거같은대 어떻게해결해여되나요

  6. 1234 2017.06.11 21:46  댓글주소  수정/삭제  댓글쓰기

    관리자의 승인을 기다리고 있는 댓글입니다

  7. 1234 2017.06.11 21:51  댓글주소  수정/삭제  댓글쓰기

    관리자의 승인을 기다리고 있는 댓글입니다

  8. 2017.11.12 13:34  댓글주소  수정/삭제  댓글쓰기

    비밀댓글입니다

  9. 단결 2017.12.05 17:32  댓글주소  수정/삭제  댓글쓰기

    안녕하세요 블로그에서 많은 정보 얻고갑니다! 그런데 실행을 하다가 다른부분은 잘되는데 예측하는 값에서 결과가 안나오네요...

  10. 2018.04.18 16:21  댓글주소  수정/삭제  댓글쓰기

    비밀댓글입니다


텐서플로우로 모델을 만들어보자

Softmax를 이용한 숫자 인식

조대협 (http://bcho.tistory.com)


텐서플로우와 머신러닝에 대한 개념에 대해서 대략적으로 이해 했으면 간단한 코드를 한번 짜보자.

MNIST

그러면 이제 실제로 텐서플로우로 모델을 만들어서 학습을 시켜보자. 예제에 사용할 시나리오는 MNIST (Mixed National Institute of Standards and Technology database) 라는 데이타로, 손으로 쓴 숫자이다. 이 손으로 쓴 숫자 이미지를 0~9 사이의 숫자로 인식하는 예제이다.



이 예제는 텐서플로우 MNIST 튜토리얼 (https://www.tensorflow.org/tutorials/mnist/beginners/) 을 기반으로 작성하였는데, 설명이 빠진 부분과 소스코드 일부분이 수정되었으니 내용이 약간 다르다는 것을 인지해주기를 바란다.


MNIST 숫자 이미지를 인식하는 모델을 softmax 알고리즘을 이용하여 만든 후에, 트레이닝을 시키고, 정확도를 체크해보도록 하겠다.

데이타셋

MNIST 데이타는 텐서플로우 내에 라이브러리 형태로 내장이 되어 있어서 쉽게 사용이 가능하다. tensorflow.examples.tutorials.mnist 패키지에 데이타가 들어 있는데, read_data_sets 명령어를 이용하면 쉽게 데이타를 로딩할 수 있다.


데이타 로딩 코드

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data


mnist = input_data.read_data_sets('/tmp/tensorflow/mnist/input_data', one_hot=True)


Mnist 데이타셋에는 총 60,000개의 데이타가 있는데, 이 데이타는  크게 아래와 같이 세종류의 데이타 셋으로 나눠 진다. 모델 학습을 위한 학습용 데이타인 mnist.train 그리고, 학습된 모델을 테스트하기 위한 테스트 데이타 셋은 minst.test, 그리고 모델을 확인하기 위한 mnist.validation 데이타셋으로 구별된다.

각 데이타는 아래와 같이 학습용 데이타 55000개, 테스트용 10,000개, 그리고, 확인용 데이타 5000개로 구성되어 있다.


데이타셋 명

행렬 차원

데이타 종류

노트

mnist.train.images

55000 x 784

학습 이미지 데이타


mnist.train.labels

55000 x 10

학습 라벨 데이타


mnist.test.images

10000 x 784

테스트용 이미지 데이타


mnist.test.labels

10000 x 10

테스트용 라벨 데이타


mnist.validation.images

5000 x 784

확인용 이미지 데이타


mnist.validation.labels

5000 x 10

확인용 라벨 데이타



각 데이타셋은 학습을 위한 글자 이미지를 저장한 데이타 image 와, 그 이미지가 어떤 숫자인지를 나타낸 라벨 데이타인 label로 두개의 데이타 셋으로 구성되어 있다.

이미지

먼저 이미지 데이타를 보면 아래 그림과 같이 28x28 로 구성되어 있는데,


이를 2차원 행렬에서 1차원으로 쭈욱 핀 형태로 784개의 열을 가진 1차원 행렬로 변환되어 저장이 되어 있다.

mnist.train.image는 이러한 784개의 열로 구성된 이미지가 55000개가 저장이 되어 있다.


텐서플로우의 행렬을 나타내는 shape의 형태로는 shape=[55000,784] 이 된다.


마찬가지로, mnist.train.image 도 784개의 열로 구성된 숫자 이미지 데이타를 10000개를 가지고 있고 텐서플로우의 shape으로는 shape=[10000,784] 로 표현될 수 있다.


라벨

Label 은 이미지가 나타내는 숫자가 어떤 숫자인지를 나타내는 라벨 데이타로 10개의 숫자로 이루어진 1행 행렬이다. 0~9 순서로, 그 숫자이면 1 아니면 0으로 표현된다. 예를 들어 1인경우는 [0,1,0,0,0,0,0,0,0,0,0]  9인 경우는 [0,0,0,0,0,0,0,0,0,1] 로 표현된다.

이미지 데이타에 대한 라벨이기 때문에, 당연히 이미지 데이타 수만큼의 라벨을 가지게 된다.



Train 데이타 셋은 이미지가 55000개 였기 때문에, Train의 label의 수 역시도 55000개가 된다.


소프트맥스 회귀(Softmax regression)

숫자 이미지를 인식하는 모델은 많지만, 여기서는 간단한 알고리즘 중 하나인 소프트 맥스 회귀 모델을 사용하겠다.

소프트맥스 회귀에 대한 알고리즘 자체는 자세히 설명하지 않는다. 소프트맥스 회귀는 classification 알고리즘중의 하나로, 들어온 값이 어떤 분류인지 구분해주는 알고리즘이다.

예를 들어 A,B,C 3개의 결과로 분류해주는 소프트맥스의 경우 결과값은 [0.7,0.2,0.1] 와 같이 각각 A,B,C일 확률을 리턴해준다. (결과값의 합은 1.0이 된다.)


(cf. 로지스틱 회귀는 두 가지로만 분류가 가능하지만, 소프트맥스 회귀는 n 개의 분류로 구분이 가능하다.)


모델 정의

소프트맥스로 분류를 할때, x라는 값이 들어 왔을때, 분류를 한다고 가정했을때, 모델에서 사용하는 가설은 다음과 같다.  

y = softmax (W*x + b)

W는 weight, 그리고 b는 bias 값이다.

y는 최종적으로 10개의 숫자를 감별하는 결과가 나와야 하기 때문에, 크기가 10인 행렬이 되고,

10개의 결과를 만들기 위해서 W역시 10개가 되어야 하며, 이미지 하나는 784개의 숫자로 되어 있기 때문에, 10개의 값을 각각 784개의 숫자에 적용해야 하기 때문에, W는 784x10 행렬이 된다. 그리고, b 는 10개의 값에 각각 더하는 값이기 때문에, 크기가 10인 행렬이 된다.


이를 표현해보면 다음과 같은 그림이 된다.


이를 텐서플로우 코드로 표현하면 다음과 같다.

x = tf.placeholder(tf.float32, [None, 784])

W = tf.Variable(tf.zeros([784, 10]))

b = tf.Variable(tf.zeros([10]))

k = tf.matmul(x, W) + b

y = tf.nn.softmax(k)


우리가 구하고자 하는 값은 x 값으로 학습을 시켜서 0~9를 가장 잘 구별해내는 W와 b의 값을 찾는 일이다.


여기서 코드를 주의깊게 봤다면 하나의 의문이 생길것이다.

x의 데이타는 총 55000개로, 55000x784 행렬이 되고, W는 784x10 행렬이다. 이 둘을 곱하면, 55000x10 행렬이 되는데, b는 1x10 행렬로 차원이 달라서 합이 되지 않는다.

텐서플로우와 파이썬에서는 이렇게 차원이 다른 행렬을 큰 행렬의 크기로 늘려주는 기능이 있는데, 이를 브로드 캐스팅이라고 한다. (브로드 캐스팅 개념 참고 - http://bcho.tistory.com/1153)

브로드 캐스팅에 의해서 b는 55000x10 사이즈로 자동으로 늘어나고 각 행에는 첫행과 같은 데이타들로 채워지게 된다.


소프트맥스 알고리즘을 이해하고 사용해도 좋지만, 텐서플로우에는 이미 tf.nn.softmax 라는 함수로 만들어져 있고, 대부분 많이 알려진 머신러닝 모델들은 샘플들이 많이 있기 때문에, 대략적인 원리만 이해하고 가져다 쓰는 것을 권장한다. 보통 모델을 다 이해하려고 하다가 수학에서 부딪혀서 포기하는 경우가 많은데, 디테일한 모델을 이해하기 힘들면, 그냥 함수나 예제코드를 가져다 쓰는 방법으로 접근하자. 우리가 일반적인 프로그래밍에서도 해쉬테이블이나 트리와 같은 자료구조에 대해서 대략적인 개념만 이해하고 미리 정의된 라이브러리를 사용하지 직접 해쉬 테이블등을 구현하는 경우는 드물다.

코스트(비용) 함수

이 소프트맥스 함수에 대한 코스트 함수는 크로스엔트로피 (Cross entropy) 함수의 평균을 이용하는데, 복잡한 산식 없이 그냥 외워서 쓰자. 다행이도 크로스엔트로피 함수역시 함수로 구현이 되어있다.


Cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(tf.matmul(x, W) + b, y_))


가설에 의해 계산된 값 y를 넣지 않고 tf.matmul(x, W) + b 를 넣은 이유는 tf.nn.softmax_cross_entropy_with_logits 함수 자체가 softmax를 포함하기 때문이다.

y_은 학습을 위해서 입력된 값이다.


텐서플로우로 구현

자 그럼 학습을 위한 전체 코드를 보자


샘플코드

# Import data

from tensorflow.examples.tutorials.mnist import input_data

import tensorflow as tf

 

mnist = input_data.read_data_sets('/tmp/tensorflow/mnist/input_data', one_hot=True)


# Create the model

x = tf.placeholder(tf.float32, [None, 784])

W = tf.Variable(tf.zeros([784, 10]))

b = tf.Variable(tf.zeros([10]))

k = tf.matmul(x, W) + b

y = tf.nn.softmax(k)


# Define loss and optimizer

y_ = tf.placeholder(tf.float32, [None, 10])                                                                               

learning_rate = 0.5

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(k, y_))

train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)


print ("Training")

sess = tf.Session()

init = tf.global_variables_initializer() #.run()

sess.run(init)

for _ in range(1000):

   # 1000번씩, 전체 데이타에서 100개씩 뽑아서 트레이닝을 함.  

   batch_xs, batch_ys = mnist.train.next_batch(100)

   sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})


print ('b is ',sess.run(b))

print('W is',sess.run(W))

데이타 로딩

# Import data

from tensorflow.examples.tutorials.mnist import input_data

import tensorflow as tf

 

mnist = input_data.read_data_sets('/tmp/tensorflow/mnist/input_data', one_hot=True)


앞에서 데이타에 대해서 설명한것과 같이 데이타를 로딩하는 부분이다. read_data_sets에 들어가 있는 디렉토리는 샘플데이타를 온라인에서 다운 받는데, 그 데이타를 임시로 저장해놓을 위치이다.

모델 정의

다음은 소프트맥스를 이용하여 모델을 정의한다.

# Create the model

x = tf.placeholder(tf.float32, [None, 784])

W = tf.Variable(tf.zeros([784, 10]))

b = tf.Variable(tf.zeros([10]))

k = tf.matmul(x, W) + b

y = tf.nn.softmax(k)


x는 트레이닝 데이타를 저장하는 스테이크홀더, W는 Weight, b는 bias 값이고, 모델은 y = tf.nn.softmax(tf.matmul(x, W) + b) 이 된다.

코스트함수와 옵티마이저 정의

모델을 정의했으면 학습을 위해서, 코스트 함수를 정의한다.

# Define loss and optimizer

y_ = tf.placeholder(tf.float32, [None, 10])                                                                               

learning_rate = 0.5

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(k, y_))

train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)


코스트 함수는 크로스 엔트로피 함수의 평균값을 사용한다. 크로스엔트로피 함수는 아래와 같은 모양인데, 이 값을 전체 트레이닝 데이타셋의 수로 나눠 준다.  


그래서 최종적으로 cost 함수는 cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(k, y_)) 이 된다.

이 때 주의할점은 y가 아니라 k를 넣어야 한다. softmax_cross_entropy_with_logits 함수는 softmax를 같이 하기 때문에, 위의 y값은 이미 softmax를 해버린 함수이기 때문에 softmax가 중복될 수 있다.



이 코스트 함수를 가지고 코스트가 최소화가 되는 W와 b를 구해야 하는데, 옵티마이져를 사용한다. 여기서는 경사 하강법(Gradient Descent Optimizer)를 사용하였고 경사하강법에 대한 개념은 http://bcho.tistory.com/1141 를 참고하기 바란다.

GradientDescent에서 learning rate는 학습속도 인데, 학습 속도에 대한 개념은 http://bcho.tistory.com/1141 글을 참고하기 바란다.

세션 초기화  

print ("Training")

sess = tf.Session()

init = tf.global_variables_initializer() #.run()

sess.run(init)


tf.Session() 을 이용해서 세션을 만들고, global_variable_initializer()를 이용하여, 변수들을 모두 초기화한후, 초기화 값을 sess.run에 넘겨서 세션을 초기화 한다.

트레이닝 시작

세션이 생성되었으면 이제 트레이닝을 시작한다.

for _ in range(1000):

   # 1000번씩, 전체 데이타에서 100개씩 뽑아서 트레이닝을 함.  

   batch_xs, batch_ys = mnist.train.next_batch(100)

   sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})


여기서 주목할점은 Batch training 과 Stochastic training 인데, Batch training이란, 학습을 할때 전체 데이타를 가지고 한번에 학습을 하는게 아니라 전체 데이타셋을 몇 개로 쪼갠후 나눠서 트레이닝을 하는 방법을 배치 트레이닝이라고 한다. 그중에서 여기에 사용된 배치 방법은 Stochastic training 이라는 방법인데, 원칙대로라면 전체 55000개 의 학습데이타가 있기 때문에 배치 사이즈를 100으로 했다면, 100개씩 550번 순차적으로 데이타를 읽어서 학습을 해야겠지만, Stochastic training은 전체 데이타중 일부를 샘플링해서 학습하는 방법으로, 여기서는 배치 한번에 100개씩의 데이타를 뽑아서 1000번 배치로 학습을 하였다.

(텐서플로우 문서에 따르면, 전체 데이타를 순차적으로 학습 시키기에는 연산 비용이 비싸기 때문에, 샘플링을 해도 비슷한 정확도를 낼 수 있기 때문에, 예제 차원에서 간단하게, Stochastic training을 사용한것으로 보인다.)


결과값 출력

print ('b is ',sess.run(b))

print('W is',sess.run(W))


마지막으로 학습에서 구해진 W와 b를 출력해보자

다음은 실행 결과 스크린 샷이다.




먼저 앞에서 데이타를 로딩하도록 지정한 디렉토리에, 학습용 데이타를 다운 받아서 압축 받는 것을 확인할 수 있다. (Extracting.. 부분)

그 다음 학습이 끝난후에, b와 W 값이 출력되었다. W는 784 라인이기 때문에, 중간을 생략하고 출력되었으나, 각 행을 모두 찍어보면 아래와 같이 W 값이 들어가 있는 것을 볼 수 있다.


모델 검증

이제 모델을 만들고 학습을 시켰으니, 이 모델이 얼마나 정확하게 작동하는지를 테스트 해보자.  mnist.test.image 와 mnist.test.labels 데이타셋을 이용하여 테스트를 진행하는데, 앞에서 나온 모델에 mnist.test.image 데이타를 넣어서 예측을 한 후에, 그 결과를 mnist.test.labels (정답)과 비교해서 정답률이 얼마나 되는지를 비교한다.


다음은 모델 테스팅 코드이다. 이 코드를 위의 코드 뒤에 붙여서 실행하면 된다.


모델 검증 코드

print ("Testing model")

# Test trained model

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

print('accuracy ',sess.run(accuracy, feed_dict={x: mnist.test.images,

                                    y_: mnist.test.labels}))

print ("done")

   

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))

코드를 보자, tf.argmax 함수를 이해해야 하는데, argmax(y,1)은 행렬 y에서 몇번째에 가장 큰 값이 들어가 있는지를 리턴해주는 함수이다. 아래 예제 코드를 보면


session = tf.InteractiveSession()


data = tf.constant([9,2,11,4])

idx = tf.argmax(data,0)

print idx.eval()


session.close()


[9,2,11,4] 에서 최대수는 11이고, 이 위치는 두번째 (0 부터 시작한다)이기 때문에 0을 리턴한다.

두번째 변수는 어느축으로 카운트를 할것인지를 선택한다. , 1차원 배열의 경우에는 0을 사용한다.

여기서 y는 2차원 행렬인데, 0이면 같은 열에서 최대값인 순서, 1이면 같은 행에서 최대값인 순서를 리턴한다.

그럼 원래 코드로 돌아오면 tf.argmax(y,1)은 y의 각행에서 가장 큰 값의 순서를 찾는다. y의 각행을 0~9으로 인식한 이미지의 확률을 가지고 있다.

아래는 4를 인식한 y 값인데, 4의 값이 0.7로 가장높기 (4일 확률이 70%, 3일 확률이 10%, 1일 확률이 20%로 이해하면 된다.) 때문에, 4로 인식된다.

여기서 tf.argmax(y,1)을 사용하면, 행별로 가장 큰 값을 리턴하기 때문에, 위의 값에서는 4가 리턴이된다.

테스트용 데이타에서 원래 정답이 4로 되어 있다면, argmax(y_,1)도 4를 리턴하기 때문에, tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))는 tf.equals(4,4)로 True를 리턴하게 된다.


모든 테스트 셋에 대해서 검증을 하고 나서 그 결과에서 True만 더해서, 전체 트레이닝 데이타의 수로 나눠 주면 결국 정확도가 나오는데, tf.cast(boolean, tf.float32)를 하면 텐서플로우의 bool 값을 float32 (실수)로 변환해준다. True는 1.0으로 False는 0.0으로 변환해준다. 이렇게 변환된 값들의 전체 평균을 구하면 되기 때문에, tf.reduce_mean을 사용한다.


이렇게 정확도를 구하는 함수가 정의되었으면 이제 정확도를 구하기 위해 데이타를 넣어보자

sess.run(accuracy, feed_dict={x: mnist.test.images,y_: mnist.test.labels})

x에 mnist.test.images 데이타셋으로 이미지 데이타를 입력받아서  y (예측 결과)를 계산하고, y_에는 mnist.test.labels 정답을 입력 받아서, y와 y_로 정확도 accuracy를 구해서 출력한다.


최종 출력된 accuracy 정확도는 0.9 로 대략 90% 정도가 나온다.


Testing model
('accuracy ', 0.90719998)
done


다른 알고리즘의 정확도는 http://rodrigob.github.io/are_we_there_yet/build/classification_datasets_results.html 를 참고하면 된다.


다음글에서는 소프트맥스 모델 대신 CNN (Convolutional Neural Network)를 이용하여, 조금 더 정확도가 높은  MNIST를 구현하고 테스트해보도록 하겠다.


참고 자료

  • 텐서플로우 MNIST https://www.tensorflow.org/tutorials/mnist/beginners/


2017년 1월 6일 추가

위의 코드 부분에 잘못된 부분이 있어서 수정합니다.


k = tf.matmul(x, W) + b

y = tf.nn.softmax(k)


# Define loss and optimizer

y_ = tf.placeholder(tf.float32, [None, 10])                                                                               

learning_rate = 0.5

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(k, y_))


https://github.com/tensorflow/tensorflow/blob/master/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.nn.softmax_cross_entropy_with_logits.md 레퍼런스에 따르면


WARNING: This op expects unscaled logits, since it performs a softmax on logits internally for efficiency. Do not call this op with the output of softmax, as it will produce incorrect results.


tf.nn.softmax_cross_entropy_with_logits 함수는 softmax를 포함하고 있다. 그래서 softmax를 적용한 y를 넣으면 안되고 softmax 적용전인 k를 넣어야 한다.



본인은 구글 클라우드의 직원이며, 이 블로그에 있는 모든 글은 회사와 관계 없는 개인의 의견임을 알립니다.

댓글을 달아 주세요

  1. jchern 2017.01.19 11:40  댓글주소  수정/삭제  댓글쓰기

    아래 내용에 오류가 있는것 같네요
    [9,2,11,4] 에서 최대수는 11이고, 이 위치는 두번째 (0 부터 시작한다)이기 때문에 0을 리턴한다.
    => 2를 리턴한다

  2. kkw 2017.01.28 20:19  댓글주소  수정/삭제  댓글쓰기

    지나가다가 궁금해서 하나 여쭤볼게요. n차원 추상화에 대해서 어떻게 이해해야 하는지를 궁금합니다. 1차원은 선, 2차원은 면, 3차원은 공간이라고 생각을 하고 여기에 n차원을 끼워 맞추려 하다보니 4차원을 넘어가는 차원이라면 와닿지가 않아서요. n차원 추상화 개념을 어떻게 받아들이셨나요??

    • Jay 2018.06.07 17:57  댓글주소  수정/삭제

      일반적으로 알려진 차원으로만 이해하면 힘듭니다 2차원 이후부턴 2차원 배열 을 몇개 가지나 로 시도 해 보세요. N차원 : 2차원 * N-1개
      4차원 : 2차원 *2개

  3. ask0127 2017.03.22 13:38  댓글주소  수정/삭제  댓글쓰기

    W랑 b 를 print 하실때 W는 784 라인이기때문에 중간값이 생략되어 출력된다고 하셨는데 각 행을 모두 찍어보실때 어떻게 하셨나요? 저도 dimension이 너무 커서그런지 중간은 모두 '...'로 생략되어 나오는데...

    • bocky 2017.11.17 21:50  댓글주소  수정/삭제

      저는
      W_get = sess.run(W)
      으로 값을 받아서
      print('W(12,:) is',W_get[12,:])
      으로 출력했더니 한 행씩 값 찍어볼 수 있었습니다.

  4. likesea7 2017.07.16 08:36  댓글주소  수정/삭제  댓글쓰기

    TF 1.0 이상에서는 tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(k, y_)) 문이
    오류나네요. 변수를 선언해야한다고 합니다.

  5. Jisuk Lee 2018.03.25 20:29  댓글주소  수정/삭제  댓글쓰기

    likesea7님 말씀대로 아래처럼 변경해야 하네요.ㅎㅎ
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = k,labels = y_))

  6. 타이탄양 2018.04.29 01:24 신고  댓글주소  수정/삭제  댓글쓰기

    안녕하세요! 지금 대학교 3학년 재학중인 학생입니다. 학교수업과는 무관하게 머신러닝에대해서 찾아보면서 무작정 공부하던 차에 블로그에서 자료를 보면서 공부하게 되었습니다. 혹시 머신러닝과 관련해서 알고있으면 도움이 될 기반지식이나 꼭 알고 있어야 할 논리나 개념에 대해 알려주실 수 있으실까요..?

  7. 해피로퍼 2019.02.13 12:20  댓글주소  수정/삭제  댓글쓰기

    키야............이런것이었군요 감사합니다.

  8. 이승희 2019.11.30 19:06  댓글주소  수정/삭제  댓글쓰기

    관리자의 승인을 기다리고 있는 댓글입니다

  9. 이승희 2019.11.30 19:36  댓글주소  수정/삭제  댓글쓰기

    관리자의 승인을 기다리고 있는 댓글입니다

파이어베이스 애널러틱스를 이용한 모바일 데이타 분석 #1-Hello Firebase

조대협 (http://bcho.tistory.com)


얼마전에 구글은 모바일 백앤드 플랫폼인 파이어베이스를 인수하고 이를 서비스로 공개하였다.

파이어 베이스는 모바일 백앤드의 종합 솔루션으로, 크래쉬 리포팅, 리모트 컨피그를 이용한 A/B 테스팅 플랫폼, 클라우드와 자동 동기화가 가능한 리얼타임 데이타 베이스, 사용자 인증 기능, 강력한 푸쉬 플랫폼 다양한 모바일 기기에 대해서 테스트를 해볼 수 있는 테스트랩 등, 모바일 앱 개발에 필요한 모든 서비스를 제공해주는 종합 패키지와 같은 플랫폼이라고 보면 된다. 안드로이드 뿐만 아니라 iOS까지 지원하여 모든 모바일 앱 개발에 공통적으로 사용할 수 있다.



그중에서 파이어베이스 애널러틱스 (Firebase analytics)는 모바일 부분은 모바일 앱에 대한 모든 이벤트를 수집 및 분석하여 자동으로 대쉬 보드를 통하여 분석을 가능하게 해준다.


이 글에서는 파이어베이스 전체 제품군중에서 파이어베이스 애널러틱스에 대해서 수회에 걸쳐서 설명을 하고자 한다.


파이어베이스 애널러틱스

이미 시장에는 모바일 앱에 대한 데이타 분석이 가능한 유료 또는 무료 제품이 많다.

대표적으로 야후의 flurry, 트위터 fabric, 구글 애널러틱스등이 대표적인 제품군인데, 그렇다면 파이어베이스가 애널러틱스가 가지고 있는 장단점은 무엇인가?


퍼널 분석 및 코호트 분석 지원

파이어베이스 애널러틱스는 데이타 분석 방법중에 퍼넬 분석과 코호트 분석을 지원한다.

퍼널 분석은 한글로 깔데기 분석이라고 하는데, 예를 들어 사용자가 가입한 후에, 쇼핑몰의 상품 정보를 보고  주문 및 결재를 하는 단계 까지 각 단계별로 사용자가 이탈하게 된다. 이 구조를 그려보면 깔데기 모양이 되는데,사용자 가입에서 부터 최종 목표인 주문 결재까지 이루도록 단계별로 이탈율을 분석하여 서비스를 개선하고, 이탈율을 줄이는데 사용할 수 있다.

코호트 분석은 데이타를 집단으로 나누어서 분석하는 방법으로 일일 사용자 데이타 (DAU:Daily Active User)그래프가 있을때, 일일 사용자가 연령별로 어떻게 분포가 되는지등을 나눠서 분석하여 데이타를 조금 더 세밀하게 분석할 수 있는 방법이다.


이러한 코호트 분석과 퍼넬 분석은 모바일 데이타 분석 플랫폼 중에서 일부만 지원하는데, 파이어베이스 애널러틱스는 퍼넬과 코호트 분석을 기본적으로 제공하고 있으며, 특히 코호트 분석으로 많이 사용되는 사용자 잔존율 (Retention 분석)의 경우 별다른 설정 없이도 기본으로 제공하고 있다.


<그림. 구글 파이어베이스의 사용자 잔존율 코호트 분석 차트>

출처 : https://support.google.com/firebase/answer/6317510?hl=en

무제한 앱 및 무제한 사용자 무료 지원

이러한 모바일 서비스 분석 서비스의 경우 사용자 수나 수집할 수 있는 이벤트 수나 사용할 수 있는 앱수에 제약이 있는데, 파이어베이스 애널러틱스의 경우에는 제약이 없다.

빅쿼리 연계 지원

가장 강력한 기능중의 하나이자, 이 글에서 주로 다루고자 하는 내용이 빅쿼리 연동 지원이다.

모바일 데이타 분석 서비스 플랫폼의 경우 대 부분 플랫폼 서비스의 형태를 띄기 때문에, 분석 플랫폼에서 제공해주는 일부 데이타만 볼 수 가 있고, 원본 데이타에 접근하는 것이 대부분 불가능 하다.

그래서 모바일 애플리케이션 서버에서 생성된 데이타나, 또는 광고 플랫폼등 외부 연동 플랫폼에서 온 데이타에 대한 연관 분석이 불가능하고, 원본 데이타를 통하여 여러가지 지표를 분석하는 것이 불가능하다.


파이어베이스 애널러틱스의 경우에는 구글의 데이타 분석 플랫폼이 빅쿼리 연동을 통하여 모든 데이타를 빅쿼리에 저장하여 간단하게 분석이 가능하다.

구글 빅쿼리에 대한 소개는 http://bcho.tistory.com/1116 를 참고하기 바란다.

구글의 빅쿼리는 아마존 S3나, 구글의 스토리지 서비스인 GCS 보다 저렴한 비용으로 데이타를 저장하면서도, 수천억 레코드에 대한 연산을 수십초만에 8~9000개의 CPU와 3~4000개의 디스크를 사용해서 끝낼만큼 어마어마한 성능을 제공하면서도, 사용료 매우 저렴하며 기존 SQL 문법을 사용하기 때문에, 매우 쉽게 접근이 가능하다.

모바일 데이타 분석을 쉽게 구현이 가능

보통 모바일 서비스에 대한 데이타 분석을 할때는 무료 서비스를 통해서 DAU나 세션과 같은 기본적인 정보 수집은 가능하지만, 추가적인 이벤트를 수집하여 저장 및 분석을 하거나 서버나 다른 시스템의 지표를 통합 분석 하는 것은 별도의 로그 수집 시스템을 모바일 앱과 서버에 만들어야 하였고, 이를 분석 및 저장하고 리포팅 하기 위해서 하둡이나 스파크와 같은 복잡한 빅데이타 기술을 사용하고 리포팅에도 많은 시간이 소요 되었다.


파이어베이스 애널러틱스를 이용하면, 손 쉽게, 추가 이벤트나 로그 정보를 기존의 로깅 프레임웍을 통하여 빅쿼리에 저장할 수 있고, 복잡한 하둡이나 스파크의 설치나 프로그래밍 없이 빅쿼리에서 간략하게 SQL만을 사용하여 분석을 하고 오픈소스 시각화 도구인 Jupyter 노트북이나 구글의 데이타스튜디오 (http://datastudio.google.com)을 통하여 시작화가 간단하기 때문에, 이제는 누구나 쉽게 빅데이타 로그를 수집하고 분석할 수 있게 된다.

실시간 데이타 분석은 지원하지 않음

파이어베이스 애널러틱스가 그러면 만능 도구이고 좋은 기능만 있는가? 그건 아니다. 파이어베이스 애널러틱스는 아직까지는 실시간 데이타 분석을 지원하고 있지 않다. 수집된 데이타는 보통 수시간이 지나야 대쉬 보드에 반영이 되기 때문에 현재 접속자나, 실시간 모니터링에는 적절하지 않다.

그래서 보완을 위해서 다른 모니터링 도구와 혼용해서 사용하는 게 좋다. 실시간 분석이 강한 서비스로는 트위터 fabric이나 Google analytics 등이 있다.

이러한 도구를 이용하여 데이타에 대한 실시간 분석을 하고, 정밀 지표에 대한 분석을 파이어베이스 애널러틱스를 사용 하는 것이 좋다.


파이어베이스 애널러틱스 적용해보기

백문이 불여일견이라고, 파이어베이스 애널러틱스를 직접 적용해보자.

https://firebase.google.com/ 사이트로 가서, 가입을 한 후에, “콘솔로 이동하기"를 통해서 파이어 베이스 콘솔로 들어가자.

프로젝트 생성하기

다음으로 파이어베이스 프로젝트를 생성한다. 상단 메뉴에서 “CREATE NEW PROJECT”를 선택하면 새로운 파이어 베이스 프로젝트를 생성할 수 있다. 만약에 기존에 사용하던 구글 클라우드 프로젝트등이 있으면 별도의 프로젝트를 생성하지 않고 “IMPORT GOOGLE PROJECT”를 이용하여 기존의 프로젝트를 불러와서 연결할 수 있다.



프로젝트가 생성되었으면 파이어베이스를 사용하고자 하는 앱을 등록해야 한다.

파이어베이스 화면에서 “ADD APP” 이라는 버튼을 누르면 앱을 추가할 수 있다.

아래는 앱을 추가하는 화면중 첫번째 화면으로 앱에 대한 기본 정보를 넣는 화면이다.

“Package name” 에, 파이어베이스와 연동하고자 하는 안드로이드 앱의 패키지 명을 넣는다.


ADD APP 버튼을 누르고 다음 단계로 넘어가면 google-services.json 이라는 파일이 자동으로 다운된다. 이 파일은 나중에 안드로이드 앱의 소스에 추가해야 하기 때문에 잘 보관한다.


Continue 버튼을 누르면 아래와 같이 다음 단계로 넘어간다. 다음 단계에서는 안드로이드 앱을 개발할때 파이어베이스를 연동하려면 어떻게 해야 하는지에 대한 가이드가 나오는데, 이 부분은 나중에 코딩 부분에서 설명할 예정이니 넘어가도록 하자.


자 이제 파이어베이스 콘솔에서, 프로젝트를 생성하고 앱을 추가하였다.

이제 연동을 할 안드로이드 애플리케이션을 만들어보자.

안드로이드 빌드 환경 설정

콘솔에서 앱이 추가되었으니, 이제 코드를 작성해보자, 아래 예제는 안드로이드 스튜디오 2.1.2 버전 (맥 OS 기준) 으로 작성되었다.


먼저 안드로이드 프로젝트를 생성하였다. 이때 반드시 안드로이드 프로젝트에서 앱 패키지 명은 앞에 파이어베이스 콘솔에서 지정한 com.terry.hellofirebase가 되어야 한다.

안드로이드 프로젝트에는 프로젝트 레벨의 build.gradle 파일과, 앱 레벨의 build.gradle 파일이 있는데



프로젝트 레벨의 build.gradle 파일에 classpath 'com.google.gms:google-services:3.0.0' 를 추가하여  다음과 같이 수정한다.


// Top-level build file where you can add configuration options common to all sub-projects/modules.


buildscript {

  repositories {

      jcenter()

  }

  dependencies {

      classpath 'com.android.tools.build:gradle:2.1.2'

      classpath 'com.google.gms:google-services:3.0.0'

      // NOTE: Do not place your application dependencies here; they belong

      // in the individual module build.gradle files

  }

}


allprojects {

  repositories {

      jcenter()

  }

}


task clean(type