2

I want to add some conditional control in my symbol, it seems that if-else is evaluated in symbol construction time. But I want it to evaluated in symbol run time.

a = mx.symbol.Variable(name='a')
b = mx.symbol.Variable(name='b')

if a>b:
    c = a-b
else:
    c = a+b

TensorFlow provides the tf.cond() operator to deal with it, is there a counterpart in mxnet?

Indhu Bharathi
  • 1,437
  • 1
  • 13
  • 22
Zehao Shi
  • 99
  • 8

1 Answers1

4

You can use mx.symbol.where.

You can compute a_minus_b and a_plus_b and return an array where each element is either from a_minus_b or a_plus_b depending on the corresponding value in another condition array. Here is an example:

a = mx.symbol.Variable(name='a')
b = mx.symbol.Variable(name='b')

a_minus_b = a - b
a_plus_b  = a + b

# gt = a > b
gt = a.__gt__(b) 

result = mx.sym.where(condition=gt, x=a_minus_b, y=a_plus_b)

ex = result.bind(ctx=mx.cpu(), args={'a':mx.nd.array([1,2,3]), 'b':mx.nd.array([3,2,1])})
r = ex.forward()

print(r[0].asnumpy()) #result should be [1+3, 2+2, 3-1]
Indhu Bharathi
  • 1,437
  • 1
  • 13
  • 22