Search code examples
huggingface-transformers

Stop model.generate


I'm using TextIteratorStreamer to generate text as stream and I use Thread to run model.generate

thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()

I want to introduce a cancel_event = asyncio.Event() and check if cancel_event.is_set() in the streamer loop to stop model.generate consuming GPU resources, How can I stop model.generate? Do I need to kill the thread? how?


Solution

  • You can create a class that should handle the cancelling without the need to kill the thread. Something like this might work:

    import asyncio
    from threading import Thread
    from transformers import StoppingCriteria
    
    class StopCriteria(StoppingCriteria):
        def __init__(self, event):
            self.event = event
    
        def __call__(self, *args, **kwargs):
            return self.event.is_set()
    
    cancel_event = asyncio.Event()
    generation_kwargs['stopping_criteria'] = [StopCriteria(cancel_event)]
    
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    
    # your existing code
    
    cancel_event.set()
    

    If you want the thread to finish before cancelling it you could use:

    thread.join()