|
先放关键代码:- i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])
复制代码 原理剖析:
第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。shuffle指定是否打乱顺序,这里shuffle=False表示队列的元素是按0到NUM_EXPOCHES-1的顺序存储。在Graph运行的时间,每个线程从队列取出元素,假设值为i,然后按照第二行代码切出array的一小段数据作为一个batch。例如NUM_EXPOCHES=3,如果num_epochs=2,则队列的内容是这样子;
0,1,2,0,1,2
队列只有6个元素,这样在训练的时间只能产生6个batch,迭代6次以后训练就结束。
如果num_epochs不指定,则队列内容是这样子:
0,1,2,0,1,2,0,1,2,0,1,2...
队列可以一直生成元素,训练的时间可以产生无限的batch,需要本身控制什么时间停止训练。
下面是完整的演示代码。
数据文件test.txt内容: - 1234567891011121314151617181920212223242526272829303132333435
复制代码 main.py内容:
[code]import tensorflow as tfimport codecs BATCH_SIZE = 6NUM_EXPOCHES = 5 def input_producer(): array = codecs.open("test.txt").readlines() array = map(lambda line: line.strip(), array) i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue() inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE]) return inputs class Inputs(object): def __init__(self): self.inputs = input_producer() def main(*args, **kwargs): inputs = Inputs() init = tf.group(tf.initialize_all_variables(), tf.initialize_local_variables()) sess = tf.Session() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) sess.run(init) try: index = 0 while not coord.should_stop() and index |
|