Search code examples
pytorchartificial-intelligencehuggingface-transformerssamplinglogits

Top-p sampling not working. CUDA error: device-side assert triggered


I was trying to re-implement the model.generate() function of transformers' models from huggingface. I did that so I could implement logit-bias, that normal function does not allow. But before I could reach that, I encountered a lot of problems with my top-p sampling.

Here's the code snippet:

generation_args = {
    "max_new_tokens": 500,
    "temperature": 0.4,  # Adjust temperature if needed for more or less randomness
    "do_sample": True,  # Enable sampling
    "top_p": 0.5,  # Set the cumulative probability for nucleus sampling
    "top_k": None,  # Optionally, you can set top_k if you want to use it alongside or instead of top_p
}


def top_p_filtering(logits, top_p):
    """Filter the logits using top-p (nucleus) sampling."""
    # Sort logits in descending order and get the sorted indices
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)

    # Compute the cumulative probabilities of the sorted logits
    cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)

    # Create a mask for the tokens to keep
    sorted_indices_to_keep = cumulative_probs <= top_p

    # Ensure that at least one token is kept (the first token, which has the highest logit)
    sorted_indices_to_keep[..., 0] = True

    # Filter out the tokens to remove by setting their logits to negative infinity
    logits[sorted_indices[~sorted_indices_to_keep]] = float('-inf')

    return logits


def custom_generate(input_ids, streamer, max_new_tokens, temperature, top_p):
    past_key_values = None
    attention_mask = torch.ones(input_ids.shape, device=input_ids.device)

    for _ in range(max_new_tokens):
        with torch.no_grad():
            outputs = model(
                input_ids=input_ids,
                past_key_values=past_key_values,
                attention_mask=attention_mask,
                use_cache=True
            )

        logits = outputs.logits[:, -1, :]  # Get logits of the last token

        # Apply temperature to logits
        if temperature != 1.0:
            logits = logits / temperature

        # Apply top-p sampling
        if top_p is not None and top_p < 1.0:
            logits = top_p_filtering(logits, top_p)
        print("1")
        next_token_probs = torch.nn.functional.softmax(logits, dim=-1)
        print("2")
        # Check if next_token_probs contains valid probabilities


        next_token_id = torch.multinomial(next_token_probs,
                                          num_samples=1)  
        print("3")
        streamer.put(next_token_id)  # Pass the tensor directly to the streamer

        input_ids = next_token_id  # Set the next input to the last generated token
        attention_mask = torch.cat(
            [attention_mask, torch.ones((attention_mask.shape[0], 1), device=attention_mask.device)], dim=1)

        past_key_values = outputs.past_key_values

        if next_token_id.item() == tokenizer.eos_token_id:  
            break

with torch.no_grad():
    custom_generate(input_ids, streamer, generation_args["max_new_tokens"], generation_args["temperature"], generation_args["top_p"])

The error that I face:

../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [10,0,0], thread: [63,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
Exception in thread Thread-18 (generate):
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/mnt/c/Users/User/Documents/EmpatheticChatBot/Inference-Server.py", line 130, in generate
    custom_generate(input_ids, streamer, generation_args["max_new_tokens"], generation_args["temperature"], generation_args["top_p"])
  File "/mnt/c/Users/User/Documents/EmpatheticChatBot/Inference-Server.py", line 108, in custom_generate
    next_token_id = torch.multinomial(next_token_probs,
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

The entire problem arised only after adding top-p sampling.

I expected my sampling to work, as I have looked through my code maybe 30 times already. ChatGPT says this code is perfect, and that my error is really hard to debug. My hypothesis is that values are getting incorrectly filtered or setting them to "bad" values.


Solution

  • The problem is the indexing you're doing at this line:

    logits[sorted_indices[~sorted_indices_to_keep]] = float('-inf')

    For reasons I'll explain, this is causing an index out of bounds error. Out of bounds indexing is a common cause of CUDA error: device-side assert triggered errors.

    Consider the following:

    import torch
    import torch.nn as nn
    
    torch.manual_seed(42)
    
    top_p = 0.2
    
    logits = torch.randn(8, 128) # random logits
    
    # sort logits 
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    
    # calculate cumulative probs
    cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
    
    # apply top p threshold to cumulative probs
    sorted_indices_to_keep = cumulative_probs <= top_p
    
    # ensure at least one index is kept
    sorted_indices_to_keep[..., 0] = True
    
    # this is the problem: logits[sorted_indices[~sorted_indices_to_keep]] = float('-inf')
    print(logits.shape, sorted_indices[~sorted_indices_to_keep].shape)
    > torch.Size([8, 128]) torch.Size([989])
    

    When you index sorted_indices[~sorted_indices_to_keep], both inputs are of shape (8, 128), but the output is of shape (989,) (or similar depending on the random seed for the dummy logits).

    This happens because the sorted_indices_to_keep has an irregular number of True values in each row. This means the indexing operation can't resolve the output into a clean 2D tensor where every row is the same size. Pytorch handles this situation by returning an unrolled vector of every True value from the indexing tensor.

    This means when you try to compute logits[sorted_indices[~sorted_indices_to_keep]], you are using a long 1D tensor to index into a small 2D tensor. If you run this on CPU, you get an error like IndexError: index 20 is out of bounds for dimension 0 with size 8. When you run on GPU, you get the Cuda assert error.

    To fix this, use the scatter operation. Use something like this:

    def top_p_filtering(logits, top_p, shift_indices=True, debug=False):
        """Filter the logits using top-p (nucleus) sampling."""
        # Sort logits in descending order and get the sorted indices
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    
        # Compute the cumulative probabilities of the sorted logits
        cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
    
        # Create a mask for the tokens to keep
        sorted_indices_to_keep = cumulative_probs <= top_p
        
        # Optional: shift indices to the right. This results in keeping the first 
        # token above the top_p threshold. Skip this line to ensure that all 
        # token probs are strictly below the top_p threshold
        if shift_indices:
            sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone()
    
        # Ensure that at least one token is kept (the first token, which has the highest logit)
        sorted_indices_to_keep[..., 0] = True
        
        # Use scatter to create top_p mask
        mask = sorted_indices_to_keep.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_keep)
        
        # Optional debug check to make sure top_p is being honored
        # Note we need to compute probs before masking because applying softmax 
        # after masking will result in a distribution that sums to 1
        if debug:
            probs = torch.nn.functional.softmax(logits, dim=-1)
            probs[~mask] = 0
            print(probs.sum(-1))
        
        # Use mask to set logit vals to -inf
        logits[~mask] = float('-inf')
    
        return logits