tf.split()函数的用法 - tensorflow

半兽人 发表于: 2019-04-16   最后更新时间: 2019-04-16 23:35:38  
{{totalSubscript}} 订阅, 8,029 游览

tensorflow的代码里经常看到tf.split()这个函数,我们来看看这个具体用法

tf.split(
    value,
    num_or_size_splits,
    axis=0,
    num=None,
    name='split'
)

把一个张量划分成几个子张量:

  • value:准备切分的张量
  • num_or_size_splits:准备切成几份
  • axis : 准备在第几个维度上进行切割

其中分割方式分为两种

  1. 如果num_or_size_splits传入的 是一个整数,那直接在axis=D这个维度上把张量平均切分成几个小张量

  2. 如果num_or_size_splits传入的是一个向量(这里向量各个元素的和要跟原本这个维度的数值相等)就根据这个向量有几个元素分为几项)
    举个例子

# 张量为(5, 30)
# 这个时候5是axis=0, 30是axis=1,如果要在axis=1这个维度上把这个张量拆分成三个子张量
# 传入向量时
split0, split1, split2 = tf.split(value, [4, 15, 11], 1)
tf.shape(split0)  # [5, 4]
tf.shape(split1)  # [5, 15]
tf.shape(split2)  # [5, 11]
# 传入整数时
split0, split1, split2 = tf.split(value, num_or_size_splits=3, axis=1)
tf.shape(split0)  # [5, 10]

在来个详细的例子:

import tensorflow as tf

value = [[1, 2, 3, 4],
         [5, 6, 7, 8],
         [9, 10, 11, 12]]

print('axis=0时,拆分....')
split0, split1, split2 = tf.split(value, [1, 1, 1], 0)
with tf.Session() as sess:
    print(sess.run(split0))
    print("------------")
    print(sess.run(split1))
    print("------------")
    print(sess.run(split2))

print('axis=1时,拆分....')
split0, split1, split2 = tf.split(value, [1, 2, 1], 1)
with tf.Session() as sess:
    print(sess.run(split0))
    print("------------")
    print(sess.run(split1))
    print("------------")
    print(sess.run(split2))

运行结果:

[[1 2 3 4]]
------------
[[5 6 7 8]]
------------
[[ 9 10 11 12]]
axis=1时,拆分....
[[1]
 [5]
 [9]]
------------
[[ 2  3]
 [ 6  7]
 [10 11]]
------------
[[ 4]
 [ 8]
 [12]]
更新于 2019-04-16
在线,2小时前登录

查看TensorFlow更多相关的文章或提一个关于TensorFlow的问题,也可以与我们一起分享文章