I have a struct
#[pyclass]
pub struct DynMat {
...
}
and I have this function
#[pyfunction]
#[text_signature = "(tensor/)"]
pub fn exp<'py>(py: Python<'py>, tensor_or_scalar: &'py PyAny) -> PyResult<&'py PyAny> {
// I need to return &PyAny because I might either return PyFloat or DynMat
if let Ok(scalar) = tensor_or_scalar.cast_as::<PyFloat>() {
let scalar: &PyAny = PyFloat::new(py, scalar.extract::<f64>()?.exp());
Ok(scalar)
} else if let Ok(tensor) = tensor_or_scalar.cast_as::<PyCell<DynMat>>() {
let mut tensor:PyRef<DynMat> = tensor.try_borrow()?;
let tensor:DynMat = tensor.exp()?;
// what now? How to return tensor
}
}
The question is, how can I return a Rust struct marked with pyclass
from a function that expects PyResult<&'py PyAny>
I assume it's the tensor
you want to return.
If your return type was PyResult<DynMat>
you could just return that and let automatic conversion kick in. But I assume depending on whether you have a scalar or a tensor you will return different types.
So, right now you have tensor
as a DynMat
in an owned value, and we need to move that to the python heap. Here's how that looks like:
let tensor_as_py = Py::new(py, tensor)?.into_ref(py);
return Ok(tensor_as_py);
PS: You can write your attempt at conversion more concisely too:
pub fn blablabla() {
let tensor: PyRefMut<DynMat> = tensor_or_scalar.extract();
if let Ok(tensor) = tensor {
let tensor = tensor.exp();
But looking at your code, there's one more thing that's confusing me:
To exponentiate the tensor, you're borrowing it mutably. That suggests to me that the exponentiation will be in place. So why do you then need to return it, too?
Or is this meant to be a reference back to the original tensor? In that case, I'd get rid of the variable shadowing so that you can just return the PyRefMut<DynMat>
, which you can convert to a &PyAny
via from
or into
.
But actually, tensor.exp()?
seems to be return an owned value of type DynMat
, so it seems like a new tensor is created after all. In that case, yes, you need to move it from Rust to the python heap with the Py::new
method shown above.
EDIT:
Previous version used as_ref(py)
instead of into_ref(py)
. The former borrows from the Py<_>
object to give you a reference, but the latter actually consumes the Py<_>
object.
The documentation actually explains exactly your use case here https://docs.rs/pyo3/0.13.2/pyo3/prelude/struct.Py.html#method.into_ref