1. 程式人生 > >分治法實現矩陣乘法

分治法實現矩陣乘法

name cout namespace size cas put 分治 ade add

整體的思路就是分,加&乘,拼

#include <iostream>
#include <cstddef>
#include <cstdlib>
#include <ctime>

using namespace std;

int *InitMatrix(int row,int col);//初始化 
void FillMatrix(int *MatrixA, int size);//自動填充 
void PrintMatrix(int *MatrixA,int size);//打印矩陣 
void AddMatrix(int *MatrixIn1,int *MatrixIn2,int *MatrixOut,int size);//加 
void SubMatrix(int *MatrixIn1,int *MatrixIn2,int *MatrixOut,int size);//減 
void SplitMatrix(int *MatrixIn,int *MatrixOut,int size,int part);//四分
void StitchMatrix(int *PartA,int *PartB,int *PartC,int *PartD,int *Result,int size);//反著拼回去
void Strassen(int *MA,int *MB,int *MC,int size);  //Strassen算法
void GradeSchool(int *MatrixIn1,int *MatrixIn2,int *MatrixOut,int size);//對比算法 
 

int main()
{
	clock_t StartTimeS,EndTimeS,StartTimeG,EndTimeG;
	int MaSize = 0;
	cout << "Please input the row of matrix(it must be index of two,like,2,4,8):";
	cin >> MaSize;
	int *MA = NULL;//規避野指針 
	int *MB = NULL;
	int *MC = NULL;
	MA = InitMatrix(MaSize,MaSize);
	MB = InitMatrix(MaSize,MaSize);
	MC = InitMatrix(MaSize,MaSize);
	FillMatrix(MA,MaSize);
	FillMatrix(MB,MaSize); 
	cout << "Matrix A is:" << endl << endl;
	PrintMatrix(MA,MaSize);
	cout << "Matrix B is:" << endl << endl;
	PrintMatrix(MB,MaSize);
	cout << "Matrix A and B are generated!" << endl << "Start to caculate!" << endl;//提示填充完畢
	StartTimeS = clock();
	Strassen(MA,MB,MC,MaSize);
	EndTimeS = clock();
	cout << "After Strassen multiplication the result is:" << endl << endl;
	PrintMatrix(MC,MaSize);
	StartTimeG = clock();
	GradeSchool(MA,MB,MC,MaSize);
	EndTimeG = clock();
	cout << "After Strassen multiplication the result is:" << endl << endl;
	PrintMatrix(MC,MaSize);
	cout << "Strassen method starts at:" << StartTimeS << endl << "ends at:" << EndTimeS << endl;
	cout << "Grade-School method starts at:" << StartTimeG << endl << "ends at:" << EndTimeG << endl;
	
	free(MA);//釋放空間 
	free(MB);
	free(MC);
	
	return 0;
} 

int *InitMatrix(int row,int col)//初始化矩陣,大小事先不確定,所以需要動態分配  
{
	int *p;
	size_t size = sizeof(int)*row*col;//需要開row*col個int類型大小的空間 
	if (NULL == (p = (int *)malloc(size)))  
    {
    	cout << "Error in InitMatrix!" << endl;
    	return NULL;
    }
    else  
		return p;    //返回矩陣首地址 
}


void FillMatrix( int *MatrixA, int size)
{
	 for(int row = 0; row < size; row ++)
    {
        for(int col = 0; col < size; col ++)
        {
           MatrixA[row*size + col] = rand() %5;
        }
    }
}

void PrintMatrix(int *MatrixA,int size)
{
	//cout<<"The Matrix is:"<<endl;
	for(int row = 0; row < size; row ++)
	{
		for(int col = 0; col < size; col ++)
		{
			cout << MatrixA[row*size + col] << "\t";
			if ((col + 1) % ((size)) == 0)
				cout << endl;
		}
	}
	cout << endl;
}

void AddMatrix(int *MatrixIn1,int *MatrixIn2,int *MatrixOut,int size)
{
	for(int i = 0;i < size*size;i ++)
	{
		MatrixOut[i] = MatrixIn1[i] + MatrixIn2[i];
	}
}

