淺談keras中的batch_dot,dot方法和TensorFlow的matmul
阿新 • • 發佈:2020-06-19
概述
在使用keras中的keras.backend.batch_dot和tf.matmul實現功能其實是一樣的智慧矩陣乘法,比如A,B,C,D,E,F,G,H,I,J,K,L都是二維矩陣,中間點表示矩陣乘法,AG 表示矩陣A 和G 矩陣乘法(A 的列維度等於G 行維度),WX=Z
import keras.backend as K import tensorflow as tf import numpy as np w = K.variable(np.random.randint(10,size=(10,12,4,5))) k = K.variable(np.random.randint(10,5,8))) z = K.batch_dot(w,k) print(z.shape) #(10,8)
import keras.backend as K import tensorflow as tf import numpy as np w = tf.Variable(np.random.randint(10,5)),dtype=tf.float32) k = tf.Variable(np.random.randint(10,8)),dtype=tf.float32) z = tf.matmul(w,8)
示例
from keras import backend as K a = K.ones((3,2)) b = K.ones((2,3,7)) c = K.dot(a,b) print(c.shape)
會輸出:
ValueError: Dimensions must be equal,but are 2 and 3 for ‘MatMul' (op: ‘MatMul') with input shapes: [60,2],[3,70].
from keras import backend as K a = K.ones((3,4)) b = K.ones((4,5)) c = K.dot(a,b) print(c.shape)#(3,5)
或者
import tensorflow as tf a = tf.ones((3,4)) b = tf.ones((4,5)) c = tf.matmul(a,5)
如果增加維度:
from keras import backend as K a = K.ones((2,4)) b = K.ones((7,b) print(c.shape)#(2,7,5)
這個矩陣乘法會沿著兩個矩陣最後兩個維度進行乘法,不是element-wise矩陣乘法
from keras import backend as K a = K.ones((1,2,4)) b = K.ones((8,b) print(c.shape)#(1,8,5)
keras的dot方法是Theano中的複製
from keras import backend as K a = K.ones((1,b) print(c.shape)# (1,5).
from keras import backend as K a = K.ones((9,2)) b = K.ones((9,5)) c = K.batch_dot(a,b) print(c.shape) #(9,5)
或者
import tensorflow as tf a = tf.ones((9,2)) b = tf.ones((9,5)
以上這篇淺談keras中的batch_dot,dot方法和TensorFlow的matmul就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。