在tensorflow
的代码里经常看到tf.split()
这个函数,我们来看看这个具体用法
tf.split(
value,
num_or_size_splits,
axis=0,
num=None,
name='split'
)
把一个张量划分成几个子张量:
- value:准备切分的张量
- num_or_size_splits:准备切成几份
- axis : 准备在第几个维度上进行切割
其中分割方式分为两种
如果
num_or_size_splits
传入的 是一个整数,那直接在axis=D这个维度上把张量平均切分成几个小张量如果
num_or_size_splits
传入的是一个向量(这里向量各个元素的和要跟原本这个维度的数值相等)就根据这个向量有几个元素分为几项)
举个例子
# 张量为(5, 30)
# 这个时候5是axis=0, 30是axis=1,如果要在axis=1这个维度上把这个张量拆分成三个子张量
# 传入向量时
split0, split1, split2 = tf.split(value, [4, 15, 11], 1)
tf.shape(split0) # [5, 4]
tf.shape(split1) # [5, 15]
tf.shape(split2) # [5, 11]
# 传入整数时
split0, split1, split2 = tf.split(value, num_or_size_splits=3, axis=1)
tf.shape(split0) # [5, 10]
在来个详细的例子:
import tensorflow as tf
value = [[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]]
print('axis=0时,拆分....')
split0, split1, split2 = tf.split(value, [1, 1, 1], 0)
with tf.Session() as sess:
print(sess.run(split0))
print("------------")
print(sess.run(split1))
print("------------")
print(sess.run(split2))
print('axis=1时,拆分....')
split0, split1, split2 = tf.split(value, [1, 2, 1], 1)
with tf.Session() as sess:
print(sess.run(split0))
print("------------")
print(sess.run(split1))
print("------------")
print(sess.run(split2))
运行结果:
[[1 2 3 4]]
------------
[[5 6 7 8]]
------------
[[ 9 10 11 12]]
axis=1时,拆分....
[[1]
[5]
[9]]
------------
[[ 2 3]
[ 6 7]
[10 11]]
------------
[[ 4]
[ 8]
[12]]