tensorflow-tf.shape(x)、x.shape和x.get_shape()的区别

tf.shape(x)、x.shape和x.get_shape()的区别

对于Tensor来说

 1import tensorflow as tf
 2
 3input = tf.constant([[0,1,2],[3,4,5]])
 4
 5print(type(input.shape))
 6print(type(input.get_shape()))
 7print(type(tf.shape(input)))
 8
 9Out:
10<class 'tensorflow.python.framework.tensor_shape.TensorShape'>
11<class 'tensorflow.python.framework.tensor_shape.TensorShape'>
12<class 'tensorflow.python.framework.ops.Tensor'>

可以看到x.shapex.get_shape()都是返回TensorShape类型对象,而tf.shape(x)返回的是Tensor类型对象。

具体来说tf.shape()返回的是tensor,想要获取tensor具体的shape结果需要sess.run才行。而tf.get_shapex.shape返回的是一个元组,因此要想操作维度信息,则需要调用TensorShape的tf.as_list()方法,返回的是Python的list。

需要注意的是tf.get_shape()返回的是元组,不能放到sess.run()里面,这个里面只能放operation和tensor

对于placeholder来说

tf.placeholder占位符来说,如果shape设置的其中某一个是None,那么对于tf.shape,sess.run会报错,而tf.get_shape不会,它会在None位置显示“?”表示此位置的shape暂时未知。

 1a = tf.Variable(tf.constant(1.5, dtype=tf.float32, shape=[1,2,3,4,5,6,7]), name='a')
 2b = tf.placeholder(dtype=tf.int32, shape=[None, 3], name='b')
 3s1 = tf.shape(a)
 4s2 = a.get_shape()
 5print (s1)  # Tensor("Shape:0", shape=(7,), dtype=int32)
 6print (s2)  # 元组 (1, 2, 3, 4, 5, 6, 7)
 7
 8s11 = tf.shape(b)
 9s21 = b.get_shape()
10print (s11)  # Tensor("Shape_1:0", shape=(2,), dtype=int32)
11print (s21)  # 因为第一位设置的是None,所以这里的第一位显示问号表示暂时不确认 (?, 3)
12with tf.Session() as sess:
13    sess.run(tf.global_variables_initializer())
14    print (sess.run(s1))  # [1 2 3 4 5 6 7]
15    print (sess.run(s11))
16    # InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'b' with dtype int32
17    # [[Node: b = Placeholder[dtype=DT_INT32, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]