Search code examples
goaws-sdk-go

Generic function with variable input/output types


Just playing with aws sdk for go. When listing resources of different types I tend to have alot of very similar functions like the two in the example bellow. Is there a way to rewrite them as one generic function that will return a specific type depending on what is passed on as param?

Something like:

func generic(session, funcToCall, t, input) (interface{}, error) {}

currently I have to do this (functionality is the same just types change):

func getVolumes(s *session.Session) ([]*ec2.Volume, error) {

    client := ec2.New(s)

    t := []*ec2.Volume{}
    input := ec2.DescribeVolumesInput{}

    for {
        result, err := client.DescribeVolumes(&input)
        if err != nil {
            return nil, err
        }

        t = append(t, result.Volumes...)

        if result.NextToken != nil {
            input.NextToken = result.NextToken
        } else {
            break
        }
    }
    return t, nil
}

func getVpcs(s *session.Session) ([]*ec2.Vpc, error) {

    client := ec2.New(s)

    t := []*ec2.Vpc{}
    input := ec2.DescribeVpcsInput{}

    for {
        result, err := client.DescribeVpcs(&input)
        if err != nil {
            return nil, err
        }

        t = append(t, result.Vpcs...)

        if result.NextToken != nil {
            input.NextToken = result.NextToken
        } else {
            break
        }
    }
    return t, nil
} 

Solution

  • Because you only deal with functions it is possible to use the reflect package to generate functions at runtime.

    Using the object type (Volume, Vpc) it is possible to derive all subsequents information to provide a fully generic implementation that is really dry, at the extent at the being more complex and slower.

    It is untested, you are welcome to help in testing and fixing it, but something like this should put you on the track

    https://play.golang.org/p/mGjtYVG2OZS

    The registry idea come from this answer https://stackoverflow.com/a/23031445/4466350

    for reference the golang documentation of the reflect package is at https://golang.org/pkg/reflect/

    package main
    
    import (
        "errors"
        "fmt"
        "reflect"
    )
    
    func main() {
        fmt.Printf("%T\n", getter(Volume{}))
        fmt.Printf("%T\n", getter(Vpc{}))
    }
    
    type DescribeVolumesInput struct{}
    type DescribeVpcs struct{}
    
    type Volume struct{}
    type Vpc struct{}
    
    type Session struct{}
    
    type Client struct{}
    
    func New(s *Session) Client { return Client{} }
    
    var typeRegistry = make(map[string]reflect.Type)
    
    func init() {
        some := []interface{}{DescribeVolumesInput{}, DescribeVpcs{}}
        for _, v := range some {
            typeRegistry[fmt.Sprintf("%T", v)] = reflect.TypeOf(v)
        }
    }
    
    var errV = errors.New("")
    var errType = reflect.ValueOf(&errV).Elem().Type()
    var zeroErr = reflect.Zero(reflect.TypeOf((*error)(nil)).Elem())
    var nilErr = []reflect.Value{zeroErr}
    
    func getter(of interface{}) interface{} {
    
        outType := reflect.SliceOf(reflect.PtrTo(reflect.TypeOf(of)))
        fnType := reflect.FuncOf([]reflect.Type{reflect.TypeOf(new(Session))}, []reflect.Type{outType, errType}, false)
        fnBody := func(input []reflect.Value) []reflect.Value {
    
            client := reflect.ValueOf(New).Call(input)[0]
    
            t := reflect.MakeSlice(outType, 0, 0)
            name := fmt.Sprintf("Describe%TsInput", of)
            descInput := reflect.New(typeRegistry[name]).Elem()
    
            mName := fmt.Sprintf("Describe%Ts", of)
            meth := client.MethodByName(mName)
            if !meth.IsValid() {
                return []reflect.Value{
                    t,
                    reflect.ValueOf(fmt.Errorf("no such method %q", mName)),
                }
            }
            for {
                out := meth.Call([]reflect.Value{descInput.Addr()})
                if len(out) > 0 {
                    errOut := out[len(out)-1]
                    if errOut.Type().Implements(errType) && errOut.IsNil() == false {
                        return []reflect.Value{t, errOut}
                    }
                }
                result := out[1]
                fName := fmt.Sprintf("%Ts", of)
                if x := result.FieldByName(fName); x.IsValid() {
                    t = reflect.AppendSlice(t, x)
                } else {
                    return []reflect.Value{
                        t,
                        reflect.ValueOf(fmt.Errorf("field not found %q", fName)),
                    }
                }
    
                if x := result.FieldByName("NextToken"); x.IsValid() {
                    descInput.FieldByName("NextToken").Set(x)
                } else {
                    break
                }
            }
            return []reflect.Value{t, zeroErr}
        }
        fn := reflect.MakeFunc(fnType, fnBody)
        return fn.Interface()
    }