I am trying to build a simple neural network to represent a logical AND.
As I am new to machine learning and the Deep Java Library I was following the beginner tutorial: https://docs.djl.ai/jupyter/tutorial/01_create_your_first_network.html
The result of the tutorial was good and I got the correct results.
Then I have modified the code to:
- read data from a CSV file
- use the data from the CSV file for training
- classify an input vector of two float values
The code is shown below. Unfortunately the result of the classification is not as expected.
When I use:
float one [] = {1f,1f};
classify(one);
I get the result:
0: 0.5816633701324463
1: 0.4183366000652313
When I use:
float zero [] = {1f,0f};
classify(zero);
I get the result:
0: 0.5276625156402588
1: 0.47233742475509644
So there is something obviously wrong, but I do not know where to start:
- data / training
- model type
- network setup
Maybe someone can help me to find the solution and show me the mistake I am making.
Java code:
import ai.djl.*;
import ai.djl.training.*;
import java.io.IOException;
import java.nio.file.*;
import ai.djl.ndarray.types.*;
import ai.djl.training.loss.*;
import ai.djl.training.listener.*;
import ai.djl.training.evaluator.*;
import ai.djl.basicmodelzoo.basic.*;
import java.util.*;
import java.util.stream.*;
import org.apache.commons.csv.CSVFormat;
import ai.djl.ndarray.*;
import ai.djl.modality.*;
import ai.djl.translate.*;
import ai.djl.ndarray.NDList;
import ai.djl.translate.TranslateException;
import ai.djl.basicdataset.tabular.CsvDataset;
import ai.djl.basicdataset.tabular.utils.Feature;
public class App
{
public static void main( String[] args )
{
boolean train = false;
if(train) {
try {
training();
} catch (Exception e) {
System.out.println("[ERROR] Could not train");
e.printStackTrace();
}
} else {
try {
// define some input vectors for the neural network
float zero [] = {1f,0f};
float zero2 [] = {0f,0f};
float one [] = {1f,1f};
classify(zero);
} catch (Exception e) {
System.out.println("[ERROR] Could not classify");
e.printStackTrace();
}
}
}
/**
* Classify with the trained neural network
* @throws MalformedModelException
* @throws IOException
* @throws TranslateException
*/
static void classify(float [] input) throws MalformedModelException, IOException, TranslateException {
Path modelDir = Paths.get("build/mlp");
Model model = Model.newInstance("mlpBlock");
model.setBlock(new Mlp(2, 2, new int[] {2}));
model.load(modelDir);
Translator<float[], Classifications> translator = new Translator<float[], Classifications>() {
@Override
public NDList processInput(TranslatorContext ctx, float[] input) {
NDArray array = ctx.getNDManager().create(input);
NDList ndList = new NDList();
ndList.add(array);
return ndList;
}
@Override
public Classifications processOutput(TranslatorContext ctx, NDList list) {
NDArray probabilities = list.singletonOrThrow().softmax(0);
List<String> classNames = IntStream.range(0, 2).mapToObj(String::valueOf).collect(Collectors.toList());
return new Classifications(classNames, probabilities);
}
@Override
public Batchifier getBatchifier() {
return Batchifier.STACK;
}
};
var predictor = model.newPredictor(translator);
var classifications = predictor.predict(input);
for(int i = 0; i < classifications.getProbabilities().size(); i++) {
System.out.println(classifications.getClassNames().get(i) + ": " + classifications.getProbabilities().get(i));
}
}
/**
* Train the neural network
* @throws IOException
* @throws TranslateException
*/
static void training() throws IOException, TranslateException {
Path csvPath = Paths.get("TrainingDataAND.csv");
CSVFormat csvFormat = CSVFormat.DEFAULT.withHeader();
CsvDataset dataset = CsvDataset.builder()
.optCsvFile(csvPath)
.addFeature(new Feature("in1", true))
.addFeature(new Feature("in2", true))
.addLabel(new Feature("result", true))
.setSampling(2, true)
.setCsvFormat(csvFormat)
.build();
Model model = Model.newInstance("mlpBlock");
model.setBlock(new Mlp(2, 2, new int[] {2}));
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.addEvaluator(new Accuracy())
.addTrainingListeners(TrainingListener.Defaults.logging());
Trainer trainer = model.newTrainer(config);
trainer.initialize(new Shape(1, 2));
int epoch = 2;
EasyTrain.fit(trainer, epoch, dataset, null);
Path modelDir = Paths.get("build/mlp");
Files.createDirectories(modelDir);
model.setProperty("Epoch", String.valueOf(epoch));
model.save(modelDir, "mlpBlock");
}
}
Training data (TrainingDataAND.csv):
in1,in2,result
1,1,1
1,0,0
0,1,0
0,0,0