Search code examples
pythonc++cbase64

How does Python's Binascii.a2b_base64 (base64.b64decode) work?


I checked the source code in Python and implemented a function that is the same as the Python's binascii.a2b_base64 function.

This function is located under path Python-3.12.4/Modules/binascii.c, line 387.

static PyObject *
binascii_a2b_base64_impl(PyObject *module, Py_buffer *data, int strict_mode)

I used C++ and re implemented this function in my own code according to the original function in Python, in order to better understand and learn the working principle of Base64 decoding.

However, I don't know why the function I implemented cannot handle non Base64 encoded characters correctly.

I have checked these codes and confirmed that they do not affect the function's handling of non Base64 encoded characters, such as function [_PyBytesWriter_Init, _PyBytesWriter_Alloc, _PyBytesWriter_Finish, ...], and ignored it from my code.

When processing Base64 strings that comply with the RFC4648 standard, as well as, In the case where only \n is used as a non Base64 encoded character, the function I implemented will achieve the same result as the corresponding function in Python.
For example:

const char *encoded = {
    "QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVpBQkNERUZHSElKS0xNTk9QUVJTVFVW\n"
    "V1hZWkFCQ0RFRkdISUpLTE1OT1BRUlNUVVZXWFlaQUJDREVGR0hJSktMTU5PUFFS\n"
    "U1RVVldYWVpBQkNERUZHSElKS0xNTk9QUVJTVFVWV1hZWg==\n"
};

Using either my function or Python's binascii.a2b_base64 function will yield the same result as the following:

ABCDEFGHIJKLMNOPQRSTUVWXYZABCDEFGHIJKLMNOPQRSTUVWXYZABCDEFGHIJKLMNOPQRSTUVWXYZABCDEFGHIJKLMNOPQRSTUVWXYZABCDEFGHIJKLMNOPQRSTUVWXYZ

Here is the specific implementation of my code:

#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdint>
#include <stdexcept>

#define BASE64PAD '='

constexpr uint8_t b64de_table[256] = {
    255,255,255,255, 255,255,255,255, 255,255,255,255, 255,255,255,255,
    255,255,255,255, 255,255,255,255, 255,255,255,255, 255,255,255,255,
    255,255,255,255, 255,255,255,255, 255,255,255, 62, 255,255,255, 63,
    52 , 53, 54, 55,  56, 57, 58, 59,  60, 61,255,255, 255,  0,255,255,

    255,  0,  1,  2,   3,  4,  5,  6,   7,  8,  9, 10,  11, 12, 13, 14,
    15 , 16, 17, 18,  19, 20, 21, 22,  23, 24, 25,255, 255,255,255,255,
    255, 26, 27, 28,  29, 30, 31, 32,  33, 34, 35, 36,  37, 38, 39, 40,
    41 , 42, 43, 44,  45, 46, 47, 48,  49, 50, 51,255, 255,255,255,255,

    255,255,255,255, 255,255,255,255, 255,255,255,255, 255,255,255,255,
    255,255,255,255, 255,255,255,255, 255,255,255,255, 255,255,255,255,
    255,255,255,255, 255,255,255,255, 255,255,255,255, 255,255,255,255,
    255,255,255,255, 255,255,255,255, 255,255,255,255, 255,255,255,255,

    255,255,255,255, 255,255,255,255, 255,255,255,255, 255,255,255,255,
    255,255,255,255, 255,255,255,255, 255,255,255,255, 255,255,255,255,
    255,255,255,255, 255,255,255,255, 255,255,255,255, 255,255,255,255,
    255,255,255,255, 255,255,255,255, 255,255,255,255, 255,255,255,255};

