0

I am trying to get the label list so I can compare that to my probability output. However whenever I do my iterator.getLabels() it returns a null instead of the list of labels.

int numLinesToSkip = 0;
char delimeter = ',';
int labelIndex = 0;
int numClasses = 9;
int trainBatchSize = 10000;

RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimeter):
recordReader.initialize(new FileSplit(new File("myFile.csv")));

List<DataSet> trainingData = new ArrayList<>();
List<DataSet> testingData = new ArrayList<>();

DataSetIterator iterator = new RecordReaderDataSetIterator.Builder(recordReader, trainBatchSize)
    .classification(labelIndex, numClasses)
    .build();

while (iterator.hasNext()) {
    DataSet allData = iterator.next;
    allData.shuffle();
    SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);
    trainingData.add(testAndTrain.getTrain());
    testingData.add(testAndTrain.getTest());
    System.out.println(iterator.getLabels());
}

rngoode1
  • 1
  • 1

1 Answers1

0

Unfortunately, the getLabels method is a bit deceiving. It will only give you the string labels if the underlying RecordReader does provide them.

CSVRecordReader, however, doesn't care about the string labels, so there is no simple way of getting the labels with your setup.

Side note: I wouldn't recommend loading everything into a single dataset and use the .splitTestAndTrain method to do your test/train split. With any dataset that would actually benefit from using deeplearning on it, you are likely to run into memory issues with that approach.

Paul Dubs
  • 798
  • 4
  • 8
  • So if I use ```CSVRecordReader``` there is no way to know what labels my outputs are referencing? – rngoode1 Jul 27 '22 at 17:26
  • The iterator just delegates to the underlying record reader for that. Unfortunately it doesn't have labels. My recommendation if you want that is to extend the CSV record reader and override the getLabels() with your own. – Adam Gibson Jul 27 '22 at 21:33