1

In mxnet 1.4 using the Python API, suppose I do

import mxnet as mx

tmp = mx.sym.var('a')
print(tmp)  # <Symbol a>

tmp = tmp + tmp
print(tmp)  # <Symbol _plus0>

tmp = mx.sym.var('b')
tmp = tmp + tmp
print(tmp)  # <Symbol _plus1>

I assume, <Symbol _plus0> is still present in the graph somewhere. How can I list all symbols which currently live in my graph?

I would like to do something like mx.sym.list_all_symbols().

I have checked this tutorial, the docs, as well as the source code but couldn't find anything.

buechel
  • 717
  • 7
  • 18

1 Answers1

0

Use mxnet's viz module to plot the network. You can also save the symbol graph to a json file and peruse the json file to view all the symbols -

import mxnet as mx
a = mx.sym.Variable('a')
b = mx.sym.Variable('b')
c = a + b
d = a * b
# matrix multiplication
e = mx.sym.dot(a, b)
# reshape
f = mx.sym.reshape(d+e, shape=(1,4))
# plot
f.save('fgraph-symbol.json')
mx.viz.plot_network(symbol=f)

The graph that gets plotted

Edit :

If you want the list the symbols and not just the names, then you should use get_internals() function. Run

print(f.get_internals())

you will get a symbol group containing a list of all the symbol outputs. you can then access the individual symbols using their index in the group -

<Symbol group [a, b, _mul11, dot24, _plus51, reshape24]>

You can also walk through the graph using f.get_children().

Anirudh
  • 52
  • 1
  • 7
  • Thanks for your answer. But that does not really give me programmatic access to them. Assume I would want to apply some function to each of the symbols in my graph. – buechel Oct 17 '19 at 09:49
  • if you want the symbols, then your best bet is to use `get_internals` method. Run `print(f.get_internals())` you will get a symbol group containing a list of all the symbol outputs - ``. You can also walk though the graph using `f.get_children()`. – Anirudh Oct 18 '19 at 19:08