Search code examples
javaapache-sparkdeep-learninggpudeeplearning4j

DL4J running (not training) LSTM neural networks on GPUs with Apache Spark?


I have the need to running several (hundreds) of already trained LSTM neural networks with realtime data (on which new time steps are fed very frequently). These LSTM neural networks are implemented using deeplearining4j. In order to run all of these efficiently, I'd like to have them use GPUs to perform their calculations so that I could run hundreds of these with a large stream of realtime data.

I know I can train neural networks using GPUs.

My question is: can I execute them over realtime data using rnnTimeStep() on GPUs was well?

Any pointers are very appreciated, I spent many hours searching but can't find anything on this. Only material describing training on GPUs.

Don't worry about the GPU overhead, I'm accounting for it, and I know this is an unusual thing to be doing. Just need to know if it's possible and if there are any pointers how to go about it.

Thanks!


Solution

  • Adam's answer doesn't really tell the whole story. You can do real time inference on GPUs with Spark Streaming, but Spark does make it quite a bit harder than it could be. Also because you have a hundred models that you need to do inference it becomes quite the challenge.

    One big hurdle is that unless you are running recent versions of YARN, it really has no notion of GPUs as resources. So you'll have to use a cluster that you can control the configuration of so that you can keep the number of executors per node matched to the number of GPUs. If you need this cluster to do things other things as well then you'll have to use placement labels.

    Assuming the configuration is ready to go, the next problem is the shear number of models. In general when using DL4J in spark, you'll want to use RDD#mapPartitions so that you can get a whole partition's worth of data on a single worker thread (which should = 1 GPU). It's the job of the map to load the model (caching it in a thread local) and then break up the partition into minibatches and feed them to Model#output. (DL4J/ND4J will handle mapping each thread onto 1 GPU.) Maps in Spark are by default "the whole cluster" so all the data will be split up evenly. So each node will load and unload each one of the hundred models in series. Which would be inefficient and not exactly real time.

    If all 100 models are independent one (not great) option is to amplify the data by creating a PairRDD of [ModelId, DataSet] (copying the DataSet 100 times) and doing a fancy ReduceByKey in a single Spark Job. To reduce the killer shuffle (or if the models aren't independent) you'll need to create N spark streaming jobs with a limited max number of executors, listening on a Kafka Topic. If the models have a more like a DAG then you'll really start to fight with Spark's model, and what you want in that situation is something more like Apache Storm.

    Last time I worked with Storm, it only presented things one at a time, so you'll have to configure it properly so that you can create minibatches that maximize the GPU usage.