Search code examples
rustpyo3

rust: py03 how to get reference to struct member


How do I edit Header.a via Packet.Header.a?

#![allow(dead_code)]
use pyo3::prelude::*;

#[pyclass]
#[derive(Clone)]
pub struct Header {
    #[pyo3(get, set)]
    a: u32,
    #[pyo3(get, set)]
    b: u32,
}
#[pymethods]
impl Header {
    #[new]
    fn new(a: u32, b: u32) -> Self {
        Header { a, b }
    }
}

#[pyclass]
/// Structure used to hold an ordered list of headers
pub struct Packet {
    #[pyo3(get, set)]
    pub h: Header,
}
#[pymethods]
impl Packet {
    #[new]
    fn new() -> Self {
        Packet {
            h: Header { a: 0, b: 0 },
        }
    }
}

#[pymodule]
fn pyo3test(_py: Python, m: &PyModule) -> PyResult<()> {
    m.add_class::<Header>()?;
    m.add_class::<Packet>()?;
    Ok(())
}

After running "maturin develop", within python

from pyo3test import *
p = Packet()
print(p.h.a) # prints 0
h = p.h
h.a = 1
print(h.a) -> # prints 1
print(p.h.a) -> # still prints 0
p.h.a = 1
print(p.h.a) # still prints 0

This seems against python semantics. h is a reference to p.h. An update to h should have updated p.h. How do I implement the get trait to return a reference to Packet.Header?


Solution

  • Disclaimer: I'm not an expert on PyO3. I just somewhat have an understanding of how Rust and Python work. So take everything I say with a grain of salt.

    The problem here is that Rust and Python have a very different memory model. While Rust is ownership based, Python is reference-counted. This creates some challenges in implementing classes that are usable in both languages.

    Specifically for getters/setters, it seems that PyO3 decided to clone() instead of reference counting:

    For get the field type must implement both IntoPy<PyObject> and Clone.

    There are further incompatibilities: Python is not typesafe. That means that any item could be of any type, for example, you could write: p.h.a = "Test", although the definition Header::a in Rust is clearly a u32. This also "seems to be against python semantics".


    That said, you could achieve something similar by using reference counters internally. You cannot, however (at the time of writing and from what I could tell) expose those reference counters to Python. (Meaning, you cannot use something like Arc<Header> and return this from a getter)

    But you could make Header itself somewhat of a reference counter:

    use std::sync::{
        atomic::{AtomicU32, Ordering},
        Arc,
    };
    
    use pyo3::prelude::*;
    
    #[pyclass]
    #[derive(Clone)]
    pub struct Header {
        a: Arc<AtomicU32>,
        b: Arc<AtomicU32>,
    }
    #[pymethods]
    impl Header {
        #[new]
        fn new(a: u32, b: u32) -> Self {
            Header {
                a: Arc::new(AtomicU32::new(a)),
                b: Arc::new(AtomicU32::new(b)),
            }
        }
    
        #[getter]
        fn get_a(&self) -> PyResult<u32> {
            Ok(self.a.load(Ordering::Acquire))
        }
    
        #[setter]
        fn set_a(&mut self, value: u32) -> PyResult<()> {
            self.a.store(value, Ordering::Release);
            Ok(())
        }
    
        #[getter]
        fn get_b(&self) -> PyResult<u32> {
            Ok(self.b.load(Ordering::Acquire))
        }
    
        #[setter]
        fn set_b(&mut self, value: u32) -> PyResult<()> {
            self.b.store(value, Ordering::Release);
            Ok(())
        }
    }
    
    #[pyclass]
    /// Structure used to hold an ordered list of headers
    pub struct Packet {
        #[pyo3(get, set)]
        pub h: Header,
    }
    #[pymethods]
    impl Packet {
        #[new]
        fn new() -> Self {
            Packet {
                h: Header::new(0, 0),
            }
        }
    }
    
    /// A Python module implemented in Rust.
    #[pymodule]
    #[pyo3(name = "rust_python_test")]
    fn rust_python_test(_py: Python, m: &PyModule) -> PyResult<()> {
        m.add_class::<Header>()?;
        m.add_class::<Packet>()?;
        Ok(())
    }
    
    #!/usr/bin/env python3
    
    from rust_python_test import Packet
    
    p = Packet()
    print(p.h.a) # prints 0
    h = p.h
    h.a = 1
    print(h.a) # prints 1
    print(p.h.a) # prints 1
    p.h.a = 1
    print(p.h.a) # prints 1