Search code examples
argumentsrusttype-promotion

Is it possible to calculate the return type of a Rust function or trait method based on its arguments?


Can I achieve something similar to boost::math::tools::promote_args in Rust? See also Idiomatic C++11 type promotion

To be more specific: is it possible to calculate the return type of a function or trait method based on its arguments and ensure, that the return type has the same type as one of the arguments?

Consider the following case. I have two structs:

#[derive(Debug, Clone, Copy)]
struct MySimpleType(f64);

#[derive(Debug, Clone, Copy)]
struct MyComplexType(f64, f64);

where MySimpleType can be promoted to MyComplexType via the From trait.

impl From<MySimpleType> for MyComplexType {
    fn from(src: MySimpleType) -> MyComplexType {
        let MySimpleType(x1) = src;
        MyComplexType(x1, 0.0)
    }
 }

I want to write a function that takes two arguments of types MySimpleType or MyComplexType and return a value of type MySimpleType if all arguments are typed as MySimpleType, otherwise the function should return a value of type MyComplexType. Assuming I have implemented Add<Output=Self> for both types I could do something like this:

trait Foo<S, T> {
    fn foo(s: S, t: T) -> Self;
}

impl<S, T, O> Foo<S, T> for O
    where O: From<S> + From<T> + Add<Output = Self>
{
    fn foo(s: S, t: T) -> Self {
        let s: O = From::from(s);
        let t: O = From::from(t);
        s + t
    }
}

but then the compiler doesn't know that O should be either S or T and I have to annotate most method calls.

My second attempt is to use a slightly different trait and write two implementations:

trait Foo<S, T> {
    fn foo(s: S, t: T) -> Self;
}

impl Foo<MySimpleType, MySimpleType> for MySimpleType {
    fn foo(s: MySimpleType, t: MySimpleType) -> Self {
        s + t
    }
}

impl<S, T> Foo<S, T> for MyComplexType
    where MyComplexType: From<S> + From<T>
{
    fn foo(s: S, t: T) -> Self {
        let s: MyComplexType = From::from(s);
        let t: MyComplexType = From::from(t);
        s + t
    }
}

but again, the compiler isn't able to figure the return type of

Foo::foo(MySimpleType(1.0), MySimpleType(1.0))

The third attempt is something similar to the std::ops::{Add, Mul, ...}. Use an associated type and write a specific implementation for each possible combination of argument types

trait Foo<T> {
    type Output;
    fn foo(self, t: T) -> Self::Output;
}

impl<T: Add<Output=T>> Foo<T> for T {
    type Output = Self;
    fn foo(self, t: T) -> Self::Output {
        self + t
    }
}

impl Foo<MySimpleType> for MyComplexType {
    type Output = Self;
    fn foo(self, t: MySimpleType) -> Self::Output {
        let t: Self = From::from(t);
        self + t
    }
}

impl Foo<MyComplexType> for MySimpleType {
    type Output = MyComplexType;
    fn foo(self, t: MyComplexType) -> Self::Output {
        let s: MyComplexType = From::from(self);
        s + t
    }
}

This seems to be the best solution until one needs a function with n arguments. Because then one has to write 2^n - n + 1 impl statements. Of course, this gets even worse if more then two types being considered.

===

Edit:

In my code I've multiple nested function calls and I want to avoid non necessary type promotion, since the evaluation of the functions for the simple type is cheap and expensive for the complex type. By using @MatthieuM. 's proposed solution, this is not achieved. Please consider the following example

#![feature(core_intrinsics)]

use std::ops::Add;

trait Promote<Target> {
    fn promote(self) -> Target;
}

impl<T> Promote<T> for T {
    fn promote(self) -> T {
        self
    }
}

impl Promote<u64> for u32 {
    fn promote(self) -> u64 {
        self as u64
    }
}

fn foo<Result, Left, Right>(left: Left, right: Right) -> Result
    where Left: Promote<Result>,
        Right: Promote<Result>,
        Result: Add<Output = Result>
{
    println!("============\nFoo called");
    println!("Left: {}", unsafe { std::intrinsics::type_name::<Left>() });
    println!("Right: {}",
            unsafe { std::intrinsics::type_name::<Right>() });
    println!("Result: {}",
            unsafe { std::intrinsics::type_name::<Result>() });
    left.promote() + right.promote()
}

