-1

So, I am trying to figure out strassen's method for matrix multiplication, I am using C++, but it could be any language. Right now, it is looking like:

 typedef vector<long int> ROW;
 typedef vector<ROW> MATRIX;

 void getQuad(const MATRIX& IN, MATRIX& OUT0, MATRIX& OUT1
                    MATRIX& OUT2, MATRIX& OUT3)
 { /*determine quadrants*/ }

 void strassen(const MATRIX& A, const MATRIX& B, MATRIX& C
 {
      if (A.size() == 2 && A[0] == 2) //know that its 2x2, stop
      {
           // Get M1-M7 vars and set MATRIX C with them
      }
      else
      {
           /*
             getQuad(...) returns the quadrants
             ___________
             | X0 | X1 |
             -----------
             | X2 | X3 |
             -----------
           */

        MATRIX A0,A1,A2,A3;
        getQuad(A,A0,A1,A2,A3);

        MATRIX B0,B1,B2,B3;
        getQuad(B,B0,B1,B2,B3);
      }
 }

I am not sure where to go next with the individual quadrants i.e. how to derive M1-M7 matrices at this point. I would imagine that the M1-M7 matrices (as opposed to primitive data types in base case) would be used in the same manner as the base case. I am just not certain how the unraveling would look like here.

I know its a bit difficult to read someone else's code, but hopefully it has been made clear.

I am certain that my base case is correct, and I am certain that I am splitting the matrix correctly, I just am unclear of where to go next. Perhaps I have written my algorithm wrong.

Serenity
  • 35,289
  • 20
  • 120
  • 115
basil
  • 690
  • 2
  • 11
  • 30

1 Answers1

1

I believe you missed the main point of the Strassen algorithm - the fact that it is recursive. In pseudo-code the algorithm would be something like this:

MATRIX strassen(const MATRIX&a, const MATRIX&b) {
    int aw = a.width();
    int ah = a.height();
    int bw = b.width();
    int bh = b.height();

    if (aw != bh)
        throw some_exception();

    // Strassen algorithm requires each size to be a power of 2
    int max_size = max(aw, ah, bw);
    int extended_size = next_pow_2(max_size);
    MATRIX aEx = a.extend(extended_size, extended_size);
    MATRIX bEx = a.extend(extended_size, extended_size);
    MATRIX cEx = strassenImpl(aEx, bEx);

    // truncate back from power of 2 to real size
    return cEx.truncate(ah, bw);
}


MATRIX strassenImpl(const MATRIX&A, const MATRIX&B) {
    // if matrix size is relatively small it is faster to do the usual straightforward multiplication
    if (A.size() <= threshold) {
        return usualMultiply(A, B);
    }
    // alternatively threshold is 1 so matrix multiplication is just multiplication of the single values
    //if (A.size() == 1) {
    //    return MATRIX(A[0][0]*B[0][0]);
    //} 
    else {
        MATRIX A11, A12, A21, A22;
        getQuad(A, A11, A12, A21, A22);

        MATRIX B11, B12, B21, B22;
        getQuad(B, B11, B12, B21, B22);

        // recursive calls, note that we don't need to go through the extension step
        // here because if the size is a power of 2, half of the size is also a power of 2
        MATRIX M1 = strassenImpl(A11 + A22, B11 + B22);
        MATRIX M2 = strassenImpl(A21 + A22, B11);
        MATRIX M3 = strassenImpl(A11, B12 - B22);
        MATRIX M4 = strassenImpl(A22, B21 - B11);
        MATRIX M5 = strassenImpl(A11 + A12, B22);
        MATRIX M6 = strassenImpl(A21 - A11, B11 + B12);
        MATRIX M2 = strassenImpl(A12 - A22, B21 + B22);

        MATRIX C11 = M1 + M4 - M5 + M7;
        MATRIX C12 = M3 + M5;           
        MATRIX C21 = M2 + M4;
        MATRIX C22 = M1 - M2 + M3 + M6;

        MATRIX C = buildFromQuads(C11, C12, C21, C22);
        return C;
    }
}
SergGr
  • 23,570
  • 2
  • 30
  • 51
  • Yes, I understand, I was just worried about performance bottlenecks when copying vectors in c++. Algorithmically, I already know matrix addition/subtraction will be a problem. – basil Jun 09 '18 at 18:14
  • @basil Strassen algorithm uses more memory than a naive implementation and I think this is unavoidable. You can reduce memory consumption a bit but for that you'll need to wrap raw `vector`s into a custom `Matrix` class that supports creating another `Matrix` as a "view" of its range of indices. In this way you can avoid having separate memory allocations for quads matrices but the seven `M` matrices will need additional memory anyway. – SergGr Jun 09 '18 at 18:26