4

I'm writing a fully connected layer using Tensorflow/Keras (TF version 2.1, Python 3.7 on Windows), but I've found that if I reshape my weights tensor before multiplying by it then Tensorflow doesn't seem to be able to calculate the gradient even if I just reshape to its own shape. Consider the following layer code:

import tensorflow as tf
import numpy as np

class FCLayer(tf.keras.layers.Layer):
    def __init__(self,output_size,cause_error = False):
        super(FCLayer,self).__init__()
        self.output_size = output_size
        self.cause_error = cause_error

    def build(self,input_shape): 
        self.input_size = input_shape[1]        
        weights = self.add_weight(shape=(self.input_size,
                                         self.output_size),
                                 initializer='random_normal',
                                 trainable=True)

        if self.cause_error:
            self.weights2 = tf.reshape( weights,
                                        shape = (self.input_size,
                                                 self.output_size))
        else:
            self.weights2 = weights

    def call(self, inputs):
        return tf.matmul(inputs, self.weights2)    

If this is used with cause_error = True, then I get the following output when training on mnist for 4 epochs (specific training code included below):

Train on 60000 samples, validate on 10000 samples
Epoch 1/4
WARNING:tensorflow:Gradients do not exist for variables ['sequential/dummy_layer/Variable:0'] when minimizing the loss.
WARNING:tensorflow:Gradients do not exist for variables ['sequential/dummy_layer/Variable:0'] when minimizing the loss.
60000/60000 [==============================] - 1s 20us/sample - loss: 2.4131 - accuracy: 0.0722 - val_loss: 2.3963 - val_accuracy: 0.0834
Epoch 2/4
60000/60000 [==============================] - 1s 12us/sample - loss: 2.4122 - accuracy: 0.0722 - val_loss: 2.3953 - val_accuracy: 0.0836
Epoch 3/4
60000/60000 [==============================] - 1s 12us/sample - loss: 2.4112 - accuracy: 0.0724 - val_loss: 2.3944 - val_accuracy: 0.0838
Epoch 4/4
60000/60000 [==============================] - 1s 13us/sample - loss: 2.4102 - accuracy: 0.0725 - val_loss: 2.3933 - val_accuracy: 0.0839

This is just a warning, but it is clear that the model is not really improving and obviously it needs those gradients.

If I set cause_error=False I instead get the expected output (no warnings, modest improvements):

Train on 60000 samples, validate on 10000 samples
Epoch 1/4
60000/60000 [==============================] - 1s 16us/sample - loss: 2.3671 - accuracy: 0.1527 - val_loss: 2.3445 - val_accuracy: 0.1508
Epoch 2/4
60000/60000 [==============================] - 1s 12us/sample - loss: 2.3293 - accuracy: 0.1596 - val_loss: 2.3072 - val_accuracy: 0.1610
Epoch 3/4
60000/60000 [==============================] - 1s 13us/sample - loss: 2.2939 - accuracy: 0.1683 - val_loss: 2.2722 - val_accuracy: 0.1720
Epoch 4/4
60000/60000 [==============================] - 1s 13us/sample - loss: 2.2609 - accuracy: 0.1784 - val_loss: 2.2397 - val_accuracy: 0.1847

I suspect I need to somehow tell Tensorflow to keep track of the gradients, but am not quite sure how. It seems to do it automatically when I use tf.matmul, and I'm pretty sure this kind of code used to work in TF 1.

The specific code I used to execute was (adapted from the mnist tutorial):

batch_size = 128
num_classes = 10
epochs = 4

# input image dimensions
img_rows, img_cols = 28, 28

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()


x_train = x_train.reshape(x_train.shape[0], img_rows* img_cols)
x_test = x_test.reshape(x_test.shape[0], img_rows*img_cols)
input_shape = (img_rows * img_cols)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)

model = tf.keras.models.Sequential()

dummy_layer = FCLayer(10, cause_error = True)
model.add( dummy_layer )
model.add( tf.keras.layers.Dense(10, activation='softmax') )

model.compile(loss=tf.keras.losses.categorical_crossentropy,
              optimizer=tf.keras.optimizers.Adadelta(),
              metrics=['accuracy'])

