1. 程式人生 > >C++ 方矩陣乘法 + Strassen矩陣

C++ 方矩陣乘法 + Strassen矩陣

  這幾天看演算法導論,看到矩陣一章,就實現了一下。

下面是普通的矩陣乘法,複雜度為:n^3。

template<unsigned M,unsigned N, unsigned Q>
void Square_matrix_multiply(int(&A)[M][N], int(&B)[N][Q], int(&C)[M][Q]) {                 
	for (size_t i = 0;i != M;++i) {
		for (size_t j = 0;j != Q;++j) {
			C[i][j] = 0;
			for (size_t n = 0;n != N;++n) {
				C[i][j] += A[i][n] * B[n][j];
			}
		}
	}
}

函式接受三個二維陣列,A * B得到的矩陣賦值給C。

下面是分治策略的演算法。

template<typename T>
Matrix Square_max_matrix_multiply_recursive(const T &A, const T &B) {
	size_t n = A.rows();
	Matrix C(n, n);
	if (n == 1)
		return C = A.get()*B.get();
	else {
		MatrixRef A_11(A, 0, 0), A_12(A, 0, n / 2), A_21(A, n / 2, 0), A_22(A, n / 2, n / 2);    // 使用一個類MatirxRef
		MatrixRef B_11(B, 0, 0), B_12(B, 0, n / 2), B_21(B, n / 2, 0), B_22(B, n / 2, n / 2);    // 含有三個size_t型別。其中兩個實現座標,一個指明矩陣長度
		MatrixRef C_11(C, 0, 0), C_12(C, 0, n / 2), C_21(C, n / 2, 0), C_22(C, n / 2, n / 2);    // 進行分割
		C_11 = Square_max_matrix_multiply_recursive(A_11, B_11) + Square_max_matrix_multiply_recursive(A_12, B_21);  // Matrix::operator+;
		C_12 = Square_max_matrix_multiply_recursive(A_11, B_12) + Square_max_matrix_multiply_recursive(A_12, B_22);  // MatrixRef::operator=;
		C_21 = Square_max_matrix_multiply_recursive(A_21, B_11) + Square_max_matrix_multiply_recursive(A_22, B_21);
		C_22 = Square_max_matrix_multiply_recursive(A_21, B_12) + Square_max_matrix_multiply_recursive(A_22, B_22);
	}
	return C;
}

矩陣實現了一個Matrix類(具體實現在最下面),有一個建構函式:接受兩個size_t值l、r,生成l*r大小值全為0的矩陣。

Matrix::Matrix(size_t l, size_t r) : hight(l), width(r), data(make_shared<vector<int>>()) {
	data->resize(l*r);
}

其中hight為矩陣行高,width為列寬,data為shared_ptr,矩陣用vector實現。

A.rows()返回A的width長度(即方矩陣的邊長),Matrix(n,n)建立一個矩陣。

size_t rows() const {
		return width;
	}

如果n==1,通過Matrix的get函式返回第一個元素,也就是唯一的一個元素。

int Matrix::get() const {
	return (*data)[0];
}

為了不復制矩陣元素(如果可以複製矩陣元素的話,會簡單很多),另實現了一個MatrixRef,其含有:兩個size_t資料成員(實現座標點)、一個size_t資料成員(實現矩陣長度)、一個weak_ptr(指向vector<int>)。

MatrixRef含有兩個建構函式:一個接受Matrix加兩個size_t;一個接受MatrixRef加兩個size_t。都是為了指明引用範圍。

MatrixRef::MatrixRef(const Matrix &m, size_t line, size_t row) : wptr(m.data), 
      hight_startptr(line), width_startptr(row), length(m.rows() / 2) { }
MatrixRef::MatrixRef(const MatrixRef &mref, size_t line, size_t row) : wptr(mref.wptr),                 
      hight_startptr(mref.hight_startptr + line),
      width_startptr(mref.width_startptr + row), length(mref.rows() / 2) { }

wptr用data或wptr初始化,避免拷貝。length為rows()的返回值除以2,因為是分割為4個矩陣,行列各除以2。

要注意:接受MatrixRef的座標要加上之前的座標。

MatrixRef也有一個rows成員函式,為了遞迴呼叫。

