1

I wanted to rewrite the create_model function into Keras Functional API. Running it on TPU though when I translate it gives me an error about using a Placeholder in create_method function. In the original example the author didn't put an explicit Placeholder into the create_method function. I am using the Keras Input function since I need to instantiate a Keras tensor to get started, obviously this is a place holder. Is there a way to get rid of the Placeholder inside my create_method function?

Here is my snippet of my code:

def create_model(data_format):

  if data_format == 'channels_first':
    input_shape = [1, 28, 28]
  else:
    assert data_format == 'channels_last'
    input_shape = [28, 28, 1]

  l = tf.keras.layers
  m = tf.keras.models
  b = tf.keras.backend
  v = tf.contrib.layers

  # The model consists of a sequential chain of layers, so tf.keras.Sequential
  # (a subclass of tf.keras.Model) makes for a compact description.

  input = l.Input(shape=(28, 28, 1))

  visible = l.Reshape(target_shape=input_shape, input_shape=(28*28,))(input)

When I create it from the provided MNIST TPU Code I get the error

placeholder outside of the infeed

But I also cant run it without the Placeholder as in the Sequential Code or is there a way to do this?

craft
  • 495
  • 5
  • 16

1 Answers1

0

Why do you need to instantiate the Keras tensor with a placeholder? If you just need a model for using in Keras, you can use following code snippet:

NUM_CLASSES = 10

def mnist_model(input_shape):
  """Creates a MNIST model."""
  model = tf.keras.models.Sequential()
  model.add(
      tf.keras.layers.Conv2D(
          32, kernel_size=(3, 3), activation='relu', input_shape=input_shape))
  model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))
  model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))
  model.add(tf.keras.layers.Dropout(0.25))
  model.add(tf.keras.layers.Flatten())
  model.add(tf.keras.layers.Dense(128, activation='relu'))
  model.add(tf.keras.layers.Dropout(0.5))
  model.add(tf.keras.layers.Dense(NUM_CLASSES, activation='softmax'))
  return model


def main():
  ...
  input_shape = (28, 28, 1)
  model = mnist_model(input_shape)
  ...
aman2930
  • 275
  • 2
  • 9