텐서플로우 배치 처리
조대협 (http://bcho.tistory.com)
텐서플로우에서 파일에서 데이타를 읽은 후에, 배치처리로 placeholder에서 읽는 예제를 설명한다.
텐서의 shape 의 차원과 세션의 실행 시점등이 헷갈려서 시행착오가 많았기 때문에 글로 정리해놓는다.
큐와 파일처리에 대한 기본적인 내용은 아래글
http://bcho.tistory.com/1163
http://bcho.tistory.com/1165
데이타 포맷
읽어 드릴 데이타 포맷은 다음과 같다. 비행기 노선 정보에 대한 데이타로 “년도,항공사 코드, 편명"을 기록한 CSV 파일이다.
2014,VX,121
2014,WN,1873
2014,WN,2787
배치 처리 코드
이 데이타를 텐서 플로우에서 읽어서 배치로 place holder에 feeding 하는 코드 이다
먼저 read_data는 csv 파일에서 데이타를 읽어서 파싱을 한 후 각 컬럼을 year,flight,time 으로 리턴하는 함 수이다.
def read_data(file_name):
try:
csv_file = tf.train.string_input_producer([file_name],name='filename_queue')
textReader = tf.TextLineReader()
_,line = textReader.read(csv_file)
year,flight,time = tf.decode_csv(line,record_defaults=[ [1900],[""],[0] ],field_delim=',')
except:
print "Unexpected error:",sys.exc_info()[0]
exit()
return year,flight,time
string_input_producer를 통해서 파일명들을 큐잉해서 하나씩 읽는데,여기서는 편의상 하나의 파일만 읽도록 하였는데, 여러개의 파일을 병렬로 처리하고자 한다면, [file_name] 부분에 리스트 형으로 여러개의 파일 목록을 지정해주면 된다.
다음 각 파일을 TextReader를 이용하여 라인 단위로 읽은 후 decode_csv를 이용하여, “,”로 분리된 컬럼을 각각 읽어서 year,flight,time 에 저장하여 리턴하였다.
다음 함수는 read_data_batch 라는 함수인데, 앞에서 정의한 read_data 함수를 호출하여, 읽어드린 year,flight,time 을 배치로 묶어서 리턴하는 함수 이다.
def read_data_batch(file_name,batch_size=10):
year,flight,time = read_data(file_name)
batch_year,batch_flight,batch_time = tf.train.batch([year,flight,time],batch_size=batch_size)
return batch_year,batch_flight,batch_time
tf.train.batch 함수가 배치로 묶어서 리턴을 하는 함수인데, batch로 묶고자 하는 tensor 들을 인자로 준 다음에, batch_size (한번에 묶어서 리턴하고자 하는 텐서들의 개수)를 정해주면 된다.
위의 예제에서는 batch_size를 10으로 해줬기 때문에, batch_year = [ 1900,1901….,1909] 와 같은 형태로 10개의 년도를 하나의 텐서에 묶어서 리턴해준다.
즉 입력 텐서의 shape이 [x,y,z] 일 경우 tf.train.batch를 통한 출력은 [batch_size,x,y,z] 가 된다.(이 부분이 핵심)
메인 코드
자 이제 메인 코드를 보자
def main():
print 'start session'
#coornator 위에 코드가 있어야 한다
#데이타를 집어 넣기 전에 미리 그래프가 만들어져 있어야 함.
batch_year,batch_flight,batch_time = read_data_batch(TRAINING_FILE)
year = tf.placeholder(tf.int32,[None,])
flight = tf.placeholder(tf.string,[None,])
time = tf.placeholder(tf.int32,[None,])
tt = time * 10
tt = time * 10 이라는 공식을 실행하기 위해서 time 이라는 값을 읽어서 피딩하는 예제인데 먼저 read_data_batch를 이용하여 데이타를 읽는 그래프를 생성한다. 이때 주의해야할점은 이 함수를 수행한다고 해서, 바로 데이타를 읽기 시작하는 것이 아니라, 데이타의 흐름을 정의하는 그래프만 생성된다는 것을 주의하자
다음으로는 year,flight,time placeholder를 정의한다.
year,flight,time 은 0 차원의 scalar 텐서이지만, 값이 연속적으로 들어오기 때문에, [None, ] 로 정의한다.
즉 year = [1900,1901,1902,1903,.....] 형태이기 때문에 1차원 Vector 형태의 shape으로 [None, ] 로 정의한다.
Placeholder 들에 대한 정의가 끝났으면, 세션을 정의하고 데이타를 읽어드리기 위한 Queue runner를 수행한다. 앞의 과정까지 텐서 그래프를 다 그렸고, 이 그래프 값을 부어넣기 위해서, Queue runner 를 수행한 것이다.
with tf.Session() as sess:
try:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
Queue runner를 실행하였기 때문에 데이타가 데이타 큐로 들어가기 시작하고, 이 큐에 들어간 데이타를 읽어드리기 위해서, 세션을 실행한다.
y_,f_,t_ = sess.run([batch_year,batch_flight,batch_time])
print sess.run(tt,feed_dict={time:t_})
세션을 실행하면, batch_year,batch_flight,batch_time 값을 읽어서 y_,f_,t_ 변수에 각각 집어 넣은 다음에, t_ 값을 tt 공식의 time 변수에 feeding 하여, 값을 계산한다.
모든 작업이 끝났으면 아래와 같이 Queue runner를 정지 시킨다.
coord.request_stop()
coord.join(threads)
다음은 앞에서 설명한 전체 코드이다.
import tensorflow as tf
import numpy as np
import sys
TRAINING_FILE = '/Users/terrycho/dev/data/flight.csv'
## read training data and label
def read_data(file_name):
try:
csv_file = tf.train.string_input_producer([file_name],name='filename_queue')
textReader = tf.TextLineReader()
_,line = textReader.read(csv_file)
year,flight,time = tf.decode_csv(line,record_defaults=[ [1900],[""],[0] ],field_delim=',')
except:
print "Unexpected error:",sys.exc_info()[0]
exit()
return year,flight,time
def read_data_batch(file_name,batch_size=10):
year,flight,time = read_data(file_name)
batch_year,batch_flight,batch_time = tf.train.batch([year,flight,time],batch_size=batch_size)
return batch_year,batch_flight,batch_time
def main():
print 'start session'
#coornator 위에 코드가 있어야 한다
#데이타를 집어 넣기 전에 미리 그래프가 만들어져 있어야 함.
batch_year,batch_flight,batch_time = read_data_batch(TRAINING_FILE)
year = tf.placeholder(tf.int32,[None,])
flight = tf.placeholder(tf.string,[None,])
time = tf.placeholder(tf.int32,[None,])
tt = time * 10
with tf.Session() as sess:
try:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(10):
y_,f_,t_ = sess.run([batch_year,batch_flight,batch_time])
print sess.run(tt,feed_dict={time:t_})
print 'stop batch'
coord.request_stop()
coord.join(threads)
except:
print "Unexpected error:", sys.exc_info()[0]
main()
다음은 실행결과이다.
'빅데이타 & 머신러닝 > 머신러닝' 카테고리의 다른 글
머신러닝 라벨 데이타 타입에 대해서 (0) | 2017.04.10 |
---|---|
텐서플로우의 세션,그래프 그리고 함수의 개념 (1) | 2017.04.03 |
연예인 얼굴 인식 서비스를 만들어보자 #2-CSV에 있는 이미지 목록을 텐서로 읽어보자 (4) | 2017.03.15 |
연예인 얼굴 인식 서비스를 만들어보자 #1 - 학습 데이타 준비하기 (2) | 2017.03.14 |
텐서플로우 - 파일에서 학습데이타를 읽어보자#2 (Reader와 Decoder) (2) | 2017.03.11 |