fn bar<Result, Left, Right>(left: Left, right: Right) -> Result
    where Left: Promote<Result>,
        Right: Promote<Result>,
        Result: Add<Output = Result>
{
    left.promote() + right.promote()
}

fn baz<Result, A, B, C, D>(a: A, b: B, c: C, d: D) -> Result
    where A: Promote<Result>,
        B: Promote<Result>,
        C: Promote<Result>,
        D: Promote<Result>,
        Result: Add<Output = Result>
{
    let lhs = foo(a, b).promote();
    let rhs = bar(c, d).promote();
    lhs + rhs
}

fn main() {
    let one = baz(1u32, 1u32, 1u64, 1u32);
    println!("{}", one);
}

Solution

  • I would expect the simplest way to implement promotion is to create a Promote trait:

    trait Promote<Target> {
        fn promote(self) -> Target;
    }
    
    impl<T> Promote<T> for T {
        fn promote(self) -> T { self }
    }
    

    Note: I provide a blanket implementation as all types can be promoted to themselves.

    Using associated types is NOT an option here, because a single type can be promoted to multiple types; thus we just use a regular type parameter.


    Using this, a simple example is:

    impl Promote<u64> for u32 {
        fn promote(self) -> u64 { self as u64 }
    }
    
    fn add<Result, Left, Right>(left: Left, right: Right) -> Result
        where
            Left: Promote<Result>,
            Right: Promote<Result>,
            Result: Add<Output = Result>
    {
        left.promote() + right.promote()
    }
    
    fn main() {
        let one: u32 = add(1u32, 1u32);
        let two: u64 = add(1u32, 2u64);
        let three: u64 = add(2u64, 1u32);
        let four: u64 = add(2u64, 2u64);
        println!("{} {} {} {}", one, two, three, four);
    }
    

    The only issue is that in the case of two u32 arguments, the result type must be specified otherwise the compiler cannot choose between which possible Promote implementation to use: Promote<u32> or Promote<u64>.

    I am not sure if this is an issue in practice, however, since at some point you should have a concrete type to anchor type inference. For example:

    fn main() {
        let v = vec![add(1u32, 1u32), add(1u32, 2u64)];
        println!("{:?}", v);
    }
    

    compiles without type hint, because add(1u32, 2u64) can only be u64, and therefore since a Vec is a homogeneous collection, add(1u32, 1u32) has to return a u64 here.


    As you experienced, though, sometimes you need the ability to direct the result beyond what type inference can handle. It's fine, you just need another trait for it:

    trait PromoteTarget {
        type Output;
    }
    
    impl<T> PromoteTarget for (T, T) {
        type Output = T;
    }
    

    And then a little implementation:

    impl PromoteTarget for (u32, u64) {
        type Output = u64;
    }
    
    impl PromoteTarget for (u64, u32) {
        type Output = u64;
    }
    

    With that out of the way, we can rewrite baz signature to correctly account for all intermediate types. Unfortunately I don't know any way to introduce aliases in a where clause, so brace yourself:

    fn baz<Result, A, B, C, D>(a: A, b: B, c: C, d: D) -> Result
        where
            A: Promote<<(A, B) as PromoteTarget>::Output>,
            B: Promote<<(A, B) as PromoteTarget>::Output>,
            C: Promote<<(C, D) as PromoteTarget>::Output>,
            D: Promote<<(C, D) as PromoteTarget>::Output>,
            (A, B): PromoteTarget,
            (C, D): PromoteTarget,
            <(A, B) as PromoteTarget>::Output: Promote<Result> + Add<Output = <(A, B) as PromoteTarget>::Output>,
            <(C, D) as PromoteTarget>::Output: Promote<Result> + Add<Output = <(C, D) as PromoteTarget>::Output>,
            Result: Add<Output = Result>
    {
        let lhs = foo(a, b).promote();
        let rhs = bar(c, d).promote();
        lhs + rhs
    }
    

    Link to the playground here, so you can check the result:

    ============
    Foo called
    Left: u32
    Right: u32
    Result: u32
    4