size_t rows() const {
		return length;
	}

Square_max_matrix_multiply_recursive函式返回一個Matrix,Matrix實現了operator+,但是行列必須相等。

Matrix& Matrix::operator+=(const Matrix &rhs) {
	if (hight == rhs.hight && width == rhs.width) {
		for (size_t i = 0;i != size();++i)
			(*data)[i] += (*rhs.data)[i];
	}
	else
		throw std::logic_error("Not Matched");
	return *this;
}
Matrix operator+(const Matrix &lhs, const Matrix &rhs) {
	Matrix m(lhs);
	return m += rhs;
}

MatrixRef實現了一個operator=。

MatrixRef& MatrixRef::operator=(const Matrix &rhs) {   
	for (size_t i = 0;i != length;++i) {
		for (size_t j = 0;j != length;++j) {
			(*wptr.lock())[(i + hight_startptr)*length * 2 + j + width_startptr] = 
               rhs.get(i + 1, j + 1);  //注意:length*2  因為C也被分割了
		}
	}
	return *this;
}

其中(i + hight_startptr)*length * 2 + j + width_startptr)為當前下標(vector對應矩陣的下標,非矩陣行列)。此函式將分割的C進行“拼合”。注意:length * 2  ,因為C也被分割了,不乘以2為C_11及C_12的長度,乘以2才是C的行列長寬,才能給C的給定位置賦值。

下面是Strassen矩陣演算法。

template<typename T, typename N>
Matrix Strassen_matrix_fit(const T &A, const N &B) { // 為2的冪的情況下
	size_t n = A.rows();
	Matrix C(n, n);
	if (n == 1) {
		return C = A.get()*B.get();
	}
	else {
		MatrixRef A_11(A, 0, 0), A_12(A, 0, n / 2), A_21(A, n / 2, 0), A_22(A, n / 2, n / 2);        // 使用一個類MatirxRef
		MatrixRef B_11(B, 0, 0), B_12(B, 0, n / 2), B_21(B, n / 2, 0), B_22(B, n / 2, n / 2);        // 含有三個size_t型別。其中兩個實現座標,一個指明矩陣長度
		MatrixRef C_11(C, 0, 0), C_12(C, 0, n / 2), C_21(C, n / 2, 0), C_22(C, n / 2, n / 2);        // 進行分割
		Matrix S1 = B_12 - B_22, S2 = A_11 + A_12, S3 = A_21 + A_22, S4 = B_21 - B_11, S5 = A_11 + A_22,     //MatrixRef的加、減
			S6 = B_11 + B_22, S7 = A_12 - A_22, S8 = B_21 + B_22, S9 = A_11 - A_21, S10 = B_11 + B_12;
		Matrix P1 = Strassen_matrix_fit(A_11, S1), P2 = Strassen_matrix_fit(S2, B_22),
			P3 = Strassen_matrix_fit(S3, B_11), P4 = Strassen_matrix_fit(A_22, S4),
			P5 = Strassen_matrix_fit(S5, S6), P6 = Strassen_matrix_fit(S7, S8), P7 = Strassen_matrix_fit(S9, S10);
		C_11 = P5 + P4 - P2 + P6;
		C_12 = P1 + P2;
		C_21 = P3 + P4;
		C_22 = P5 + P1 - P3 - P7;
	}
	return C;
}

此演算法較之前多了一個MatrixRef::operator-、以及MatrixRef::operator+。

Matrix& Matrix::operator-() {
	for (auto &f : *data)
		f = -f;
	return *this;
}
Matrix operator-(const MatrixRef &lhs, const MatrixRef &rhs) {
	Matrix ml(lhs), mr(rhs);
	return ml = -mr + ml;
}
Matrix operator+(const Matrix &lhs, const Matrix &rhs) {
	Matrix m(lhs);
	return m += rhs;
}

operator-用Matrix的取負、以及Matrix的加法,同時最重要的還有Matrix(const MatrixRef &)。MatrixRef將此物件引用範圍內的子矩陣建立一個區域性Matrix物件。

operator+用Matrix的加法與Matrix(const MatrixRef &)。