void SubMatrix(int *MatrixIn1,int *MatrixIn2,int *MatrixOut,int size)
{
	for(int i = 0;i < size*size;i ++)
	{
		MatrixOut[i] = MatrixIn1[i] - MatrixIn2[i];
	}
}


void SplitMatrix(int *MatrixIn,int *MatrixOut,int size,int part)
{
	int n = size/2;//編寫方便 
	switch(part)
	{
		case 1://四分左上 
		{
			for (int i = 0;i < n;i ++)  
            {  
                for (int j = 0;j < n;j ++)  
                {  
                    MatrixOut[i*n + j] = MatrixIn[i*n + j];  
                }  
            }  
            break;  
		}
		case 2://四分右上 
		{
			for (int i = 0;i < n;i ++)  
            {  
                for (int j = 0;j < n;j ++)  
                {  
                    MatrixOut[i*n + j] = MatrixIn[i*n + j + n];  
                }  
            }  
            break;  
		}
		case 3://四分左下 
		{
			for (int i = 0; i < n; i ++)  
            {  
                for (int j = 0; j < n; j ++)  
                {  
                    MatrixOut[i*n + j] = MatrixIn[(i + n)*n + j];  
                }  
            }  
            break;  
		}
		case 4://四分右下 
		{
			for (int i = 0; i < n; i ++)  
            {  
                for (int j = 0; j< n; j ++)  
                {  
                    MatrixOut[i*n + j] = MatrixIn[(i + n)*n + j + n];  
                }  
            }  
            break;  
		}
		default :  
        	cout<<"Error in SplitMatrix!"; 
	}
}

