I have the following function:
let mtlDevice = MTLCreateSystemDefaultDevice()!
let device = MPSGraphDevice(mtlDevice: mtlDevice)
let inputData: [UInt32] = [0, 45, 0, 0, 45, 81, 0, 54, 0, 0, 54, 81, 1, 45, 0, 1, 45, 81, 1, 54, 0, 1, 54, 81, 7, 63, 567, 7, 63, 648, 7, 72, 567, 7, 72, 648, 8, 63, 567, 8, 63, 648, 8, 72, 567, 8, 72, 648, 1, 9, 81, 1, 9, 162, 1, 18, 81, 1, 18, 162, 2, 9, 81, 2, 9, 162, 2, 18, 81, 2, 18, 162]
let inputTensor = MPSGraphTensorData(device: device, data: Data(bytes: inputData, count: inputData.count * 4),
shape: [3, 8, 3], dataType: .uInt32)
let graph = MPSGraph()
let inputPlaceholder = graph.placeholder(shape: [3, 8, 3], dataType: .uInt32, name: nil)
let output = graph.reductionSum(with: inputPlaceholder, axis: 2, name: nil)
let outputTensor = graph.run(feeds: [inputPlaceholder: inputTensor], targetTensors: [output], targetOperations: nil)[output]!
var outputData = [UInt32].init(repeating: 0, count: 3 * 8 * 1)
outputTensor.mpsndarray().readBytes(&outputData, strideBytes: nil)
print("output: \(outputTensor.shape) \(outputData)")
The intention is to reduce inputData
, interpreted as a tensor of shape [3, 8, 3]
along the third axis. That is, to obtain a tensor of shape [3, 8, 1]
where each element is the sum of the 3 elements previously present in axis 2
.
When I run the above code, I receive the following output:
output: [3, 8, 1] [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
This is surprising to me because it is obviously not the sum I was expecting.
Performing a similar operation in Tensorflow gives me the expected output:
data = tf.reshape(
[0, 45, 0, 0, 45, 81, 0, 54, 0, 0, 54, 81, 1, 45, 0, 1, 45, 81, 1, 54, 0, 1, 54, 81, 7, 63, 567, 7, 63, 648, 7, 72, 567, 7, 72, 648, 8, 63, 567, 8, 63, 648, 8, 72, 567, 8, 72, 648, 1, 9, 81, 1, 9, 162, 1, 18, 81, 1, 18, 162, 2, 9, 81, 2, 9, 162, 2, 18, 81, 2, 18, 162],
shape=[3, 8, 3])
tf.reduce_sum(data, axis=2)
Gives:
[[ 45, 126, 54, 135, 46, 127, 55, 136],
[637, 718, 646, 727, 638, 719, 647, 728],
[ 91, 172, 100, 181, 92, 173, 101, 182]]
as expected.
How can I obtain the same behaviour from reductionSum
in MPSGraph? What explains the behaviour that I am observing from MPSGraph?