Search code examples
linqf#diffsharp

Linq Expression.Add error but (+) does work in F#


#r "nuget: DiffSharp.Core, 1.0.7-preview1873603133"
#r "nuget: DiffSharp.Backends.Reference, 1.0.7-preview1873603133"
#r "nuget: DiffSharp.Backends.Torch, 1.0.7-preview1873603133"
open DiffSharp
open DiffSharp.Util
let t3 = dsharp.tensor [[1.1; 2.2]; [1.1; 2.2]; [1.1; 2.2]]
1 + t3 //Does work!!

open System.Linq.Expressions
let addB = Expression.Parameter(typeof<int>, "b")
let addC = Expression.Parameter(typeof<Tensor>, "c")
Expression.Add(
            addC,
            addB
        ) //=> Throw exception...

(*
System.InvalidOperationException: The binary operator Add is not defined for the types 'DiffSharp.Tensor' and 'System.Int32'.
   at System.Linq.Expressions.Expression.GetUserDefinedBinaryOperatorOrThrow(ExpressionType binaryType, String name, Expression left, Expression right, Boolean liftToNull)
   at System.Linq.Expressions.Expression.Add(Expression left, Expression right, MethodInfo method)
   at System.Linq.Expressions.Expression.Add(Expression left, Expression right)
   at <StartupCode$FSI_0048>.$FSI_0048.main@()
Stopped due to error
*)

Why 1+t3 is legal but unable to parse into Linq expression? And I would like to know, if this RFC is applied, will it helpful to resolve this issue? F# RFC FS-1043 - Extension members become available to solve operator trait constraints


Solution

  • According to @NetMage's comments, the last solution is like the following code snippets:

    The value static member part:

    type Value =
        | Number of BigRational
        | Approximation of Approximation
        | ComplexInfinity
        | PositiveInfinity
        | NegativeInfinity
        | Undefined
        | RealVec of Vector<float>
        | ComplexVec of Vector<complex>
        | RealMat of Matrix<float>
        | ComplexMat of Matrix<complex>
        | DSTen of Tensor
        with 
            static member (+) (vl : Value, vr : Value) =
                match vl with
                | Number vlv ->
                    match vr with
                    | Number vrv ->
                        Number (vlv * vrv)
                | Approximation (Real vlv) ->
                    match vr with
                    | Approximation (Real vrv) ->
                        Approximation (Real (vlv + vrv))
                    | DSTen dt ->
                        DSTen (vlv + dt)
            static member (*) (vl : Value, vr : float) =
                match vl with
                | Approximation (Real vlv) ->
                    Approximation (Real (vlv * vr))
            static member (*) (vl : float, vr : Value) =
                match vr with
                | Approximation (Real vrv) ->
                    Approximation (Real (vl * vrv))
            static member (+) (vl : Value, vr : float) =
                match vl with
                | Approximation (Real vlv) ->
                    Approximation (Real (vlv + vr))
            static member (+) (vl : float, vr : Value) =
                match vr with
                | Approximation vrv ->
                    match vrv with
                    | Approximation.Real vrvv ->
                        Approximation (Real (vl + vrvv))
                | DSTen dt ->
                    DSTen (vl + dt)
    

    The casting part:

        let exprObj2ValueToInject =
            ExprHelper.Quote<Func<obj, MathNet.Symbolics.Value>> (fun j ->
                match j with
                | :? Value -> (j :?> Value) 
                | _ when j.GetType() = typeof<float> ->
                    Value.Approximation (Approximation.Real (j :?> float))
                | :? Vector<float> ->
                    Value.RealVec (j :?> Vector<float>)
                | :? Matrix<float> ->
                    Value.RealMat (j :?> Matrix<float>)
                | _ ->
                    failwithf "orz010: %s, %A" (j.GetType().FullName) j
                )
                :> Expression :?> LambdaExpression
    

    The invocation part:

                        let xsvv =
                            xsv
                            |> List.map (fun xsExp ->
                                let casted = Expression.Convert(xsExp, typeof<obj>) :> Expression
                                let ivk = Expression.Invoke(exprObj2ValueToInject, [|casted|])
                                ivk :> Expression
                            )
    
    
                        let vLambda = Expression.Invoke(exprBy2Lambda, xsvv)
    

    No matter what is passed in, I convert the value to object and using match to determine how to wrap it.