Search code examples
pythonrustpyo3

How to save numpy.ndarray in static or thread local variable in python extension written in Rust with PyO3?


I am building a simple python extension to process numpy.ndarray objects using rust-numpy crate. I want to save numpy.ndarray objects in static or thread local variables for later process:

use std::cell::RefCell;
use numpy::PyReadwriteArray1;
use pyo3::{Bound, pymodule, PyResult, types::PyModule};

thread_local! {
    static ARRAYS: RefCell<Vec::<PyReadwriteArray1<i32>>> = RefCell::new(Vec::<PyReadwriteArray1<i32>>::new());
}

#[pymodule]
fn rust_ext<'py>(m: &Bound<'py, PyModule>) -> PyResult<()> {
    #[pyfn(m)]
    fn register_array(mut x: PyReadwriteArray1<i32>) {
        x.as_array_mut()[0] = 100;
        ARRAYS.with_borrow_mut(|v| v.push(x));
    }

    Ok(())
}

But I got lifetime compile error:

static ARRAYS: RefCell<Vec::<PyReadwriteArray1<i32>>> = RefCell::new(Vec::<PyReadwriteArray1<i32>>::new());
  |                                                   ^ expected named lifetime parameter

So is it possible to save python objects in extension written in Rust? If so, how to fix 'py lifetime issue?


Solution

  • To hold a Python value outside of the GIL, you need it wrapped in a Py. You can get one from PyReadwriteArray1 using .as_unbound() which will yield a Py<PyArray1>. So your thread local would look like this:

    thread_local! {
        static ARRAYS: RefCell<Vec<Py<PyArray1<i32>>>> = RefCell::new(Vec::new());
    }
    
    let py = x.py();
    let arr = x.as_unbound().clone_ref(py);
    ARRAYS.with_borrow_mut(|v| v.push(arr));
    

    To get the value back from the Py, you need to "unlock" it with a Python<'_> token indicating you have the GIL. You will get this normally in your functions either directly or accessible by .py() on an existing parameter. You'll likely want to get it as a Bound by using .bind(). See the Py documentation for other options.

    Beyond that, you can get the PyArray1 as a PyReadwriteArray1 by using .readwrite() or .try_readwrite() via PyReadwriteArray1:

    // just getting the first one as a demonstration
    let arr = ARRAYS.with_borrow(|v| v[0].bind(py).clone());
    let x = arr.readwrite();
    
    // do whatever with the PyReadwriteArray1