tensorflow 基本函式(不斷更新哦)
阿新 • • 發佈:2018-12-20
1. tf.split(3, group, input) # 拆分函式
3 表示的是在第三個維度上, group表示拆分的次數, input 表示輸入的值
import tensorflow as tf import numpy as np x = [[1, 2], [3, 4]] Y = tf.split(axis=1, num_or_size_splits=2, value=x) sess = tf.Session() for y in Y: print(sess.run(y))
2. tf.concat(3, input) # 串接函式
3 表示的是在第三個維度上, input表示的是輸入,輸入一般都是列表
import tensorflow as tf x = [[1, 2], [3, 4]] y = tf.concat(x, axis=0) sess = tf.Session() print(sess.run(y))
3. tf.squeeze(input, squeeze_dims=[1, 2]) # 表示的是去除列數為1的維度, squeeze_dim 指定維度
import tensorflow as tf import numpy as np x = [[1, 2]] print(np.array(x).shape) y = tf.squeeze(x, axis=[0]) sess= tf.Session() print(sess.run(y))
4. tf.less_equal(a, b) a 可以是一個列表, b表示需要比較的數,如果比b大返回false,否者返回True
import tensorflow as tf import numpy as np raw_gt = [1, 2, 3, 4] y = tf.where(tf.less_equal(raw_gt, 2)) sess = tf.Session() print(sess.run(y))
5.tf.where(input) # 返回是真的序號,通過tf.where找出小於等於2的數的序號
import tensorflow as tf import numpy as np raw_gt = [1, 2, 3, 4] y = tf.where(tf.less_equal(raw_gt, 2)) sess = tf.Session() print(sess.run(y))
6. tf.gather # 根據序列號對資料進行取值,輸入的是input, index
import tensorflow as tf import numpy as np raw_gt = [3, 4, 5, 6] y = tf.gather(raw_gt, [[0], [1]]) sess = tf.Session() print(sess.run(y))