I trying to use cblas.h (from openblas library) to compute the product of two matrices. Μore specifically, Ι have a double array A, with dimensions n*d, an array B with dimensions m*d and an array C with dimensions n*m. I want to compute the product A 'times' B transpose.
My code is
#include <cblas.h>
#include <stdio.h>
#include <stdlib.h>
void random_matrix(double *X, int rows, int cols);
void print_matrix(double *X, int rows, int cols);
int main(int argc, char** argv)
{
int n = atoi(argv[1]),
m = atoi(argv[2]),
d = atoi(argv[3]);
double *A, *B, *C;
A = malloc(n*d*sizeof(double));
B = malloc(m*d*sizeof(double));
C = malloc(n*m*sizeof(double));
random_matrix(A,n,d);
print_matrix(A,n,d);
random_matrix(B,m,d);
print_matrix(B,m,d);
cblas_dgemm(CblasRowMajor,
CblasNoTrans, CblasTrans, n,m,d,
1.0, A, n, B, m,
0.0, C, n
);
print_matrix(C,n,m);
return 0;
}
void random_matrix(double *X, int rows, int cols){
for(int i = 0; i < rows; i++)
for(int j = 0; j < cols; j++)
X[i*cols+j] = (double)rand() / RAND_MAX + (double)(rand()%10);
}
void print_matrix(double *X, int rows, int cols){
for(int i = 0; i < rows; i++) {
for(int j = 0; j < cols; j++) {
printf("%g ", X[i*cols+j]);
}printf(";\n");
}printf("\n\n");
}
When I run the program with n = 6, m = 5 and d = 2, the output is:
9.00001 8.75561 ;
2.53277 8.04704 ;
9.6793 5.3835 ;
2.83097 3.05346 ;
9.67115 2.38342 ;
9.41749 7.58898 ;
3.84617 8.09196 ;
5.416 6.91032 ;
7.26245 3.73608 ;
9.63264 1.99104 ;
2.24704 6.72266 ;
105.466 117.964 4.58702e-309 4.70574e-309 0 ;
-1.83255e-06 35.5969 39.9896 1.59969e-309 1.4802e-309 ;
0 -1.83255e-06 0 0 0 ;
0 0 -1.83255e-06 0 0 ;
0 0 0 -1.83255e-06 0 ;
0 0 0 0 -1.83255e-06 ;
Which is wrong because when I try it on octave, I get:
octave:52> A = [9.00001 8.75561 ;
> 2.53277 8.04704 ;
> 9.6793 5.3835 ;
> 2.83097 3.05346 ;
> 9.67115 2.38342 ;
> 9.41749 7.58898 ;];
octave:53> B = [3.84617 8.09196 ;
> 5.416 6.91032 ;
> 7.26245 3.73608 ;
> 9.63264 1.99104 ;
> 2.24704 6.72266 ;];
octave:54> A*B'
ans =
105.466 109.248 98.074 104.127 79.084
74.858 69.325 48.459 40.419 59.789
80.791 89.625 90.409 103.956 57.941
35.597 36.433 31.968 33.349 26.889
56.483 68.849 79.141 97.904 37.754
97.631 103.447 96.747 105.825 72.180