Search code examples
pythonjavaencryptionmd5pycrypto

Converting AES 256 Java Encryption in python


I have enryption code written in Java, and I am trying to write it in python, But the issue is python encryption and decryption is working fine but when trying to decrypt the java encrypted string using python, its not working.

This is the Java code that I had.

import lombok.extern.slf4j.Slf4j;
import org.bouncycastle.jce.provider.BouncyCastleProvider;

import javax.crypto.Cipher;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.security.*;
import java.util.Arrays;
import java.util.Base64;


@Slf4j
public class EncryptUsingAES {


    /**
     * This function takes a String as input and encrypts it using AES Algorithm
     * @param stringToEncrypt
     * @param publicKey
     * @return
     */
    public String returnEncryptedString(String stringToEncrypt, String publicKey) {

        Security.addProvider(new BouncyCastleProvider());
        try{
            SecureRandom sr = new SecureRandom();
            byte[] salt = new byte[8];
            sr.nextBytes(salt);
            final byte[][] keyAndIV = generateKeyAndIV(salt, publicKey.getBytes(StandardCharsets.UTF_8), MessageDigest.getInstance("MD5"));
            Cipher cipher = Cipher.getInstance("AES/CBC/PKCS7Padding", BouncyCastleProvider.PROVIDER_NAME);
            cipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(keyAndIV[0], "AES"), new IvParameterSpec(keyAndIV[1]));
            byte[] encryptedData = cipher.doFinal(stringToEncrypt.getBytes(StandardCharsets.UTF_8));
            byte[] prefixAndSaltAndEncryptedData = new byte[16 + encryptedData.length];

            System.arraycopy("Salted__".getBytes(StandardCharsets.UTF_8), 0, prefixAndSaltAndEncryptedData, 0, 8);

            System.arraycopy(salt, 0, prefixAndSaltAndEncryptedData, 8, 8);

            System.arraycopy(encryptedData, 0, prefixAndSaltAndEncryptedData, 16, encryptedData.length);
            String encryptedString = URLEncoder.encode(Base64.getEncoder().encodeToString(prefixAndSaltAndEncryptedData),"UTF-8");

            return encryptedString;
        } catch (Exception e) {
            return null;
        }
    }


    /**
     * this function creates key Initialisation Vector which is required for AES Encryption
     * @param salt
     * @param password
     * @param md
     * @return
     */
    protected byte[][] generateKeyAndIV(byte[] salt, byte[] password, MessageDigest md) {

        int keyLength = 32;
        int ivLength = 16;
        int iterations = 1;
        int digestLength = md.getDigestLength();
        int requiredLength = (keyLength + ivLength + digestLength - 1) / digestLength * digestLength;
        byte[] generatedData = new byte[requiredLength];
        int generatedLength = 0;

        try {
            md.reset();

            while (generatedLength < keyLength + ivLength) {

                if (generatedLength > 0)
                    md.update(generatedData, generatedLength - digestLength, digestLength);
                md.update(password);
                if (salt != null)
                    md.update(salt, 0, 8);
                md.digest(generatedData, generatedLength, digestLength);


                for (int i = 1; i < iterations; i++) {
                    md.update(generatedData, generatedLength, digestLength);
                    md.digest(generatedData, generatedLength, digestLength);
                }

                generatedLength += digestLength;
            }

            byte[][] result = new byte[2][];
            result[0] = Arrays.copyOfRange(generatedData, 0, keyLength);
            if (ivLength > 0)
                result[1] = Arrays.copyOfRange(generatedData, keyLength, keyLength + ivLength);

            return result;

        } catch (DigestException e) {
            log.error(OrderStatusSmsTrackerConstants.ORDER_STATUS_SMS_TRACKER_SERVICE+ 
            return null;
        } finally {

            Arrays.fill(generatedData, (byte)0);
        }
    }
}

And this is my python encryption and decryption code,

from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from Crypto.Random import get_random_bytes
import base64
import urllib.parse
import hashlib
import logging

