I've been trying to implement Strassen's algorithm for matrix multiplication for the past couple of hours and have had trouble getting the correct product. I think one of the my helper functions (helpSub,createProd, helpProduct) may be the issue or the format of my strass2 function (order of commands, etc). Any tips would be welcome because I'm totally stumped. I've been using two 4 x 4 matrices as tester matrices. I've tried tons of variations of p1-p7 and c1-c4 that I've seen on the internet but none seem to work. Below is the class I've created.
/* @author williamnewman
public class strassen2 {
//Main Strassen multiplication function
//BASE CASE:
int [][] strass2(int[][] x, int[][]y){
if(x.length == 1 && y.length == 1){
System.out.println("Donezo");
int [][] nu = new int[1][1];
nu[0][0] = x[0][0] * y[0][0];
return nu;
}
else{
int[][] a,b,c,d,e,f,g,h;
int dim = x.length/2;
//Dividing two matrices into 8 sub matrices
System.out.println("A<B<C");
a = helpSub(0,0,x);
C(a);
b = helpSub(0,dim,x);
C(b);
c = helpSub(dim,0,x);
C(c);
d = helpSub(dim,dim,x);
C(d);
e = helpSub(0,0,y);
C(e);
f = helpSub(0,dim,y);
C(f);
g = helpSub(dim,0,y);
C(g);
h = helpSub(dim,dim,y);
C(h);
int[][] p1,p2,p3,p4,p5,p6,p7;
//Creating p1-p7
/
p1 = strass2(a,subtract(f,h));
p2 = strass2(h, add(a,b));
p3 = strass2(e,add(c,d));
p4 = strass2(d,subtract(g,e));
p5 = strass2(add(a,d),add(e,h));
p6 = strass2(subtract(b,d),add(g,h));
p7 = strass2(subtract(a,c),add(e,f));
int [][] prod;
int [][] c1,c2,c3,c4;
//Creating c1-c4
c1 = subtract(add(p6,p5),subtract(p4,p2));
c2 = add(p1,p2);
c3 = add(p3,p4);
c4 = subtract(add(p1,p5),subtract(p3,p7));
C(c1);
System.out.println("C1::");
C(c2);
System.out.println("C2::");
C(c3);
System.out.println("C3::");
C(c4);
System.out.println("C4::");
//CREATES PRODUCT MATRIX
prod = createProd(c1,c2,c3,c4);
return prod;
}
}
//Creates product matrix from c1-c4
int[][] createProd(int[][] c1, int[][] c2, int[][] c3, int[][] c4){
int[][] product = new int[c1.length*2][c1.length*2];
int mid = c1.length;
int fin = c1.length * 2;
helpProduct(0,0,mid,mid,product,c1);
helpProduct(0,mid,mid,fin,product,c2);
helpProduct(mid,0,fin,mid,product,c3);
helpProduct(mid,mid,fin,fin,product,c4);
System.out.println();
System.out.println("PRODUCT::!:");
C(product);
return product;
}
//Helper function to create larger matrix from submatrices
void helpProduct(int x, int y, int z1, int z2,int[][] product, int[][] a1){
int indR = 0;
int indC = 0;
for(int i = x; i < z1; i++){
indC = 0;
for(int j = y; j < z2; j++){
product[i][j] = a1[indR][indC];
indC++;
}
indR++;
}
}
int[][] helpSub(int x, int y, int[][] mat){
int[][] sub = new int[mat.length/2][mat.length/2];
for(int i1 = 0, i2=x; i1 < (mat.length/2); i1++, i2++)
for(int j1 = 0, j2=y; j1<(mat.length/2); j1++, j2++)
{
sub[i1][j1] = mat[i2][j2];
// System.out.println(sub[i1][j1]);
}
return sub;
}
//Normal Matrix Multiplication Function
int[][] multiply(int[][]a,int[][]b){
MM nu = new MM(a,b);
return nu.product;
}
//Adds one matrix to the next
int[][] add(int[][]a, int[][]b){
int [][] nu = new int[a.length][a[0].length];
for(int i = 0; i < a.length; i++){
for(int j = 0; j < a[i].length;j++){
nu[i][j] = a[i][j] + b[i][j];
}
}
return nu;
}
//Subtracts second matrix from the first
int[][] subtract(int[][] a, int[][] b){
int [][] sub = new int[a.length][a.length];
//System.out.println("made it");
for(int i = 0; i < a.length; i++){
for(int j = 0; j < a[i].length;j++){
sub[i][j] = a[i][j] - b[i][j];
}
}
return sub;
}
//Prints the matrix
void C(int[][] product){
for(int i = 0; i <product.length; i++){
for(int j = 0; j < product[i].length; j++){
System.out.print(product[j][i] + " ");
}
System.out.println();
}
}
}
If anything is confusing let me know and I'll update the question!
Here is the main function::
public static void main(String[] args) {
int [][]a = {{1,2,3,4},
{4,3,2,1},
{1,2,3,4},
{4,3,2,1}};
int [][]b = {{3,4,5,6},
{3,4,5,6},
{5,4,3,2},
{5,4,3,2}
};
MM a1 = new MM(a,b);
a1.C();
int[][] prod;
System.out.println("----");
strassen2 a2 = new strassen2();
prod = a2.strass2(a,b);
a2.C(prod);
}
}
Here are the results so far (the expected result is the first 4x4 matrix shown and the actual result is the last 4x4 matrix shown):
EXPECTED:
44 40 36 32
36 40 44 48
44 40 36 32
36 40 44 48
----
ACTUAL::
70 78 50 42
86 86 34 34
30 38 30 38
38 54 38 54
I'm pretty sure my helpSub() functions work because they produced the corrected a-h. However, there might be a problem with the parameters I use in the strass2 recursive calls. I'm sorry if it's not specific enough I'm just a bit burnt out on it and was curious if anyone saw any glaring issues.