Matrix::Matrix(const MatrixRef &rhs) : hight(rhs.length), width(rhs.length), data(make_shared<vector<int>>()) {
	size_t max_size = static_cast<size_t>(sqrt(rhs.wptr.lock()->size()));            // 未分解的原式中的矩陣長度
	auto ivec = *rhs.wptr.lock();
	for (size_t i = 0; i != hight; ++i) {
		for (size_t j = 0; j != width; ++j) {
			data->push_back(ivec[(i + rhs.hight_startptr)*max_size + j + rhs.width_startptr]);
		}
	}
}

其中max_size為wptr所指的vector<int>的size,進行根號得到。max_size就是MatrixRef物件未分解(即未分割的C)的矩陣邊長。static_cast把sqrt返回的double轉未size_t,因為是方矩陣,所以不會損失精度。

(i + rhs.hight_startptr)*max_size + j + rhs.width_startptr)為MatrixRef物件引用範圍內對應vector的下標。此必需乘以max_size。

下面為不是2的冪的情況。

template<typename T, typename N>
Matrix Strassen_matrix(const T &A, const N &B) {
	size_t n = A.rows();
	double size = log(n) / log(2);
	size_t l_size = static_cast<size_t>(size);
	if (l_size != size) {
		size_t t_size = (l_size + 1)*(l_size + 1);
		Matrix a(t_size, t_size), b(t_size, t_size);
		a = A;
		b = B;
		Matrix C = Strassen_matrix_fit(a, b);
		Matrix c(n, n);
		c = C;
		return c;
	}
	else
		return Strassen_matrix_fit(A, B);
}

size與l_size比較可知是否為2的冪,如果是,執行else,不是,則執行if。

當不是2的冪時,我的思路是把它加0,拼成2的冪的形式。如下圖。

1 2 3                1 2 3 0            5 6 7
2 3 2     --->       2 3 2 0   --->     4 5 2 
3 2 1                3 2 1 0            3 5 6
                     0 0 0 0

然後得出結果時再切去周圍的零,其值是不變的。假如為n*n的矩陣,複雜度(n + k) ^ lg7。n + k 為最接近n的2的冪,其中0<k<n。

(n + k) ^ lg7 < (2n) ^ lg7 = 7 * n ^ lg7。

複雜度還是n ^ lg7。

加零還是切去零,我是通過賦值來實現的。

Matrix& Matrix::operator=(const Matrix &rhs) {
	if (hight == rhs.hight) {								//  rhs          this
		for (size_t i = 0; i != size(); ++i) {					        //	1 2 3		 1 2 3
			(*data)[i] = (*rhs.data)[i];						//	2 3 2   ->	 2 3 2
		}								        	//	3 2 1		 3 2 1
	}																		 
	else if (hight > rhs.hight) {							 	//	1 2 3		 1 2 3 0
		for (size_t i = 0;i != hight; ++i) {				        	//	2 3 2   ->	 2 3 2 0
			for (size_t j = 0, n = 1;j != width; ++j) {			        //	3 2 1		 3 2 1 0
				if (j >= rhs.width || i >= rhs.hight)			//				 0 0 0 0
					(*data)[i * width + j] = 0;
				else																
					(*data)[i * width + j] = (*rhs.data)[i * rhs.width + j];	
			}
		}
	}
	else {											     //	1 2 3 4	      1 2 3 
			for (size_t i = 0;i != hight; ++i) {					     // 2 3 4 3   ->  2 3 4 
				for (size_t j = 0;j != width; ++j) {				     // 3 4 3 2	      3 4 3 
					(*data)[i * width + j] = (*rhs.data)[i * rhs.width + j];     //	4 3 2 1		 
				}
			}
	}
	return *this;
}

有三種複製方式:當左矩陣邊長與右矩陣邊長相等,第一個,正常賦值。左>右時,左上角對其,剩餘的賦0。左<右時,左上角對其,多餘的切掉。

補充:

在我的電腦上,普通(n^3)的演算法與Strassen演算法在1500*1500左右的時候時間是差不多的,但是耗時達到30秒,之後Strassen演算法會出現明顯的優勢。在小於100*100的矩陣乘法時普通演算法耗時小於0.01秒,而Strassen可達到3秒,普通演算法有絕對的優勢。

