Search code examples
cwindowsaes-gcmcng

How to chain BCryptEncrypt and BCryptDecrypt calls using AES in GCM mode?


Using the Windows CNG API, I am able to encrypt and decrypt individual blocks of data with authentication, using AES in GCM mode. I now want to encrypt and decrypt multiple buffers in a row.

According to documentation for CNG, the following scenario is supported:

If the input to encryption or decryption is scattered across multiple buffers, then you must chain calls to the BCryptEncrypt and BCryptDecrypt functions. Chaining is indicated by setting the BCRYPT_AUTH_MODE_IN_PROGRESS_FLAG flag in the dwFlags member.

If I understand it correctly, this means that I can invoke BCryptEncrypt sequentially on multiple buffers an obtain the authentication tag for the combined buffers at the end. Similarly, I can invoke BCryptDecrypt sequentially on multiple buffers while deferring the actual authentication check until the end. I can not get that to work though, it looks like the value for dwFlags is ignored. Whenever I use BCRYPT_AUTH_MODE_IN_PROGRESS_FLAG, I get a return value of 0xc000a002 , which is equal to STATUS_AUTH_TAG_MISMATCH as defined in ntstatus.h.

Even though the parameter pbIV is marked as in/out, the elements pointed to by the parameter pbIV do not get modified by BCryptEncrypt(). Is that expected? I also looked at the field pbNonce in the BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO structure, pointed to by the pPaddingInfo pointer, but that one does not get modified either. I also tried "manually" advancing the IV, modifying the contents myself according to the counter scheme, but that did not help either.

What is the right procedure to chain the BCryptEncrypt and/or BCryptDecrypt functions successfully?


