Search code examples
aiogram

How to work with a function that blocks a thread in python\aiogram3?


So, for a more detailed description of the problem, I will give an example of a small code:

import asyncio
from time import sleep
import logging

from aiogram import Bot, Dispatcher, Router
from aiogram.filters import Command
from aiogram.types import Message


token = "<token>"

bot = Bot(token=token)
dispatcher = Dispatcher()

router = Router()
dispatcher.include_router(router)


def block_function() -> str:
    sleep(2)
    return "Complete!"


@router.message(Command("start"))
async def start(message: Message) -> None:
    username = message.from_user.username
    await message.answer(f"Hello, {username}!")
    await message.answer("Start...")
    await message.answer(block_function())


@router.message(Command("escape"))
async def stop_polling(message: Message) -> None:
    await message.answer("Bye!")
    await dispatcher.stop_polling()


async def main() -> None:
    logging.basicConfig(level=logging.INFO)
    await dispatcher.start_polling(bot)


if __name__ == "__main__":
    asyncio.run(main())

I intentionally used a synchronous function in block_function(). In this case, it is obvious that for 2 seconds the bot will not be able to accept and process requests from users.

Question: what needs to be done in order to put the blocking function in a separate thread?

Known solutions: to make a synchronous function asynchronous, but, unfortunately, in a real project, this function is mathematical and cannot be asynchronous


Solution

  • Thanks to Zwick Vitaly, the answer was found. Below I offer the corrected code, with the function asyncio.to_thread and a few additional examples:

    import asyncio
    from io import BytesIO
    import logging
    from time import sleep, time
    
    from aiogram import Bot, Dispatcher, Router
    from aiogram.filters import Command
    from aiogram.types import Message, BufferedInputFile
    
    import matplotlib as mpl
    import matplotlib.pyplot as plt
    import numpy as np
    
    mpl.use("agg")
    plt.style.use("ggplot")
    
    
    token = "<YOUR TOKEN HERE>"
    
    bot = Bot(token=token)
    dispatcher = Dispatcher()
    
    router = Router()
    dispatcher.include_router(router)
    
    
    def block_plot_function(username: str) -> bytes:
        x = np.linspace(0, 10, 1000)
        y = np.sin(x**2 - 1)
    
        plt.plot(x, y)
        plt.title(f"example for {username}")
        plt.xlabel("x")
        plt.ylabel("y")
    
        image_bytes = BytesIO()
        plt.savefig(image_bytes, format="png")
        plt.close("all")
    
        image_bytes.seek(0)
        values = image_bytes.getvalue()
        image_bytes.close()
    
        return values
    
    
    def block_math_function() -> str:
        M = np.random.randn(50) * 1e9
    
        start = time()
        for i in range(10000000):
            M = np.sqrt(M)
    
        return f"Complete in {time() - start} sec"
    
    
    def block_sleep_function() -> str:
        sleep(2)
        return "Complete!"
    
    
    @router.message(Command("check_sleep"))
    async def check_sleep(message: Message) -> None:
        await message.answer("Start sleep check...")
        await message.answer(await asyncio.to_thread(block_sleep_function))
    
    
    @router.message(Command("check_math"))
    async def check_math(message: Message) -> None:
        await message.answer("Start math check...")
        await message.answer(await asyncio.to_thread(block_math_function))
    
    
    @router.message(Command("check_plot"))
    async def check_plot(message: Message) -> None:
        uname = message.from_user.username
        await message.answer("Start plot check...")
        values = await asyncio.to_thread(block_plot_function, username=uname)
        await message.answer_photo(BufferedInputFile(values, "example"))
    
    
    @router.message(Command("escape"))
    async def stop_polling(message: Message) -> None:
        await message.answer("Bye!")
        await dispatcher.stop_polling()
    
    
    async def main() -> None:
        logging.basicConfig(level=logging.INFO)
        await dispatcher.start_polling(bot)
    
    
    if __name__ == "__main__":
        asyncio.run(main())