0

So it looks like tensorflow for Python and the currently released Jython version aren't compatible, so I'm writing my AI model (a GAN) in Java. I'm following a python guide in Keras and Tensorflow for my GAN model. I've figured out how to set up my python neural network code in Java using DeepLearning4j. Problem is, I can't set up the training function in DeepLearning4j.

Here's the python training code I'm following from:

generator_optimizer = tf.keras.optimizers.Adam(generator_lr) # learning rate for generator
discriminator_optimizer = tf.keras.optimizers.Adam(discriminator_lr) # learning rate for discriminator


seed = tf.random.normal([num_examples_to_generate, noise_dim]) 

# ignore this; I already have dataset code set up
train_dataset = tf.data.Dataset.from_tensor_slices(train_images_scaled).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:

      generated_images = generator(noise, training=True) #generator model is already set up under the name "generator"

      
      real_output = discriminator(images, training=True) #discriminator model is already set up as "discriminator"
      fake_output = discriminator(generated_images, training=True)

      
      gen_loss = generator_loss(fake_output) #predefined function that calculates the loss of the generator as a decimal value using the loss of the discriminator; uses crossentropy
      disc_loss = discriminator_loss(real_output, fake_output) #predefined function that calculates the loss of the discriminator as a decimal value

    
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

def train(dataset, epochs):

  for epoch in range(epochs):
    start = time.time() # time module of course

    for structure_batch in dataset:
      train_step(structure_batch)
    

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

In regards to what the purpose of the GAN is, it's to generate a structure in Minecraft by having a one-hot encoding algorithm for the different blocks, but I don't think that's necessary information for this.

I just want the above Python code to be "translated" into Java using the DeepLearning4j library. (I also have the tensorflow library for Java, but I think that's not directly compatible with DL4J)

1 Answers1

0

Use keras import.

External link here: https://deeplearning4j.konduit.ai/v/en-1.0.0-m1/deeplearning4j/how-to-guides/keras-import/api-reference

For posterity, the way to do this depending on whether you have a functional model:

public static ComputationGraph importKerasModelAndWeights(InputStream modelHdf5Stream) throws IOException, UnsupportedKerasConfigurationException, InvalidKerasConfigurationException

or a sequential model:

public static MultiLayerNetwork importKerasSequentialModelAndWeights(InputStream modelHdf5Stream,
                                                                         boolean enforceTrainingConfig)
            throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException

Save the model as an H5 file. Load the model in to deeplearning4j using the above methods, you can then use the deeplearning4j dataset iterators to train your model.

For the input data pipeline, I would put that in a separate question if possible. If you want more details, please do ask in the comments.

Adam Gibson
  • 3,055
  • 1
  • 10
  • 12
  • Hi, actually when I tried making a model in python and importing it to Java, DL4J raised issues with my usage of the "Conv3DTranspose" layer. It refused to recognize that layer so I couldn't import it, which is why I decided to construct the entire model in Java using DL4J's own library. I could do that, I just don't know how to write this one training function – Rishab Borah Jul 19 '21 at 22:58
  • First of all, next time that happens please raise an issue. (Feedback is crucial to making sure we support the features users need. (This is true for every open source project, doesn't matter what) The work around for that would be to add a custom layer. You can define a custom layer using samediff (our equivalent of tensorflow) - start with just doing that - you may find more about that here: https://deeplearning4j.konduit.ai/keras-import/custom-layers – Adam Gibson Jul 19 '21 at 23:39
  • There is an existing DL4J layer functioning the same as Conv3DTranspose called "Deconvolution3D." Is there any way to register the Deconvolution3D layer to be mapped to the name "Conv3DTranspose" when importing? – Rishab Borah Jul 21 '21 at 21:35