1

I am trying to build a simple neural network to represent a logical AND.

As I am new to machine learning and the Deep Java Library I was following the beginner tutorial: https://docs.djl.ai/jupyter/tutorial/01_create_your_first_network.html

The result of the tutorial was good and I got the correct results.

Then I have modified the code to:

  1. read data from a CSV file
  2. use the data from the CSV file for training
  3. classify an input vector of two float values

The code is shown below. Unfortunately the result of the classification is not as expected.

When I use:

float one [] = {1f,1f};
classify(one);

I get the result:

0: 0.5816633701324463
1: 0.4183366000652313

When I use:

float zero [] = {1f,0f};
classify(zero);

I get the result:

0: 0.5276625156402588
1: 0.47233742475509644

So there is something obviously wrong, but I do not know where to start:

  1. data / training
  2. model type
  3. network setup

Maybe someone can help me to find the solution and show me the mistake I am making.

Java code:

import ai.djl.*;
import ai.djl.training.*;
import java.io.IOException;
import java.nio.file.*;
import ai.djl.ndarray.types.*;
import ai.djl.training.loss.*;
import ai.djl.training.listener.*;
import ai.djl.training.evaluator.*;
import ai.djl.basicmodelzoo.basic.*;
import java.util.*;
import java.util.stream.*;
import org.apache.commons.csv.CSVFormat;
import ai.djl.ndarray.*;
import ai.djl.modality.*;
import ai.djl.translate.*;
import ai.djl.ndarray.NDList;
import ai.djl.translate.TranslateException;
import ai.djl.basicdataset.tabular.CsvDataset;
import ai.djl.basicdataset.tabular.utils.Feature;

public class App 
{
    public static void main( String[] args )
    {
        boolean train = false;

        if(train) {
            try {
                training();
            } catch (Exception e) {
                System.out.println("[ERROR] Could not train");
                e.printStackTrace();
            }
        } else {
            try {
                // define some input vectors for the neural network
                float zero [] = {1f,0f};
                float zero2 [] = {0f,0f};
                float one [] = {1f,1f};
    
                classify(zero);
            } catch (Exception e) {
                System.out.println("[ERROR] Could not classify");
                e.printStackTrace();
            }
        }
        
    }

    /**
     * Classify with the trained neural network
     * @throws MalformedModelException
     * @throws IOException
     * @throws TranslateException
     */
    static void classify(float [] input) throws MalformedModelException, IOException, TranslateException {

        Path modelDir = Paths.get("build/mlp");
        Model model = Model.newInstance("mlpBlock");
        model.setBlock(new Mlp(2, 2, new int[] {2}));
        model.load(modelDir);

        Translator<float[], Classifications> translator = new Translator<float[], Classifications>() {

            @Override
            public NDList processInput(TranslatorContext ctx, float[] input) {
                
                NDArray array = ctx.getNDManager().create(input);
                NDList ndList = new NDList();
                ndList.add(array);
                return ndList;
            }
        
            @Override
            public Classifications processOutput(TranslatorContext ctx, NDList list) {
                NDArray probabilities = list.singletonOrThrow().softmax(0);
                List<String> classNames = IntStream.range(0, 2).mapToObj(String::valueOf).collect(Collectors.toList());
                return new Classifications(classNames, probabilities);
            }
        
            @Override
            public Batchifier getBatchifier() {
                return Batchifier.STACK;
            }
        };

        var predictor = model.newPredictor(translator);

        var classifications = predictor.predict(input);

        for(int i = 0; i < classifications.getProbabilities().size(); i++) {
            System.out.println(classifications.getClassNames().get(i) + ": " + classifications.getProbabilities().get(i));
            
        }
    }

    /**
     * Train the neural network
     * @throws IOException
     * @throws TranslateException
     */
    static void training() throws IOException, TranslateException {
    
        Path csvPath = Paths.get("TrainingDataAND.csv");

        CSVFormat csvFormat = CSVFormat.DEFAULT.withHeader();

        CsvDataset dataset = CsvDataset.builder()
            .optCsvFile(csvPath)
            .addFeature(new Feature("in1", true))
            .addFeature(new Feature("in2", true))
            .addLabel(new Feature("result", true))
            .setSampling(2, true)
            .setCsvFormat(csvFormat)
            .build();

        Model model = Model.newInstance("mlpBlock");
        model.setBlock(new Mlp(2, 2, new int[] {2}));

        DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
            .addEvaluator(new Accuracy())
            .addTrainingListeners(TrainingListener.Defaults.logging());

        Trainer trainer = model.newTrainer(config);

        trainer.initialize(new Shape(1, 2));

        int epoch = 2;

        EasyTrain.fit(trainer, epoch, dataset, null);

        Path modelDir = Paths.get("build/mlp");
        Files.createDirectories(modelDir);

        model.setProperty("Epoch", String.valueOf(epoch));

        model.save(modelDir, "mlpBlock");

    }
}

Training data (TrainingDataAND.csv):

in1,in2,result
1,1,1
1,0,0
0,1,0
0,0,0
Philipp S
  • 141
  • 1
  • 6

0 Answers0