0

There is an optimization for dl4j that only works with GPUs: DataTypeUtil.setDTypeForContext(DataBuffer.Type.HALF)

I'd like to only make that call if the backend is a GPU.

In my Maven pom.xml, I've got

<!-- CPU or GPU -->
<nd4j.backend>nd4j-native-platform</nd4j.backend>
<!--<nd4j.backend>nd4j-cuda-8.0-platform</nd4j.backend>-->

And I was looking at ways to read that value from Java, all of which seem clunky. It would be much easier if I could query dl4j or nd4j for "What flavor of backend are we running?" and then make the optimization call based on that.

Edit from answer:

Nd4jBackend.load().let { be->
    println("nd4j Backend: ${be.javaClass.simpleName}")
    if(be.javaClass.simpleName.toLowerCase().contains("gpu")) {
        println("Optimizing for GPU")
        DataTypeUtil.setDTypeForContext(DataBuffer.Type.HALF)
    }
}
Benjamin H
  • 5,164
  • 6
  • 34
  • 42

2 Answers2

1

See if you can use Nd4j.backend. Printing it with cuda enabled I get:

org.nd4j.linalg.jcublas.JCublasBackend

and without cuda:

org.nd4j.linalg.cpu.nativecpu.CpuBackend
reden
  • 968
  • 7
  • 14
0

It also prints out at the beginning when you start up nd4j. There should be a vendor it prints out for the backend.

Adam Gibson
  • 3,055
  • 1
  • 10
  • 12