Search code examples
c#socketsencryptioncrypto++

Encrypting in C# and Decrypting with Crypto++ does not work


Encryption :

    public static byte[] EncryptAES(Message msg)
    {

        byte[] encText; // This will keep the encrypted text
        byte[] encLength; // This will keep the length of the encrypted text
        byte[] finalEncText = null; // This keeps the encLength + encText (#####[encText] / [encLength][encText])

        // Building the plaintext message : 
        string plainText = msg.MessageCode.ToString();

        if (msg.Parameters != null)
            foreach (string parameter in msg.Parameters)
                plainText += parameter;

        // Encrypting the plaintext :
        encText = EncryptAES(plainText);

        string encLen = encText.Length.ToString();
        string fittedEncLen = MessageSender.FitStringIntoSize(encLen, Globals.MESSAGE_LENGTH_LEN); // Fit the length of the encrypted text into a certain size
        encLength = Encoding.ASCII.GetBytes(fittedEncLen); // convert the length into byte[]

        finalEncText = new byte[encLength.Length + encText.Length];
        System.Buffer.BlockCopy(encLength, 0, finalEncText, 0, encLength.Length);
        System.Buffer.BlockCopy(encText, 0, finalEncText, encLength.Length, encText.Length); // Copy the byte arrays into the new byte array


        return finalEncText;
    }

    private static byte[] EncryptAES(string text)
    {
        // This function encrypts a plaintext message using the aes key we have from the server

        if (AesKey == null || IV == null) // If we dont have an aes key / iv, dont encrypt
            return Encoding.ASCII.GetBytes(text);


        byte[] encryptedText;

        try
        {
            Aes aes = Aes.Create();
            aes.Mode = CipherMode.CBC;
            aes.Padding = PaddingMode.Zeros;
            aes.Key = Encoding.ASCII.GetBytes(AesKey);
            aes.IV = Encoding.ASCII.GetBytes(IV);

            ICryptoTransform cryptor = aes.CreateEncryptor(aes.Key, aes.IV);



            using (MemoryStream memStream = new MemoryStream())
            {
                using (CryptoStream crypotStream = new CryptoStream(memStream, cryptor, CryptoStreamMode.Write))
                {
                    using (StreamWriter writerStream = new StreamWriter(crypotStream))
                    {
                        writerStream.Write(text);
                    }
                    encryptedText = memStream.ToArray();
                }
            }

            aes.Dispose();
        }
        catch
        {
            // In case of an error while encrypting, dont encrypt
            encryptedText = Encoding.ASCII.GetBytes(text);
        }
        return encryptedText;
    }

[The added fittedEncLen is basically a prefix of fixed length of 5 chars, that contains the length of the encrypted message following it, before decrypting the server reads those 5 chars and then it decrypts the encrypted part]

