Search code examples
pythonmatplotlibasynchronousprocessevent-loop

How to asynchronously run Matplolib server-side with a timeout? The process hangs randomly


I'm trying to reproduce the ChatGPT code interpreter feature where a LLM create figures on demand by executing to code.

Unfortunately Matplotlib hangs 20% of time, I have not managed to understand why.

I would like the implementation:

  • to be non-blocking for the rest of the server
  • to have a timeout in case the code is too long to execute

I made a first implementation:

import asyncio
import psutil


TIMEOUT = 5


async def exec_python(code: str) -> str:
    """Execute Python code.

    Args:
        code (str): Python code to execute.

    Returns:
        dict: A dictionary containing the stdout and the stderr from executing the code.
    """
    code = preprocess_code(code)
    stdout = ""
    stderr = ""
    try:
        stdout, stderr = await run_with_timeout(code, TIMEOUT)
    except asyncio.TimeoutError:
        stderr = "Execution timed out."
    return {"stdout": stdout, "stderr": stderr}


async def run_with_timeout(code: str, timeout: int) -> str:
    proc = await run(code)

    try:
        stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout)
        return stdout.decode().strip(), stderr.decode().strip()
    except asyncio.TimeoutError:
        kill_process(proc.pid)
        raise


async def run(code: str):
    return await asyncio.create_subprocess_exec(
        "python",
        "-c",
        code,
        stdout=asyncio.subprocess.PIPE,
        stderr=asyncio.subprocess.PIPE,
    )


def kill_process(pid: int):
    try:
        parent = psutil.Process(pid)
        for child in parent.children(recursive=True):
            child.kill()
        parent.kill()
        print(f"Killing Process {pid} (timed out)")
    except psutil.NoSuchProcess:
        print("Process already killed.")


PLT_OVERRIDE_PREFIX = """
import matplotlib
matplotlib.use('Agg') # non-interactive backend
import asyncio
import matplotlib.pyplot as plt
import io
import base64

def custom_show():
    buf = io.BytesIO()
    plt.gcf().savefig(buf, format='png')
    buf.seek(0)
    image_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
    print('[BASE_64_IMG]', image_base64)
    buf.close()  # Close the buffer
    plt.close('all')
    matplotlib.pyplot.figure()  # Create a new figure
    matplotlib.pyplot.close('all')  # Close it to ensure the state is clean
    matplotlib.pyplot.cla()  # Clear the current axes
    matplotlib.pyplot.clf()  # Clear the current figure
    matplotlib.pyplot.close()  # Close the current figure

plt.show = custom_show
"""



def preprocess_code(code: str) -> str:
    override_prefix = ""
    code_lines = code.strip().split("\n")
    if not code_lines:
        return code  # Return original code if it's empty
    if "import matplotlib.pyplot as plt" in code:
        override_prefix = PLT_OVERRIDE_PREFIX + "\n"
        code_lines = [
            line for line in code_lines if line != "import matplotlib.pyplot as plt"
        ]

    last_line = code_lines[-1]
    # Check if the last line is already a print statement
    if last_line.strip().startswith("print"):
        return "\n".join(code_lines)

    try:
        compile(last_line, "<string>", "eval")
        # If it's a valid expression, wrap it with print
        code_lines[-1] = f"print({last_line})"
    except SyntaxError:
        # If it's not an expression, check if it's an assignment
        if "=" in last_line:
            variable_name = last_line.split("=")[0].strip()
            code_lines.append(f"print({variable_name})")

    return override_prefix + "\n".join(code_lines)

I have already tried without success:

  • ipython rather than python
  • using threads rather than process
  • saving the image on disk rather than on buffer

What extremely weird is that I cannot reproduce the bug using the code above. And yet I see the error frequently in prod and on my machine.


Solution

  • I finally found the source of the bug, and it was NOT in the code interpretor (I should have expected it since I could not reproduce the bug in a simplified settings).

    It turns GPT does not always respect the signature of the tools. For instance instance instead of returning a dict {"code": "..."} it will sometimes directly return the code as a string.

    I improved the parsing to handle that case:

        try:
            parsed_args = json.loads(function_args)
            function_args = parsed_args
        except json.JSONDecodeError:
            function_args = {
                next(
                    iter(inspect.signature(function_to_call).parameters)
                ): function_args
            }