I want to use a SASS instruction which (AFAICT) is not available via a PTX instruction as of CUDA 12.4. Namely, suppose it is: HMMA.16816.F16
- a warp-wide matrix-multiply-and-add, of half-precision data, with shape M=16, N=8, K=16 (IIANM).
The CUDA PTX ISA guide for CUDA 12.4 indicates in Section 9.7.13.3 that at FP16 precision, we only have PTX WMMA instructions with shape (M,N,K) being one of (16, 16, 16) or (32, 8, 16) or (8, 32, 16) - nothing smaller. But Section 9.7.13.1 says that smaller matrix shapes - (16, 8, 16), (16, 8, 8) and (8, 8, 4) - Are supported.
Trying to use the intrinsics corresponding to these smaller shapes, e.g.:
__hmma_m16n8k16_ld_a
results in an error:
mma-smaller.hpp(86): error: identifier "__hmma_m16n8k16_ld_a" is undefined
__hmma_m16n8k16_ld_a((int*)&a, (const int*)p, ldm, 0);
^
So are these shapes supported in PTX, or are they not?
Note: I'm using an Ampere GPU.
TL;DR: You can issue such a SASS instruction through appropriate choice of a PTX-level mma
instruction (not wmma
), but there is no corresponding C++ intrinsic documented to do that, AFAIK, at this time.
Longer: Let's start with some general background to disentangle some of these ideas. The mma class of instruction are there primarily to exercise tensorcore units, which provide hardware accelerated matrix-matrix multiply operations.
wmma.mma
, mma
, and wgmma.mma_async
.wmma.mma
instructions are distinguished by the fact that they also have corresponding matrix load and store instructions - they do not expose the per-thread register storage footprint directly. The mma
instructions, on the other hand, take PTX register input/output directly.wmma
style operations are documented - that is a subset of possible tensorcore ops, and that subset corresponds to the PTX wmma.mma
instructions, and that subset is also distinguished by the fact that matrix load/store functions are used, not direct register manipulation.wmma:mma_sync(...)
. There are no intrinsics documented in the C++ programming guide that look like __hmma_m16n8k16_ld_a
Does PTX (8.3) not cover smaller-shape WMMA instructions?
Yes, you can issue a 16x8x16 (M,N,K) 16-bit floating point tensorcore op using PTX. It cannot be directly done using a (documented) C++ intrinsic, and in PTX I wouldn't use a wmma.mma
instruction for it, I would use this mma PTX instruction - mma.m16n8k16 . A detailed description for this with PTX register layout is here. An instruction-skeleton example is given here. The "target ISA notes" section later in that link provides hardware support info. Of note:
.f16 floating point type mma operation with .m16n8k16 shape requires sm_80 or higher.
Here is a complete example (a modification of what I depicted here):
# cat t153.cu
#include <mma.h>
#include <cuda_fp16.h>
#include <iostream>
#include <stdio.h>
__global__ void mma_fp16_acc_fp32(float *out) {
float c[4] = {0., 0., 0., 0.};
float d[4] = {0., 0., 0., 0.};
half a[8] = {1., 1., 1., 1., 1., 1., 1., 1.};
half b[4] = {1., 1., 1., 1.};
// the above would set our input matrices to all 1
// now lets modify some values
if (threadIdx.x%4 == 0) {
// set the first column of A to be 0, 1, 2, 3, ... 15
a[0] = threadIdx.x/4; a[2] = threadIdx.x/4 + 8;
// set the second row of B to 3,3,3, ... 3
b[1] = 3;}
unsigned const *A = reinterpret_cast<unsigned const *>(&a);
unsigned const *B = reinterpret_cast<unsigned const *>(&b);
float const *C = reinterpret_cast<float const *>(&c);
float *D = reinterpret_cast<float *>(&d);
asm(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
:
"r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
"r"(B[0]), "r"(B[1]),
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
);
memcpy(out+threadIdx.x*2, D, 8);
memcpy(out+8*8+threadIdx.x*2, D+2, 8);
}
int main() {
float* h_C = (float*)malloc(16*8*sizeof(float));
float* d_C;
cudaMalloc(&d_C, 16*8*sizeof(float));
mma_fp16_acc_fp32<<<1, 32>>>(d_C);
cudaDeviceSynchronize();
cudaMemcpy(h_C, d_C, 16*8*sizeof(float), cudaMemcpyDeviceToHost);
for (int i = 0; i < 16; i++){
for (int j = 0; j < 8; j++) std::cout << h_C[i*8+j] << " ";
std::cout << std::endl;}
}
# nvcc -o t153 t153.cu -arch=sm_89
# compute-sanitizer ./t153
========= COMPUTE-SANITIZER
17 17 17 17 17 17 17 17
18 18 18 18 18 18 18 18
19 19 19 19 19 19 19 19
20 20 20 20 20 20 20 20
21 21 21 21 21 21 21 21
22 22 22 22 22 22 22 22
23 23 23 23 23 23 23 23
24 24 24 24 24 24 24 24
25 25 25 25 25 25 25 25
26 26 26 26 26 26 26 26
27 27 27 27 27 27 27 27
28 28 28 28 28 28 28 28
29 29 29 29 29 29 29 29
30 30 30 30 30 30 30 30
31 31 31 31 31 31 31 31
32 32 32 32 32 32 32 32
========= ERROR SUMMARY: 0 errors
#
As indicated in the link I provided, these tensorcore ops compute:
D = A*B+C
In the above example, I have chosen to use/declare A and B as 16-bit floating point, whereas C and D are 32-bit floating point.
If we disassemble the above built code, we observe the following, indicating the SASS level tensorcore op in use:
# cuobjdump -sass ./t153
Fatbin elf code:
================
arch = sm_89
code version = [1,7]
host = linux
compile_size = 64bit
code for sm_89
Fatbin elf code:
================
arch = sm_89
code version = [1,7]
host = linux
compile_size = 64bit
code for sm_89
Function : _Z17mma_fp16_acc_fp32Pf
.headerflags @"EF_CUDA_TEXMODE_UNIFIED EF_CUDA_64BIT_ADDRESS EF_CUDA_SM89 EF_CUDA_VIRTUAL_SM(EF_CUDA_SM89)"
/*0000*/ IMAD.MOV.U32 R1, RZ, RZ, c[0x0][0x28] ; /* 0x00000a00ff017624 */
/* 0x000fc400078e00ff */
/*0010*/ S2R R9, SR_TID.X ; /* 0x0000000000097919 */
/* 0x000e220000002100 */
/*0020*/ HADD2 R3, -RZ.H0_H0, 1, 1 ; /* 0x3c003c00ff037430 */
/* 0x000fe20000000900 */
/*0030*/ IMAD.MOV.U32 R5, RZ, RZ, 0x3c00 ; /* 0x00003c00ff057424 */
/* 0x000fe200078e00ff */
/*0040*/ ULDC.64 UR4, c[0x0][0x118] ; /* 0x0000460000047ab9 */
/* 0x000fe20000000a00 */
/*0050*/ IMAD.MOV.U32 R14, RZ, RZ, 0x3c003c00 ; /* 0x3c003c00ff0e7424 */
/* 0x000fe400078e00ff */
/*0060*/ PRMT R0, R3.reuse, 0x7610, R0 ; /* 0x0000761003007816 */
/* 0x040fe20000000000 */
/*0070*/ IMAD.MOV.U32 R15, RZ, RZ, 0x3c003c00 ; /* 0x3c003c00ff0f7424 */
/* 0x000fe200078e00ff */
/*0080*/ PRMT R2, R3, 0x7610, R2 ; /* 0x0000761003027816 */
/* 0x000fe40000000002 */
/*0090*/ LOP3.LUT P0, RZ, R9, 0x3, RZ, 0xc0, !PT ; /* 0x0000000309ff7812 */
/* 0x001fda000780c0ff */
/*00a0*/ @!P0 SHF.R.U32.HI R4, RZ, 0x2, R9 ; /* 0x00000002ff048819 */
/* 0x000fe20000011609 */
/*00b0*/ @!P0 I2F.F16 R3, 0x3 ; /* 0x0000000300038906 */
/* 0x000fe60000200c00 */
/*00c0*/ @!P0 IADD3 R8, R4, 0x8, RZ ; /* 0x0000000804088810 */
/* 0x000fca0007ffe0ff */
/*00d0*/ @!P0 I2F.F16.U32 R0, R4 ; /* 0x0000000400008306 */
/* 0x000e300000200800 */
/*00e0*/ @!P0 I2F.F16.U32 R2, R8 ; /* 0x0000000800028306 */
/* 0x000e620000200800 */
/*00f0*/ PRMT R0, R0, 0x5410, R5 ; /* 0x0000541000007816 */
/* 0x001fe20000000005 */
/*0100*/ IMAD.MOV.U32 R5, RZ, RZ, 0x3c003c00 ; /* 0x3c003c00ff057424 */
/* 0x000fc600078e00ff */
/*0110*/ PRMT R4, R0, 0x5432, R3 ; /* 0x0000543200047816 */
/* 0x000fe20000000003 */
/*0120*/ IMAD.MOV.U32 R12, RZ, RZ, R0.reuse ; /* 0x000000ffff0c7224 */
/* 0x100fe400078e0000 */
/*0130*/ IMAD.MOV.U32 R3, RZ, RZ, 0x4 ; /* 0x00000004ff037424 */
/* 0x000fe200078e00ff */
/*0140*/ PRMT R13, R2, 0x7610, R0 ; /* 0x00007610020d7816 */
/* 0x002fe20000000000 */
/*0150*/ IMAD.SHL.U32 R2, R9, 0x2, RZ ; /* 0x0000000209027824 */
/* 0x000fc800078e00ff */
/*0160*/ IMAD.WIDE.U32 R2, R2, R3, c[0x0][0x160] ; /* 0x0000580002027625 */
/* 0x000fe400078e0003 */
/*0170*/ HMMA.16816.F32 R4, R12, R4, RZ ; /* 0x000000040c04723c */
/* 0x000f5e00000018ff */
/*0180*/ NOP ; /* 0x0000000000007918 */
/* 0x000fd00000000000 */
/*0190*/ STG.E.U8 [R2.64+0x4], R5 ; /* 0x0000040502007986 */
/* 0x0201e2000c101104 */
/*01a0*/ SHF.R.U32.HI R15, RZ, 0x18, R5.reuse ; /* 0x00000018ff0f7819 */
/* 0x100fe40000011605 */
/*01b0*/ SHF.R.U32.HI R17, RZ, 0x10, R5.reuse ; /* 0x00000010ff117819 */
/* 0x100fe20000011605 */
/*01c0*/ STG.E.U8 [R2.64], R4 ; /* 0x0000000402007986 */
/* 0x000fe2000c101104 */
/*01d0*/ SHF.R.U32.HI R19, RZ, 0x8, R5 ; /* 0x00000008ff137819 */
/* 0x000fe40000011605 */
/*01e0*/ SHF.R.U32.HI R9, RZ, 0x18, R4.reuse ; /* 0x00000018ff097819 */
/* 0x100fe20000011604 */
/*01f0*/ STG.E.U8 [R2.64+0x100], R6 ; /* 0x0001000602007986 */
/* 0x000fe2000c101104 */
/*0200*/ SHF.R.U32.HI R11, RZ, 0x10, R4.reuse ; /* 0x00000010ff0b7819 */
/* 0x100fe40000011604 */
/*0210*/ SHF.R.U32.HI R13, RZ, 0x8, R4 ; /* 0x00000008ff0d7819 */
/* 0x000fe20000011604 */
/*0220*/ STG.E.U8 [R2.64+0x104], R7 ; /* 0x0001040702007986 */
/* 0x000fe2000c101104 */
/*0230*/ SHF.R.U32.HI R21, RZ, 0x18, R6 ; /* 0x00000018ff157819 */
/* 0x000fc40000011606 */
/*0240*/ SHF.R.U32.HI R23, RZ, 0x10, R6.reuse ; /* 0x00000010ff177819 */
/* 0x100fe20000011606 */
/*0250*/ STG.E.U8 [R2.64+0x3], R9 ; /* 0x0000030902007986 */
/* 0x000fe2000c101104 */
/*0260*/ SHF.R.U32.HI R25, RZ, 0x8, R6 ; /* 0x00000008ff197819 */
/* 0x000fe40000011606 */
/*0270*/ SHF.R.U32.HI R27, RZ, 0x18, R7.reuse ; /* 0x00000018ff1b7819 */
/* 0x100fe20000011607 */
/*0280*/ STG.E.U8 [R2.64+0x2], R11 ; /* 0x0000020b02007986 */
/* 0x000fe2000c101104 */
/*0290*/ SHF.R.U32.HI R29, RZ, 0x10, R7.reuse ; /* 0x00000010ff1d7819 */
/* 0x100fe40000011607 */
/*02a0*/ SHF.R.U32.HI R5, RZ, 0x8, R7 ; /* 0x00000008ff057819 */
/* 0x001fe20000011607 */
/*02b0*/ STG.E.U8 [R2.64+0x1], R13 ; /* 0x0000010d02007986 */
/* 0x000fe8000c101104 */
/*02c0*/ STG.E.U8 [R2.64+0x7], R15 ; /* 0x0000070f02007986 */
/* 0x000fe8000c101104 */
/*02d0*/ STG.E.U8 [R2.64+0x6], R17 ; /* 0x0000061102007986 */
/* 0x000fe8000c101104 */
/*02e0*/ STG.E.U8 [R2.64+0x5], R19 ; /* 0x0000051302007986 */
/* 0x000fe8000c101104 */
/*02f0*/ STG.E.U8 [R2.64+0x103], R21 ; /* 0x0001031502007986 */
/* 0x000fe8000c101104 */
/*0300*/ STG.E.U8 [R2.64+0x102], R23 ; /* 0x0001021702007986 */
/* 0x000fe8000c101104 */
/*0310*/ STG.E.U8 [R2.64+0x101], R25 ; /* 0x0001011902007986 */
/* 0x000fe8000c101104 */
/*0320*/ STG.E.U8 [R2.64+0x107], R27 ; /* 0x0001071b02007986 */
/* 0x000fe8000c101104 */
/*0330*/ STG.E.U8 [R2.64+0x106], R29 ; /* 0x0001061d02007986 */
/* 0x000fe8000c101104 */
/*0340*/ STG.E.U8 [R2.64+0x105], R5 ; /* 0x0001050502007986 */
/* 0x000fe2000c101104 */
/*0350*/ EXIT ; /* 0x000000000000794d */
/* 0x000fea0003800000 */
/*0360*/ BRA 0x360; /* 0xfffffff000007947 */
Fatbin ptx code:
================
arch = sm_89
code version = [8,2]
host = linux
compile_size = 64bit
compressed
#
The indicated tensorcore SASS instruction is HMMA.16816.F32 R4, R12, R4, RZ
If you want to see HMMA.16816.F16, then switch the C and D matrices to 16-bit float, and modify the PTX instruction accordingly. Something like this:
# cat t154.cu
#include <mma.h>
#include <cuda_fp16.h>
#include <iostream>
#include <stdio.h>
__global__ void mma_fp16_acc_fp32(float *out) {
half c[4] = {0., 0., 0., 0.};
half d[4] = {0., 0., 0., 0.};
half a[8] = {1., 1., 1., 1., 1., 1., 1., 1.};
half b[4] = {1., 1., 1., 1.};
// the above would set our input matrices to all 1
// now lets modify some values
if (threadIdx.x%4 == 0) {
// set the first column of A to be 0, 1, 2, 3, ... 15
a[0] = threadIdx.x/4; a[2] = threadIdx.x/4 + 8;
// set the second row of B to 3,3,3, ... 3
b[1] = 3;}
unsigned const *A = reinterpret_cast<unsigned const *>(&a);
unsigned const *B = reinterpret_cast<unsigned const *>(&b);
unsigned const *C = reinterpret_cast<unsigned const *>(&c);
unsigned *D = reinterpret_cast<unsigned *>(&d);
asm(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
: "=r"(D[0]), "=r"(D[1])
:
"r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
"r"(B[0]), "r"(B[1]),
"r"(C[0]), "r"(C[1])
);
memcpy(...);
memcpy(...);
}
int main() {
...
}
# nvcc -o t154 t154.cu -arch=sm_89
# cuobjdump -sass ./t154
Fatbin elf code:
================
arch = sm_89
code version = [1,7]
host = linux
compile_size = 64bit
code for sm_89
Fatbin elf code:
================
arch = sm_89
code version = [1,7]
host = linux
compile_size = 64bit
code for sm_89
Function : _Z17mma_fp16_acc_fp32Pf
.headerflags @"EF_CUDA_TEXMODE_UNIFIED EF_CUDA_64BIT_ADDRESS EF_CUDA_SM89 EF_CUDA_VIRTUAL_SM(EF_CUDA_SM89)"
/*0000*/ IMAD.MOV.U32 R1, RZ, RZ, c[0x0][0x28] ; /* 0x00000a00ff017624 */
/* 0x000fc400078e00ff */
/*0010*/ S2R R8, SR_TID.X ; /* 0x0000000000087919 */
/* 0x000e220000002100 */
/*0020*/ HADD2 R3, -RZ.H0_H0, 1, 1 ; /* 0x3c003c00ff037430 */
/* 0x000fe20000000900 */
/*0030*/ IMAD.MOV.U32 R6, RZ, RZ, 0x3c00 ; /* 0x00003c00ff067424 */
/* 0x000fe200078e00ff */
/*0040*/ ULDC.64 UR4, c[0x0][0x118] ; /* 0x0000460000047ab9 */
/* 0x000fe20000000a00 */
/*0050*/ IMAD.MOV.U32 R7, RZ, RZ, 0x3c003c00 ; /* 0x3c003c00ff077424 */
/* 0x000fe400078e00ff */
/*0060*/ PRMT R0, R3.reuse, 0x7610, R0 ; /* 0x0000761003007816 */
/* 0x040fe20000000000 */
/*0070*/ IMAD.MOV.U32 R14, RZ, RZ, 0x3c003c00 ; /* 0x3c003c00ff0e7424 */
/* 0x000fe200078e00ff */
/*0080*/ PRMT R2, R3, 0x7610, R2 ; /* 0x0000761003027816 */
/* 0x000fe20000000002 */
/*0090*/ IMAD.MOV.U32 R15, RZ, RZ, 0x3c003c00 ; /* 0x3c003c00ff0f7424 */
/* 0x000fe200078e00ff */
/*00a0*/ LOP3.LUT P0, RZ, R8, 0x3, RZ, 0xc0, !PT ; /* 0x0000000308ff7812 */
/* 0x001fda000780c0ff */
/*00b0*/ @!P0 SHF.R.U32.HI R4, RZ, 0x2, R8 ; /* 0x00000002ff048819 */
/* 0x000fe20000011608 */
/*00c0*/ @!P0 I2F.F16 R3, 0x3 ; /* 0x0000000300038906 */
/* 0x000fe60000200c00 */
/*00d0*/ @!P0 IADD3 R5, R4, 0x8, RZ ; /* 0x0000000804058810 */
/* 0x000fca0007ffe0ff */
/*00e0*/ @!P0 I2F.F16.U32 R0, R4 ; /* 0x0000000400008306 */
/* 0x000e300000200800 */
/*00f0*/ @!P0 I2F.F16.U32 R2, R5 ; /* 0x0000000500028306 */
/* 0x000e620000200800 */
/*0100*/ PRMT R0, R0, 0x5410, R6 ; /* 0x0000541000007816 */
/* 0x001fc80000000006 */
/*0110*/ PRMT R6, R0, 0x5432, R3 ; /* 0x0000543200067816 */
/* 0x000fe20000000003 */
/*0120*/ IMAD.MOV.U32 R12, RZ, RZ, R0.reuse ; /* 0x000000ffff0c7224 */
/* 0x100fe400078e0000 */
/*0130*/ IMAD.MOV.U32 R3, RZ, RZ, 0x4 ; /* 0x00000004ff037424 */
/* 0x000fe200078e00ff */
/*0140*/ PRMT R13, R2, 0x7610, R0 ; /* 0x00007610020d7816 */
/* 0x002fe20000000000 */
/*0150*/ IMAD.SHL.U32 R2, R8, 0x2, RZ ; /* 0x0000000208027824 */
/* 0x000fc800078e00ff */
/*0160*/ IMAD.WIDE.U32 R2, R2, R3, c[0x0][0x160] ; /* 0x0000580002027625 */
/* 0x000fe400078e0003 */
/*0170*/ HMMA.16816.F16 R6, R12, R6, RZ ; /* 0x000000060c06723c */
/* 0x000f5e00000008ff */
/*0180*/ NOP ; /* 0x0000000000007918 */
/* 0x000fd00000000000 */
/*0190*/ STG.E.U8 [R2.64], R6 ; /* 0x0000000602007986 */
/* 0x020fe2000c101104 */
/*01a0*/ SHF.R.U32.HI R5, RZ, 0x18, R6.reuse ; /* 0x00000018ff057819 */
/* 0x100fe40000011606 */
/*01b0*/ SHF.R.U32.HI R9, RZ, 0x10, R6.reuse ; /* 0x00000010ff097819 */
/* 0x100fe20000011606 */
/*01c0*/ STG.E.U8 [R2.64+0x4], R7 ; /* 0x0000040702007986 */
/* 0x000fe2000c101104 */
/*01d0*/ SHF.R.U32.HI R11, RZ, 0x8, R6 ; /* 0x00000008ff0b7819 */
/* 0x000fe40000011606 */
/*01e0*/ SHF.R.U32.HI R13, RZ, 0x18, R7.reuse ; /* 0x00000018ff0d7819 */
/* 0x100fe20000011607 */
/*01f0*/ STG.E.U8 [R2.64+0x3], R5 ; /* 0x0000030502007986 */
/* 0x000fe2000c101104 */
/*0200*/ SHF.R.U32.HI R15, RZ, 0x10, R7.reuse ; /* 0x00000010ff0f7819 */
/* 0x100fe40000011607 */
/*0210*/ SHF.R.U32.HI R17, RZ, 0x8, R7 ; /* 0x00000008ff117819 */
/* 0x000fe20000011607 */
/*0220*/ STG.E.U8 [R2.64+0x2], R9 ; /* 0x0000020902007986 */
/* 0x000fe8000c101104 */
/*0230*/ STG.E.U8 [R2.64+0x1], R11 ; /* 0x0000010b02007986 */
/* 0x000fe8000c101104 */
/*0240*/ STG.E.U8 [R2.64+0x7], R13 ; /* 0x0000070d02007986 */
/* 0x000fe8000c101104 */
/*0250*/ STG.E.U8 [R2.64+0x6], R15 ; /* 0x0000060f02007986 */
/* 0x000fe8000c101104 */
/*0260*/ STG.E.U8 [R2.64+0x5], R17 ; /* 0x0000051102007986 */
/* 0x000fe2000c101104 */
/*0270*/ EXIT ; /* 0x000000000000794d */
(I have removed non-essential lines due to hitting the character limit in my answer).