演算法導論——python實踐(4.2矩陣乘法的Strassen演算法)
阿新 • • 發佈:2019-02-09
4.2.1 矩陣乘法的暴力解法
#暴力解法
def matrix_multiply(a,b):
n=len(a)
c=[[0]*n for i in range(n)]#快速建立n階初始化方陣
for i in range (0,n):
for j in range(0,n):
c[i][j]=0
for k in range(0,n):
c[i][j]+=a[i][k]*b[k][j]
return c
這裡假定a和b都是方陣,如果選擇暴力解法,三個for迴圈,迴圈次數為n,總共需要花費θ(n^3)時間。
4.2.2 矩陣乘法的簡單分治法
演算法策略:(前提:假定A,B都是n等於2的次冪的方陣)
(1)基本思路:計算C=A*B時,將C,A,B矩陣進行分塊操作,對每個分塊的矩陣進行乘法運算,運算完畢後重新對得到的C11,C12,C21,C22進行組合操作。
(2)確定遞迴終止條件:當分塊矩陣得到的階數為1 時,得到的C即是A和B中兩個元素的乘積。
def division(a): #矩陣分塊函式 n=len(a)//2 a11=[[0 for i in range(n)]for j in range(n)] a12=[[0 for i in range(n)]for j in range(n)] a21=[[0 for i in range(n)]for j in range(n)] a22=[[0 for i in range(n)]for j in range(n)] for i in range(n): for j in range(n): a11[i][j]=a[i][j] a12[i][j]=a[i][j+n] a21[i][j]=a[i+n][j] a22[i][j]=a[i+n][j+n] return (a11,a12,a21,a22) def matrix_combination(a11,a12,a21,a22): n2 = len(a11) n=n2*2 a = [[0 for col in range(n)] for row in range(n)] for i in range (0,n): for j in range (0,n): if i <= (n2-1) and j <= (n2-1): a[i][j] = a11[i][j] elif i <= (n2-1) and j > (n2-1): a[i][j] = a12[i][j-2] elif i > (n2-1) and j <= (n2-1): a[i][j] = a21[i-n2][j] else: a[i][j] = a22[i-n2][j-n2] return a def matrix_add(a,b): #矩陣相加函式 n = len(a) c = [[0 for col in range(n)] for row in range(n)] for i in range(0,n): for j in range(0,n): c[i][j] = a[i][j]+b[i][j] return c def matrix_devision_multiply(a,b): #矩陣乘法的簡單分治法主程式 n=len(a) c = [[0 for col in range(n)] for row in range(n)]#c=[[0]*n for i in range(n)] if n==1: c[0][0]=a[0][0]*b[0][0] else: (a11,a21,a12,a22)=division(a) (b11,b21,b12,b22)=division(b) (c11,c21,c12,c22)=division(c) c11=matrix_add(matrix_devision_multiply(a11,b11),matrix_devision_multiply(a12,b21)) c12=matrix_add(matrix_devision_multiply(a11,b12),matrix_devision_multiply(a12,b22)) c21=matrix_add(matrix_devision_multiply(a21,b11),matrix_devision_multiply(a22,b21)) c22=matrix_add(matrix_devision_multiply(a21,b12),matrix_devision_multiply(a22,b22)) c=matrix_combination(c11,c12,c21,c22) return c a=[[1,1,1,1],[1,1,1,1],[2,2,2,2],[2,2,2,2]] b=a print(matrix_devision_multiply(a,b))
4.2.3矩陣的Strassen演算法
在簡單分治法的思想上,為進一步減少遞迴樹的分枝,在遞迴函式中只進行7次而不是8次的矩陣的乘法,而減少一次乘法的代價是增加額外的幾次矩陣加法運算。
def matrix_strassen(a,b): n=len(a) c = [[0 for col in range(n)] for row in range(n)] if n==1: c[0][0]=a[0][0]*b[0][0] else: (a11,a12,a21,a22)=division(a) (b11,b12,b21,b22)=division(b) (c11,c12,c21,c22)=division(c) s1=matrix_add_sub(b12,b22,0) s2=matrix_add_sub(a11,a12,1) s3=matrix_add_sub(a21,a22,1) s4=matrix_add_sub(b21,b11,0) s5=matrix_add_sub(a11,a22,1) s6=matrix_add_sub(b11,b22,1) s7=matrix_add_sub(a12,a22,0) s8=matrix_add_sub(b21,b22,1) s9=matrix_add_sub(a11,a21,0) s10=matrix_add_sub(b11,b12,1) p1=matrix_strassen(a11,s1) p2=matrix_strassen(s2,b22) p3=matrix_strassen(s3,b11) p4=matrix_strassen(a22,s4) p5=matrix_strassen(s5,s6) p6=matrix_strassen(s7,s8) p7=matrix_strassen(s9,s10) c11=matrix_add_sub(matrix_add_sub(matrix_add_sub(p5,p4,1),p2,0),p6,1) c12=matrix_add_sub(p1,p2,1) c21=matrix_add_sub(p3,p4,1) c22=matrix_add_sub(matrix_add_sub(matrix_add_sub(p5,p1,1),p3,0),p7,0) c=matrix_combination(c11,c12,c21,c22) return c #矩陣的strssen演算法 def division(a): #對矩陣進行分解操作 n=len(a)//2 a11=[[0 for i in range(n)]for j in range(n)] a12=[[0 for i in range(n)]for j in range(n)] a21=[[0 for i in range(n)]for j in range(n)] a22=[[0 for i in range(n)]for j in range(n)] for i in range(n): for j in range(n): a11[i][j]=a[i][j] a12[i][j]=a[i][j+n] a21[i][j]=a[i+n][j] a22[i][j]=a[i+n][j+n] return (a11,a12,a21,a22) def matrix_add_sub(a,b,keys): n = len(a) c = [[0 for col in range(n)] for row in range(n)] if keys==1: for i in range(n): for j in range(n): c[i][j] = a[i][j]+b[i][j] else: for i in range(n): for j in range(n): c[i][j]=a[i][j]-b[i][j] return c def matrix_combination(a11,a12,a21,a22): #對矩陣進行組合操作 n2 = len(a11) n=n2*2 a = [[0 for col in range(n)] for row in range(n)] for i in range (0,n): for j in range (0,n): if i <= (n2-1) and j <= (n2-1): a[i][j] = a11[i][j] elif i <= (n2-1) and j > (n2-1): a[i][j] = a12[i][j-n2] elif i > (n2-1) and j <= (n2-1): a[i][j] = a21[i-n2][j] else: a[i][j] = a22[i-n2][j-n2] return a a=[[1,1,1,1],[1,1,1,1],[2,2,2,2],[2,2,2,2]] b=a print(matrix_strassen(a,b))
4.2.4 修改Strassen演算法,使之適應矩陣規模n不是2的冪的情況。
具體思路是將不是2的次冪的矩陣擴充套件成2的次冪的矩陣,在多出的行和列上添上0元素,在計算結果重新組合成c後,對c矩陣多出的行和列上的0元素捨去。因此在簡單分治程式的基礎上增加了matrix_expand和matrix_shrink函式。在主函式中,首先對輸入矩陣A,B的階數進行判斷,如果是2的次冪則不用進行任何操作,直接用普通的Strassen演算法,如果不是2的次冪,先對A,B進行矩陣拓展,在計算得到的結果後進行矩陣縮略。
#coding UTF-8
#矩陣的strassen演算法
from math import *
def matrix_strassen(a,b):
n=len(a)
c = [[0 for col in range(n)] for row in range(n)]
if n==1:
c[0][0]=a[0][0]*b[0][0]
else:
(a11,a12,a21,a22)=division(a)
(b11,b12,b21,b22)=division(b)
(c11,c12,c21,c22)=division(c)
s1=matrix_add_sub(b12,b22,0)
s2=matrix_add_sub(a11,a12,1)
s3=matrix_add_sub(a21,a22,1)
s4=matrix_add_sub(b21,b11,0)
s5=matrix_add_sub(a11,a22,1)
s6=matrix_add_sub(b11,b22,1)
s7=matrix_add_sub(a12,a22,0)
s8=matrix_add_sub(b21,b22,1)
s9=matrix_add_sub(a11,a21,0)
s10=matrix_add_sub(b11,b12,1)
p1=matrix_strassen(a11,s1)
p2=matrix_strassen(s2,b22)
p3=matrix_strassen(s3,b11)
p4=matrix_strassen(a22,s4)
p5=matrix_strassen(s5,s6)
p6=matrix_strassen(s7,s8)
p7=matrix_strassen(s9,s10)
c11=matrix_add_sub(matrix_add_sub(matrix_add_sub(p5,p4,1),p2,0),p6,1)
c12=matrix_add_sub(p1,p2,1)
c21=matrix_add_sub(p3,p4,1)
c22=matrix_add_sub(matrix_add_sub(matrix_add_sub(p5,p1,1),p3,0),p7,0)
c=matrix_combination(c11,c12,c21,c22)
return c
def matrix_expand(a): #對a,b執行矩陣擴充套件程式段
n=len(a)
m=ceil(log(n,2))
p=int(pow(2,m))
c=[[0 for col in range(p)]for row in range(p)]#執行expand模式
for i in range(p):
for j in range(p):
if i>=n or j>=n:
c[i][j]=0
else:
c[i][j]=a[i][j]
return c
def matrix_shrink(a,b):
n=len(b)
c=[[0 for col in range(n)]for row in range(n)]
for i in range(n):
for j in range(n):
c[i][j]=a[i][j]
return c
def division(a): #對矩陣進行分解操作
n=len(a)//2
a11=[[0 for i in range(n)]for j in range(n)]
a12=[[0 for i in range(n)]for j in range(n)]
a21=[[0 for i in range(n)]for j in range(n)]
a22=[[0 for i in range(n)]for j in range(n)]
for i in range(n):
for j in range(n):
a11[i][j]=a[i][j]
a12[i][j]=a[i][j+n]
a21[i][j]=a[i+n][j]
a22[i][j]=a[i+n][j+n]
return (a11,a12,a21,a22)
def matrix_add_sub(a,b,keys): #矩陣加減程式,keys=1時執行矩陣相加,keys=0時執行矩陣相減
n = len(a)
c = [[0 for col in range(n)] for row in range(n)]
if keys==1:
for i in range(n):
for j in range(n):
c[i][j] = a[i][j]+b[i][j]
else:
for i in range(n):
for j in range(n):
c[i][j]=a[i][j]-b[i][j]
return c
def matrix_combination(a11,a12,a21,a22): #對矩陣進行組合操作
n2 = len(a11)
n=n2*2
a = [[0 for col in range(n)] for row in range(n)]
for i in range (0,n):
for j in range (0,n):
if i <= (n2-1) and j <= (n2-1):
a[i][j] = a11[i][j]
elif i <= (n2-1) and j > (n2-1):
a[i][j] = a12[i][j-n2]
elif i > (n2-1) and j <= (n2-1):
a[i][j] = a21[i-n2][j]
else:
a[i][j] = a22[i-n2][j-n2]
return a
a=[[1,1,1,1,1],[1,1,1,1,1],[2,2,2,2,2],[2,2,2,2,2],[3,3,3,3,3]]
b=a
n=len(a)
if not(log(n,2)-floor(log(n,2))): #如果是2的次冪
print(matrix_strassen(a,b))
else:
print(matrix_shrink(matrix_strassen(matrix_expand(a),matrix_expand(b)),a))
此種演算法是自己能想到的最簡單的思路,但是增加了計算量,例如5階的方陣會轉化為8階的方陣進行計算,增加了不必要的繁瑣的0和0的乘法。自己現在也沒在網上看相關的資料,能力有限,望諸位看官海涵。
參考文獻:
1.演算法導論 機械工業出版社 第四章第二節 矩陣的Strassen演算法。