1. 程式人生 > >矩陣乘積計算(Strassen)

矩陣乘積計算(Strassen)

矩陣乘積計算(Strassen)

問題描述

​ 已知A,B兩個矩陣計算其乘積C?

矩陣乘積數學公式:

​ 假設存在兩個矩陣A為m×n矩陣,B為k×l矩陣,若需要計算AB則必須n=k,若需要計算BA必須l=m否則無法進行計算,先假定n=k即B為n×l矩陣則AB的結果為一個m×l的矩陣並且該矩陣每個點的元素的值表示為 Cij 則:

這裡寫圖片描述

這裡寫圖片描述

方法一:直接計算

​ 直接利用多重for迴圈求出相關矩陣對應的點的值即可

//矩陣的資料結構,隨機矩陣,非特殊矩陣
struct array
{
    int **data;                 //資料域
int row; int col; }; /** * 初始化矩陣元素,用隨機數填充 * 只為研究演算法因此為進行相關的記憶體檢查 * flag用來標記是否生成空矩陣,即元素全部為0的矩陣 */ void init_array(struct array *ptr,const int row,const int col,int flag) { int i = 0,j = 0; ptr->data = (int **)malloc(sizeof(int)*row); //??記憶體分配 for
(i = 0; i < row; i++) { *(ptr->data + i) = (int*)malloc(sizeof(int)*col); //??記憶體分配 } ptr->col = col; ptr->row = row; srand(time(NULL)); for (i = 0; i < row; i++) { for (j = 0; j < col; j++) { if
(flag) { ptr->data[i][j] = rand() % ARRAY_PRCE; } else { ptr->data[i][j] = 0; } } } } /** * 列印矩陣元素 */ void print_array(const struct array *ptr, const char *msg) { int i, j; printf("%s\n", msg); for (i = 0; i < ptr->row; i++) { for (j = 0; j < ptr->col; j++) { printf("%4d", ptr->data[i][j]); } printf("\n"); } } /** * 銷燬記憶體 */ void delete_array(struct array *ptr) { int i = 0; for (i = 0; i < ptr->row; i++) { free(*(ptr->data + i)); *(ptr->data + i) = NULL; } free(ptr->data); } /******************************************************************************************* *******************************************************************************************/ /** * 矩陣乘法求解 * 問題描述:已知兩個可以進行相乘的矩陣,求的乘積後的結果 */ /** * 方法一:暴力直接求解 * 利用矩陣乘法規則直接進行求解羅列出每個點的值求的最終的矩陣 */ struct array mult_array(const struct array *ptr1, const struct array *ptr2) { int i = 0; int j = 0; int k = 0; struct array ptr; if (ptr1->col != ptr2->row) //檢查是否符合可以進行乘積的要求 { return; } init_array(&ptr, ptr1->row, ptr2->col, 0); for (i = 0; i < ptr.row;i ++) { for (j = 0; j < ptr.col; j++) { for (k = 0; k < ptr1->col; k++) { ptr.data[i][j] += ptr1->data[i][k] * ptr2->data[k][j]; } } } return ptr; }

執行效果

這裡寫圖片描述

時間複雜度為O( n3 )

方法二:分治演算法

​ 將矩陣分解為一個個小矩陣進行計算然後將計算結果合併得到相關的結果。源於矩陣服從分配率和結合律,並不支援交換律。

​ 三個矩陣本身就可以寫成下面的格式

這裡寫圖片描述

​ 那麼相關的計算可以寫成

這裡寫圖片描述

​ 同理A11等一些子矩陣也可以寫成相關的子矩陣,就這樣將矩陣不斷分解為小矩陣進行計算,最後歸併為一個矩陣。

​ 時間複雜度為O( n3 )

/**
 *  方法二:利用分治思想進行求解
 *  存在問題無法解決不同型別的矩陣的問題,要求矩陣的行列必須為2的n次方,若不符合要求可以使用
 *  補0來構造相關的矩陣
 */
Matrix* Matrix::merge_calc(const Matrix& x)
{
    if (x.row == 1)             //當前的矩陣為單個的元素
    {
        Matrix *ptr = new Matrix(x.row, x.col);
        ptr->clear((this->getElem(0, 0))*(x.getElem(0, 0)));
        return ptr;
    }

    //將第一個矩陣分解為四個子矩陣
    Matrix A11(0,0,row/2,col/2,*this);
    Matrix A12(row / 2, 0, row, col / 2, *this);
    Matrix A21(0, col / 2, row / 2, col, *this);
    Matrix A22(row / 2, col / 2, row, col, *this);
    //將第二個矩陣分解為四個子矩陣
    Matrix B11(0, 0, row / 2, col / 2, x);
    Matrix B12(row / 2, 0, row, col / 2, x);
    Matrix B21(0, col / 2, row / 2, col, x);
    Matrix B22(row / 2, col / 2, row, col, x);

    Matrix *C11 = Matrix::add(A11.merge_calc(B11), A12.merge_calc(B21));
    Matrix *C12 = Matrix::add(A11.merge_calc(B12), A12.merge_calc(B22));
    Matrix *C21 = Matrix::add(A21.merge_calc(B11), A22.merge_calc(B21));
    Matrix *C22 = Matrix::add(A21.merge_calc(B12), A22.merge_calc(B22));

    //將C11,C12,C21,C22合併為一個完整的矩陣
    Matrix* ptr = Matrix::merge(C11, C12, C21, C22);

    return ptr;
}

方法三:Strassen演算法

​ Strassen演算法同樣是使用分治的思想解決問題,只不過,不同的是當矩陣的階很大時就會採取一個遞推式進行計算相關遞推式為:

                            S1 = B12 - B22
                            S2 = A11 + A12
                            S3 = A21 + A22
                            S4 = B21 - B11
                            S5 = A11 + A22
                            S6 = B11 + B22
                            S7 = A12 - A22
                            S8 = B21 + B22
                            S9 = A11 - A21
                            S10 = B11 + B12 
                            P1 = A11 * S1
                            P2 = S2 * B22
                            P3 = S3 * B11
                            P4 = A22 * S4
                            P5 = S5 * S6
                            P6 = S7 * S8
                            P7 = S9 * S10
                            C11 = P5 + P4 - P2 + P6
                            C12 = P1 + P2
                            C21 = P3 + P4
                            C22 = P5 + P1 - P3 - P7

​ 其中A11,A12,A21,A22和B11,B12,B21,B22分別為兩個乘數A和B矩陣的四個子矩陣。C11,C12,C21,C22為最終的結果C矩陣的四個子矩陣。該遞推式是被數學家證明過的。

​ 該演算法的效率為O( n(log27) ),但是相對來說額外空間的使用也是很多的。

Matrix* Matrix::strassen_calc(const Matrix& x)
{
    if (x.row < 2)
    {
        return this->force_calc(x);
    }

    //將第一個矩陣分解為四個子矩陣
    Matrix A11(0, 0, row / 2, col / 2, *this);
    Matrix A12(row / 2, 0, row, col / 2, *this);
    Matrix A21(0, col / 2, row / 2, col, *this);
    Matrix A22(row / 2, col / 2, row, col, *this);
    //將第二個矩陣分解為四個子矩陣
    Matrix B11(0, 0, row / 2, col / 2, x);
    Matrix B12(row / 2, 0, row, col / 2, x);
    Matrix B21(0, col / 2, row / 2, col, x);
    Matrix B22(row / 2, col / 2, row, col, x);

    Matrix* S1 = B12 - B22;
    Matrix* S2 = A11 + A12;
    Matrix* S3 = A21 + A22;
    Matrix* S4 = B21 - B11;
    Matrix* S5 = A11 + A22;
    Matrix* S6 = B11 + B22;
    Matrix* S7 = A12 - A22;
    Matrix* S8 = B21 + B22;
    Matrix* S9 = A11 - A21;
    Matrix* S10 = B11 + B12;

    Matrix* P1 = B12 - B22;
    Matrix* P2 = B12 - B22;
    Matrix* P3 = B12 - B22;
    Matrix* P4 = B12 - B22;
    Matrix* P5 = B12 - B22;
    Matrix* P6 = B12 - B22;
    Matrix* P7 = B12 - B22;

    P1 = A11.strassen_calc(*S1);
    P2 = S2->strassen_calc(B22);
    P3 = S3->strassen_calc(B11);
    P4 = A22.strassen_calc(*S4);
    P5 = S5->strassen_calc(*S6);
    P6 = S7->strassen_calc(*S8);
    P7 = S9->strassen_calc(*S10);

    Matrix *C11 = Matrix::sub(Matrix::add(P5, P4), Matrix::sub(P2, P6));
    Matrix *C12 = Matrix::add(P1, P2);
    Matrix *C21 = Matrix::add(P3, P4);
    Matrix *C22 = Matrix::sub(Matrix::add(P5, P1), Matrix::add(P3, P7));

    return Matrix::merge(C11,C12,C21,C22);
}

執行效果:

這裡寫圖片描述

完整的程式碼

//Matrix.h
#pragma once
#ifndef _MATRIX_H_
#define _MATRIX_H_

#include <iostream>
#include <vector>
using std::vector;
#define VISE 5
#define GATE 16                 //用來限定使用哪種演算法進行計算

#include <cstdlib>
#include <ctime>

typedef int type;
class Matrix
{
private:
    int row;                            //行
    int col;                            //列
    vector<vector<type>> data;          //資料
public:
    Matrix(int row, int col) :data(row),row(row),col(col)                   //矩陣資料生成利用隨機數進行生成
    {
        for (int i = 0; i < row; i++)
        {
            data[i].resize(col);
        }

        srand(time(0));
        for (int i = 0; i < row; i++)
        {
            for (int j = 0; j < col; j++)
            {
                data[i][j] = rand() % VISE;
            }
        }
    }

    Matrix(int row1, int col1, int row2, int col2, const Matrix& x) :row(row2 - row1), col(col2 - col1),data(row)
    {
        for (int i = 0; i < row; i++)
        {
            data[i].resize(col);
        }

        for (int i = 0; i < row; i++)
        {
            for (int j = 0; j < col; j++)
            {
                data[i][j] = x.getElem(col1 + i, row1 + j);
            }
        }
    }

    Matrix(const Matrix& x)
    {
        *this = x;
    }

    //相關算數運算操作
    Matrix* operator+(const Matrix&);
    Matrix* operator-(const Matrix&);
    Matrix* operator*(const Matrix&);
    static Matrix* add(const Matrix*, const Matrix*);                               //+
    static Matrix* sub(const Matrix*, const Matrix*);                               //-
    static Matrix* merge(const Matrix*, const Matrix*,const Matrix*,const Matrix*); //將四個子矩陣合併為一個矩陣
    //獲取矩陣的相關元素
    vector<type> operator[](const int);             //取得row
    type getElem(const int,const int) const;        //獲取相關節點的資料
    void setElem(const int, const int, type);       //設定節點的資料   

    //計算乘法的演算法
    Matrix* force_calc(const Matrix&);              //直接暴力求解
    Matrix* merge_calc(const Matrix&);              //分治求解
    Matrix* strassen_calc(const Matrix&);           //Strassen演算法

    void show();                                    //列印矩陣
    bool isSimilar(const Matrix& x);                //行列相同即為同類型矩陣
    void clear(type);                               //設定矩陣中所有的元素為同一個指定的值                                

    ~Matrix();
};

#endif
//_MATRIX_H_
//Matrix.cpp
#include "Matrix.h"

Matrix* Matrix::operator*(const Matrix& x)
{
    if (x.row != this->col)
    {
        return nullptr;
    }

    if (x.row < VISE && x.col < VISE && row < VISE && col < VISE)
    {
        return this->merge_calc(x);
    }

    return this->strassen_calc(x);
}

/**
 *  方法一:暴力直接求解問題
 *  時間複雜度為O(n^3)
 */
Matrix* Matrix::force_calc(const Matrix& x)
{
    if (x.row != this->col)                                             //行列不同無法進行乘法,可以進行補零將相關矩陣填充為可使用的矩陣
    {                                                                   //這裡不進行相關的編寫
        return nullptr;
    }

    Matrix *ptr = new Matrix(row, x.col);
    ptr->clear(0);
    for (int i = 0; i < row; i++)
    {
        for (int j = 0; j < x.col; j++)
        {
            for (int k = 0; k < col; k++)
            {
                ptr->setElem(i, j, ptr->getElem(i, j) + getElem(i, k) * x.getElem(k, j));
            }
        }
    }

    return ptr;
}

void Matrix::clear(type cur = 0)
{
    for (int i = 0; i < row; i++)
    {
        for (int j = 0; j < col; j++)
        {
            data[i][j] = cur;
        }
    }
}

/**
 *  方法二:利用分治思想進行求解
 *  存在問題無法解決不同型別的矩陣的問題,要求矩陣的行列必須為2的n次方,若不符合要求可以使用
 *  補0來構造相關的矩陣
 */
Matrix* Matrix::merge_calc(const Matrix& x)
{
    if (x.row == 1)             //當前的矩陣為單個的元素
    {
        Matrix *ptr = new Matrix(x.row, x.col);
        ptr->clear((this->getElem(0, 0))*(x.getElem(0, 0)));
        return ptr;
    }

    //將第一個矩陣分解為四個子矩陣
    Matrix A11(0,0,row/2,col/2,*this);
    Matrix A12(row / 2, 0, row, col / 2, *this);
    Matrix A21(0, col / 2, row / 2, col, *this);
    Matrix A22(row / 2, col / 2, row, col, *this);
    //將第二個矩陣分解為四個子矩陣
    Matrix B11(0, 0, row / 2, col / 2, x);
    Matrix B12(row / 2, 0, row, col / 2, x);
    Matrix B21(0, col / 2, row / 2, col, x);
    Matrix B22(row / 2, col / 2, row, col, x);

    Matrix *C11 = Matrix::add(A11.merge_calc(B11), A12.merge_calc(B21));
    Matrix *C12 = Matrix::add(A11.merge_calc(B12), A12.merge_calc(B22));
    Matrix *C21 = Matrix::add(A21.merge_calc(B11), A22.merge_calc(B21));
    Matrix *C22 = Matrix::add(A21.merge_calc(B12), A22.merge_calc(B22));

    //將C11,C12,C21,C22合併為一個完整的矩陣
    Matrix* ptr = Matrix::merge(C11, C12, C21, C22);

    return ptr;
}

Matrix* Matrix::strassen_calc(const Matrix& x)
{
    if (x.row < 2)
    {
        return this->force_calc(x);
    }

    //將第一個矩陣分解為四個子矩陣
    Matrix A11(0, 0, row / 2, col / 2, *this);
    Matrix A12(row / 2, 0, row, col / 2, *this);
    Matrix A21(0, col / 2, row / 2, col, *this);
    Matrix A22(row / 2, col / 2, row, col, *this);
    //將第二個矩陣分解為四個子矩陣
    Matrix B11(0, 0, row / 2, col / 2, x);
    Matrix B12(row / 2, 0, row, col / 2, x);
    Matrix B21(0, col / 2, row / 2, col, x);
    Matrix B22(row / 2, col / 2, row, col, x);

    Matrix* S1 = B12 - B22;
    Matrix* S2 = A11 + A12;
    Matrix* S3 = A21 + A22;
    Matrix* S4 = B21 - B11;
    Matrix* S5 = A11 + A22;
    Matrix* S6 = B11 + B22;
    Matrix* S7 = A12 - A22;
    Matrix* S8 = B21 + B22;
    Matrix* S9 = A11 - A21;
    Matrix* S10 = B11 + B12;

    Matrix* P1 = B12 - B22;
    Matrix* P2 = B12 - B22;
    Matrix* P3 = B12 - B22;
    Matrix* P4 = B12 - B22;
    Matrix* P5 = B12 - B22;
    Matrix* P6 = B12 - B22;
    Matrix* P7 = B12 - B22;

    P1 = A11.strassen_calc(*S1);
    P2 = S2->strassen_calc(B22);
    P3 = S3->strassen_calc(B11);
    P4 = A22.strassen_calc(*S4);
    P5 = S5->strassen_calc(*S6);
    P6 = S7->strassen_calc(*S8);
    P7 = S9->strassen_calc(*S10);

    Matrix *C11 = Matrix::sub(Matrix::add(P5, P4), Matrix::sub(P2, P6));
    Matrix *C12 = Matrix::add(P1, P2);
    Matrix *C21 = Matrix::add(P3, P4);
    Matrix *C22 = Matrix::sub(Matrix::add(P5, P1), Matrix::add(P3, P7));

    return Matrix::merge(C11,C12,C21,C22);
}

/**
 *  將四個子矩陣合併為一個完整的矩陣
 *  也可以使用分治思想進行解決,以後可能會新增相關的功能
 */
Matrix* Matrix::merge(const Matrix* p1, const Matrix* p2,const Matrix* p3, const Matrix* p4)
{
    //不符合可以進行合併的條件
    if (!(p1->row == p2->row && p2->col == p4->col && p4->row == p3->row && p1->col == p3->col))
    {
        return nullptr;
    }

    Matrix* ptr = new Matrix(p1->row + p3->row, p2->col + p1->col);
    ptr->clear(0);
    //重新裝值
    for (int i = 0; i < p1->row; i++)
    {
        for (int j = 0; j < p1->col; j++)
        {
            ptr->setElem(i, j, p1->getElem(i, j));
        }
    }

    for (int i = 0; i < p2->row; i++)
    {
        for (int j = 0; j < p2->col; j++)
        {
            ptr->setElem(i, j + p1->col, p2->getElem(i, j));
        }
    }

    for (int i = 0; i < p3->row; i++)
    {
        for (int j = 0; j < p3->col; j++)
        {
            ptr->setElem(i + p1->row, j, p3->getElem(i, j));
        }
    }

    for (int i = 0; i < p4->row; i++)
    {
        for (int j = 0; j < p4->col; j++)
        {
            ptr->setElem(p1->row + i, p1->col + j, p4->getElem(i, j));
        }
    }

    return ptr;
}

Matrix* Matrix::sub(const Matrix* p1, const Matrix* p2)
{
    if (!(p1->col == p2->col && p1->row == p2->row))
    {
        return nullptr;
    }

    Matrix *ptr = new Matrix(p1->row, p1->col);
    for (int i = 0; i < p1->row; i++)
    {
        for (int j = 0; j < p1->col; j++)
        {
            ptr->setElem(i, j, (p1->getElem(i, j) - p2->getElem(i, j)));
        }
    }

    return ptr;
}

Matrix* Matrix::add(const Matrix* p1, const Matrix* p2)
{
    if (!(p1->col == p2->col && p1->row == p2->row))
    {
        return nullptr;
    }

    Matrix *ptr = new Matrix(p1->row, p1->col);
    for (int i = 0; i < p1->row; i++)
    {
        for (int j = 0; j < p1->col; j++)
        {
            ptr->setElem(i, j, (p1->getElem(i, j) + p2->getElem(i, j)));
        }
    }

    return ptr;
}

Matrix* Matrix::operator+(const Matrix& x)
{
    if (!isSimilar(x))
    {
        return nullptr;
    }

    Matrix *ptr = new Matrix(x.row, x.col);                             //記憶體需要釋放
    for (int i = 0; i < row; i++)
    {
        for (int j = 0; j < col; j++)
        {
            ptr->setElem(i, j, this->getElem(i, j) + x.getElem(i, j));
        }
    }

    return ptr;
}

Matrix* Matrix::operator-(const Matrix& x)
{
    if (!isSimilar(x))
    {
        return nullptr;
    }

    Matrix *ptr = new Matrix(x.row, x.col);                             //記憶體需要釋放
    for (int i = 0; i < row; i++)
    {
        for (int j = 0; j < col; j++)
        {
            ptr->setElem(i, j, this->getElem(i, j) - x.getElem(i, j));
        }
    }

    return ptr;
}

vector<type> Matrix::operator[](const int row)
{
    return data[row];
}

type Matrix::getElem(int row, int col)const
{
    return this->data[row][col];
}

void Matrix::setElem(int row, int col, type cur)
{
    this->data[row][col] = cur;
}

void Matrix::show()
{
    for (int i = 0; i < row; i++)
    {
        if (i == 0)
        {
            std::cout << "┏";
        }
        else if (i == row - 1)
        {
            std::cout << "┗";
        }
        else
        {
            std::cout << "┃";
        }

        for (int j = 0; j < col; j++)
        {
            std::cout.width(4);
            std::cout << data[i][j];
        }

        if (i == 0)
        {
            std::cout << "   ┓";
        }
        else if (i == row - 1)
        {
            std::cout << "   ┛";
        }
        else
        {
            std::cout << "   ┃";
        }

        std::cout << std::endl;
    }
}

bool Matrix::isSimilar(const Matrix& x)
{
    return x.row == this->row && this->col == x.col;
}

Matrix::~Matrix()
{
    this->row = 0;
    this->col = 0;
}