1

In Python, can I call a variable from main function? Use global variable? Any help appreciated!

def main(dataset, n_h, n_y, batch_size, dev_split, n_epochs):
    input_to_state = Linear(name='input_to_state',
                            input_dim=seq_u.shape[-1],
                            output_dim=n_h)
    global RNN # correct?
    RNN = SimpleRecurrent(activation=Tanh(),
                          dim=n_h, name="RNN")


def predict(dev_X):
    dev_transform = main.input_to_state.apply(dev_X) #?  call  "input_to_state", which one is correct?
    dev_transform = input_to_state.apply(dev_X) #?
    dev_h = main.RNN.apply(dev_transform) #? call "RNN", which one is correct?
    dev_h = RNN.apply(dev_transform) #?

if __name__ == "__main__":    
    def predict(dev_X): #  one more question: can predict function be added here?
    dataset =  ....
    main(dataset, n_h, n_y, batch_size, dev_split, 5000)
    get_predictions = theano.function([dev_X], predict) # call predict function
Rick
  • 189
  • 1
  • 2
  • 11

2 Answers2

0

You have to define 'input_to_state' and 'RNN' outside the 'main' function, and then modify them afterwards. Like this:

input_to_state = None
RNN = None
def main(dataset, n_h, n_y, batch_size, dev_split, n_epochs):
    # Calling 'global' allows you to modify these variables
    global input_to_state
    global RNN
    input_to_state = Linear(name='input_to_state',
                            input_dim=seq_u.shape[-1],
                            output_dim=n_h)
    RNN = SimpleRecurrent(activation=Tanh(),
                          dim=n_h, name="RNN")


def predict(dev_X):
    dev_transform = input_to_state.apply(dev_X)
    dev_h = RNN.apply(dev_transform)

if __name__ == "__main__":   
    main(args) 
    predict(dev_X)

Howerver, I would not recommend this, global variables should be used as little as possible. more detail here.

A better solution would be to return 'input_to_state' and 'RNN' at the end of the main function, like this:

def main(dataset, n_h, n_y, batch_size, dev_split, n_epochs):
    input_to_state = Linear(name='input_to_state',
                            input_dim=seq_u.shape[-1],
                            output_dim=n_h)
    RNN = SimpleRecurrent(activation=Tanh(),
                          dim=n_h, name="RNN")
    return input_to_state, RNN

def predict(dev_X, input_to_state, RNN):
    dev_transform = input_to_state.apply(dev_X)
    dev_h = RNN.apply(dev_transform)

if __name__ == "__main__":   
    input_to_state, RNN = main(args) 
    predict(dev_X, input_to_state, RNN)
-1

Try with this.

main.py

__dataset__ = main(dataset, n_h, n_y, batch_size, dev_split, 5000)

sub.py

import sys, main
__dataset__ = sys.modules['__main__'].__dataset__


EDIT:
Another method is to use a class with static variables.

mclass.py

class MClass:
    i = 0

MClass.i = 1

main.py

import sub
from mclass import MClass

# In the main file
print(MClass.i) # Outputs 1
MClass.i = 3
print(MClass.i) # Outputs 3

# In a subfile
sub.mPrint() # Outputs 3
sub.set(10)
sub.mPrint() # Outputs 10

# And back in the main
print(MClass.i) # Outputs 10

sub.py

from mclass import MClass

def mPrint():
   print(MClass.i)

def set(n):
   MClass.i = n
Diblo Dk
  • 585
  • 10
  • 26