Solution

  • I managed to get it to work. It seems that the problem is in MSDN, it should mention setting BCRYPT_AUTH_MODE_CHAIN_CALLS_FLAG instead of BCRYPT_AUTH_MODE_IN_PROGRESS_FLAG.

    #include <windows.h>
    #include <assert.h>
    #include <vector>
    #include <Bcrypt.h>
    #pragma comment(lib, "bcrypt.lib")
    
    std::vector<BYTE> MakePatternBytes(size_t a_Length)
    {
        std::vector<BYTE> result(a_Length);
        for (size_t i = 0; i < result.size(); i++)
        {
            result[i] = (BYTE)i;
        }
    
        return result;
    }
    
    std::vector<BYTE> MakeRandomBytes(size_t a_Length)
    {
        std::vector<BYTE> result(a_Length);
        for (size_t i = 0; i < result.size(); i++)
        {
            result[i] = (BYTE)rand();
        }
    
        return result;
    }
    
    int _tmain(int argc, _TCHAR* argv[])
    {
        NTSTATUS bcryptResult = 0;
        DWORD bytesDone = 0;
    
        BCRYPT_ALG_HANDLE algHandle = 0;
        bcryptResult = BCryptOpenAlgorithmProvider(&algHandle, BCRYPT_AES_ALGORITHM, 0, 0);
        assert(BCRYPT_SUCCESS(bcryptResult) || !"BCryptOpenAlgorithmProvider");
    
        bcryptResult = BCryptSetProperty(algHandle, BCRYPT_CHAINING_MODE, (BYTE*)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0);
        assert(BCRYPT_SUCCESS(bcryptResult) || !"BCryptSetProperty(BCRYPT_CHAINING_MODE)");
    
        BCRYPT_AUTH_TAG_LENGTHS_STRUCT authTagLengths;
        bcryptResult = BCryptGetProperty(algHandle, BCRYPT_AUTH_TAG_LENGTH, (BYTE*)&authTagLengths, sizeof(authTagLengths), &bytesDone, 0);
        assert(BCRYPT_SUCCESS(bcryptResult) || !"BCryptGetProperty(BCRYPT_AUTH_TAG_LENGTH)");
    
        DWORD blockLength = 0;
        bcryptResult = BCryptGetProperty(algHandle, BCRYPT_BLOCK_LENGTH, (BYTE*)&blockLength, sizeof(blockLength), &bytesDone, 0);
        assert(BCRYPT_SUCCESS(bcryptResult) || !"BCryptGetProperty(BCRYPT_BLOCK_LENGTH)");
    
        BCRYPT_KEY_HANDLE keyHandle = 0;
        {
            const std::vector<BYTE> key = MakeRandomBytes(blockLength);
            bcryptResult = BCryptGenerateSymmetricKey(algHandle, &keyHandle, 0, 0, (PUCHAR)&key[0], key.size(), 0);
            assert(BCRYPT_SUCCESS(bcryptResult) || !"BCryptGenerateSymmetricKey");
        }
    
        const size_t GCM_NONCE_SIZE = 12;
        const std::vector<BYTE> origNonce = MakeRandomBytes(GCM_NONCE_SIZE);
        const std::vector<BYTE> origData  = MakePatternBytes(256);
    
        // Encrypt data as a whole
        std::vector<BYTE> encrypted = origData;
        std::vector<BYTE> authTag(authTagLengths.dwMinLength);
        {
            BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO authInfo;
            BCRYPT_INIT_AUTH_MODE_INFO(authInfo);
            authInfo.pbNonce = (PUCHAR)&origNonce[0];
            authInfo.cbNonce = origNonce.size();
            authInfo.pbTag   = &authTag[0];
            authInfo.cbTag   = authTag.size();
    
            bcryptResult = BCryptEncrypt
                (
                keyHandle,
                &encrypted[0], encrypted.size(),
                &authInfo,
                0, 0,
                &encrypted[0], encrypted.size(),
                &bytesDone, 0
                );
    
            assert(BCRYPT_SUCCESS(bcryptResult) || !"BCryptEncrypt");
            assert(bytesDone == encrypted.size());
        }
    
        // Decrypt data in two parts
        std::vector<BYTE> decrypted = encrypted;
        {
            DWORD partSize = decrypted.size() / 2;
    
            std::vector<BYTE> macContext(authTagLengths.dwMaxLength);
    
            BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO authInfo;
            BCRYPT_INIT_AUTH_MODE_INFO(authInfo);
            authInfo.pbNonce = (PUCHAR)&origNonce[0];
            authInfo.cbNonce = origNonce.size();
            authInfo.pbTag   = &authTag[0];
            authInfo.cbTag   = authTag.size();
            authInfo.pbMacContext = &macContext[0];
            authInfo.cbMacContext = macContext.size();
    
            // IV value is ignored on first call to BCryptDecrypt.
            // This buffer will be used to keep internal IV used for chaining.
            std::vector<BYTE> contextIV(blockLength);
    
            // First part
            authInfo.dwFlags = BCRYPT_AUTH_MODE_CHAIN_CALLS_FLAG;
            bcryptResult = BCryptDecrypt
                (
                keyHandle,
                &decrypted[0*partSize], partSize,
                &authInfo,
                &contextIV[0], contextIV.size(),
                &decrypted[0*partSize], partSize,
                &bytesDone, 0
                );
    
            assert(BCRYPT_SUCCESS(bcryptResult) || !"BCryptDecrypt");
            assert(bytesDone == partSize);
    
            // Second part
            authInfo.dwFlags &= ~BCRYPT_AUTH_MODE_CHAIN_CALLS_FLAG;
            bcryptResult = BCryptDecrypt
                (
                keyHandle,
                &decrypted[1*partSize], partSize,
                &authInfo,
                &contextIV[0], contextIV.size(),
                &decrypted[1*partSize], partSize,
                &bytesDone, 0
                );
    
            assert(BCRYPT_SUCCESS(bcryptResult) || !"BCryptDecrypt");
            assert(bytesDone == partSize);
        }
    
        // Check decryption
        assert(decrypted == origData);
    
        // Cleanup
        BCryptDestroyKey(keyHandle);
        BCryptCloseAlgorithmProvider(algHandle, 0);
    
        return 0;
    }