0

I am currently trying to set up a feed-forward NN using the MXNet R API. I would like to implement a custom loss function that uses pre-fixed weights defined by myself to be c(7,8,9). In Tensorflow, there is the option to define variables as non-trainable, which makes sure these variables are not modified during the training process. This is exactly what I would need for my weights! Unfortunately, I have not found any way of implementing this. Here is my code:

data <- mx.symbol.Variable('data')
label <- mx.symbol.Variable('label')
weights <- mx.symbol.Variable(name='weights')

... [some network layers]...

fc2 <- mx.symbol.FullyConnected(data=tanh3, num_hidden=length(predictable_errors))
softmax <- mx.symbol.SoftmaxActivation(data=fc2, name="softmax_activation")
weighted_l2 <- mx.symbol.sum(mx.symbol.square(softmax - label)*weights)
loss <- mx.symbol.MakeLoss(data=weighted_l2)

model <- mx.model.FeedForward.create(loss, X=train.x, y=train.y, ctx=mx.cpu(), arg.params = list(weights=mx.nd.array( array(c(7,8,9), dim=c(3,1)), mx.cpu() )), num.round=1, learning.rate=0.05, momentum=0.9, array.batch.size = 1, eval.metric=mx.metric.accuracy, epoch.end.callback=mx.callback.log.train.metric(1))

I know that the Python API offers the function set_lr_mult, with which I could set the learning rate to zero for "weights", but with R this does not seem to be an option. Would you have any suggestions?

Many thanks in adavance!

ge.org
  • 69
  • 3

1 Answers1

0

You can do this by using Module instead of FeedForward, and when you use module, you can pass the fixed parameters which you dont want to train.

model <- mx.mod.Module(loss, data_names, label_names, ctx=mx.cpu(),
                            fixed_param_names=[list of parameters you do not want to train for])

You can read more here

Community
  • 1
  • 1
Karishma Malkan
  • 2,069
  • 1
  • 16
  • 14
  • Thank you for your reply! I have looked into the model API and it seems only to be available for Python. While I am prepared to switch to Python in the absence of alternatives, I would first like to explore all possibilites of the R API. – ge.org Feb 06 '17 at 07:55