7

Say I have two very big matrices A (M-by-N) and B (N-by-M). I need the diagonal of A*B. Computing the full A*B requires M*M*N multiplications, while computing the diagonal of it only requires M*N multiplications since there's no need to compute the elements that will end up outside the diagonal.

Does MATLAB realize this and on-the-fly-optimize diag(A*B) automagically, or am I better off using a for loop in this case?

Luis Mendo
  • 110,752
  • 13
  • 76
  • 147
Emil Lundberg
  • 7,268
  • 6
  • 37
  • 53
  • Typicaly how big are those numbers - N and M? – Divakar May 15 '14 at 15:24
  • The application where I encountered this question is where A is weight vectors in an artificial neural network and B is the difference between the input vector and another weight matrix. In this particular case M would be somewhere between 100 and 1000, and N = 4. So my original application may not actually qualify as "very big matrices". :P – Emil Lundberg May 15 '14 at 15:30
  • If all the benchmarking done in the answers suggests that `diag(A*B)` isn't optimized on the fly, contact MathWorks with a feature request. – Sam Roberts May 15 '14 at 17:34

4 Answers4

11

One can also implement diag(A*B) as sum(A.*B',2). Let's benchmark this along with all other implementations/solutions as suggested for this question.

The different methods implemented as functions are listed below for benchmarking purposes:

  1. Sum-multiplication method-1

    function out = sum_mult_method1(A,B)
    
    out = sum(A.*B',2);
    
  2. Sum-multiplication method-2

    function out = sum_mult_method2(A,B)
    
    out = sum(A.'.*B).';
    
  3. For-loop method

    function out = for_loop_method(A,B)
    
    M = size(A,1);
    out = zeros(M,1);
    for i=1:M
        out(i) = A(i,:) * B(:,i);
    end
    
  4. Full/Direct-multiplication method

    function out = direct_mult_method(A,B)
    
    out = diag(A*B);
    
  5. Bsxfun-method

    function out = bsxfun_method(A,B)
    
    out = sum(bsxfun(@times,A,B.'),2);
    

Benchmarking Code

num_runs = 1000;
M_arr = [100 200 500 1000];
N = 4;

%// Warm up tic/toc.
tic();
elapsed = toc();
tic();
elapsed = toc();

for k2 = 1:numel(M_arr)
    M = M_arr(k2);

    fprintf('\n')
    disp(strcat('*** Benchmarking sizes are M =',num2str(M),' and N = ',num2str(N)));

    A = randi(9,M,N);
    B = randi(9,N,M);

    disp('1. Sum-multiplication method-1');
    tic
    for k = 1:num_runs
        out1 = sum_mult_method1(A,B);
    end
    toc
    clear out1

    disp('2. Sum-multiplication method-2');
    tic
    for k = 1:num_runs
        out2 = sum_mult_method2(A,B);
    end
    toc
    clear out2

    disp('3. For-loop method');
    tic
    for k = 1:num_runs
        out3 = for_loop_method(A,B);
    end
    toc
    clear out3

    disp('4. Direct-multiplication method');
    tic
    for k = 1:num_runs
        out4 = direct_mult_method(A,B);
    end
    toc
    clear out4

    disp('5. Bsxfun method');
    tic
    for k = 1:num_runs
        out5 = bsxfun_method(A,B);
    end
    toc
    clear out5

end

Results

*** Benchmarking sizes are M =100 and N =4
1. Sum-multiplication method-1
Elapsed time is 0.015242 seconds.
2. Sum-multiplication method-2
Elapsed time is 0.015180 seconds.
3. For-loop method
Elapsed time is 0.192021 seconds.
4. Direct-multiplication method
Elapsed time is 0.065543 seconds.
5. Bsxfun method
Elapsed time is 0.054149 seconds.

*** Benchmarking sizes are M =200 and N =4
1. Sum-multiplication method-1
Elapsed time is 0.009138 seconds.
2. Sum-multiplication method-2
Elapsed time is 0.009428 seconds.
3. For-loop method
Elapsed time is 0.435735 seconds.
4. Direct-multiplication method
Elapsed time is 0.148908 seconds.
5. Bsxfun method
Elapsed time is 0.030946 seconds.

*** Benchmarking sizes are M =500 and N =4
1. Sum-multiplication method-1
Elapsed time is 0.033287 seconds.
2. Sum-multiplication method-2
Elapsed time is 0.026405 seconds.
3. For-loop method
Elapsed time is 0.965260 seconds.
4. Direct-multiplication method
Elapsed time is 2.832855 seconds.
5. Bsxfun method
Elapsed time is 0.034923 seconds.

*** Benchmarking sizes are M =1000 and N =4
1. Sum-multiplication method-1
Elapsed time is 0.026068 seconds.
2. Sum-multiplication method-2
Elapsed time is 0.032850 seconds.
3. For-loop method
Elapsed time is 1.775382 seconds.
4. Direct-multiplication method
Elapsed time is 13.764870 seconds.
5. Bsxfun method
Elapsed time is 0.044931 seconds.

Intermediate Conclusions

Looks like sum-multiplication methods are the best approaches, though bsxfun approach seems be to catching up with them as M increases from 100 to 1000.

Next, higher benchmarking sizes were tested with just the sum-multiplication and bsxfun methods. The sizes were -

M_arr = [1000 2000 5000 10000 20000 50000];

The results are -

*** Benchmarking sizes are M =1000 and N =4
1. Sum-multiplication method-1
Elapsed time is 0.030390 seconds.
2. Sum-multiplication method-2
Elapsed time is 0.032334 seconds.
5. Bsxfun method
Elapsed time is 0.047377 seconds.

*** Benchmarking sizes are M =2000 and N =4
1. Sum-multiplication method-1
Elapsed time is 0.040111 seconds.
2. Sum-multiplication method-2
Elapsed time is 0.045132 seconds.
5. Bsxfun method
Elapsed time is 0.060762 seconds.

*** Benchmarking sizes are M =5000 and N =4
1. Sum-multiplication method-1
Elapsed time is 0.099986 seconds.
2. Sum-multiplication method-2
Elapsed time is 0.103213 seconds.
5. Bsxfun method
Elapsed time is 0.117650 seconds.

*** Benchmarking sizes are M =10000 and N =4
1. Sum-multiplication method-1
Elapsed time is 0.375604 seconds.
2. Sum-multiplication method-2
Elapsed time is 0.273726 seconds.
5. Bsxfun method
Elapsed time is 0.226791 seconds.

*** Benchmarking sizes are M =20000 and N =4
1. Sum-multiplication method-1
Elapsed time is 1.906839 seconds.
2. Sum-multiplication method-2
Elapsed time is 1.849166 seconds.
5. Bsxfun method
Elapsed time is 1.344905 seconds.

*** Benchmarking sizes are M =50000 and N =4
1. Sum-multiplication method-1
Elapsed time is 5.159177 seconds.
2. Sum-multiplication method-2
Elapsed time is 5.081211 seconds.
5. Bsxfun method
Elapsed time is 3.866018 seconds.

Alternate benchmarking Code (with `timeit)

num_runs = 1000;
M_arr = [1000 2000 5000 10000 20000 50000 100000 200000 500000 1000000];
N = 4;

timeall = zeros(5,numel(M_arr));
for k2 = 1:numel(M_arr)
    M = M_arr(k2);

    A = rand(M,N);
    B = rand(N,M);

    f = @() sum_mult_method1(A,B);
    timeall(1,k2) = timeit(f);
    clear f

    f = @() sum_mult_method2(A,B);
    timeall(2,k2) = timeit(f);
    clear f

    f = @() bsxfun_method(A,B);
    timeall(5,k2) = timeit(f);
    clear f

end

figure,
hold on
plot(M_arr,timeall(1,:),'-ro')
plot(M_arr,timeall(2,:),'-ko')
plot(M_arr,timeall(5,:),'-.b')
legend('sum-method1','sum-method2','bsxfun-method')
xlabel('M ->')
ylabel('Time(sec) ->')

Plot

enter image description here

Final Conclusions

It seems sum-multiplication method is great till certain stage, which is around M=5000 mark and after that bsxfun seems to have a slight upper-hand.

Future Work

One can look into varying N and study the performances for the implementations mentioned here.

Divakar
  • 218,885
  • 19
  • 262
  • 358
  • 2
    Those times look pretty small to be reliable and you're not doing any warm up. If you're not familiar with doing reliable benchmarking, but you have a recent Matlab, I suggest using `timeit`. Also, how about testing `bsxfun`? – horchler May 15 '14 at 15:54
  • @horchler Well those numbers are what OP is using as suggested in comments. Looking into timeit now. – Divakar May 15 '14 at 15:55
  • Those may be the sizes of the matrices, but that doesn't mean that you can't do more runs and perhaps more importantly, warm up the functions before timing. – horchler May 15 '14 at 15:57
4

Yes, this is one of the rare cases where a for loop is better.

I ran the following script through the profiler:

M = 5000;
N = 5000;

A = rand(M, N); B = rand(N, M);
product = A*B;
diag1 = diag(product);

A = rand(M, N); B = rand(N, M);
diag2 = diag(A*B);

A = rand(M, N); B = rand(N, M);
diag3 = zeros(M,1);
for i=1:M
    diag3(i) = A(i,:) * B(:,i);
end

I reset A and B between each test just in case MATLAB would try to speed anything up by caching.

Result (edited for brevity):

  time   calls  line
  6.29       1    5 product = A*B; 
< 0.01       1    6 diag1 = diag(product); 

  5.46       1    9 diag2 = diag(A*B); 

             1   12 diag3 = zeros(M,1); 
             1   13 for i=1:M 
  0.52    5000   14     diag3(i) = A(i,:) * B(:,i); 
< 0.01    5000   15 end 

As we can see, the for loop variant is an order of magnitude faster than the other two in this case. While the diag(A*B) variant is actually faster than the diag(product) variant, it's marginal at best.

I tried some different values of M and N, and in my tests the for loop variant is slower only if M=1.

Emil Lundberg
  • 7,268
  • 6
  • 37
  • 53
3

Actually, you can do this faster than a for loop using the wonders of bsxfun:

diag4 = sum(bsxfun(@times,A,B.'),2)

This about twice as fast as the explicit for loop on my machine for large matrices (2,000-by-2,000 and bigger) and is faster for matrices larger than about 500-by-500.

Note that all of these methods will produce numerically different results because of the different orders of summation and multiplication.

horchler
  • 18,384
  • 4
  • 37
  • 73
  • +1 for this approach, and for the very appropriate note on numerical differences. If you have the benchmarking code maybe you could compare my solution too and tell the results? – Luis Mendo May 15 '14 at 15:42
  • @LuisMendo: The `bsxfun` method is about 20% faster than your solution for the 5,000-by-5,000 case. I think it might be the parallelization implicit in the method. – horchler May 15 '14 at 15:51
  • Thanks! Good to know. That's one more reason to like `bsxfun`. However, `.*` and `sum` [are claimed](http://www.mathworks.com/matlabcentral/answers/95958-which-matlab-functions-benefit-from-multithreaded-computation) to benefit from parallelization too... – Luis Mendo May 15 '14 at 15:54
3

You can compute only the diagonal elements without a loop: just use

sum(A.'.*B).'

or

sum(A.*B.',2)
Luis Mendo
  • 110,752
  • 13
  • 76
  • 147
  • 1
    Or `sum(A.*B',2)`. I have some benchmarking results on this, coming up! Hope thats okay! – Divakar May 15 '14 at 15:42
  • @Divakar Great! I had just asked @horchler for that. BTW, I tend to avoid the `,2` in `sum` on the basis that it's slower, but now I'm not sure that feeling is correct – Luis Mendo May 15 '14 at 15:44
  • Yeah, I realized shortly after submitting my answer, and was just about to edit that in. :) – Emil Lundberg May 15 '14 at 15:47
  • @LuisMendo Added separate benchmarking results for your implementation too! `timeit` is what I am looking into now, to verify if the benchmarking results presented here are reliable. – Divakar May 15 '14 at 16:53