void StitchMatrix(int *PartA,int *PartB,int *PartC,int *PartD,int *Result,int size)//反著拼回去 
{
	for(int i = 0; i < size; i ++)  
    {  
        for(int j = 0; j < size; j ++)  
        {  
            Result[i*size*2 + j] = PartA[i*size + j];  
            Result[i*size*2 + j + size] = PartB[i*size + j];  
            Result[(i + size)*size*2 + j] = PartC[i*size + j];  
            Result[(i + size)*size*2 + j + size] = PartD[i*size + j];  
        }  
    }  
}
/*
/Strassen算法:
/分塊,分到2*2
*/
void Strassen(int *MA,int *MB,int *MC,int size)
{
	int n = size/2;
	if (2 == size)//這樣就不用分了,以及分到最後執行這個不用再遞歸 
    {  
        int p1,p2,p3,p4,p5,p6,p7;  
        p1 = MA[0]*(MB[1]-MB[3]) ;  
        p2 = (MA[0] + MA[1])*MB[3] ;  
        p3 = (MA[2] + MA[3])*MB[0] ;  
        p4 = MA[3]*(MB[2] - MB[0]) ;  
        p5 = (MA[0] + MA[3])*(MB[0] + MB[3]) ;  
        p6 = (MA[1] - MA[3])*(MB[2] + MB[3]) ;  
        p7 = (MA[0] - MA[2])*(MB[0] + MB[1]) ;  
        MC[0] = p5 + p4 - p2 + p6 ;  
        MC[1] = p1 + p2 ;  
        MC[2] = p3 + p4 ;  
        MC[3] = p5 + p1 -p3 - p7 ;  
        return ;      
    }  
    else
	{
		int *MA1 = NULL,*MA2 = NULL,*MA3 = NULL,*MA4 = NULL;
		int *MB1 = NULL,*MB2 = NULL,*MB3 = NULL,*MB4 = NULL;
		int *MC1 = NULL,*MC2 = NULL,*MC3 = NULL,*MC4 = NULL;
		int *p1 = NULL,*p2 = NULL,*p3 = NULL,*p4 = NULL,*p5 = NULL,*p6 = NULL,*p7 = NULL;
		int *TEMP1 = NULL,*TEMP2 = NULL;
		
		
		MA1 = InitMatrix(n,n);
		MA2 = InitMatrix(n,n);
		MA3 = InitMatrix(n,n);
		MA4 = InitMatrix(n,n);
		MB1 = InitMatrix(n,n);
		MB2 = InitMatrix(n,n);
		MB3 = InitMatrix(n,n);
		MB4 = InitMatrix(n,n);
		MC1 = InitMatrix(n,n);
		MC2 = InitMatrix(n,n);
		MC3 = InitMatrix(n,n);
		MC4 = InitMatrix(n,n);
		p1 = InitMatrix(n,n);
		p2 = InitMatrix(n,n);
		p3 = InitMatrix(n,n);
		p4 = InitMatrix(n,n);
		p5 = InitMatrix(n,n);
		p6 = InitMatrix(n,n);
		p7 = InitMatrix(n,n);
		TEMP1 = InitMatrix(n,n);
		TEMP2 = InitMatrix(n,n);
		
		SplitMatrix(MA,MA1,size,1);SplitMatrix(MA,MA2,size,2);SplitMatrix(MA,MA3,size,3);SplitMatrix(MA,MA4,size,4);
		SplitMatrix(MB,MB1,size,1);SplitMatrix(MB,MB2,size,2);SplitMatrix(MB,MB3,size,3);SplitMatrix(MB,MB4,size,4);

		/*///////////////////
		/* p1=a(f-h)
		/* p2=h(a+b)
		/* p3=e(c+d)
		/* p4=d(g+e)
		/* p5=(e+h)(a+d)
		/* p6=(g+h)(b-d)
		/* p7=(a-c)(e+f)
		/*A a1  b2     B  e1  f2
		/*  c3  d4        g3  h4
		///////////////////*/
		
		//p1
		SubMatrix(MB2,MB4,TEMP1,n);
		Strassen(MA1,TEMP1,p1,n);
		//p2
		AddMatrix(MA1,MA2,TEMP1,n);
		Strassen(MB4,TEMP1,p2,n);
		//P3
		AddMatrix(MA3,MA4,TEMP1,n);
		Strassen(MB1,TEMP1,p3,n);	
		//P4
		AddMatrix(MB3,MB1,TEMP1,n);
		Strassen(MA4,TEMP1,p4,n);
		//P5
		AddMatrix(MB1,MB4,TEMP1,n);
		AddMatrix(MA1,MA4,TEMP2,n);
		Strassen(TEMP1,TEMP2,p5,n);
		//P6
		AddMatrix(MB3,MB4,TEMP1,n);
		SubMatrix(MA2,MA4,TEMP1,n);
		Strassen(TEMP1,TEMP2,p6,n);
		//P7
		AddMatrix(MB1,MB2,TEMP1,n);
		SubMatrix(MA1,MA3,TEMP2,n);
		Strassen(TEMP1,TEMP2,p7,n);
		
		//C1=P5+P4+P6-P2
		AddMatrix(p5,p4,TEMP1,n);
		AddMatrix(TEMP1,p6,TEMP2,n);
		SubMatrix(TEMP2,p2,MC1,n);
		
		//C2=P1+P2
		AddMatrix(p1,p2,MC2,n);
		
		//C3=P3+P4
		AddMatrix(p3,p4,MC3,n);
		
		//C4=P5+P1-P3-P7
		AddMatrix(p5,p1,TEMP1,n);
		SubMatrix(TEMP1,p3,TEMP2,n);
		SubMatrix(TEMP2,p7,MC4,n);
		
		StitchMatrix(MC1,MC2,MC3,MC4,MC,n);
		
		free(MA1);free(MA2);free(MA3);free(MA4);
		free(MB1);free(MB2);free(MB3);free(MB4);
		free(MC1);free(MC2);free(MC3);free(MC4);
		free(p1);free(p2);free(p3);free(p4);free(p5);free(p6);free(p7);
		free(TEMP1);free(TEMP2);
		
		return ;
	} 
}

void GradeSchool(int *MatrixIn1,int *MatrixIn2,int *MatrixOut,int size)
{
	for (int i = 0; i < size; i ++)
    {
        for (int j = 0; j < size; j ++)
        {
			MatrixOut[i*size + j] = 0;
            for (int k = 0; k < size; k ++)
            {
                MatrixOut[i*size + j] = MatrixOut[i*size + j] + MatrixIn1[i*size + k]*MatrixIn2[k*size + j];
            }
        }
    }
}

分治法實現矩陣乘法