Search code examples
c++securityfloating-pointnumerical-methodssgx

Side-Channel-Resistant Math Functions for C++


I'm working on an SGX project processing secret data, and at some point, I need to evaluate the natural logarithm of a floating point number. The evaluation process should be side-channel-resistant, meaning its running time and memory access patterns are to be independent of its input and output.

Is there such an implementation out there in the wild? Has the problem been addressed in literature?


Solution

  • The tag suggests that your hardware platform is a very recent x86_64 Intel CPU that also has support for AVX2 and the FMA operation. The key to an implementation that is invariant in run time and memory access pattern is the avoidance of branches. If the compiler cooperates and converts simple conditional assignments into appropriate conditional move or blend instructions, the implementation of logf() below should work fine. However, relying on compiler code generation is brittle, and out of the various compilers offered by the Compiler Explorer, I could only get clang to deliver something close to the desired result, with all branches converted except one (the conditional assignment of a = t in the handling of denormal inputs).

    So you will likely have to do manual work to enforce result selection through appropriate instructions instead of branch-y code, for example by the use of intrinsics.

    As EOF pointed out in comments, eliminating branches is a necessary but not sufficient condition, as individual floating-point operations can also have variable runtime, even if they are just adds, multiplies, and FMAs. This is not a problem on architectures that handle special operands such as subnormals (often called denormals) at speed, e.g. GPUs. However, this is an issue on x86 processors I worked with and worked on. Usually, the most severe variability occurs due to denormal results, with a much smaller impact from denormal source operands.

    Code shown below contains multiple operations using the original argument a as a source operand which exposes it to the risk of runtime variation due to denormal inputs. Whether potential variations in run-time rise above noise level (for example due to variablity in pipeline state at the point the function is called) should be carefully tested for the specific platform(s) where one intends to deploy the code.

    #include <cstdint>
    #include <cstring>
    #include <cmath>
    
    int __float_as_int (float a)
    {
        int r;
        memcpy (&r, &a, sizeof(r));
        return r;
    }
    
    float __int_as_float (int a)
    {
        float r;
        memcpy (&r, &a, sizeof(r));
        return r;
    }
    
    /* maximum error 0.85417 ulp */
    float my_logf (float a)
    {
        float m, r, s, t, i, f, u;
        int32_t e;
    
        /* result for exceptional cases */
        u = a + a;  // silence NaNs if necessary
        if (a  < 0.0f) u =  0.0f / 0.0f; //  NaN
        if (a == 0.0f) u = -1.0f / 0.0f; // -Inf
        
        /* result for non-exceptional cases */
        i = 0.0f;
    
        /* fix up denormal input if needed */
        t = a * 8388608.0f;
        if (a < 1.17549435e-38f) {
            a = t;
            i = -23.0f;
        }
    
        /* split argument into exponent and mantissa parts */
        e = (__float_as_int (a) - 0x3f2aaaab) & 0xff800000;
        m = __int_as_float (__float_as_int (a) - e);
        i = fmaf ((float)e, 1.19209290e-7f, i);
    
        /* m in [2/3, 4/3] */
        f = m - 1.0f;
        s = f * f;
        /* Compute log1p(f) for f in [-1/3, 1/3] */
        r =             -0.130310059f; 
        t =              0.140869141f; 
        r = fmaf (r, s, -0.121489234f);
        t = fmaf (t, s,  0.139809728f);
        r = fmaf (r, s, -0.166844666f);
        t = fmaf (t, s,  0.200121239f);
        r = fmaf (r, s, -0.249996305f);
        r = fmaf (t, f, r);
        r = fmaf (r, f,  0.333331943f);
        r = fmaf (r, f, -0.500000000f);
        r = fmaf (r, s, f);
        r = fmaf (i, 0.693147182f, r); // log(2) 
    
        /* late selection between exceptional and non-exceptional result */
        if (!((a > 0.0f) && (a <= 3.40282347e+38f))) r = u;
    
        return r;
    }
    

    The potential issues identified above can be addressed by performing both special case handling in the logarithm computation and result selection by portable integer-based code. The obvious trade-off is a loss of performance. Handling of denormal arguments requires normalization based on a count of leading zeros (CLZ). While x86 processor have instructions for this, they might not be accessible in portable fashion from C++. But a portable implementation with invariant runtime can be constructed in a straightforward way. This leads to a branchless implementation that I would expect to work well with most compilers, but double checking the generated machine code will be essential. I used Compiler Explorer to verify that it compiles as desired with gcc 11.1 and clang 12.0.1

    #include <stdio.h>
    #include <stdlib.h>
    #include <stdint.h>
    #include <string.h>
    #include <math.h>
    
    /* reinterpret bit pattern of IEEE-754 binary32 as a 32-bit unsigned integer */ 
    int32_t float_as_int32 (float a)
    {
        uint32_t r;
        memcpy (&r, &a, sizeof r);
        return r;
    }
    
    /* reinterpret bit pattern of a 32-bit unsigned integer as IEEE-754 binary32 */
    float int32_as_float (int32_t a)
    {
        float r; 
        memcpy (&r, &a, sizeof r); 
        return r;
    }
    
    /* branch free implementation of ((cond) ? a : b). cond must be in {0,1} */
    int32_t mux (int cond, int32_t a, int32_t b)
    {
        return (1 - cond) * b + cond * a;
    }
    
    /* leading zero count with invariant runtime */
    int clz (uint32_t a)
    {
        // Algorithm by aqrit, https://stackoverflow.com/a/58827596/780717
        int n = 158 - (((uint32_t)float_as_int32 ((float)(int32_t)(a & ~(a >> 1)))) >> 23);
        n = mux (n < 0, 0, n);   // clamp below
        n = mux (n > 32, 32, n); // clamp above
        return n;
    }
    
    /* Compute natural logarithm with a maximum error of 0.85089 ulp */
    float my_logf (float a)
    {
        float m, r, s, t, i, f;
        int32_t e, ia, ii, it, iu, im, shift, excp;
        const int32_t abs_mask   = 0x7fffffffu;
        const int32_t qnan_bit   = 0x00400000u;
        const int32_t pos_infty  = 0x7f800000u;
        const int32_t neg_infty  = 0xff800000u;
        const int32_t indefinite = 0xffc00000u;
        const int32_t zero_float = 0x00000000u; // 0.0f
        const int32_t one_float  = 0x3f800000u; // 1.0f
        const int32_t tiny_float = 0x00800000u; // 1.17549435e-38f
        const int32_t huge_float = 0x7f7fffffu; // 3.40282347e+38f 
    
        ia = float_as_int32 (a);
    
        /* result for exceptional cases */
        iu = mux ((int32_t)ia < 0, indefinite, ia); // return QNaN INDEFINITE
        iu = mux ((ia & abs_mask) == 0, neg_infty, iu); // return -Inf
        iu = mux ((ia & abs_mask) > pos_infty, ia | qnan_bit, iu); // convert to QNaN
    
        /* result for non-exceptional cases */
        shift = clz (ia) - 8;
        it = (ia << shift) + ((23 - shift) << 23);
        ii = mux (ia < tiny_float, -23, 0);
        it = mux (ia < tiny_float, it, ia);
    
        /* split argument into exponent and mantissa parts */
        e = (it - 0x3f2aaaab) & 0xff800000;
        m = int32_as_float (it - e);
        i = fmaf ((float)e, 1.19209290e-7f, (float)ii);
    
        /* m in [2/3, 4/3] */
        f = m - 1.0f;
        s = f * f;
    
        /* Compute log1p(f) for f in [-1/3, 1/3] */
        r =             -0.130310059f; 
        t =              0.140869141f; 
        r = fmaf (r, s, -0.121483363f);
        t = fmaf (t, s,  0.139814854f);
        r = fmaf (r, s, -0.166846141f);
        t = fmaf (t, s,  0.200120345f);
        r = fmaf (r, s, -0.249996200f);
        r = fmaf (t, f, r);
        r = fmaf (r, f,  0.333331972f);
        r = fmaf (r, f, -0.500000000f);
        r = fmaf (r, s, f);
        r = fmaf (i, 0.693147182f, r); // log(2) 
    
        /* late selection between exceptional and non-exceptional result */
        excp = ((uint32_t)ia - 1) > ((uint32_t)huge_float - 1);
        iu = mux (excp, iu, zero_float);
        im = mux (excp, zero_float, one_float);
        r = fmaf (int32_as_float (im), r, int32_as_float (iu));
    
        return r;
    }
    
    /* reinterpret bit pattern of IEEE-754 binary64 as a 64-bit unsigned integer */ 
    uint64_t double_as_uint64 (double a)
    {
        uint64_t r;
        memcpy (&r, &a, sizeof r);
        return r;
    }
    
    double floatUlpErr (float res, double ref)
    {
        uint64_t i, j, err;
        int expoRef;
        
        /* ulp error cannot be computed if either operand is NaN, infinity, zero */
        if (isnan (res) || isnan (ref) || isinf (res) || isinf (ref) ||
            (res == 0.0f) || (ref == 0.0f)) {
            return 0.0;
        }
        /* Convert the float result to an "extended float". This is like a float
           with 56 instead of 24 effective mantissa bits
        */
        i = ((uint64_t)(uint32_t)float_as_int32 (res)) << 32;
        /* Convert the double reference to an "extended float". If the reference is
           >= 2^129, we need to clamp yo the maximum "extended float". If reference
           is < 2^-126, we need to denormalize because of float's limited exponent
           range.
        */
        expoRef = (int)(((double_as_uint64(ref) >> 52) & 0x7ff) - 1023);
        if (expoRef >= 129) {
            j = (double_as_uint64(ref) & 0x8000000000000000ULL) |
                0x7fffffffffffffffULL;
        } else if (expoRef < -126) {
            j = ((double_as_uint64(ref) << 11) | 0x8000000000000000ULL) >> 8;
            j = j >> (-(expoRef + 126));
            j = j | (double_as_uint64(ref) & 0x8000000000000000ULL);
        } else {
            j = ((double_as_uint64(ref) << 11) & 0x7fffffffffffffffULL) >> 8;
            j = j | ((uint64_t)(expoRef + 127) << 55);
            j = j | (double_as_uint64(ref) & 0x8000000000000000ULL);
        }
        err = (i < j) ? (j - i) : (i - j);
        return err / 4294967296.0;
    }
    
    int main (void)
    {
        uint32_t diff, refi, resi, argi = 0;
        float reff, res, arg;
        double ref, ulp, maxulp = 0;
    
        do {
            arg = int32_as_float (argi);
            ref = log ((double)arg);
            reff = (float)ref;
            res = my_logf (arg);
            ulp = floatUlpErr (res, ref);
            if (ulp > maxulp) maxulp = ulp;
            resi = float_as_int32 (res);
            refi = float_as_int32 (reff);
            diff = (resi > refi) ? (resi - refi) : (refi - resi);
            if (diff > 1) {
                printf ("error: arg=%15.6a res=%15.6a ref=%15.6a\n", arg, res, ref);
                return EXIT_FAILURE;
            }
            argi++;
        } while (argi);
        printf ("maximum ulp error: %.5f\n", maxulp);
        return EXIT_SUCCESS;
    }