tf.shape, x.shape, x.get_shape的區別
阿新 • • 發佈:2018-11-25
tf.shape, x.shape, x.get_shape的區別
tf.shape(x)
回傳的型別為tensorflow.python.framework.ops.Tensor,因此可以在計算圖中使用。在使用sess.run或x.eval()後可以得到x的具體形狀(不會有?的存在)。
x.shape, x.get_shape()
可以將這兩種寫法想成是一樣的功能,回傳的型別皆為tensorflow.python.framework.tensor_shape.TensorShape
使用示例
下面展示了這三種用法的不同之處:
x.shape及x.get_shape()可以單獨使用,得到tensor x在定義時的形狀。
tf.shape則可以被放入計算圖中,得到張量在執行後真正的形狀。同時它也可以被放入其他的tf運算內(此處將tf.shape放到tf.reshape內),成為計算圖的一部份。
// An highlighted block
x1 = np.arange(32).reshape(2,16)
x2 = np.arange(32).reshape(4,8)
a = tf. placeholder(shape=(None, 16), dtype=tf.float32)
b = tf.placeholder(shape=(None, 8), dtype=tf.float32)
c = tf.reshape(b, tf.shape(a))
# c = tf.reshape(b, a.shape) #not work
# c = tf.reshape(b, a.get_shape()) #not work
print(a.shape) #(?, 8)
print(a.get_shape()) #(?, 8)
with tf.Session() as sess:
print (sess.run(tf.shape(a), feed_dict={a: x1})) #[ 2 16]
print(sess.run(a, feed_dict={a: x1}))
"""
result:
[[ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15.]
[16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31.]]
"""
print(sess.run(b, feed_dict={b: x2}))
"""
result:
[[ 0. 1. 2. 3. 4. 5. 6. 7.]
[ 8. 9. 10. 11. 12. 13. 14. 15.]
[16. 17. 18. 19. 20. 21. 22. 23.]
[24. 25. 26. 27. 28. 29. 30. 31.]]
"""
print(sess.run(c, feed_dict={a: x1, b: x2}))
"""
result:
[[ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15.]
[16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31.]]
"""
參考連結
[1]https://stackoverflow.com/questions/37085430/tf-shape-get-wrong-shape-in-tensorflow