There is an optimization for dl4j that only works with GPUs:
DataTypeUtil.setDTypeForContext(DataBuffer.Type.HALF)
I'd like to only make that call if the backend is a GPU.
In my Maven pom.xml, I've got
<!-- CPU or GPU -->
<nd4j.backend>nd4j-native-platform</nd4j.backend>
<!--<nd4j.backend>nd4j-cuda-8.0-platform</nd4j.backend>-->
And I was looking at ways to read that value from Java, all of which seem clunky. It would be much easier if I could query dl4j or nd4j for "What flavor of backend are we running?" and then make the optimization call based on that.
Edit from answer:
Nd4jBackend.load().let { be->
println("nd4j Backend: ${be.javaClass.simpleName}")
if(be.javaClass.simpleName.toLowerCase().contains("gpu")) {
println("Optimizing for GPU")
DataTypeUtil.setDTypeForContext(DataBuffer.Type.HALF)
}
}
See if you can use Nd4j.backend
. Printing it with cuda enabled I get:
org.nd4j.linalg.jcublas.JCublasBackend
and without cuda:
org.nd4j.linalg.cpu.nativecpu.CpuBackend