在我的電腦上,把二維陣列擴充套件到300*300以上時會有棧溢位,這時可以上網搜尋找到相應的解決辦法。

END

設定MatirxRef類不知道好不好,畢竟矩陣拷貝也不影響複雜度。

肯定有很多值得改進的地方,也有不對的地方,可以評論提醒一下。

附:

Matrix標頭檔案。

#ifndef MATRIX_H
#define MATRIX_H
#include<iostream>
#include<memory>
#include<vector>
class MatrixRef;
class Matrix {
	friend Matrix operator+(const Matrix &, const Matrix &);
	friend Matrix operator-(const Matrix &, const Matrix &);
	friend std::ostream& operator<<(std::ostream&, const Matrix &);
	friend Matrix operator+(const MatrixRef &, const MatrixRef &);
	friend Matrix operator-(const MatrixRef &, const MatrixRef &);
	friend class MatrixRef;
public:
	Matrix();
	template<unsigned M, unsigned N>
	Matrix(int(&A)[M][N]) : hight(M), width(N), data(make_shared<vector<int>>()) {
		data->reserve(M*N);
		for (size_t i = 0;i != M;++i)
			for (size_t j = 0;j != N;++j)
				data->push_back(A[i][j]);
	}
	Matrix(size_t l, size_t r); // 建立一個行l、列r的零矩陣
	Matrix(const Matrix &rhs); // 深層次拷貝構造
	explicit Matrix(const MatrixRef &);  // 將MatrixRef轉換為Matrix 
	int& get(size_t l, size_t r); // 取得行L、列R的值
	const int& get(size_t l, size_t r) const;
	int get() const; // 得到第一個值
	Matrix& operator=(const Matrix &rhs); // 深層次拷貝賦值
	Matrix& operator+=(const Matrix &rhs);
	Matrix& operator=(int i); // 把一個為i的值賦給行為1、列為1的矩陣
	Matrix& operator-(); // 對矩陣取負
	size_t rows() const {
		return width;
	}
	size_t size() const {
		return hight * width;
	}
private:
	void check_situation(size_t l, size_t r) const {
		if (l > hight || r > width)
			throw std::range_error("Invalid range");
	}
	size_t hight = 1;
	size_t width = 1;
	std::shared_ptr<std::vector<int>> data;
};

class MatrixRef {
	friend Matrix operator-(const MatrixRef &, const MatrixRef &);
	friend Matrix operator+(const MatrixRef &, const MatrixRef &);
	friend class Matrix;
public:
	MatrixRef(const Matrix &m, size_t line, size_t row);
	MatrixRef(const MatrixRef &mref, size_t line, size_t row);
	MatrixRef& operator=(const Matrix &rhs); // 對C_11、C_12、C_13、C_14進行賦值拼接的函式
	int& get() const; // 得到第一個值
	size_t rows() const {
		return length;
	}
private:
	std::weak_ptr<std::vector<int>> wptr;
	size_t hight_startptr;
	size_t width_startptr;
	size_t length;
};
Matrix operator+(const Matrix &, const Matrix &);
Matrix operator-(const Matrix &, const Matrix &);
Matrix operator+(const MatrixRef &, const MatrixRef &);
Matrix operator-(const MatrixRef &, const MatrixRef &);
std::ostream& operator<<(std::ostream&, const Matrix &);

