0

I am trying to implement a dot product layer but it doesn't seem to work correctly. Here's a couple of implementations I have so far

Implementation 1

// dot product of betas and factors
        graphBuilder.addVertex("layer_product_",
                new ElementWiseVertex(ElementWiseVertex.Op.Product), "layer_output_Betas_", "layer_output_F_");
        graphBuilder.addLayer("layer_dot_",
                new DenseLayer.Builder().nOut(1)
                .activation(new SumActivation()).build(), "layer_product_");

SumActivation.java

public class SumActivation extends BaseActivationFunction {
    @Override
    public INDArray getActivation(INDArray in, boolean training) {
        Nd4j.getExecutioner().execAndReturn(new Sum(in, 1));
        return in;
    }

    @Override
    public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
        assertShape(in, epsilon);
        Nd4j.getExecutioner().execAndReturn(new Sum(in, 1));
        return new Pair<>(in, null);
    }

    @Override
    public String toString() {
       return  "reduce-sum";
    }
}

Implementation 2

// dot product of betas and factors
graphBuilder.addVertex("layer_product_",
                new ElementWiseVertex(ElementWiseVertex.Op.Product), "layer_output_Betas_", "layer_output_F_");
    graphBuilder.addLayer("layer_dot_",
                        new GlobalPoolingLayer.Builder().poolingDimensions(1).poolingType(PoolingType.SUM).build(), "layer_product_");
                graphBuilder.inputPreProcessor("layer_dot_", new FeedForwardToRnnPreProcessor());

The input will be n x 6 matrix and the output of the dot product will be a n x 1 vector. Is there something wrong that I am doing here?

harshvardhan.agr
  • 165
  • 1
  • 12
  • Do you get any specific error message, or is the math not coming out right? – Paul Dubs May 11 '20 at 07:16
  • @PaulDubs For Implementation 1 the math is not right and all rows have the same value in the output. For 2, it errors out when i try to use a batch size of greater than 1. It always gives me a 1x1 output instead of a nx1 vector in the output; – harshvardhan.agr May 11 '20 at 12:03

0 Answers0