Search code examples
swiftoptimizationrsamathematical-optimizationbigint

Big number computation in Swift for RSA implementation


I'm trying to implement a RSA algorithm in Swift for the CryptoSwift lib (to fix #63). The algorithm itself is working, but I need to improve big numbers computation performance for it to work in reasonable time.

I implemented my own GiantUInt struct (storing bytes as UInt8) to compute operations with RSA big numbers (2048 bits length for example), but it is too slow (mainly the remainder operation, but I think everything can be improved):

precedencegroup PowerPrecedence { higherThan: MultiplicationPrecedence }
infix operator ^^ : PowerPrecedence

public struct GiantUInt: Equatable, Comparable, ExpressibleByIntegerLiteral, ExpressibleByArrayLiteral {
  
  // Properties
  
  public let bytes: Array<UInt8>
  
  // Initialization
  
  public init(_ raw: Array<UInt8>) {
    var bytes = raw
    
    while bytes.last == 0 {
      bytes.removeLast()
    }
    
    self.bytes = bytes
  }
  
  // ExpressibleByIntegerLiteral
  
  public typealias IntegerLiteralType = UInt8
  
  public init(integerLiteral value: UInt8) {
    self = GiantUInt([value])
  }
  
  // ExpressibleByArrayLiteral
  
  public typealias ArrayLiteralElement = UInt8
  
  public init(arrayLiteral elements: UInt8...) {
    self = GiantUInt(elements)
  }
    
  // Equatable
  
  public static func == (lhs: GiantUInt, rhs: GiantUInt) -> Bool {
    lhs.bytes == rhs.bytes
  }
  
  // Comparable
  
  public static func < (rhs: GiantUInt, lhs: GiantUInt) -> Bool {
    for i in (0 ..< max(rhs.bytes.count, lhs.bytes.count)).reversed() {
      let r = rhs.bytes[safe: i] ?? 0
      let l = lhs.bytes[safe: i] ?? 0
      if r < l {
        return true
      } else if r > l {
        return false
      }
    }
    
    return false
  }
  
  // Operations
  
  public static func + (rhs: GiantUInt, lhs: GiantUInt) -> GiantUInt {
    var bytes = [UInt8]()
    var r: UInt8 = 0
    
    for i in 0 ..< max(rhs.bytes.count, lhs.bytes.count) {
      let res = UInt16(rhs.bytes[safe: i] ?? 0) + UInt16(lhs.bytes[safe: i] ?? 0) + UInt16(r)
      r = UInt8(res >> 8)
      bytes.append(UInt8(res & 0xff))
    }
    
    if r != 0 {
      bytes.append(r)
    }
    
    return GiantUInt(bytes)
  }
  
  public static func - (rhs: GiantUInt, lhs: GiantUInt) -> GiantUInt {
    var bytes = [UInt8]()
    var r: UInt8 = 0
    
    for i in 0 ..< max(rhs.bytes.count, lhs.bytes.count) {
      let rhsb = UInt16(rhs.bytes[safe: i] ?? 0)
      let lhsb = UInt16(lhs.bytes[safe: i] ?? 0) + UInt16(r)
      r = UInt8(rhsb < lhsb ? 1 : 0)
      let res = (UInt16(r) << 8) + rhsb - lhsb
      bytes.append(UInt8(res & 0xff))
    }
    
    if r != 0 {
      bytes.append(r)
    }
    
    return GiantUInt(bytes)
  }
  
  public static func * (rhs: GiantUInt, lhs: GiantUInt) -> GiantUInt {
    var offset = 0
    var sum = [GiantUInt]()
    
    for rbyte in rhs.bytes {
      var bytes = [UInt8](repeating: 0, count: offset)
      var r: UInt8 = 0
      
      for lbyte in lhs.bytes {
        let res = UInt16(rbyte) * UInt16(lbyte) + UInt16(r)
        r = UInt8(res >> 8)
        bytes.append(UInt8(res & 0xff))
      }
      
      if r != 0 {
        bytes.append(r)
      }
      
      sum.append(GiantUInt(bytes))
      offset += 1
    }
    
    return sum.reduce(0, +)
  }
  
  public static func % (rhs: GiantUInt, lhs: GiantUInt) -> GiantUInt {
    var remainder = rhs
    
    // This needs serious optimization (but works)
    while remainder >= lhs {
      remainder = remainder - lhs
    }
  
    return remainder
  }
  