template<typename T>
Matrix Square_max_matrix_multiply_recursive(const T &A, const T &B) {
	size_t n = A.rows();
	Matrix C(n, n);
	if (n == 1)
		return C = A.get()*B.get();
	else {
		MatrixRef A_11(A, 0, 0), A_12(A, 0, n / 2), A_21(A, n / 2, 0), A_22(A, n / 2, n / 2);// 使用一個類MatirxRef
		MatrixRef B_11(B, 0, 0), B_12(B, 0, n / 2), B_21(B, n / 2, 0), B_22(B, n / 2, n / 2);// 含有三個size_t型別。其中兩個實現座標,一個指明矩陣長度
		MatrixRef C_11(C, 0, 0), C_12(C, 0, n / 2), C_21(C, n / 2, 0), C_22(C, n / 2, n / 2);// 進行分割
		C_11 = Square_max_matrix_multiply_recursive(A_11, B_11) + Square_max_matrix_multiply_recursive(A_12, B_21);// Matrix::operator+;
		C_12 = Square_max_matrix_multiply_recursive(A_11, B_12) + Square_max_matrix_multiply_recursive(A_12, B_22);// MatrixRef::operator=;
		C_21 = Square_max_matrix_multiply_recursive(A_21, B_11) + Square_max_matrix_multiply_recursive(A_22, B_21);
		C_22 = Square_max_matrix_multiply_recursive(A_21, B_12) + Square_max_matrix_multiply_recursive(A_22, B_22);
	}
	return C;
}
template<typename T, typename N>
Matrix Strassen_matrix_fit(const T &A, const N &B) { // 為2的冪的情況下
	size_t n = A.rows();
	Matrix C(n, n);
	if (n == 1) {
		return C = A.get()*B.get();
	}
	else {
		MatrixRef A_11(A, 0, 0), A_12(A, 0, n / 2), A_21(A, n / 2, 0), A_22(A, n / 2, n / 2);// 使用一個類MatirxRef
		MatrixRef B_11(B, 0, 0), B_12(B, 0, n / 2), B_21(B, n / 2, 0), B_22(B, n / 2, n / 2);// 含有三個size_t型別。其中兩個實現座標,一個指明矩陣長度
		MatrixRef C_11(C, 0, 0), C_12(C, 0, n / 2), C_21(C, n / 2, 0), C_22(C, n / 2, n / 2);// 進行分割
		Matrix S1 = B_12 - B_22, S2 = A_11 + A_12, S3 = A_21 + A_22, S4 = B_21 - B_11, S5 = A_11 + A_22, //MatrixRef的加、減
			S6 = B_11 + B_22, S7 = A_12 - A_22, S8 = B_21 + B_22, S9 = A_11 - A_21, S10 = B_11 + B_12;
		Matrix P1 = Strassen_matrix_fit(A_11, S1), P2 = Strassen_matrix_fit(S2, B_22),
			P3 = Strassen_matrix_fit(S3, B_11), P4 = Strassen_matrix_fit(A_22, S4),
			P5 = Strassen_matrix_fit(S5, S6), P6 = Strassen_matrix_fit(S7, S8), P7 = Strassen_matrix_fit(S9, S10);
		C_11 = P5 + P4 - P2 + P6;
		C_12 = P1 + P2;
		C_21 = P3 + P4;
		C_22 = P5 + P1 - P3 - P7;
	}
	return C;
}
template<typename T, typename N>
Matrix Strassen_matrix(const T &A, const N &B) {
	size_t n = A.rows();
	double size = log(n) / log(2);
	size_t l_size = static_cast<size_t>(size);
	if (l_size != size) {
		size_t t_size = (l_size + 1)*(l_size + 1);
		Matrix a(t_size, t_size), b(t_size, t_size);
		a = A;
		b = B;
		Matrix C = Strassen_matrix_fit(a, b);
		Matrix c(n, n);
		c = C;
		return c;
	}
	else
		return Strassen_matrix_fit(A, B);
}


#endif

標頭檔案實現。

#include"Matrix1.h"
#include<math.h>
using namespace std;