Sending the message to the server [TCPClient] [C#] :

public int Send(Message message)
        {
            /*
             * Encrpyts the message and then sends it to the network stream.
             * 
             * Return code:
             *      0 on success.
             *      -1 on failure.
             */

            byte[] msg = Cryptography.EncryptAES(message); // Encrypt the message
            // Sending message

            try
            {
                this._networkStream.Write(msg, 0, msg.Length);
                this._networkStream.Flush();
            }
            catch
            {
                return -1;
            }
            return 0;
        }

Receiving [C++] :

wstring Helper::getWideStringPartFromSocket(SOCKET sc, int bytesNum)
{
    // This function reads the message from the socket, using wide string
    std::wstringstream cls;
    cls << getPartFromSocket(sc, bytesNum, 0);
    return cls.str();
}

char* Helper::getPartFromSocket(SOCKET sc, int bytesNum, int flags)
{
    if (bytesNum == 0)
        return "";

    char* data = new char[bytesNum + 1];
    int res = recv(sc, data, bytesNum, flags);

    if (res == INVALID_SOCKET)
    {
        string s = "Error while recieving from socket: ";
        s += to_string(sc);
        throw exception(s.c_str());
    }

    data[bytesNum] = 0;

    return data;
}



BufferedString* Helper::makeBufferedString(SOCKET sc)
    {
        /*
            The socket contains <length of encrypted message (unencrypted)> <encrypted message>.

            This function will read the length of the unencrypted message, read 
            the encrypted message, decrypt it, store it in a BufferedString
            object and return the object.

            Length of length number: MESSAGE_LENGTH_LEN.
        */

        int sizeOfMessage = Helper::getIntPartFromSocket(sc, MESSAGE_LENGTH_LEN);
        if (sizeOfMessage == 0)
            return NULL;

        wstring wideString = getWideStringPartFromSocket(sc, sizeOfMessage);
        string decrypted = "";

        if (wideString.length() < sizeOfMessage)
        {
            std::wstringstream cls;
            cls << wideString;
            cls << getWideStringPartFromSocket(sc, sizeOfMessage - wideString.length());
            wideString = cls.str();
        }

        SocketEncryptionKeychain* keyChain = SocketEncryptionKeychain::getKeychain(sc);

        if (keyChain != nullptr) // If the socket has a keychain, decrypt the message
            decrypted = Cryptography::decryptAES(wideString, keyChain->getKey(), keyChain->getIV()); // Try to decrypt the message
        else // If the keychain is null, just convert the widestring to a string
            decrypted = wideStringToString(wideString);

        return new BufferedString(decrypted);
    }

SocketEncryptionKeychain basically contains the AES Key and IV for each socket BufferedString is a class that contains the string, and you can read from it like you read from a socket [its a buffer that once you read from it, what you read is deleted] [basically a string buffer, nothing special]

Decrypting [C++]:

string Cryptography::decryptAES(wstring cipherText, byte aesKey[], byte iv[])
{
    if (aesKey == nullptr || iv == nullptr) // If the key or iv are null, dont decrypt
        return Helper::wideStringToString(cipherText);

    string plaintext;
    try
    {
        // Decrypt :
        byte* cipher = wideStringToByteArray(cipherText); // Convert the wide string to byte*

        CryptoPP::AES::Decryption aesDecryption(aesKey, 32);
        CryptoPP::CBC_Mode_ExternalCipher::Decryption ecbDecryption(aesDecryption, iv);


        CryptoPP::StreamTransformationFilter stfDecryptor(ecbDecryption, new CryptoPP::StringSink(plaintext), StreamTransformationFilter::ZEROS_PADDING);
        stfDecryptor.Put(cipher, cipherText.length());
        stfDecryptor.MessageEnd();

        Helper::safeDelete(cipher);
    }
    catch (CryptoPP::InvalidCiphertext& ex)
    {
        // In case of an error don't decrypt
        plaintext = Helper::wideStringToString(cipherText);
    }

    return plaintext;
}


byte* Cryptography::wideStringToByteArray(wstring text)
{
    // This function translates the wstring into a byte*
    byte* bytes = new byte[text.length()]; // Convert the wstring to byte*
    for (int i = 0; i < text.length(); i++)
    {
        bytes[i] = text[i];
    }
    return bytes;
}

[Helper::safeDelete is a function that just deletes the pointer and sets it as null]

The decryption only fails once in a while


Solution

  • So, the problem was while parsing the char* to wstring in the function

    The problem in this function is the way I parse it:

    wstring Helper::getWideStringPartFromSocket(SOCKET sc, int bytesNum)
    {
        // This function reads the message from the socket, using wide string
        std::wstringstream cls;
        cls << getPartFromSocket(sc, bytesNum, 0);
        return cls.str();
    }
    

    I used a wstringstream, and the encrypted text can sometimes contain null-terminating character.

    So instead of using a wstringstream I used this:

    wstring Helper::getWideStringPartFromSocket(SOCKET sc, int bytesNum)
    {
        // This function reads the message from the socket, using wide string
        char* readBuffer = getPartFromSocket(sc, bytesNum, 0);
    
        return wstring(&readBuffer[0], &readBuffer[bytesNum]);
    }
    

    and then it does not cut the message at null-character