Search code examples
functiongogenericsmethodshigher-order-functions

How do I pass a generic function as an argument to another function in golang?


How do i modify transformNumbers function so that it works with generic function doubleG and tripleG

type trandformFn func(int) int

func transformNumbers(numbers *[]int, transform trandformFn) []int {
    dNumbers := []int{}
    for _, value := range *numbers {
        dNumbers = append(dNumbers, transform(value))
    }
    return dNumbers
}

func double(value int) int {
    return 2 * value
}

func triple(value int) int {
    return 3 * value
}

func doubleG[T int | float64 | float32](value T) T {
    return 2 * value
}

func tripleG[T int | float64 | float32](value T) T {
    return 3 * value
}

I am confused with the transformFn type.

trying something like:

func transformNumbers(numbers *[]int, transform func[T int|float64|float32](T)T) []int {
    dNumbers := []int{}
    for _, value := range *numbers {
        dNumbers = append(dNumbers, transform(value))
    }
    return dNumbers
}

but getting error !!

func transformNumbers(numbers *[]int, transform func[T int|float64|float32](T)T) []int {
    dNumbers := []int{}
    for _, value := range *numbers {
        dNumbers = append(dNumbers, transform(value))
    }
    return dNumbers
}

expecting this to work but getting error!


Solution

  • If you want to use a generic function as a parameter of another function without instantiation, you also have to make that function generic. Also don't pass a pointer to a slice, that's unneeded (a slice value is already a header holding a pointer to a backing array). See Are slices passed by value?

    For example:

    func transformNumbers[T int | float64 | float32](numbers []T, transform func(T) T) []T {
        dNumbers := []T{}
        for _, value := range numbers {
            dNumbers = append(dNumbers, transform(value))
        }
        return dNumbers
    }
    

    Testing it:

    fmt.Println(transformNumbers([]int{1, 2, 3}, doubleG))
    fmt.Println(transformNumbers([]float32{1, 2, 3}, doubleG))
    fmt.Println(transformNumbers([]float64{1, 2, 3}, tripleG))
    

    This will output (try it on the Go Playground):

    [2 4 6]
    [2 4 6]
    [3 6 9]
    

    You can of course create a transformFn type for the transformer functions, which also must be generic:

    type transformFn[T int | float64 | float32] func(T) T
    
    func transformNumbers[T int | float64 | float32](numbers []T, transform transformFn[T]) []T {
        dNumbers := []T{}
        for _, value := range numbers {
            dNumbers = append(dNumbers, transform(value))
        }
        return dNumbers
    }
    

    Usage is the same, try this one on the Go Playground.

    You can also create a separate type for the constraint so you don't have to repeat the allowed types:

    type allowedTypes interface {
        int | float64 | float32
    }
    
    type transformFn[T allowedTypes] func(T) T
    
    func transformNumbers[T allowedTypes](numbers []T, transform transformFn[T]) []T {
        dNumbers := []T{}
        for _, value := range numbers {
            dNumbers = append(dNumbers, transform(value))
        }
        return dNumbers
    }
    

    Also note that preallocating the result slice and assinging each transformed value (without using append()) will be much faster, so do it like this:

    func transformNumbers[T int | float64 | float32](numbers []T, transform func(T) T) []T {
        dNumbers := make([]T, len(numbers))
        for i, value := range numbers {
            dNumbers[i] = transform(value)
        }
        return dNumbers
    }