Search code examples
pythonasynchronousrustasync-awaitpyo3

How to pass an async function as a parameter in Rust PyO3


When we write vanilla rust and we have to pass in an async function as an argument to another function we do the following:

pub f<F,'a>(
    test: &dyn Fn(&'a mut String, String, String, TcpStream) -> F,
) where
    F: Future<Output = ()> + 'a,

But when I do the same on a #![pyfunction] expecting to get an async python function, I am getting an error.

e.g async def fn():
            ....

On reading the docs of PyO3, I found out that I can include PyAny as a param.

But, on implementing the following function:

pub fn start_server(test: PyAny) {
  test.call0();
}

I get the following error.

[rustc E0277] [E] the trait bound `pyo3::PyAny: pyo3::FromPyObject<'_>` is not satisfied

expected an implementor of trait `pyo3::FromPyObject<'_>`

note: required because of the requirements on the impl of `pyo3::FromPyObject<'_>` for `pyo3::PyAny`

How can I implement this in my code. I would understand if this is not possible , if that is the case, I would request you to please recommend me an alternative.

UPDATE:

I have found an alternative where I create an empty struct and call the method in the following way. But I would really appreciate if I can get through without creating an empty struct.

#[pymethods]
impl Server {
    #[new]
    fn new() -> Self {
        Self {}
    }

    fn start(mut self_: PyRefMut<Self>, test: &PyAny) {
        test.call0();
    }
}

But on passing an async function as param gives an error of

RuntimeWarning: coroutine
  s.start(h)
RuntimeWarning: Enable tracemalloc to get the object allocation traceback

Solution

  • Your function needs to take a reference, i.e. &PyAny. PyAny as an owned value does not implement FromPyObject, which is why you got the error.

    // lib.rs
    use pyo3::prelude::*;
    use pyo3::wrap_pyfunction;
    
    #[pyfunction]
    fn foo(x: &PyAny) -> PyResult<&PyAny> {
        x.call0()
    }
    
    #[pymodule]
    fn async_pyo3(py: Python, m: &PyModule) -> PyResult<()> {
        m.add_function(wrap_pyfunction!(foo, m)?).unwrap();
    
        Ok(())
    }
    
    
    import async_pyo3
    
    async def bar():
        return "foo"
    
    awaitable = async_pyo3.foo(bar) # <coroutine object bar at 0x7f8f6aa01340>
    print(await awaitable) # "foo"
    

    As such, the fix of moving it to an method on Server most likely was not the fix, but just coincidence since you changed test to &PyAny.

    There is a whole section in the PyO3 documentation about integrating Python and Rust async / await