class EncryptUsingAES:
    def return_encrypted_string(self, string_to_encrypt, public_key):
        try:
            salt = get_random_bytes(8)
            key_and_iv = self.generate_key_and_iv(salt, public_key.encode('utf-8'), hashlib.md5)

            cipher = AES.new(key_and_iv[0], AES.MODE_CBC, iv=key_and_iv[1])
            encrypted_data = cipher.encrypt(pad(string_to_encrypt.encode('utf-8'), AES.block_size))

            prefix_and_salt_and_encrypted_data = b"Salted__" + salt + encrypted_data
            encrypted_string = urllib.parse.quote(base64.b64encode(prefix_and_salt_and_encrypted_data))

            return encrypted_string
        except Exception as e:
            logging.error("Error when encrypting: " + str(e))
            return None
        
    def return_decrypted_string(self, encrypted_string, public_key):
        try:
            encrypted_data = base64.b64decode(urllib.parse.unquote(encrypted_string))
            salt = encrypted_data[8:16]
            encrypted_data = encrypted_data[16:]

            key_and_iv = self.generate_key_and_iv(salt, public_key.encode('utf-8'), hashlib.md5)

            cipher = AES.new(key_and_iv[0], AES.MODE_CBC, iv=key_and_iv[1])
            decrypted_data = unpad(cipher.decrypt(encrypted_data), AES.block_size)

            return decrypted_data.decode('utf-8')
        except Exception as e:
            logging.error("Error when decrypting: " + str(e))
            return None

    def generate_key_and_iv(self, salt, password, md):
        key_length = 32
        iv_length = 16
        iterations = 1
        digest_length = md().digest_size
        required_length = (key_length + iv_length + digest_length - 1) // digest_length * digest_length
        generated_data = bytearray([0]) * required_length
        generated_length = 0

        try:
            md = md()
            while generated_length < key_length + iv_length:
                if generated_length > 0:
                    md.update(generated_data[generated_length - digest_length:generated_length])
                md.update(password)
                if salt is not None:
                    md.update(salt[:8])
                generated_data[generated_length:generated_length + digest_length] = md.digest()

                for i in range(1, iterations):
                    md.update(generated_data[generated_length:generated_length + digest_length])
                    generated_data[generated_length:generated_length + digest_length] = md.digest()

                generated_length += digest_length

            key = generated_data[:key_length]
            iv = generated_data[key_length:key_length + iv_length]
            return key, iv

        except Exception as e:
            logging.error("Error when generating key and IV: " + str(e))
            return None
        finally:
            for i in range(len(generated_data)):
                generated_data[i] = 0

Can anyone help me on this.


Solution

  • The bug is in the key derivation: The Java implementation of the digest automatically resets the instance to the initial state after a digest() call, unlike the Python implementation.
    Therefore, in the Python implementation, the digest must be explicitly reset, e.g. by creating a new digest object.
    Attention: This is necessary in two places (whereby the second modification is only noticeable for iterations greater than 1).

    The following changes are required in generate_key_and_iv():

    ...
    try:
        while generated_length < key_length + iv_length:
            h = md() # create new instance (1.)
            if generated_length > 0:
                h.update(generated_data[generated_length - digest_length:generated_length])
            h.update(password)
            if salt is not None:
                h.update(salt[:8])
            generated_data[generated_length:generated_length + digest_length] = h.digest()
    
            for i in range(1, iterations):
                h = md() # create new instance (2.)
                h.update(generated_data[generated_length:generated_length + digest_length])
                generated_data[generated_length:generated_length + digest_length] = h.digest()
    
            generated_length += digest_length
            ...
    

    For a test (and only for the test), a static salt can be used in both implementations. With identical plaintext and password, both encryptions then provide the same ciphertext.

    For completeness, the key derivation implemented in the Java code is a lightweight implementation of EVP_BytesToKey().


    In addition, a safe='' should be passed as 2nd parameter in urllib.parse.quote() so that / is also replaced by an escape sequence.