TENSORFLOW 生成batch测试

import tensorflow as tf

IMAGE_SIZE = 256
height = IMAGE_SIZE
width = IMAGE_SIZE

# 我们要读两幅图片x.png, y.png
filename = ['x.png', 'y.png' ]
# string_input_producer会产生一个文件名队列
filename_queue = tf.train.string_input_producer(filename, num_epochs=6)
# reader从文件名队列中读数据。对应的方法是reader.read
reader = tf.WholeFileReader()
key, value = reader.read(filename_queue)#此处读出的value为二进制数据
#tf.random_crop(reshaped_image, [height, width, 3])
# 解码png图片数据并切片
mypng = tf.image.decode_png(value, channels=3)
mypng = tf.random_crop(mypng, [height, width, 3])
#生成包括4张图片的batch,此步骤省去可以得到单独的图片tensor
mypng_batch = tf.train.batch([mypng], batch_size = 4)
with tf.Session() as sess:

# tf.train.string_input_producer定义了一个epoch变量,要对它进行初始化
    sess.run(tf.local_variables_initializer())
    coord = tf.train.Coordinator()  #创建一个协调器,管理线程
    # 使用start_queue_runners之后,才会开始填充队列
    threads = tf.train.start_queue_runners(coord=coord)
    i = 0
    while True:
        i += 1
        # 获取图片数据并保存
        try:#因为定义了 num_epochs=6,故6个 epoch后会抛出outofrange异常,要捕获
            image_data = sess.run(mypng_batch)
            print(image_data.shape , '+++++++++++++++')
        except Exception:
    print('ecpch搞定!')
    break
   with open('/home/chuchienshu/test_%d.jpg' % i, 'wb') as f:
f.write(image_data)   #此处写的是tensor数据,并不真的可以看到图片,可以写入上面提及的二进制数据

2 Feedbacks on “TENSORFLOW 生成batch测试”

    1. 你好,这只是一段toy example代码。对于以原图作为输入的数据集来说,图片对应的名称应该作为label,此处可理解label分别为x.png 与 y.png,其它的数据集,例如csv文件,二进制文件等,一般会将label数据与图片信息写在一起,读取时分离即可。

发表评论

电子邮件地址不会被公开。 必填项已用*标注