搜索
Hi~登录注册
查看: 1012|回复: 0

tensorflow使用range_input_producer多线程读取数据实例

[复制链接]

0

主题

0

帖子

10

积分

新手上路

Rank: 1

积分
10
发表于 2020-1-27 13:23:06 | 显示全部楼层 |阅读模式
先放关键代码:
  1. i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])
复制代码
原理剖析:
第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。shuffle指定是否打乱顺序,这里shuffle=False表示队列的元素是按0到NUM_EXPOCHES-1的顺序存储。在Graph运行的时间,每个线程从队列取出元素,假设值为i,然后按照第二行代码切出array的一小段数据作为一个batch。例如NUM_EXPOCHES=3,如果num_epochs=2,则队列的内容是这样子;
0,1,2,0,1,2
队列只有6个元素,这样在训练的时间只能产生6个batch,迭代6次以后训练就结束。
如果num_epochs不指定,则队列内容是这样子:
0,1,2,0,1,2,0,1,2,0,1,2...
队列可以一直生成元素,训练的时间可以产生无限的batch,需要本身控制什么时间停止训练。
下面是完整的演示代码。
数据文件test.txt内容:
  1. 1234567891011121314151617181920212223242526272829303132333435
复制代码
main.py内容:
[code]import tensorflow as tfimport codecs BATCH_SIZE = 6NUM_EXPOCHES = 5  def input_producer(): array = codecs.open("test.txt").readlines()        array = map(lambda line: line.strip(), array) i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue() inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE]) return inputs  class Inputs(object): def __init__(self):  self.inputs = input_producer()  def main(*args, **kwargs): inputs = Inputs() init = tf.group(tf.initialize_all_variables(),     tf.initialize_local_variables()) sess = tf.Session() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) sess.run(init) try:  index = 0  while not coord.should_stop() and index
回复

使用道具 举报

游客
回复
您需要登录后才可以回帖 登录 | 点我注册

快速回复 返回顶部 返回列表