Search code examples
rustassociated-types

Associated types for generic struct?


I want to emulate the x86 al/ah/ax byte addressing behavior where al/ah addresses the low and hi bytes of the 16-bit ax register. I want it generic so that I can use it with not just u16/u8 values, but u32/u16 and u64/u32.

#[derive(Clone, Copy)]
#[repr(C)] // Ensure field order
struct Halves<H> {
    lo: H,
    hi: H,
}

#[derive(Clone, Copy)]
union Addressed<T: Copy, H: Copy> {
    whole: T,
    halves: Halves<H>,
}

You'd need to use the above as Addressed<u16, u8>, but the u8 is implied by the u16. It'd be better expressed as Addressed<u16>.

Clearly, H depends on T, like an associated type. Is there a mechanism in Rust to enforce a specific type H for a given type T?


Solution

  • Yes, you can achieve this with custom trait and associated type (as was mentioned by @kmdreko), here's one way to do it:

    trait Register: Copy {
        type Half: Copy;
        fn split(self) -> (Self::Half, Self::Half);
    }
    macro_rules! register {
        ($whole: ty, $half: ty) => {
            impl Register for $whole {
                type Half = $half;
                fn split(self) -> (Self::Half, Self::Half) {
                    let u = self >> (Self::BITS / 2);
                    let l = self & (Self::MAX >> (Self::BITS / 2));
                    (u as Self::Half, l as Self::Half)
                }
            }
        };
    }
    
    register!(u16, u8);
    register!(u32, u16);
    register!(u64, u32);
    register!(u128, u64);
    
    #[derive(Clone, Copy)]
    union Addressed<R: Register> {
        whole: R,
        halves: (R::Half, R::Half),
    }
    
    impl<R: Register> Addressed<R> {
        fn whole(whole: R) -> Self {
            Self { whole }
        }
    
        fn halves(whole: R) -> Self {
            Self {
                halves: whole.split(),
            }
        }
    }
    
    #[allow(unused)]
    fn main() {
        let short = Addressed::whole(42_u16);
        let int = Addressed::halves(42_u32);
        let long = Addressed::whole(42_u64);
        let longlong = Addressed::halves(42_u128);
    }