1

In mxnet 1.4 using the Python API, suppose I do

import mxnet as mx

var = mx.sym.var('a')
print(var)  # <Symbol a>

var = mx.sym.var('b')
print(var)  # <Symbol b>

How can I access Symbol a by name?

I would like to do something like var_a = mx.sym.get_by_name('a').

I have checked this tutorial, the docs, as well as the source code but couldn't find anything.

buechel
  • 717
  • 7
  • 18

2 Answers2

0

Use the get_internals() or get_children() methods and then access individual symbols using their index -

a = mx.sym.var('a')
b = mx.sym.var('b')
tmp = a * b

graph = tmp.get_internals()
print(graph)
print(graph[0])
print(tmp.get_children())

Output:

<Symbol group [a, b, _mul22]>
<Symbol a>
<Symbol group [a, b]>
Anirudh
  • 52
  • 1
  • 7
0

A messy way is to get a list of symbol names first, then query the list to get the index of the name, then use this index to query the symbol group.

symbol_output = last_layer.get_internals()
symbol_output_list = symbol_output.list_outputs()
# say the name is 'conv0_output'
conv0_index = symbol_output_list.index('conv0_output')
print(conv0_index)
# 8
print(symbol_output[conv0_index])
# <Symbol conv0>
Yiding
  • 91
  • 7