model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          validation_data=(x_test, y_test))
reltar
  • 43
  • 1
  • 5

2 Answers2

8

The problem is related to eager execution TF 2.0 -- any operations such as tf.reshape are run the moment they are encountered. build is only called a single time for a given model. Now, what is happening is that you are creating a tensor weights2, which is a reshaped version of the tf.Variable weights but is not itself a tf.Variable (ops generally return tensors, not variables). Because this happens in eager execution, no "record" of this is kept and weights2 has no connection to weights. Thus, when it is used in the model call, weights cannot be updated. This does not happen in the else case because here, weights2 is just another name referring to the actual tf.Variable weights.

Two way of fixing this:

  1. Use assign in build to do the reshape in place (note, I use self.w because self.weights is a reserved name for Keras layers):

    def build(self,input_shape): 
        self.input_size = input_shape[1]        
        self.w = self.add_weight(shape=(self.input_size,
                                              self.output_size),
                                       initializer='random_normal',
                                       trainable=True)
    
        if self.cause_error:
            self.w.assign(tf.reshape(self.w,
                                       shape = (self.input_size,
                                                self.output_size)))
    

This causes no error/warning, but it might not be what you want because you are modifying the original weights, which is lost. I suppose you want to rather use a modified version of weights on each call. In this case, do it in the call method:

class FCLayer(tf.keras.layers.Layer):
    def __init__(self,output_size,cause_error = False):
        super(FCLayer,self).__init__()
        self.output_size = output_size
        self.cause_error = cause_error

    def build(self,input_shape): 
        self.input_size = input_shape[1]        
        self.w = self.add_weight(shape=(self.input_size,
                                          self.output_size),
                                   initializer='random_normal',
                                   trainable=True)
    def call(self, inputs):
        weights2 = tf.reshape(self.w, (self.input_size, self.output_size)
        return tf.matmul(inputs, weights2)

This works because now the reshape operation is part of the model call graph, i.e. we can backtrace that weights2 actually came from weights, and gradients can flow.

xdurch0
  • 9,905
  • 4
  • 32
  • 38
  • 1
    Thanks, that is very helpful, both in fixing my issue, but also in helping me understand how Tensorflow handles these things. I didn't realize that operations performed in build weren't included in the model call graph (I guess the name should have been a clue). – reltar May 13 '20 at 09:35
0

Likely cause of that behaviour is lack of @tf.function decorator on the build function, i.e.

@tf.function
def build(self, input_shape):
    self.input_size = input_shape[1]
    weights = self.add_weight(shape=(self.input_size,
                                     self.output_size),
                              initializer='random_normal',
                              trainable=True)

    if self.cause_error:
        self.weights2 = tf.reshape(weights,
                                   shape=(self.input_size,
                                          self.output_size))
    else:
        self.weights2 = weights

Why is it vital? The Python Tensorflow API is just an interface to the actual implementation in C/C++. The moment you provide your custom operation (like tf.reshape) in Python to be executed as part of the graph, you have to instruct the module to compile this part of the code into "native" Tensorflow.

It doesn't matter that your reshape didn't actually reshape anything. You have "interrupted" default execution path and "injected" Python code. @tf.function should fix it.

Lukasz Tracewski
  • 10,794
  • 3
  • 34
  • 53
  • `tf.function` has nothing to do with gradient flow. `tf.reshape` is not a "custom operation" but a built-in tensorflow function. By your logic (if I don't misunderstand your answer), no model should ever work without `tf.function` because they all use many `tf.some_operation` functions under the hood. Actually, adding `tf.function` as you proposed results in an error for me (very deeply nested so I'd rather not post it). – xdurch0 May 13 '20 at 08:30
  • Actually, now that I think about it, I believe you are on the right track -- the issue is most likely related to eager execution and `tf.function` would deactivate it, which in principle would fix it. Sorry if I came across as rude -- but your solution still doesn't quite work, unfortunately. – xdurch0 May 13 '20 at 08:32
  • 1
    @xdurch0 Not at all, I appreciate critical attitude :). Thanks for sharing the solution +1. – Lukasz Tracewski May 26 '20 at 09:46