Matrix::Matrix() :data(make_shared<vector<int>>()) {
	data->push_back(0);
}
Matrix::Matrix(size_t l, size_t r) : hight(l), width(r), data(make_shared<vector<int>>()) {
	data->resize(l*r);
}
Matrix::Matrix(const Matrix &rhs) : hight(rhs.hight), width(rhs.width), data(make_shared<vector<int>>()) {
		for (size_t i = 0;i != size(); ++i)                        
			data->push_back((*rhs.data)[i]);														
}
Matrix::Matrix(const MatrixRef &rhs) : hight(rhs.length), width(rhs.length), data(make_shared<vector<int>>()) {
	size_t max_size = static_cast<size_t>(sqrt(rhs.wptr.lock()->size())); // 未分解的原式中的矩陣長度
	auto ivec = *rhs.wptr.lock();
	for (size_t i = 0; i != hight; ++i) {
		for (size_t j = 0; j != width; ++j) {
			data->push_back(ivec[(i + rhs.hight_startptr)*max_size + j + rhs.width_startptr]);
		}
	}
}
int& Matrix::get(size_t l, size_t r) {
	check_situation(l, r);
	return (*data)[--l * width + --r];
}
const int& Matrix::get(size_t l, size_t r) const {
	check_situation(l, r);
	return (*data)[--l * width + --r];
}
int Matrix::get() const {
	return (*data)[0];
}
Matrix& Matrix::operator=(const Matrix &rhs) {
	if (hight == rhs.hight) {												//  rhs          this
		for (size_t i = 0; i != size(); ++i) {								//	1 2 3		 1 2 3
			(*data)[i] = (*rhs.data)[i];									//	2 3 2   ->	 2 3 2
		}																	//	3 2 1		 3 2 1
	}																		 
	else if (hight > rhs.hight) {							 				//	1 2 3		 1 2 3 0
		for (size_t i = 0;i != hight; ++i) {								//	2 3 2   ->	 2 3 2 0
			for (size_t j = 0, n = 1;j != width; ++j) {						//	3 2 1		 3 2 1 0
				if (j >= rhs.width || i >= rhs.hight)						//				 0 0 0 0
					(*data)[i * width + j] = 0;
				else																
					(*data)[i * width + j] = (*rhs.data)[i * rhs.width + j];	
			}
		}
	}
	else {																	 //	1 2 3 4		  1 2 3 
			for (size_t i = 0;i != hight; ++i) {							 //	2 3 4 3   ->  2 3 4 
				for (size_t j = 0;j != width; ++j) {				    	 //	3 4 3 2	      3 4 3 
					(*data)[i * width + j] = (*rhs.data)[i * rhs.width + j]; //	4 3	2 1		 
				}
			}
	}
	return *this;
}
Matrix& Matrix::operator+=(const Matrix &rhs) {
	if (hight == rhs.hight && width == rhs.width) {
		for (size_t i = 0;i != size();++i)
			(*data)[i] += (*rhs.data)[i];
	}
	else
		throw std::logic_error("Not Matched");
	return *this;
}
Matrix& Matrix::operator=(int i) {
	if (hight == width && hight == 1)
		(*data)[0] = i;
	return *this;
}
Matrix& Matrix::operator-() {
	for (auto &f : *data)
		f = -f;
	return *this;
}

Matrix operator+(const Matrix &lhs, const Matrix &rhs) {
	Matrix m(lhs);
	return m += rhs;
}
Matrix operator-(const Matrix &lhs,const Matrix &rhs) {
	Matrix m(rhs);
	return m = -m + lhs;
}

MatrixRef::MatrixRef(const Matrix &m, size_t line, size_t row) : wptr(m.data), hight_startptr(line), width_startptr(row), length(m.rows() / 2) { }
MatrixRef::MatrixRef(const MatrixRef &mref, size_t line, size_t row) : wptr(mref.wptr), hight_startptr(mref.hight_startptr + line),
width_startptr(mref.width_startptr + row), length(mref.rows() / 2) { }
MatrixRef& MatrixRef::operator=(const Matrix &rhs) {   
	for (size_t i = 0;i != length;++i) {
		for (size_t j = 0;j != length;++j) {
			(*wptr.lock())[(i + hight_startptr)*length * 2 + j + width_startptr] = rhs.get(i + 1, j + 1);  //注意:length*2  因為C也被分割了
		}
	}
	return *this;
}
int& MatrixRef::get() const {
	return (*wptr.lock())[static_cast<size_t>(hight_startptr*sqrt(wptr.lock()->size())) + width_startptr];
}
Matrix operator+(const MatrixRef &lhs, const MatrixRef &rhs) {
	Matrix ml(lhs), mr(rhs);
	return ml += mr;
}
Matrix operator-(const MatrixRef &lhs, const MatrixRef &rhs) {
	Matrix ml(lhs), mr(rhs);
	return ml = -mr + ml;
}
ostream& operator<<(ostream &os, const Matrix &m) {
	int i = 0;
	for (auto f : *m.data) {
		cout << f;
		if (++i == m.width) {
			std::cout << '\n';
			i = 0;
		}
		else
			cout << ' ';
	}
	return os;
}

END