0

I went thru the basic linear regression example in the documentation using SameDiff which works fine in Scala 3.2. I then tried to train the linear regression with input and label INDArrays and that's where I get stuck. I can't find a way to convert INDArrays into the proper DataSetIterator for the fit method. I am using SameDiff instead of the ready made layers because I want to use SameDiff for optimizing other losses for example EM algorithm or Hyperbolic Embeddings.

package main
import org.nd4j.autodiff.samediff._
import org.nd4j.linalg.factory.Nd4j
import scala.jdk.CollectionConverters._
import org.nd4j.linalg.api.buffer.DataType
import org.nd4j.weightinit.impl.XavierInitScheme
import javax.xml.crypto.Data
import org.nd4j.autodiff.samediff.TrainingConfig
import org.nd4j.linalg.learning.config.Adam
import org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator
import org.nd4j.common.primitives.Pair




@main def hello: Unit =

        val nIn = 4
        val nOut = 2

        val sd = SameDiff.create()


        //First: Let's create our placeholders. Shape: [minibatch, in/out]
        val input = sd.placeHolder("input", DataType.FLOAT, -1, nIn)
        val labels = sd.placeHolder("labels", DataType.FLOAT, -1, 1)

        //Second: let's create our variables
        val weights = sd.`var`("weights", new XavierInitScheme('c', nIn, nOut), DataType.FLOAT, nIn,nOut)
        val bias = sd.`var`("bias")


        //And define our forward pass:
        val out = input.mmul(weights).add(bias)    //Note: it's broadcast add here

        //And our loss function (done manually here for the purposes of this example):
        val difference = labels.sub(out)
        val sqDiff = sd.math().square(difference)
        val mse = sqDiff.mean("mse")

        //Let's create some mock data for this example:
        val minibatch = 10
        Nd4j.getRandom().setSeed(12345)
        val inputArr = Nd4j.rand(minibatch, nIn)
        val labelArr = Nd4j.rand(minibatch, nOut)
        println(labelArr)

        val placeholderData = Map("input" -> inputArr, "labels" -> labelArr).asJava


        //Execute forward pass:
        val loss = sd.output(placeholderData, "mse").get("mse")
        println("MSE: " + loss)

        //Calculate gradients:
        val gradMap = sd.calculateGradients(placeholderData, "weights", "bias").asScala.toMap
        System.out.println("Weights gradient:")
        System.out.println(gradMap.get("weights"))
        System.out.println("Bias gradient:")
        System.out.println(gradMap.get("bias"))

        val config = TrainingConfig.builder().
            l2(1e-4).
            updater(Adam(1e-3)).
            dataSetFeatureMapping("input").
            dataSetLabelMapping("labels").
            build()

        sd.setTrainingConfig(config)
        val data = INDArrayDataSetIterator(List(Pair(inputArr,labelArr)).asJava,32)
        val hist = sd.fit(data) //Can't find overloaded method

Additionally, I am unsure which SBT imports to use. I am currently using:

libraryDependencies += "org.nd4j" % "nd4j" % "1.0.0-M2.1"
libraryDependencies += "org.deeplearning4j" % "deeplearning4j-core" % "1.0.0-M2.1"
libraryDependencies += "org.nd4j" % "nd4j-native-platform" % "1.0.0-M2.1"

1 Answers1

0

You don't have to use an iterator. Just create a dataset and pass that to fit:

val data = DataSet(input,labels);

If you want to use an iterator, you can also use a ListDatasetIterator:

DataSetIterator iter = new ListDataSetIterator(Arrays.asList(data),batchSize));

where dataSet the one created above: https://github.com/deeplearning4j/deeplearning4j/blob/f8bb11ddf40f86f8739ce21c303a7861addf4e15/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/DataSet.java

Adam Gibson
  • 3,055
  • 1
  • 10
  • 12
  • In the second example where you use DataSetIterator in what format does the variable data needs to be? – Lukas Tycho Jul 25 '23 at 14:26
  • NDArrays for both. I don't know what you mean by "format". It depends on your problem. For ndarrays and classification, make sure it's probabilities for each class. For other problems like regression, those would just be normalized labels. – Adam Gibson Jul 29 '23 at 00:16