So it looks like tensorflow for Python and the currently released Jython version aren't compatible, so I'm writing my AI model (a GAN) in Java. I'm following a python guide in Keras and Tensorflow for my GAN model. I've figured out how to set up my python neural network code in Java using DeepLearning4j. Problem is, I can't set up the training function in DeepLearning4j.
Here's the python training code I'm following from:
generator_optimizer = tf.keras.optimizers.Adam(generator_lr) # learning rate for generator
discriminator_optimizer = tf.keras.optimizers.Adam(discriminator_lr) # learning rate for discriminator
seed = tf.random.normal([num_examples_to_generate, noise_dim])
# ignore this; I already have dataset code set up
train_dataset = tf.data.Dataset.from_tensor_slices(train_images_scaled).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
@tf.function
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True) #generator model is already set up under the name "generator"
real_output = discriminator(images, training=True) #discriminator model is already set up as "discriminator"
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output) #predefined function that calculates the loss of the generator as a decimal value using the loss of the discriminator; uses crossentropy
disc_loss = discriminator_loss(real_output, fake_output) #predefined function that calculates the loss of the discriminator as a decimal value
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
for epoch in range(epochs):
start = time.time() # time module of course
for structure_batch in dataset:
train_step(structure_batch)
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
In regards to what the purpose of the GAN is, it's to generate a structure in Minecraft by having a one-hot encoding algorithm for the different blocks, but I don't think that's necessary information for this.
I just want the above Python code to be "translated" into Java using the DeepLearning4j library. (I also have the tensorflow library for Java, but I think that's not directly compatible with DL4J)