Search code examples
pythonrustpyo3

PyO3 convert rust struct to &PyAny


I have a struct

#[pyclass]
pub struct DynMat {
   ...
}

and I have this function

#[pyfunction]
#[text_signature = "(tensor/)"]
pub fn exp<'py>(py: Python<'py>, tensor_or_scalar: &'py PyAny) -> PyResult<&'py PyAny> {
    // I need to return &PyAny because I might either return PyFloat or DynMat
    if let Ok(scalar) = tensor_or_scalar.cast_as::<PyFloat>() {
        let scalar: &PyAny = PyFloat::new(py, scalar.extract::<f64>()?.exp());
        Ok(scalar)
    } else if let Ok(tensor) = tensor_or_scalar.cast_as::<PyCell<DynMat>>() {
        let mut tensor:PyRef<DynMat> = tensor.try_borrow()?;
        let tensor:DynMat = tensor.exp()?;
        // what now? How to return tensor
    }
}

The question is, how can I return a Rust struct marked with pyclass from a function that expects PyResult<&'py PyAny>


Solution

  • I assume it's the tensor you want to return.

    If your return type was PyResult<DynMat> you could just return that and let automatic conversion kick in. But I assume depending on whether you have a scalar or a tensor you will return different types.

    So, right now you have tensor as a DynMat in an owned value, and we need to move that to the python heap. Here's how that looks like:

    let tensor_as_py = Py::new(py, tensor)?.into_ref(py);
    return Ok(tensor_as_py);
    
    

    PS: You can write your attempt at conversion more concisely too:

    pub fn blablabla() {
      let tensor: PyRefMut<DynMat> = tensor_or_scalar.extract();
      if let Ok(tensor) = tensor {
        let tensor = tensor.exp();
    

    But looking at your code, there's one more thing that's confusing me:

    To exponentiate the tensor, you're borrowing it mutably. That suggests to me that the exponentiation will be in place. So why do you then need to return it, too?

    Or is this meant to be a reference back to the original tensor? In that case, I'd get rid of the variable shadowing so that you can just return the PyRefMut<DynMat>, which you can convert to a &PyAny via from or into.

    But actually, tensor.exp()? seems to be return an owned value of type DynMat, so it seems like a new tensor is created after all. In that case, yes, you need to move it from Rust to the python heap with the Py::new method shown above.

    EDIT: Previous version used as_ref(py) instead of into_ref(py). The former borrows from the Py<_> object to give you a reference, but the latter actually consumes the Py<_> object.

    The documentation actually explains exactly your use case here https://docs.rs/pyo3/0.13.2/pyo3/prelude/struct.Py.html#method.into_ref