uint8_t *
pyBase64Decode(const char *buffer, size_t &length,
               bool strict_mode = false)
{
    std::string error_message;

    const uint8_t *ascii_data = (const uint8_t *)buffer;
    size_t ascii_len = length;
    bool padding_started = 0;

    size_t bin_len = ascii_len / 4 * 3; 
    uint8_t *bin_data = new (std::nothrow) uint8_t[bin_len + 1];
    if(!bin_data) {
        throw std::runtime_error("Failed to allocate memory for bin_data.");
    }
    uint8_t *bin_data_start = bin_data;
    bin_data[bin_len] = 0x0;

    uint8_t leftchar = 0;
    uint32_t quad_pos = 0;
    uint32_t pads = 0;

    if(strict_mode && (ascii_len > 0) && (*ascii_data == BASE64PAD)) {
        error_message = "Leading padding not allowed.";
        goto error_end;
    }

    size_t i;
    uint8_t this_ch;
    for(i = 0; i < ascii_len; ++i) {
        this_ch = ascii_data[i];

        if(this_ch == BASE64PAD) {
            padding_started = true;
            // If the current character is a padding character, the length
            // will be reduced by one to obtain the decoded true length.
            bin_len--;

            if(strict_mode && (!quad_pos)) {
                error_message = "Excess padding not allowed.";
                goto error_end;
            }

            if((quad_pos >= 2) && (quad_pos + (++pads) >= 4)) {

                if(strict_mode && ((i + 1) < ascii_len)) {
                    error_message = "Excess data after padding.";
                    goto error_end;
                }

                goto done;
            }

            continue;
        }

        this_ch = b64de_table[this_ch];
        if(this_ch == 255) {
            if(strict_mode) {
                error_message = "Only base64 data is allowed.";
                goto error_end;
            }
            continue;
        }

        if(strict_mode && padding_started) {
            error_message = "Discontinuous padding not allowed.";
            goto error_end;
        }

        pads = 0;

        switch(quad_pos) {
        case 0:
            quad_pos = 1;
            leftchar = this_ch;
            break;
        case 1:
            quad_pos = 2;
            *bin_data++ = (leftchar << 2) | (this_ch >> 4);
            leftchar = this_ch & 0xf;
            break;
        case 2:
            quad_pos = 3;
            *bin_data++ = (leftchar << 4) | (this_ch >> 2);
            leftchar = this_ch & 0x3;
            break;
        case 3:
            quad_pos = 0;
            *bin_data++ = (leftchar << 6) | (this_ch);
            leftchar = 0;
            break;
        }
    }

    if(quad_pos) {
        if(quad_pos == 1) {
            char tmpMsg[128]{};
            snprintf(tmpMsg, sizeof(tmpMsg),
                    "Invalid base64-encoded string: "
                    "number of data characters (%zd) cannot be 1 more "
                    "than a multiple of 4",
                    (bin_data - bin_data_start) / 3 * 4 + 1);
            error_message = tmpMsg;
            goto error_end;
        } else {
            error_message = "Incorrect padding.";
            goto error_end;
        }
        error_end:
        delete[] bin_data;
        throw std::runtime_error(error_message);
    }

done:
    length = bin_len;
    return bin_data_start;
}

How to use this function:

int main()
{
    const char *encoded = "aGVsbG8sIHdvcmxkLg==";
    size_t length = strlen(encoded);
    uint8_t *decoded = pyBase64Decode(encoded, length);
    printf("decoded: %s\n", decoded);
    return 0;
}

Here are a few samples with different results after executing Python and my code.

original decoded:

stackoverflow

original encoded:

c3RhY2tvdmVyZmxvdw==

sample 1:

original "c3##RhY2t...vdmV!?y~Zmxvdw=="
result of python "stackoverflow"
result of pyBase64Decode "stackoverflowP"[^print_method_1]
result of pyBase64Decode "stackoverflow"[^print_method_2] but, length: 19

sample 2:

original "c3\n\nRh~Y2tvd#$mVyZmx$vdw=="
result of python "stackoverflow"
result of pyBase64Decode "stackoverflow"[^print_method_1] but, length: 16

sample 3:

original "c3Rh$$$$$$$$$$$$$$$$$$$$$Y2tvdmVy###############Zmxvdw=="
result of python "stackoverflow"
result of pyBase64Decode "stackoverflowP\2;SP2;SPROFILE_"[^print_method_1] length: 40 Bytes
result of pyBase64Decode "stackoverflow"[^print_method_2] but, length: 40

[^print_method_1]: cout << std::string((char *)decoded, length) << endl;
[^print_method_2]: printf("%s", decoded);


