I need perform the matrix multiplicatión between differents layers in a neural network. That is: W0, W1, W2, ... Wn
are the weights of the neural netwotk and the input is data
. Resulting matrices are:
Out1 = data * W0
Out2 = Out1 * W1
Out3 = Out2 * W2
.
.
.
OutN = Out(N-1) * Wn
I Know the absolute max value in the weights matrices and also I know that the input data range values are from 0 to 1 (input are normalizated). The matrix multiplication is in fixed point with 16 bits. The weights are scalated to the optimal format point. For example: if the absolute maximun value in W0
is 2.5 I know that the minimun number of bits in the integer part is 2 and the bits in fractional part will be 14. Because the data input is in the range [0,1] also I know the integer and fractional bits are 1.15.
My question is: How can I know the mininum number of bits in the integer part in the resultant matrix to avoid overflow? Is there anyway to study and infer the maximun value in a matrix multiplication? I know about determinant and norm of a matrix, but, I think the problem is in the consecutive negatives or positives values in the matrix rows an columns. For example, if I have this row vector and this column vector, and the result is in 8 bits fixed point:
A = [1, 2, 3, 4, 5, 6, -7, -8]
B = [1, 2, 3, 4, 5, 6, 7, 8]
A * B = (1*1) + (2*2) + (3*3) + (4*4) + (5*5) + (6*6) + (7*-7) + (8*8) = 90 - 49 + -68
When the sum accumulator is below than 64, occurs overflow altough the final result be contained between [-64,63].
Another example: If I have have this row vector and this column vector, and the result is in 8 bits fixed point:
A = [1, -2, 3, -4, 5, -6, 7, -8]
B = [1, 2, 3, 4, 5, 6, 7, 8]
A * B = (1*1) - (2*2) + (3*3) - (4*4) + (5*5) - (6*6) + (7*7) - (8*8) = -36
The sum accumulator in any moment exceeds the maximun range for 8 bits.
To sum up: I'm looking for a way to analize the weights matrices to avoid the overflow in the sum accumulator. The way that I do the matrix multiplication is (only a example if matrices A and B has been scalated to 1.15 format):
A1 --> 1.15 bits
B1 --> 1.15 bits
A2 --> 1.15 bits
B2 --> 1.15 bits
mult_1 = (A1 * B1) >> 2^15; // Right shift to alineate the operands
mult_2 = (A2 * B2) >> 2^15; // Right shift to alineate the operands
sum_acc = mult_1 + mult_2; // Sum accumulator