0

I have the following function:

let mtlDevice = MTLCreateSystemDefaultDevice()!
let device = MPSGraphDevice(mtlDevice: mtlDevice)

let inputData: [UInt32] = [0, 45, 0, 0, 45, 81, 0, 54, 0, 0, 54, 81, 1, 45, 0, 1, 45, 81, 1, 54, 0, 1, 54, 81, 7, 63, 567, 7, 63, 648, 7, 72, 567, 7, 72, 648, 8, 63, 567, 8, 63, 648, 8, 72, 567, 8, 72, 648, 1, 9, 81, 1, 9, 162, 1, 18, 81, 1, 18, 162, 2, 9, 81, 2, 9, 162, 2, 18, 81, 2, 18, 162]
let inputTensor = MPSGraphTensorData(device: device, data: Data(bytes: inputData, count: inputData.count * 4),
                                shape: [3, 8, 3], dataType: .uInt32)

let graph = MPSGraph()
let inputPlaceholder = graph.placeholder(shape: [3, 8, 3], dataType: .uInt32, name: nil)
let output = graph.reductionSum(with: inputPlaceholder, axis: 2, name: nil)

let outputTensor = graph.run(feeds: [inputPlaceholder: inputTensor], targetTensors: [output], targetOperations: nil)[output]!
var outputData = [UInt32].init(repeating: 0, count: 3 * 8 * 1)
outputTensor.mpsndarray().readBytes(&outputData, strideBytes: nil)

print("output: \(outputTensor.shape) \(outputData)")

The intention is to reduce inputData, interpreted as a tensor of shape [3, 8, 3] along the third axis. That is, to obtain a tensor of shape [3, 8, 1] where each element is the sum of the 3 elements previously present in axis 2.

When I run the above code, I receive the following output:

output: [3, 8, 1] [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

This is surprising to me because it is obviously not the sum I was expecting.

Performing a similar operation in Tensorflow gives me the expected output:

data = tf.reshape(
    [0, 45, 0, 0, 45, 81, 0, 54, 0, 0, 54, 81, 1, 45, 0, 1, 45, 81, 1, 54, 0, 1, 54, 81, 7, 63, 567, 7, 63, 648, 7, 72, 567, 7, 72, 648, 8, 63, 567, 8, 63, 648, 8, 72, 567, 8, 72, 648, 1, 9, 81, 1, 9, 162, 1, 18, 81, 1, 18, 162, 2, 9, 81, 2, 9, 162, 2, 18, 81, 2, 18, 162],
    shape=[3, 8, 3])

tf.reduce_sum(data, axis=2)

Gives:

[[ 45, 126,  54, 135,  46, 127,  55, 136],
   [637, 718, 646, 727, 638, 719, 647, 728],
   [ 91, 172, 100, 181,  92, 173, 101, 182]]

as expected.

How can I obtain the same behaviour from reductionSum in MPSGraph? What explains the behaviour that I am observing from MPSGraph?

konsolas
  • 1,041
  • 11
  • 24

0 Answers0