I'm working on a generic method that accepts parameters whose types should be inferred from the generic type arguments of a given interface. I'd like to achieve compile-time type safety similar to the following (non-compiling) code:
Note: This is a function on the PradResult class.
/// <summary>
/// Applies the specified IOperation to this PradResult.
/// </summary>
/// <typeparam name="T">The type of the operation to apply.</typeparam>
/// <param name="p1">The first parameter for the operation.</param>
/// <param name="p2">The second parameter for the operation.</param>
/// <param name="parameters">Additional parameters for the operation.</param>
/// <returns>A new PradResult after applying the operation.</returns>
public PradResult Then<T>(T.Args[0] p1, T.Args[1] p2, params object[] parameters)
where T : IPradOperation<,>, new()
{
var opType = typeof(T);
IOperation op = (IOperation)Activator.CreateInstance(opType);
var forwardMethod = opType.GetMethod("Forward");
var forwardParams = parameters.Prepend(p2).Prepend(p1);
var length = parameters.Length;
...
}
The idea is that the p1 and p2 parameters should automatically have their types inferred based on the first and second generic type arguments of the IPradOperation interface implemented by T.
However, I understand that C# doesn't currently support this syntax directly. Are there any workarounds or alternative patterns I can use to achieve this kind of compile-time type inference in my generic method?
I want the Then method to be as user-friendly as possible so that they don't have to specify the types of 'T.Args[0]' and 'T.Args[1]', which in my example would be of type 'Matrix'.
For sake of completeness, here is the IPradOperation interface:
/// <summary>
/// A PRAD Operation.
/// </summary>
/// <typeparam name="T">The first parameter to the forward function.</typeparam>
/// <typeparam name="TR">The return type of the forward function.</typeparam>
public interface IPradOperation<T, TR>
{
}
and the interface gets used like this:
public class AmplifiedSigmoidOperation : Operation, IPradOperation<Matrix, Matrix>
{
...
/// <summary>
/// The forward pass of the operation.
/// </summary>
/// <param name="input">The input for the operation.</param>
/// <returns>The output for the operation.</returns>
public Matrix Forward(Matrix input)
{
this.input = input;
int numRows = input.Length;
int numCols = input[0].Length;
this.Output = new Matrix(numRows, numCols);
for (int i = 0; i < numRows; i++)
{
for (int j = 0; j < numCols; j++)
{
this.Output[i][j] = 1.0f / (1.0f + PradMath.Pow(PradMath.PI - 2, -input[i][j]));
}
}
return this.Output;
}
...
}
So from a usage perspective, if this was working properly and the C# language handled this, I could do:
[Fact]
public void TestThenGeneric()
{
Tensor tensor = new Tensor(new int[] { 200, 300, 400 }, 5f);
PradOp op = new PradOp(tensor);
Matrix m1 = new Matrix(1, 4);
Matrix m2 = new Matrix(2, 3);
// This would work because AmplifiedSigmoidOperation
// implements IPradOperation<Matrix, Matrix>
op.SeedResult
.Then<AmplifiedSigmoidOperation>(m1, m2, tensor, 1)
.Then<MatrixMultiplyOperation>(m1, m2); // etc.
}
Then, I would get Intellisense to guide me on the correct types expected for m1 and m2.
To respond to a comment:
I understand that this scenario wouldn't work if you have: class AmplifiedSigmoidOperation : Operation, IPradOperation<Matrix, Matrix>, IPradOperation<int, Vector>.
I would assume that the compiler would not allow you to use T.Args[0] and T.Args[1] if it was ambiguous. But for cases where you don't have ambiguity, it would work and be more concise for the user of the library.
Also, a possibility to cut down on ambiguity could be, theoretically something like:
public PradResult Then<T>(T.Args[0] p1, T.Args[1] p2, params object[] parameters)
where T : IPradOperation<class, class>, new()
or
public PradResult Then<T>(T.Args[0] p1, T.Args[1] p2, params object[] parameters)
where T : IPradOperation<A1, A2>, new()
where A1 : class
where A2 : class
Where you could then have IPradOperation<int, int> and IPradOperation<Matrix, Matrix> and have it nonambiguous
It's similar in spirit to how C# has evolved to include features like tuple deconstruction or pattern matching, which enhance expressiveness and type safety.
I've come up with a workaround that I believe addresses the issues.
I used a generic base class that specifies the input and output types for each operation's forward pass, which provides compile-time type checking and reduces runtime errors.
It leverages the Curiously Recurring Template Pattern (CRTP) by having the derived operation classes inherit from the generic base class PradOperationBase, passing themselves as the type parameter. It allows the compiler to distinctly identify each operation type and infer the correct types for the inputs and outputs.
Each operation has its own concrete type eliminating issues with ambiguity over multiple versions of an interface.
Types are well-defined and inferred, offering accurate Intellisense. It provides a clean, chainable, and intuitive API with proper compile-time type checking.
// Base class for operations with one input
public abstract class PradOperationBase<TOperation, TParamType, TReturnType> : OperationBase, IOperation
where TOperation : PradOperationBase<TOperation, TParamType, TReturnType>
{
public abstract TReturnType Forward(TParamType input);
public abstract BackwardResult Backward(Matrix dOutput);
}
// Base class for operations with two inputs
public abstract class PradOperationBase<TOperation, TParam1Type, TParam2Type, TReturnType> : OperationBase, IOperation
where TOperation : PradOperationBase<TOperation, TParam1Type, TParam2Type, TReturnType>
{
public abstract TReturnType Forward(TParam1Type input1, TParam2Type input2);
public abstract BackwardResult Backward(Matrix dOutput);
}
// Example operation implementations
public class AmplifiedSigmoidOperation : PradOperationBase<AmplifiedSigmoidOperation, Matrix, Matrix>
{
// Implementation details...
}
public class MatrixMultiplyOperation : PradOperationBase<MatrixMultiplyOperation, Matrix, Matrix, Matrix>
{
// Implementation details...
}
// PradResult class with Then methods
public class PradResult
{
// Then method for operations with one input
public PradResult Then<TOperation, TParamType, TReturnType>(
PradOperationBase<TOperation, TParamType, TReturnType> operation)
where TOperation : PradOperationBase<TOperation, TParamType, TReturnType>
{
// Implementation details...
}
// Then method for operations with two inputs
public PradResult Then<TOperation, TParam1Type, TParam2Type, TReturnType>(
PradOperationBase<TOperation, TParam1Type, TParam2Type, TReturnType> operation,
TParam2Type param2)
where TOperation : PradOperationBase<TOperation, TParam1Type, TParam2Type, TReturnType>
{
// Implementation details...
}
}
// Usage example
[Fact]
public void TestThenGeneric()
{
Tensor tensor = new Tensor(new int[] { 200, 300, 400 }, 5f);
PradOp op = new PradOp(tensor);
Matrix m1 = new Matrix(1, 4);
Matrix m2 = new Matrix(2, 3);
op.SeedResult
.Then(new AmplifiedSigmoidOperation())
.Then(new MatrixMultiplyOperation(), m1)
.Then(new RMAD.LossOps.MeanSquaredErrorLossOperation(), m2);
}
Now, all the user of the library has to know is the name of the operation, and the IDE will guide you as to the parameters of the constructor and the necessary type instances to pass in for the forward pass.
Made the PradOp class partial allowing users to extend it with their own custom nested classes
/// <summary>
/// The ops for the PradOp class.
/// </summary>
public partial class PradOp
{
/// <summary>
/// Operation types.
/// </summary>
public class Ops
{
/// <summary>
/// Gets the add Gaussian noise op.
/// </summary>
public static Func<double, AddGaussianNoiseOperation> AddGaussianNoiseOp => (d) => new AddGaussianNoiseOperation(d);
/// <summary>
/// Gets the Amplified sigmoid op.
/// </summary>
public static Func<AmplifiedSigmoidOperation> AmplifiedSigmoidOp => () => new AmplifiedSigmoidOperation();
...
}
}
This hides the construction part away to make the design more flexible
Usage:
[Fact]
public void TestThenGeneric()
{
Tensor tensor = new Tensor(new int[] { 200, 300, 400 }, 5f);
PradOp op = new PradOp(tensor);
Matrix m1 = new Matrix(1, 4);
Matrix m2 = new Matrix(2, 3);
op.SeedResult
.Then(PradOp.Ops.AmplifiedSigmoidOp)
.Then(PradOp.Ops.AddGaussianNoiseOp(6d))
.Then(PradOp.Ops.MatrixMultiplyOp, m1)
.Then(PradOp.LossOps.MeanSquaredErrorLossOp, m2);
Added these so it now accepts a Func and still leverages the CRTP pattern and automatic type inference
public PradResult Then<TOperation, TParamType, TReturnType>(Func<PradOperationBase<TOperation, TParamType, TReturnType>> opFunc)
where TOperation : PradOperationBase<TOperation, TParamType, TReturnType>
{
var operation = opFunc.Invoke();
var opType = operation.GetType();
IOperation op = operation;
var forwardMethod = opType.GetMethod("Forward");
return this.InnerThenGeneric(new object?[] { this.PradOp.GetCurrentTensor() }, typeof(TReturnType), forwardMethod, op, (v) => new BackwardResult[] { op.Backward(v.ToMatrix()) });
}
public PradResult Then<TOperation, TParam1Type, TParam2Type, TReturnType>(Func<PradOperationBase<TOperation, TParam1Type, TParam2Type, TReturnType>> opFunc, TParam2Type param2)
where TOperation : PradOperationBase<TOperation, TParam1Type, TParam2Type, TReturnType>
{
var operation = opFunc.Invoke();
var opType = operation.GetType();
IOperation op = operation;
var forwardMethod = opType.GetMethod("Forward");
return this.InnerThenGeneric(new object?[] { this.PradOp.GetCurrentTensor(), param2 }, typeof(TReturnType), forwardMethod, op, (v) => new BackwardResult[] { op.Backward(v.ToMatrix()) });
}