2

I'll try to keep it as brief as possible.

I hava this map: Map<Neuron,Float> connections. It contains a Neuron-Objekt as Key and the weight of the connection as value.

The Neuron class has a method "getOutput" to get the output-Value of the neuron.

What I want to do is to go over every neuron in the map, calculate neuron.getOutput * connections.get(neuron) and sum all of that into one variable together.

Is there an elegant way to do this with Java-Streams? Maybe with reduce? I tried it but couldn't get it to work properly.

inputConnections.keySet().stream().reduce(
            0f,
            (accumulatedFloat, inputNeuron) -> accumulatedFloat + inputConnections.get(inputNeuron),
            Float::sum);

I guess the 0f results in everything getting multiplied with 0.

This code seems to work, but I'd like a more elegant solution.

AtomicReference<Float> tmp = new AtomicReference<>(0f);
    inputConnections.keySet().forEach(inputNeuron -> {
        tmp.updateAndGet(v -> new Float((float) (v + inputNeuron.getOutput() * inputConnections.get(inputNeuron))));
    });
Wewius
  • 87
  • 7

3 Answers3

4

you can also achieve the same with map and sum

inputConnections.entrySet().stream().mapToDouble(entry -> entry.getKey().getOutput() * entry.getValue()).sum()
Pavel
  • 785
  • 5
  • 14
3

Your approach using reduce is (almost) correct. It should look like the second code snippet where you multiply the neuron's output with the value from the map (inputConnections.get(..))

inputConnections.entrySet()
        .stream()
        .reduce(0f,
                (result, entry) -> result + entry.getKey().getOutput() * entry.getValue(),
                Float::sum);
Thiyagu
  • 17,362
  • 5
  • 42
  • 79
0

You can also use parallel streams (be careful: only makes sense with huge datasets). In case you need some statistics additional to the sum, the collector summarizingDouble is helpful:

DoubleSummaryStatistics sumStat = connections.entrySet().parallelStream()
                .map(entry -> Double.valueOf(entry.getKey().getOutput() * entry.getValue()))
                .collect(Collectors.summarizingDouble(Double::doubleValue));
System.out.println(sumStat);

Example output: DoubleSummaryStatistics{count=3, sum=18.000000, min=4.000000, average=6.000000, max=8.000000}

Here is a shorter version (thank you Rob for your comment):

sumStat = connections.entrySet().parallelStream()
        .mapToDouble(entry -> entry.getKey().getOutput() * entry.getValue())
        .summaryStatistics();
Roland J.
  • 76
  • 4
  • If you `mapToDouble`, you can replace the `collect` step with `summaryStatistics()`. – Rob Spoor Jan 06 '23 at 07:32
  • This looks interesting. I'm not quite sure if I could use this in my application though. My neural networks will be very small. But I'll have a lot of individual small networks that are going to be processed. – Wewius Jan 06 '23 at 16:54