Search code examples
clinuxcudamultiprecision

multi-precision multiplication in CUDA


I am trying to implement multi-precision multiplication in CUDA. For doing that, I have implemented a kernel which should compute multiplication of uint32_t type operand with 256-bit operand and put the result in 288-bit array. So far, I have came up with this code:

__device__ __constant__ UN_256fe B_const;

 __global__ void multiply32x256Kernel(uint32_t A, UN_288bite* result){

uint8_t tid = blockIdx.x * blockDim.x + threadIdx.x;
//for managing warps
//uint8_t laineid = tid % 32; 
//allocate partial products into array of uint64_t 
__shared__ uint64_t partialMuls[8];
uint32_t carry, r;
if((tid < 8) && (tid != 0)){
    //compute partial products
    partialMuls[tid] = A * B_const.uint32[tid];

    //add partial products and propagate carry
    result->uint32[8] = (uint32_t)partialMuls[7];
    r = (partialMuls[tid] >> 32) + ((uint32_t)partialMuls[tid - 1]);
    carry = r < (partialMuls[tid] >> 32);
    result->uint32[0] = (partialMuls[0] >> 32);
    while(__any(carry)){

        r = r + carry;
        //new carry?        
        carry = r < carry;  
    } 
result->uint32[tid] = r;

}

and my data-type is :

typedef struct UN_256fe{

uint32_t uint32[8];

}UN_256fe;

typedef struct UN_288bite{

uint32_t uint32[9];

}UN_288bite;

My kernel works, but it gives me wrong result. I cannot debug inside the kernel, so I would appreciate if someone let me know where the problem is or how I can debug my code inside the kernel on tegra-ubuntu with cuda-6.0. Thanks


Solution

  • This answer has nothing to do with CUDA itself, but is a general C implementation.

    I can't quite follow what you are doing (especially with carry) but you could try this snippet based on my own big num functions. I defined dtype to make it easier to test with smaller fields. Note that I don't specifically use a carry, but carry forward the partial product.

    // little-endian
    #include <stdio.h>
    #include <stdint.h>
    #include <limits.h>
    
    #define dtype uint8_t           // for testing
    //#define dtype uint32_t        // for proper ver
    
    #define SHIFTS (sizeof(dtype)*CHAR_BIT)
    #define NIBBLES (SHIFTS/4)
    #define ARRLEN 8
    
    typedef struct UN_256fe {
        dtype uint[ARRLEN];
    } UN_256fe;
    
    typedef struct UN_288bite {
        dtype uint[ARRLEN+1];
    } UN_288bite;
    
    void multiply(UN_288bite *product, UN_256fe *operand, dtype multiplier)
    {
        int i;
        uint64_t partial = 0;
        for (i=0; i<ARRLEN; i++) {
            partial = partial + (uint64_t)multiplier * operand->uint[i];
            product->uint[i] = (dtype)partial;
            partial >>= SHIFTS;                     // carry
        }
        product->uint[i] = (dtype)partial;
    }
    
    int main(void)
    {
        int i;
        dtype multiplier = 0xAA;
        UN_256fe operand = { 1, 2, 3, 4, 5, 6, 7, 8};
        UN_288bite product;
    
        multiply(&product, &operand, multiplier);
    
        for(i=ARRLEN-1; i>=0; i--)
            printf("%0*X", NIBBLES, operand.uint[i]);
        printf("\n * %0*X = \n", NIBBLES, multiplier);
        for(i=ARRLEN; i>=0; i--)
            printf("%0*X", NIBBLES, product.uint[i]);
        printf("\n");
    
        return 0;
    }
    

    Program output for uint8_t

    0807060504030201
     * AA =
    0554A9FF54A9FF54AA