  static func ^^ (rhs: GiantUInt, lhs: GiantUInt) -> GiantUInt {
    let count = lhs.bytes.count
    var result = GiantUInt([1])
    
    for iByte in 0 ..< count {
      let byte = lhs.bytes[iByte]
      for i in 0 ..< 8 {
        if iByte != count - 1 || byte >> i > 0 {
          result = result * result
          if (byte >> i) & 1 == 1 {
            result = result * rhs
          }
        }
      }
    }
    
    return result
  }
  
  public static func exponentiateWithModulus(rhs: GiantUInt, lhs: GiantUInt, modulus: GiantUInt) -> GiantUInt {
    let count = lhs.bytes.count
    var result = GiantUInt([1])
    
    for iByte in 0 ..< count {
      let byte = lhs.bytes[iByte]
      for i in 0 ..< 8 {
        if iByte != count - 1 || byte >> i > 0 {
          result = (result * result) % modulus
          if (byte >> i) & 1 == 1 {
            result = (result * rhs) % modulus
          }
        }
      }
    }
    
    return result
  }
  
}

(this file is available here on my fork)

How can I improve it to make it (a lot) quicker?


Solution

  • How can I improve it to make it (a lot) quicker?

    Don't use bytes. Performance depends on the number of "digits" in the big numbers; so more smaller digits is worse and fewer larger digits is better. For example, for multiplying a pair of 2048-bit big numbers, if it's implemented using bytes you end up with "256 digits * 256 digits = 65536 multiplications of digits" and if it's implemented with 64-bit integers then you end up with "32 digits * 32 digits = 1024 multiplications of digits" (which is about 64 times faster).

    Prefer destructive operations

    For big numbers; for something like "a = b + c" the CPU has to deal with 3 sets of cache lines, and for something like "a += b" the CPU only has to deal with 2 sets of cache lines. For large big numbers this can be the difference between "it all fits in cache" and "performance is ruined by cache misses".

    Don't use append

    Things like bytes.append(r) probably involve a buffer capacity check and potential resizing of the underlying buffer; and this extra overhead is unnecessary and avoidable - you should be able to determine the size of the result in advance, create a correct size array in advance, then calculate the result without any checks and without any resizing.

    Don't use multiplication for squaring

    For squaring, the number of "multiplications of digits" can be almost halved by relying on the fact that both numbers are the same number. To understand this, assume you're doing 1234 * 1234 in decimal and express the intermediate values a grid like this:

         1        2         3           4
        --------------
     1 | 1*1    + 2*10    + 3*100     + 4*1000 +
     2 | 2*10   + 4*100   + 6*1000    + 8*10000 +
     3 | 3*100  + 6*1000  + 9*10000   + 12*100000 +
     4 | 4*1000 + 8*10000 + 12*100000 + 16*1000000
    

    You can see that the top right "not quite half" of the grid is a mirror of the bottom left "not quite half", so you could do this instead:

         1          2           3              4
        --------------
     1 | 1*1
     2 | (2*10)*2 + 4*100
     3 | (3*100   + 6*1000)*2 + 9*10000
     4 | (4*1000  + 8*10000   + 12*100000)*2 + 16*1000000
    

    Of course this can be rearranged so that the *2 only happens once; like "result = (2*10 + 3*100 + 6*1000 + 4*1000 + 8*10000 + 12*100000) * 2 + 1*1 + 4*100 + 9*10000 + 16*1000000".

    Your modulo can/should be improved

    One approach is to shift the divisor left until it's larger than the numerator (while keeping track of shift count); then do a "while shift count isn't zero { right shift divisor; if divisor is not larger than numerator subtract divisor from numerator; decrease shift count }".

    Don't use Swift or High Level Languages

    Most CPUs have special instructions to make working with big numbers much more efficient, and even basic things (e.g. "add with carry") are impossible in most high level languages. The consequence is that using the best digit size (e.g. 64-bit digits on 64-bit CPUs) is painful, so you implement algorithms differently, and then the compiler can't optimize properly because the algorithm is different.

    The best performance can only be achieved by using assembly language (e.g. a native library that can be used from swift code). You can see this difference clearly if you compare (e.g.) the GMP library (which uses lots of inline assembly language) against mini-GMP (which doesn't use assembly language and merely aims to be "not more than 10 times slower for numbers up to a few hundred bits").