Solution

  • I didn't check the source code of the Python implementation, but the issue in your C++ implementation is in the if(this_ch == 255) block:

    In this case the length of the output will be impacted, as you had based bin_len on ascii_len, where the latter included these non-base64 characters. But as they have no bearing to the output, bin_len will need to be adapted.

    I would suggest to fix this as follows:

    • Keep track of the number of characters that have no data representation, including = and non-base64 characters. You could use a variable skip for that, initialised at 0.

    • At the very end, in the done section, recalculate bin_len with that information, and only then place the \0 terminator in bin_data.

    Here is your code with just those modifications. Comments indicate where the changes occur:

    uint8_t *
    pyBase64Decode(const char *buffer, size_t &length,
                   bool strict_mode = false)
    {
        size_t skip = 0; // Add this variable
        std::string error_message;
    
        const uint8_t *ascii_data = (const uint8_t *)buffer;
        size_t ascii_len = length;
        bool padding_started = 0;
    
        size_t bin_len = ascii_len / 4 * 3; 
        uint8_t *bin_data = new (std::nothrow) uint8_t[bin_len + 1];
        if(!bin_data) {
            throw std::runtime_error("Failed to allocate memory for bin_data.");
        }
        uint8_t *bin_data_start = bin_data;
        bin_data[bin_len] = 0x0; // This could be omitted, as we will do it in "done".
    
        uint8_t leftchar = 0;
        uint32_t quad_pos = 0;
        uint32_t pads = 0;
    
        if(strict_mode && (ascii_len > 0) && (*ascii_data == BASE64PAD)) {
            error_message = "Leading padding not allowed.";
            goto error_end;
        }
    
        size_t i;
        uint8_t this_ch;
        for(i = 0; i < ascii_len; ++i) {
            this_ch = ascii_data[i];
    
            if(this_ch == BASE64PAD) {
                padding_started = true;
                skip++; // Instead of decreasing bin_len, increase the new variable
                if(strict_mode && (!quad_pos)) {
                    error_message = "Excess padding not allowed.";
                    goto error_end;
                }
                if((quad_pos >= 2) && (quad_pos + (++pads) >= 4)) {
                    if(strict_mode && ((i + 1) < ascii_len)) {
                        error_message = "Excess data after padding.";
                        goto error_end;
                    }
                    goto done;
                }
                continue;
            }
            this_ch = b64de_table[this_ch];
            if(this_ch == 255) {
                if(strict_mode) {
                    error_message = "Only base64 data is allowed.";
                    goto error_end;
                }
                skip++;
                continue;
            }
            if(strict_mode && padding_started) {
                error_message = "Discontinuous padding not allowed.";
                goto error_end;
            }
            pads = 0;
            switch(quad_pos) {
            case 0:
                quad_pos = 1;
                leftchar = this_ch;
                break;
            case 1:
                quad_pos = 2;
                *bin_data++ = (leftchar << 2) | (this_ch >> 4);
                leftchar = this_ch & 0xf;
                break;
            case 2:
                quad_pos = 3;
                *bin_data++ = (leftchar << 4) | (this_ch >> 2);
                leftchar = this_ch & 0x3;
                break;
            case 3:
                quad_pos = 0;
                *bin_data++ = (leftchar << 6) | (this_ch);
                leftchar = 0;
                break;
            }
        }
    
        if(quad_pos) {
            if(quad_pos == 1) {
                char tmpMsg[128]{};
                snprintf(tmpMsg, sizeof(tmpMsg),
                        "Invalid base64-encoded string: "
                        "number of data characters (%zd) cannot be 1 more "
                        "than a multiple of 4",
                        (bin_data - bin_data_start) / 3 * 4 + 1);
                error_message = tmpMsg;
                goto error_end;
            } else {
                error_message = "Incorrect padding.";
                goto error_end;
            }
            error_end:
            delete[] bin_data;
            throw std::runtime_error(error_message);
        }
    
    done:
        bin_len = (ascii_len - skip) * 3 / 4;  // Recalculate
        length = bin_len;
        bin_data[bin_len] = 0x0; // ...and place the terminator only now
        return bin_data_start;
    }