2

I need perform the matrix multiplicatión between differents layers in a neural network. That is: W0, W1, W2, ... Wn are the weights of the neural netwotk and the input is data. Resulting matrices are:

Out1 = data * W0
Out2 = Out1 * W1
Out3 = Out2 * W2
.
.
.
OutN = Out(N-1) * Wn

I Know the absolute max value in the weights matrices and also I know that the input data range values are from 0 to 1 (input are normalizated). The matrix multiplication is in fixed point with 16 bits. The weights are scalated to the optimal format point. For example: if the absolute maximun value in W0 is 2.5 I know that the minimun number of bits in the integer part is 2 and the bits in fractional part will be 14. Because the data input is in the range [0,1] also I know the integer and fractional bits are 1.15.

My question is: How can I know the mininum number of bits in the integer part in the resultant matrix to avoid overflow? Is there anyway to study and infer the maximun value in a matrix multiplication? I know about determinant and norm of a matrix, but, I think the problem is in the consecutive negatives or positives values in the matrix rows an columns. For example, if I have this row vector and this column vector, and the result is in 8 bits fixed point:

A = [1, 2, 3, 4, 5, 6, -7, -8]
B = [1, 2, 3, 4, 5, 6, 7, 8]
A * B = (1*1) + (2*2) + (3*3) + (4*4) + (5*5) + (6*6) + (7*-7) + (8*8) = 90 - 49 + -68

When the sum accumulator is below than 64, occurs overflow altough the final result be contained between [-64,63].

Another example: If I have have this row vector and this column vector, and the result is in 8 bits fixed point:

A = [1, -2, 3, -4, 5, -6, 7, -8]
B = [1, 2, 3, 4, 5, 6, 7, 8]
A * B = (1*1) - (2*2) + (3*3) - (4*4) + (5*5) - (6*6) + (7*7) - (8*8) = -36

The sum accumulator in any moment exceeds the maximun range for 8 bits.

To sum up: I'm looking for a way to analize the weights matrices to avoid the overflow in the sum accumulator. The way that I do the matrix multiplication is (only a example if matrices A and B has been scalated to 1.15 format):

A1 --> 1.15 bits
B1 --> 1.15 bits
A2 --> 1.15 bits
B2 --> 1.15 bits
mult_1 = (A1 * B1) >> 2^15; // Right shift to alineate the operands
mult_2 = (A2 * B2) >> 2^15; // Right shift to alineate the operands
sum_acc = mult_1 + mult_2;  // Sum accumulator
Diego Ruiz
  • 187
  • 2
  • 11
  • Hi @spektre And whats happends when the matrix dimensions are larges: from [hundreds x hundreds] to [thousands x thousands] ? Maybe I don't understand you qhen you say: "you will end up with (2*(n-1)).15". Thanks. – Diego Ruiz Nov 28 '20 at 11:26
  • Ok, if the log function is applied to the maximum integer bits It's possible implement it in a FPGA. Also I will look for papers or similar to find methods that infers the magnitude of matrix multiplication. Thnx @Spektre !! – Diego Ruiz Nov 28 '20 at 11:37
  • 1
    Ok I finished the editing of my answer... It should be correct now as all examples do match now. – Spektre Nov 28 '20 at 12:10

1 Answers1

2

let consider n=100 dimensional dot product (which is part of any matrix multiplication or convolution) of %3.13 fixed point format as an example.

  1. Integer bits

    max value in %4.13 is slightly below 2^4 so let consider it would be: 15.999999

    Now n dimensional dot product has n multiplications and n-1 additions.

    15.999999*15.999999 + 15.999999*15.999999 + .... + 15.999999*15.999999
    

    Each multiplication will sum up the integer bits

    15.999999*15.999999 = 255.999999 -> ceil(log2(255)) = 8 = 2*(4)-> %8.13
    

    Now this value is 99 times added so its the same as:

    255.999999*99 = 25343.999999 -> ceil(log2(25343)) = 15 = ceil(8+log2(99)) -> %15.13
    

    So if n is number of dimensions and i is number of integer bits the result needs:

    i' = ceil((i*2)+log2(n-1)) 
    

    integer bits... so:

    %1.? -> 99*( 1.999999^2) =   395.99 -> % 9.?
    %2.? -> 99*( 3.999999^2) =  1583.99 -> %11.?
    %3.? -> 99*( 7.999999^2) =  6335.99 -> %13.?
    %4.? -> 99*(15.999999^2) = 25343.99 -> %15.?
    
    i(1) = ceil((1*2)+log2(99)) = ceil(2+6.626) = 9
    i(2) = ceil((2*2)+log2(99)) = ceil(4+6.626) = 11
    i(3) = ceil((3*2)+log2(99)) = ceil(6+6.626) = 13
    i(4) = ceil((4*2)+log2(99)) = ceil(8+6.626) = 15
    
  2. Fractional bits

    ok let see what hapens with multiplication:

    0.1b^2 = 0.01b        -> %?.1 -> %?.2
    0.01b^2 = 0.0001b     -> %?.2 -> %?.4
    0.001b^2 = 0.000001b  -> %?.3 -> %?.6
    

    so f' = 2*f where f is number of fractional bits. The addition is not changing the bitwidth:

    0.1b*2 = 1.0b         -> %?.1 -> %?.1
    0.01b*2 = 0.1b        -> %?.2 -> %?.2
    0.001b*2 = 0.01b      -> %?.3 -> %?.3
    

    as the result will not be smaller then operands. So when applying fractional part to the dot product we will have:

    i' = ceil((i*2)+log2(n-1)) 
    f' = 2*f 
    
Spektre
  • 49,595
  • 11
  • 110
  • 380
  • The last question: Suppose that I have this matrices dimensions: A_dims = [4096, 512], B_dims = [512, 4096] and this fixed points formats: format_A = [4, 12], format_B = [1, 15]. Because the maximun number of operations involves in the matrix multiplication to perfomr one result is Columns(A) * Rows(B), the values for n and i will be: n = A_dims[1] i = format_A[0] * format_B[0] and the number of integer resultant bits will be: integer_bits = ceil(((i+1)*2)+log2(n-1)) The number of bits in the fractional part will be always the maxmiun between two operands fractional part,right? – Diego Ruiz Nov 28 '20 at 12:38
  • @DiegoRuiz fractional part is usually truncated....However if you do not want to lose precision then addition will not change it, but multiplication will again sum the fractional bits of operands... Do you want similar example Like I did for integer part? – Spektre Nov 28 '20 at 12:40
  • Another example is not necessary, I understand what you mean. You have helped me a lot so far. Thank you! – Diego Ruiz Nov 28 '20 at 12:45
  • 1
    @DiegoRuiz heh meanwhile I did it anyway :) also `ceil(((i+1)*2)+log2(n-1))` in your comment is not correct it should be `ceil((i*2)+log2(n-1))` – Spektre Nov 28 '20 at 12:52