tensorflow高级用法

在CV工作中,为了实现data agumentation,我们常常需要对图片进行旋转,裁剪,平移,加噪等操作,而tensorflow只提供了部分图片操作功能。我们就需要手动实现一些效果,tf.map_fn() 及 tf.py_fn() 可供参考。tf.map_fn(fn, elems):接受一个函数对象,然后用该函数对象对集合(elems)中的每一个元素分别处理,

实现图片旋转功能
def get_imagbatch(filename, batch_size ):
    ......
   return tf.train.batch([mypng ], batch_size = batch_size )

def random_rotate_image(image):  
    angle = random.uniform(-30.0, 30.0)
    image_rotate = misc.imrotate(image, angle, )
    image_rotate = tf.cast(image_rotate, tf.float32)
    return image_rotate
firBatch = get_imagbatch(.....)
firBatch = tf.map_fn(lambda png:random_rotate_image(png), firBatch, tf.float32)

ref:http://blog.csdn.net/xukaiwen_2016/article/details/77571415?ABstrategy=codes_snippets_optimize_v4

https://github.com/agubrud/udacity_carnd_term1/blob/1b9c1da5ee3ddd3ac9ca6295854a732f501dd0d1/CarND-Traffic-Sign-Classifier-Project/writeup_gubrud_aaron.md

  1. tf.py_func() :用来将 一个 python 函数打包成一个 op, 测试了一下,很可惜,无法求导。

注意:如果python_func()函数有 string 参数的话,tf会把这个string参数 转换成 bytes 类型。

 ———————————分割线————————————-

用tf.py_func()时遇到的坑。

InvalidArgumentError (see above for traceback): ConcatOp : Ranks of all input tensors should match: shape[0] = [4,384,512,3] vs. shape[3] = [1,4,384,512,1]

打印出来的shape[3] 明明是[4, 384, 512, 1], 为毛会报这个错呢??

def _interact(gd_flows, pre_flows):

         ......
      return [np.array(maps, np.float32)]
maps = tf.py_func( _interact, [gd_flows, pre_flows],    [tf.float32], name='interact_map')
def _interact(gd_flows, pre_flows):

         ......
      return np.array(maps, np.float32)
maps = tf.py_func( _interact, [gd_flows, pre_flows],    tf.float32, name='interact_map')

只需要将返回值和Tout 的【】去掉。。。。。

发表评论

电子邮件地址不会被公开。 必填项已用*标注