Search code examples
mathswiftsqrtsquare-root

Generic square root in Swift


I'm building a generic vector class in Swift with three types: Float, Double and Int. This works so far, but when I try to calculate the length of the vector I run into an issue.

The formula for vector length is the square root of (x²+y²). But since I use a generic class for my vectors, the values of x and y are called T.

The sqrt function of Swift only accepts Double as an argument but no generic argument.

Is there any way to use the sqrt function with generic parameters?

Here is a snippet of the code I use for the vector length and the dot product:

protocol ArithmeticType {
    func + (left: Self, right: Self) -> Self
    func - (left: Self, right: Self) -> Self
    func * (left: Self, right: Self) -> Self
    func / (left: Self, right: Self) -> Self
    prefix func - (left: Self) -> Self

    func toDouble() -> Double
}

extension Double: ArithmeticType {
    func toDouble() -> Double {
        return Double(self)
    }
}

extension Float: ArithmeticType {
    func toDouble() -> Double {
        return Double(self)
    }
}

extension Int: ArithmeticType {
    func toDouble() -> Double {
        return Double(self)
    }
}

class Vector<T where T: ArithmeticType, T: Comparable> {
    var length: T { return sqrt((self ⋅ self).toDouble()) }
}

infix operator ⋅ { associativity left }
func ⋅<T: ArithmeticType> (left: Vector<T>, right: Vector<T>) -> T {
    var result: T? = nil

    for (index, value) in enumerate(left.values) {
        let additive = value * right.values[index]

        if result == nil {
            result = additive
        } else if let oldResult = result {
            result = oldResult + additive
        }
    }

    if let unwrappedResult = result {
        return unwrappedResult
    }
}

Solution

  • I see that you're using using a custom Arithmetic protocol to constraint the generic.

    My approach would be to declare 2 required methods in that protocol: toDouble() and fromDouble(), and implement both in Float, Double and Int extensions. Note that fromDouble() should be a static method.

    This way you can convert T to Double, hence be able to use sqrt(), and convert back from Double to T.

    Last, there's a bug in your code: if left is an empty vector, the function will crash, because the code in the loop will never be executed, so result will keep its nil initial value. The forced unwrapping in the return statement will fail, causing the exception.