Search code examples
javascalajvmswigtorch

JVM: How to manage off-heap memory created by JNI


I'm building a Scala wrapper around the Torch libraries. I'm using Swig to build the glue layer. It allows me to create tensors off-heap, which I can only free by explicitly calling a static method of the library. However, I want to use the tensors in an imperative way, without having to worry about releasing the memory, much like any ordinary object in Java.

The only way I can think of doing this is to (mis-)use the JVM's garbage collector in the following way:

A 'memory manager' keeps track of the amount of off-heap memory consumed, and when a threshold is reached, it calls System.gc().

object MemoryManager {    
  val Threshold: Long = 2L * 1024L * 1024L * 1024L // 2 GB
  val FloatSize = 4
  private val hiMemMark = new AtomicLong(0)

  def dec(size: Long): Long = hiMemMark.addAndGet(-size * FloatSize)
  def inc(size: Long): Long = hiMemMark.addAndGet(size * FloatSize)

  def memCheck(size: Long): Unit = {
    val level = inc(size)
    if (level > Threshold) {
      System.gc()
    }
  }
}

The tensors themselves are wrapped in a class, with a finalize method, that frees the off-heap memory, like so:

class Tensor private (val payload: SWIGTYPE_p_THFloatTensor) {
  def numel: Int = TH.THFloatTensor_numel(payload)

  override def finalize(): Unit = {
    val memSize = MemoryManager.dec(numel)
    TH.THFloatTensor_free(payload)
  }    
}

Tensor creation is done by a factory method, that notifies the memory manager. For example, to create a Tensor of zeros:

object Tensor {
  def zeros(shape: List[Int]): Tensor = {
      MemoryManager.memCheck(shape.product)
      val storage = ... // boilerplate
      val t = TH.THFloatTensor_new
      TH.THFloatTensor_zeros(t, storage)
      new Tensor(t)
  }
}

I realize this is a naive approach, but can I get away with this? It seems to work fine, also when running in parallel (which generates lots of superfluous calls to System.gc() but otherwise nothing) Or can you think of a better solution?

Thank you.


Solution

  • There's a more deterministic option - explicitly managed regions of memory

    So, roughly if we had a class like this:

    class Region private () {
      private val registered = ArrayBuffer.empty[() => Unit]
      def register(finalizer: () => Unit): Unit = registered += finalizer
      def releaseAll(): Unit = {
        registered.foreach(f => f()) // todo - will leak if f() throws
      }
    }
    

    We could have a method implementing so-called "Loan pattern" that gives us a fresh region and then handles the deallocation

    object Region {
      def run[A](f: Region => A): A = {
        val r = new Region
        try f(r) finally r.releaseAll()
      }
    }
    

    Then something that requires manual deallocation could be described as taking an implicit Region:

    class Leakable(i: Int)(implicit r: Region) {
      // Class body is constructor body, so you can register finalizers
      r.register(() => println(s"Deallocated foo $i"))
    
      def foo() = println(s"Foo: $i")
    }
    

    Which you would be able to use in a fairly boilerplate-free way:

    Region.run { implicit r =>
      val a = new Leakable(1)
      val b = new Leakable(2)
      b.foo()
      a.foo()
    }
    

    This code produces the following output:

    Foo: 2
    Foo: 1
    Deallocated foo 1
    Deallocated foo 2
    

    Such approach is limiting a little bit (if you try to assign a Leakable to a variable outside the closure passed in run, its scope will not be promoted), but will be faster and guaranteed to work even if calls to System.gc are disabled.