Search code examples
c++windowsencryptionaescng

Straight forward example using CryptoAPI Next Generation (CNG) to encrypt data


I'd like to implement data encryption and decryption in a C++ application running on Windows. I've spent considerable time looking around the Web and am thinking I should probably use the Windows Cryptography API: Next Generation (CNG) functions (although I'm open to better alternatives).

What I find everywhere are complex examples that do all sorts of stuff. I don't feel that confident in this area and so I'd like to find a simple example. In the end, I need a method that takes a string and encrypts, and another methods that decrypts the data back to the string. The user would supply a password for both operations.

This must have been done countless times already. Can anyone point me to a complete and competent example? Ultimately, I'll end up with an Encrypt() and Decrypt() method.

Something that is both secure and performant would be ideal.


Solution

  • Before encrypting (and decrypting) you need to derive key from password with key derivation functions (for example PBKDF2 with SHA256). To prevent pre-computed dictionary attacks in additional to password you will also need random string (called salt).
    Next pick cipher algorithm (AES with 256-bit key is good one) and cipher mode (ECB cipher mode considered weak, so use any other for example CBC). Also it will require one more random string (called initialization vector).

    So encrypting algorithm will be:

    1. Generate random salt
    2. Derive key(password, salt) = key
    3. Generate random IV
    4. Encrypt(key, IV, plain text) = cipher text

    Input parameters: plain text, password
    Output parameters: cipher text, salt, IV

    Decrypting algorithm will be:

    1. Derive key(password, salt) = key
    2. Decrypt(key, iv, cipher text) = plain text
      Input parameters: cipher text, salt, iv, password
      Output parameters: plain text

    Sample code:

    #include <Windows.h>
    #include <iostream>
    #include <vector>
    #include <array>
    
    #pragma comment(lib, "bcrypt")
    
    static NTSTATUS gen_random(BYTE* buf, ULONG buf_len)
    {
        BCRYPT_ALG_HANDLE hAlg = nullptr;
        NTSTATUS status = NTE_FAIL;
        do {
            status = BCryptOpenAlgorithmProvider(&hAlg, L"RNG", nullptr, 0);
            if (status != ERROR_SUCCESS) {
                return status;
            }
            status = BCryptGenRandom(hAlg, buf, buf_len, 0);
        } while (0);
        if (hAlg) {
            BCryptCloseAlgorithmProvider(hAlg, 0);
        }
        return status;
    }
    
    static NTSTATUS derive_key(BYTE* pass, ULONG pass_len, BYTE* salt,
                               ULONG salt_len, const ULONG iteration, BYTE* derived_key, ULONG derived_key_len)
    {
        BCRYPT_ALG_HANDLE hPrf = nullptr;
        NTSTATUS status = ERROR_SUCCESS;
        do {
            status = BCryptOpenAlgorithmProvider(&hPrf, L"SHA256", nullptr, BCRYPT_ALG_HANDLE_HMAC_FLAG);
            if (status != ERROR_SUCCESS) {
                break;
            }
            status = BCryptDeriveKeyPBKDF2(hPrf, pass, pass_len, salt, salt_len, iteration, derived_key, derived_key_len, 0);
        } while (0);
        if (hPrf) {
            BCryptCloseAlgorithmProvider(hPrf, 0);
        }
        return status;
    }
    
    static NTSTATUS do_encrypt(BYTE* key, ULONG key_len, BYTE* plain_text, ULONG plain_text_len,
                               std::vector<BYTE>& iv, std::vector<BYTE>& cipher_text)
    {
        NTSTATUS status = NTE_FAIL;
        BCRYPT_ALG_HANDLE hAlg = nullptr;
        BCRYPT_KEY_HANDLE hKey = nullptr;
        do {
            status = BCryptOpenAlgorithmProvider(&hAlg, L"AES", nullptr, 0);
            if (status != ERROR_SUCCESS) {
                break;
            }
    
            /* create key object */
            status = BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, key, key_len, 0);
            if (status != ERROR_SUCCESS) {
                break;
            }
    
            /* set chaining mode */
            std::wstring mode = BCRYPT_CHAIN_MODE_CBC;
            BYTE* ptr = reinterpret_cast<BYTE*>(const_cast<wchar_t*>(mode.data()));
            ULONG size = static_cast<ULONG>(sizeof(wchar_t) * (mode.size() + 1));
            status = BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, ptr, size, 0);
            if (status != ERROR_SUCCESS) {
                break;
            }
    
            /* generate iv */
            ULONG block_len = 0;
            ULONG res = 0;
            status = BCryptGetProperty(hAlg, BCRYPT_BLOCK_LENGTH, reinterpret_cast<BYTE*>(&block_len), sizeof(block_len), &res, 0);
            if (status != ERROR_SUCCESS) {
                break;
            }
            iv.resize(block_len);
            status = gen_random(iv.data(), static_cast<ULONG>(iv.size()));
            if (status != ERROR_SUCCESS) {
                break;
            }
    
            /* BCryptEncrypt modify iv parameter, so we need to make copy */
            std::vector<BYTE> iv_copy = iv;
    
            /* get cipher text length */
            ULONG cipher_text_len = 0;
            status = BCryptEncrypt(hKey, plain_text, plain_text_len, nullptr, iv_copy.data(), static_cast<ULONG>(iv_copy.size()),
                                   nullptr, cipher_text_len, &cipher_text_len, BCRYPT_BLOCK_PADDING);
            if (status != ERROR_SUCCESS) {
                break;
            }
            cipher_text.resize(static_cast<size_t>(cipher_text_len));
    
            /* now encrypt */
            status = BCryptEncrypt(hKey, plain_text, plain_text_len, nullptr, iv_copy.data(), static_cast<ULONG>(iv_copy.size()),
                                   cipher_text.data(), cipher_text_len, &cipher_text_len, BCRYPT_BLOCK_PADDING);
        } while (0);
        /* cleanup */
        if (hKey) {
            BCryptDestroyKey(hKey);
        }
        if (hAlg) {
            BCryptCloseAlgorithmProvider(hAlg, 0);
        }
        return status;
    }
    
    static NTSTATUS do_decrypt(BYTE* key, ULONG key_len, BYTE* cipher_text, ULONG cipher_text_len,
                               const std::vector<BYTE>& iv, std::vector<BYTE>& plain_text)
    {
        NTSTATUS status = NTE_FAIL;
        BCRYPT_ALG_HANDLE hAlg = nullptr;
        BCRYPT_KEY_HANDLE hKey = nullptr;
        do {
            status = BCryptOpenAlgorithmProvider(&hAlg, L"AES", nullptr, 0);
            if (status != ERROR_SUCCESS) {
                break;
            }
    
            /* create key object */
            status = BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, key, key_len, 0);
            if (status != ERROR_SUCCESS) {
                break;
            }
    
            /* set chaining mode */
            std::wstring mode = BCRYPT_CHAIN_MODE_CBC;
            BYTE* ptr = reinterpret_cast<BYTE*>(const_cast<wchar_t*>(mode.data()));
            ULONG size = static_cast<ULONG>(sizeof(wchar_t) * (mode.size() + 1));
            status = BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, ptr, size, 0);
            if (status != ERROR_SUCCESS) {
                break;
            }
    
            /* BCryptEncrypt modify iv parameter, so we need to make copy */
            std::vector<BYTE> iv_copy = iv;
    
            /* get expected plain text length */
            ULONG plain_text_len = 0;
            status = BCryptDecrypt(hKey, cipher_text, cipher_text_len, nullptr, iv_copy.data(), static_cast<ULONG>(iv_copy.size()),
                                   nullptr, plain_text_len, &plain_text_len, BCRYPT_BLOCK_PADDING);
            plain_text.resize(static_cast<size_t>(plain_text_len));
    
            /* decrypt */
            status = BCryptDecrypt(hKey, cipher_text, cipher_text_len, nullptr, iv_copy.data(), static_cast<ULONG>(iv_copy.size()),
                                   plain_text.data(), plain_text_len, &plain_text_len, BCRYPT_BLOCK_PADDING);
            /* actualize size */
            plain_text.resize(static_cast<size_t>(plain_text_len));
        } while (0);
        /* cleanup */
        if (hKey) {
            BCryptDestroyKey(hKey);
        }
        if (hAlg) {
            BCryptCloseAlgorithmProvider(hAlg, 0);
        }
        return status;
    }
    
    
    NTSTATUS encrypt(BYTE* pass, ULONG pass_len, const std::vector<BYTE>& plain_text,
                     std::vector<BYTE>& salt, std::vector<BYTE>& iv, std::vector<BYTE>& cipher_text)
    {
        NTSTATUS status = NTE_FAIL;
        salt.resize(8);
        std::array<BYTE, 32> key{0x00};
        do {
            /* generate salt */
            status = gen_random(salt.data(), static_cast<ULONG>(salt.size()));
            if (status != ERROR_SUCCESS) {
                break;
            }
            /* derive key from password using SHA256 algorithm and 20000 iteration */
            status = derive_key(pass, pass_len, salt.data(), static_cast<ULONG>(salt.size()), 20000, key.data(), key.size());
            if (status != ERROR_SUCCESS) {
                break;
            }
            /* encrypt */
            status = do_encrypt(key.data(), static_cast<ULONG>(key.size()), const_cast<BYTE*>(plain_text.data()),
                                static_cast<ULONG>(plain_text.size()), iv, cipher_text);
        } while (0);
        SecureZeroMemory(key.data(), key.size());
        return status;
    }
    
    
    NTSTATUS decrypt(BYTE* pass, ULONG pass_len, const std::vector<BYTE>& salt, const std::vector<BYTE>& iv,
                     const std::vector<BYTE>& cipher_text, std::vector<BYTE>& plain_text)
    {
        NTSTATUS status = NTE_FAIL;
        std::array<BYTE, 32> key{0x00};
        do {
            /* derive key from password using same algorithm, salt and iteraion count */
            status = derive_key(pass, pass_len, const_cast<BYTE*>(salt.data()), static_cast<ULONG>(salt.size()),
                                20000, key.data(), key.size());
            if (status != ERROR_SUCCESS) {
                break;
            }
            /* decrypt */
            status = do_decrypt(key.data(), static_cast<ULONG>(key.size()), const_cast<BYTE*>(cipher_text.data()),
            static_cast<ULONG>(cipher_text.size()), iv, plain_text);
        } while (0);
        SecureZeroMemory(key.data(), key.size());
        return status;
    }