1

I was following this Image classification tutorial and Text Generation tutorial. So I've implemented transfer learning with fine-tuning on my dataset but I don't know how to access labels whenever I am doing predictions. I transformed my data into the right shape (tf.data.Dataset) so I am using the Keras model for predictions. So for example if I want just to predict one label: keras_model.predict(federated_train_data[0])

federated_train_data consists of following elements:

(TensorSpec(shape=(None, 32, 32, 3), dtype=tf.float32, name=None),
 TensorSpec(shape=(None,), dtype=tf.int64, name=None))

First Tensor is an image shape and the second one represents encoded labels.

My goal is to illustrate what are true and predicted labels of an image, for example:(Predicted classes)

TLDR: Is there a way that you can access just labels when you have tf.data.Dataset?

ana
  • 58
  • 4

1 Answers1

2

If federated_train_data is a tf.data.Dataset whose .element_spec property returns:

(TensorSpec(shape=(None, 32, 32, 3), dtype=tf.float32, name=None),
 TensorSpec(shape=(None,), dtype=tf.int64, name=None))

Then iterating over the dataset is possible:

# Get the first batch
first_batch = next(iter(federated_train_data)) 

# Examine all batches
for batch in federated_train_data:
  print(batch)

From the .element_spec we know each batch is a 2-tuple of (features, labels), so we can get the labels using the second index:

labesl = first_batch[1]

# Or unpack
features, labels = first_batch

Combining this with the model predictions:

for batch in federated_train_data:
  features, labels = batch
  predictions = keras_model.predict(features)
  # Now we have all three pieces: features, labels, and predictions.
Zachary Garrett
  • 2,911
  • 15
  • 23