Search code examples
rusttypestraits

how to design the return type of trait method?


I start to implement an utility for SKU that a product takes multiple units. The basic logic is:

  • Convert the list of amounts into a single amount with base unit.
  • Split a single amount with base unit into a list of human readable amounts.

Example

let five = "5kg".parse::<Amount>()?;
let two = "2g".parse::<Amount>()?;

let sum = five + two;
let result = sum * 3;

let result = Weight.reduce(result)?;
assert_eq!(result, Amount::new(15006, Weight.base_unit()));

let result = Weight.split(result)?.into_iter().collect::<Vec<_>>();
assert_eq!(result, [Amount::new(15, kg()), Amount::new(6, g())]);

Design

The code design as below:

use std::ops::{Add, Mul};

trait Exchanger {
    type Err;

    fn rate(&self, unit: &Unit) -> Result<u32, Self::Err>;

    fn reduce<E>(&self, expr: E) -> Result<Amount, Self::Err>
    where
        for<'a> E: Reduce<&'a Self>,
        Self: Sized,
    {
        expr.reduce(self)
    }

    fn split<E>(&self, expr: E) -> Result<Split, Self::Err>
    where
        for<'a> E: Reduce<&'a Self>,
        Self: Sized,
    {
        let base = expr.reduce(self)?;
        todo!()
    }
}

impl<T: Exchanger> Exchanger for &T {
    type Err = T::Err;

    fn rate(&self, unit: &Unit) -> Result<u32, Self::Err> {
        (**self).rate(unit)
    }
}

trait Reduce<E: Exchanger> {
    fn reduce(self, e: E) -> Result<Amount, E::Err>;
}

trait Expression<E: Exchanger, Rhs, T>: Reduce<E> + Add<Rhs> + Mul<T> {}

// impls Expression<E, Rhs, T>
struct Amount(u32, Unit);

struct Unit(String);

// impls Expression<E, Rhs, T>
struct Sum<L, R>(L, R);

// impls Expression<E, Rhs, T>
struct Split(Vec<Amount>);

Problems

The design above has problems and isn't elegant, I think.

  1. The first one is I need to define an additional struct Split, since the return type of Exchanger::split maybe: Amount or Sum, but this is fine in just introduce an interface like Expression. I tried to use Box<dyn Expression> but it takes generic parameters and I can't feed them at compile-time.

  2. I need to re-implement Expression for the new type Split.

  3. Whenever I add a new type T that implements Expression, I must also implement Add<T> for all types that has implemented Expression, such combination is explosive, e.g:

    struct NewType;
    impl Add<NewType> for NewType {}
    impl Add<Amount> for NewType {}
    impl Add<Sum> for NewType {}
    impl Add<Split> for NewType {}
    
    impl Add<NewType> for Amount {}
    impl Add<NewType> for Sum {}
    impl Add<NewType> for Split {}
    
    // impl all reference types, and so on.
    impl Add<NewType> for &Amount {}
    
  4. I also need to implements all traits just for reference type, e.g:

    trait Trait {}
    
    impl <T: Trait> Trait for &T{}
    

Can you tell me a better way to achieve the current design? You can also refer to the github repository to see the full code.


