C++ 方矩陣乘法 + Strassen矩陣
這幾天看演算法導論,看到矩陣一章,就實現了一下。
下面是普通的矩陣乘法,複雜度為:n^3。
template<unsigned M,unsigned N, unsigned Q> void Square_matrix_multiply(int(&A)[M][N], int(&B)[N][Q], int(&C)[M][Q]) { for (size_t i = 0;i != M;++i) { for (size_t j = 0;j != Q;++j) { C[i][j] = 0; for (size_t n = 0;n != N;++n) { C[i][j] += A[i][n] * B[n][j]; } } } }
函式接受三個二維陣列,A * B得到的矩陣賦值給C。
下面是分治策略的演算法。
template<typename T> Matrix Square_max_matrix_multiply_recursive(const T &A, const T &B) { size_t n = A.rows(); Matrix C(n, n); if (n == 1) return C = A.get()*B.get(); else { MatrixRef A_11(A, 0, 0), A_12(A, 0, n / 2), A_21(A, n / 2, 0), A_22(A, n / 2, n / 2); // 使用一個類MatirxRef MatrixRef B_11(B, 0, 0), B_12(B, 0, n / 2), B_21(B, n / 2, 0), B_22(B, n / 2, n / 2); // 含有三個size_t型別。其中兩個實現座標,一個指明矩陣長度 MatrixRef C_11(C, 0, 0), C_12(C, 0, n / 2), C_21(C, n / 2, 0), C_22(C, n / 2, n / 2); // 進行分割 C_11 = Square_max_matrix_multiply_recursive(A_11, B_11) + Square_max_matrix_multiply_recursive(A_12, B_21); // Matrix::operator+; C_12 = Square_max_matrix_multiply_recursive(A_11, B_12) + Square_max_matrix_multiply_recursive(A_12, B_22); // MatrixRef::operator=; C_21 = Square_max_matrix_multiply_recursive(A_21, B_11) + Square_max_matrix_multiply_recursive(A_22, B_21); C_22 = Square_max_matrix_multiply_recursive(A_21, B_12) + Square_max_matrix_multiply_recursive(A_22, B_22); } return C; }
矩陣實現了一個Matrix類(具體實現在最下面),有一個建構函式:接受兩個size_t值l、r,生成l*r大小值全為0的矩陣。
Matrix::Matrix(size_t l, size_t r) : hight(l), width(r), data(make_shared<vector<int>>()) {
data->resize(l*r);
}
其中hight為矩陣行高,width為列寬,data為shared_ptr,矩陣用vector實現。
A.rows()返回A的width長度(即方矩陣的邊長),Matrix(n,n)建立一個矩陣。
size_t rows() const { return width; }
如果n==1,通過Matrix的get函式返回第一個元素,也就是唯一的一個元素。
int Matrix::get() const {
return (*data)[0];
}
為了不復制矩陣元素(如果可以複製矩陣元素的話,會簡單很多),另實現了一個MatrixRef,其含有:兩個size_t資料成員(實現座標點)、一個size_t資料成員(實現矩陣長度)、一個weak_ptr(指向vector<int>)。
MatrixRef含有兩個建構函式:一個接受Matrix加兩個size_t;一個接受MatrixRef加兩個size_t。都是為了指明引用範圍。
MatrixRef::MatrixRef(const Matrix &m, size_t line, size_t row) : wptr(m.data),
hight_startptr(line), width_startptr(row), length(m.rows() / 2) { }
MatrixRef::MatrixRef(const MatrixRef &mref, size_t line, size_t row) : wptr(mref.wptr),
hight_startptr(mref.hight_startptr + line),
width_startptr(mref.width_startptr + row), length(mref.rows() / 2) { }
wptr用data或wptr初始化,避免拷貝。length為rows()的返回值除以2,因為是分割為4個矩陣,行列各除以2。
要注意:接受MatrixRef的座標要加上之前的座標。
MatrixRef也有一個rows成員函式,為了遞迴呼叫。
size_t rows() const {
return length;
}
Square_max_matrix_multiply_recursive函式返回一個Matrix,Matrix實現了operator+,但是行列必須相等。
Matrix& Matrix::operator+=(const Matrix &rhs) {
if (hight == rhs.hight && width == rhs.width) {
for (size_t i = 0;i != size();++i)
(*data)[i] += (*rhs.data)[i];
}
else
throw std::logic_error("Not Matched");
return *this;
}
Matrix operator+(const Matrix &lhs, const Matrix &rhs) {
Matrix m(lhs);
return m += rhs;
}
MatrixRef實現了一個operator=。
MatrixRef& MatrixRef::operator=(const Matrix &rhs) {
for (size_t i = 0;i != length;++i) {
for (size_t j = 0;j != length;++j) {
(*wptr.lock())[(i + hight_startptr)*length * 2 + j + width_startptr] =
rhs.get(i + 1, j + 1); //注意:length*2 因為C也被分割了
}
}
return *this;
}
其中(i + hight_startptr)*length * 2 + j + width_startptr)為當前下標(vector對應矩陣的下標,非矩陣行列)。此函式將分割的C進行“拼合”。注意:length * 2 ,因為C也被分割了,不乘以2為C_11及C_12的長度,乘以2才是C的行列長寬,才能給C的給定位置賦值。
下面是Strassen矩陣演算法。
template<typename T, typename N>
Matrix Strassen_matrix_fit(const T &A, const N &B) { // 為2的冪的情況下
size_t n = A.rows();
Matrix C(n, n);
if (n == 1) {
return C = A.get()*B.get();
}
else {
MatrixRef A_11(A, 0, 0), A_12(A, 0, n / 2), A_21(A, n / 2, 0), A_22(A, n / 2, n / 2); // 使用一個類MatirxRef
MatrixRef B_11(B, 0, 0), B_12(B, 0, n / 2), B_21(B, n / 2, 0), B_22(B, n / 2, n / 2); // 含有三個size_t型別。其中兩個實現座標,一個指明矩陣長度
MatrixRef C_11(C, 0, 0), C_12(C, 0, n / 2), C_21(C, n / 2, 0), C_22(C, n / 2, n / 2); // 進行分割
Matrix S1 = B_12 - B_22, S2 = A_11 + A_12, S3 = A_21 + A_22, S4 = B_21 - B_11, S5 = A_11 + A_22, //MatrixRef的加、減
S6 = B_11 + B_22, S7 = A_12 - A_22, S8 = B_21 + B_22, S9 = A_11 - A_21, S10 = B_11 + B_12;
Matrix P1 = Strassen_matrix_fit(A_11, S1), P2 = Strassen_matrix_fit(S2, B_22),
P3 = Strassen_matrix_fit(S3, B_11), P4 = Strassen_matrix_fit(A_22, S4),
P5 = Strassen_matrix_fit(S5, S6), P6 = Strassen_matrix_fit(S7, S8), P7 = Strassen_matrix_fit(S9, S10);
C_11 = P5 + P4 - P2 + P6;
C_12 = P1 + P2;
C_21 = P3 + P4;
C_22 = P5 + P1 - P3 - P7;
}
return C;
}
此演算法較之前多了一個MatrixRef::operator-、以及MatrixRef::operator+。
Matrix& Matrix::operator-() {
for (auto &f : *data)
f = -f;
return *this;
}
Matrix operator-(const MatrixRef &lhs, const MatrixRef &rhs) {
Matrix ml(lhs), mr(rhs);
return ml = -mr + ml;
}
Matrix operator+(const Matrix &lhs, const Matrix &rhs) {
Matrix m(lhs);
return m += rhs;
}
operator-用Matrix的取負、以及Matrix的加法,同時最重要的還有Matrix(const MatrixRef &)。MatrixRef將此物件引用範圍內的子矩陣建立一個區域性Matrix物件。
operator+用Matrix的加法與Matrix(const MatrixRef &)。
Matrix::Matrix(const MatrixRef &rhs) : hight(rhs.length), width(rhs.length), data(make_shared<vector<int>>()) {
size_t max_size = static_cast<size_t>(sqrt(rhs.wptr.lock()->size())); // 未分解的原式中的矩陣長度
auto ivec = *rhs.wptr.lock();
for (size_t i = 0; i != hight; ++i) {
for (size_t j = 0; j != width; ++j) {
data->push_back(ivec[(i + rhs.hight_startptr)*max_size + j + rhs.width_startptr]);
}
}
}
其中max_size為wptr所指的vector<int>的size,進行根號得到。max_size就是MatrixRef物件未分解(即未分割的C)的矩陣邊長。static_cast把sqrt返回的double轉未size_t,因為是方矩陣,所以不會損失精度。
(i + rhs.hight_startptr)*max_size + j + rhs.width_startptr)為MatrixRef物件引用範圍內對應vector的下標。此必需乘以max_size。
下面為不是2的冪的情況。
template<typename T, typename N>
Matrix Strassen_matrix(const T &A, const N &B) {
size_t n = A.rows();
double size = log(n) / log(2);
size_t l_size = static_cast<size_t>(size);
if (l_size != size) {
size_t t_size = (l_size + 1)*(l_size + 1);
Matrix a(t_size, t_size), b(t_size, t_size);
a = A;
b = B;
Matrix C = Strassen_matrix_fit(a, b);
Matrix c(n, n);
c = C;
return c;
}
else
return Strassen_matrix_fit(A, B);
}
size與l_size比較可知是否為2的冪,如果是,執行else,不是,則執行if。
當不是2的冪時,我的思路是把它加0,拼成2的冪的形式。如下圖。
1 2 3 1 2 3 0 5 6 7
2 3 2 ---> 2 3 2 0 ---> 4 5 2
3 2 1 3 2 1 0 3 5 6
0 0 0 0
然後得出結果時再切去周圍的零,其值是不變的。假如為n*n的矩陣,複雜度(n + k) ^ lg7。n + k 為最接近n的2的冪,其中0<k<n。
(n + k) ^ lg7 < (2n) ^ lg7 = 7 * n ^ lg7。
複雜度還是n ^ lg7。
加零還是切去零,我是通過賦值來實現的。
Matrix& Matrix::operator=(const Matrix &rhs) {
if (hight == rhs.hight) { // rhs this
for (size_t i = 0; i != size(); ++i) { // 1 2 3 1 2 3
(*data)[i] = (*rhs.data)[i]; // 2 3 2 -> 2 3 2
} // 3 2 1 3 2 1
}
else if (hight > rhs.hight) { // 1 2 3 1 2 3 0
for (size_t i = 0;i != hight; ++i) { // 2 3 2 -> 2 3 2 0
for (size_t j = 0, n = 1;j != width; ++j) { // 3 2 1 3 2 1 0
if (j >= rhs.width || i >= rhs.hight) // 0 0 0 0
(*data)[i * width + j] = 0;
else
(*data)[i * width + j] = (*rhs.data)[i * rhs.width + j];
}
}
}
else { // 1 2 3 4 1 2 3
for (size_t i = 0;i != hight; ++i) { // 2 3 4 3 -> 2 3 4
for (size_t j = 0;j != width; ++j) { // 3 4 3 2 3 4 3
(*data)[i * width + j] = (*rhs.data)[i * rhs.width + j]; // 4 3 2 1
}
}
}
return *this;
}
有三種複製方式:當左矩陣邊長與右矩陣邊長相等,第一個,正常賦值。左>右時,左上角對其,剩餘的賦0。左<右時,左上角對其,多餘的切掉。
補充:
在我的電腦上,普通(n^3)的演算法與Strassen演算法在1500*1500左右的時候時間是差不多的,但是耗時達到30秒,之後Strassen演算法會出現明顯的優勢。在小於100*100的矩陣乘法時普通演算法耗時小於0.01秒,而Strassen可達到3秒,普通演算法有絕對的優勢。
在我的電腦上,把二維陣列擴充套件到300*300以上時會有棧溢位,這時可以上網搜尋找到相應的解決辦法。
END
設定MatirxRef類不知道好不好,畢竟矩陣拷貝也不影響複雜度。
肯定有很多值得改進的地方,也有不對的地方,可以評論提醒一下。
附:
Matrix標頭檔案。
#ifndef MATRIX_H
#define MATRIX_H
#include<iostream>
#include<memory>
#include<vector>
class MatrixRef;
class Matrix {
friend Matrix operator+(const Matrix &, const Matrix &);
friend Matrix operator-(const Matrix &, const Matrix &);
friend std::ostream& operator<<(std::ostream&, const Matrix &);
friend Matrix operator+(const MatrixRef &, const MatrixRef &);
friend Matrix operator-(const MatrixRef &, const MatrixRef &);
friend class MatrixRef;
public:
Matrix();
template<unsigned M, unsigned N>
Matrix(int(&A)[M][N]) : hight(M), width(N), data(make_shared<vector<int>>()) {
data->reserve(M*N);
for (size_t i = 0;i != M;++i)
for (size_t j = 0;j != N;++j)
data->push_back(A[i][j]);
}
Matrix(size_t l, size_t r); // 建立一個行l、列r的零矩陣
Matrix(const Matrix &rhs); // 深層次拷貝構造
explicit Matrix(const MatrixRef &); // 將MatrixRef轉換為Matrix
int& get(size_t l, size_t r); // 取得行L、列R的值
const int& get(size_t l, size_t r) const;
int get() const; // 得到第一個值
Matrix& operator=(const Matrix &rhs); // 深層次拷貝賦值
Matrix& operator+=(const Matrix &rhs);
Matrix& operator=(int i); // 把一個為i的值賦給行為1、列為1的矩陣
Matrix& operator-(); // 對矩陣取負
size_t rows() const {
return width;
}
size_t size() const {
return hight * width;
}
private:
void check_situation(size_t l, size_t r) const {
if (l > hight || r > width)
throw std::range_error("Invalid range");
}
size_t hight = 1;
size_t width = 1;
std::shared_ptr<std::vector<int>> data;
};
class MatrixRef {
friend Matrix operator-(const MatrixRef &, const MatrixRef &);
friend Matrix operator+(const MatrixRef &, const MatrixRef &);
friend class Matrix;
public:
MatrixRef(const Matrix &m, size_t line, size_t row);
MatrixRef(const MatrixRef &mref, size_t line, size_t row);
MatrixRef& operator=(const Matrix &rhs); // 對C_11、C_12、C_13、C_14進行賦值拼接的函式
int& get() const; // 得到第一個值
size_t rows() const {
return length;
}
private:
std::weak_ptr<std::vector<int>> wptr;
size_t hight_startptr;
size_t width_startptr;
size_t length;
};
Matrix operator+(const Matrix &, const Matrix &);
Matrix operator-(const Matrix &, const Matrix &);
Matrix operator+(const MatrixRef &, const MatrixRef &);
Matrix operator-(const MatrixRef &, const MatrixRef &);
std::ostream& operator<<(std::ostream&, const Matrix &);
template<typename T>
Matrix Square_max_matrix_multiply_recursive(const T &A, const T &B) {
size_t n = A.rows();
Matrix C(n, n);
if (n == 1)
return C = A.get()*B.get();
else {
MatrixRef A_11(A, 0, 0), A_12(A, 0, n / 2), A_21(A, n / 2, 0), A_22(A, n / 2, n / 2);// 使用一個類MatirxRef
MatrixRef B_11(B, 0, 0), B_12(B, 0, n / 2), B_21(B, n / 2, 0), B_22(B, n / 2, n / 2);// 含有三個size_t型別。其中兩個實現座標,一個指明矩陣長度
MatrixRef C_11(C, 0, 0), C_12(C, 0, n / 2), C_21(C, n / 2, 0), C_22(C, n / 2, n / 2);// 進行分割
C_11 = Square_max_matrix_multiply_recursive(A_11, B_11) + Square_max_matrix_multiply_recursive(A_12, B_21);// Matrix::operator+;
C_12 = Square_max_matrix_multiply_recursive(A_11, B_12) + Square_max_matrix_multiply_recursive(A_12, B_22);// MatrixRef::operator=;
C_21 = Square_max_matrix_multiply_recursive(A_21, B_11) + Square_max_matrix_multiply_recursive(A_22, B_21);
C_22 = Square_max_matrix_multiply_recursive(A_21, B_12) + Square_max_matrix_multiply_recursive(A_22, B_22);
}
return C;
}
template<typename T, typename N>
Matrix Strassen_matrix_fit(const T &A, const N &B) { // 為2的冪的情況下
size_t n = A.rows();
Matrix C(n, n);
if (n == 1) {
return C = A.get()*B.get();
}
else {
MatrixRef A_11(A, 0, 0), A_12(A, 0, n / 2), A_21(A, n / 2, 0), A_22(A, n / 2, n / 2);// 使用一個類MatirxRef
MatrixRef B_11(B, 0, 0), B_12(B, 0, n / 2), B_21(B, n / 2, 0), B_22(B, n / 2, n / 2);// 含有三個size_t型別。其中兩個實現座標,一個指明矩陣長度
MatrixRef C_11(C, 0, 0), C_12(C, 0, n / 2), C_21(C, n / 2, 0), C_22(C, n / 2, n / 2);// 進行分割
Matrix S1 = B_12 - B_22, S2 = A_11 + A_12, S3 = A_21 + A_22, S4 = B_21 - B_11, S5 = A_11 + A_22, //MatrixRef的加、減
S6 = B_11 + B_22, S7 = A_12 - A_22, S8 = B_21 + B_22, S9 = A_11 - A_21, S10 = B_11 + B_12;
Matrix P1 = Strassen_matrix_fit(A_11, S1), P2 = Strassen_matrix_fit(S2, B_22),
P3 = Strassen_matrix_fit(S3, B_11), P4 = Strassen_matrix_fit(A_22, S4),
P5 = Strassen_matrix_fit(S5, S6), P6 = Strassen_matrix_fit(S7, S8), P7 = Strassen_matrix_fit(S9, S10);
C_11 = P5 + P4 - P2 + P6;
C_12 = P1 + P2;
C_21 = P3 + P4;
C_22 = P5 + P1 - P3 - P7;
}
return C;
}
template<typename T, typename N>
Matrix Strassen_matrix(const T &A, const N &B) {
size_t n = A.rows();
double size = log(n) / log(2);
size_t l_size = static_cast<size_t>(size);
if (l_size != size) {
size_t t_size = (l_size + 1)*(l_size + 1);
Matrix a(t_size, t_size), b(t_size, t_size);
a = A;
b = B;
Matrix C = Strassen_matrix_fit(a, b);
Matrix c(n, n);
c = C;
return c;
}
else
return Strassen_matrix_fit(A, B);
}
#endif
標頭檔案實現。
#include"Matrix1.h"
#include<math.h>
using namespace std;
Matrix::Matrix() :data(make_shared<vector<int>>()) {
data->push_back(0);
}
Matrix::Matrix(size_t l, size_t r) : hight(l), width(r), data(make_shared<vector<int>>()) {
data->resize(l*r);
}
Matrix::Matrix(const Matrix &rhs) : hight(rhs.hight), width(rhs.width), data(make_shared<vector<int>>()) {
for (size_t i = 0;i != size(); ++i)
data->push_back((*rhs.data)[i]);
}
Matrix::Matrix(const MatrixRef &rhs) : hight(rhs.length), width(rhs.length), data(make_shared<vector<int>>()) {
size_t max_size = static_cast<size_t>(sqrt(rhs.wptr.lock()->size())); // 未分解的原式中的矩陣長度
auto ivec = *rhs.wptr.lock();
for (size_t i = 0; i != hight; ++i) {
for (size_t j = 0; j != width; ++j) {
data->push_back(ivec[(i + rhs.hight_startptr)*max_size + j + rhs.width_startptr]);
}
}
}
int& Matrix::get(size_t l, size_t r) {
check_situation(l, r);
return (*data)[--l * width + --r];
}
const int& Matrix::get(size_t l, size_t r) const {
check_situation(l, r);
return (*data)[--l * width + --r];
}
int Matrix::get() const {
return (*data)[0];
}
Matrix& Matrix::operator=(const Matrix &rhs) {
if (hight == rhs.hight) { // rhs this
for (size_t i = 0; i != size(); ++i) { // 1 2 3 1 2 3
(*data)[i] = (*rhs.data)[i]; // 2 3 2 -> 2 3 2
} // 3 2 1 3 2 1
}
else if (hight > rhs.hight) { // 1 2 3 1 2 3 0
for (size_t i = 0;i != hight; ++i) { // 2 3 2 -> 2 3 2 0
for (size_t j = 0, n = 1;j != width; ++j) { // 3 2 1 3 2 1 0
if (j >= rhs.width || i >= rhs.hight) // 0 0 0 0
(*data)[i * width + j] = 0;
else
(*data)[i * width + j] = (*rhs.data)[i * rhs.width + j];
}
}
}
else { // 1 2 3 4 1 2 3
for (size_t i = 0;i != hight; ++i) { // 2 3 4 3 -> 2 3 4
for (size_t j = 0;j != width; ++j) { // 3 4 3 2 3 4 3
(*data)[i * width + j] = (*rhs.data)[i * rhs.width + j]; // 4 3 2 1
}
}
}
return *this;
}
Matrix& Matrix::operator+=(const Matrix &rhs) {
if (hight == rhs.hight && width == rhs.width) {
for (size_t i = 0;i != size();++i)
(*data)[i] += (*rhs.data)[i];
}
else
throw std::logic_error("Not Matched");
return *this;
}
Matrix& Matrix::operator=(int i) {
if (hight == width && hight == 1)
(*data)[0] = i;
return *this;
}
Matrix& Matrix::operator-() {
for (auto &f : *data)
f = -f;
return *this;
}
Matrix operator+(const Matrix &lhs, const Matrix &rhs) {
Matrix m(lhs);
return m += rhs;
}
Matrix operator-(const Matrix &lhs,const Matrix &rhs) {
Matrix m(rhs);
return m = -m + lhs;
}
MatrixRef::MatrixRef(const Matrix &m, size_t line, size_t row) : wptr(m.data), hight_startptr(line), width_startptr(row), length(m.rows() / 2) { }
MatrixRef::MatrixRef(const MatrixRef &mref, size_t line, size_t row) : wptr(mref.wptr), hight_startptr(mref.hight_startptr + line),
width_startptr(mref.width_startptr + row), length(mref.rows() / 2) { }
MatrixRef& MatrixRef::operator=(const Matrix &rhs) {
for (size_t i = 0;i != length;++i) {
for (size_t j = 0;j != length;++j) {
(*wptr.lock())[(i + hight_startptr)*length * 2 + j + width_startptr] = rhs.get(i + 1, j + 1); //注意:length*2 因為C也被分割了
}
}
return *this;
}
int& MatrixRef::get() const {
return (*wptr.lock())[static_cast<size_t>(hight_startptr*sqrt(wptr.lock()->size())) + width_startptr];
}
Matrix operator+(const MatrixRef &lhs, const MatrixRef &rhs) {
Matrix ml(lhs), mr(rhs);
return ml += mr;
}
Matrix operator-(const MatrixRef &lhs, const MatrixRef &rhs) {
Matrix ml(lhs), mr(rhs);
return ml = -mr + ml;
}
ostream& operator<<(ostream &os, const Matrix &m) {
int i = 0;
for (auto f : *m.data) {
cout << f;
if (++i == m.width) {
std::cout << '\n';
i = 0;
}
else
cout << ' ';
}
return os;
}
END