Solution

  • If you only need to do what you describe as base logic, I would do something like this:

    use std::{error::Error, fmt::Display, num::ParseIntError, ops::Add, str::FromStr};
    
    /// A type representing all units of a certain quantity (e.g kg, g, t, etc.)
    pub trait Unit: Sized {
        /// The base unit. All others units are a multiples of it.
        fn base() -> Self;
    
        /// The magnitude of this unit with respect to the base unit. The base unit itself hasµ
        /// magnitude 1.
        ///
        /// For example, if the base unit is the gram, then the kilogram would have a magnitude of
        /// 1000.
        fn magnitude(&self) -> u32;
    
        /// All units of this quantity, in decreasing order of magnitude.
        ///
        /// For example: [t, kg, g]
        fn sorted_units() -> Vec<Self>;
    }
    
    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
    pub struct Quantity<U> {
        pub amount: u32,
        pub unit: U,
    }
    
    impl<U> Quantity<U> {
        pub fn new(amount: u32, unit: U) -> Self {
            Self { amount, unit }
        }
    }
    
    impl<U: Unit> Quantity<U> {
        /// Splits this quantity into different units
        pub fn split(&self) -> Vec<Self> {
            let mut amount = self.amount * self.unit.magnitude();
            let mut split = vec![];
            for unit in U::sorted_units() {
                let mag = unit.magnitude();
                let whole = amount / mag;
                if whole != 0 {
                    split.push(Self {
                        amount: whole,
                        unit,
                    })
                }
                amount -= whole * mag;
            }
            split
        }
    }
    
    impl<U: Unit> Add for Quantity<U> {
        type Output = Self;
        /// Output always in terms of the "base unit"
        fn add(self, rhs: Self) -> Self::Output {
            Self {
                amount: self.amount * self.unit.magnitude() + rhs.amount * rhs.unit.magnitude(),
                unit: U::base(),
            }
        }
    }
    
    #[derive(Debug, Clone, PartialEq, Eq)]
    pub enum ParseQuantityError<UErr> {
        InvalidAmount(ParseIntError),
        InvalidUnit(UErr),
    }
    
    impl<U: FromStr> FromStr for Quantity<U> {
        type Err = ParseQuantityError<<U as FromStr>::Err>;
        fn from_str(s: &str) -> Result<Self, Self::Err> {
            let first_non_digit_idx = s
                .char_indices()
                .find(|(_, c)| !c.is_ascii_digit())
                .map(|(idx, _)| idx)
                .unwrap_or(s.len());
    
            let amount = s[..first_non_digit_idx]
                .parse::<u32>()
                .map_err(ParseQuantityError::InvalidAmount)?;
    
            let unit = s[first_non_digit_idx..]
                .parse::<U>()
                .map_err(ParseQuantityError::InvalidUnit)?;
    
            Ok(Self { amount, unit })
        }
    }
    
    impl<UErr: Display> Display for ParseQuantityError<UErr> {
        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
            match self {
                Self::InvalidAmount(err) => write!(f, "failed to parse `Quantity` amount: {err}"),
                Self::InvalidUnit(err) => write!(f, "failed to parse `Quantity` unit: {err}"),
            }
        }
    }
    
    impl<UErr: Error> Error for ParseQuantityError<UErr> {}
    

    This design has multiple advantages:

    • It's pretty small
    • Every dimension can be a separate type (most likely an enum), which makes weird operations like adding kilograms and liters impossible
    • Creating a new unit is only a matter of implementing the Unit and possibly the standard FromStr trait

    For example, to represent mass with multiple units:

    use std::{error::Error, fmt::Display, str::FromStr};
    
    pub type Mass = Quantity<MassUnit>;
    
    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
    pub enum MassUnit {
        G,
        Kg,
    }
    
    impl Unit for MassUnit {
        fn base() -> Self {
            Self::G
        }
    
        fn magnitude(&self) -> u32 {
            match self {
                Self::G => 1,
                Self::Kg => 1000,
            }
        }
    
        fn sorted_units() -> Vec<Self> {
            vec![Self::Kg, Self::G]
        }
    }
    
    #[derive(Debug)]
    pub struct InvalidMassUnit;
    
    impl FromStr for MassUnit {
        type Err = InvalidMassUnit;
        fn from_str(s: &str) -> Result<Self, Self::Err> {
            match s {
                "g" => Ok(Self::G),
                "kg" => Ok(Self::Kg),
                _ => Err(InvalidMassUnit),
            }
        }
    }
    
    impl Display for InvalidMassUnit {
        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
            write!(f, "invalid mass unit")
        }
    }
    
    impl Error for InvalidMassUnit {}
    

    Which can be used like this:

    let five_kg = "5kg".parse::<Mass>()?;
    let two_g = "2g".parse::<Mass>()?;
    
    let sum = five_kg + two_g;
    assert_eq!(sum, Mass::new(5002, MassUnit::base()));
    
    let split = sum.split();
    assert_eq!(
        split,
        [Mass::new(5, MassUnit::Kg), Mass::new(2, MassUnit::G)]
    );
    

    I've taken the liberty to remove the Sum type, but if you think you need it as an extra layer in your computation, adding it shouldn't be too hard.


    I'm also assuming you really need to be able to represent similar quantities with different units; otherwise, the problem almost disappears, and you can replace every Quantity<U> type with a newtype, like so:

    pub struct Grams(u32);
    
    pub struct Millimeters(u32);
    

    You'd maybe want to keep Quantity<U> for the return value of split, if that can't just live in the Display impl.

    Edit

    If you can't necessarily know the units at compile-time, this is how I would do things.

    use std::{error::Error, fmt::Display, num::ParseIntError, str::FromStr};
    
    /// A type holding information about a certain type that can represent multiple units.
    pub trait Exchanger<U> {
        /// The base unit. All others units are a multiples of it.
        fn base_unit(&self) -> U;
    
        /// The magnitude of this unit with respect to the base unit. The base unit itself hasµ
        /// magnitude 1.
        ///
        /// For example, if the base unit is the gram, then the kilogram would have a magnitude of
        /// 1000.
        fn magnitude(&self, unit: &U) -> u32;
    
        /// All units of this quantity, in decreasing order of magnitude.
        ///
        /// For example: [t, kg, g]
        fn sorted_units(&self) -> Vec<U>;
    
        type ParseUnitError;
    
        fn parse_unit(&self, s: &str) -> Result<U, Self::ParseUnitError>;
    
        fn parse_quantity(
            &self,
            s: &str,
        ) -> Result<Quantity<U>, ParseQuantityError<Self::ParseUnitError>> {
            let first_non_digit_idx = s
                .char_indices()
                .find(|(_, c)| !c.is_ascii_digit())
                .map(|(idx, _)| idx)
                .unwrap_or(s.len());
    
            let amount = s[..first_non_digit_idx]
                .parse::<u32>()
                .map_err(ParseQuantityError::InvalidAmount)?;
    
            let unit = self
                .parse_unit(&s[first_non_digit_idx..])
                .map_err(ParseQuantityError::InvalidUnit)?;
    
            Ok(Quantity { amount, unit })
        }
    
        fn add(&self, lhs: Quantity<U>, rhs: Quantity<U>) -> Quantity<U> {
            Quantity {
                amount: lhs.amount * self.magnitude(&lhs.unit) + rhs.amount * self.magnitude(&rhs.unit),
                unit: self.base_unit(),
            }
        }
    
        /// Splits this quantity into different units
        fn split(&self, quantity: Quantity<U>) -> Vec<Quantity<U>> {
            let mut amount = quantity.amount * self.magnitude(&quantity.unit);
            let mut split = vec![];
            for unit in self.sorted_units() {
                let mag = self.magnitude(&unit);
                let whole = amount / mag;
                if whole != 0 {
                    split.push(Quantity {
                        amount: whole,
                        unit,
                    })
                }
                amount -= whole * mag;
            }
            split
        }
    }
    
    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
    pub struct Quantity<U> {
        pub amount: u32,
        pub unit: U,
    }
    
    impl<U> Quantity<U> {
        pub fn new(amount: u32, unit: U) -> Self {
            Self { amount, unit }
        }
    }
    
    #[derive(Debug, Clone, PartialEq, Eq)]
    pub enum ParseQuantityError<UErr> {
        InvalidAmount(ParseIntError),
        InvalidUnit(UErr),
    }
    
    impl<U: FromStr> FromStr for Quantity<U> {
        type Err = ParseQuantityError<<U as FromStr>::Err>;
        fn from_str(s: &str) -> Result<Self, Self::Err> {
            let first_non_digit_idx = s
                .char_indices()
                .find(|(_, c)| !c.is_ascii_digit())
                .map(|(idx, _)| idx)
                .unwrap_or(s.len());
    
            let amount = s[..first_non_digit_idx]
                .parse::<u32>()
                .map_err(ParseQuantityError::InvalidAmount)?;
    
            let unit = s[first_non_digit_idx..]
                .parse::<U>()
                .map_err(ParseQuantityError::InvalidUnit)?;
    
            Ok(Self { amount, unit })
        }
    }
    
    impl<UErr: Display> Display for ParseQuantityError<UErr> {
        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
            match self {
                Self::InvalidAmount(err) => write!(f, "failed to parse `Quantity` amount: {err}"),
                Self::InvalidUnit(err) => write!(f, "failed to parse `Quantity` unit: {err}"),
            }
        }
    }
    
    impl<UErr: Error> Error for ParseQuantityError<UErr> {}
    

    For example:

    use std::{error::Error, fmt::Display};
    
    use crate::Exchanger;
    
    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
    pub struct FooUnit {
        index: usize,
    }
    
    pub struct FooExchanger {
        units: Vec<(String, u32)>,
    }
    
    impl FooExchanger {
        pub fn new(base: String, mut units: Vec<(String, u32)>) -> Self {
            units.push((base, 1));
            units.sort_unstable_by_key(|(_, mag)| u32::MAX - mag);
            Self { units }
        }
    }
    
    impl Exchanger<FooUnit> for FooExchanger {
        fn base_unit(&self) -> FooUnit {
            FooUnit {
                index: self.units.len() - 1,
            }
        }
    
        fn magnitude(&self, unit: &FooUnit) -> u32 {
            self.units[unit.index].1
        }
    
        fn sorted_units(&self) -> Vec<FooUnit> {
            (0..self.units.len())
                .map(|index| FooUnit { index })
                .collect()
        }
    
        type ParseUnitError = InvalidFooUnit;
        fn parse_unit(&self, s: &str) -> Result<FooUnit, Self::ParseUnitError> {
            let index = self
                .units
                .iter()
                .position(|(name, _)| name == s)
                .ok_or(InvalidFooUnit)?;
            Ok(FooUnit { index })
        }
    }
    
    #[derive(Debug)]
    pub struct InvalidFooUnit;
    
    impl Display for InvalidFooUnit {
        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
            write!(f, "invalid unit")
        }
    }
    
    impl Error for InvalidFooUnit {}
    

    Which can be used like this:

    let exchanger = FooExchanger::new("g".to_string(), vec![("kg".to_string(), 1000)]);
    
    let five_kg = exchanger.parse_quantity("5kg").unwrap();
    let two_g = exchanger.parse_quantity("2g").unwrap();
    
    let sum = exchanger.add(five_kg, two_g);
    assert_eq!(sum, Quantity::new(5002, exchanger.base_unit()));
    
    let split = exchanger.split(sum);
    assert_eq!(
        split,
        [
            Quantity::new(5, exchanger.parse_unit("kg").unwrap()),
            Quantity::new(2, exchanger.parse_unit("g").unwrap())
        ]
    );